From 6b518dc9e4c601c0108768932dc1450c036075fd Mon Sep 17 00:00:00 2001 From: Taylor Mullen Date: Fri, 9 May 2025 23:29:02 -0700 Subject: [PATCH] Enable tools to cancel active execution. - Plumbed abort signals through to tools - Updated the shell tool to properly cancel active requests by killing the entire child process tree of the underlying shell process and then report that the shell itself was canceled. Fixes https://b.corp.google.com/issues/416829935 --- .../cli/src/ui/hooks/atCommandProcessor.ts | 4 +- packages/cli/src/ui/hooks/useGeminiStream.ts | 115 +++++++++++++----- packages/server/src/core/client.ts | 11 +- packages/server/src/core/turn.ts | 13 +- packages/server/src/tools/edit.ts | 5 +- packages/server/src/tools/glob.ts | 5 +- packages/server/src/tools/grep.ts | 5 +- packages/server/src/tools/ls.ts | 5 +- packages/server/src/tools/read-file.ts | 5 +- packages/server/src/tools/read-many-files.ts | 5 +- packages/server/src/tools/shell.ts | 70 ++++++++--- packages/server/src/tools/terminal.ts | 5 +- packages/server/src/tools/tools.ts | 5 +- packages/server/src/tools/web-fetch.ts | 5 +- packages/server/src/tools/write-file.ts | 5 +- 15 files changed, 191 insertions(+), 72 deletions(-) diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.ts b/packages/cli/src/ui/hooks/atCommandProcessor.ts index 5ffa5383d9..a13a7d3698 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.ts @@ -26,6 +26,7 @@ interface HandleAtCommandParams { addItem: UseHistoryManagerReturn['addItem']; setDebugMessage: React.Dispatch>; messageId: number; + signal: AbortSignal; } interface HandleAtCommandResult { @@ -90,6 +91,7 @@ export async function handleAtCommand({ addItem, setDebugMessage, messageId: userMessageTimestamp, + signal, }: HandleAtCommandParams): Promise { const trimmedQuery = query.trim(); const parsedCommand = parseAtCommand(trimmedQuery); @@ -163,7 +165,7 @@ export async function handleAtCommand({ let toolCallDisplay: IndividualToolCallDisplay; try { - const result = await readManyFilesTool.execute(toolArgs); + const result = await readManyFilesTool.execute(toolArgs, signal); const fileContent = result.llmContent || ''; toolCallDisplay = { diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 3f8cee405d..e86ae0b937 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -89,7 +89,7 @@ export const useGeminiStream = ( }, [config, addItem]); useInput((_input, key) => { - if (streamingState === StreamingState.Responding && key.escape) { + if (streamingState !== StreamingState.Idle && key.escape) { abortControllerRef.current?.abort(); } }); @@ -104,6 +104,9 @@ export const useGeminiStream = ( setShowHelp(false); + abortControllerRef.current ??= new AbortController(); + const signal = abortControllerRef.current.signal; + if (typeof query === 'string') { const trimmedQuery = query.trim(); setDebugMessage(`User query: '${trimmedQuery}'`); @@ -120,6 +123,7 @@ export const useGeminiStream = ( addItem, setDebugMessage, messageId: userMessageTimestamp, + signal, }); if (!atCommandResult.shouldProceed) return; queryToSendToGemini = atCommandResult.processedQuery; @@ -165,9 +169,6 @@ export const useGeminiStream = ( const chat = chatSessionRef.current; try { - abortControllerRef.current = new AbortController(); - const signal = abortControllerRef.current.signal; - const stream = client.sendMessageStream( chat, queryToSendToGemini, @@ -294,7 +295,26 @@ export const useGeminiStream = ( } else if (event.type === ServerGeminiEventType.UserCancelled) { // Flush out existing pending history item. if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); + // If the pending item is a tool_group, update statuses to Canceled + if (pendingHistoryItemRef.current.type === 'tool_group') { + const updatedTools = pendingHistoryItemRef.current.tools.map( + (tool) => { + if ( + tool.status === ToolCallStatus.Pending || + tool.status === ToolCallStatus.Confirming || + tool.status === ToolCallStatus.Executing + ) { + return { ...tool, status: ToolCallStatus.Canceled }; + } + return tool; + }, + ); + const pendingHistoryItem = pendingHistoryItemRef.current; + pendingHistoryItem.tools = updatedTools; + addItem(pendingHistoryItem, userMessageTimestamp); + } else { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + } setPendingHistoryItem(null); } addItem( @@ -412,6 +432,59 @@ export const useGeminiStream = ( } if (outcome === ToolConfirmationOutcome.Cancel) { + declineToolExecution( + 'User rejected function call.', + ToolCallStatus.Error, + ); + } else { + const tool = toolRegistry.getTool(request.name); + if (!tool) { + throw new Error( + `Tool "${request.name}" not found or is not registered.`, + ); + } + + try { + abortControllerRef.current = new AbortController(); + const result = await tool.execute( + request.args, + abortControllerRef.current.signal, + ); + + if (abortControllerRef.current.signal.aborted) { + declineToolExecution( + result.llmContent, + ToolCallStatus.Canceled, + ); + return; + } + + const functionResponse: Part = { + functionResponse: { + name: request.name, + id: request.callId, + response: { output: result.llmContent }, + }, + }; + + const responseInfo: ToolCallResponseInfo = { + callId: request.callId, + responsePart: functionResponse, + resultDisplay: result.returnDisplay, + error: undefined, + }; + updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); + setStreamingState(StreamingState.Idle); + await submitQuery(functionResponse); + } finally { + abortControllerRef.current = null; + } + } + + function declineToolExecution( + declineMessage: string, + status: ToolCallStatus, + ) { let resultDisplay: ToolResultDisplay | undefined; if ('fileDiff' in originalConfirmationDetails) { resultDisplay = { @@ -426,43 +499,19 @@ export const useGeminiStream = ( functionResponse: { id: request.callId, name: request.name, - response: { error: 'User rejected function call.' }, + response: { error: declineMessage }, }, }; const responseInfo: ToolCallResponseInfo = { callId: request.callId, responsePart: functionResponse, resultDisplay, - error: new Error('User rejected function call.'), - }; - // Update UI to show cancellation/error - updateFunctionResponseUI(responseInfo, ToolCallStatus.Error); - setStreamingState(StreamingState.Idle); - } else { - const tool = toolRegistry.getTool(request.name); - if (!tool) { - throw new Error( - `Tool "${request.name}" not found or is not registered.`, - ); - } - const result = await tool.execute(request.args); - const functionResponse: Part = { - functionResponse: { - name: request.name, - id: request.callId, - response: { output: result.llmContent }, - }, + error: new Error(declineMessage), }; - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responsePart: functionResponse, - resultDisplay: result.returnDisplay, - error: undefined, - }; - updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); + // Update UI to show cancellation/error + updateFunctionResponseUI(responseInfo, status); setStreamingState(StreamingState.Idle); - await submitQuery(functionResponse); } }; diff --git a/packages/server/src/core/client.ts b/packages/server/src/core/client.ts index 904e944cf3..46af465a7f 100644 --- a/packages/server/src/core/client.ts +++ b/packages/server/src/core/client.ts @@ -64,10 +64,13 @@ export class GeminiClient { .getTool('read_many_files') as ReadManyFilesTool; if (readManyFilesTool) { // Read all files in the target directory - const result = await readManyFilesTool.execute({ - paths: ['**/*'], // Read everything recursively - useDefaultExcludes: true, // Use default excludes - }); + const result = await readManyFilesTool.execute( + { + paths: ['**/*'], // Read everything recursively + useDefaultExcludes: true, // Use default excludes + }, + AbortSignal.timeout(30000), + ); if (result.llmContent) { initialParts.push({ text: `\n--- Full File Context ---\n${result.llmContent}`, diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts index 7d8bf7b65d..622199382f 100644 --- a/packages/server/src/core/turn.ts +++ b/packages/server/src/core/turn.ts @@ -36,7 +36,10 @@ export interface ServerTool { name: string; schema: FunctionDeclaration; // The execute method signature might differ slightly or be wrapped - execute(params: Record): Promise; + execute( + params: Record, + signal?: AbortSignal, + ): Promise; shouldConfirmExecute( params: Record, ): Promise; @@ -153,7 +156,7 @@ export class Turn { if (confirmationDetails) { return { ...pendingToolCall, confirmationDetails }; } - const result = await tool.execute(pendingToolCall.args); + const result = await tool.execute(pendingToolCall.args, signal); return { ...pendingToolCall, result, @@ -199,7 +202,11 @@ export class Turn { resultDisplay: outcome.result?.returnDisplay, error: outcome.error, }; - yield { type: GeminiEventType.ToolCallResponse, value: responseInfo }; + + // If aborted we're already yielding the user cancellations elsewhere. + if (!signal?.aborted) { + yield { type: GeminiEventType.ToolCallResponse, value: responseInfo }; + } } } diff --git a/packages/server/src/tools/edit.ts b/packages/server/src/tools/edit.ts index c40b9e440b..fd57d97d5e 100644 --- a/packages/server/src/tools/edit.ts +++ b/packages/server/src/tools/edit.ts @@ -333,7 +333,10 @@ Expectation for parameters: * @param params Parameters for the edit operation * @returns Result of the edit operation */ - async execute(params: EditToolParams): Promise { + async execute( + params: EditToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/glob.ts b/packages/server/src/tools/glob.ts index 9e7df0e883..b1b9d0cfc5 100644 --- a/packages/server/src/tools/glob.ts +++ b/packages/server/src/tools/glob.ts @@ -138,7 +138,10 @@ export class GlobTool extends BaseTool { /** * Executes the glob search with the given parameters */ - async execute(params: GlobToolParams): Promise { + async execute( + params: GlobToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/grep.ts b/packages/server/src/tools/grep.ts index e3253ecf0d..543918329c 100644 --- a/packages/server/src/tools/grep.ts +++ b/packages/server/src/tools/grep.ts @@ -166,7 +166,10 @@ export class GrepTool extends BaseTool { * @param params Parameters for the grep search * @returns Result of the grep search */ - async execute(params: GrepToolParams): Promise { + async execute( + params: GrepToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { console.error( diff --git a/packages/server/src/tools/ls.ts b/packages/server/src/tools/ls.ts index 01da512164..fea9518774 100644 --- a/packages/server/src/tools/ls.ts +++ b/packages/server/src/tools/ls.ts @@ -184,7 +184,10 @@ export class LSTool extends BaseTool { * @param params Parameters for the LS operation * @returns Result of the LS operation */ - async execute(params: LSToolParams): Promise { + async execute( + params: LSToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { return this.errorResult( diff --git a/packages/server/src/tools/read-file.ts b/packages/server/src/tools/read-file.ts index 598b4691b9..de09161d6b 100644 --- a/packages/server/src/tools/read-file.ts +++ b/packages/server/src/tools/read-file.ts @@ -193,7 +193,10 @@ export class ReadFileTool extends BaseTool { * @param params Parameters for the file reading * @returns Result with file contents */ - async execute(params: ReadFileToolParams): Promise { + async execute( + params: ReadFileToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/read-many-files.ts b/packages/server/src/tools/read-many-files.ts index 0b4b090d26..44882e4417 100644 --- a/packages/server/src/tools/read-many-files.ts +++ b/packages/server/src/tools/read-many-files.ts @@ -237,7 +237,10 @@ Default excludes apply to common non-text files and large dependency directories return `Will attempt to read and concatenate files ${pathDesc}. ${excludeDesc}. File encoding: ${DEFAULT_ENCODING}. Separator: "${DEFAULT_OUTPUT_SEPARATOR_FORMAT.replace('{filePath}', 'path/to/file.ext')}".`; } - async execute(params: ReadManyFilesParams): Promise { + async execute( + params: ReadManyFilesParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/shell.ts b/packages/server/src/tools/shell.ts index fd8a6b1a86..7851b76afe 100644 --- a/packages/server/src/tools/shell.ts +++ b/packages/server/src/tools/shell.ts @@ -118,7 +118,10 @@ export class ShellTool extends BaseTool { return confirmationDetails; } - async execute(params: ShellToolParams): Promise { + async execute( + params: ShellToolParams, + abortSignal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { return { @@ -174,18 +177,38 @@ export class ShellTool extends BaseTool { }); let code: number | null = null; - let signal: NodeJS.Signals | null = null; - shell.on( - 'close', - (_code: number | null, _signal: NodeJS.Signals | null) => { - code = _code; - signal = _signal; - }, - ); + let processSignal: NodeJS.Signals | null = null; + const closeHandler = ( + _code: number | null, + _signal: NodeJS.Signals | null, + ) => { + code = _code; + processSignal = _signal; + }; + shell.on('close', closeHandler); + + const abortHandler = () => { + if (shell.pid) { + try { + // Kill the entire process group + process.kill(-shell.pid, 'SIGTERM'); + } catch (_e) { + // Fallback to killing the main process if group kill fails + try { + shell.kill('SIGKILL'); // or 'SIGTERM' + } catch (_killError) { + // Ignore errors if the process is already dead + } + } + } + }; + abortSignal.addEventListener('abort', abortHandler); // wait for the shell to exit await new Promise((resolve) => shell.on('close', resolve)); + abortSignal.removeEventListener('abort', abortHandler); + // parse pids (pgrep output) from temporary file and remove it const backgroundPIDs: number[] = []; if (fs.existsSync(tempFilePath)) { @@ -205,19 +228,26 @@ export class ShellTool extends BaseTool { } fs.unlinkSync(tempFilePath); } else { - console.error('missing pgrep output'); + if (!abortSignal.aborted) { + console.error('missing pgrep output'); + } } - const llmContent = [ - `Command: ${params.command}`, - `Directory: ${params.directory || '(root)'}`, - `Stdout: ${stdout || '(empty)'}`, - `Stderr: ${stderr || '(empty)'}`, - `Error: ${error ?? '(none)'}`, - `Exit Code: ${code ?? '(none)'}`, - `Signal: ${signal ?? '(none)'}`, - `Background PIDs: ${backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'}`, - ].join('\n'); + let llmContent = ''; + if (abortSignal.aborted) { + llmContent = 'Command did not complete, it was cancelled by the user'; + } else { + llmContent = [ + `Command: ${params.command}`, + `Directory: ${params.directory || '(root)'}`, + `Stdout: ${stdout || '(empty)'}`, + `Stderr: ${stderr || '(empty)'}`, + `Error: ${error ?? '(none)'}`, + `Exit Code: ${code ?? '(none)'}`, + `Signal: ${processSignal ?? '(none)'}`, + `Background PIDs: ${backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'}`, + ].join('\n'); + } const returnDisplay = this.config.getDebugMode() ? llmContent : output; diff --git a/packages/server/src/tools/terminal.ts b/packages/server/src/tools/terminal.ts index 7320cfb264..af558fb0a5 100644 --- a/packages/server/src/tools/terminal.ts +++ b/packages/server/src/tools/terminal.ts @@ -265,7 +265,10 @@ Use this tool for running build steps (\`npm install\`, \`make\`), linters (\`es return confirmationDetails; } - async execute(params: TerminalToolParams): Promise { + async execute( + params: TerminalToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/tools.ts b/packages/server/src/tools/tools.ts index ac04450d94..7bb05a9561 100644 --- a/packages/server/src/tools/tools.ts +++ b/packages/server/src/tools/tools.ts @@ -64,7 +64,7 @@ export interface Tool< * @param params Parameters for the tool execution * @returns Result of the tool execution */ - execute(params: TParams): Promise; + execute(params: TParams, signal: AbortSignal): Promise; } /** @@ -141,9 +141,10 @@ export abstract class BaseTool< * Abstract method to execute the tool with the given parameters * Must be implemented by derived classes * @param params Parameters for the tool execution + * @param signal AbortSignal for tool cancellation * @returns Result of the tool execution */ - abstract execute(params: TParams): Promise; + abstract execute(params: TParams, signal: AbortSignal): Promise; } export interface ToolResult { diff --git a/packages/server/src/tools/web-fetch.ts b/packages/server/src/tools/web-fetch.ts index 12584231eb..62ca21625f 100644 --- a/packages/server/src/tools/web-fetch.ts +++ b/packages/server/src/tools/web-fetch.ts @@ -70,7 +70,10 @@ export class WebFetchTool extends BaseTool { return `Fetching content from ${displayUrl}`; } - async execute(params: WebFetchToolParams): Promise { + async execute( + params: WebFetchToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/write-file.ts b/packages/server/src/tools/write-file.ts index c9a47296a2..1f4c0d94b5 100644 --- a/packages/server/src/tools/write-file.ts +++ b/packages/server/src/tools/write-file.ts @@ -150,7 +150,10 @@ export class WriteFileTool extends BaseTool { return confirmationDetails; } - async execute(params: WriteFileToolParams): Promise { + async execute( + params: WriteFileToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateParams(params); if (validationError) { return {