From 541eeb7a50254b9bca0545992906a313d49d00c0 Mon Sep 17 00:00:00 2001 From: joshualitt Date: Mon, 27 Oct 2025 09:59:08 -0700 Subject: [PATCH] feat(core, cli): Implement sequential approval. (#11593) --- packages/a2a-server/src/agent/task.test.ts | 121 +++++- packages/a2a-server/src/agent/task.ts | 19 +- packages/a2a-server/src/http/app.test.ts | 219 +++++++++-- .../cli/src/ui/hooks/useGeminiStream.test.tsx | 190 +++++++--- packages/cli/src/ui/hooks/useGeminiStream.ts | 165 ++++++--- .../cli/src/ui/hooks/useReactToolScheduler.ts | 69 ++-- .../cli/src/ui/hooks/useToolScheduler.test.ts | 190 +++++++--- .../core/src/core/coreToolScheduler.test.ts | 290 ++++++++++++++- packages/core/src/core/coreToolScheduler.ts | 348 ++++++++++++------ 9 files changed, 1272 insertions(+), 339 deletions(-) diff --git a/packages/a2a-server/src/agent/task.test.ts b/packages/a2a-server/src/agent/task.test.ts index 513867f4e2..1bf26d8bc8 100644 --- a/packages/a2a-server/src/agent/task.test.ts +++ b/packages/a2a-server/src/agent/task.test.ts @@ -4,11 +4,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, vi } from 'vitest'; +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { Task } from './task.js'; import type { Config, ToolCallRequestInfo } from '@google/gemini-cli-core'; import { createMockConfig } from '../utils/testing_utils.js'; import type { ExecutionEventBus } from '@a2a-js/sdk/server'; +import type { ToolCall } from '@google/gemini-cli-core'; describe('Task', () => { it('scheduleToolCalls should not modify the input requests array', async () => { @@ -94,4 +95,122 @@ describe('Task', () => { ); }); }); + + describe('_schedulerToolCallsUpdate', () => { + let task: Task; + type SpyInstance = ReturnType; + let setTaskStateAndPublishUpdateSpy: SpyInstance; + + beforeEach(() => { + const mockConfig = createMockConfig(); + const mockEventBus: ExecutionEventBus = { + publish: vi.fn(), + on: vi.fn(), + off: vi.fn(), + once: vi.fn(), + removeAllListeners: vi.fn(), + finished: vi.fn(), + }; + + // @ts-expect-error - Calling private constructor + task = new Task( + 'task-id', + 'context-id', + mockConfig as Config, + mockEventBus, + ); + + // Spy on the method we want to check calls for + setTaskStateAndPublishUpdateSpy = vi.spyOn( + task, + 'setTaskStateAndPublishUpdate', + ); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('should set state to input-required when a tool is awaiting approval and none are executing', () => { + const toolCalls = [ + { request: { callId: '1' }, status: 'awaiting_approval' }, + ] as ToolCall[]; + + // @ts-expect-error - Calling private method + task._schedulerToolCallsUpdate(toolCalls); + + // The last call should be the final state update + expect(setTaskStateAndPublishUpdateSpy).toHaveBeenLastCalledWith( + 'input-required', + { kind: 'state-change' }, + undefined, + undefined, + true, // final: true + ); + }); + + it('should NOT set state to input-required if a tool is awaiting approval but another is executing', () => { + const toolCalls = [ + { request: { callId: '1' }, status: 'awaiting_approval' }, + { request: { callId: '2' }, status: 'executing' }, + ] as ToolCall[]; + + // @ts-expect-error - Calling private method + task._schedulerToolCallsUpdate(toolCalls); + + // It will be called for status updates, but not with final: true + const finalCall = setTaskStateAndPublishUpdateSpy.mock.calls.find( + (call) => call[4] === true, + ); + expect(finalCall).toBeUndefined(); + }); + + it('should set state to input-required once an executing tool finishes, leaving one awaiting approval', () => { + const initialToolCalls = [ + { request: { callId: '1' }, status: 'awaiting_approval' }, + { request: { callId: '2' }, status: 'executing' }, + ] as ToolCall[]; + // @ts-expect-error - Calling private method + task._schedulerToolCallsUpdate(initialToolCalls); + + // No final call yet + let finalCall = setTaskStateAndPublishUpdateSpy.mock.calls.find( + (call) => call[4] === true, + ); + expect(finalCall).toBeUndefined(); + + // Now, the executing tool finishes. The scheduler would call _resolveToolCall for it. + // @ts-expect-error - Calling private method + task._resolveToolCall('2'); + + // Then another update comes in for the awaiting tool (e.g., a re-check) + const subsequentToolCalls = [ + { request: { callId: '1' }, status: 'awaiting_approval' }, + ] as ToolCall[]; + // @ts-expect-error - Calling private method + task._schedulerToolCallsUpdate(subsequentToolCalls); + + // NOW we should get the final call + finalCall = setTaskStateAndPublishUpdateSpy.mock.calls.find( + (call) => call[4] === true, + ); + expect(finalCall).toBeDefined(); + expect(finalCall?.[0]).toBe('input-required'); + }); + + it('should NOT set state to input-required if skipFinalTrueAfterInlineEdit is true', () => { + task.skipFinalTrueAfterInlineEdit = true; + const toolCalls = [ + { request: { callId: '1' }, status: 'awaiting_approval' }, + ] as ToolCall[]; + + // @ts-expect-error - Calling private method + task._schedulerToolCallsUpdate(toolCalls); + + const finalCall = setTaskStateAndPublishUpdateSpy.mock.calls.find( + (call) => call[4] === true, + ); + expect(finalCall).toBeUndefined(); + }); + }); }); diff --git a/packages/a2a-server/src/agent/task.ts b/packages/a2a-server/src/agent/task.ts index a7b0e288c9..eee5e736d6 100644 --- a/packages/a2a-server/src/agent/task.ts +++ b/packages/a2a-server/src/agent/task.ts @@ -40,7 +40,6 @@ import type { import { v4 as uuidv4 } from 'uuid'; import { logger } from '../utils/logger.js'; import * as fs from 'node:fs'; - import { CoderAgentEvent } from '../types.js'; import type { CoderAgentMessage, @@ -373,11 +372,11 @@ export class Task { // Only send an update if the status has actually changed. if (hasChanged) { - const message = this.toolStatusMessage(tc, this.id, this.contextId); const coderAgentMessage: CoderAgentMessage = tc.status === 'awaiting_approval' ? { kind: CoderAgentEvent.ToolCallConfirmationEvent } : { kind: CoderAgentEvent.ToolCallUpdateEvent }; + const message = this.toolStatusMessage(tc, this.id, this.contextId); const event = this._createStatusUpdateEvent( this.taskState, @@ -404,20 +403,16 @@ export class Task { const isAwaitingApproval = allPendingStatuses.some( (status) => status === 'awaiting_approval', ); - const allPendingAreStable = allPendingStatuses.every( - (status) => - status === 'awaiting_approval' || - status === 'success' || - status === 'error' || - status === 'cancelled', + const isExecuting = allPendingStatuses.some( + (status) => status === 'executing', ); - // 1. Are any pending tool calls awaiting_approval - // 2. Are all pending tool calls in a stable state (i.e. not in validing or executing) - // 3. After an inline edit, the edited tool call will send awaiting_approval THEN scheduled. We wait for the next update in this case. + // The turn is complete and requires user input if at least one tool + // is waiting for the user's decision, and no other tool is actively + // running in the background. if ( isAwaitingApproval && - allPendingAreStable && + !isExecuting && !this.skipFinalTrueAfterInlineEdit ) { this.skipFinalTrueAfterInlineEdit = false; diff --git a/packages/a2a-server/src/http/app.test.ts b/packages/a2a-server/src/http/app.test.ts index 70d90f78cb..15b386bd3d 100644 --- a/packages/a2a-server/src/http/app.test.ts +++ b/packages/a2a-server/src/http/app.test.ts @@ -313,7 +313,7 @@ describe('E2E Tests', () => { expect(workingEvent.kind).toBe('status-update'); expect(workingEvent.status.state).toBe('working'); - // State Update: Validate each tool call + // State Update: Validate the first tool call const toolCallValidateEvent1 = events[3].result as TaskStatusUpdateEvent; expect(toolCallValidateEvent1.metadata?.['coderAgent']).toMatchObject({ kind: 'tool-call-update', @@ -326,47 +326,218 @@ describe('E2E Tests', () => { }, }, ]); - const toolCallValidateEvent2 = events[4].result as TaskStatusUpdateEvent; - expect(toolCallValidateEvent2.metadata?.['coderAgent']).toMatchObject({ + + // --- Assert the event stream --- + // 1. Initial "submitted" status. + expect((events[0].result as TaskStatusUpdateEvent).status.state).toBe( + 'submitted', + ); + + // 2. "working" status after receiving the user prompt. + expect((events[1].result as TaskStatusUpdateEvent).status.state).toBe( + 'working', + ); + + // 3. A "state-change" event from the agent. + expect(events[2].result.metadata?.['coderAgent']).toMatchObject({ + kind: 'state-change', + }); + + // 4. Tool 1 is validating. + const toolCallUpdate1 = events[3].result as TaskStatusUpdateEvent; + expect(toolCallUpdate1.metadata?.['coderAgent']).toMatchObject({ kind: 'tool-call-update', }); - expect(toolCallValidateEvent2.status.message?.parts).toMatchObject([ + expect(toolCallUpdate1.status.message?.parts).toMatchObject([ { data: { + request: { callId: 'test-call-id-1' }, status: 'validating', - request: { callId: 'test-call-id-2' }, }, }, ]); - // State Update: Set each tool call to awaiting - const toolCallAwaitEvent1 = events[5].result as TaskStatusUpdateEvent; - expect(toolCallAwaitEvent1.metadata?.['coderAgent']).toMatchObject({ - kind: 'tool-call-confirmation', + // 5. Tool 2 is validating. + const toolCallUpdate2 = events[4].result as TaskStatusUpdateEvent; + expect(toolCallUpdate2.metadata?.['coderAgent']).toMatchObject({ + kind: 'tool-call-update', }); - expect(toolCallAwaitEvent1.status.message?.parts).toMatchObject([ + expect(toolCallUpdate2.status.message?.parts).toMatchObject([ { data: { - status: 'awaiting_approval', - request: { callId: 'test-call-id-1' }, - }, - }, - ]); - const toolCallAwaitEvent2 = events[6].result as TaskStatusUpdateEvent; - expect(toolCallAwaitEvent2.metadata?.['coderAgent']).toMatchObject({ - kind: 'tool-call-confirmation', - }); - expect(toolCallAwaitEvent2.status.message?.parts).toMatchObject([ - { - data: { - status: 'awaiting_approval', request: { callId: 'test-call-id-2' }, + status: 'validating', }, }, ]); + // 6. Tool 1 is awaiting approval. + const toolCallAwaitEvent = events[5].result as TaskStatusUpdateEvent; + expect(toolCallAwaitEvent.metadata?.['coderAgent']).toMatchObject({ + kind: 'tool-call-confirmation', + }); + expect(toolCallAwaitEvent.status.message?.parts).toMatchObject([ + { + data: { + request: { callId: 'test-call-id-1' }, + status: 'awaiting_approval', + }, + }, + ]); + + // 7. The final event is "input-required". + const finalEvent = events[6].result as TaskStatusUpdateEvent; + expect(finalEvent.final).toBe(true); + expect(finalEvent.status.state).toBe('input-required'); + + // The scheduler now waits for approval, so no more events are sent. + assertUniqueFinalEventIsLast(events); + expect(events.length).toBe(7); + }); + + it('should handle multiple tool calls sequentially in YOLO mode', async () => { + // Set YOLO mode to auto-approve tools and test sequential execution. + getApprovalModeSpy.mockReturnValue(ApprovalMode.YOLO); + + // First call yields the tool request + sendMessageStreamSpy.mockImplementationOnce(async function* () { + yield* [ + { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'test-call-id-1', + name: 'test-tool-1', + args: {}, + }, + }, + { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'test-call-id-2', + name: 'test-tool-2', + args: {}, + }, + }, + ]; + }); + // Subsequent calls yield nothing, as the tools will "succeed". + sendMessageStreamSpy.mockImplementation(async function* () { + yield* [{ type: 'content', value: 'All tools executed.' }]; + }); + + const mockTool1 = new MockTool({ + name: 'test-tool-1', + displayName: 'Test Tool 1', + shouldConfirmExecute: vi.fn(mockToolConfirmationFn), + execute: vi + .fn() + .mockResolvedValue({ llmContent: 'tool 1 done', returnDisplay: '' }), + }); + const mockTool2 = new MockTool({ + name: 'test-tool-2', + displayName: 'Test Tool 2', + shouldConfirmExecute: vi.fn(mockToolConfirmationFn), + execute: vi + .fn() + .mockResolvedValue({ llmContent: 'tool 2 done', returnDisplay: '' }), + }); + + getToolRegistrySpy.mockReturnValue({ + getAllTools: vi.fn().mockReturnValue([mockTool1, mockTool2]), + getToolsByServer: vi.fn().mockReturnValue([]), + getTool: vi.fn().mockImplementation((name: string) => { + if (name === 'test-tool-1') return mockTool1; + if (name === 'test-tool-2') return mockTool2; + return undefined; + }), + }); + + const agent = request.agent(app); + const res = await agent + .post('/') + .send( + createStreamMessageRequest( + 'run two tools', + 'a2a-multi-tool-test-message', + ), + ) + .set('Content-Type', 'application/json') + .expect(200); + + const events = streamToSSEEvents(res.text); + assertTaskCreationAndWorkingStatus(events); + + // --- Assert the sequential execution flow --- + const eventStream = events.slice(2).map((e) => { + const update = e.result as TaskStatusUpdateEvent; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const agentData = update.metadata?.['coderAgent'] as any; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const toolData = update.status.message?.parts[0] as any; + if (!toolData) { + return { kind: agentData.kind }; + } + return { + kind: agentData.kind, + status: toolData.data?.status, + callId: toolData.data?.request.callId, + }; + }); + + const expectedFlow = [ + // Initial state change + { kind: 'state-change', status: undefined, callId: undefined }, + // Tool 1 Lifecycle + { + kind: 'tool-call-update', + status: 'validating', + callId: 'test-call-id-1', + }, + { + kind: 'tool-call-update', + status: 'scheduled', + callId: 'test-call-id-1', + }, + { + kind: 'tool-call-update', + status: 'executing', + callId: 'test-call-id-1', + }, + { + kind: 'tool-call-update', + status: 'success', + callId: 'test-call-id-1', + }, + // Tool 2 Lifecycle + { + kind: 'tool-call-update', + status: 'validating', + callId: 'test-call-id-2', + }, + { + kind: 'tool-call-update', + status: 'scheduled', + callId: 'test-call-id-2', + }, + { + kind: 'tool-call-update', + status: 'executing', + callId: 'test-call-id-2', + }, + { + kind: 'tool-call-update', + status: 'success', + callId: 'test-call-id-2', + }, + // Final updates + { kind: 'state-change', status: undefined, callId: undefined }, + { kind: 'text-content', status: undefined, callId: undefined }, + ]; + + // Use `toContainEqual` for flexibility if other events are interspersed. + expect(eventStream).toEqual(expect.arrayContaining(expectedFlow)); + assertUniqueFinalEventIsLast(events); - expect(events.length).toBe(8); }); it('should handle tool calls that do not require approval', async () => { diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 14a596c9e1..37698a09b9 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -37,7 +37,7 @@ import { } from '@google/gemini-cli-core'; import type { Part, PartListUnion } from '@google/genai'; import type { UseHistoryManagerReturn } from './useHistoryManager.js'; -import type { HistoryItem, SlashCommandProcessorResult } from '../types.js'; +import type { SlashCommandProcessorResult } from '../types.js'; import { MessageType, StreamingState } from '../types.js'; import type { LoadedSettings } from '../../config/settings.js'; @@ -231,8 +231,9 @@ describe('useGeminiStream', () => { mockUseReactToolScheduler.mockReturnValue([ [], // Default to empty array for toolCalls mockScheduleToolCalls, - mockCancelAllToolCalls, mockMarkToolsAsSubmitted, + vi.fn(), // setToolCallsForDisplay + mockCancelAllToolCalls, ]); // Reset mocks for GeminiClient instance methods (startChat and sendMessageStream) @@ -259,38 +260,71 @@ describe('useGeminiStream', () => { initialToolCalls: TrackedToolCall[] = [], geminiClient?: any, ) => { - let currentToolCalls = initialToolCalls; - const setToolCalls = (newToolCalls: TrackedToolCall[]) => { - currentToolCalls = newToolCalls; - }; - - mockUseReactToolScheduler.mockImplementation(() => [ - currentToolCalls, - mockScheduleToolCalls, - mockCancelAllToolCalls, - mockMarkToolsAsSubmitted, - ]); - const client = geminiClient || mockConfig.getGeminiClient(); + const initialProps = { + client, + history: [], + addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'], + config: mockConfig, + onDebugMessage: mockOnDebugMessage, + handleSlashCommand: mockHandleSlashCommand as unknown as ( + cmd: PartListUnion, + ) => Promise, + shellModeActive: false, + loadedSettings: mockLoadedSettings, + toolCalls: initialToolCalls, + }; + const { result, rerender } = renderHook( - (props: { - client: any; - history: HistoryItem[]; - addItem: UseHistoryManagerReturn['addItem']; - config: Config; - onDebugMessage: (message: string) => void; - handleSlashCommand: ( - cmd: PartListUnion, - ) => Promise; - shellModeActive: boolean; - loadedSettings: LoadedSettings; - toolCalls?: TrackedToolCall[]; // Allow passing updated toolCalls - }) => { - // Update the mock's return value if new toolCalls are passed in props - if (props.toolCalls) { - setToolCalls(props.toolCalls); - } + (props: typeof initialProps) => { + // This mock needs to be stateful. When setToolCallsForDisplay is called, + // it should trigger a rerender with the new state. + const mockSetToolCallsForDisplay = vi.fn((updater) => { + const newToolCalls = + typeof updater === 'function' ? updater(props.toolCalls) : updater; + rerender({ ...props, toolCalls: newToolCalls }); + }); + + // Create a stateful mock for cancellation that updates the toolCalls state. + const statefulCancelAllToolCalls = vi.fn((...args) => { + // Call the original spy so `toHaveBeenCalled` checks still work. + mockCancelAllToolCalls(...args); + + const newToolCalls = props.toolCalls.map((tc) => { + // Only cancel tools that are in a cancellable state. + if ( + tc.status === 'awaiting_approval' || + tc.status === 'executing' || + tc.status === 'scheduled' || + tc.status === 'validating' + ) { + // A real cancelled tool call has a response object. + // We need to simulate this to avoid type errors downstream. + return { + ...tc, + status: 'cancelled', + response: { + callId: tc.request.callId, + responseParts: [], + resultDisplay: 'Request cancelled.', + }, + responseSubmittedToGemini: true, // Mark as "processed" + } as any as TrackedCancelledToolCall; + } + return tc; + }); + rerender({ ...props, toolCalls: newToolCalls }); + }); + + mockUseReactToolScheduler.mockImplementation(() => [ + props.toolCalls, + mockScheduleToolCalls, + mockMarkToolsAsSubmitted, + mockSetToolCallsForDisplay, + statefulCancelAllToolCalls, // Use the stateful mock + ]); + return useGeminiStream( props.client, props.history, @@ -313,19 +347,7 @@ describe('useGeminiStream', () => { ); }, { - initialProps: { - client, - history: [], - addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'], - config: mockConfig, - onDebugMessage: mockOnDebugMessage, - handleSlashCommand: mockHandleSlashCommand as unknown as ( - cmd: PartListUnion, - ) => Promise, - shellModeActive: false, - loadedSettings: mockLoadedSettings, - toolCalls: initialToolCalls, - }, + initialProps, }, ); return { @@ -452,7 +474,7 @@ describe('useGeminiStream', () => { mockUseReactToolScheduler.mockImplementation((onComplete) => { capturedOnComplete = onComplete; - return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted]; + return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted, vi.fn()]; }); renderHook(() => @@ -535,7 +557,7 @@ describe('useGeminiStream', () => { mockUseReactToolScheduler.mockImplementation((onComplete) => { capturedOnComplete = onComplete; - return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted]; + return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted, vi.fn()]; }); renderHook(() => @@ -647,7 +669,7 @@ describe('useGeminiStream', () => { mockUseReactToolScheduler.mockImplementation((onComplete) => { capturedOnComplete = onComplete; - return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted]; + return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted, vi.fn()]; }); renderHook(() => @@ -760,6 +782,7 @@ describe('useGeminiStream', () => { currentToolCalls, mockScheduleToolCalls, mockMarkToolsAsSubmitted, + vi.fn(), // setToolCallsForDisplay ]; }); @@ -797,6 +820,7 @@ describe('useGeminiStream', () => { completedToolCalls, mockScheduleToolCalls, mockMarkToolsAsSubmitted, + vi.fn(), // setToolCallsForDisplay ]; }); @@ -1031,7 +1055,7 @@ describe('useGeminiStream', () => { expect(result.current.streamingState).toBe(StreamingState.Idle); }); - it('should not cancel if a tool call is in progress (not just responding)', async () => { + it('should cancel if a tool call is in progress', async () => { const toolCalls: TrackedToolCall[] = [ { request: { callId: 'call1', name: 'tool1', args: {} }, @@ -1052,7 +1076,6 @@ describe('useGeminiStream', () => { } as TrackedExecutingToolCall, ]; - const abortSpy = vi.spyOn(AbortController.prototype, 'abort'); const { result } = renderTestHook(toolCalls); // State is `Responding` because a tool is running @@ -1061,8 +1084,71 @@ describe('useGeminiStream', () => { // Try to cancel simulateEscapeKeyPress(); - // Nothing should happen because the state is not `Responding` - expect(abortSpy).not.toHaveBeenCalled(); + // The cancel function should be called + expect(mockCancelAllToolCalls).toHaveBeenCalled(); + }); + + it('should cancel a request when a tool is awaiting confirmation', async () => { + const mockOnConfirm = vi.fn().mockResolvedValue(undefined); + const toolCalls: TrackedToolCall[] = [ + { + request: { + callId: 'confirm-call', + name: 'some_tool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-1', + }, + status: 'awaiting_approval', + responseSubmittedToGemini: false, + tool: { + name: 'some_tool', + description: 'a tool', + build: vi.fn().mockImplementation((_) => ({ + getDescription: () => `Mock description`, + })), + } as any, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, + confirmationDetails: { + type: 'edit', + title: 'Confirm Edit', + onConfirm: mockOnConfirm, + fileName: 'file.txt', + filePath: '/test/file.txt', + fileDiff: 'fake diff', + originalContent: 'old', + newContent: 'new', + }, + } as TrackedWaitingToolCall, + ]; + + const { result } = renderTestHook(toolCalls); + + // State is `WaitingForConfirmation` because a tool is awaiting approval + expect(result.current.streamingState).toBe( + StreamingState.WaitingForConfirmation, + ); + + // Try to cancel + simulateEscapeKeyPress(); + + // The imperative cancel function should be called on the scheduler + expect(mockCancelAllToolCalls).toHaveBeenCalled(); + + // A cancellation message should be added to history + await waitFor(() => { + expect(mockAddItem).toHaveBeenCalledWith( + expect.objectContaining({ + text: 'Request cancelled.', + }), + expect.any(Number), + ); + }); + + // The final state should be idle + expect(result.current.streamingState).toBe(StreamingState.Idle); }); }); @@ -1282,7 +1368,7 @@ describe('useGeminiStream', () => { mockUseReactToolScheduler.mockImplementation((onComplete) => { capturedOnComplete = onComplete; - return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted]; + return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted, vi.fn()]; }); renderHook(() => diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index a0190a3c4b..ae3a23c7eb 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -111,6 +111,7 @@ export const useGeminiStream = ( const [initError, setInitError] = useState(null); const abortControllerRef = useRef(null); const turnCancelledRef = useRef(false); + const activeQueryIdRef = useRef(null); const [isResponding, setIsResponding] = useState(false); const [thought, setThought] = useState(null); const [pendingHistoryItem, pendingHistoryItemRef, setPendingHistoryItem] = @@ -126,47 +127,55 @@ export const useGeminiStream = ( return new GitService(config.getProjectRoot(), storage); }, [config, storage]); - const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] = - useReactToolScheduler( - async (completedToolCallsFromScheduler) => { - // This onComplete is called when ALL scheduled tools for a given batch are done. - if (completedToolCallsFromScheduler.length > 0) { - // Add the final state of these tools to the history for display. - addItem( - mapTrackedToolCallsToDisplay( - completedToolCallsFromScheduler as TrackedToolCall[], - ), - Date.now(), - ); - - // Record tool calls with full metadata before sending responses. - try { - const currentModel = - config.getGeminiClient().getCurrentSequenceModel() ?? - config.getModel(); - config - .getGeminiClient() - .getChat() - .recordCompletedToolCalls( - currentModel, - completedToolCallsFromScheduler, - ); - } catch (error) { - console.error( - `Error recording completed tool call information: ${error}`, - ); - } - - // Handle tool response submission immediately when tools complete - await handleCompletedTools( + const [ + toolCalls, + scheduleToolCalls, + markToolsAsSubmitted, + setToolCallsForDisplay, + cancelAllToolCalls, + ] = useReactToolScheduler( + async (completedToolCallsFromScheduler) => { + // This onComplete is called when ALL scheduled tools for a given batch are done. + if (completedToolCallsFromScheduler.length > 0) { + // Add the final state of these tools to the history for display. + addItem( + mapTrackedToolCallsToDisplay( completedToolCallsFromScheduler as TrackedToolCall[], + ), + Date.now(), + ); + + // Clear the live-updating display now that the final state is in history. + setToolCallsForDisplay([]); + + // Record tool calls with full metadata before sending responses. + try { + const currentModel = + config.getGeminiClient().getCurrentSequenceModel() ?? + config.getModel(); + config + .getGeminiClient() + .getChat() + .recordCompletedToolCalls( + currentModel, + completedToolCallsFromScheduler, + ); + } catch (error) { + console.error( + `Error recording completed tool call information: ${error}`, ); } - }, - config, - getPreferredEditor, - onEditorClose, - ); + + // Handle tool response submission immediately when tools complete + await handleCompletedTools( + completedToolCallsFromScheduler as TrackedToolCall[], + ); + } + }, + config, + getPreferredEditor, + onEditorClose, + ); const pendingToolCallGroupDisplay = useMemo( () => @@ -265,27 +274,54 @@ export const useGeminiStream = ( }, [streamingState, config, history]); const cancelOngoingRequest = useCallback(() => { - if (streamingState !== StreamingState.Responding) { + if ( + streamingState !== StreamingState.Responding && + streamingState !== StreamingState.WaitingForConfirmation + ) { return; } if (turnCancelledRef.current) { return; } turnCancelledRef.current = true; - abortControllerRef.current?.abort(); + + // A full cancellation means no tools have produced a final result yet. + // This determines if we show a generic "Request cancelled" message. + const isFullCancellation = !toolCalls.some( + (tc) => tc.status === 'success' || tc.status === 'error', + ); + + // Ensure we have an abort controller, creating one if it doesn't exist. + if (!abortControllerRef.current) { + abortControllerRef.current = new AbortController(); + } + + // The order is important here. + // 1. Fire the signal to interrupt any active async operations. + abortControllerRef.current.abort(); + // 2. Call the imperative cancel to clear the queue of pending tools. + cancelAllToolCalls(abortControllerRef.current.signal); + if (pendingHistoryItemRef.current) { addItem(pendingHistoryItemRef.current, Date.now()); } - addItem( - { - type: MessageType.INFO, - text: 'Request cancelled.', - }, - Date.now(), - ); setPendingHistoryItem(null); + + // If it was a full cancellation, add the info message now. + // Otherwise, we let handleCompletedTools figure out the next step, + // which might involve sending partial results back to the model. + if (isFullCancellation) { + addItem( + { + type: MessageType.INFO, + text: 'Request cancelled.', + }, + Date.now(), + ); + setIsResponding(false); + } + onCancelSubmit(); - setIsResponding(false); setShellInputFocused(false); }, [ streamingState, @@ -294,6 +330,8 @@ export const useGeminiStream = ( onCancelSubmit, pendingHistoryItemRef, setShellInputFocused, + cancelAllToolCalls, + toolCalls, ]); useKeypress( @@ -302,7 +340,11 @@ export const useGeminiStream = ( cancelOngoingRequest(); } }, - { isActive: streamingState === StreamingState.Responding }, + { + isActive: + streamingState === StreamingState.Responding || + streamingState === StreamingState.WaitingForConfirmation, + }, ); const prepareQueryForGemini = useCallback( @@ -764,6 +806,8 @@ export const useGeminiStream = ( options?: { isContinuation: boolean }, prompt_id?: string, ) => { + const queryId = `${Date.now()}-${Math.random()}`; + activeQueryIdRef.current = queryId; if ( (streamingState === StreamingState.Responding || streamingState === StreamingState.WaitingForConfirmation) && @@ -901,7 +945,9 @@ export const useGeminiStream = ( ); } } finally { - setIsResponding(false); + if (activeQueryIdRef.current === queryId) { + setIsResponding(false); + } } }); }, @@ -963,10 +1009,6 @@ export const useGeminiStream = ( const handleCompletedTools = useCallback( async (completedToolCallsFromScheduler: TrackedToolCall[]) => { - if (isResponding) { - return; - } - const completedAndReadyToSubmitTools = completedToolCallsFromScheduler.filter( ( @@ -1028,6 +1070,19 @@ export const useGeminiStream = ( ); if (allToolsCancelled) { + // If the turn was cancelled via the imperative escape key flow, + // the cancellation message is added there. We check the ref to avoid duplication. + if (!turnCancelledRef.current) { + addItem( + { + type: MessageType.INFO, + text: 'Request cancelled.', + }, + Date.now(), + ); + } + setIsResponding(false); + if (geminiClient) { // We need to manually add the function responses to the history // so the model knows the tools were cancelled. @@ -1074,12 +1129,12 @@ export const useGeminiStream = ( ); }, [ - isResponding, submitQuery, markToolsAsSubmitted, geminiClient, performMemoryRefresh, modelSwitchedFromQuotaError, + addItem, ], ); diff --git a/packages/cli/src/ui/hooks/useReactToolScheduler.ts b/packages/cli/src/ui/hooks/useReactToolScheduler.ts index 883690d79a..2c7c8fc4df 100644 --- a/packages/cli/src/ui/hooks/useReactToolScheduler.ts +++ b/packages/cli/src/ui/hooks/useReactToolScheduler.ts @@ -62,12 +62,20 @@ export type TrackedToolCall = | TrackedCompletedToolCall | TrackedCancelledToolCall; +export type CancelAllFn = (signal: AbortSignal) => void; + export function useReactToolScheduler( onComplete: (tools: CompletedToolCall[]) => Promise, config: Config, getPreferredEditor: () => EditorType | undefined, onEditorClose: () => void, -): [TrackedToolCall[], ScheduleFn, MarkToolsAsSubmittedFn] { +): [ + TrackedToolCall[], + ScheduleFn, + MarkToolsAsSubmittedFn, + React.Dispatch>, + CancelAllFn, +] { const [toolCallsForDisplay, setToolCallsForDisplay] = useState< TrackedToolCall[] >([]); @@ -112,37 +120,36 @@ export function useReactToolScheduler( ); const toolCallsUpdateHandler: ToolCallsUpdateHandler = useCallback( - (updatedCoreToolCalls: ToolCall[]) => { - setToolCallsForDisplay((prevTrackedCalls) => - updatedCoreToolCalls.map((coreTc) => { - const existingTrackedCall = prevTrackedCalls.find( - (ptc) => ptc.request.callId === coreTc.request.callId, - ); - // Start with the new core state, then layer on the existing UI state - // to ensure UI-only properties like pid are preserved. + (allCoreToolCalls: ToolCall[]) => { + setToolCallsForDisplay((prevTrackedCalls) => { + const prevCallsMap = new Map( + prevTrackedCalls.map((c) => [c.request.callId, c]), + ); + + return allCoreToolCalls.map((coreTc): TrackedToolCall => { + const existingTrackedCall = prevCallsMap.get(coreTc.request.callId); + const responseSubmittedToGemini = existingTrackedCall?.responseSubmittedToGemini ?? false; if (coreTc.status === 'executing') { + // Preserve live output if it exists from a previous render. + const liveOutput = (existingTrackedCall as TrackedExecutingToolCall) + ?.liveOutput; return { ...coreTc, responseSubmittedToGemini, - liveOutput: (existingTrackedCall as TrackedExecutingToolCall) - ?.liveOutput, + liveOutput, pid: (coreTc as ExecutingToolCall).pid, }; + } else { + return { + ...coreTc, + responseSubmittedToGemini, + }; } - - // For other statuses, explicitly set liveOutput and pid to undefined - // to ensure they are not carried over from a previous executing state. - return { - ...coreTc, - responseSubmittedToGemini, - liveOutput: undefined, - pid: undefined, - }; - }), - ); + }); + }); }, [setToolCallsForDisplay], ); @@ -178,9 +185,10 @@ export function useReactToolScheduler( request: ToolCallRequestInfo | ToolCallRequestInfo[], signal: AbortSignal, ) => { + setToolCallsForDisplay([]); void scheduler.schedule(request, signal); }, - [scheduler], + [scheduler, setToolCallsForDisplay], ); const markToolsAsSubmitted: MarkToolsAsSubmittedFn = useCallback( @@ -196,7 +204,20 @@ export function useReactToolScheduler( [], ); - return [toolCallsForDisplay, schedule, markToolsAsSubmitted]; + const cancelAllToolCalls = useCallback( + (signal: AbortSignal) => { + scheduler.cancelAll(signal); + }, + [scheduler], + ); + + return [ + toolCallsForDisplay, + schedule, + markToolsAsSubmitted, + setToolCallsForDisplay, + cancelAllToolCalls, + ]; } /** diff --git a/packages/cli/src/ui/hooks/useToolScheduler.test.ts b/packages/cli/src/ui/hooks/useToolScheduler.test.ts index d80f8eceb2..11d1b7e7d8 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.test.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.test.ts @@ -260,9 +260,15 @@ describe('useReactToolScheduler', () => { args: { param: 'value' }, } as any; + let completedToolCalls: ToolCall[] = []; + onComplete.mockImplementation((calls) => { + completedToolCalls = calls; + }); + act(() => { schedule(request, new AbortController().signal); }); + await act(async () => { await vi.runAllTimersAsync(); }); @@ -292,7 +298,110 @@ describe('useReactToolScheduler', () => { }), }), ]); - expect(result.current[0]).toEqual([]); + expect(completedToolCalls).toHaveLength(1); + expect(completedToolCalls[0].status).toBe('success'); + expect(completedToolCalls[0].request).toBe(request); + }); + + it('should clear previous tool calls when scheduling new ones', async () => { + mockToolRegistry.getTool.mockReturnValue(mockTool); + (mockTool.execute as Mock).mockResolvedValue({ + llmContent: 'Tool output', + returnDisplay: 'Formatted tool output', + } as ToolResult); + + const { result } = renderScheduler(); + const schedule = result.current[1]; + const setToolCallsForDisplay = result.current[3]; + + // Manually set a tool call in the display. + const oldToolCall = { + request: { callId: 'oldCall' }, + status: 'success', + } as any; + act(() => { + setToolCallsForDisplay([oldToolCall]); + }); + expect(result.current[0]).toEqual([oldToolCall]); + + const newRequest: ToolCallRequestInfo = { + callId: 'newCall', + name: 'mockTool', + args: {}, + } as any; + act(() => { + schedule(newRequest, new AbortController().signal); + }); + + // After scheduling, the old call should be gone, + // and the new one should be in the display in its initial state. + expect(result.current[0].length).toBe(1); + expect(result.current[0][0].request.callId).toBe('newCall'); + expect(result.current[0][0].request.callId).not.toBe('oldCall'); + + // Let the new call finish. + await act(async () => { + await vi.runAllTimersAsync(); + }); + await act(async () => { + await vi.runAllTimersAsync(); + }); + await act(async () => { + await vi.runAllTimersAsync(); + }); + expect(onComplete).toHaveBeenCalled(); + }); + + it('should cancel all running tool calls', async () => { + mockToolRegistry.getTool.mockReturnValue(mockTool); + + let resolveExecute: (value: ToolResult) => void = () => {}; + const executePromise = new Promise((resolve) => { + resolveExecute = resolve; + }); + (mockTool.execute as Mock).mockReturnValue(executePromise); + (mockTool.shouldConfirmExecute as Mock).mockResolvedValue(null); + + const { result } = renderScheduler(); + const schedule = result.current[1]; + const cancelAllToolCalls = result.current[4]; + const request: ToolCallRequestInfo = { + callId: 'cancelCall', + name: 'mockTool', + args: {}, + } as any; + + act(() => { + schedule(request, new AbortController().signal); + }); + await act(async () => { + await vi.runAllTimersAsync(); + }); // validation + await act(async () => { + await vi.runAllTimersAsync(); + }); // scheduling + + // At this point, the tool is 'executing' and waiting on the promise. + expect(result.current[0][0].status).toBe('executing'); + + const cancelController = new AbortController(); + act(() => { + cancelAllToolCalls(cancelController.signal); + }); + + await act(async () => { + await vi.runAllTimersAsync(); + }); + + expect(onComplete).toHaveBeenCalledWith([ + expect.objectContaining({ + status: 'cancelled', + request, + }), + ]); + + // Clean up the pending promise to avoid open handles. + resolveExecute({ llmContent: 'output', returnDisplay: 'display' }); }); it('should handle tool not found', async () => { @@ -305,6 +414,11 @@ describe('useReactToolScheduler', () => { args: {}, } as any; + let completedToolCalls: ToolCall[] = []; + onComplete.mockImplementation((calls) => { + completedToolCalls = calls; + }); + act(() => { schedule(request, new AbortController().signal); }); @@ -315,24 +429,15 @@ describe('useReactToolScheduler', () => { await vi.runAllTimersAsync(); }); - expect(onComplete).toHaveBeenCalledWith([ - expect.objectContaining({ - status: 'error', - request, - response: expect.objectContaining({ - error: expect.objectContaining({ - message: expect.stringMatching( - /Tool "nonexistentTool" not found in registry/, - ), - }), - }), - }), - ]); - const errorMessage = onComplete.mock.calls[0][0][0].response.error.message; - expect(errorMessage).toContain('Did you mean one of:'); - expect(errorMessage).toContain('"mockTool"'); - expect(errorMessage).toContain('"anotherTool"'); - expect(result.current[0]).toEqual([]); + expect(completedToolCalls).toHaveLength(1); + expect(completedToolCalls[0].status).toBe('error'); + expect(completedToolCalls[0].request).toBe(request); + expect((completedToolCalls[0] as any).response.error.message).toContain( + 'Tool "nonexistentTool" not found in registry', + ); + expect((completedToolCalls[0] as any).response.error.message).toContain( + 'Did you mean one of:', + ); }); it('should handle error during shouldConfirmExecute', async () => { @@ -348,6 +453,11 @@ describe('useReactToolScheduler', () => { args: {}, } as any; + let completedToolCalls: ToolCall[] = []; + onComplete.mockImplementation((calls) => { + completedToolCalls = calls; + }); + act(() => { schedule(request, new AbortController().signal); }); @@ -358,16 +468,10 @@ describe('useReactToolScheduler', () => { await vi.runAllTimersAsync(); }); - expect(onComplete).toHaveBeenCalledWith([ - expect.objectContaining({ - status: 'error', - request, - response: expect.objectContaining({ - error: confirmError, - }), - }), - ]); - expect(result.current[0]).toEqual([]); + expect(completedToolCalls).toHaveLength(1); + expect(completedToolCalls[0].status).toBe('error'); + expect(completedToolCalls[0].request).toBe(request); + expect((completedToolCalls[0] as any).response.error).toBe(confirmError); }); it('should handle error during execute', async () => { @@ -384,6 +488,11 @@ describe('useReactToolScheduler', () => { args: {}, } as any; + let completedToolCalls: ToolCall[] = []; + onComplete.mockImplementation((calls) => { + completedToolCalls = calls; + }); + act(() => { schedule(request, new AbortController().signal); }); @@ -397,16 +506,10 @@ describe('useReactToolScheduler', () => { await vi.runAllTimersAsync(); }); - expect(onComplete).toHaveBeenCalledWith([ - expect.objectContaining({ - status: 'error', - request, - response: expect.objectContaining({ - error: execError, - }), - }), - ]); - expect(result.current[0]).toEqual([]); + expect(completedToolCalls).toHaveLength(1); + expect(completedToolCalls[0].status).toBe('error'); + expect(completedToolCalls[0].request).toBe(request); + expect((completedToolCalls[0] as any).response.error).toBe(execError); }); it('should handle tool requiring confirmation - approved', async () => { @@ -518,7 +621,7 @@ describe('useReactToolScheduler', () => { functionResponse: expect.objectContaining({ response: expect.objectContaining({ error: - '[Operation Cancelled] Reason: User did not allow tool call', + '[Operation Cancelled] Reason: User cancelled the operation.', }), }), }), @@ -705,7 +808,9 @@ describe('useReactToolScheduler', () => { ], }), }); - expect(result.current[0]).toEqual([]); + + expect(completedCalls).toHaveLength(2); + expect(completedCalls.every((t) => t.status === 'success')).toBe(true); }); it('should queue if scheduling while already running', async () => { @@ -774,7 +879,8 @@ describe('useReactToolScheduler', () => { response: expect.objectContaining({ resultDisplay: 'done display' }), }), ]); - expect(result.current[0]).toEqual([]); + const toolCalls = result.current[0]; + expect(toolCalls).toHaveLength(0); }); }); diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index e1e6aa2430..7dbf8021b8 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -288,6 +288,263 @@ describe('CoreToolScheduler', () => { expect(completedCalls[0].status).toBe('cancelled'); }); + it('should cancel all tools when cancelAll is called', async () => { + const mockTool1 = new MockTool({ + name: 'mockTool1', + shouldConfirmExecute: MOCK_TOOL_SHOULD_CONFIRM_EXECUTE, + }); + const mockTool2 = new MockTool({ name: 'mockTool2' }); + const mockTool3 = new MockTool({ name: 'mockTool3' }); + + const mockToolRegistry = { + getTool: (name: string) => { + if (name === 'mockTool1') return mockTool1; + if (name === 'mockTool2') return mockTool2; + if (name === 'mockTool3') return mockTool3; + return undefined; + }, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByName: (name: string) => { + if (name === 'mockTool1') return mockTool1; + if (name === 'mockTool2') return mockTool2; + if (name === 'mockTool3') return mockTool3; + return undefined; + }, + getToolByDisplayName: () => undefined, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + } as unknown as ToolRegistry; + + const onAllToolCallsComplete = vi.fn(); + const onToolCallsUpdate = vi.fn(); + + const mockConfig = { + getSessionId: () => 'test-session-id', + getUsageStatisticsEnabled: () => true, + getDebugMode: () => false, + getApprovalMode: () => ApprovalMode.DEFAULT, + getAllowedTools: () => [], + getContentGeneratorConfig: () => ({ + model: 'test-model', + authType: 'oauth-personal', + }), + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + }), + storage: { + getProjectTempDir: () => '/tmp', + }, + getTruncateToolOutputThreshold: () => + DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, + getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, + getToolRegistry: () => mockToolRegistry, + getUseSmartEdit: () => false, + getUseModelRouter: () => false, + getGeminiClient: () => null, // No client needed for these tests + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, + getPolicyEngine: () => null, + } as unknown as Config; + + const scheduler = new CoreToolScheduler({ + config: mockConfig, + onAllToolCallsComplete, + onToolCallsUpdate, + getPreferredEditor: () => 'vscode', + onEditorClose: vi.fn(), + }); + + const abortController = new AbortController(); + const requests = [ + { + callId: '1', + name: 'mockTool1', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-1', + }, + { + callId: '2', + name: 'mockTool2', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-1', + }, + { + callId: '3', + name: 'mockTool3', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-1', + }, + ]; + + // Don't await, let it run in the background + void scheduler.schedule(requests, abortController.signal); + + // Wait for the first tool to be awaiting approval + await waitForStatus(onToolCallsUpdate, 'awaiting_approval'); + + // Cancel all operations + scheduler.cancelAll(abortController.signal); + abortController.abort(); // Also fire the signal + + await vi.waitFor(() => { + expect(onAllToolCallsComplete).toHaveBeenCalled(); + }); + + const completedCalls = onAllToolCallsComplete.mock + .calls[0][0] as ToolCall[]; + + expect(completedCalls).toHaveLength(3); + expect(completedCalls.find((c) => c.request.callId === '1')?.status).toBe( + 'cancelled', + ); + expect(completedCalls.find((c) => c.request.callId === '2')?.status).toBe( + 'cancelled', + ); + expect(completedCalls.find((c) => c.request.callId === '3')?.status).toBe( + 'cancelled', + ); + }); + + it('should cancel all tools in a batch when one is cancelled via confirmation', async () => { + const mockTool1 = new MockTool({ + name: 'mockTool1', + shouldConfirmExecute: MOCK_TOOL_SHOULD_CONFIRM_EXECUTE, + }); + const mockTool2 = new MockTool({ name: 'mockTool2' }); + const mockTool3 = new MockTool({ name: 'mockTool3' }); + + const mockToolRegistry = { + getTool: (name: string) => { + if (name === 'mockTool1') return mockTool1; + if (name === 'mockTool2') return mockTool2; + if (name === 'mockTool3') return mockTool3; + return undefined; + }, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByName: (name: string) => { + if (name === 'mockTool1') return mockTool1; + if (name === 'mockTool2') return mockTool2; + if (name === 'mockTool3') return mockTool3; + return undefined; + }, + getToolByDisplayName: () => undefined, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + } as unknown as ToolRegistry; + + const onAllToolCallsComplete = vi.fn(); + const onToolCallsUpdate = vi.fn(); + + const mockConfig = { + getSessionId: () => 'test-session-id', + getUsageStatisticsEnabled: () => true, + getDebugMode: () => false, + getApprovalMode: () => ApprovalMode.DEFAULT, + getAllowedTools: () => [], + getContentGeneratorConfig: () => ({ + model: 'test-model', + authType: 'oauth-personal', + }), + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + }), + storage: { + getProjectTempDir: () => '/tmp', + }, + getTruncateToolOutputThreshold: () => + DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, + getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, + getToolRegistry: () => mockToolRegistry, + getUseSmartEdit: () => false, + getUseModelRouter: () => false, + getGeminiClient: () => null, // No client needed for these tests + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, + getPolicyEngine: () => null, + } as unknown as Config; + + const scheduler = new CoreToolScheduler({ + config: mockConfig, + onAllToolCallsComplete, + onToolCallsUpdate, + getPreferredEditor: () => 'vscode', + onEditorClose: vi.fn(), + }); + + const abortController = new AbortController(); + const requests = [ + { + callId: '1', + name: 'mockTool1', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-1', + }, + { + callId: '2', + name: 'mockTool2', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-1', + }, + { + callId: '3', + name: 'mockTool3', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-1', + }, + ]; + + // Don't await, let it run in the background + void scheduler.schedule(requests, abortController.signal); + + // Wait for the first tool to be awaiting approval + const awaitingCall = (await waitForStatus( + onToolCallsUpdate, + 'awaiting_approval', + )) as WaitingToolCall; + + // Cancel the first tool via its confirmation handler + await awaitingCall.confirmationDetails.onConfirm( + ToolConfirmationOutcome.Cancel, + ); + abortController.abort(); // User cancelling often involves an abort signal + + await vi.waitFor(() => { + expect(onAllToolCallsComplete).toHaveBeenCalled(); + }); + + const completedCalls = onAllToolCallsComplete.mock + .calls[0][0] as ToolCall[]; + + expect(completedCalls).toHaveLength(3); + expect(completedCalls.find((c) => c.request.callId === '1')?.status).toBe( + 'cancelled', + ); + expect(completedCalls.find((c) => c.request.callId === '2')?.status).toBe( + 'cancelled', + ); + expect(completedCalls.find((c) => c.request.callId === '3')?.status).toBe( + 'cancelled', + ); + }); + it('should mark tool call as cancelled when abort happens during confirmation error', async () => { const abortController = new AbortController(); const abortError = new Error('Abort requested during confirmation'); @@ -1510,16 +1767,19 @@ describe('CoreToolScheduler request queueing', () => { await scheduler.schedule(requests, abortController.signal); - // Wait for all tools to be awaiting approval + // Wait for the FIRST tool to be awaiting approval await vi.waitFor(() => { const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[]; + // With the sequential scheduler, the update includes the active call and the queue. expect(calls?.length).toBe(3); - expect(calls?.every((call) => call.status === 'awaiting_approval')).toBe( - true, - ); + expect(calls?.[0].status).toBe('awaiting_approval'); + expect(calls?.[0].request.callId).toBe('1'); + // Check that the other two are in the queue (still in 'validating' state) + expect(calls?.[1].status).toBe('validating'); + expect(calls?.[2].status).toBe('validating'); }); - expect(pendingConfirmations.length).toBe(3); + expect(pendingConfirmations.length).toBe(1); // Approve the first tool with ProceedAlways const firstConfirmation = pendingConfirmations[0]; @@ -1528,15 +1788,16 @@ describe('CoreToolScheduler request queueing', () => { // Wait for all tools to be completed await vi.waitFor(() => { expect(onAllToolCallsComplete).toHaveBeenCalled(); - const completedCalls = onAllToolCallsComplete.mock.calls.at( - -1, - )?.[0] as ToolCall[]; - expect(completedCalls?.length).toBe(3); - expect(completedCalls?.every((call) => call.status === 'success')).toBe( - true, - ); }); + const completedCalls = onAllToolCallsComplete.mock.calls.at( + -1, + )?.[0] as ToolCall[]; + expect(completedCalls?.length).toBe(3); + expect(completedCalls?.every((call) => call.status === 'success')).toBe( + true, + ); + // Verify approval mode was changed expect(approvalMode).toBe(ApprovalMode.AUTO_EDIT); }); @@ -1788,11 +2049,10 @@ describe('CoreToolScheduler Sequential Execution', () => { expect(onAllToolCallsComplete).toHaveBeenCalled(); }); - // Check that execute was called for all three tools initially - expect(executeFn).toHaveBeenCalledTimes(3); + // Check that execute was called for the first two tools only + expect(executeFn).toHaveBeenCalledTimes(2); expect(executeFn).toHaveBeenCalledWith({ call: 1 }); expect(executeFn).toHaveBeenCalledWith({ call: 2 }); - expect(executeFn).toHaveBeenCalledWith({ call: 3 }); const completedCalls = onAllToolCallsComplete.mock .calls[0][0] as ToolCall[]; diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 5c1cb58fb7..a59de8698e 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -348,12 +348,15 @@ export class CoreToolScheduler { private onEditorClose: () => void; private isFinalizingToolCalls = false; private isScheduling = false; + private isCancelling = false; private requestQueue: Array<{ request: ToolCallRequestInfo | ToolCallRequestInfo[]; signal: AbortSignal; resolve: () => void; reject: (reason?: Error) => void; }> = []; + private toolCallQueue: ToolCall[] = []; + private completedToolCallsForBatch: CompletedToolCall[] = []; constructor(options: CoreToolSchedulerOptions) { this.config = options.config; @@ -398,30 +401,36 @@ export class CoreToolScheduler { private setStatusInternal( targetCallId: string, status: 'success', + signal: AbortSignal, response: ToolCallResponseInfo, ): void; private setStatusInternal( targetCallId: string, status: 'awaiting_approval', + signal: AbortSignal, confirmationDetails: ToolCallConfirmationDetails, ): void; private setStatusInternal( targetCallId: string, status: 'error', + signal: AbortSignal, response: ToolCallResponseInfo, ): void; private setStatusInternal( targetCallId: string, status: 'cancelled', + signal: AbortSignal, reason: string, ): void; private setStatusInternal( targetCallId: string, status: 'executing' | 'scheduled' | 'validating', + signal: AbortSignal, ): void; private setStatusInternal( targetCallId: string, newStatus: Status, + signal: AbortSignal, auxiliaryData?: unknown, ): void { this.toolCalls = this.toolCalls.map((currentCall) => { @@ -561,7 +570,6 @@ export class CoreToolScheduler { } }); this.notifyToolCallsUpdate(); - this.checkAndNotifyCompletion(); } private setArgsInternal(targetCallId: string, args: unknown): void { @@ -692,11 +700,43 @@ export class CoreToolScheduler { return this._schedule(request, signal); } + cancelAll(signal: AbortSignal): void { + if (this.isCancelling) { + return; + } + this.isCancelling = true; + // Cancel the currently active tool call, if there is one. + if (this.toolCalls.length > 0) { + const activeCall = this.toolCalls[0]; + // Only cancel if it's in a cancellable state. + if ( + activeCall.status === 'awaiting_approval' || + activeCall.status === 'executing' || + activeCall.status === 'scheduled' || + activeCall.status === 'validating' + ) { + this.setStatusInternal( + activeCall.request.callId, + 'cancelled', + signal, + 'User cancelled the operation.', + ); + } + } + + // Clear the queue and mark all queued items as cancelled for completion reporting. + this._cancelAllQueuedCalls(); + + // Finalize the batch immediately. + void this.checkAndNotifyCompletion(signal); + } + private async _schedule( request: ToolCallRequestInfo | ToolCallRequestInfo[], signal: AbortSignal, ): Promise { this.isScheduling = true; + this.isCancelling = false; try { if (this.isRunning()) { throw new Error( @@ -704,6 +744,7 @@ export class CoreToolScheduler { ); } const requestsToProcess = Array.isArray(request) ? request : [request]; + this.completedToolCallsForBatch = []; const newToolCalls: ToolCall[] = requestsToProcess.map( (reqInfo): ToolCall => { @@ -753,45 +794,74 @@ export class CoreToolScheduler { }, ); - this.toolCalls = this.toolCalls.concat(newToolCalls); - this.notifyToolCallsUpdate(); + this.toolCallQueue.push(...newToolCalls); + await this._processNextInQueue(signal); + } finally { + this.isScheduling = false; + } + } - for (const toolCall of newToolCalls) { - if (toolCall.status !== 'validating') { - continue; + private async _processNextInQueue(signal: AbortSignal): Promise { + // If there's already a tool being processed, or the queue is empty, stop. + if (this.toolCalls.length > 0 || this.toolCallQueue.length === 0) { + return; + } + + // If cancellation happened between steps, handle it. + if (signal.aborted) { + this._cancelAllQueuedCalls(); + // Finalize the batch. + await this.checkAndNotifyCompletion(signal); + return; + } + + const toolCall = this.toolCallQueue.shift()!; + + // This is now the single active tool call. + this.toolCalls = [toolCall]; + this.notifyToolCallsUpdate(); + + // Handle tools that were already errored during creation. + if (toolCall.status === 'error') { + // An error during validation means this "active" tool is already complete. + // We need to check for batch completion to either finish or process the next in queue. + await this.checkAndNotifyCompletion(signal); + return; + } + + // This logic is moved from the old `for` loop in `_schedule`. + if (toolCall.status === 'validating') { + const { request: reqInfo, invocation } = toolCall; + + try { + if (signal.aborted) { + this.setStatusInternal( + reqInfo.callId, + 'cancelled', + signal, + 'Tool call cancelled by user.', + ); + // The completion check will handle the cascade. + await this.checkAndNotifyCompletion(signal); + return; } - const validatingCall = toolCall as ValidatingToolCall; - const { request: reqInfo, invocation } = validatingCall; + const confirmationDetails = + await invocation.shouldConfirmExecute(signal); - try { - if (signal.aborted) { - this.setStatusInternal( - reqInfo.callId, - 'cancelled', - 'Tool call cancelled by user.', - ); - continue; - } - - const confirmationDetails = - await invocation.shouldConfirmExecute(signal); - - if (!confirmationDetails) { + if (!confirmationDetails) { + this.setToolCallOutcome( + reqInfo.callId, + ToolConfirmationOutcome.ProceedAlways, + ); + this.setStatusInternal(reqInfo.callId, 'scheduled', signal); + } else { + if (this.isAutoApproved(toolCall)) { this.setToolCallOutcome( reqInfo.callId, ToolConfirmationOutcome.ProceedAlways, ); - this.setStatusInternal(reqInfo.callId, 'scheduled'); - continue; - } - - if (this.isAutoApproved(validatingCall)) { - this.setToolCallOutcome( - reqInfo.callId, - ToolConfirmationOutcome.ProceedAlways, - ); - this.setStatusInternal(reqInfo.callId, 'scheduled'); + this.setStatusInternal(reqInfo.callId, 'scheduled', signal); } else { // Allow IDE to resolve confirmation if ( @@ -835,35 +905,36 @@ export class CoreToolScheduler { this.setStatusInternal( reqInfo.callId, 'awaiting_approval', + signal, wrappedConfirmationDetails, ); } - } catch (error) { - if (signal.aborted) { - this.setStatusInternal( - reqInfo.callId, - 'cancelled', - 'Tool call cancelled by user.', - ); - continue; - } - + } + } catch (error) { + if (signal.aborted) { + this.setStatusInternal( + reqInfo.callId, + 'cancelled', + signal, + 'Tool call cancelled by user.', + ); + await this.checkAndNotifyCompletion(signal); + } else { this.setStatusInternal( reqInfo.callId, 'error', + signal, createErrorResponse( reqInfo, error instanceof Error ? error : new Error(String(error)), ToolErrorType.UNHANDLED_EXCEPTION, ), ); + await this.checkAndNotifyCompletion(signal); } } - await this.attemptExecutionOfScheduledCalls(signal); - void this.checkAndNotifyCompletion(); - } finally { - this.isScheduling = false; } + await this.attemptExecutionOfScheduledCalls(signal); } async handleConfirmationResponse( @@ -881,18 +952,12 @@ export class CoreToolScheduler { await originalOnConfirm(outcome); } - if (outcome === ToolConfirmationOutcome.ProceedAlways) { - await this.autoApproveCompatiblePendingTools(signal, callId); - } - this.setToolCallOutcome(callId, outcome); if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) { - this.setStatusInternal( - callId, - 'cancelled', - 'User did not allow tool call', - ); + // Instead of just cancelling one tool, trigger the full cancel cascade. + this.cancelAll(signal); + return; // `cancelAll` calls `checkAndNotifyCompletion`, so we can exit here. } else if (outcome === ToolConfirmationOutcome.ModifyWithEditor) { const waitingToolCall = toolCall as WaitingToolCall; if (isModifiableDeclarativeTool(waitingToolCall.tool)) { @@ -902,7 +967,7 @@ export class CoreToolScheduler { return; } - this.setStatusInternal(callId, 'awaiting_approval', { + this.setStatusInternal(callId, 'awaiting_approval', signal, { ...waitingToolCall.confirmationDetails, isModifying: true, } as ToolCallConfirmationDetails); @@ -917,7 +982,7 @@ export class CoreToolScheduler { this.onEditorClose, ); this.setArgsInternal(callId, updatedParams); - this.setStatusInternal(callId, 'awaiting_approval', { + this.setStatusInternal(callId, 'awaiting_approval', signal, { ...waitingToolCall.confirmationDetails, fileDiff: updatedDiff, isModifying: false, @@ -932,7 +997,7 @@ export class CoreToolScheduler { signal, ); } - this.setStatusInternal(callId, 'scheduled'); + this.setStatusInternal(callId, 'scheduled', signal); } await this.attemptExecutionOfScheduledCalls(signal); } @@ -974,10 +1039,15 @@ export class CoreToolScheduler { ); this.setArgsInternal(toolCall.request.callId, updatedParams); - this.setStatusInternal(toolCall.request.callId, 'awaiting_approval', { - ...toolCall.confirmationDetails, - fileDiff: updatedDiff, - }); + this.setStatusInternal( + toolCall.request.callId, + 'awaiting_approval', + signal, + { + ...toolCall.confirmationDetails, + fileDiff: updatedDiff, + }, + ); } private async attemptExecutionOfScheduledCalls( @@ -1002,7 +1072,7 @@ export class CoreToolScheduler { const scheduledCall = toolCall; const { callId, name: toolName } = scheduledCall.request; const invocation = scheduledCall.invocation; - this.setStatusInternal(callId, 'executing'); + this.setStatusInternal(callId, 'executing', signal); const liveOutputCallback = scheduledCall.tool.canUpdateOutput && this.outputUpdateHandler @@ -1055,12 +1125,10 @@ export class CoreToolScheduler { this.setStatusInternal( callId, 'cancelled', + signal, 'User cancelled tool execution.', ); - continue; - } - - if (toolResult.error === undefined) { + } else if (toolResult.error === undefined) { let content = toolResult.llmContent; let outputFile: string | undefined = undefined; const contentLength = @@ -1116,7 +1184,7 @@ export class CoreToolScheduler { outputFile, contentLength, }; - this.setStatusInternal(callId, 'success', successResponse); + this.setStatusInternal(callId, 'success', signal, successResponse); } else { // It is a failure const error = new Error(toolResult.error.message); @@ -1125,19 +1193,21 @@ export class CoreToolScheduler { error, toolResult.error.type, ); - this.setStatusInternal(callId, 'error', errorResponse); + this.setStatusInternal(callId, 'error', signal, errorResponse); } } catch (executionError: unknown) { if (signal.aborted) { this.setStatusInternal( callId, 'cancelled', + signal, 'User cancelled tool execution.', ); } else { this.setStatusInternal( callId, 'error', + signal, createErrorResponse( scheduledCall.request, executionError instanceof Error @@ -1148,45 +1218,126 @@ export class CoreToolScheduler { ); } } + await this.checkAndNotifyCompletion(signal); } } } - private async checkAndNotifyCompletion(): Promise { - const allCallsAreTerminal = this.toolCalls.every( - (call) => - call.status === 'success' || - call.status === 'error' || - call.status === 'cancelled', - ); + private async checkAndNotifyCompletion(signal: AbortSignal): Promise { + // This method is now only concerned with the single active tool call. + if (this.toolCalls.length === 0) { + // It's possible to be called when a batch is cancelled before any tool has started. + if (signal.aborted && this.toolCallQueue.length > 0) { + this._cancelAllQueuedCalls(); + } + } else { + const activeCall = this.toolCalls[0]; + const isTerminal = + activeCall.status === 'success' || + activeCall.status === 'error' || + activeCall.status === 'cancelled'; - if (this.toolCalls.length > 0 && allCallsAreTerminal) { - const completedCalls = [...this.toolCalls] as CompletedToolCall[]; + // If the active tool is not in a terminal state (e.g., it's 'executing' or 'awaiting_approval'), + // then the scheduler is still busy or paused. We should not proceed. + if (!isTerminal) { + return; + } + + // The active tool is finished. Move it to the completed batch. + const completedCall = activeCall as CompletedToolCall; + this.completedToolCallsForBatch.push(completedCall); + logToolCall(this.config, new ToolCallEvent(completedCall)); + + // Clear the active tool slot. This is crucial for the sequential processing. this.toolCalls = []; + } - for (const call of completedCalls) { - logToolCall(this.config, new ToolCallEvent(call)); + // Now, check if the entire batch is complete. + // The batch is complete if the queue is empty or the operation was cancelled. + if (this.toolCallQueue.length === 0 || signal.aborted) { + if (signal.aborted) { + this._cancelAllQueuedCalls(); + } + + // If there's nothing to report and we weren't cancelled, we can stop. + // But if we were cancelled, we must proceed to potentially start the next queued request. + if (this.completedToolCallsForBatch.length === 0 && !signal.aborted) { + return; } if (this.onAllToolCallsComplete) { this.isFinalizingToolCalls = true; - await this.onAllToolCallsComplete(completedCalls); + // Use the batch array, not the (now empty) active array. + await this.onAllToolCallsComplete(this.completedToolCallsForBatch); + this.completedToolCallsForBatch = []; // Clear after reporting. this.isFinalizingToolCalls = false; } + this.isCancelling = false; this.notifyToolCallsUpdate(); - // After completion, process the next item in the queue. + + // After completion of the entire batch, process the next item in the main request queue. if (this.requestQueue.length > 0) { const next = this.requestQueue.shift()!; this._schedule(next.request, next.signal) .then(next.resolve) .catch(next.reject); } + } else { + // The batch is not yet complete, so continue processing the current batch sequence. + await this._processNextInQueue(signal); + } + } + + private _cancelAllQueuedCalls(): void { + while (this.toolCallQueue.length > 0) { + const queuedCall = this.toolCallQueue.shift()!; + // Don't cancel tools that already errored during validation. + if (queuedCall.status === 'error') { + this.completedToolCallsForBatch.push(queuedCall); + continue; + } + const durationMs = + 'startTime' in queuedCall && queuedCall.startTime + ? Date.now() - queuedCall.startTime + : undefined; + const errorMessage = + '[Operation Cancelled] User cancelled the operation.'; + this.completedToolCallsForBatch.push({ + request: queuedCall.request, + tool: queuedCall.tool, + invocation: queuedCall.invocation, + status: 'cancelled', + response: { + callId: queuedCall.request.callId, + responseParts: [ + { + functionResponse: { + id: queuedCall.request.callId, + name: queuedCall.request.name, + response: { + error: errorMessage, + }, + }, + }, + ], + resultDisplay: undefined, + error: undefined, + errorType: undefined, + contentLength: errorMessage.length, + }, + durationMs, + outcome: ToolConfirmationOutcome.Cancel, + }); } } private notifyToolCallsUpdate(): void { if (this.onToolCallsUpdate) { - this.onToolCallsUpdate([...this.toolCalls]); + this.onToolCallsUpdate([ + ...this.completedToolCallsForBatch, + ...this.toolCalls, + ...this.toolCallQueue, + ]); } } @@ -1215,35 +1366,4 @@ export class CoreToolScheduler { return doesToolInvocationMatch(tool, invocation, allowedTools); } - - private async autoApproveCompatiblePendingTools( - signal: AbortSignal, - triggeringCallId: string, - ): Promise { - const pendingTools = this.toolCalls.filter( - (call) => - call.status === 'awaiting_approval' && - call.request.callId !== triggeringCallId, - ) as WaitingToolCall[]; - - for (const pendingTool of pendingTools) { - try { - const stillNeedsConfirmation = - await pendingTool.invocation.shouldConfirmExecute(signal); - - if (!stillNeedsConfirmation) { - this.setToolCallOutcome( - pendingTool.request.callId, - ToolConfirmationOutcome.ProceedAlways, - ); - this.setStatusInternal(pendingTool.request.callId, 'scheduled'); - } - } catch (error) { - console.error( - `Error checking confirmation for tool ${pendingTool.request.callId}:`, - error, - ); - } - } - } }