From b0ceb7446211ebba8ee0f6518a64ef2c4a5019a9 Mon Sep 17 00:00:00 2001 From: Steven Robertson Date: Mon, 23 Feb 2026 19:57:00 -0800 Subject: [PATCH] feat: implement AfterTool tail tool calls (#18486) --- docs/extensions/reference.md | 52 +++++--- docs/hooks/reference.md | 8 ++ .../hooks-system.tail-tool-call.responses | 2 + integration-tests/hooks-system.test.ts | 107 +++++++++++++++++ .../components/messages/ShellToolMessage.tsx | 4 + .../ui/components/messages/ToolMessage.tsx | 2 + .../src/ui/components/messages/ToolShared.tsx | 8 ++ packages/cli/src/ui/hooks/toolMapping.test.ts | 15 +++ packages/cli/src/ui/hooks/toolMapping.ts | 1 + .../cli/src/ui/hooks/useToolScheduler.test.ts | 61 +++++++++- packages/cli/src/ui/hooks/useToolScheduler.ts | 28 ++++- packages/cli/src/ui/types.ts | 1 + .../core/src/core/coreToolHookTriggers.ts | 9 ++ packages/core/src/hooks/hookEventHandler.ts | 8 ++ packages/core/src/hooks/hookSystem.ts | 4 + packages/core/src/hooks/types.ts | 37 ++++++ packages/core/src/scheduler/scheduler.test.ts | 113 ++++++++++++++++++ packages/core/src/scheduler/scheduler.ts | 73 ++++++++++- packages/core/src/scheduler/state-manager.ts | 13 ++ .../core/src/scheduler/tool-executor.test.ts | 12 +- packages/core/src/scheduler/tool-executor.ts | 12 +- packages/core/src/scheduler/types.ts | 14 +++ packages/core/src/tools/tools.ts | 9 ++ 23 files changed, 567 insertions(+), 26 deletions(-) create mode 100644 integration-tests/hooks-system.tail-tool-call.responses diff --git a/docs/extensions/reference.md b/docs/extensions/reference.md index b4a0df7336..d36df94d78 100644 --- a/docs/extensions/reference.md +++ b/docs/extensions/reference.md @@ -116,7 +116,9 @@ The manifest file defines the extension's behavior and configuration. "description": "My awesome extension", "mcpServers": { "my-server": { - "command": "node my-server.js" + "command": "node", + "args": ["${extensionPath}/my-server.js"], + "cwd": "${extensionPath}" } }, "contextFileName": "GEMINI.md", @@ -124,19 +126,41 @@ The manifest file defines the extension's behavior and configuration. } ``` -- `name`: A unique identifier for the extension. Use lowercase letters, numbers, - and dashes. This name must match the extension's directory name. -- `version`: The current version of the extension. -- `description`: A short summary shown in the extension gallery. -- `mcpServers`: A map of Model Context Protocol (MCP) - servers. Extension servers follow the same format as standard - [CLI configuration](../reference/configuration.md). -- `contextFileName`: The name of the context file (defaults to `GEMINI.md`). Can - also be an array of strings to load multiple context files. -- `excludeTools`: An array of tools to block from the model. You can restrict - specific arguments, such as `run_shell_command(rm -rf)`. -- `themes`: An optional list of themes provided by the extension. See - [Themes](../cli/themes.md) for more information. +- `name`: The name of the extension. This is used to uniquely identify the + extension and for conflict resolution when extension commands have the same + name as user or project commands. The name should be lowercase or numbers and + use dashes instead of underscores or spaces. This is how users will refer to + your extension in the CLI. Note that we expect this name to match the + extension directory name. +- `version`: The version of the extension. +- `description`: A short description of the extension. This will be displayed on + [geminicli.com/extensions](https://geminicli.com/extensions). +- `mcpServers`: A map of MCP servers to settings. The key is the name of the + server, and the value is the server configuration. These servers will be + loaded on startup just like MCP servers defined in a + [`settings.json` file](../reference/configuration.md). If both an extension + and a `settings.json` file define an MCP server with the same name, the server + defined in the `settings.json` file takes precedence. + - Note that all MCP server configuration options are supported except for + `trust`. + - For portability, you should use `${extensionPath}` to refer to files within + your extension directory. + - Separate your executable and its arguments using `command` and `args` + instead of putting them both in `command`. +- `contextFileName`: The name of the file that contains the context for the + extension. This will be used to load the context from the extension directory. + If this property is not used but a `GEMINI.md` file is present in your + extension directory, then that file will be loaded. +- `excludeTools`: An array of tool names to exclude from the model. You can also + specify command-specific restrictions for tools that support it, like the + `run_shell_command` tool. For example, + `"excludeTools": ["run_shell_command(rm -rf)"]` will block the `rm -rf` + command. Note that this differs from the MCP server `excludeTools` + functionality, which can be listed in the MCP server config. + +When Gemini CLI starts, it loads all the extensions and merges their +configurations. If there are any conflicts, the workspace configuration takes +precedence. ### Extension settings diff --git a/docs/hooks/reference.md b/docs/hooks/reference.md index 452edb378d..9b7226ac05 100644 --- a/docs/hooks/reference.md +++ b/docs/hooks/reference.md @@ -98,6 +98,8 @@ and parameter rewriting. - `tool_name`: (`string`) The name of the tool being called. - `tool_input`: (`object`) The raw arguments generated by the model. - `mcp_context`: (`object`) Optional metadata for MCP-based tools. + - `original_request_name`: (`string`) The original name of the tool being + called, if this is a tail tool call. - **Relevant Output Fields**: - `decision`: Set to `"deny"` (or `"block"`) to prevent the tool from executing. @@ -120,12 +122,18 @@ hiding sensitive output from the agent. - `tool_response`: (`object`) The result containing `llmContent`, `returnDisplay`, and optional `error`. - `mcp_context`: (`object`) + - `original_request_name`: (`string`) The original name of the tool being + called, if this is a tail tool call. - **Relevant Output Fields**: - `decision`: Set to `"deny"` to hide the real tool output from the agent. - `reason`: Required if denied. This text **replaces** the tool result sent back to the model. - `hookSpecificOutput.additionalContext`: Text that is **appended** to the tool result for the agent. + - `hookSpecificOutput.tailToolCallRequest`: (`{ name: string, args: object }`) + A request to execute another tool immediately after this one. The result of + this "tail call" will replace the original tool's response. Ideal for + programmatic tool routing. - `continue`: Set to `false` to **kill the entire agent loop** immediately. - **Exit Code 2 (Block Result)**: Hides the tool result. Uses `stderr` as the replacement content sent to the agent. **The turn continues.** diff --git a/integration-tests/hooks-system.tail-tool-call.responses b/integration-tests/hooks-system.tail-tool-call.responses new file mode 100644 index 0000000000..13dc3fde4d --- /dev/null +++ b/integration-tests/hooks-system.tail-tool-call.responses @@ -0,0 +1,2 @@ +{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"functionCall":{"name":"read_file","args":{"file_path":"original.txt"}}}],"role":"model"},"finishReason":"STOP","index":0}]}]} +{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"text":"Tail call completed successfully."}],"role":"model"},"finishReason":"STOP","index":0}]}]} \ No newline at end of file diff --git a/integration-tests/hooks-system.test.ts b/integration-tests/hooks-system.test.ts index 2db1019c5f..479851957b 100644 --- a/integration-tests/hooks-system.test.ts +++ b/integration-tests/hooks-system.test.ts @@ -286,6 +286,113 @@ describe('Hooks System Integration', () => { }); }); + describe('Command Hooks - Tail Tool Calls', () => { + it('should execute a tail tool call from AfterTool hooks and replace original response', async () => { + // Create a script that acts as the hook. + // It will trigger on "read_file" and issue a tail call to "write_file". + rig.setup('should execute a tail tool call from AfterTool hooks', { + fakeResponsesPath: join( + import.meta.dirname, + 'hooks-system.tail-tool-call.responses', + ), + }); + + const hookOutput = { + decision: 'allow', + hookSpecificOutput: { + hookEventName: 'AfterTool', + tailToolCallRequest: { + name: 'write_file', + args: { + file_path: 'tail-called-file.txt', + content: 'Content from tail call', + }, + }, + }, + }; + + const hookScript = `console.log(JSON.stringify(${JSON.stringify( + hookOutput, + )})); process.exit(0);`; + + const scriptPath = join(rig.testDir!, 'tail_call_hook.js'); + writeFileSync(scriptPath, hookScript); + const commandPath = scriptPath.replace(/\\/g, '/'); + + rig.setup('should execute a tail tool call from AfterTool hooks', { + fakeResponsesPath: join( + import.meta.dirname, + 'hooks-system.tail-tool-call.responses', + ), + settings: { + hooksConfig: { + enabled: true, + }, + hooks: { + AfterTool: [ + { + matcher: 'read_file', + hooks: [ + { + type: 'command', + command: `node "${commandPath}"`, + timeout: 5000, + }, + ], + }, + ], + }, + }, + }); + + // Create a test file to trigger the read_file tool + rig.createFile('original.txt', 'Original content'); + + const cliOutput = await rig.run({ + args: 'Read original.txt', // Fake responses should trigger read_file on this + }); + + // 1. Verify that write_file was called (as a tail call replacing read_file) + // Since read_file was replaced before finalizing, it will not appear in the tool logs. + const foundWriteFile = await rig.waitForToolCall('write_file'); + expect(foundWriteFile).toBeTruthy(); + + // Ensure hook logs are flushed and the final LLM response is received. + // The mock LLM is configured to respond with "Tail call completed successfully." + expect(cliOutput).toContain('Tail call completed successfully.'); + + // Ensure telemetry is written to disk + await rig.waitForTelemetryReady(); + + // Read hook logs to debug + const hookLogs = rig.readHookLogs(); + const relevantHookLog = hookLogs.find( + (l) => l.hookCall.hook_event_name === 'AfterTool', + ); + + expect(relevantHookLog).toBeDefined(); + + // 2. Verify write_file was executed. + // In non-interactive mode, the CLI deduplicates tool execution logs by callId. + // Since a tail call reuses the original callId, "Tool: write_file" is not printed. + // Instead, we verify the side-effect (file creation) and the telemetry log. + + // 3. Verify the tail-called tool actually wrote the file + const modifiedContent = rig.readFile('tail-called-file.txt'); + expect(modifiedContent).toBe('Content from tail call'); + + // 4. Verify telemetry for the final tool call. + // The original 'read_file' call is replaced, so only 'write_file' is finalized and logged. + const toolLogs = rig.readToolLogs(); + const successfulTools = toolLogs.filter((t) => t.toolRequest.success); + expect( + successfulTools.some((t) => t.toolRequest.name === 'write_file'), + ).toBeTruthy(); + // The original request name should be preserved in the log payload if possible, + // but the executed tool name is 'write_file'. + }); + }); + describe('BeforeModel Hooks - LLM Request Modification', () => { it('should modify LLM requests with BeforeModel hooks', async () => { // Create a hook script that replaces the LLM request with a modified version diff --git a/packages/cli/src/ui/components/messages/ShellToolMessage.tsx b/packages/cli/src/ui/components/messages/ShellToolMessage.tsx index 54abbc09d3..8e760b28e7 100644 --- a/packages/cli/src/ui/components/messages/ShellToolMessage.tsx +++ b/packages/cli/src/ui/components/messages/ShellToolMessage.tsx @@ -58,7 +58,10 @@ export const ShellToolMessage: React.FC = ({ borderColor, borderDimColor, + isExpandable, + + originalRequestName, }) => { const { activePtyId: activeShellPtyId, @@ -129,6 +132,7 @@ export const ShellToolMessage: React.FC = ({ status={status} description={description} emphasis={emphasis} + originalRequestName={originalRequestName} /> = ({ config, progressMessage, progressPercent, + originalRequestName, }) => { const isThisShellFocused = checkIsShellFocused( name, @@ -93,6 +94,7 @@ export const ToolMessage: React.FC = ({ emphasis={emphasis} progressMessage={progressMessage} progressPercent={progressPercent} + originalRequestName={originalRequestName} /> = ({ @@ -198,6 +199,7 @@ export const ToolInfo: React.FC = ({ emphasis, progressMessage, progressPercent, + originalRequestName, }) => { const status = mapCoreStatusToDisplayStatus(coreStatus); const nameColor = React.useMemo(() => { @@ -242,6 +244,12 @@ export const ToolInfo: React.FC = ({ {name} + {originalRequestName && originalRequestName !== name && ( + + {' '} + (redirection from {originalRequestName}) + + )} {!isCompletedAskUser && ( <> {' '} diff --git a/packages/cli/src/ui/hooks/toolMapping.test.ts b/packages/cli/src/ui/hooks/toolMapping.test.ts index 241b5d94f0..c97f4a526d 100644 --- a/packages/cli/src/ui/hooks/toolMapping.test.ts +++ b/packages/cli/src/ui/hooks/toolMapping.test.ts @@ -275,5 +275,20 @@ describe('toolMapping', () => { expect(result.tools[0].resultDisplay).toBeUndefined(); expect(result.tools[0].status).toBe(CoreToolCallStatus.Scheduled); }); + + it('propagates originalRequestName correctly', () => { + const toolCall: ScheduledToolCall = { + status: CoreToolCallStatus.Scheduled, + request: { + ...mockRequest, + originalRequestName: 'original_tool', + }, + tool: mockTool, + invocation: mockInvocation, + }; + + const result = mapToDisplay(toolCall); + expect(result.tools[0].originalRequestName).toBe('original_tool'); + }); }); }); diff --git a/packages/cli/src/ui/hooks/toolMapping.ts b/packages/cli/src/ui/hooks/toolMapping.ts index ded17f29a9..6f484d5d25 100644 --- a/packages/cli/src/ui/hooks/toolMapping.ts +++ b/packages/cli/src/ui/hooks/toolMapping.ts @@ -107,6 +107,7 @@ export function mapToDisplay( progressMessage, progressPercent, approvalMode: call.approvalMode, + originalRequestName: call.request.originalRequestName, }; }); diff --git a/packages/cli/src/ui/hooks/useToolScheduler.test.ts b/packages/cli/src/ui/hooks/useToolScheduler.test.ts index ddf43944f6..ca9df3d5d3 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.test.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.test.ts @@ -13,6 +13,7 @@ import { Scheduler, type Config, type MessageBus, + type ExecutingToolCall, type CompletedToolCall, type ToolCallsUpdateMessage, type AnyDeclarativeTool, @@ -110,7 +111,7 @@ describe('useToolScheduler', () => { tool: createMockTool(), invocation: createMockInvocation(), liveOutput: 'Loading...', - }; + } as ExecutingToolCall; act(() => { void mockMessageBus.publish({ @@ -405,4 +406,62 @@ describe('useToolScheduler', () => { toolCalls.find((t) => t.request.callId === 'call-sub')?.schedulerId, ).toBe('subagent-1'); }); + + it('adapts success/error status to executing when a tail call is present', () => { + vi.useFakeTimers(); + const { result } = renderHook(() => + useToolScheduler( + vi.fn().mockResolvedValue(undefined), + mockConfig, + () => undefined, + ), + ); + + const startTime = Date.now(); + vi.advanceTimersByTime(1000); + + const mockToolCall = { + status: CoreToolCallStatus.Success as const, + request: { + callId: 'call-1', + name: 'test_tool', + args: {}, + isClientInitiated: false, + prompt_id: 'p1', + }, + tool: createMockTool(), + invocation: createMockInvocation(), + response: { + callId: 'call-1', + resultDisplay: 'OK', + responseParts: [], + error: undefined, + errorType: undefined, + }, + tailToolCallRequest: { + name: 'tail_tool', + args: {}, + isClientInitiated: false, + prompt_id: '123', + }, + }; + + act(() => { + void mockMessageBus.publish({ + type: MessageBusType.TOOL_CALLS_UPDATE, + toolCalls: [mockToolCall], + schedulerId: ROOT_SCHEDULER_ID, + } as ToolCallsUpdateMessage); + }); + + const [toolCalls, , , , , lastOutputTime] = result.current; + + // Check if status has been adapted to 'executing' + expect(toolCalls[0].status).toBe(CoreToolCallStatus.Executing); + + // Check if lastOutputTime was updated due to the transitional state + expect(lastOutputTime).toBeGreaterThan(startTime); + + vi.useRealTimers(); + }); }); diff --git a/packages/cli/src/ui/hooks/useToolScheduler.ts b/packages/cli/src/ui/hooks/useToolScheduler.ts index 56b1622468..f09ed9b81f 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.ts @@ -14,6 +14,7 @@ import { Scheduler, type EditorType, type ToolCallsUpdateMessage, + CoreToolCallStatus, } from '@google/gemini-cli-core'; import { useCallback, useState, useMemo, useEffect, useRef } from 'react'; @@ -115,7 +116,16 @@ export function useToolScheduler( useEffect(() => { const handler = (event: ToolCallsUpdateMessage) => { // Update output timer for UI spinners (Side Effect) - if (event.toolCalls.some((tc) => tc.status === 'executing')) { + const hasExecuting = event.toolCalls.some( + (tc) => + tc.status === CoreToolCallStatus.Executing || + ((tc.status === CoreToolCallStatus.Success || + tc.status === CoreToolCallStatus.Error) && + 'tailToolCallRequest' in tc && + tc.tailToolCallRequest != null), + ); + + if (hasExecuting) { setLastToolOutputTime(Date.now()); } @@ -238,9 +248,23 @@ function adaptToolCalls( const prev = prevMap.get(coreCall.request.callId); const responseSubmittedToGemini = prev?.responseSubmittedToGemini ?? false; + let status = coreCall.status; + // If a tool call has completed but scheduled a tail call, it is in a transitional + // state. Force the UI to render it as "executing". + if ( + (status === CoreToolCallStatus.Success || + status === CoreToolCallStatus.Error) && + 'tailToolCallRequest' in coreCall && + coreCall.tailToolCallRequest != null + ) { + status = CoreToolCallStatus.Executing; + } + + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion return { ...coreCall, + status, responseSubmittedToGemini, - }; + } as TrackedToolCall; }); } diff --git a/packages/cli/src/ui/types.ts b/packages/cli/src/ui/types.ts index 2d40f0a48c..68a029e267 100644 --- a/packages/cli/src/ui/types.ts +++ b/packages/cli/src/ui/types.ts @@ -110,6 +110,7 @@ export interface IndividualToolCallDisplay { approvalMode?: ApprovalMode; progressMessage?: string; progressPercent?: number; + originalRequestName?: string; } export interface CompressionProps { diff --git a/packages/core/src/core/coreToolHookTriggers.ts b/packages/core/src/core/coreToolHookTriggers.ts index cb98d3af20..9c83253903 100644 --- a/packages/core/src/core/coreToolHookTriggers.ts +++ b/packages/core/src/core/coreToolHookTriggers.ts @@ -75,6 +75,7 @@ export async function executeToolWithHooks( shellExecutionConfig?: ShellExecutionConfig, setPidCallback?: (pid: number) => void, config?: Config, + originalRequestName?: string, ): Promise { // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion const toolInput = (invocation.params || {}) as Record; @@ -90,6 +91,7 @@ export async function executeToolWithHooks( toolName, toolInput, mcpContext, + originalRequestName, ); // Check if hook requested to stop entire agent execution @@ -196,6 +198,7 @@ export async function executeToolWithHooks( error: toolResult.error, }, mcpContext, + originalRequestName, ); // Check if hook requested to stop entire agent execution @@ -242,6 +245,12 @@ export async function executeToolWithHooks( toolResult.llmContent = wrappedContext; } } + + // Check if the hook requested a tail tool call + const tailToolCallRequest = afterOutput?.getTailToolCallRequest(); + if (tailToolCallRequest) { + toolResult.tailToolCallRequest = tailToolCallRequest; + } } return toolResult; diff --git a/packages/core/src/hooks/hookEventHandler.ts b/packages/core/src/hooks/hookEventHandler.ts index 3301ffb69d..0e744c3be7 100644 --- a/packages/core/src/hooks/hookEventHandler.ts +++ b/packages/core/src/hooks/hookEventHandler.ts @@ -76,12 +76,16 @@ export class HookEventHandler { toolName: string, toolInput: Record, mcpContext?: McpToolContext, + originalRequestName?: string, ): Promise { const input: BeforeToolInput = { ...this.createBaseInput(HookEventName.BeforeTool), tool_name: toolName, tool_input: toolInput, ...(mcpContext && { mcp_context: mcpContext }), + ...(originalRequestName && { + original_request_name: originalRequestName, + }), }; const context: HookEventContext = { toolName }; @@ -97,6 +101,7 @@ export class HookEventHandler { toolInput: Record, toolResponse: Record, mcpContext?: McpToolContext, + originalRequestName?: string, ): Promise { const input: AfterToolInput = { ...this.createBaseInput(HookEventName.AfterTool), @@ -104,6 +109,9 @@ export class HookEventHandler { tool_input: toolInput, tool_response: toolResponse, ...(mcpContext && { mcp_context: mcpContext }), + ...(originalRequestName && { + original_request_name: originalRequestName, + }), }; const context: HookEventContext = { toolName }; diff --git a/packages/core/src/hooks/hookSystem.ts b/packages/core/src/hooks/hookSystem.ts index 1d5f346210..56eb10b015 100644 --- a/packages/core/src/hooks/hookSystem.ts +++ b/packages/core/src/hooks/hookSystem.ts @@ -368,12 +368,14 @@ export class HookSystem { toolName: string, toolInput: Record, mcpContext?: McpToolContext, + originalRequestName?: string, ): Promise { try { const result = await this.hookEventHandler.fireBeforeToolEvent( toolName, toolInput, mcpContext, + originalRequestName, ); return result.finalOutput; } catch (error) { @@ -391,6 +393,7 @@ export class HookSystem { error: unknown; }, mcpContext?: McpToolContext, + originalRequestName?: string, ): Promise { try { const result = await this.hookEventHandler.fireAfterToolEvent( @@ -398,6 +401,7 @@ export class HookSystem { toolInput, toolResponse as Record, mcpContext, + originalRequestName, ); return result.finalOutput; } catch (error) { diff --git a/packages/core/src/hooks/types.ts b/packages/core/src/hooks/types.ts index b4a8ce27e8..ba579d81e6 100644 --- a/packages/core/src/hooks/types.ts +++ b/packages/core/src/hooks/types.ts @@ -253,6 +253,33 @@ export class DefaultHookOutput implements HookOutput { shouldClearContext(): boolean { return false; } + + /** + * Optional request to execute another tool immediately after this one. + * The result of this tail call will replace the original tool's response. + */ + getTailToolCallRequest(): + | { + name: string; + args: Record; + } + | undefined { + if ( + this.hookSpecificOutput && + 'tailToolCallRequest' in this.hookSpecificOutput + ) { + const request = this.hookSpecificOutput['tailToolCallRequest']; + if ( + typeof request === 'object' && + request !== null && + !Array.isArray(request) + ) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + return request as { name: string; args: Record }; + } + } + return undefined; + } } /** @@ -430,6 +457,7 @@ export interface BeforeToolInput extends HookInput { tool_name: string; tool_input: Record; mcp_context?: McpToolContext; // Only present for MCP tools + original_request_name?: string; } /** @@ -450,6 +478,7 @@ export interface AfterToolInput extends HookInput { tool_input: Record; tool_response: Record; mcp_context?: McpToolContext; // Only present for MCP tools + original_request_name?: string; } /** @@ -459,6 +488,14 @@ export interface AfterToolOutput extends HookOutput { hookSpecificOutput?: { hookEventName: 'AfterTool'; additionalContext?: string; + /** + * Optional request to execute another tool immediately after this one. + * The result of this tail call will replace the original tool's response. + */ + tailToolCallRequest?: { + name: string; + args: Record; + }; }; } diff --git a/packages/core/src/scheduler/scheduler.test.ts b/packages/core/src/scheduler/scheduler.test.ts index 61699d07a6..97ab4bfcd4 100644 --- a/packages/core/src/scheduler/scheduler.test.ts +++ b/packages/core/src/scheduler/scheduler.test.ts @@ -201,6 +201,12 @@ describe('Scheduler (Orchestrator)', () => { mockQueue.length = 0; }), clearBatch: vi.fn(), + replaceActiveCallWithTailCall: vi.fn((id: string, nextCall: ToolCall) => { + if (mockActiveCallsMap.has(id)) { + mockActiveCallsMap.delete(id); + mockQueue.unshift(nextCall); + } + }), } as unknown as Mocked; // Define getters for accessors idiomatically @@ -1006,6 +1012,113 @@ describe('Scheduler (Orchestrator)', () => { const result = await (scheduler as any)._processNextItem(signal); expect(result).toBe(false); }); + + describe('Tail Calls', () => { + it('should replace the active call with a new tool call and re-run the loop when tail call is requested', async () => { + // Setup: Tool A will return a success with a tail call request to Tool B + const mockResponse = { + callId: 'call-1', + responseParts: [], + } as unknown as ToolCallResponseInfo; + + mockExecutor.execute + .mockResolvedValueOnce({ + status: 'success', + response: mockResponse, + tailToolCallRequest: { + name: 'tool-b', + args: { key: 'value' }, + }, + request: req1, + } as unknown as SuccessfulToolCall) + .mockResolvedValueOnce({ + status: 'success', + response: mockResponse, + request: { + ...req1, + name: 'tool-b', + args: { key: 'value' }, + originalRequestName: 'test-tool', + }, + } as unknown as SuccessfulToolCall); + + const mockToolB = { + name: 'tool-b', + build: vi.fn().mockReturnValue({}), + } as unknown as AnyDeclarativeTool; + + vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockToolB); + + await scheduler.schedule(req1, signal); + + // Assert: The state manager is instructed to replace the call + expect( + mockStateManager.replaceActiveCallWithTailCall, + ).toHaveBeenCalledWith( + 'call-1', + expect.objectContaining({ + request: expect.objectContaining({ + callId: 'call-1', + name: 'tool-b', + args: { key: 'value' }, + originalRequestName: 'test-tool', // Preserves original name + }), + tool: mockToolB, + }), + ); + + // Assert: The executor should be called twice (once for Tool A, once for Tool B) + expect(mockExecutor.execute).toHaveBeenCalledTimes(2); + }); + + it('should inject an errored tool call if the tail tool is not found', async () => { + const mockResponse = { + callId: 'call-1', + responseParts: [], + } as unknown as ToolCallResponseInfo; + + mockExecutor.execute.mockResolvedValue({ + status: 'success', + response: mockResponse, + tailToolCallRequest: { + name: 'missing-tool', + args: {}, + }, + request: req1, + } as unknown as SuccessfulToolCall); + + // Tool registry returns undefined for missing-tool, but valid tool for test-tool + vi.mocked(mockToolRegistry.getTool).mockImplementation((name) => { + if (name === 'test-tool') { + return { + name: 'test-tool', + build: vi.fn().mockReturnValue({}), + } as unknown as AnyDeclarativeTool; + } + return undefined; + }); + + await scheduler.schedule(req1, signal); + + // Assert: Replaces active call with an errored call + expect( + mockStateManager.replaceActiveCallWithTailCall, + ).toHaveBeenCalledWith( + 'call-1', + expect.objectContaining({ + status: 'error', + request: expect.objectContaining({ + callId: 'call-1', + name: 'missing-tool', // Name of the failed tail call + originalRequestName: 'test-tool', + }), + response: expect.objectContaining({ + errorType: ToolErrorType.TOOL_NOT_REGISTERED, + }), + }), + ); + }); + }); }); describe('Tool Call Context Propagation', () => { diff --git a/packages/core/src/scheduler/scheduler.ts b/packages/core/src/scheduler/scheduler.ts index 3ee55975f1..0733370645 100644 --- a/packages/core/src/scheduler/scheduler.ts +++ b/packages/core/src/scheduler/scheduler.ts @@ -19,6 +19,7 @@ import { type ExecutingToolCall, type ValidatingToolCall, type ErroredToolCall, + type SuccessfulToolCall, CoreToolCallStatus, type ScheduledToolCall, } from './types.js'; @@ -446,13 +447,16 @@ export class Scheduler { c.status === CoreToolCallStatus.Scheduled || this.isTerminal(c.status), ); + let madeProgress = false; if (allReady && scheduledCalls.length > 0) { - await Promise.all(scheduledCalls.map((c) => this._execute(c, signal))); + const execResults = await Promise.all( + scheduledCalls.map((c) => this._execute(c, signal)), + ); + madeProgress = execResults.some((res) => res); } // 3. Finalize terminal calls activeCalls = this.state.allActiveCalls; - let madeProgress = false; for (const call of activeCalls) { if (this.isTerminal(call.status)) { this.state.finalizeCall(call.request.callId); @@ -595,12 +599,12 @@ export class Scheduler { // --- Sub-phase Handlers --- /** - * Executes the tool and records the result. + * Executes the tool and records the result. Returns true if a new tool call was added. */ private async _execute( toolCall: ScheduledToolCall, signal: AbortSignal, - ): Promise { + ): Promise { const callId = toolCall.request.callId; if (signal.aborted) { this.state.updateStatus( @@ -608,7 +612,7 @@ export class Scheduler { CoreToolCallStatus.Cancelled, 'Operation cancelled', ); - return; + return false; } this.state.updateStatus(callId, CoreToolCallStatus.Executing); @@ -642,6 +646,64 @@ export class Scheduler { }), ); + if ( + (result.status === CoreToolCallStatus.Success || + result.status === CoreToolCallStatus.Error) && + result.tailToolCallRequest + ) { + // Log the intermediate tool call before it gets replaced. + const intermediateCall: SuccessfulToolCall | ErroredToolCall = { + request: activeCall.request, + tool: activeCall.tool, + invocation: activeCall.invocation, + status: result.status, + response: result.response, + durationMs: activeCall.startTime + ? Date.now() - activeCall.startTime + : undefined, + outcome: activeCall.outcome, + schedulerId: this.schedulerId, + }; + logToolCall(this.config, new ToolCallEvent(intermediateCall)); + + const tailRequest = result.tailToolCallRequest; + const originalCallId = result.request.callId; + const originalRequestName = + result.request.originalRequestName || result.request.name; + + const newTool = this.config.getToolRegistry().getTool(tailRequest.name); + + const newRequest: ToolCallRequestInfo = { + callId: originalCallId, + name: tailRequest.name, + args: tailRequest.args, + originalRequestName, + isClientInitiated: result.request.isClientInitiated, + prompt_id: result.request.prompt_id, + schedulerId: this.schedulerId, + }; + + if (!newTool) { + // Enqueue an errored tool call + const errorCall = this._createToolNotFoundErroredToolCall( + newRequest, + this.config.getToolRegistry().getAllToolNames(), + ); + this.state.replaceActiveCallWithTailCall(callId, errorCall); + } else { + // Enqueue a validating tool call for the new tail tool + const validatingCall = this._validateAndCreateToolCall( + newRequest, + newTool, + activeCall.approvalMode ?? this.config.getApprovalMode(), + ); + this.state.replaceActiveCallWithTailCall(callId, validatingCall); + } + + // Loop continues, picking up the new tail call at the front of the queue. + return true; + } + if (result.status === CoreToolCallStatus.Success) { this.state.updateStatus( callId, @@ -661,6 +723,7 @@ export class Scheduler { result.response, ); } + return false; } private _processNextInRequestQueue() { diff --git a/packages/core/src/scheduler/state-manager.ts b/packages/core/src/scheduler/state-manager.ts index fb16125340..fe727f6dd3 100644 --- a/packages/core/src/scheduler/state-manager.ts +++ b/packages/core/src/scheduler/state-manager.ts @@ -187,6 +187,19 @@ export class SchedulerStateManager { this.emitUpdate(); } + /** + * Replaces the currently active call with a new call, placing the new call + * at the front of the queue to be processed immediately in the next tick. + * Used for Tail Calls to chain execution without finalizing the original call. + */ + replaceActiveCallWithTailCall(callId: string, nextCall: ToolCall): void { + if (this.activeCalls.has(callId)) { + this.activeCalls.delete(callId); + this.queue.unshift(nextCall); + this.emitUpdate(); + } + } + cancelAllQueued(reason: string): void { if (this.queue.length === 0) { return; diff --git a/packages/core/src/scheduler/tool-executor.test.ts b/packages/core/src/scheduler/tool-executor.test.ts index 1cbee019c6..29db841aac 100644 --- a/packages/core/src/scheduler/tool-executor.test.ts +++ b/packages/core/src/scheduler/tool-executor.test.ts @@ -252,7 +252,17 @@ describe('ToolExecutor', () => { // 2. Mock executeToolWithHooks to trigger the PID callback const testPid = 12345; vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockImplementation( - async (_inv, _name, _sig, _tool, _liveCb, _shellCfg, setPidCallback) => { + async ( + _inv, + _name, + _sig, + _tool, + _liveCb, + _shellCfg, + setPidCallback, + _config, + _originalRequestName, + ) => { // Simulate the shell tool reporting a PID if (setPidCallback) { setPidCallback(testPid); diff --git a/packages/core/src/scheduler/tool-executor.ts b/packages/core/src/scheduler/tool-executor.ts index b94b0e5184..9ae00b24a7 100644 --- a/packages/core/src/scheduler/tool-executor.ts +++ b/packages/core/src/scheduler/tool-executor.ts @@ -99,6 +99,7 @@ export class ToolExecutor { shellExecutionConfig, setPidCallback, this.config, + request.originalRequestName, ); } else { promise = executeToolWithHooks( @@ -110,6 +111,7 @@ export class ToolExecutor { shellExecutionConfig, undefined, this.config, + request.originalRequestName, ); } @@ -133,6 +135,7 @@ export class ToolExecutor { new Error(toolResult.error.message), toolResult.error.type, displayText, + toolResult.tailToolCallRequest, ); } } catch (executionError: unknown) { @@ -204,7 +207,7 @@ export class ToolExecutor { ): Promise { let content = toolResult.llmContent; let outputFile: string | undefined; - const toolName = call.request.name; + const toolName = call.request.originalRequestName || call.request.name; const callId = call.request.callId; if (typeof content === 'string' && toolName === SHELL_TOOL_NAME) { @@ -268,6 +271,7 @@ export class ToolExecutor { startTime, endTime: Date.now(), outcome: call.outcome, + tailToolCallRequest: toolResult.tailToolCallRequest, }; } @@ -276,6 +280,7 @@ export class ToolExecutor { error: Error, errorType?: ToolErrorType, returnDisplay?: string, + tailToolCallRequest?: { name: string; args: Record }, ): ErroredToolCall { const response = this.createErrorResponse( call.request, @@ -289,11 +294,12 @@ export class ToolExecutor { status: CoreToolCallStatus.Error, request: call.request, response, - tool: call.tool, + tool: 'tool' in call ? call.tool : undefined, durationMs: startTime ? Date.now() - startTime : undefined, startTime, endTime: Date.now(), outcome: call.outcome, + tailToolCallRequest, }; } @@ -311,7 +317,7 @@ export class ToolExecutor { { functionResponse: { id: request.callId, - name: request.name, + name: request.originalRequestName || request.name, response: { error: error.message }, }, }, diff --git a/packages/core/src/scheduler/types.ts b/packages/core/src/scheduler/types.ts index 5fe6028bac..6486c04997 100644 --- a/packages/core/src/scheduler/types.ts +++ b/packages/core/src/scheduler/types.ts @@ -36,6 +36,11 @@ export interface ToolCallRequestInfo { callId: string; name: string; args: Record; + /** + * The original name of the tool requested by the model. + * This is used for tail calls to ensure the final response retains the original name. + */ + originalRequestName?: string; isClientInitiated: boolean; prompt_id: string; checkpoint?: string; @@ -58,6 +63,12 @@ export interface ToolCallResponseInfo { data?: Record; } +/** Request to execute another tool immediately after a completed one. */ +export interface TailToolCallRequest { + name: string; + args: Record; +} + export type ValidatingToolCall = { status: CoreToolCallStatus.Validating; request: ToolCallRequestInfo; @@ -91,6 +102,7 @@ export type ErroredToolCall = { outcome?: ToolConfirmationOutcome; schedulerId?: string; approvalMode?: ApprovalMode; + tailToolCallRequest?: TailToolCallRequest; }; export type SuccessfulToolCall = { @@ -105,6 +117,7 @@ export type SuccessfulToolCall = { outcome?: ToolConfirmationOutcome; schedulerId?: string; approvalMode?: ApprovalMode; + tailToolCallRequest?: TailToolCallRequest; }; export type ExecutingToolCall = { @@ -120,6 +133,7 @@ export type ExecutingToolCall = { pid?: number; schedulerId?: string; approvalMode?: ApprovalMode; + tailToolCallRequest?: TailToolCallRequest; }; export type CancelledToolCall = { diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 94188deca0..e06dff160e 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -579,6 +579,15 @@ export interface ToolResult { * Optional data payload for passing structured information back to the caller. */ data?: Record; + + /** + * Optional request to execute another tool immediately after this one. + * The result of this tail call will replace the original tool's response. + */ + tailToolCallRequest?: { + name: string; + args: Record; + }; } /**