mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-24 12:04:56 -07:00
feat: implement AfterTool tail tool calls (#18486)
This commit is contained in:
@@ -75,6 +75,7 @@ export async function executeToolWithHooks(
|
||||
shellExecutionConfig?: ShellExecutionConfig,
|
||||
setPidCallback?: (pid: number) => void,
|
||||
config?: Config,
|
||||
originalRequestName?: string,
|
||||
): Promise<ToolResult> {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const toolInput = (invocation.params || {}) as Record<string, unknown>;
|
||||
@@ -90,6 +91,7 @@ export async function executeToolWithHooks(
|
||||
toolName,
|
||||
toolInput,
|
||||
mcpContext,
|
||||
originalRequestName,
|
||||
);
|
||||
|
||||
// Check if hook requested to stop entire agent execution
|
||||
@@ -196,6 +198,7 @@ export async function executeToolWithHooks(
|
||||
error: toolResult.error,
|
||||
},
|
||||
mcpContext,
|
||||
originalRequestName,
|
||||
);
|
||||
|
||||
// Check if hook requested to stop entire agent execution
|
||||
@@ -242,6 +245,12 @@ export async function executeToolWithHooks(
|
||||
toolResult.llmContent = wrappedContext;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the hook requested a tail tool call
|
||||
const tailToolCallRequest = afterOutput?.getTailToolCallRequest();
|
||||
if (tailToolCallRequest) {
|
||||
toolResult.tailToolCallRequest = tailToolCallRequest;
|
||||
}
|
||||
}
|
||||
|
||||
return toolResult;
|
||||
|
||||
@@ -76,12 +76,16 @@ export class HookEventHandler {
|
||||
toolName: string,
|
||||
toolInput: Record<string, unknown>,
|
||||
mcpContext?: McpToolContext,
|
||||
originalRequestName?: string,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: BeforeToolInput = {
|
||||
...this.createBaseInput(HookEventName.BeforeTool),
|
||||
tool_name: toolName,
|
||||
tool_input: toolInput,
|
||||
...(mcpContext && { mcp_context: mcpContext }),
|
||||
...(originalRequestName && {
|
||||
original_request_name: originalRequestName,
|
||||
}),
|
||||
};
|
||||
|
||||
const context: HookEventContext = { toolName };
|
||||
@@ -97,6 +101,7 @@ export class HookEventHandler {
|
||||
toolInput: Record<string, unknown>,
|
||||
toolResponse: Record<string, unknown>,
|
||||
mcpContext?: McpToolContext,
|
||||
originalRequestName?: string,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: AfterToolInput = {
|
||||
...this.createBaseInput(HookEventName.AfterTool),
|
||||
@@ -104,6 +109,9 @@ export class HookEventHandler {
|
||||
tool_input: toolInput,
|
||||
tool_response: toolResponse,
|
||||
...(mcpContext && { mcp_context: mcpContext }),
|
||||
...(originalRequestName && {
|
||||
original_request_name: originalRequestName,
|
||||
}),
|
||||
};
|
||||
|
||||
const context: HookEventContext = { toolName };
|
||||
|
||||
@@ -368,12 +368,14 @@ export class HookSystem {
|
||||
toolName: string,
|
||||
toolInput: Record<string, unknown>,
|
||||
mcpContext?: McpToolContext,
|
||||
originalRequestName?: string,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
try {
|
||||
const result = await this.hookEventHandler.fireBeforeToolEvent(
|
||||
toolName,
|
||||
toolInput,
|
||||
mcpContext,
|
||||
originalRequestName,
|
||||
);
|
||||
return result.finalOutput;
|
||||
} catch (error) {
|
||||
@@ -391,6 +393,7 @@ export class HookSystem {
|
||||
error: unknown;
|
||||
},
|
||||
mcpContext?: McpToolContext,
|
||||
originalRequestName?: string,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
try {
|
||||
const result = await this.hookEventHandler.fireAfterToolEvent(
|
||||
@@ -398,6 +401,7 @@ export class HookSystem {
|
||||
toolInput,
|
||||
toolResponse as Record<string, unknown>,
|
||||
mcpContext,
|
||||
originalRequestName,
|
||||
);
|
||||
return result.finalOutput;
|
||||
} catch (error) {
|
||||
|
||||
@@ -253,6 +253,33 @@ export class DefaultHookOutput implements HookOutput {
|
||||
shouldClearContext(): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Optional request to execute another tool immediately after this one.
|
||||
* The result of this tail call will replace the original tool's response.
|
||||
*/
|
||||
getTailToolCallRequest():
|
||||
| {
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
}
|
||||
| undefined {
|
||||
if (
|
||||
this.hookSpecificOutput &&
|
||||
'tailToolCallRequest' in this.hookSpecificOutput
|
||||
) {
|
||||
const request = this.hookSpecificOutput['tailToolCallRequest'];
|
||||
if (
|
||||
typeof request === 'object' &&
|
||||
request !== null &&
|
||||
!Array.isArray(request)
|
||||
) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
return request as { name: string; args: Record<string, unknown> };
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -430,6 +457,7 @@ export interface BeforeToolInput extends HookInput {
|
||||
tool_name: string;
|
||||
tool_input: Record<string, unknown>;
|
||||
mcp_context?: McpToolContext; // Only present for MCP tools
|
||||
original_request_name?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -450,6 +478,7 @@ export interface AfterToolInput extends HookInput {
|
||||
tool_input: Record<string, unknown>;
|
||||
tool_response: Record<string, unknown>;
|
||||
mcp_context?: McpToolContext; // Only present for MCP tools
|
||||
original_request_name?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -459,6 +488,14 @@ export interface AfterToolOutput extends HookOutput {
|
||||
hookSpecificOutput?: {
|
||||
hookEventName: 'AfterTool';
|
||||
additionalContext?: string;
|
||||
/**
|
||||
* Optional request to execute another tool immediately after this one.
|
||||
* The result of this tail call will replace the original tool's response.
|
||||
*/
|
||||
tailToolCallRequest?: {
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -201,6 +201,12 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
mockQueue.length = 0;
|
||||
}),
|
||||
clearBatch: vi.fn(),
|
||||
replaceActiveCallWithTailCall: vi.fn((id: string, nextCall: ToolCall) => {
|
||||
if (mockActiveCallsMap.has(id)) {
|
||||
mockActiveCallsMap.delete(id);
|
||||
mockQueue.unshift(nextCall);
|
||||
}
|
||||
}),
|
||||
} as unknown as Mocked<SchedulerStateManager>;
|
||||
|
||||
// Define getters for accessors idiomatically
|
||||
@@ -1006,6 +1012,113 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
const result = await (scheduler as any)._processNextItem(signal);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
describe('Tail Calls', () => {
|
||||
it('should replace the active call with a new tool call and re-run the loop when tail call is requested', async () => {
|
||||
// Setup: Tool A will return a success with a tail call request to Tool B
|
||||
const mockResponse = {
|
||||
callId: 'call-1',
|
||||
responseParts: [],
|
||||
} as unknown as ToolCallResponseInfo;
|
||||
|
||||
mockExecutor.execute
|
||||
.mockResolvedValueOnce({
|
||||
status: 'success',
|
||||
response: mockResponse,
|
||||
tailToolCallRequest: {
|
||||
name: 'tool-b',
|
||||
args: { key: 'value' },
|
||||
},
|
||||
request: req1,
|
||||
} as unknown as SuccessfulToolCall)
|
||||
.mockResolvedValueOnce({
|
||||
status: 'success',
|
||||
response: mockResponse,
|
||||
request: {
|
||||
...req1,
|
||||
name: 'tool-b',
|
||||
args: { key: 'value' },
|
||||
originalRequestName: 'test-tool',
|
||||
},
|
||||
} as unknown as SuccessfulToolCall);
|
||||
|
||||
const mockToolB = {
|
||||
name: 'tool-b',
|
||||
build: vi.fn().mockReturnValue({}),
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
|
||||
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockToolB);
|
||||
|
||||
await scheduler.schedule(req1, signal);
|
||||
|
||||
// Assert: The state manager is instructed to replace the call
|
||||
expect(
|
||||
mockStateManager.replaceActiveCallWithTailCall,
|
||||
).toHaveBeenCalledWith(
|
||||
'call-1',
|
||||
expect.objectContaining({
|
||||
request: expect.objectContaining({
|
||||
callId: 'call-1',
|
||||
name: 'tool-b',
|
||||
args: { key: 'value' },
|
||||
originalRequestName: 'test-tool', // Preserves original name
|
||||
}),
|
||||
tool: mockToolB,
|
||||
}),
|
||||
);
|
||||
|
||||
// Assert: The executor should be called twice (once for Tool A, once for Tool B)
|
||||
expect(mockExecutor.execute).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should inject an errored tool call if the tail tool is not found', async () => {
|
||||
const mockResponse = {
|
||||
callId: 'call-1',
|
||||
responseParts: [],
|
||||
} as unknown as ToolCallResponseInfo;
|
||||
|
||||
mockExecutor.execute.mockResolvedValue({
|
||||
status: 'success',
|
||||
response: mockResponse,
|
||||
tailToolCallRequest: {
|
||||
name: 'missing-tool',
|
||||
args: {},
|
||||
},
|
||||
request: req1,
|
||||
} as unknown as SuccessfulToolCall);
|
||||
|
||||
// Tool registry returns undefined for missing-tool, but valid tool for test-tool
|
||||
vi.mocked(mockToolRegistry.getTool).mockImplementation((name) => {
|
||||
if (name === 'test-tool') {
|
||||
return {
|
||||
name: 'test-tool',
|
||||
build: vi.fn().mockReturnValue({}),
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
|
||||
await scheduler.schedule(req1, signal);
|
||||
|
||||
// Assert: Replaces active call with an errored call
|
||||
expect(
|
||||
mockStateManager.replaceActiveCallWithTailCall,
|
||||
).toHaveBeenCalledWith(
|
||||
'call-1',
|
||||
expect.objectContaining({
|
||||
status: 'error',
|
||||
request: expect.objectContaining({
|
||||
callId: 'call-1',
|
||||
name: 'missing-tool', // Name of the failed tail call
|
||||
originalRequestName: 'test-tool',
|
||||
}),
|
||||
response: expect.objectContaining({
|
||||
errorType: ToolErrorType.TOOL_NOT_REGISTERED,
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Tool Call Context Propagation', () => {
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
type ExecutingToolCall,
|
||||
type ValidatingToolCall,
|
||||
type ErroredToolCall,
|
||||
type SuccessfulToolCall,
|
||||
CoreToolCallStatus,
|
||||
type ScheduledToolCall,
|
||||
} from './types.js';
|
||||
@@ -446,13 +447,16 @@ export class Scheduler {
|
||||
c.status === CoreToolCallStatus.Scheduled || this.isTerminal(c.status),
|
||||
);
|
||||
|
||||
let madeProgress = false;
|
||||
if (allReady && scheduledCalls.length > 0) {
|
||||
await Promise.all(scheduledCalls.map((c) => this._execute(c, signal)));
|
||||
const execResults = await Promise.all(
|
||||
scheduledCalls.map((c) => this._execute(c, signal)),
|
||||
);
|
||||
madeProgress = execResults.some((res) => res);
|
||||
}
|
||||
|
||||
// 3. Finalize terminal calls
|
||||
activeCalls = this.state.allActiveCalls;
|
||||
let madeProgress = false;
|
||||
for (const call of activeCalls) {
|
||||
if (this.isTerminal(call.status)) {
|
||||
this.state.finalizeCall(call.request.callId);
|
||||
@@ -595,12 +599,12 @@ export class Scheduler {
|
||||
// --- Sub-phase Handlers ---
|
||||
|
||||
/**
|
||||
* Executes the tool and records the result.
|
||||
* Executes the tool and records the result. Returns true if a new tool call was added.
|
||||
*/
|
||||
private async _execute(
|
||||
toolCall: ScheduledToolCall,
|
||||
signal: AbortSignal,
|
||||
): Promise<void> {
|
||||
): Promise<boolean> {
|
||||
const callId = toolCall.request.callId;
|
||||
if (signal.aborted) {
|
||||
this.state.updateStatus(
|
||||
@@ -608,7 +612,7 @@ export class Scheduler {
|
||||
CoreToolCallStatus.Cancelled,
|
||||
'Operation cancelled',
|
||||
);
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
this.state.updateStatus(callId, CoreToolCallStatus.Executing);
|
||||
|
||||
@@ -642,6 +646,64 @@ export class Scheduler {
|
||||
}),
|
||||
);
|
||||
|
||||
if (
|
||||
(result.status === CoreToolCallStatus.Success ||
|
||||
result.status === CoreToolCallStatus.Error) &&
|
||||
result.tailToolCallRequest
|
||||
) {
|
||||
// Log the intermediate tool call before it gets replaced.
|
||||
const intermediateCall: SuccessfulToolCall | ErroredToolCall = {
|
||||
request: activeCall.request,
|
||||
tool: activeCall.tool,
|
||||
invocation: activeCall.invocation,
|
||||
status: result.status,
|
||||
response: result.response,
|
||||
durationMs: activeCall.startTime
|
||||
? Date.now() - activeCall.startTime
|
||||
: undefined,
|
||||
outcome: activeCall.outcome,
|
||||
schedulerId: this.schedulerId,
|
||||
};
|
||||
logToolCall(this.config, new ToolCallEvent(intermediateCall));
|
||||
|
||||
const tailRequest = result.tailToolCallRequest;
|
||||
const originalCallId = result.request.callId;
|
||||
const originalRequestName =
|
||||
result.request.originalRequestName || result.request.name;
|
||||
|
||||
const newTool = this.config.getToolRegistry().getTool(tailRequest.name);
|
||||
|
||||
const newRequest: ToolCallRequestInfo = {
|
||||
callId: originalCallId,
|
||||
name: tailRequest.name,
|
||||
args: tailRequest.args,
|
||||
originalRequestName,
|
||||
isClientInitiated: result.request.isClientInitiated,
|
||||
prompt_id: result.request.prompt_id,
|
||||
schedulerId: this.schedulerId,
|
||||
};
|
||||
|
||||
if (!newTool) {
|
||||
// Enqueue an errored tool call
|
||||
const errorCall = this._createToolNotFoundErroredToolCall(
|
||||
newRequest,
|
||||
this.config.getToolRegistry().getAllToolNames(),
|
||||
);
|
||||
this.state.replaceActiveCallWithTailCall(callId, errorCall);
|
||||
} else {
|
||||
// Enqueue a validating tool call for the new tail tool
|
||||
const validatingCall = this._validateAndCreateToolCall(
|
||||
newRequest,
|
||||
newTool,
|
||||
activeCall.approvalMode ?? this.config.getApprovalMode(),
|
||||
);
|
||||
this.state.replaceActiveCallWithTailCall(callId, validatingCall);
|
||||
}
|
||||
|
||||
// Loop continues, picking up the new tail call at the front of the queue.
|
||||
return true;
|
||||
}
|
||||
|
||||
if (result.status === CoreToolCallStatus.Success) {
|
||||
this.state.updateStatus(
|
||||
callId,
|
||||
@@ -661,6 +723,7 @@ export class Scheduler {
|
||||
result.response,
|
||||
);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private _processNextInRequestQueue() {
|
||||
|
||||
@@ -187,6 +187,19 @@ export class SchedulerStateManager {
|
||||
this.emitUpdate();
|
||||
}
|
||||
|
||||
/**
|
||||
* Replaces the currently active call with a new call, placing the new call
|
||||
* at the front of the queue to be processed immediately in the next tick.
|
||||
* Used for Tail Calls to chain execution without finalizing the original call.
|
||||
*/
|
||||
replaceActiveCallWithTailCall(callId: string, nextCall: ToolCall): void {
|
||||
if (this.activeCalls.has(callId)) {
|
||||
this.activeCalls.delete(callId);
|
||||
this.queue.unshift(nextCall);
|
||||
this.emitUpdate();
|
||||
}
|
||||
}
|
||||
|
||||
cancelAllQueued(reason: string): void {
|
||||
if (this.queue.length === 0) {
|
||||
return;
|
||||
|
||||
@@ -252,7 +252,17 @@ describe('ToolExecutor', () => {
|
||||
// 2. Mock executeToolWithHooks to trigger the PID callback
|
||||
const testPid = 12345;
|
||||
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockImplementation(
|
||||
async (_inv, _name, _sig, _tool, _liveCb, _shellCfg, setPidCallback) => {
|
||||
async (
|
||||
_inv,
|
||||
_name,
|
||||
_sig,
|
||||
_tool,
|
||||
_liveCb,
|
||||
_shellCfg,
|
||||
setPidCallback,
|
||||
_config,
|
||||
_originalRequestName,
|
||||
) => {
|
||||
// Simulate the shell tool reporting a PID
|
||||
if (setPidCallback) {
|
||||
setPidCallback(testPid);
|
||||
|
||||
@@ -99,6 +99,7 @@ export class ToolExecutor {
|
||||
shellExecutionConfig,
|
||||
setPidCallback,
|
||||
this.config,
|
||||
request.originalRequestName,
|
||||
);
|
||||
} else {
|
||||
promise = executeToolWithHooks(
|
||||
@@ -110,6 +111,7 @@ export class ToolExecutor {
|
||||
shellExecutionConfig,
|
||||
undefined,
|
||||
this.config,
|
||||
request.originalRequestName,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -133,6 +135,7 @@ export class ToolExecutor {
|
||||
new Error(toolResult.error.message),
|
||||
toolResult.error.type,
|
||||
displayText,
|
||||
toolResult.tailToolCallRequest,
|
||||
);
|
||||
}
|
||||
} catch (executionError: unknown) {
|
||||
@@ -204,7 +207,7 @@ export class ToolExecutor {
|
||||
): Promise<SuccessfulToolCall> {
|
||||
let content = toolResult.llmContent;
|
||||
let outputFile: string | undefined;
|
||||
const toolName = call.request.name;
|
||||
const toolName = call.request.originalRequestName || call.request.name;
|
||||
const callId = call.request.callId;
|
||||
|
||||
if (typeof content === 'string' && toolName === SHELL_TOOL_NAME) {
|
||||
@@ -268,6 +271,7 @@ export class ToolExecutor {
|
||||
startTime,
|
||||
endTime: Date.now(),
|
||||
outcome: call.outcome,
|
||||
tailToolCallRequest: toolResult.tailToolCallRequest,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -276,6 +280,7 @@ export class ToolExecutor {
|
||||
error: Error,
|
||||
errorType?: ToolErrorType,
|
||||
returnDisplay?: string,
|
||||
tailToolCallRequest?: { name: string; args: Record<string, unknown> },
|
||||
): ErroredToolCall {
|
||||
const response = this.createErrorResponse(
|
||||
call.request,
|
||||
@@ -289,11 +294,12 @@ export class ToolExecutor {
|
||||
status: CoreToolCallStatus.Error,
|
||||
request: call.request,
|
||||
response,
|
||||
tool: call.tool,
|
||||
tool: 'tool' in call ? call.tool : undefined,
|
||||
durationMs: startTime ? Date.now() - startTime : undefined,
|
||||
startTime,
|
||||
endTime: Date.now(),
|
||||
outcome: call.outcome,
|
||||
tailToolCallRequest,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -311,7 +317,7 @@ export class ToolExecutor {
|
||||
{
|
||||
functionResponse: {
|
||||
id: request.callId,
|
||||
name: request.name,
|
||||
name: request.originalRequestName || request.name,
|
||||
response: { error: error.message },
|
||||
},
|
||||
},
|
||||
|
||||
@@ -36,6 +36,11 @@ export interface ToolCallRequestInfo {
|
||||
callId: string;
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
/**
|
||||
* The original name of the tool requested by the model.
|
||||
* This is used for tail calls to ensure the final response retains the original name.
|
||||
*/
|
||||
originalRequestName?: string;
|
||||
isClientInitiated: boolean;
|
||||
prompt_id: string;
|
||||
checkpoint?: string;
|
||||
@@ -58,6 +63,12 @@ export interface ToolCallResponseInfo {
|
||||
data?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
/** Request to execute another tool immediately after a completed one. */
|
||||
export interface TailToolCallRequest {
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export type ValidatingToolCall = {
|
||||
status: CoreToolCallStatus.Validating;
|
||||
request: ToolCallRequestInfo;
|
||||
@@ -91,6 +102,7 @@ export type ErroredToolCall = {
|
||||
outcome?: ToolConfirmationOutcome;
|
||||
schedulerId?: string;
|
||||
approvalMode?: ApprovalMode;
|
||||
tailToolCallRequest?: TailToolCallRequest;
|
||||
};
|
||||
|
||||
export type SuccessfulToolCall = {
|
||||
@@ -105,6 +117,7 @@ export type SuccessfulToolCall = {
|
||||
outcome?: ToolConfirmationOutcome;
|
||||
schedulerId?: string;
|
||||
approvalMode?: ApprovalMode;
|
||||
tailToolCallRequest?: TailToolCallRequest;
|
||||
};
|
||||
|
||||
export type ExecutingToolCall = {
|
||||
@@ -120,6 +133,7 @@ export type ExecutingToolCall = {
|
||||
pid?: number;
|
||||
schedulerId?: string;
|
||||
approvalMode?: ApprovalMode;
|
||||
tailToolCallRequest?: TailToolCallRequest;
|
||||
};
|
||||
|
||||
export type CancelledToolCall = {
|
||||
|
||||
@@ -579,6 +579,15 @@ export interface ToolResult {
|
||||
* Optional data payload for passing structured information back to the caller.
|
||||
*/
|
||||
data?: Record<string, unknown>;
|
||||
|
||||
/**
|
||||
* Optional request to execute another tool immediately after this one.
|
||||
* The result of this tail call will replace the original tool's response.
|
||||
*/
|
||||
tailToolCallRequest?: {
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user