feat: implement AfterTool tail tool calls (#18486)

This commit is contained in:
Steven Robertson
2026-02-23 19:57:00 -08:00
committed by GitHub
parent ee5eb70070
commit b0ceb74462
23 changed files with 567 additions and 26 deletions
+38 -14
View File
@@ -116,7 +116,9 @@ The manifest file defines the extension's behavior and configuration.
"description": "My awesome extension",
"mcpServers": {
"my-server": {
"command": "node my-server.js"
"command": "node",
"args": ["${extensionPath}/my-server.js"],
"cwd": "${extensionPath}"
}
},
"contextFileName": "GEMINI.md",
@@ -124,19 +126,41 @@ The manifest file defines the extension's behavior and configuration.
}
```
- `name`: A unique identifier for the extension. Use lowercase letters, numbers,
and dashes. This name must match the extension's directory name.
- `version`: The current version of the extension.
- `description`: A short summary shown in the extension gallery.
- <a id="mcp-servers"></a>`mcpServers`: A map of Model Context Protocol (MCP)
servers. Extension servers follow the same format as standard
[CLI configuration](../reference/configuration.md).
- `contextFileName`: The name of the context file (defaults to `GEMINI.md`). Can
also be an array of strings to load multiple context files.
- `excludeTools`: An array of tools to block from the model. You can restrict
specific arguments, such as `run_shell_command(rm -rf)`.
- `themes`: An optional list of themes provided by the extension. See
[Themes](../cli/themes.md) for more information.
- `name`: The name of the extension. This is used to uniquely identify the
extension and for conflict resolution when extension commands have the same
name as user or project commands. The name should be lowercase or numbers and
use dashes instead of underscores or spaces. This is how users will refer to
your extension in the CLI. Note that we expect this name to match the
extension directory name.
- `version`: The version of the extension.
- `description`: A short description of the extension. This will be displayed on
[geminicli.com/extensions](https://geminicli.com/extensions).
- `mcpServers`: A map of MCP servers to settings. The key is the name of the
server, and the value is the server configuration. These servers will be
loaded on startup just like MCP servers defined in a
[`settings.json` file](../reference/configuration.md). If both an extension
and a `settings.json` file define an MCP server with the same name, the server
defined in the `settings.json` file takes precedence.
- Note that all MCP server configuration options are supported except for
`trust`.
- For portability, you should use `${extensionPath}` to refer to files within
your extension directory.
- Separate your executable and its arguments using `command` and `args`
instead of putting them both in `command`.
- `contextFileName`: The name of the file that contains the context for the
extension. This will be used to load the context from the extension directory.
If this property is not used but a `GEMINI.md` file is present in your
extension directory, then that file will be loaded.
- `excludeTools`: An array of tool names to exclude from the model. You can also
specify command-specific restrictions for tools that support it, like the
`run_shell_command` tool. For example,
`"excludeTools": ["run_shell_command(rm -rf)"]` will block the `rm -rf`
command. Note that this differs from the MCP server `excludeTools`
functionality, which can be listed in the MCP server config.
When Gemini CLI starts, it loads all the extensions and merges their
configurations. If there are any conflicts, the workspace configuration takes
precedence.
### Extension settings
+8
View File
@@ -98,6 +98,8 @@ and parameter rewriting.
- `tool_name`: (`string`) The name of the tool being called.
- `tool_input`: (`object`) The raw arguments generated by the model.
- `mcp_context`: (`object`) Optional metadata for MCP-based tools.
- `original_request_name`: (`string`) The original name of the tool being
called, if this is a tail tool call.
- **Relevant Output Fields**:
- `decision`: Set to `"deny"` (or `"block"`) to prevent the tool from
executing.
@@ -120,12 +122,18 @@ hiding sensitive output from the agent.
- `tool_response`: (`object`) The result containing `llmContent`,
`returnDisplay`, and optional `error`.
- `mcp_context`: (`object`)
- `original_request_name`: (`string`) The original name of the tool being
called, if this is a tail tool call.
- **Relevant Output Fields**:
- `decision`: Set to `"deny"` to hide the real tool output from the agent.
- `reason`: Required if denied. This text **replaces** the tool result sent
back to the model.
- `hookSpecificOutput.additionalContext`: Text that is **appended** to the
tool result for the agent.
- `hookSpecificOutput.tailToolCallRequest`: (`{ name: string, args: object }`)
A request to execute another tool immediately after this one. The result of
this "tail call" will replace the original tool's response. Ideal for
programmatic tool routing.
- `continue`: Set to `false` to **kill the entire agent loop** immediately.
- **Exit Code 2 (Block Result)**: Hides the tool result. Uses `stderr` as the
replacement content sent to the agent. **The turn continues.**
@@ -0,0 +1,2 @@
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"functionCall":{"name":"read_file","args":{"file_path":"original.txt"}}}],"role":"model"},"finishReason":"STOP","index":0}]}]}
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"text":"Tail call completed successfully."}],"role":"model"},"finishReason":"STOP","index":0}]}]}
+107
View File
@@ -286,6 +286,113 @@ describe('Hooks System Integration', () => {
});
});
describe('Command Hooks - Tail Tool Calls', () => {
it('should execute a tail tool call from AfterTool hooks and replace original response', async () => {
// Create a script that acts as the hook.
// It will trigger on "read_file" and issue a tail call to "write_file".
rig.setup('should execute a tail tool call from AfterTool hooks', {
fakeResponsesPath: join(
import.meta.dirname,
'hooks-system.tail-tool-call.responses',
),
});
const hookOutput = {
decision: 'allow',
hookSpecificOutput: {
hookEventName: 'AfterTool',
tailToolCallRequest: {
name: 'write_file',
args: {
file_path: 'tail-called-file.txt',
content: 'Content from tail call',
},
},
},
};
const hookScript = `console.log(JSON.stringify(${JSON.stringify(
hookOutput,
)})); process.exit(0);`;
const scriptPath = join(rig.testDir!, 'tail_call_hook.js');
writeFileSync(scriptPath, hookScript);
const commandPath = scriptPath.replace(/\\/g, '/');
rig.setup('should execute a tail tool call from AfterTool hooks', {
fakeResponsesPath: join(
import.meta.dirname,
'hooks-system.tail-tool-call.responses',
),
settings: {
hooksConfig: {
enabled: true,
},
hooks: {
AfterTool: [
{
matcher: 'read_file',
hooks: [
{
type: 'command',
command: `node "${commandPath}"`,
timeout: 5000,
},
],
},
],
},
},
});
// Create a test file to trigger the read_file tool
rig.createFile('original.txt', 'Original content');
const cliOutput = await rig.run({
args: 'Read original.txt', // Fake responses should trigger read_file on this
});
// 1. Verify that write_file was called (as a tail call replacing read_file)
// Since read_file was replaced before finalizing, it will not appear in the tool logs.
const foundWriteFile = await rig.waitForToolCall('write_file');
expect(foundWriteFile).toBeTruthy();
// Ensure hook logs are flushed and the final LLM response is received.
// The mock LLM is configured to respond with "Tail call completed successfully."
expect(cliOutput).toContain('Tail call completed successfully.');
// Ensure telemetry is written to disk
await rig.waitForTelemetryReady();
// Read hook logs to debug
const hookLogs = rig.readHookLogs();
const relevantHookLog = hookLogs.find(
(l) => l.hookCall.hook_event_name === 'AfterTool',
);
expect(relevantHookLog).toBeDefined();
// 2. Verify write_file was executed.
// In non-interactive mode, the CLI deduplicates tool execution logs by callId.
// Since a tail call reuses the original callId, "Tool: write_file" is not printed.
// Instead, we verify the side-effect (file creation) and the telemetry log.
// 3. Verify the tail-called tool actually wrote the file
const modifiedContent = rig.readFile('tail-called-file.txt');
expect(modifiedContent).toBe('Content from tail call');
// 4. Verify telemetry for the final tool call.
// The original 'read_file' call is replaced, so only 'write_file' is finalized and logged.
const toolLogs = rig.readToolLogs();
const successfulTools = toolLogs.filter((t) => t.toolRequest.success);
expect(
successfulTools.some((t) => t.toolRequest.name === 'write_file'),
).toBeTruthy();
// The original request name should be preserved in the log payload if possible,
// but the executed tool name is 'write_file'.
});
});
describe('BeforeModel Hooks - LLM Request Modification', () => {
it('should modify LLM requests with BeforeModel hooks', async () => {
// Create a hook script that replaces the LLM request with a modified version
@@ -58,7 +58,10 @@ export const ShellToolMessage: React.FC<ShellToolMessageProps> = ({
borderColor,
borderDimColor,
isExpandable,
originalRequestName,
}) => {
const {
activePtyId: activeShellPtyId,
@@ -129,6 +132,7 @@ export const ShellToolMessage: React.FC<ShellToolMessageProps> = ({
status={status}
description={description}
emphasis={emphasis}
originalRequestName={originalRequestName}
/>
<FocusHint
@@ -57,6 +57,7 @@ export const ToolMessage: React.FC<ToolMessageProps> = ({
config,
progressMessage,
progressPercent,
originalRequestName,
}) => {
const isThisShellFocused = checkIsShellFocused(
name,
@@ -93,6 +94,7 @@ export const ToolMessage: React.FC<ToolMessageProps> = ({
emphasis={emphasis}
progressMessage={progressMessage}
progressPercent={progressPercent}
originalRequestName={originalRequestName}
/>
<FocusHint
shouldShowFocusHint={shouldShowFocusHint}
@@ -189,6 +189,7 @@ type ToolInfoProps = {
emphasis: TextEmphasis;
progressMessage?: string;
progressPercent?: number;
originalRequestName?: string;
};
export const ToolInfo: React.FC<ToolInfoProps> = ({
@@ -198,6 +199,7 @@ export const ToolInfo: React.FC<ToolInfoProps> = ({
emphasis,
progressMessage,
progressPercent,
originalRequestName,
}) => {
const status = mapCoreStatusToDisplayStatus(coreStatus);
const nameColor = React.useMemo<string>(() => {
@@ -242,6 +244,12 @@ export const ToolInfo: React.FC<ToolInfoProps> = ({
<Text color={nameColor} bold>
{name}
</Text>
{originalRequestName && originalRequestName !== name && (
<Text color={theme.text.secondary} italic>
{' '}
(redirection from {originalRequestName})
</Text>
)}
{!isCompletedAskUser && (
<>
{' '}
@@ -275,5 +275,20 @@ describe('toolMapping', () => {
expect(result.tools[0].resultDisplay).toBeUndefined();
expect(result.tools[0].status).toBe(CoreToolCallStatus.Scheduled);
});
it('propagates originalRequestName correctly', () => {
const toolCall: ScheduledToolCall = {
status: CoreToolCallStatus.Scheduled,
request: {
...mockRequest,
originalRequestName: 'original_tool',
},
tool: mockTool,
invocation: mockInvocation,
};
const result = mapToDisplay(toolCall);
expect(result.tools[0].originalRequestName).toBe('original_tool');
});
});
});
+1
View File
@@ -107,6 +107,7 @@ export function mapToDisplay(
progressMessage,
progressPercent,
approvalMode: call.approvalMode,
originalRequestName: call.request.originalRequestName,
};
});
@@ -13,6 +13,7 @@ import {
Scheduler,
type Config,
type MessageBus,
type ExecutingToolCall,
type CompletedToolCall,
type ToolCallsUpdateMessage,
type AnyDeclarativeTool,
@@ -110,7 +111,7 @@ describe('useToolScheduler', () => {
tool: createMockTool(),
invocation: createMockInvocation(),
liveOutput: 'Loading...',
};
} as ExecutingToolCall;
act(() => {
void mockMessageBus.publish({
@@ -405,4 +406,62 @@ describe('useToolScheduler', () => {
toolCalls.find((t) => t.request.callId === 'call-sub')?.schedulerId,
).toBe('subagent-1');
});
it('adapts success/error status to executing when a tail call is present', () => {
vi.useFakeTimers();
const { result } = renderHook(() =>
useToolScheduler(
vi.fn().mockResolvedValue(undefined),
mockConfig,
() => undefined,
),
);
const startTime = Date.now();
vi.advanceTimersByTime(1000);
const mockToolCall = {
status: CoreToolCallStatus.Success as const,
request: {
callId: 'call-1',
name: 'test_tool',
args: {},
isClientInitiated: false,
prompt_id: 'p1',
},
tool: createMockTool(),
invocation: createMockInvocation(),
response: {
callId: 'call-1',
resultDisplay: 'OK',
responseParts: [],
error: undefined,
errorType: undefined,
},
tailToolCallRequest: {
name: 'tail_tool',
args: {},
isClientInitiated: false,
prompt_id: '123',
},
};
act(() => {
void mockMessageBus.publish({
type: MessageBusType.TOOL_CALLS_UPDATE,
toolCalls: [mockToolCall],
schedulerId: ROOT_SCHEDULER_ID,
} as ToolCallsUpdateMessage);
});
const [toolCalls, , , , , lastOutputTime] = result.current;
// Check if status has been adapted to 'executing'
expect(toolCalls[0].status).toBe(CoreToolCallStatus.Executing);
// Check if lastOutputTime was updated due to the transitional state
expect(lastOutputTime).toBeGreaterThan(startTime);
vi.useRealTimers();
});
});
+26 -2
View File
@@ -14,6 +14,7 @@ import {
Scheduler,
type EditorType,
type ToolCallsUpdateMessage,
CoreToolCallStatus,
} from '@google/gemini-cli-core';
import { useCallback, useState, useMemo, useEffect, useRef } from 'react';
@@ -115,7 +116,16 @@ export function useToolScheduler(
useEffect(() => {
const handler = (event: ToolCallsUpdateMessage) => {
// Update output timer for UI spinners (Side Effect)
if (event.toolCalls.some((tc) => tc.status === 'executing')) {
const hasExecuting = event.toolCalls.some(
(tc) =>
tc.status === CoreToolCallStatus.Executing ||
((tc.status === CoreToolCallStatus.Success ||
tc.status === CoreToolCallStatus.Error) &&
'tailToolCallRequest' in tc &&
tc.tailToolCallRequest != null),
);
if (hasExecuting) {
setLastToolOutputTime(Date.now());
}
@@ -238,9 +248,23 @@ function adaptToolCalls(
const prev = prevMap.get(coreCall.request.callId);
const responseSubmittedToGemini = prev?.responseSubmittedToGemini ?? false;
let status = coreCall.status;
// If a tool call has completed but scheduled a tail call, it is in a transitional
// state. Force the UI to render it as "executing".
if (
(status === CoreToolCallStatus.Success ||
status === CoreToolCallStatus.Error) &&
'tailToolCallRequest' in coreCall &&
coreCall.tailToolCallRequest != null
) {
status = CoreToolCallStatus.Executing;
}
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
return {
...coreCall,
status,
responseSubmittedToGemini,
};
} as TrackedToolCall;
});
}
+1
View File
@@ -110,6 +110,7 @@ export interface IndividualToolCallDisplay {
approvalMode?: ApprovalMode;
progressMessage?: string;
progressPercent?: number;
originalRequestName?: string;
}
export interface CompressionProps {
@@ -75,6 +75,7 @@ export async function executeToolWithHooks(
shellExecutionConfig?: ShellExecutionConfig,
setPidCallback?: (pid: number) => void,
config?: Config,
originalRequestName?: string,
): Promise<ToolResult> {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const toolInput = (invocation.params || {}) as Record<string, unknown>;
@@ -90,6 +91,7 @@ export async function executeToolWithHooks(
toolName,
toolInput,
mcpContext,
originalRequestName,
);
// Check if hook requested to stop entire agent execution
@@ -196,6 +198,7 @@ export async function executeToolWithHooks(
error: toolResult.error,
},
mcpContext,
originalRequestName,
);
// Check if hook requested to stop entire agent execution
@@ -242,6 +245,12 @@ export async function executeToolWithHooks(
toolResult.llmContent = wrappedContext;
}
}
// Check if the hook requested a tail tool call
const tailToolCallRequest = afterOutput?.getTailToolCallRequest();
if (tailToolCallRequest) {
toolResult.tailToolCallRequest = tailToolCallRequest;
}
}
return toolResult;
@@ -76,12 +76,16 @@ export class HookEventHandler {
toolName: string,
toolInput: Record<string, unknown>,
mcpContext?: McpToolContext,
originalRequestName?: string,
): Promise<AggregatedHookResult> {
const input: BeforeToolInput = {
...this.createBaseInput(HookEventName.BeforeTool),
tool_name: toolName,
tool_input: toolInput,
...(mcpContext && { mcp_context: mcpContext }),
...(originalRequestName && {
original_request_name: originalRequestName,
}),
};
const context: HookEventContext = { toolName };
@@ -97,6 +101,7 @@ export class HookEventHandler {
toolInput: Record<string, unknown>,
toolResponse: Record<string, unknown>,
mcpContext?: McpToolContext,
originalRequestName?: string,
): Promise<AggregatedHookResult> {
const input: AfterToolInput = {
...this.createBaseInput(HookEventName.AfterTool),
@@ -104,6 +109,9 @@ export class HookEventHandler {
tool_input: toolInput,
tool_response: toolResponse,
...(mcpContext && { mcp_context: mcpContext }),
...(originalRequestName && {
original_request_name: originalRequestName,
}),
};
const context: HookEventContext = { toolName };
+4
View File
@@ -368,12 +368,14 @@ export class HookSystem {
toolName: string,
toolInput: Record<string, unknown>,
mcpContext?: McpToolContext,
originalRequestName?: string,
): Promise<DefaultHookOutput | undefined> {
try {
const result = await this.hookEventHandler.fireBeforeToolEvent(
toolName,
toolInput,
mcpContext,
originalRequestName,
);
return result.finalOutput;
} catch (error) {
@@ -391,6 +393,7 @@ export class HookSystem {
error: unknown;
},
mcpContext?: McpToolContext,
originalRequestName?: string,
): Promise<DefaultHookOutput | undefined> {
try {
const result = await this.hookEventHandler.fireAfterToolEvent(
@@ -398,6 +401,7 @@ export class HookSystem {
toolInput,
toolResponse as Record<string, unknown>,
mcpContext,
originalRequestName,
);
return result.finalOutput;
} catch (error) {
+37
View File
@@ -253,6 +253,33 @@ export class DefaultHookOutput implements HookOutput {
shouldClearContext(): boolean {
return false;
}
/**
* Optional request to execute another tool immediately after this one.
* The result of this tail call will replace the original tool's response.
*/
getTailToolCallRequest():
| {
name: string;
args: Record<string, unknown>;
}
| undefined {
if (
this.hookSpecificOutput &&
'tailToolCallRequest' in this.hookSpecificOutput
) {
const request = this.hookSpecificOutput['tailToolCallRequest'];
if (
typeof request === 'object' &&
request !== null &&
!Array.isArray(request)
) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
return request as { name: string; args: Record<string, unknown> };
}
}
return undefined;
}
}
/**
@@ -430,6 +457,7 @@ export interface BeforeToolInput extends HookInput {
tool_name: string;
tool_input: Record<string, unknown>;
mcp_context?: McpToolContext; // Only present for MCP tools
original_request_name?: string;
}
/**
@@ -450,6 +478,7 @@ export interface AfterToolInput extends HookInput {
tool_input: Record<string, unknown>;
tool_response: Record<string, unknown>;
mcp_context?: McpToolContext; // Only present for MCP tools
original_request_name?: string;
}
/**
@@ -459,6 +488,14 @@ export interface AfterToolOutput extends HookOutput {
hookSpecificOutput?: {
hookEventName: 'AfterTool';
additionalContext?: string;
/**
* Optional request to execute another tool immediately after this one.
* The result of this tail call will replace the original tool's response.
*/
tailToolCallRequest?: {
name: string;
args: Record<string, unknown>;
};
};
}
@@ -201,6 +201,12 @@ describe('Scheduler (Orchestrator)', () => {
mockQueue.length = 0;
}),
clearBatch: vi.fn(),
replaceActiveCallWithTailCall: vi.fn((id: string, nextCall: ToolCall) => {
if (mockActiveCallsMap.has(id)) {
mockActiveCallsMap.delete(id);
mockQueue.unshift(nextCall);
}
}),
} as unknown as Mocked<SchedulerStateManager>;
// Define getters for accessors idiomatically
@@ -1006,6 +1012,113 @@ describe('Scheduler (Orchestrator)', () => {
const result = await (scheduler as any)._processNextItem(signal);
expect(result).toBe(false);
});
describe('Tail Calls', () => {
it('should replace the active call with a new tool call and re-run the loop when tail call is requested', async () => {
// Setup: Tool A will return a success with a tail call request to Tool B
const mockResponse = {
callId: 'call-1',
responseParts: [],
} as unknown as ToolCallResponseInfo;
mockExecutor.execute
.mockResolvedValueOnce({
status: 'success',
response: mockResponse,
tailToolCallRequest: {
name: 'tool-b',
args: { key: 'value' },
},
request: req1,
} as unknown as SuccessfulToolCall)
.mockResolvedValueOnce({
status: 'success',
response: mockResponse,
request: {
...req1,
name: 'tool-b',
args: { key: 'value' },
originalRequestName: 'test-tool',
},
} as unknown as SuccessfulToolCall);
const mockToolB = {
name: 'tool-b',
build: vi.fn().mockReturnValue({}),
} as unknown as AnyDeclarativeTool;
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockToolB);
await scheduler.schedule(req1, signal);
// Assert: The state manager is instructed to replace the call
expect(
mockStateManager.replaceActiveCallWithTailCall,
).toHaveBeenCalledWith(
'call-1',
expect.objectContaining({
request: expect.objectContaining({
callId: 'call-1',
name: 'tool-b',
args: { key: 'value' },
originalRequestName: 'test-tool', // Preserves original name
}),
tool: mockToolB,
}),
);
// Assert: The executor should be called twice (once for Tool A, once for Tool B)
expect(mockExecutor.execute).toHaveBeenCalledTimes(2);
});
it('should inject an errored tool call if the tail tool is not found', async () => {
const mockResponse = {
callId: 'call-1',
responseParts: [],
} as unknown as ToolCallResponseInfo;
mockExecutor.execute.mockResolvedValue({
status: 'success',
response: mockResponse,
tailToolCallRequest: {
name: 'missing-tool',
args: {},
},
request: req1,
} as unknown as SuccessfulToolCall);
// Tool registry returns undefined for missing-tool, but valid tool for test-tool
vi.mocked(mockToolRegistry.getTool).mockImplementation((name) => {
if (name === 'test-tool') {
return {
name: 'test-tool',
build: vi.fn().mockReturnValue({}),
} as unknown as AnyDeclarativeTool;
}
return undefined;
});
await scheduler.schedule(req1, signal);
// Assert: Replaces active call with an errored call
expect(
mockStateManager.replaceActiveCallWithTailCall,
).toHaveBeenCalledWith(
'call-1',
expect.objectContaining({
status: 'error',
request: expect.objectContaining({
callId: 'call-1',
name: 'missing-tool', // Name of the failed tail call
originalRequestName: 'test-tool',
}),
response: expect.objectContaining({
errorType: ToolErrorType.TOOL_NOT_REGISTERED,
}),
}),
);
});
});
});
describe('Tool Call Context Propagation', () => {
+68 -5
View File
@@ -19,6 +19,7 @@ import {
type ExecutingToolCall,
type ValidatingToolCall,
type ErroredToolCall,
type SuccessfulToolCall,
CoreToolCallStatus,
type ScheduledToolCall,
} from './types.js';
@@ -446,13 +447,16 @@ export class Scheduler {
c.status === CoreToolCallStatus.Scheduled || this.isTerminal(c.status),
);
let madeProgress = false;
if (allReady && scheduledCalls.length > 0) {
await Promise.all(scheduledCalls.map((c) => this._execute(c, signal)));
const execResults = await Promise.all(
scheduledCalls.map((c) => this._execute(c, signal)),
);
madeProgress = execResults.some((res) => res);
}
// 3. Finalize terminal calls
activeCalls = this.state.allActiveCalls;
let madeProgress = false;
for (const call of activeCalls) {
if (this.isTerminal(call.status)) {
this.state.finalizeCall(call.request.callId);
@@ -595,12 +599,12 @@ export class Scheduler {
// --- Sub-phase Handlers ---
/**
* Executes the tool and records the result.
* Executes the tool and records the result. Returns true if a new tool call was added.
*/
private async _execute(
toolCall: ScheduledToolCall,
signal: AbortSignal,
): Promise<void> {
): Promise<boolean> {
const callId = toolCall.request.callId;
if (signal.aborted) {
this.state.updateStatus(
@@ -608,7 +612,7 @@ export class Scheduler {
CoreToolCallStatus.Cancelled,
'Operation cancelled',
);
return;
return false;
}
this.state.updateStatus(callId, CoreToolCallStatus.Executing);
@@ -642,6 +646,64 @@ export class Scheduler {
}),
);
if (
(result.status === CoreToolCallStatus.Success ||
result.status === CoreToolCallStatus.Error) &&
result.tailToolCallRequest
) {
// Log the intermediate tool call before it gets replaced.
const intermediateCall: SuccessfulToolCall | ErroredToolCall = {
request: activeCall.request,
tool: activeCall.tool,
invocation: activeCall.invocation,
status: result.status,
response: result.response,
durationMs: activeCall.startTime
? Date.now() - activeCall.startTime
: undefined,
outcome: activeCall.outcome,
schedulerId: this.schedulerId,
};
logToolCall(this.config, new ToolCallEvent(intermediateCall));
const tailRequest = result.tailToolCallRequest;
const originalCallId = result.request.callId;
const originalRequestName =
result.request.originalRequestName || result.request.name;
const newTool = this.config.getToolRegistry().getTool(tailRequest.name);
const newRequest: ToolCallRequestInfo = {
callId: originalCallId,
name: tailRequest.name,
args: tailRequest.args,
originalRequestName,
isClientInitiated: result.request.isClientInitiated,
prompt_id: result.request.prompt_id,
schedulerId: this.schedulerId,
};
if (!newTool) {
// Enqueue an errored tool call
const errorCall = this._createToolNotFoundErroredToolCall(
newRequest,
this.config.getToolRegistry().getAllToolNames(),
);
this.state.replaceActiveCallWithTailCall(callId, errorCall);
} else {
// Enqueue a validating tool call for the new tail tool
const validatingCall = this._validateAndCreateToolCall(
newRequest,
newTool,
activeCall.approvalMode ?? this.config.getApprovalMode(),
);
this.state.replaceActiveCallWithTailCall(callId, validatingCall);
}
// Loop continues, picking up the new tail call at the front of the queue.
return true;
}
if (result.status === CoreToolCallStatus.Success) {
this.state.updateStatus(
callId,
@@ -661,6 +723,7 @@ export class Scheduler {
result.response,
);
}
return false;
}
private _processNextInRequestQueue() {
@@ -187,6 +187,19 @@ export class SchedulerStateManager {
this.emitUpdate();
}
/**
* Replaces the currently active call with a new call, placing the new call
* at the front of the queue to be processed immediately in the next tick.
* Used for Tail Calls to chain execution without finalizing the original call.
*/
replaceActiveCallWithTailCall(callId: string, nextCall: ToolCall): void {
if (this.activeCalls.has(callId)) {
this.activeCalls.delete(callId);
this.queue.unshift(nextCall);
this.emitUpdate();
}
}
cancelAllQueued(reason: string): void {
if (this.queue.length === 0) {
return;
@@ -252,7 +252,17 @@ describe('ToolExecutor', () => {
// 2. Mock executeToolWithHooks to trigger the PID callback
const testPid = 12345;
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockImplementation(
async (_inv, _name, _sig, _tool, _liveCb, _shellCfg, setPidCallback) => {
async (
_inv,
_name,
_sig,
_tool,
_liveCb,
_shellCfg,
setPidCallback,
_config,
_originalRequestName,
) => {
// Simulate the shell tool reporting a PID
if (setPidCallback) {
setPidCallback(testPid);
+9 -3
View File
@@ -99,6 +99,7 @@ export class ToolExecutor {
shellExecutionConfig,
setPidCallback,
this.config,
request.originalRequestName,
);
} else {
promise = executeToolWithHooks(
@@ -110,6 +111,7 @@ export class ToolExecutor {
shellExecutionConfig,
undefined,
this.config,
request.originalRequestName,
);
}
@@ -133,6 +135,7 @@ export class ToolExecutor {
new Error(toolResult.error.message),
toolResult.error.type,
displayText,
toolResult.tailToolCallRequest,
);
}
} catch (executionError: unknown) {
@@ -204,7 +207,7 @@ export class ToolExecutor {
): Promise<SuccessfulToolCall> {
let content = toolResult.llmContent;
let outputFile: string | undefined;
const toolName = call.request.name;
const toolName = call.request.originalRequestName || call.request.name;
const callId = call.request.callId;
if (typeof content === 'string' && toolName === SHELL_TOOL_NAME) {
@@ -268,6 +271,7 @@ export class ToolExecutor {
startTime,
endTime: Date.now(),
outcome: call.outcome,
tailToolCallRequest: toolResult.tailToolCallRequest,
};
}
@@ -276,6 +280,7 @@ export class ToolExecutor {
error: Error,
errorType?: ToolErrorType,
returnDisplay?: string,
tailToolCallRequest?: { name: string; args: Record<string, unknown> },
): ErroredToolCall {
const response = this.createErrorResponse(
call.request,
@@ -289,11 +294,12 @@ export class ToolExecutor {
status: CoreToolCallStatus.Error,
request: call.request,
response,
tool: call.tool,
tool: 'tool' in call ? call.tool : undefined,
durationMs: startTime ? Date.now() - startTime : undefined,
startTime,
endTime: Date.now(),
outcome: call.outcome,
tailToolCallRequest,
};
}
@@ -311,7 +317,7 @@ export class ToolExecutor {
{
functionResponse: {
id: request.callId,
name: request.name,
name: request.originalRequestName || request.name,
response: { error: error.message },
},
},
+14
View File
@@ -36,6 +36,11 @@ export interface ToolCallRequestInfo {
callId: string;
name: string;
args: Record<string, unknown>;
/**
* The original name of the tool requested by the model.
* This is used for tail calls to ensure the final response retains the original name.
*/
originalRequestName?: string;
isClientInitiated: boolean;
prompt_id: string;
checkpoint?: string;
@@ -58,6 +63,12 @@ export interface ToolCallResponseInfo {
data?: Record<string, unknown>;
}
/** Request to execute another tool immediately after a completed one. */
export interface TailToolCallRequest {
name: string;
args: Record<string, unknown>;
}
export type ValidatingToolCall = {
status: CoreToolCallStatus.Validating;
request: ToolCallRequestInfo;
@@ -91,6 +102,7 @@ export type ErroredToolCall = {
outcome?: ToolConfirmationOutcome;
schedulerId?: string;
approvalMode?: ApprovalMode;
tailToolCallRequest?: TailToolCallRequest;
};
export type SuccessfulToolCall = {
@@ -105,6 +117,7 @@ export type SuccessfulToolCall = {
outcome?: ToolConfirmationOutcome;
schedulerId?: string;
approvalMode?: ApprovalMode;
tailToolCallRequest?: TailToolCallRequest;
};
export type ExecutingToolCall = {
@@ -120,6 +133,7 @@ export type ExecutingToolCall = {
pid?: number;
schedulerId?: string;
approvalMode?: ApprovalMode;
tailToolCallRequest?: TailToolCallRequest;
};
export type CancelledToolCall = {
+9
View File
@@ -579,6 +579,15 @@ export interface ToolResult {
* Optional data payload for passing structured information back to the caller.
*/
data?: Record<string, unknown>;
/**
* Optional request to execute another tool immediately after this one.
* The result of this tail call will replace the original tool's response.
*/
tailToolCallRequest?: {
name: string;
args: Record<string, unknown>;
};
}
/**