diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts index e8fd45ed2e..7ab94f5bf1 100644 --- a/packages/cli/src/nonInteractiveCli.test.ts +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -14,7 +14,6 @@ import type { UserFeedbackPayload, } from '@google/gemini-cli-core'; import { - executeToolCall, ToolErrorType, GeminiEventType, OutputFormat, @@ -48,6 +47,8 @@ const mockCoreEvents = vi.hoisted(() => ({ drainBacklogs: vi.fn(), })); +const mockSchedulerSchedule = vi.hoisted(() => vi.fn()); + vi.mock('@google/gemini-cli-core', async (importOriginal) => { const original = await importOriginal(); @@ -61,7 +62,10 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => { return { ...original, - executeToolCall: vi.fn(), + Scheduler: class { + schedule = mockSchedulerSchedule; + cancelAll = vi.fn(); + }, isTelemetrySdkInitialized: vi.fn().mockReturnValue(true), ChatRecordingService: MockChatRecordingService, uiTelemetryService: { @@ -91,7 +95,6 @@ describe('runNonInteractive', () => { let mockConfig: Config; let mockSettings: LoadedSettings; let mockToolRegistry: ToolRegistry; - let mockCoreExecuteToolCall: Mock; let consoleErrorSpy: MockInstance; let processStdoutSpy: MockInstance; let processStderrSpy: MockInstance; @@ -122,7 +125,7 @@ describe('runNonInteractive', () => { }; beforeEach(async () => { - mockCoreExecuteToolCall = vi.mocked(executeToolCall); + mockSchedulerSchedule.mockReset(); mockCommandServiceCreate.mockResolvedValue({ getCommands: mockGetCommands, @@ -158,6 +161,11 @@ describe('runNonInteractive', () => { mockConfig = { initialize: vi.fn().mockResolvedValue(undefined), + getMessageBus: vi.fn().mockReturnValue({ + subscribe: vi.fn(), + unsubscribe: vi.fn(), + publish: vi.fn(), + }), getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient), getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), getMaxSessionTurns: vi.fn().mockReturnValue(10), @@ -263,25 +271,27 @@ describe('runNonInteractive', () => { }, }; const toolResponse: Part[] = [{ text: 'Tool response' }]; - mockCoreExecuteToolCall.mockResolvedValue({ - status: 'success', - request: { - callId: 'tool-1', - name: 'testTool', - args: { arg1: 'value1' }, - isClientInitiated: false, - prompt_id: 'prompt-id-2', + mockSchedulerSchedule.mockResolvedValue([ + { + status: 'success', + request: { + callId: 'tool-1', + name: 'testTool', + args: { arg1: 'value1' }, + isClientInitiated: false, + prompt_id: 'prompt-id-2', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + responseParts: toolResponse, + callId: 'tool-1', + error: undefined, + errorType: undefined, + contentLength: undefined, + }, }, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - responseParts: toolResponse, - callId: 'tool-1', - error: undefined, - errorType: undefined, - contentLength: undefined, - }, - }); + ]); const firstCallEvents: ServerGeminiStreamEvent[] = [toolCallEvent]; const secondCallEvents: ServerGeminiStreamEvent[] = [ @@ -304,9 +314,8 @@ describe('runNonInteractive', () => { }); expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2); - expect(mockCoreExecuteToolCall).toHaveBeenCalledWith( - mockConfig, - expect.objectContaining({ name: 'testTool' }), + expect(mockSchedulerSchedule).toHaveBeenCalledWith( + [expect.objectContaining({ name: 'testTool' })], expect.any(AbortSignal), ); expect(mockGeminiClient.sendMessageStream).toHaveBeenNthCalledWith( @@ -335,16 +344,18 @@ describe('runNonInteractive', () => { }; // 2. Mock the execution of the tools. We just need them to succeed. - mockCoreExecuteToolCall.mockResolvedValue({ - status: 'success', - request: toolCallEvent.value, // This is generic enough for both calls - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - responseParts: [], - callId: 'mock-tool', + mockSchedulerSchedule.mockResolvedValue([ + { + status: 'success', + request: toolCallEvent.value, // This is generic enough for both calls + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + responseParts: [], + callId: 'mock-tool', + }, }, - }); + ]); // 3. Define the sequence of events streamed from the mock model. // Turn 1: Model outputs text, then requests a tool call. @@ -385,7 +396,7 @@ describe('runNonInteractive', () => { expect(getWrittenOutput()).toMatchSnapshot(); // Also verify the tools were called as expected. - expect(mockCoreExecuteToolCall).toHaveBeenCalledTimes(2); + expect(mockSchedulerSchedule).toHaveBeenCalledTimes(2); }); it('should handle error during tool execution and should send error back to the model', async () => { @@ -399,34 +410,36 @@ describe('runNonInteractive', () => { prompt_id: 'prompt-id-3', }, }; - mockCoreExecuteToolCall.mockResolvedValue({ - status: 'error', - request: { - callId: 'tool-1', - name: 'errorTool', - args: {}, - isClientInitiated: false, - prompt_id: 'prompt-id-3', - }, - tool: {} as AnyDeclarativeTool, - response: { - callId: 'tool-1', - error: new Error('Execution failed'), - errorType: ToolErrorType.EXECUTION_FAILED, - responseParts: [ - { - functionResponse: { - name: 'errorTool', - response: { - output: 'Error: Execution failed', + mockSchedulerSchedule.mockResolvedValue([ + { + status: 'error', + request: { + callId: 'tool-1', + name: 'errorTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-3', + }, + tool: {} as AnyDeclarativeTool, + response: { + callId: 'tool-1', + error: new Error('Execution failed'), + errorType: ToolErrorType.EXECUTION_FAILED, + responseParts: [ + { + functionResponse: { + name: 'errorTool', + response: { + output: 'Error: Execution failed', + }, }, }, - }, - ], - resultDisplay: 'Execution failed', - contentLength: undefined, + ], + resultDisplay: 'Execution failed', + contentLength: undefined, + }, }, - }); + ]); const finalResponse: ServerGeminiStreamEvent[] = [ { type: GeminiEventType.Content, @@ -448,7 +461,7 @@ describe('runNonInteractive', () => { prompt_id: 'prompt-id-3', }); - expect(mockCoreExecuteToolCall).toHaveBeenCalled(); + expect(mockSchedulerSchedule).toHaveBeenCalled(); expect(consoleErrorSpy).toHaveBeenCalledWith( 'Error executing tool errorTool: Execution failed', ); @@ -498,24 +511,26 @@ describe('runNonInteractive', () => { prompt_id: 'prompt-id-5', }, }; - mockCoreExecuteToolCall.mockResolvedValue({ - status: 'error', - request: { - callId: 'tool-1', - name: 'nonexistentTool', - args: {}, - isClientInitiated: false, - prompt_id: 'prompt-id-5', + mockSchedulerSchedule.mockResolvedValue([ + { + status: 'error', + request: { + callId: 'tool-1', + name: 'nonexistentTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-5', + }, + response: { + callId: 'tool-1', + error: new Error('Tool "nonexistentTool" not found in registry.'), + resultDisplay: 'Tool "nonexistentTool" not found in registry.', + responseParts: [], + errorType: undefined, + contentLength: undefined, + }, }, - response: { - callId: 'tool-1', - error: new Error('Tool "nonexistentTool" not found in registry.'), - resultDisplay: 'Tool "nonexistentTool" not found in registry.', - responseParts: [], - errorType: undefined, - contentLength: undefined, - }, - }); + ]); const finalResponse: ServerGeminiStreamEvent[] = [ { type: GeminiEventType.Content, @@ -538,7 +553,7 @@ describe('runNonInteractive', () => { prompt_id: 'prompt-id-5', }); - expect(mockCoreExecuteToolCall).toHaveBeenCalled(); + expect(mockSchedulerSchedule).toHaveBeenCalled(); expect(consoleErrorSpy).toHaveBeenCalledWith( 'Error executing tool nonexistentTool: Tool "nonexistentTool" not found in registry.', ); @@ -665,25 +680,27 @@ describe('runNonInteractive', () => { }, }; const toolResponse: Part[] = [{ text: 'Tool executed successfully' }]; - mockCoreExecuteToolCall.mockResolvedValue({ - status: 'success', - request: { - callId: 'tool-1', - name: 'testTool', - args: { arg1: 'value1' }, - isClientInitiated: false, - prompt_id: 'prompt-id-tool-only', + mockSchedulerSchedule.mockResolvedValue([ + { + status: 'success', + request: { + callId: 'tool-1', + name: 'testTool', + args: { arg1: 'value1' }, + isClientInitiated: false, + prompt_id: 'prompt-id-tool-only', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + responseParts: toolResponse, + callId: 'tool-1', + error: undefined, + errorType: undefined, + contentLength: undefined, + }, }, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - responseParts: toolResponse, - callId: 'tool-1', - error: undefined, - errorType: undefined, - contentLength: undefined, - }, - }); + ]); // First call returns only tool call, no content const firstCallEvents: ServerGeminiStreamEvent[] = [ @@ -719,9 +736,8 @@ describe('runNonInteractive', () => { }); expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2); - expect(mockCoreExecuteToolCall).toHaveBeenCalledWith( - mockConfig, - expect.objectContaining({ name: 'testTool' }), + expect(mockSchedulerSchedule).toHaveBeenCalledWith( + [expect.objectContaining({ name: 'testTool' })], expect.any(AbortSignal), ); @@ -1248,25 +1264,27 @@ describe('runNonInteractive', () => { }, }; const toolResponse: Part[] = [{ text: 'file.txt' }]; - mockCoreExecuteToolCall.mockResolvedValue({ - status: 'success', - request: { - callId: 'tool-shell-1', - name: 'ShellTool', - args: { command: 'ls' }, - isClientInitiated: false, - prompt_id: 'prompt-id-allowed', + mockSchedulerSchedule.mockResolvedValue([ + { + status: 'success', + request: { + callId: 'tool-shell-1', + name: 'ShellTool', + args: { command: 'ls' }, + isClientInitiated: false, + prompt_id: 'prompt-id-allowed', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + responseParts: toolResponse, + callId: 'tool-shell-1', + error: undefined, + errorType: undefined, + contentLength: undefined, + }, }, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - responseParts: toolResponse, - callId: 'tool-shell-1', - error: undefined, - errorType: undefined, - contentLength: undefined, - }, - }); + ]); const firstCallEvents: ServerGeminiStreamEvent[] = [toolCallEvent]; const secondCallEvents: ServerGeminiStreamEvent[] = [ @@ -1288,9 +1306,8 @@ describe('runNonInteractive', () => { prompt_id: 'prompt-id-allowed', }); - expect(mockCoreExecuteToolCall).toHaveBeenCalledWith( - mockConfig, - expect.objectContaining({ name: 'ShellTool' }), + expect(mockSchedulerSchedule).toHaveBeenCalledWith( + [expect.objectContaining({ name: 'ShellTool' })], expect.any(AbortSignal), ); expect(getWrittenOutput()).toBe('file.txt\n'); @@ -1446,20 +1463,22 @@ describe('runNonInteractive', () => { }, }; - mockCoreExecuteToolCall.mockResolvedValue({ - status: 'success', - request: toolCallEvent.value, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - responseParts: [{ text: 'Tool response' }], - callId: 'tool-1', - error: undefined, - errorType: undefined, - contentLength: undefined, - resultDisplay: 'Tool executed successfully', + mockSchedulerSchedule.mockResolvedValue([ + { + status: 'success', + request: toolCallEvent.value, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + responseParts: [{ text: 'Tool response' }], + callId: 'tool-1', + error: undefined, + errorType: undefined, + contentLength: undefined, + resultDisplay: 'Tool executed successfully', + }, }, - }); + ]); const firstCallEvents: ServerGeminiStreamEvent[] = [ { type: GeminiEventType.Content, value: 'Thinking...' }, @@ -1636,19 +1655,21 @@ describe('runNonInteractive', () => { prompt_id: 'prompt-id-tool-error', }, }; - mockCoreExecuteToolCall.mockResolvedValue({ - status: 'success', - request: toolCallEvent.value, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - responseParts: [], - callId: 'tool-1', - error: undefined, - errorType: undefined, - contentLength: undefined, + mockSchedulerSchedule.mockResolvedValue([ + { + status: 'success', + request: toolCallEvent.value, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + responseParts: [], + callId: 'tool-1', + error: undefined, + errorType: undefined, + contentLength: undefined, + }, }, - }); + ]); const events: ServerGeminiStreamEvent[] = [ toolCallEvent, @@ -1717,19 +1738,21 @@ describe('runNonInteractive', () => { }; // Mock tool execution returning STOP_EXECUTION - mockCoreExecuteToolCall.mockResolvedValue({ - status: 'error', - request: toolCallEvent.value, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - callId: 'stop-call', - responseParts: [{ text: 'error occurred' }], - errorType: ToolErrorType.STOP_EXECUTION, - error: new Error('Stop reason from hook'), - resultDisplay: undefined, + mockSchedulerSchedule.mockResolvedValue([ + { + status: 'error', + request: toolCallEvent.value, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: 'stop-call', + responseParts: [{ text: 'error occurred' }], + errorType: ToolErrorType.STOP_EXECUTION, + error: new Error('Stop reason from hook'), + resultDisplay: undefined, + }, }, - }); + ]); const firstCallEvents: ServerGeminiStreamEvent[] = [ { type: GeminiEventType.Content, value: 'Executing tool...' }, @@ -1750,7 +1773,7 @@ describe('runNonInteractive', () => { prompt_id: 'prompt-id-stop', }); - expect(mockCoreExecuteToolCall).toHaveBeenCalled(); + expect(mockSchedulerSchedule).toHaveBeenCalled(); // The key assertion: sendMessageStream should have been called ONLY ONCE (initial user input). expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(1); @@ -1777,19 +1800,21 @@ describe('runNonInteractive', () => { }, }; - mockCoreExecuteToolCall.mockResolvedValue({ - status: 'error', - request: toolCallEvent.value, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - callId: 'stop-call', - responseParts: [{ text: 'error occurred' }], - errorType: ToolErrorType.STOP_EXECUTION, - error: new Error('Stop reason'), - resultDisplay: undefined, + mockSchedulerSchedule.mockResolvedValue([ + { + status: 'error', + request: toolCallEvent.value, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: 'stop-call', + responseParts: [{ text: 'error occurred' }], + errorType: ToolErrorType.STOP_EXECUTION, + error: new Error('Stop reason'), + resultDisplay: undefined, + }, }, - }); + ]); const firstCallEvents: ServerGeminiStreamEvent[] = [ { type: GeminiEventType.Content, value: 'Partial content' }, @@ -1839,19 +1864,21 @@ describe('runNonInteractive', () => { }, }; - mockCoreExecuteToolCall.mockResolvedValue({ - status: 'error', - request: toolCallEvent.value, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - callId: 'stop-call', - responseParts: [{ text: 'error occurred' }], - errorType: ToolErrorType.STOP_EXECUTION, - error: new Error('Stop reason'), - resultDisplay: undefined, + mockSchedulerSchedule.mockResolvedValue([ + { + status: 'error', + request: toolCallEvent.value, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: 'stop-call', + responseParts: [{ text: 'error occurred' }], + errorType: ToolErrorType.STOP_EXECUTION, + error: new Error('Stop reason'), + resultDisplay: undefined, + }, }, - }); + ]); const firstCallEvents: ServerGeminiStreamEvent[] = [toolCallEvent]; @@ -2066,5 +2093,63 @@ describe('runNonInteractive', () => { expect.stringContaining('[WARNING] --raw-output is enabled'), ); }); + + it('should report cancelled tool calls as success in stream-json mode (legacy parity)', async () => { + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-1', + name: 'testTool', + args: { arg1: 'value1' }, + isClientInitiated: false, + prompt_id: 'prompt-id-cancel', + }, + }; + + // Mock the scheduler to return a cancelled status + mockSchedulerSchedule.mockResolvedValue([ + { + status: 'cancelled', + request: toolCallEvent.value, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: 'tool-1', + responseParts: [{ text: 'Operation cancelled' }], + resultDisplay: 'Cancelled', + }, + }, + ]); + + const events: ServerGeminiStreamEvent[] = [ + toolCallEvent, + { + type: GeminiEventType.Content, + value: 'Model continues...', + }, + ]; + + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + vi.mocked(mockConfig.getOutputFormat).mockReturnValue( + OutputFormat.STREAM_JSON, + ); + vi.mocked(uiTelemetryService.getMetrics).mockReturnValue( + MOCK_SESSION_METRICS, + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Test input', + prompt_id: 'prompt-id-cancel', + }); + + const output = getWrittenOutput(); + expect(output).toContain('"type":"tool_result"'); + expect(output).toContain('"status":"success"'); + }); }); }); diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index 50ba2235c4..17d2537624 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -8,13 +8,11 @@ import type { Config, ToolCallRequestInfo, ResumedSessionData, - CompletedToolCall, UserFeedbackPayload, } from '@google/gemini-cli-core'; import { isSlashCommand } from './ui/utils/commandUtils.js'; import type { LoadedSettings } from './config/settings.js'; import { - executeToolCall, GeminiEventType, FatalInputError, promptIdContext, @@ -29,6 +27,8 @@ import { createWorkingStdio, recordToolCallInteractions, ToolErrorType, + Scheduler, + ROOT_SCHEDULER_ID, } from '@google/gemini-cli-core'; import type { Content, Part } from '@google/genai'; @@ -202,6 +202,12 @@ export async function runNonInteractive({ }); const geminiClient = config.getGeminiClient(); + const scheduler = new Scheduler({ + config, + messageBus: config.getMessageBus(), + getPreferredEditor: () => undefined, + schedulerId: ROOT_SCHEDULER_ID, + }); // Initialize chat. Resume if resume data is passed. if (resumedSessionData) { @@ -375,25 +381,23 @@ export async function runNonInteractive({ if (toolCallRequests.length > 0) { textOutput.ensureTrailingNewline(); + const completedToolCalls = await scheduler.schedule( + toolCallRequests, + abortController.signal, + ); const toolResponseParts: Part[] = []; - const completedToolCalls: CompletedToolCall[] = []; - for (const requestInfo of toolCallRequests) { - const completedToolCall = await executeToolCall( - config, - requestInfo, - abortController.signal, - ); + for (const completedToolCall of completedToolCalls) { const toolResponse = completedToolCall.response; - - completedToolCalls.push(completedToolCall); + const requestInfo = completedToolCall.request; if (streamFormatter) { streamFormatter.emitEvent({ type: JsonStreamEventType.TOOL_RESULT, timestamp: new Date().toISOString(), tool_id: requestInfo.callId, - status: toolResponse.error ? 'error' : 'success', + status: + completedToolCall.status === 'error' ? 'error' : 'success', output: typeof toolResponse.resultDisplay === 'string' ? toolResponse.resultDisplay diff --git a/packages/core/src/scheduler/scheduler.test.ts b/packages/core/src/scheduler/scheduler.test.ts index 45884f1de0..3456b5ec26 100644 --- a/packages/core/src/scheduler/scheduler.test.ts +++ b/packages/core/src/scheduler/scheduler.test.ts @@ -35,7 +35,10 @@ vi.mock('../telemetry/types.js', () => ({ ToolCallEvent: vi.fn().mockImplementation((call) => ({ ...call })), })); -import { SchedulerStateManager } from './state-manager.js'; +import { + SchedulerStateManager, + type TerminalCallHandler, +} from './state-manager.js'; import { resolveConfirmation } from './confirmation.js'; import { checkPolicy, updatePolicy } from './policy.js'; import { ToolExecutor } from './tool-executor.js'; @@ -64,6 +67,7 @@ import type { SuccessfulToolCall, ErroredToolCall, CancelledToolCall, + CompletedToolCall, ToolCallResponseInfo, } from './types.js'; import { ROOT_SCHEDULER_ID } from './types.js'; @@ -201,10 +205,27 @@ describe('Scheduler (Orchestrator)', () => { applyInlineModify: vi.fn(), } as unknown as Mocked; - // Wire up class constructors to return our mock instances - vi.mocked(SchedulerStateManager).mockReturnValue( - mockStateManager as unknown as Mocked, + let capturedTerminalHandler: TerminalCallHandler | undefined; + vi.mocked(SchedulerStateManager).mockImplementation( + (_messageBus, _schedulerId, onTerminalCall) => { + capturedTerminalHandler = onTerminalCall; + return mockStateManager as unknown as SchedulerStateManager; + }, ); + + mockStateManager.finalizeCall.mockImplementation((callId: string) => { + const call = mockStateManager.getToolCall(callId); + if (call) { + capturedTerminalHandler?.(call as CompletedToolCall); + } + }); + + mockStateManager.cancelAllQueued.mockImplementation((_reason: string) => { + // In tests, we usually mock the queue or completed batch. + // For the sake of telemetry tests, we manually trigger if needed, + // but most tests here check if finalizing is called. + }); + vi.mocked(ToolExecutor).mockReturnValue( mockExecutor as unknown as Mocked, ); diff --git a/packages/core/src/scheduler/scheduler.ts b/packages/core/src/scheduler/scheduler.ts index 5853736a01..fc159633d9 100644 --- a/packages/core/src/scheduler/scheduler.ts +++ b/packages/core/src/scheduler/scheduler.ts @@ -101,7 +101,11 @@ export class Scheduler { this.getPreferredEditor = options.getPreferredEditor; this.schedulerId = options.schedulerId; this.parentCallId = options.parentCallId; - this.state = new SchedulerStateManager(this.messageBus, this.schedulerId); + this.state = new SchedulerStateManager( + this.messageBus, + this.schedulerId, + (call) => logToolCall(this.config, new ToolCallEvent(call)), + ); this.executor = new ToolExecutor(this.config); this.modifier = new ToolModificationHandler(); @@ -388,16 +392,6 @@ export class Scheduler { } } - // Fetch the updated call from state before finalizing to capture the - // terminal status. - const terminalCall = this.state.getToolCall(active.request.callId); - if (terminalCall && this.isTerminal(terminalCall.status)) { - logToolCall( - this.config, - new ToolCallEvent(terminalCall as CompletedToolCall), - ); - } - this.state.finalizeCall(active.request.callId); } @@ -422,6 +416,7 @@ export class Scheduler { ToolErrorType.POLICY_VIOLATION, ), ); + this.state.finalizeCall(callId); return; } @@ -453,6 +448,7 @@ export class Scheduler { // Handle cancellation (cascades to entire batch) if (outcome === ToolConfirmationOutcome.Cancel) { this.state.updateStatus(callId, 'cancelled', 'User denied execution.'); + this.state.finalizeCall(callId); this.state.cancelAllQueued('User cancelled operation'); return; // Skip execution } diff --git a/packages/core/src/scheduler/state-manager.test.ts b/packages/core/src/scheduler/state-manager.test.ts index 3a6d535d9b..d0369fdcb1 100644 --- a/packages/core/src/scheduler/state-manager.test.ts +++ b/packages/core/src/scheduler/state-manager.test.ts @@ -23,6 +23,7 @@ import { } from '../tools/tools.js'; import { MessageBusType } from '../confirmation-bus/types.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { ROOT_SCHEDULER_ID } from './types.js'; describe('SchedulerStateManager', () => { const mockRequest: ToolCallRequestInfo = { @@ -83,6 +84,60 @@ describe('SchedulerStateManager', () => { stateManager = new SchedulerStateManager(mockMessageBus); }); + describe('Observer Callback', () => { + it('should trigger onTerminalCall when finalizing a call', () => { + const onTerminalCall = vi.fn(); + const manager = new SchedulerStateManager( + mockMessageBus, + ROOT_SCHEDULER_ID, + onTerminalCall, + ); + const call = createValidatingCall(); + manager.enqueue([call]); + manager.dequeue(); + manager.updateStatus( + call.request.callId, + 'success', + createMockResponse(call.request.callId), + ); + manager.finalizeCall(call.request.callId); + + expect(onTerminalCall).toHaveBeenCalledTimes(1); + expect(onTerminalCall).toHaveBeenCalledWith( + expect.objectContaining({ + status: 'success', + request: expect.objectContaining({ callId: call.request.callId }), + }), + ); + }); + + it('should trigger onTerminalCall for every call in cancelAllQueued', () => { + const onTerminalCall = vi.fn(); + const manager = new SchedulerStateManager( + mockMessageBus, + ROOT_SCHEDULER_ID, + onTerminalCall, + ); + manager.enqueue([createValidatingCall('1'), createValidatingCall('2')]); + + manager.cancelAllQueued('Test cancel'); + + expect(onTerminalCall).toHaveBeenCalledTimes(2); + expect(onTerminalCall).toHaveBeenCalledWith( + expect.objectContaining({ + status: 'cancelled', + request: expect.objectContaining({ callId: '1' }), + }), + ); + expect(onTerminalCall).toHaveBeenCalledWith( + expect.objectContaining({ + status: 'cancelled', + request: expect.objectContaining({ callId: '2' }), + }), + ); + }); + }); + describe('Initialization', () => { it('should start with empty state', () => { expect(stateManager.isActive).toBe(false); diff --git a/packages/core/src/scheduler/state-manager.ts b/packages/core/src/scheduler/state-manager.ts index 519bdb3ee3..625d58a463 100644 --- a/packages/core/src/scheduler/state-manager.ts +++ b/packages/core/src/scheduler/state-manager.ts @@ -31,6 +31,11 @@ import { type SerializableConfirmationDetails, } from '../confirmation-bus/types.js'; +/** + * Handler for terminal tool calls. + */ +export type TerminalCallHandler = (call: CompletedToolCall) => void; + /** * Manages the state of tool calls. * Publishes state changes to the MessageBus via TOOL_CALLS_UPDATE events. @@ -43,6 +48,7 @@ export class SchedulerStateManager { constructor( private readonly messageBus: MessageBus, private readonly schedulerId: string = ROOT_SCHEDULER_ID, + private readonly onTerminalCall?: TerminalCallHandler, ) {} addToolCalls(calls: ToolCall[]): void { @@ -134,6 +140,8 @@ export class SchedulerStateManager { if (this.isTerminalCall(call)) { this._completedBatch.push(call); this.activeCalls.delete(callId); + + this.onTerminalCall?.(call); this.emitUpdate(); } } @@ -173,9 +181,12 @@ export class SchedulerStateManager { const queuedCall = this.queue.shift()!; if (queuedCall.status === 'error') { this._completedBatch.push(queuedCall); + this.onTerminalCall?.(queuedCall); continue; } - this._completedBatch.push(this.toCancelled(queuedCall, reason)); + const cancelledCall = this.toCancelled(queuedCall, reason); + this._completedBatch.push(cancelledCall); + this.onTerminalCall?.(cancelledCall); } this.emitUpdate(); }