feat: implement AfterTool tail tool calls (#18486)

This commit is contained in:
Steven Robertson
2026-02-23 19:57:00 -08:00
committed by GitHub
parent ee5eb70070
commit b0ceb74462
23 changed files with 567 additions and 26 deletions
@@ -75,6 +75,7 @@ export async function executeToolWithHooks(
shellExecutionConfig?: ShellExecutionConfig,
setPidCallback?: (pid: number) => void,
config?: Config,
originalRequestName?: string,
): Promise<ToolResult> {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const toolInput = (invocation.params || {}) as Record<string, unknown>;
@@ -90,6 +91,7 @@ export async function executeToolWithHooks(
toolName,
toolInput,
mcpContext,
originalRequestName,
);
// Check if hook requested to stop entire agent execution
@@ -196,6 +198,7 @@ export async function executeToolWithHooks(
error: toolResult.error,
},
mcpContext,
originalRequestName,
);
// Check if hook requested to stop entire agent execution
@@ -242,6 +245,12 @@ export async function executeToolWithHooks(
toolResult.llmContent = wrappedContext;
}
}
// Check if the hook requested a tail tool call
const tailToolCallRequest = afterOutput?.getTailToolCallRequest();
if (tailToolCallRequest) {
toolResult.tailToolCallRequest = tailToolCallRequest;
}
}
return toolResult;
@@ -76,12 +76,16 @@ export class HookEventHandler {
toolName: string,
toolInput: Record<string, unknown>,
mcpContext?: McpToolContext,
originalRequestName?: string,
): Promise<AggregatedHookResult> {
const input: BeforeToolInput = {
...this.createBaseInput(HookEventName.BeforeTool),
tool_name: toolName,
tool_input: toolInput,
...(mcpContext && { mcp_context: mcpContext }),
...(originalRequestName && {
original_request_name: originalRequestName,
}),
};
const context: HookEventContext = { toolName };
@@ -97,6 +101,7 @@ export class HookEventHandler {
toolInput: Record<string, unknown>,
toolResponse: Record<string, unknown>,
mcpContext?: McpToolContext,
originalRequestName?: string,
): Promise<AggregatedHookResult> {
const input: AfterToolInput = {
...this.createBaseInput(HookEventName.AfterTool),
@@ -104,6 +109,9 @@ export class HookEventHandler {
tool_input: toolInput,
tool_response: toolResponse,
...(mcpContext && { mcp_context: mcpContext }),
...(originalRequestName && {
original_request_name: originalRequestName,
}),
};
const context: HookEventContext = { toolName };
+4
View File
@@ -368,12 +368,14 @@ export class HookSystem {
toolName: string,
toolInput: Record<string, unknown>,
mcpContext?: McpToolContext,
originalRequestName?: string,
): Promise<DefaultHookOutput | undefined> {
try {
const result = await this.hookEventHandler.fireBeforeToolEvent(
toolName,
toolInput,
mcpContext,
originalRequestName,
);
return result.finalOutput;
} catch (error) {
@@ -391,6 +393,7 @@ export class HookSystem {
error: unknown;
},
mcpContext?: McpToolContext,
originalRequestName?: string,
): Promise<DefaultHookOutput | undefined> {
try {
const result = await this.hookEventHandler.fireAfterToolEvent(
@@ -398,6 +401,7 @@ export class HookSystem {
toolInput,
toolResponse as Record<string, unknown>,
mcpContext,
originalRequestName,
);
return result.finalOutput;
} catch (error) {
+37
View File
@@ -253,6 +253,33 @@ export class DefaultHookOutput implements HookOutput {
shouldClearContext(): boolean {
return false;
}
/**
* Optional request to execute another tool immediately after this one.
* The result of this tail call will replace the original tool's response.
*/
getTailToolCallRequest():
| {
name: string;
args: Record<string, unknown>;
}
| undefined {
if (
this.hookSpecificOutput &&
'tailToolCallRequest' in this.hookSpecificOutput
) {
const request = this.hookSpecificOutput['tailToolCallRequest'];
if (
typeof request === 'object' &&
request !== null &&
!Array.isArray(request)
) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
return request as { name: string; args: Record<string, unknown> };
}
}
return undefined;
}
}
/**
@@ -430,6 +457,7 @@ export interface BeforeToolInput extends HookInput {
tool_name: string;
tool_input: Record<string, unknown>;
mcp_context?: McpToolContext; // Only present for MCP tools
original_request_name?: string;
}
/**
@@ -450,6 +478,7 @@ export interface AfterToolInput extends HookInput {
tool_input: Record<string, unknown>;
tool_response: Record<string, unknown>;
mcp_context?: McpToolContext; // Only present for MCP tools
original_request_name?: string;
}
/**
@@ -459,6 +488,14 @@ export interface AfterToolOutput extends HookOutput {
hookSpecificOutput?: {
hookEventName: 'AfterTool';
additionalContext?: string;
/**
* Optional request to execute another tool immediately after this one.
* The result of this tail call will replace the original tool's response.
*/
tailToolCallRequest?: {
name: string;
args: Record<string, unknown>;
};
};
}
@@ -201,6 +201,12 @@ describe('Scheduler (Orchestrator)', () => {
mockQueue.length = 0;
}),
clearBatch: vi.fn(),
replaceActiveCallWithTailCall: vi.fn((id: string, nextCall: ToolCall) => {
if (mockActiveCallsMap.has(id)) {
mockActiveCallsMap.delete(id);
mockQueue.unshift(nextCall);
}
}),
} as unknown as Mocked<SchedulerStateManager>;
// Define getters for accessors idiomatically
@@ -1006,6 +1012,113 @@ describe('Scheduler (Orchestrator)', () => {
const result = await (scheduler as any)._processNextItem(signal);
expect(result).toBe(false);
});
describe('Tail Calls', () => {
it('should replace the active call with a new tool call and re-run the loop when tail call is requested', async () => {
// Setup: Tool A will return a success with a tail call request to Tool B
const mockResponse = {
callId: 'call-1',
responseParts: [],
} as unknown as ToolCallResponseInfo;
mockExecutor.execute
.mockResolvedValueOnce({
status: 'success',
response: mockResponse,
tailToolCallRequest: {
name: 'tool-b',
args: { key: 'value' },
},
request: req1,
} as unknown as SuccessfulToolCall)
.mockResolvedValueOnce({
status: 'success',
response: mockResponse,
request: {
...req1,
name: 'tool-b',
args: { key: 'value' },
originalRequestName: 'test-tool',
},
} as unknown as SuccessfulToolCall);
const mockToolB = {
name: 'tool-b',
build: vi.fn().mockReturnValue({}),
} as unknown as AnyDeclarativeTool;
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockToolB);
await scheduler.schedule(req1, signal);
// Assert: The state manager is instructed to replace the call
expect(
mockStateManager.replaceActiveCallWithTailCall,
).toHaveBeenCalledWith(
'call-1',
expect.objectContaining({
request: expect.objectContaining({
callId: 'call-1',
name: 'tool-b',
args: { key: 'value' },
originalRequestName: 'test-tool', // Preserves original name
}),
tool: mockToolB,
}),
);
// Assert: The executor should be called twice (once for Tool A, once for Tool B)
expect(mockExecutor.execute).toHaveBeenCalledTimes(2);
});
it('should inject an errored tool call if the tail tool is not found', async () => {
const mockResponse = {
callId: 'call-1',
responseParts: [],
} as unknown as ToolCallResponseInfo;
mockExecutor.execute.mockResolvedValue({
status: 'success',
response: mockResponse,
tailToolCallRequest: {
name: 'missing-tool',
args: {},
},
request: req1,
} as unknown as SuccessfulToolCall);
// Tool registry returns undefined for missing-tool, but valid tool for test-tool
vi.mocked(mockToolRegistry.getTool).mockImplementation((name) => {
if (name === 'test-tool') {
return {
name: 'test-tool',
build: vi.fn().mockReturnValue({}),
} as unknown as AnyDeclarativeTool;
}
return undefined;
});
await scheduler.schedule(req1, signal);
// Assert: Replaces active call with an errored call
expect(
mockStateManager.replaceActiveCallWithTailCall,
).toHaveBeenCalledWith(
'call-1',
expect.objectContaining({
status: 'error',
request: expect.objectContaining({
callId: 'call-1',
name: 'missing-tool', // Name of the failed tail call
originalRequestName: 'test-tool',
}),
response: expect.objectContaining({
errorType: ToolErrorType.TOOL_NOT_REGISTERED,
}),
}),
);
});
});
});
describe('Tool Call Context Propagation', () => {
+68 -5
View File
@@ -19,6 +19,7 @@ import {
type ExecutingToolCall,
type ValidatingToolCall,
type ErroredToolCall,
type SuccessfulToolCall,
CoreToolCallStatus,
type ScheduledToolCall,
} from './types.js';
@@ -446,13 +447,16 @@ export class Scheduler {
c.status === CoreToolCallStatus.Scheduled || this.isTerminal(c.status),
);
let madeProgress = false;
if (allReady && scheduledCalls.length > 0) {
await Promise.all(scheduledCalls.map((c) => this._execute(c, signal)));
const execResults = await Promise.all(
scheduledCalls.map((c) => this._execute(c, signal)),
);
madeProgress = execResults.some((res) => res);
}
// 3. Finalize terminal calls
activeCalls = this.state.allActiveCalls;
let madeProgress = false;
for (const call of activeCalls) {
if (this.isTerminal(call.status)) {
this.state.finalizeCall(call.request.callId);
@@ -595,12 +599,12 @@ export class Scheduler {
// --- Sub-phase Handlers ---
/**
* Executes the tool and records the result.
* Executes the tool and records the result. Returns true if a new tool call was added.
*/
private async _execute(
toolCall: ScheduledToolCall,
signal: AbortSignal,
): Promise<void> {
): Promise<boolean> {
const callId = toolCall.request.callId;
if (signal.aborted) {
this.state.updateStatus(
@@ -608,7 +612,7 @@ export class Scheduler {
CoreToolCallStatus.Cancelled,
'Operation cancelled',
);
return;
return false;
}
this.state.updateStatus(callId, CoreToolCallStatus.Executing);
@@ -642,6 +646,64 @@ export class Scheduler {
}),
);
if (
(result.status === CoreToolCallStatus.Success ||
result.status === CoreToolCallStatus.Error) &&
result.tailToolCallRequest
) {
// Log the intermediate tool call before it gets replaced.
const intermediateCall: SuccessfulToolCall | ErroredToolCall = {
request: activeCall.request,
tool: activeCall.tool,
invocation: activeCall.invocation,
status: result.status,
response: result.response,
durationMs: activeCall.startTime
? Date.now() - activeCall.startTime
: undefined,
outcome: activeCall.outcome,
schedulerId: this.schedulerId,
};
logToolCall(this.config, new ToolCallEvent(intermediateCall));
const tailRequest = result.tailToolCallRequest;
const originalCallId = result.request.callId;
const originalRequestName =
result.request.originalRequestName || result.request.name;
const newTool = this.config.getToolRegistry().getTool(tailRequest.name);
const newRequest: ToolCallRequestInfo = {
callId: originalCallId,
name: tailRequest.name,
args: tailRequest.args,
originalRequestName,
isClientInitiated: result.request.isClientInitiated,
prompt_id: result.request.prompt_id,
schedulerId: this.schedulerId,
};
if (!newTool) {
// Enqueue an errored tool call
const errorCall = this._createToolNotFoundErroredToolCall(
newRequest,
this.config.getToolRegistry().getAllToolNames(),
);
this.state.replaceActiveCallWithTailCall(callId, errorCall);
} else {
// Enqueue a validating tool call for the new tail tool
const validatingCall = this._validateAndCreateToolCall(
newRequest,
newTool,
activeCall.approvalMode ?? this.config.getApprovalMode(),
);
this.state.replaceActiveCallWithTailCall(callId, validatingCall);
}
// Loop continues, picking up the new tail call at the front of the queue.
return true;
}
if (result.status === CoreToolCallStatus.Success) {
this.state.updateStatus(
callId,
@@ -661,6 +723,7 @@ export class Scheduler {
result.response,
);
}
return false;
}
private _processNextInRequestQueue() {
@@ -187,6 +187,19 @@ export class SchedulerStateManager {
this.emitUpdate();
}
/**
* Replaces the currently active call with a new call, placing the new call
* at the front of the queue to be processed immediately in the next tick.
* Used for Tail Calls to chain execution without finalizing the original call.
*/
replaceActiveCallWithTailCall(callId: string, nextCall: ToolCall): void {
if (this.activeCalls.has(callId)) {
this.activeCalls.delete(callId);
this.queue.unshift(nextCall);
this.emitUpdate();
}
}
cancelAllQueued(reason: string): void {
if (this.queue.length === 0) {
return;
@@ -252,7 +252,17 @@ describe('ToolExecutor', () => {
// 2. Mock executeToolWithHooks to trigger the PID callback
const testPid = 12345;
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockImplementation(
async (_inv, _name, _sig, _tool, _liveCb, _shellCfg, setPidCallback) => {
async (
_inv,
_name,
_sig,
_tool,
_liveCb,
_shellCfg,
setPidCallback,
_config,
_originalRequestName,
) => {
// Simulate the shell tool reporting a PID
if (setPidCallback) {
setPidCallback(testPid);
+9 -3
View File
@@ -99,6 +99,7 @@ export class ToolExecutor {
shellExecutionConfig,
setPidCallback,
this.config,
request.originalRequestName,
);
} else {
promise = executeToolWithHooks(
@@ -110,6 +111,7 @@ export class ToolExecutor {
shellExecutionConfig,
undefined,
this.config,
request.originalRequestName,
);
}
@@ -133,6 +135,7 @@ export class ToolExecutor {
new Error(toolResult.error.message),
toolResult.error.type,
displayText,
toolResult.tailToolCallRequest,
);
}
} catch (executionError: unknown) {
@@ -204,7 +207,7 @@ export class ToolExecutor {
): Promise<SuccessfulToolCall> {
let content = toolResult.llmContent;
let outputFile: string | undefined;
const toolName = call.request.name;
const toolName = call.request.originalRequestName || call.request.name;
const callId = call.request.callId;
if (typeof content === 'string' && toolName === SHELL_TOOL_NAME) {
@@ -268,6 +271,7 @@ export class ToolExecutor {
startTime,
endTime: Date.now(),
outcome: call.outcome,
tailToolCallRequest: toolResult.tailToolCallRequest,
};
}
@@ -276,6 +280,7 @@ export class ToolExecutor {
error: Error,
errorType?: ToolErrorType,
returnDisplay?: string,
tailToolCallRequest?: { name: string; args: Record<string, unknown> },
): ErroredToolCall {
const response = this.createErrorResponse(
call.request,
@@ -289,11 +294,12 @@ export class ToolExecutor {
status: CoreToolCallStatus.Error,
request: call.request,
response,
tool: call.tool,
tool: 'tool' in call ? call.tool : undefined,
durationMs: startTime ? Date.now() - startTime : undefined,
startTime,
endTime: Date.now(),
outcome: call.outcome,
tailToolCallRequest,
};
}
@@ -311,7 +317,7 @@ export class ToolExecutor {
{
functionResponse: {
id: request.callId,
name: request.name,
name: request.originalRequestName || request.name,
response: { error: error.message },
},
},
+14
View File
@@ -36,6 +36,11 @@ export interface ToolCallRequestInfo {
callId: string;
name: string;
args: Record<string, unknown>;
/**
* The original name of the tool requested by the model.
* This is used for tail calls to ensure the final response retains the original name.
*/
originalRequestName?: string;
isClientInitiated: boolean;
prompt_id: string;
checkpoint?: string;
@@ -58,6 +63,12 @@ export interface ToolCallResponseInfo {
data?: Record<string, unknown>;
}
/** Request to execute another tool immediately after a completed one. */
export interface TailToolCallRequest {
name: string;
args: Record<string, unknown>;
}
export type ValidatingToolCall = {
status: CoreToolCallStatus.Validating;
request: ToolCallRequestInfo;
@@ -91,6 +102,7 @@ export type ErroredToolCall = {
outcome?: ToolConfirmationOutcome;
schedulerId?: string;
approvalMode?: ApprovalMode;
tailToolCallRequest?: TailToolCallRequest;
};
export type SuccessfulToolCall = {
@@ -105,6 +117,7 @@ export type SuccessfulToolCall = {
outcome?: ToolConfirmationOutcome;
schedulerId?: string;
approvalMode?: ApprovalMode;
tailToolCallRequest?: TailToolCallRequest;
};
export type ExecutingToolCall = {
@@ -120,6 +133,7 @@ export type ExecutingToolCall = {
pid?: number;
schedulerId?: string;
approvalMode?: ApprovalMode;
tailToolCallRequest?: TailToolCallRequest;
};
export type CancelledToolCall = {
+9
View File
@@ -579,6 +579,15 @@ export interface ToolResult {
* Optional data payload for passing structured information back to the caller.
*/
data?: Record<string, unknown>;
/**
* Optional request to execute another tool immediately after this one.
* The result of this tail call will replace the original tool's response.
*/
tailToolCallRequest?: {
name: string;
args: Record<string, unknown>;
};
}
/**