refactor(core): Extract and integrate ToolExecutor (#15900)

This commit is contained in:
Abhi
2026-01-05 00:48:41 -05:00
committed by GitHub
parent 615b218ff7
commit b4b49e7029
4 changed files with 657 additions and 193 deletions

View File

@@ -5,7 +5,6 @@
*/
import {
type ToolResult,
type ToolResultDisplay,
type AnyDeclarativeTool,
type AnyToolInvocation,
@@ -15,13 +14,11 @@ import {
} from '../tools/tools.js';
import type { EditorType } from '../utils/editor.js';
import type { Config } from '../config/config.js';
import type { AnsiOutput } from '../utils/terminalSerializer.js';
import { ApprovalMode } from '../policy/types.js';
import { logToolCall, logToolOutputTruncated } from '../telemetry/loggers.js';
import { logToolCall } from '../telemetry/loggers.js';
import { ToolErrorType } from '../tools/tool-error.js';
import { ToolCallEvent, ToolOutputTruncatedEvent } from '../telemetry/types.js';
import { ToolCallEvent } from '../telemetry/types.js';
import { runInDevTraceSpan } from '../telemetry/trace.js';
import { SHELL_TOOL_NAME } from '../tools/tool-names.js';
import type { ModifyContext } from '../tools/modifiable-tool.js';
import {
isModifiableDeclarativeTool,
@@ -34,14 +31,10 @@ import {
getToolSuggestion,
} from '../utils/tool-utils.js';
import { isShellInvocationAllowlisted } from '../utils/shell-permissions.js';
import { ShellToolInvocation } from '../tools/shell.js';
import type { ToolConfirmationRequest } from '../confirmation-bus/types.js';
import { MessageBusType } from '../confirmation-bus/types.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import {
fireToolNotificationHook,
executeToolWithHooks,
} from './coreToolHookTriggers.js';
import { fireToolNotificationHook } from './coreToolHookTriggers.js';
import {
type ToolCall,
type ValidatingToolCall,
@@ -60,8 +53,7 @@ import {
type ToolCallRequestInfo,
type ToolCallResponseInfo,
} from '../scheduler/types.js';
import { saveTruncatedContent } from '../utils/fileUtils.js';
import { convertToFunctionResponse } from '../utils/generateContentResponseUtilities.js';
import { ToolExecutor } from '../scheduler/tool-executor.js';
export type {
ToolCall,
@@ -136,6 +128,7 @@ export class CoreToolScheduler {
}> = [];
private toolCallQueue: ToolCall[] = [];
private completedToolCallsForBatch: CompletedToolCall[] = [];
private toolExecutor: ToolExecutor;
constructor(options: CoreToolSchedulerOptions) {
this.config = options.config;
@@ -143,6 +136,7 @@ export class CoreToolScheduler {
this.onAllToolCallsComplete = options.onAllToolCallsComplete;
this.onToolCallsUpdate = options.onToolCallsUpdate;
this.getPreferredEditor = options.getPreferredEditor;
this.toolExecutor = new ToolExecutor(this.config);
// Subscribe to message bus for ASK_USER policy decisions
// Use a static WeakMap to ensure we only subscribe ONCE per MessageBus instance
@@ -847,188 +841,48 @@ export class CoreToolScheduler {
for (const toolCall of callsToExecute) {
if (toolCall.status !== 'scheduled') continue;
const scheduledCall = toolCall;
const { callId, name: toolName } = scheduledCall.request;
const invocation = scheduledCall.invocation;
this.setStatusInternal(callId, 'executing', signal);
const liveOutputCallback =
scheduledCall.tool.canUpdateOutput && this.outputUpdateHandler
? (outputChunk: string | AnsiOutput) => {
if (this.outputUpdateHandler) {
this.outputUpdateHandler(callId, outputChunk);
}
this.toolCalls = this.toolCalls.map((tc) =>
tc.request.callId === callId && tc.status === 'executing'
? { ...tc, liveOutput: outputChunk }
: tc,
);
this.notifyToolCallsUpdate();
}
: undefined;
const shellExecutionConfig = this.config.getShellExecutionConfig();
const hooksEnabled = this.config.getEnableHooks();
const messageBus = this.config.getMessageBus();
await runInDevTraceSpan(
{
name: toolCall.tool.name,
attributes: { type: 'tool-call' },
},
async ({ metadata: spanMetadata }) => {
spanMetadata.input = {
request: toolCall.request,
};
// TODO: Refactor to remove special casing for ShellToolInvocation.
// Introduce a generic callbacks object for the execute method to handle
// things like `onPid` and `onLiveOutput`. This will make the scheduler
// agnostic to the invocation type.
let promise: Promise<ToolResult>;
if (invocation instanceof ShellToolInvocation) {
const setPidCallback = (pid: number) => {
this.toolCalls = this.toolCalls.map((tc) =>
tc.request.callId === callId && tc.status === 'executing'
? { ...tc, pid }
: tc,
);
this.notifyToolCallsUpdate();
};
promise = executeToolWithHooks(
invocation,
toolName,
signal,
messageBus,
hooksEnabled,
toolCall.tool,
liveOutputCallback,
shellExecutionConfig,
setPidCallback,
);
} else {
promise = executeToolWithHooks(
invocation,
toolName,
signal,
messageBus,
hooksEnabled,
toolCall.tool,
liveOutputCallback,
shellExecutionConfig,
);
}
try {
const toolResult: ToolResult = await promise;
spanMetadata.output = toolResult;
if (signal.aborted) {
this.setStatusInternal(
callId,
'cancelled',
signal,
'User cancelled tool execution.',
);
} else if (toolResult.error === undefined) {
let content = toolResult.llmContent;
let outputFile: string | undefined = undefined;
const contentLength =
typeof content === 'string' ? content.length : undefined;
if (
typeof content === 'string' &&
toolName === SHELL_TOOL_NAME &&
this.config.getEnableToolOutputTruncation() &&
this.config.getTruncateToolOutputThreshold() > 0 &&
this.config.getTruncateToolOutputLines() > 0
) {
const originalContentLength = content.length;
const threshold =
this.config.getTruncateToolOutputThreshold();
const lines = this.config.getTruncateToolOutputLines();
const truncatedResult = await saveTruncatedContent(
content,
callId,
this.config.storage.getProjectTempDir(),
threshold,
lines,
);
content = truncatedResult.content;
outputFile = truncatedResult.outputFile;
if (outputFile) {
logToolOutputTruncated(
this.config,
new ToolOutputTruncatedEvent(
scheduledCall.request.prompt_id,
{
toolName,
originalContentLength,
truncatedContentLength: content.length,
threshold,
lines,
},
),
);
}
}
const response = convertToFunctionResponse(
toolName,
callId,
content,
this.config.getActiveModel(),
);
const successResponse: ToolCallResponseInfo = {
callId,
responseParts: response,
resultDisplay: toolResult.returnDisplay,
error: undefined,
errorType: undefined,
outputFile,
contentLength,
};
this.setStatusInternal(
callId,
'success',
signal,
successResponse,
);
} else {
// It is a failure
const error = new Error(toolResult.error.message);
const errorResponse = createErrorResponse(
scheduledCall.request,
error,
toolResult.error.type,
);
this.setStatusInternal(callId, 'error', signal, errorResponse);
}
} catch (executionError: unknown) {
spanMetadata.error = executionError;
if (signal.aborted) {
this.setStatusInternal(
callId,
'cancelled',
signal,
'User cancelled tool execution.',
);
} else {
this.setStatusInternal(
callId,
'error',
signal,
createErrorResponse(
scheduledCall.request,
executionError instanceof Error
? executionError
: new Error(String(executionError)),
ToolErrorType.UNHANDLED_EXCEPTION,
),
);
}
}
await this.checkAndNotifyCompletion(signal);
},
this.setStatusInternal(toolCall.request.callId, 'executing', signal);
const executingCall = this.toolCalls.find(
(c) => c.request.callId === toolCall.request.callId,
);
if (!executingCall) {
// Should not happen, but safe guard
continue;
}
const completedCall = await this.toolExecutor.execute({
call: executingCall,
signal,
outputUpdateHandler: (callId, output) => {
if (this.outputUpdateHandler) {
this.outputUpdateHandler(callId, output);
}
this.toolCalls = this.toolCalls.map((tc) =>
tc.request.callId === callId && tc.status === 'executing'
? { ...tc, liveOutput: output }
: tc,
);
this.notifyToolCallsUpdate();
},
onUpdateToolCall: (updatedCall) => {
this.toolCalls = this.toolCalls.map((tc) =>
tc.request.callId === updatedCall.request.callId
? updatedCall
: tc,
);
this.notifyToolCallsUpdate();
},
});
this.toolCalls = this.toolCalls.map((tc) =>
tc.request.callId === completedCall.request.callId
? completedCall
: tc,
);
this.notifyToolCallsUpdate();
await this.checkAndNotifyCompletion(signal);
}
}
}

View File

@@ -36,6 +36,7 @@ export * from './core/turn.js';
export * from './core/geminiRequest.js';
export * from './core/coreToolScheduler.js';
export * from './scheduler/types.js';
export * from './scheduler/tool-executor.js';
export * from './core/nonInteractiveToolExecutor.js';
export * from './core/recordingContentGenerator.js';

View File

@@ -0,0 +1,299 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { ToolExecutor } from './tool-executor.js';
import type { Config } from '../index.js';
import type { ToolResult } from '../tools/tools.js';
import { makeFakeConfig } from '../test-utils/config.js';
import { MockTool } from '../test-utils/mock-tool.js';
import type { ScheduledToolCall } from './types.js';
import type { AnyToolInvocation } from '../index.js';
import { SHELL_TOOL_NAME } from '../tools/tool-names.js';
import * as fileUtils from '../utils/fileUtils.js';
import * as coreToolHookTriggers from '../core/coreToolHookTriggers.js';
import { ShellToolInvocation } from '../tools/shell.js';
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
// Mock file utils
vi.mock('../utils/fileUtils.js', () => ({
saveTruncatedContent: vi.fn(),
}));
// Mock executeToolWithHooks
vi.mock('../core/coreToolHookTriggers.js', () => ({
executeToolWithHooks: vi.fn(),
}));
describe('ToolExecutor', () => {
let config: Config;
let executor: ToolExecutor;
beforeEach(() => {
// Use the standard fake config factory
config = makeFakeConfig();
executor = new ToolExecutor(config);
// Reset mocks
vi.resetAllMocks();
// Default mock implementation for saveTruncatedContent
vi.mocked(fileUtils.saveTruncatedContent).mockImplementation(
async (_content, _callId, _tempDir, _threshold, _lines) => ({
content: 'TruncatedContent...',
outputFile: '/tmp/truncated_output.txt',
}),
);
});
afterEach(() => {
vi.restoreAllMocks();
});
it('should execute a tool successfully', async () => {
const mockTool = new MockTool({
name: 'testTool',
execute: async () => ({
llmContent: 'Tool output',
returnDisplay: 'Tool output',
}),
});
const invocation = mockTool.build({});
// Mock executeToolWithHooks to return success
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockResolvedValue({
llmContent: 'Tool output',
returnDisplay: 'Tool output',
} as ToolResult);
const scheduledCall: ScheduledToolCall = {
status: 'scheduled',
request: {
callId: 'call-1',
name: 'testTool',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-1',
},
tool: mockTool,
invocation: invocation as unknown as AnyToolInvocation,
startTime: Date.now(),
};
const onUpdateToolCall = vi.fn();
const result = await executor.execute({
call: scheduledCall,
signal: new AbortController().signal,
onUpdateToolCall,
});
expect(result.status).toBe('success');
if (result.status === 'success') {
const response = result.response.responseParts[0]?.functionResponse
?.response as Record<string, unknown>;
expect(response).toEqual({ output: 'Tool output' });
}
});
it('should handle execution errors', async () => {
const mockTool = new MockTool({
name: 'failTool',
});
const invocation = mockTool.build({});
// Mock executeToolWithHooks to throw
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockRejectedValue(
new Error('Tool Failed'),
);
const scheduledCall: ScheduledToolCall = {
status: 'scheduled',
request: {
callId: 'call-2',
name: 'failTool',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-2',
},
tool: mockTool,
invocation: invocation as unknown as AnyToolInvocation,
startTime: Date.now(),
};
const result = await executor.execute({
call: scheduledCall,
signal: new AbortController().signal,
onUpdateToolCall: vi.fn(),
});
expect(result.status).toBe('error');
if (result.status === 'error') {
expect(result.response.error?.message).toBe('Tool Failed');
}
});
it('should return cancelled result when signal is aborted', async () => {
const mockTool = new MockTool({
name: 'slowTool',
});
const invocation = mockTool.build({});
// Mock executeToolWithHooks to simulate slow execution or cancellation check
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockImplementation(
async () => {
await new Promise((r) => setTimeout(r, 100));
return { llmContent: 'Done', returnDisplay: 'Done' };
},
);
const scheduledCall: ScheduledToolCall = {
status: 'scheduled',
request: {
callId: 'call-3',
name: 'slowTool',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-3',
},
tool: mockTool,
invocation: invocation as unknown as AnyToolInvocation,
startTime: Date.now(),
};
const controller = new AbortController();
const promise = executor.execute({
call: scheduledCall,
signal: controller.signal,
onUpdateToolCall: vi.fn(),
});
controller.abort();
const result = await promise;
expect(result.status).toBe('cancelled');
});
it('should truncate large shell output', async () => {
// 1. Setup Config for Truncation
vi.spyOn(config, 'getEnableToolOutputTruncation').mockReturnValue(true);
vi.spyOn(config, 'getTruncateToolOutputThreshold').mockReturnValue(10);
vi.spyOn(config, 'getTruncateToolOutputLines').mockReturnValue(5);
const mockTool = new MockTool({ name: SHELL_TOOL_NAME });
const invocation = mockTool.build({});
const longOutput = 'This is a very long output that should be truncated.';
// 2. Mock execution returning long content
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockResolvedValue({
llmContent: longOutput,
returnDisplay: longOutput,
});
const scheduledCall: ScheduledToolCall = {
status: 'scheduled',
request: {
callId: 'call-trunc',
name: SHELL_TOOL_NAME,
args: { command: 'echo long' },
isClientInitiated: false,
prompt_id: 'prompt-trunc',
},
tool: mockTool,
invocation: invocation as unknown as AnyToolInvocation,
startTime: Date.now(),
};
// 3. Execute
const result = await executor.execute({
call: scheduledCall,
signal: new AbortController().signal,
onUpdateToolCall: vi.fn(),
});
// 4. Verify Truncation Logic
expect(fileUtils.saveTruncatedContent).toHaveBeenCalledWith(
longOutput,
'call-trunc',
expect.any(String), // temp dir
10, // threshold
5, // lines
);
expect(result.status).toBe('success');
if (result.status === 'success') {
const response = result.response.responseParts[0]?.functionResponse
?.response as Record<string, unknown>;
// The content should be the *truncated* version returned by the mock saveTruncatedContent
expect(response).toEqual({ output: 'TruncatedContent...' });
expect(result.response.outputFile).toBe('/tmp/truncated_output.txt');
}
});
it('should report PID updates for shell tools', async () => {
// 1. Setup ShellToolInvocation
const messageBus = createMockMessageBus();
const shellInvocation = new ShellToolInvocation(
config,
{ command: 'sleep 10' },
messageBus,
);
// We need a dummy tool that matches the invocation just for structure
const mockTool = new MockTool({ name: SHELL_TOOL_NAME });
// 2. Mock executeToolWithHooks to trigger the PID callback
const testPid = 12345;
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockImplementation(
async (
_inv,
_name,
_sig,
_bus,
_hooks,
_tool,
_liveCb,
_shellCfg,
setPidCallback,
) => {
// Simulate the shell tool reporting a PID
if (setPidCallback) {
setPidCallback(testPid);
}
return { llmContent: 'done', returnDisplay: 'done' };
},
);
const scheduledCall: ScheduledToolCall = {
status: 'scheduled',
request: {
callId: 'call-pid',
name: SHELL_TOOL_NAME,
args: { command: 'sleep 10' },
isClientInitiated: false,
prompt_id: 'prompt-pid',
},
tool: mockTool,
invocation: shellInvocation,
startTime: Date.now(),
};
const onUpdateToolCall = vi.fn();
// 3. Execute
await executor.execute({
call: scheduledCall,
signal: new AbortController().signal,
onUpdateToolCall,
});
// 4. Verify PID was reported
expect(onUpdateToolCall).toHaveBeenCalledWith(
expect.objectContaining({
status: 'executing',
pid: testPid,
}),
);
});
});

View File

@@ -0,0 +1,310 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type {
ToolCallRequestInfo,
ToolCallResponseInfo,
ToolResult,
Config,
AnsiOutput,
} from '../index.js';
import {
ToolErrorType,
ToolOutputTruncatedEvent,
logToolOutputTruncated,
runInDevTraceSpan,
} from '../index.js';
import { SHELL_TOOL_NAME } from '../tools/tool-names.js';
import { ShellToolInvocation } from '../tools/shell.js';
import { executeToolWithHooks } from '../core/coreToolHookTriggers.js';
import { saveTruncatedContent } from '../utils/fileUtils.js';
import { convertToFunctionResponse } from '../utils/generateContentResponseUtilities.js';
import type {
CompletedToolCall,
ToolCall,
ExecutingToolCall,
ErroredToolCall,
SuccessfulToolCall,
CancelledToolCall,
} from './types.js';
export interface ToolExecutionContext {
call: ToolCall;
signal: AbortSignal;
outputUpdateHandler?: (callId: string, output: string | AnsiOutput) => void;
onUpdateToolCall: (updatedCall: ToolCall) => void;
}
export class ToolExecutor {
constructor(private readonly config: Config) {}
async execute(context: ToolExecutionContext): Promise<CompletedToolCall> {
const { call, signal, outputUpdateHandler, onUpdateToolCall } = context;
const { request } = call;
const toolName = request.name;
const callId = request.callId;
if (!('tool' in call) || !call.tool || !('invocation' in call)) {
throw new Error(
`Cannot execute tool call ${callId}: Tool or Invocation missing.`,
);
}
const { tool, invocation } = call;
// Setup live output handling
const liveOutputCallback =
tool.canUpdateOutput && outputUpdateHandler
? (outputChunk: string | AnsiOutput) => {
outputUpdateHandler(callId, outputChunk);
}
: undefined;
const shellExecutionConfig = this.config.getShellExecutionConfig();
const hooksEnabled = this.config.getEnableHooks();
const messageBus = this.config.getMessageBus();
return runInDevTraceSpan(
{
name: tool.name,
attributes: { type: 'tool-call' },
},
async ({ metadata: spanMetadata }) => {
spanMetadata.input = { request };
try {
let promise: Promise<ToolResult>;
if (invocation instanceof ShellToolInvocation) {
const setPidCallback = (pid: number) => {
const executingCall: ExecutingToolCall = {
...call,
status: 'executing',
tool,
invocation,
pid,
startTime: 'startTime' in call ? call.startTime : undefined,
};
onUpdateToolCall(executingCall);
};
promise = executeToolWithHooks(
invocation,
toolName,
signal,
messageBus,
hooksEnabled,
tool,
liveOutputCallback,
shellExecutionConfig,
setPidCallback,
);
} else {
promise = executeToolWithHooks(
invocation,
toolName,
signal,
messageBus,
hooksEnabled,
tool,
liveOutputCallback,
shellExecutionConfig,
);
}
const toolResult: ToolResult = await promise;
spanMetadata.output = toolResult;
if (signal.aborted) {
return this.createCancelledResult(
call,
'User cancelled tool execution.',
);
} else if (toolResult.error === undefined) {
return await this.createSuccessResult(call, toolResult);
} else {
return this.createErrorResult(
call,
new Error(toolResult.error.message),
toolResult.error.type,
);
}
} catch (executionError: unknown) {
spanMetadata.error = executionError;
if (signal.aborted) {
return this.createCancelledResult(
call,
'User cancelled tool execution.',
);
}
const error =
executionError instanceof Error
? executionError
: new Error(String(executionError));
return this.createErrorResult(
call,
error,
ToolErrorType.UNHANDLED_EXCEPTION,
);
}
},
);
}
private createCancelledResult(
call: ToolCall,
reason: string,
): CancelledToolCall {
const errorMessage = `[Operation Cancelled] ${reason}`;
const startTime = 'startTime' in call ? call.startTime : undefined;
if (!('tool' in call) || !('invocation' in call)) {
// This should effectively never happen in execution phase, but we handle
// it safely
throw new Error('Cancelled tool call missing tool/invocation references');
}
return {
status: 'cancelled',
request: call.request,
response: {
callId: call.request.callId,
responseParts: [
{
functionResponse: {
id: call.request.callId,
name: call.request.name,
response: { error: errorMessage },
},
},
],
resultDisplay: undefined,
error: undefined,
errorType: undefined,
contentLength: errorMessage.length,
},
tool: call.tool,
invocation: call.invocation,
durationMs: startTime ? Date.now() - startTime : undefined,
outcome: call.outcome,
};
}
private async createSuccessResult(
call: ToolCall,
toolResult: ToolResult,
): Promise<SuccessfulToolCall> {
let content = toolResult.llmContent;
let outputFile: string | undefined;
const toolName = call.request.name;
const callId = call.request.callId;
if (
typeof content === 'string' &&
toolName === SHELL_TOOL_NAME &&
this.config.getEnableToolOutputTruncation() &&
this.config.getTruncateToolOutputThreshold() > 0 &&
this.config.getTruncateToolOutputLines() > 0
) {
const originalContentLength = content.length;
const threshold = this.config.getTruncateToolOutputThreshold();
const lines = this.config.getTruncateToolOutputLines();
const truncatedResult = await saveTruncatedContent(
content,
callId,
this.config.storage.getProjectTempDir(),
threshold,
lines,
);
content = truncatedResult.content;
outputFile = truncatedResult.outputFile;
if (outputFile) {
logToolOutputTruncated(
this.config,
new ToolOutputTruncatedEvent(call.request.prompt_id, {
toolName,
originalContentLength,
truncatedContentLength: content.length,
threshold,
lines,
}),
);
}
}
const response = convertToFunctionResponse(
toolName,
callId,
content,
this.config.getActiveModel(),
);
const successResponse: ToolCallResponseInfo = {
callId,
responseParts: response,
resultDisplay: toolResult.returnDisplay,
error: undefined,
errorType: undefined,
outputFile,
contentLength: typeof content === 'string' ? content.length : undefined,
};
const startTime = 'startTime' in call ? call.startTime : undefined;
// Ensure we have tool and invocation
if (!('tool' in call) || !('invocation' in call)) {
throw new Error('Successful tool call missing tool or invocation');
}
return {
status: 'success',
request: call.request,
tool: call.tool,
response: successResponse,
invocation: call.invocation,
durationMs: startTime ? Date.now() - startTime : undefined,
outcome: call.outcome,
};
}
private createErrorResult(
call: ToolCall,
error: Error,
errorType?: ToolErrorType,
): ErroredToolCall {
const response = this.createErrorResponse(call.request, error, errorType);
const startTime = 'startTime' in call ? call.startTime : undefined;
return {
status: 'error',
request: call.request,
response,
tool: call.tool,
durationMs: startTime ? Date.now() - startTime : undefined,
outcome: call.outcome,
};
}
private createErrorResponse(
request: ToolCallRequestInfo,
error: Error,
errorType: ToolErrorType | undefined,
): ToolCallResponseInfo {
return {
callId: request.callId,
error,
responseParts: [
{
functionResponse: {
id: request.callId,
name: request.name,
response: { error: error.message },
},
},
],
resultDisplay: error.message,
errorType,
contentLength: error.message.length,
};
}
}