feat(core, cli): Implement sequential approval. (#11593)

This commit is contained in:
joshualitt
2025-10-27 09:59:08 -07:00
committed by GitHub
parent 23c906b085
commit 541eeb7a50
9 changed files with 1272 additions and 339 deletions
+120 -1
View File
@@ -4,11 +4,12 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi } from 'vitest';
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { Task } from './task.js';
import type { Config, ToolCallRequestInfo } from '@google/gemini-cli-core';
import { createMockConfig } from '../utils/testing_utils.js';
import type { ExecutionEventBus } from '@a2a-js/sdk/server';
import type { ToolCall } from '@google/gemini-cli-core';
describe('Task', () => {
it('scheduleToolCalls should not modify the input requests array', async () => {
@@ -94,4 +95,122 @@ describe('Task', () => {
);
});
});
describe('_schedulerToolCallsUpdate', () => {
let task: Task;
type SpyInstance = ReturnType<typeof vi.spyOn>;
let setTaskStateAndPublishUpdateSpy: SpyInstance;
beforeEach(() => {
const mockConfig = createMockConfig();
const mockEventBus: ExecutionEventBus = {
publish: vi.fn(),
on: vi.fn(),
off: vi.fn(),
once: vi.fn(),
removeAllListeners: vi.fn(),
finished: vi.fn(),
};
// @ts-expect-error - Calling private constructor
task = new Task(
'task-id',
'context-id',
mockConfig as Config,
mockEventBus,
);
// Spy on the method we want to check calls for
setTaskStateAndPublishUpdateSpy = vi.spyOn(
task,
'setTaskStateAndPublishUpdate',
);
});
afterEach(() => {
vi.restoreAllMocks();
});
it('should set state to input-required when a tool is awaiting approval and none are executing', () => {
const toolCalls = [
{ request: { callId: '1' }, status: 'awaiting_approval' },
] as ToolCall[];
// @ts-expect-error - Calling private method
task._schedulerToolCallsUpdate(toolCalls);
// The last call should be the final state update
expect(setTaskStateAndPublishUpdateSpy).toHaveBeenLastCalledWith(
'input-required',
{ kind: 'state-change' },
undefined,
undefined,
true, // final: true
);
});
it('should NOT set state to input-required if a tool is awaiting approval but another is executing', () => {
const toolCalls = [
{ request: { callId: '1' }, status: 'awaiting_approval' },
{ request: { callId: '2' }, status: 'executing' },
] as ToolCall[];
// @ts-expect-error - Calling private method
task._schedulerToolCallsUpdate(toolCalls);
// It will be called for status updates, but not with final: true
const finalCall = setTaskStateAndPublishUpdateSpy.mock.calls.find(
(call) => call[4] === true,
);
expect(finalCall).toBeUndefined();
});
it('should set state to input-required once an executing tool finishes, leaving one awaiting approval', () => {
const initialToolCalls = [
{ request: { callId: '1' }, status: 'awaiting_approval' },
{ request: { callId: '2' }, status: 'executing' },
] as ToolCall[];
// @ts-expect-error - Calling private method
task._schedulerToolCallsUpdate(initialToolCalls);
// No final call yet
let finalCall = setTaskStateAndPublishUpdateSpy.mock.calls.find(
(call) => call[4] === true,
);
expect(finalCall).toBeUndefined();
// Now, the executing tool finishes. The scheduler would call _resolveToolCall for it.
// @ts-expect-error - Calling private method
task._resolveToolCall('2');
// Then another update comes in for the awaiting tool (e.g., a re-check)
const subsequentToolCalls = [
{ request: { callId: '1' }, status: 'awaiting_approval' },
] as ToolCall[];
// @ts-expect-error - Calling private method
task._schedulerToolCallsUpdate(subsequentToolCalls);
// NOW we should get the final call
finalCall = setTaskStateAndPublishUpdateSpy.mock.calls.find(
(call) => call[4] === true,
);
expect(finalCall).toBeDefined();
expect(finalCall?.[0]).toBe('input-required');
});
it('should NOT set state to input-required if skipFinalTrueAfterInlineEdit is true', () => {
task.skipFinalTrueAfterInlineEdit = true;
const toolCalls = [
{ request: { callId: '1' }, status: 'awaiting_approval' },
] as ToolCall[];
// @ts-expect-error - Calling private method
task._schedulerToolCallsUpdate(toolCalls);
const finalCall = setTaskStateAndPublishUpdateSpy.mock.calls.find(
(call) => call[4] === true,
);
expect(finalCall).toBeUndefined();
});
});
});
+7 -12
View File
@@ -40,7 +40,6 @@ import type {
import { v4 as uuidv4 } from 'uuid';
import { logger } from '../utils/logger.js';
import * as fs from 'node:fs';
import { CoderAgentEvent } from '../types.js';
import type {
CoderAgentMessage,
@@ -373,11 +372,11 @@ export class Task {
// Only send an update if the status has actually changed.
if (hasChanged) {
const message = this.toolStatusMessage(tc, this.id, this.contextId);
const coderAgentMessage: CoderAgentMessage =
tc.status === 'awaiting_approval'
? { kind: CoderAgentEvent.ToolCallConfirmationEvent }
: { kind: CoderAgentEvent.ToolCallUpdateEvent };
const message = this.toolStatusMessage(tc, this.id, this.contextId);
const event = this._createStatusUpdateEvent(
this.taskState,
@@ -404,20 +403,16 @@ export class Task {
const isAwaitingApproval = allPendingStatuses.some(
(status) => status === 'awaiting_approval',
);
const allPendingAreStable = allPendingStatuses.every(
(status) =>
status === 'awaiting_approval' ||
status === 'success' ||
status === 'error' ||
status === 'cancelled',
const isExecuting = allPendingStatuses.some(
(status) => status === 'executing',
);
// 1. Are any pending tool calls awaiting_approval
// 2. Are all pending tool calls in a stable state (i.e. not in validing or executing)
// 3. After an inline edit, the edited tool call will send awaiting_approval THEN scheduled. We wait for the next update in this case.
// The turn is complete and requires user input if at least one tool
// is waiting for the user's decision, and no other tool is actively
// running in the background.
if (
isAwaitingApproval &&
allPendingAreStable &&
!isExecuting &&
!this.skipFinalTrueAfterInlineEdit
) {
this.skipFinalTrueAfterInlineEdit = false;
+195 -24
View File
@@ -313,7 +313,7 @@ describe('E2E Tests', () => {
expect(workingEvent.kind).toBe('status-update');
expect(workingEvent.status.state).toBe('working');
// State Update: Validate each tool call
// State Update: Validate the first tool call
const toolCallValidateEvent1 = events[3].result as TaskStatusUpdateEvent;
expect(toolCallValidateEvent1.metadata?.['coderAgent']).toMatchObject({
kind: 'tool-call-update',
@@ -326,47 +326,218 @@ describe('E2E Tests', () => {
},
},
]);
const toolCallValidateEvent2 = events[4].result as TaskStatusUpdateEvent;
expect(toolCallValidateEvent2.metadata?.['coderAgent']).toMatchObject({
// --- Assert the event stream ---
// 1. Initial "submitted" status.
expect((events[0].result as TaskStatusUpdateEvent).status.state).toBe(
'submitted',
);
// 2. "working" status after receiving the user prompt.
expect((events[1].result as TaskStatusUpdateEvent).status.state).toBe(
'working',
);
// 3. A "state-change" event from the agent.
expect(events[2].result.metadata?.['coderAgent']).toMatchObject({
kind: 'state-change',
});
// 4. Tool 1 is validating.
const toolCallUpdate1 = events[3].result as TaskStatusUpdateEvent;
expect(toolCallUpdate1.metadata?.['coderAgent']).toMatchObject({
kind: 'tool-call-update',
});
expect(toolCallValidateEvent2.status.message?.parts).toMatchObject([
expect(toolCallUpdate1.status.message?.parts).toMatchObject([
{
data: {
request: { callId: 'test-call-id-1' },
status: 'validating',
request: { callId: 'test-call-id-2' },
},
},
]);
// State Update: Set each tool call to awaiting
const toolCallAwaitEvent1 = events[5].result as TaskStatusUpdateEvent;
expect(toolCallAwaitEvent1.metadata?.['coderAgent']).toMatchObject({
kind: 'tool-call-confirmation',
// 5. Tool 2 is validating.
const toolCallUpdate2 = events[4].result as TaskStatusUpdateEvent;
expect(toolCallUpdate2.metadata?.['coderAgent']).toMatchObject({
kind: 'tool-call-update',
});
expect(toolCallAwaitEvent1.status.message?.parts).toMatchObject([
expect(toolCallUpdate2.status.message?.parts).toMatchObject([
{
data: {
status: 'awaiting_approval',
request: { callId: 'test-call-id-1' },
},
},
]);
const toolCallAwaitEvent2 = events[6].result as TaskStatusUpdateEvent;
expect(toolCallAwaitEvent2.metadata?.['coderAgent']).toMatchObject({
kind: 'tool-call-confirmation',
});
expect(toolCallAwaitEvent2.status.message?.parts).toMatchObject([
{
data: {
status: 'awaiting_approval',
request: { callId: 'test-call-id-2' },
status: 'validating',
},
},
]);
// 6. Tool 1 is awaiting approval.
const toolCallAwaitEvent = events[5].result as TaskStatusUpdateEvent;
expect(toolCallAwaitEvent.metadata?.['coderAgent']).toMatchObject({
kind: 'tool-call-confirmation',
});
expect(toolCallAwaitEvent.status.message?.parts).toMatchObject([
{
data: {
request: { callId: 'test-call-id-1' },
status: 'awaiting_approval',
},
},
]);
// 7. The final event is "input-required".
const finalEvent = events[6].result as TaskStatusUpdateEvent;
expect(finalEvent.final).toBe(true);
expect(finalEvent.status.state).toBe('input-required');
// The scheduler now waits for approval, so no more events are sent.
assertUniqueFinalEventIsLast(events);
expect(events.length).toBe(7);
});
it('should handle multiple tool calls sequentially in YOLO mode', async () => {
// Set YOLO mode to auto-approve tools and test sequential execution.
getApprovalModeSpy.mockReturnValue(ApprovalMode.YOLO);
// First call yields the tool request
sendMessageStreamSpy.mockImplementationOnce(async function* () {
yield* [
{
type: GeminiEventType.ToolCallRequest,
value: {
callId: 'test-call-id-1',
name: 'test-tool-1',
args: {},
},
},
{
type: GeminiEventType.ToolCallRequest,
value: {
callId: 'test-call-id-2',
name: 'test-tool-2',
args: {},
},
},
];
});
// Subsequent calls yield nothing, as the tools will "succeed".
sendMessageStreamSpy.mockImplementation(async function* () {
yield* [{ type: 'content', value: 'All tools executed.' }];
});
const mockTool1 = new MockTool({
name: 'test-tool-1',
displayName: 'Test Tool 1',
shouldConfirmExecute: vi.fn(mockToolConfirmationFn),
execute: vi
.fn()
.mockResolvedValue({ llmContent: 'tool 1 done', returnDisplay: '' }),
});
const mockTool2 = new MockTool({
name: 'test-tool-2',
displayName: 'Test Tool 2',
shouldConfirmExecute: vi.fn(mockToolConfirmationFn),
execute: vi
.fn()
.mockResolvedValue({ llmContent: 'tool 2 done', returnDisplay: '' }),
});
getToolRegistrySpy.mockReturnValue({
getAllTools: vi.fn().mockReturnValue([mockTool1, mockTool2]),
getToolsByServer: vi.fn().mockReturnValue([]),
getTool: vi.fn().mockImplementation((name: string) => {
if (name === 'test-tool-1') return mockTool1;
if (name === 'test-tool-2') return mockTool2;
return undefined;
}),
});
const agent = request.agent(app);
const res = await agent
.post('/')
.send(
createStreamMessageRequest(
'run two tools',
'a2a-multi-tool-test-message',
),
)
.set('Content-Type', 'application/json')
.expect(200);
const events = streamToSSEEvents(res.text);
assertTaskCreationAndWorkingStatus(events);
// --- Assert the sequential execution flow ---
const eventStream = events.slice(2).map((e) => {
const update = e.result as TaskStatusUpdateEvent;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const agentData = update.metadata?.['coderAgent'] as any;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const toolData = update.status.message?.parts[0] as any;
if (!toolData) {
return { kind: agentData.kind };
}
return {
kind: agentData.kind,
status: toolData.data?.status,
callId: toolData.data?.request.callId,
};
});
const expectedFlow = [
// Initial state change
{ kind: 'state-change', status: undefined, callId: undefined },
// Tool 1 Lifecycle
{
kind: 'tool-call-update',
status: 'validating',
callId: 'test-call-id-1',
},
{
kind: 'tool-call-update',
status: 'scheduled',
callId: 'test-call-id-1',
},
{
kind: 'tool-call-update',
status: 'executing',
callId: 'test-call-id-1',
},
{
kind: 'tool-call-update',
status: 'success',
callId: 'test-call-id-1',
},
// Tool 2 Lifecycle
{
kind: 'tool-call-update',
status: 'validating',
callId: 'test-call-id-2',
},
{
kind: 'tool-call-update',
status: 'scheduled',
callId: 'test-call-id-2',
},
{
kind: 'tool-call-update',
status: 'executing',
callId: 'test-call-id-2',
},
{
kind: 'tool-call-update',
status: 'success',
callId: 'test-call-id-2',
},
// Final updates
{ kind: 'state-change', status: undefined, callId: undefined },
{ kind: 'text-content', status: undefined, callId: undefined },
];
// Use `toContainEqual` for flexibility if other events are interspersed.
expect(eventStream).toEqual(expect.arrayContaining(expectedFlow));
assertUniqueFinalEventIsLast(events);
expect(events.length).toBe(8);
});
it('should handle tool calls that do not require approval', async () => {
@@ -37,7 +37,7 @@ import {
} from '@google/gemini-cli-core';
import type { Part, PartListUnion } from '@google/genai';
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
import type { HistoryItem, SlashCommandProcessorResult } from '../types.js';
import type { SlashCommandProcessorResult } from '../types.js';
import { MessageType, StreamingState } from '../types.js';
import type { LoadedSettings } from '../../config/settings.js';
@@ -231,8 +231,9 @@ describe('useGeminiStream', () => {
mockUseReactToolScheduler.mockReturnValue([
[], // Default to empty array for toolCalls
mockScheduleToolCalls,
mockCancelAllToolCalls,
mockMarkToolsAsSubmitted,
vi.fn(), // setToolCallsForDisplay
mockCancelAllToolCalls,
]);
// Reset mocks for GeminiClient instance methods (startChat and sendMessageStream)
@@ -259,38 +260,71 @@ describe('useGeminiStream', () => {
initialToolCalls: TrackedToolCall[] = [],
geminiClient?: any,
) => {
let currentToolCalls = initialToolCalls;
const setToolCalls = (newToolCalls: TrackedToolCall[]) => {
currentToolCalls = newToolCalls;
};
mockUseReactToolScheduler.mockImplementation(() => [
currentToolCalls,
mockScheduleToolCalls,
mockCancelAllToolCalls,
mockMarkToolsAsSubmitted,
]);
const client = geminiClient || mockConfig.getGeminiClient();
const initialProps = {
client,
history: [],
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
config: mockConfig,
onDebugMessage: mockOnDebugMessage,
handleSlashCommand: mockHandleSlashCommand as unknown as (
cmd: PartListUnion,
) => Promise<SlashCommandProcessorResult | false>,
shellModeActive: false,
loadedSettings: mockLoadedSettings,
toolCalls: initialToolCalls,
};
const { result, rerender } = renderHook(
(props: {
client: any;
history: HistoryItem[];
addItem: UseHistoryManagerReturn['addItem'];
config: Config;
onDebugMessage: (message: string) => void;
handleSlashCommand: (
cmd: PartListUnion,
) => Promise<SlashCommandProcessorResult | false>;
shellModeActive: boolean;
loadedSettings: LoadedSettings;
toolCalls?: TrackedToolCall[]; // Allow passing updated toolCalls
}) => {
// Update the mock's return value if new toolCalls are passed in props
if (props.toolCalls) {
setToolCalls(props.toolCalls);
}
(props: typeof initialProps) => {
// This mock needs to be stateful. When setToolCallsForDisplay is called,
// it should trigger a rerender with the new state.
const mockSetToolCallsForDisplay = vi.fn((updater) => {
const newToolCalls =
typeof updater === 'function' ? updater(props.toolCalls) : updater;
rerender({ ...props, toolCalls: newToolCalls });
});
// Create a stateful mock for cancellation that updates the toolCalls state.
const statefulCancelAllToolCalls = vi.fn((...args) => {
// Call the original spy so `toHaveBeenCalled` checks still work.
mockCancelAllToolCalls(...args);
const newToolCalls = props.toolCalls.map((tc) => {
// Only cancel tools that are in a cancellable state.
if (
tc.status === 'awaiting_approval' ||
tc.status === 'executing' ||
tc.status === 'scheduled' ||
tc.status === 'validating'
) {
// A real cancelled tool call has a response object.
// We need to simulate this to avoid type errors downstream.
return {
...tc,
status: 'cancelled',
response: {
callId: tc.request.callId,
responseParts: [],
resultDisplay: 'Request cancelled.',
},
responseSubmittedToGemini: true, // Mark as "processed"
} as any as TrackedCancelledToolCall;
}
return tc;
});
rerender({ ...props, toolCalls: newToolCalls });
});
mockUseReactToolScheduler.mockImplementation(() => [
props.toolCalls,
mockScheduleToolCalls,
mockMarkToolsAsSubmitted,
mockSetToolCallsForDisplay,
statefulCancelAllToolCalls, // Use the stateful mock
]);
return useGeminiStream(
props.client,
props.history,
@@ -313,19 +347,7 @@ describe('useGeminiStream', () => {
);
},
{
initialProps: {
client,
history: [],
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
config: mockConfig,
onDebugMessage: mockOnDebugMessage,
handleSlashCommand: mockHandleSlashCommand as unknown as (
cmd: PartListUnion,
) => Promise<SlashCommandProcessorResult | false>,
shellModeActive: false,
loadedSettings: mockLoadedSettings,
toolCalls: initialToolCalls,
},
initialProps,
},
);
return {
@@ -452,7 +474,7 @@ describe('useGeminiStream', () => {
mockUseReactToolScheduler.mockImplementation((onComplete) => {
capturedOnComplete = onComplete;
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted];
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted, vi.fn()];
});
renderHook(() =>
@@ -535,7 +557,7 @@ describe('useGeminiStream', () => {
mockUseReactToolScheduler.mockImplementation((onComplete) => {
capturedOnComplete = onComplete;
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted];
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted, vi.fn()];
});
renderHook(() =>
@@ -647,7 +669,7 @@ describe('useGeminiStream', () => {
mockUseReactToolScheduler.mockImplementation((onComplete) => {
capturedOnComplete = onComplete;
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted];
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted, vi.fn()];
});
renderHook(() =>
@@ -760,6 +782,7 @@ describe('useGeminiStream', () => {
currentToolCalls,
mockScheduleToolCalls,
mockMarkToolsAsSubmitted,
vi.fn(), // setToolCallsForDisplay
];
});
@@ -797,6 +820,7 @@ describe('useGeminiStream', () => {
completedToolCalls,
mockScheduleToolCalls,
mockMarkToolsAsSubmitted,
vi.fn(), // setToolCallsForDisplay
];
});
@@ -1031,7 +1055,7 @@ describe('useGeminiStream', () => {
expect(result.current.streamingState).toBe(StreamingState.Idle);
});
it('should not cancel if a tool call is in progress (not just responding)', async () => {
it('should cancel if a tool call is in progress', async () => {
const toolCalls: TrackedToolCall[] = [
{
request: { callId: 'call1', name: 'tool1', args: {} },
@@ -1052,7 +1076,6 @@ describe('useGeminiStream', () => {
} as TrackedExecutingToolCall,
];
const abortSpy = vi.spyOn(AbortController.prototype, 'abort');
const { result } = renderTestHook(toolCalls);
// State is `Responding` because a tool is running
@@ -1061,8 +1084,71 @@ describe('useGeminiStream', () => {
// Try to cancel
simulateEscapeKeyPress();
// Nothing should happen because the state is not `Responding`
expect(abortSpy).not.toHaveBeenCalled();
// The cancel function should be called
expect(mockCancelAllToolCalls).toHaveBeenCalled();
});
it('should cancel a request when a tool is awaiting confirmation', async () => {
const mockOnConfirm = vi.fn().mockResolvedValue(undefined);
const toolCalls: TrackedToolCall[] = [
{
request: {
callId: 'confirm-call',
name: 'some_tool',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-1',
},
status: 'awaiting_approval',
responseSubmittedToGemini: false,
tool: {
name: 'some_tool',
description: 'a tool',
build: vi.fn().mockImplementation((_) => ({
getDescription: () => `Mock description`,
})),
} as any,
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
confirmationDetails: {
type: 'edit',
title: 'Confirm Edit',
onConfirm: mockOnConfirm,
fileName: 'file.txt',
filePath: '/test/file.txt',
fileDiff: 'fake diff',
originalContent: 'old',
newContent: 'new',
},
} as TrackedWaitingToolCall,
];
const { result } = renderTestHook(toolCalls);
// State is `WaitingForConfirmation` because a tool is awaiting approval
expect(result.current.streamingState).toBe(
StreamingState.WaitingForConfirmation,
);
// Try to cancel
simulateEscapeKeyPress();
// The imperative cancel function should be called on the scheduler
expect(mockCancelAllToolCalls).toHaveBeenCalled();
// A cancellation message should be added to history
await waitFor(() => {
expect(mockAddItem).toHaveBeenCalledWith(
expect.objectContaining({
text: 'Request cancelled.',
}),
expect.any(Number),
);
});
// The final state should be idle
expect(result.current.streamingState).toBe(StreamingState.Idle);
});
});
@@ -1282,7 +1368,7 @@ describe('useGeminiStream', () => {
mockUseReactToolScheduler.mockImplementation((onComplete) => {
capturedOnComplete = onComplete;
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted];
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted, vi.fn()];
});
renderHook(() =>
+110 -55
View File
@@ -111,6 +111,7 @@ export const useGeminiStream = (
const [initError, setInitError] = useState<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
const turnCancelledRef = useRef(false);
const activeQueryIdRef = useRef<string | null>(null);
const [isResponding, setIsResponding] = useState<boolean>(false);
const [thought, setThought] = useState<ThoughtSummary | null>(null);
const [pendingHistoryItem, pendingHistoryItemRef, setPendingHistoryItem] =
@@ -126,47 +127,55 @@ export const useGeminiStream = (
return new GitService(config.getProjectRoot(), storage);
}, [config, storage]);
const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] =
useReactToolScheduler(
async (completedToolCallsFromScheduler) => {
// This onComplete is called when ALL scheduled tools for a given batch are done.
if (completedToolCallsFromScheduler.length > 0) {
// Add the final state of these tools to the history for display.
addItem(
mapTrackedToolCallsToDisplay(
completedToolCallsFromScheduler as TrackedToolCall[],
),
Date.now(),
);
// Record tool calls with full metadata before sending responses.
try {
const currentModel =
config.getGeminiClient().getCurrentSequenceModel() ??
config.getModel();
config
.getGeminiClient()
.getChat()
.recordCompletedToolCalls(
currentModel,
completedToolCallsFromScheduler,
);
} catch (error) {
console.error(
`Error recording completed tool call information: ${error}`,
);
}
// Handle tool response submission immediately when tools complete
await handleCompletedTools(
const [
toolCalls,
scheduleToolCalls,
markToolsAsSubmitted,
setToolCallsForDisplay,
cancelAllToolCalls,
] = useReactToolScheduler(
async (completedToolCallsFromScheduler) => {
// This onComplete is called when ALL scheduled tools for a given batch are done.
if (completedToolCallsFromScheduler.length > 0) {
// Add the final state of these tools to the history for display.
addItem(
mapTrackedToolCallsToDisplay(
completedToolCallsFromScheduler as TrackedToolCall[],
),
Date.now(),
);
// Clear the live-updating display now that the final state is in history.
setToolCallsForDisplay([]);
// Record tool calls with full metadata before sending responses.
try {
const currentModel =
config.getGeminiClient().getCurrentSequenceModel() ??
config.getModel();
config
.getGeminiClient()
.getChat()
.recordCompletedToolCalls(
currentModel,
completedToolCallsFromScheduler,
);
} catch (error) {
console.error(
`Error recording completed tool call information: ${error}`,
);
}
},
config,
getPreferredEditor,
onEditorClose,
);
// Handle tool response submission immediately when tools complete
await handleCompletedTools(
completedToolCallsFromScheduler as TrackedToolCall[],
);
}
},
config,
getPreferredEditor,
onEditorClose,
);
const pendingToolCallGroupDisplay = useMemo(
() =>
@@ -265,27 +274,54 @@ export const useGeminiStream = (
}, [streamingState, config, history]);
const cancelOngoingRequest = useCallback(() => {
if (streamingState !== StreamingState.Responding) {
if (
streamingState !== StreamingState.Responding &&
streamingState !== StreamingState.WaitingForConfirmation
) {
return;
}
if (turnCancelledRef.current) {
return;
}
turnCancelledRef.current = true;
abortControllerRef.current?.abort();
// A full cancellation means no tools have produced a final result yet.
// This determines if we show a generic "Request cancelled" message.
const isFullCancellation = !toolCalls.some(
(tc) => tc.status === 'success' || tc.status === 'error',
);
// Ensure we have an abort controller, creating one if it doesn't exist.
if (!abortControllerRef.current) {
abortControllerRef.current = new AbortController();
}
// The order is important here.
// 1. Fire the signal to interrupt any active async operations.
abortControllerRef.current.abort();
// 2. Call the imperative cancel to clear the queue of pending tools.
cancelAllToolCalls(abortControllerRef.current.signal);
if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, Date.now());
}
addItem(
{
type: MessageType.INFO,
text: 'Request cancelled.',
},
Date.now(),
);
setPendingHistoryItem(null);
// If it was a full cancellation, add the info message now.
// Otherwise, we let handleCompletedTools figure out the next step,
// which might involve sending partial results back to the model.
if (isFullCancellation) {
addItem(
{
type: MessageType.INFO,
text: 'Request cancelled.',
},
Date.now(),
);
setIsResponding(false);
}
onCancelSubmit();
setIsResponding(false);
setShellInputFocused(false);
}, [
streamingState,
@@ -294,6 +330,8 @@ export const useGeminiStream = (
onCancelSubmit,
pendingHistoryItemRef,
setShellInputFocused,
cancelAllToolCalls,
toolCalls,
]);
useKeypress(
@@ -302,7 +340,11 @@ export const useGeminiStream = (
cancelOngoingRequest();
}
},
{ isActive: streamingState === StreamingState.Responding },
{
isActive:
streamingState === StreamingState.Responding ||
streamingState === StreamingState.WaitingForConfirmation,
},
);
const prepareQueryForGemini = useCallback(
@@ -764,6 +806,8 @@ export const useGeminiStream = (
options?: { isContinuation: boolean },
prompt_id?: string,
) => {
const queryId = `${Date.now()}-${Math.random()}`;
activeQueryIdRef.current = queryId;
if (
(streamingState === StreamingState.Responding ||
streamingState === StreamingState.WaitingForConfirmation) &&
@@ -901,7 +945,9 @@ export const useGeminiStream = (
);
}
} finally {
setIsResponding(false);
if (activeQueryIdRef.current === queryId) {
setIsResponding(false);
}
}
});
},
@@ -963,10 +1009,6 @@ export const useGeminiStream = (
const handleCompletedTools = useCallback(
async (completedToolCallsFromScheduler: TrackedToolCall[]) => {
if (isResponding) {
return;
}
const completedAndReadyToSubmitTools =
completedToolCallsFromScheduler.filter(
(
@@ -1028,6 +1070,19 @@ export const useGeminiStream = (
);
if (allToolsCancelled) {
// If the turn was cancelled via the imperative escape key flow,
// the cancellation message is added there. We check the ref to avoid duplication.
if (!turnCancelledRef.current) {
addItem(
{
type: MessageType.INFO,
text: 'Request cancelled.',
},
Date.now(),
);
}
setIsResponding(false);
if (geminiClient) {
// We need to manually add the function responses to the history
// so the model knows the tools were cancelled.
@@ -1074,12 +1129,12 @@ export const useGeminiStream = (
);
},
[
isResponding,
submitQuery,
markToolsAsSubmitted,
geminiClient,
performMemoryRefresh,
modelSwitchedFromQuotaError,
addItem,
],
);
@@ -62,12 +62,20 @@ export type TrackedToolCall =
| TrackedCompletedToolCall
| TrackedCancelledToolCall;
export type CancelAllFn = (signal: AbortSignal) => void;
export function useReactToolScheduler(
onComplete: (tools: CompletedToolCall[]) => Promise<void>,
config: Config,
getPreferredEditor: () => EditorType | undefined,
onEditorClose: () => void,
): [TrackedToolCall[], ScheduleFn, MarkToolsAsSubmittedFn] {
): [
TrackedToolCall[],
ScheduleFn,
MarkToolsAsSubmittedFn,
React.Dispatch<React.SetStateAction<TrackedToolCall[]>>,
CancelAllFn,
] {
const [toolCallsForDisplay, setToolCallsForDisplay] = useState<
TrackedToolCall[]
>([]);
@@ -112,37 +120,36 @@ export function useReactToolScheduler(
);
const toolCallsUpdateHandler: ToolCallsUpdateHandler = useCallback(
(updatedCoreToolCalls: ToolCall[]) => {
setToolCallsForDisplay((prevTrackedCalls) =>
updatedCoreToolCalls.map((coreTc) => {
const existingTrackedCall = prevTrackedCalls.find(
(ptc) => ptc.request.callId === coreTc.request.callId,
);
// Start with the new core state, then layer on the existing UI state
// to ensure UI-only properties like pid are preserved.
(allCoreToolCalls: ToolCall[]) => {
setToolCallsForDisplay((prevTrackedCalls) => {
const prevCallsMap = new Map(
prevTrackedCalls.map((c) => [c.request.callId, c]),
);
return allCoreToolCalls.map((coreTc): TrackedToolCall => {
const existingTrackedCall = prevCallsMap.get(coreTc.request.callId);
const responseSubmittedToGemini =
existingTrackedCall?.responseSubmittedToGemini ?? false;
if (coreTc.status === 'executing') {
// Preserve live output if it exists from a previous render.
const liveOutput = (existingTrackedCall as TrackedExecutingToolCall)
?.liveOutput;
return {
...coreTc,
responseSubmittedToGemini,
liveOutput: (existingTrackedCall as TrackedExecutingToolCall)
?.liveOutput,
liveOutput,
pid: (coreTc as ExecutingToolCall).pid,
};
} else {
return {
...coreTc,
responseSubmittedToGemini,
};
}
// For other statuses, explicitly set liveOutput and pid to undefined
// to ensure they are not carried over from a previous executing state.
return {
...coreTc,
responseSubmittedToGemini,
liveOutput: undefined,
pid: undefined,
};
}),
);
});
});
},
[setToolCallsForDisplay],
);
@@ -178,9 +185,10 @@ export function useReactToolScheduler(
request: ToolCallRequestInfo | ToolCallRequestInfo[],
signal: AbortSignal,
) => {
setToolCallsForDisplay([]);
void scheduler.schedule(request, signal);
},
[scheduler],
[scheduler, setToolCallsForDisplay],
);
const markToolsAsSubmitted: MarkToolsAsSubmittedFn = useCallback(
@@ -196,7 +204,20 @@ export function useReactToolScheduler(
[],
);
return [toolCallsForDisplay, schedule, markToolsAsSubmitted];
const cancelAllToolCalls = useCallback(
(signal: AbortSignal) => {
scheduler.cancelAll(signal);
},
[scheduler],
);
return [
toolCallsForDisplay,
schedule,
markToolsAsSubmitted,
setToolCallsForDisplay,
cancelAllToolCalls,
];
}
/**
@@ -260,9 +260,15 @@ describe('useReactToolScheduler', () => {
args: { param: 'value' },
} as any;
let completedToolCalls: ToolCall[] = [];
onComplete.mockImplementation((calls) => {
completedToolCalls = calls;
});
act(() => {
schedule(request, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
});
@@ -292,7 +298,110 @@ describe('useReactToolScheduler', () => {
}),
}),
]);
expect(result.current[0]).toEqual([]);
expect(completedToolCalls).toHaveLength(1);
expect(completedToolCalls[0].status).toBe('success');
expect(completedToolCalls[0].request).toBe(request);
});
it('should clear previous tool calls when scheduling new ones', async () => {
mockToolRegistry.getTool.mockReturnValue(mockTool);
(mockTool.execute as Mock).mockResolvedValue({
llmContent: 'Tool output',
returnDisplay: 'Formatted tool output',
} as ToolResult);
const { result } = renderScheduler();
const schedule = result.current[1];
const setToolCallsForDisplay = result.current[3];
// Manually set a tool call in the display.
const oldToolCall = {
request: { callId: 'oldCall' },
status: 'success',
} as any;
act(() => {
setToolCallsForDisplay([oldToolCall]);
});
expect(result.current[0]).toEqual([oldToolCall]);
const newRequest: ToolCallRequestInfo = {
callId: 'newCall',
name: 'mockTool',
args: {},
} as any;
act(() => {
schedule(newRequest, new AbortController().signal);
});
// After scheduling, the old call should be gone,
// and the new one should be in the display in its initial state.
expect(result.current[0].length).toBe(1);
expect(result.current[0][0].request.callId).toBe('newCall');
expect(result.current[0][0].request.callId).not.toBe('oldCall');
// Let the new call finish.
await act(async () => {
await vi.runAllTimersAsync();
});
await act(async () => {
await vi.runAllTimersAsync();
});
await act(async () => {
await vi.runAllTimersAsync();
});
expect(onComplete).toHaveBeenCalled();
});
it('should cancel all running tool calls', async () => {
mockToolRegistry.getTool.mockReturnValue(mockTool);
let resolveExecute: (value: ToolResult) => void = () => {};
const executePromise = new Promise<ToolResult>((resolve) => {
resolveExecute = resolve;
});
(mockTool.execute as Mock).mockReturnValue(executePromise);
(mockTool.shouldConfirmExecute as Mock).mockResolvedValue(null);
const { result } = renderScheduler();
const schedule = result.current[1];
const cancelAllToolCalls = result.current[4];
const request: ToolCallRequestInfo = {
callId: 'cancelCall',
name: 'mockTool',
args: {},
} as any;
act(() => {
schedule(request, new AbortController().signal);
});
await act(async () => {
await vi.runAllTimersAsync();
}); // validation
await act(async () => {
await vi.runAllTimersAsync();
}); // scheduling
// At this point, the tool is 'executing' and waiting on the promise.
expect(result.current[0][0].status).toBe('executing');
const cancelController = new AbortController();
act(() => {
cancelAllToolCalls(cancelController.signal);
});
await act(async () => {
await vi.runAllTimersAsync();
});
expect(onComplete).toHaveBeenCalledWith([
expect.objectContaining({
status: 'cancelled',
request,
}),
]);
// Clean up the pending promise to avoid open handles.
resolveExecute({ llmContent: 'output', returnDisplay: 'display' });
});
it('should handle tool not found', async () => {
@@ -305,6 +414,11 @@ describe('useReactToolScheduler', () => {
args: {},
} as any;
let completedToolCalls: ToolCall[] = [];
onComplete.mockImplementation((calls) => {
completedToolCalls = calls;
});
act(() => {
schedule(request, new AbortController().signal);
});
@@ -315,24 +429,15 @@ describe('useReactToolScheduler', () => {
await vi.runAllTimersAsync();
});
expect(onComplete).toHaveBeenCalledWith([
expect.objectContaining({
status: 'error',
request,
response: expect.objectContaining({
error: expect.objectContaining({
message: expect.stringMatching(
/Tool "nonexistentTool" not found in registry/,
),
}),
}),
}),
]);
const errorMessage = onComplete.mock.calls[0][0][0].response.error.message;
expect(errorMessage).toContain('Did you mean one of:');
expect(errorMessage).toContain('"mockTool"');
expect(errorMessage).toContain('"anotherTool"');
expect(result.current[0]).toEqual([]);
expect(completedToolCalls).toHaveLength(1);
expect(completedToolCalls[0].status).toBe('error');
expect(completedToolCalls[0].request).toBe(request);
expect((completedToolCalls[0] as any).response.error.message).toContain(
'Tool "nonexistentTool" not found in registry',
);
expect((completedToolCalls[0] as any).response.error.message).toContain(
'Did you mean one of:',
);
});
it('should handle error during shouldConfirmExecute', async () => {
@@ -348,6 +453,11 @@ describe('useReactToolScheduler', () => {
args: {},
} as any;
let completedToolCalls: ToolCall[] = [];
onComplete.mockImplementation((calls) => {
completedToolCalls = calls;
});
act(() => {
schedule(request, new AbortController().signal);
});
@@ -358,16 +468,10 @@ describe('useReactToolScheduler', () => {
await vi.runAllTimersAsync();
});
expect(onComplete).toHaveBeenCalledWith([
expect.objectContaining({
status: 'error',
request,
response: expect.objectContaining({
error: confirmError,
}),
}),
]);
expect(result.current[0]).toEqual([]);
expect(completedToolCalls).toHaveLength(1);
expect(completedToolCalls[0].status).toBe('error');
expect(completedToolCalls[0].request).toBe(request);
expect((completedToolCalls[0] as any).response.error).toBe(confirmError);
});
it('should handle error during execute', async () => {
@@ -384,6 +488,11 @@ describe('useReactToolScheduler', () => {
args: {},
} as any;
let completedToolCalls: ToolCall[] = [];
onComplete.mockImplementation((calls) => {
completedToolCalls = calls;
});
act(() => {
schedule(request, new AbortController().signal);
});
@@ -397,16 +506,10 @@ describe('useReactToolScheduler', () => {
await vi.runAllTimersAsync();
});
expect(onComplete).toHaveBeenCalledWith([
expect.objectContaining({
status: 'error',
request,
response: expect.objectContaining({
error: execError,
}),
}),
]);
expect(result.current[0]).toEqual([]);
expect(completedToolCalls).toHaveLength(1);
expect(completedToolCalls[0].status).toBe('error');
expect(completedToolCalls[0].request).toBe(request);
expect((completedToolCalls[0] as any).response.error).toBe(execError);
});
it('should handle tool requiring confirmation - approved', async () => {
@@ -518,7 +621,7 @@ describe('useReactToolScheduler', () => {
functionResponse: expect.objectContaining({
response: expect.objectContaining({
error:
'[Operation Cancelled] Reason: User did not allow tool call',
'[Operation Cancelled] Reason: User cancelled the operation.',
}),
}),
}),
@@ -705,7 +808,9 @@ describe('useReactToolScheduler', () => {
],
}),
});
expect(result.current[0]).toEqual([]);
expect(completedCalls).toHaveLength(2);
expect(completedCalls.every((t) => t.status === 'success')).toBe(true);
});
it('should queue if scheduling while already running', async () => {
@@ -774,7 +879,8 @@ describe('useReactToolScheduler', () => {
response: expect.objectContaining({ resultDisplay: 'done display' }),
}),
]);
expect(result.current[0]).toEqual([]);
const toolCalls = result.current[0];
expect(toolCalls).toHaveLength(0);
});
});
+275 -15
View File
@@ -288,6 +288,263 @@ describe('CoreToolScheduler', () => {
expect(completedCalls[0].status).toBe('cancelled');
});
it('should cancel all tools when cancelAll is called', async () => {
const mockTool1 = new MockTool({
name: 'mockTool1',
shouldConfirmExecute: MOCK_TOOL_SHOULD_CONFIRM_EXECUTE,
});
const mockTool2 = new MockTool({ name: 'mockTool2' });
const mockTool3 = new MockTool({ name: 'mockTool3' });
const mockToolRegistry = {
getTool: (name: string) => {
if (name === 'mockTool1') return mockTool1;
if (name === 'mockTool2') return mockTool2;
if (name === 'mockTool3') return mockTool3;
return undefined;
},
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {},
registerTool: () => {},
getToolByName: (name: string) => {
if (name === 'mockTool1') return mockTool1;
if (name === 'mockTool2') return mockTool2;
if (name === 'mockTool3') return mockTool3;
return undefined;
},
getToolByDisplayName: () => undefined,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
getToolsByServer: () => [],
} as unknown as ToolRegistry;
const onAllToolCallsComplete = vi.fn();
const onToolCallsUpdate = vi.fn();
const mockConfig = {
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.DEFAULT,
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
terminalHeight: 30,
}),
storage: {
getProjectTempDir: () => '/tmp',
},
getTruncateToolOutputThreshold: () =>
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES,
getToolRegistry: () => mockToolRegistry,
getUseSmartEdit: () => false,
getUseModelRouter: () => false,
getGeminiClient: () => null, // No client needed for these tests
getEnableMessageBusIntegration: () => false,
getMessageBus: () => null,
getPolicyEngine: () => null,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
onEditorClose: vi.fn(),
});
const abortController = new AbortController();
const requests = [
{
callId: '1',
name: 'mockTool1',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-1',
},
{
callId: '2',
name: 'mockTool2',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-1',
},
{
callId: '3',
name: 'mockTool3',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-1',
},
];
// Don't await, let it run in the background
void scheduler.schedule(requests, abortController.signal);
// Wait for the first tool to be awaiting approval
await waitForStatus(onToolCallsUpdate, 'awaiting_approval');
// Cancel all operations
scheduler.cancelAll(abortController.signal);
abortController.abort(); // Also fire the signal
await vi.waitFor(() => {
expect(onAllToolCallsComplete).toHaveBeenCalled();
});
const completedCalls = onAllToolCallsComplete.mock
.calls[0][0] as ToolCall[];
expect(completedCalls).toHaveLength(3);
expect(completedCalls.find((c) => c.request.callId === '1')?.status).toBe(
'cancelled',
);
expect(completedCalls.find((c) => c.request.callId === '2')?.status).toBe(
'cancelled',
);
expect(completedCalls.find((c) => c.request.callId === '3')?.status).toBe(
'cancelled',
);
});
it('should cancel all tools in a batch when one is cancelled via confirmation', async () => {
const mockTool1 = new MockTool({
name: 'mockTool1',
shouldConfirmExecute: MOCK_TOOL_SHOULD_CONFIRM_EXECUTE,
});
const mockTool2 = new MockTool({ name: 'mockTool2' });
const mockTool3 = new MockTool({ name: 'mockTool3' });
const mockToolRegistry = {
getTool: (name: string) => {
if (name === 'mockTool1') return mockTool1;
if (name === 'mockTool2') return mockTool2;
if (name === 'mockTool3') return mockTool3;
return undefined;
},
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {},
registerTool: () => {},
getToolByName: (name: string) => {
if (name === 'mockTool1') return mockTool1;
if (name === 'mockTool2') return mockTool2;
if (name === 'mockTool3') return mockTool3;
return undefined;
},
getToolByDisplayName: () => undefined,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
getToolsByServer: () => [],
} as unknown as ToolRegistry;
const onAllToolCallsComplete = vi.fn();
const onToolCallsUpdate = vi.fn();
const mockConfig = {
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.DEFAULT,
getAllowedTools: () => [],
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
}),
getShellExecutionConfig: () => ({
terminalWidth: 90,
terminalHeight: 30,
}),
storage: {
getProjectTempDir: () => '/tmp',
},
getTruncateToolOutputThreshold: () =>
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES,
getToolRegistry: () => mockToolRegistry,
getUseSmartEdit: () => false,
getUseModelRouter: () => false,
getGeminiClient: () => null, // No client needed for these tests
getEnableMessageBusIntegration: () => false,
getMessageBus: () => null,
getPolicyEngine: () => null,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
onEditorClose: vi.fn(),
});
const abortController = new AbortController();
const requests = [
{
callId: '1',
name: 'mockTool1',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-1',
},
{
callId: '2',
name: 'mockTool2',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-1',
},
{
callId: '3',
name: 'mockTool3',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-1',
},
];
// Don't await, let it run in the background
void scheduler.schedule(requests, abortController.signal);
// Wait for the first tool to be awaiting approval
const awaitingCall = (await waitForStatus(
onToolCallsUpdate,
'awaiting_approval',
)) as WaitingToolCall;
// Cancel the first tool via its confirmation handler
await awaitingCall.confirmationDetails.onConfirm(
ToolConfirmationOutcome.Cancel,
);
abortController.abort(); // User cancelling often involves an abort signal
await vi.waitFor(() => {
expect(onAllToolCallsComplete).toHaveBeenCalled();
});
const completedCalls = onAllToolCallsComplete.mock
.calls[0][0] as ToolCall[];
expect(completedCalls).toHaveLength(3);
expect(completedCalls.find((c) => c.request.callId === '1')?.status).toBe(
'cancelled',
);
expect(completedCalls.find((c) => c.request.callId === '2')?.status).toBe(
'cancelled',
);
expect(completedCalls.find((c) => c.request.callId === '3')?.status).toBe(
'cancelled',
);
});
it('should mark tool call as cancelled when abort happens during confirmation error', async () => {
const abortController = new AbortController();
const abortError = new Error('Abort requested during confirmation');
@@ -1510,16 +1767,19 @@ describe('CoreToolScheduler request queueing', () => {
await scheduler.schedule(requests, abortController.signal);
// Wait for all tools to be awaiting approval
// Wait for the FIRST tool to be awaiting approval
await vi.waitFor(() => {
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
// With the sequential scheduler, the update includes the active call and the queue.
expect(calls?.length).toBe(3);
expect(calls?.every((call) => call.status === 'awaiting_approval')).toBe(
true,
);
expect(calls?.[0].status).toBe('awaiting_approval');
expect(calls?.[0].request.callId).toBe('1');
// Check that the other two are in the queue (still in 'validating' state)
expect(calls?.[1].status).toBe('validating');
expect(calls?.[2].status).toBe('validating');
});
expect(pendingConfirmations.length).toBe(3);
expect(pendingConfirmations.length).toBe(1);
// Approve the first tool with ProceedAlways
const firstConfirmation = pendingConfirmations[0];
@@ -1528,15 +1788,16 @@ describe('CoreToolScheduler request queueing', () => {
// Wait for all tools to be completed
await vi.waitFor(() => {
expect(onAllToolCallsComplete).toHaveBeenCalled();
const completedCalls = onAllToolCallsComplete.mock.calls.at(
-1,
)?.[0] as ToolCall[];
expect(completedCalls?.length).toBe(3);
expect(completedCalls?.every((call) => call.status === 'success')).toBe(
true,
);
});
const completedCalls = onAllToolCallsComplete.mock.calls.at(
-1,
)?.[0] as ToolCall[];
expect(completedCalls?.length).toBe(3);
expect(completedCalls?.every((call) => call.status === 'success')).toBe(
true,
);
// Verify approval mode was changed
expect(approvalMode).toBe(ApprovalMode.AUTO_EDIT);
});
@@ -1788,11 +2049,10 @@ describe('CoreToolScheduler Sequential Execution', () => {
expect(onAllToolCallsComplete).toHaveBeenCalled();
});
// Check that execute was called for all three tools initially
expect(executeFn).toHaveBeenCalledTimes(3);
// Check that execute was called for the first two tools only
expect(executeFn).toHaveBeenCalledTimes(2);
expect(executeFn).toHaveBeenCalledWith({ call: 1 });
expect(executeFn).toHaveBeenCalledWith({ call: 2 });
expect(executeFn).toHaveBeenCalledWith({ call: 3 });
const completedCalls = onAllToolCallsComplete.mock
.calls[0][0] as ToolCall[];
+234 -114
View File
@@ -348,12 +348,15 @@ export class CoreToolScheduler {
private onEditorClose: () => void;
private isFinalizingToolCalls = false;
private isScheduling = false;
private isCancelling = false;
private requestQueue: Array<{
request: ToolCallRequestInfo | ToolCallRequestInfo[];
signal: AbortSignal;
resolve: () => void;
reject: (reason?: Error) => void;
}> = [];
private toolCallQueue: ToolCall[] = [];
private completedToolCallsForBatch: CompletedToolCall[] = [];
constructor(options: CoreToolSchedulerOptions) {
this.config = options.config;
@@ -398,30 +401,36 @@ export class CoreToolScheduler {
private setStatusInternal(
targetCallId: string,
status: 'success',
signal: AbortSignal,
response: ToolCallResponseInfo,
): void;
private setStatusInternal(
targetCallId: string,
status: 'awaiting_approval',
signal: AbortSignal,
confirmationDetails: ToolCallConfirmationDetails,
): void;
private setStatusInternal(
targetCallId: string,
status: 'error',
signal: AbortSignal,
response: ToolCallResponseInfo,
): void;
private setStatusInternal(
targetCallId: string,
status: 'cancelled',
signal: AbortSignal,
reason: string,
): void;
private setStatusInternal(
targetCallId: string,
status: 'executing' | 'scheduled' | 'validating',
signal: AbortSignal,
): void;
private setStatusInternal(
targetCallId: string,
newStatus: Status,
signal: AbortSignal,
auxiliaryData?: unknown,
): void {
this.toolCalls = this.toolCalls.map((currentCall) => {
@@ -561,7 +570,6 @@ export class CoreToolScheduler {
}
});
this.notifyToolCallsUpdate();
this.checkAndNotifyCompletion();
}
private setArgsInternal(targetCallId: string, args: unknown): void {
@@ -692,11 +700,43 @@ export class CoreToolScheduler {
return this._schedule(request, signal);
}
cancelAll(signal: AbortSignal): void {
if (this.isCancelling) {
return;
}
this.isCancelling = true;
// Cancel the currently active tool call, if there is one.
if (this.toolCalls.length > 0) {
const activeCall = this.toolCalls[0];
// Only cancel if it's in a cancellable state.
if (
activeCall.status === 'awaiting_approval' ||
activeCall.status === 'executing' ||
activeCall.status === 'scheduled' ||
activeCall.status === 'validating'
) {
this.setStatusInternal(
activeCall.request.callId,
'cancelled',
signal,
'User cancelled the operation.',
);
}
}
// Clear the queue and mark all queued items as cancelled for completion reporting.
this._cancelAllQueuedCalls();
// Finalize the batch immediately.
void this.checkAndNotifyCompletion(signal);
}
private async _schedule(
request: ToolCallRequestInfo | ToolCallRequestInfo[],
signal: AbortSignal,
): Promise<void> {
this.isScheduling = true;
this.isCancelling = false;
try {
if (this.isRunning()) {
throw new Error(
@@ -704,6 +744,7 @@ export class CoreToolScheduler {
);
}
const requestsToProcess = Array.isArray(request) ? request : [request];
this.completedToolCallsForBatch = [];
const newToolCalls: ToolCall[] = requestsToProcess.map(
(reqInfo): ToolCall => {
@@ -753,45 +794,74 @@ export class CoreToolScheduler {
},
);
this.toolCalls = this.toolCalls.concat(newToolCalls);
this.notifyToolCallsUpdate();
this.toolCallQueue.push(...newToolCalls);
await this._processNextInQueue(signal);
} finally {
this.isScheduling = false;
}
}
for (const toolCall of newToolCalls) {
if (toolCall.status !== 'validating') {
continue;
private async _processNextInQueue(signal: AbortSignal): Promise<void> {
// If there's already a tool being processed, or the queue is empty, stop.
if (this.toolCalls.length > 0 || this.toolCallQueue.length === 0) {
return;
}
// If cancellation happened between steps, handle it.
if (signal.aborted) {
this._cancelAllQueuedCalls();
// Finalize the batch.
await this.checkAndNotifyCompletion(signal);
return;
}
const toolCall = this.toolCallQueue.shift()!;
// This is now the single active tool call.
this.toolCalls = [toolCall];
this.notifyToolCallsUpdate();
// Handle tools that were already errored during creation.
if (toolCall.status === 'error') {
// An error during validation means this "active" tool is already complete.
// We need to check for batch completion to either finish or process the next in queue.
await this.checkAndNotifyCompletion(signal);
return;
}
// This logic is moved from the old `for` loop in `_schedule`.
if (toolCall.status === 'validating') {
const { request: reqInfo, invocation } = toolCall;
try {
if (signal.aborted) {
this.setStatusInternal(
reqInfo.callId,
'cancelled',
signal,
'Tool call cancelled by user.',
);
// The completion check will handle the cascade.
await this.checkAndNotifyCompletion(signal);
return;
}
const validatingCall = toolCall as ValidatingToolCall;
const { request: reqInfo, invocation } = validatingCall;
const confirmationDetails =
await invocation.shouldConfirmExecute(signal);
try {
if (signal.aborted) {
this.setStatusInternal(
reqInfo.callId,
'cancelled',
'Tool call cancelled by user.',
);
continue;
}
const confirmationDetails =
await invocation.shouldConfirmExecute(signal);
if (!confirmationDetails) {
if (!confirmationDetails) {
this.setToolCallOutcome(
reqInfo.callId,
ToolConfirmationOutcome.ProceedAlways,
);
this.setStatusInternal(reqInfo.callId, 'scheduled', signal);
} else {
if (this.isAutoApproved(toolCall)) {
this.setToolCallOutcome(
reqInfo.callId,
ToolConfirmationOutcome.ProceedAlways,
);
this.setStatusInternal(reqInfo.callId, 'scheduled');
continue;
}
if (this.isAutoApproved(validatingCall)) {
this.setToolCallOutcome(
reqInfo.callId,
ToolConfirmationOutcome.ProceedAlways,
);
this.setStatusInternal(reqInfo.callId, 'scheduled');
this.setStatusInternal(reqInfo.callId, 'scheduled', signal);
} else {
// Allow IDE to resolve confirmation
if (
@@ -835,35 +905,36 @@ export class CoreToolScheduler {
this.setStatusInternal(
reqInfo.callId,
'awaiting_approval',
signal,
wrappedConfirmationDetails,
);
}
} catch (error) {
if (signal.aborted) {
this.setStatusInternal(
reqInfo.callId,
'cancelled',
'Tool call cancelled by user.',
);
continue;
}
}
} catch (error) {
if (signal.aborted) {
this.setStatusInternal(
reqInfo.callId,
'cancelled',
signal,
'Tool call cancelled by user.',
);
await this.checkAndNotifyCompletion(signal);
} else {
this.setStatusInternal(
reqInfo.callId,
'error',
signal,
createErrorResponse(
reqInfo,
error instanceof Error ? error : new Error(String(error)),
ToolErrorType.UNHANDLED_EXCEPTION,
),
);
await this.checkAndNotifyCompletion(signal);
}
}
await this.attemptExecutionOfScheduledCalls(signal);
void this.checkAndNotifyCompletion();
} finally {
this.isScheduling = false;
}
await this.attemptExecutionOfScheduledCalls(signal);
}
async handleConfirmationResponse(
@@ -881,18 +952,12 @@ export class CoreToolScheduler {
await originalOnConfirm(outcome);
}
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
await this.autoApproveCompatiblePendingTools(signal, callId);
}
this.setToolCallOutcome(callId, outcome);
if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) {
this.setStatusInternal(
callId,
'cancelled',
'User did not allow tool call',
);
// Instead of just cancelling one tool, trigger the full cancel cascade.
this.cancelAll(signal);
return; // `cancelAll` calls `checkAndNotifyCompletion`, so we can exit here.
} else if (outcome === ToolConfirmationOutcome.ModifyWithEditor) {
const waitingToolCall = toolCall as WaitingToolCall;
if (isModifiableDeclarativeTool(waitingToolCall.tool)) {
@@ -902,7 +967,7 @@ export class CoreToolScheduler {
return;
}
this.setStatusInternal(callId, 'awaiting_approval', {
this.setStatusInternal(callId, 'awaiting_approval', signal, {
...waitingToolCall.confirmationDetails,
isModifying: true,
} as ToolCallConfirmationDetails);
@@ -917,7 +982,7 @@ export class CoreToolScheduler {
this.onEditorClose,
);
this.setArgsInternal(callId, updatedParams);
this.setStatusInternal(callId, 'awaiting_approval', {
this.setStatusInternal(callId, 'awaiting_approval', signal, {
...waitingToolCall.confirmationDetails,
fileDiff: updatedDiff,
isModifying: false,
@@ -932,7 +997,7 @@ export class CoreToolScheduler {
signal,
);
}
this.setStatusInternal(callId, 'scheduled');
this.setStatusInternal(callId, 'scheduled', signal);
}
await this.attemptExecutionOfScheduledCalls(signal);
}
@@ -974,10 +1039,15 @@ export class CoreToolScheduler {
);
this.setArgsInternal(toolCall.request.callId, updatedParams);
this.setStatusInternal(toolCall.request.callId, 'awaiting_approval', {
...toolCall.confirmationDetails,
fileDiff: updatedDiff,
});
this.setStatusInternal(
toolCall.request.callId,
'awaiting_approval',
signal,
{
...toolCall.confirmationDetails,
fileDiff: updatedDiff,
},
);
}
private async attemptExecutionOfScheduledCalls(
@@ -1002,7 +1072,7 @@ export class CoreToolScheduler {
const scheduledCall = toolCall;
const { callId, name: toolName } = scheduledCall.request;
const invocation = scheduledCall.invocation;
this.setStatusInternal(callId, 'executing');
this.setStatusInternal(callId, 'executing', signal);
const liveOutputCallback =
scheduledCall.tool.canUpdateOutput && this.outputUpdateHandler
@@ -1055,12 +1125,10 @@ export class CoreToolScheduler {
this.setStatusInternal(
callId,
'cancelled',
signal,
'User cancelled tool execution.',
);
continue;
}
if (toolResult.error === undefined) {
} else if (toolResult.error === undefined) {
let content = toolResult.llmContent;
let outputFile: string | undefined = undefined;
const contentLength =
@@ -1116,7 +1184,7 @@ export class CoreToolScheduler {
outputFile,
contentLength,
};
this.setStatusInternal(callId, 'success', successResponse);
this.setStatusInternal(callId, 'success', signal, successResponse);
} else {
// It is a failure
const error = new Error(toolResult.error.message);
@@ -1125,19 +1193,21 @@ export class CoreToolScheduler {
error,
toolResult.error.type,
);
this.setStatusInternal(callId, 'error', errorResponse);
this.setStatusInternal(callId, 'error', signal, errorResponse);
}
} catch (executionError: unknown) {
if (signal.aborted) {
this.setStatusInternal(
callId,
'cancelled',
signal,
'User cancelled tool execution.',
);
} else {
this.setStatusInternal(
callId,
'error',
signal,
createErrorResponse(
scheduledCall.request,
executionError instanceof Error
@@ -1148,45 +1218,126 @@ export class CoreToolScheduler {
);
}
}
await this.checkAndNotifyCompletion(signal);
}
}
}
private async checkAndNotifyCompletion(): Promise<void> {
const allCallsAreTerminal = this.toolCalls.every(
(call) =>
call.status === 'success' ||
call.status === 'error' ||
call.status === 'cancelled',
);
private async checkAndNotifyCompletion(signal: AbortSignal): Promise<void> {
// This method is now only concerned with the single active tool call.
if (this.toolCalls.length === 0) {
// It's possible to be called when a batch is cancelled before any tool has started.
if (signal.aborted && this.toolCallQueue.length > 0) {
this._cancelAllQueuedCalls();
}
} else {
const activeCall = this.toolCalls[0];
const isTerminal =
activeCall.status === 'success' ||
activeCall.status === 'error' ||
activeCall.status === 'cancelled';
if (this.toolCalls.length > 0 && allCallsAreTerminal) {
const completedCalls = [...this.toolCalls] as CompletedToolCall[];
// If the active tool is not in a terminal state (e.g., it's 'executing' or 'awaiting_approval'),
// then the scheduler is still busy or paused. We should not proceed.
if (!isTerminal) {
return;
}
// The active tool is finished. Move it to the completed batch.
const completedCall = activeCall as CompletedToolCall;
this.completedToolCallsForBatch.push(completedCall);
logToolCall(this.config, new ToolCallEvent(completedCall));
// Clear the active tool slot. This is crucial for the sequential processing.
this.toolCalls = [];
}
for (const call of completedCalls) {
logToolCall(this.config, new ToolCallEvent(call));
// Now, check if the entire batch is complete.
// The batch is complete if the queue is empty or the operation was cancelled.
if (this.toolCallQueue.length === 0 || signal.aborted) {
if (signal.aborted) {
this._cancelAllQueuedCalls();
}
// If there's nothing to report and we weren't cancelled, we can stop.
// But if we were cancelled, we must proceed to potentially start the next queued request.
if (this.completedToolCallsForBatch.length === 0 && !signal.aborted) {
return;
}
if (this.onAllToolCallsComplete) {
this.isFinalizingToolCalls = true;
await this.onAllToolCallsComplete(completedCalls);
// Use the batch array, not the (now empty) active array.
await this.onAllToolCallsComplete(this.completedToolCallsForBatch);
this.completedToolCallsForBatch = []; // Clear after reporting.
this.isFinalizingToolCalls = false;
}
this.isCancelling = false;
this.notifyToolCallsUpdate();
// After completion, process the next item in the queue.
// After completion of the entire batch, process the next item in the main request queue.
if (this.requestQueue.length > 0) {
const next = this.requestQueue.shift()!;
this._schedule(next.request, next.signal)
.then(next.resolve)
.catch(next.reject);
}
} else {
// The batch is not yet complete, so continue processing the current batch sequence.
await this._processNextInQueue(signal);
}
}
private _cancelAllQueuedCalls(): void {
while (this.toolCallQueue.length > 0) {
const queuedCall = this.toolCallQueue.shift()!;
// Don't cancel tools that already errored during validation.
if (queuedCall.status === 'error') {
this.completedToolCallsForBatch.push(queuedCall);
continue;
}
const durationMs =
'startTime' in queuedCall && queuedCall.startTime
? Date.now() - queuedCall.startTime
: undefined;
const errorMessage =
'[Operation Cancelled] User cancelled the operation.';
this.completedToolCallsForBatch.push({
request: queuedCall.request,
tool: queuedCall.tool,
invocation: queuedCall.invocation,
status: 'cancelled',
response: {
callId: queuedCall.request.callId,
responseParts: [
{
functionResponse: {
id: queuedCall.request.callId,
name: queuedCall.request.name,
response: {
error: errorMessage,
},
},
},
],
resultDisplay: undefined,
error: undefined,
errorType: undefined,
contentLength: errorMessage.length,
},
durationMs,
outcome: ToolConfirmationOutcome.Cancel,
});
}
}
private notifyToolCallsUpdate(): void {
if (this.onToolCallsUpdate) {
this.onToolCallsUpdate([...this.toolCalls]);
this.onToolCallsUpdate([
...this.completedToolCallsForBatch,
...this.toolCalls,
...this.toolCallQueue,
]);
}
}
@@ -1215,35 +1366,4 @@ export class CoreToolScheduler {
return doesToolInvocationMatch(tool, invocation, allowedTools);
}
private async autoApproveCompatiblePendingTools(
signal: AbortSignal,
triggeringCallId: string,
): Promise<void> {
const pendingTools = this.toolCalls.filter(
(call) =>
call.status === 'awaiting_approval' &&
call.request.callId !== triggeringCallId,
) as WaitingToolCall[];
for (const pendingTool of pendingTools) {
try {
const stillNeedsConfirmation =
await pendingTool.invocation.shouldConfirmExecute(signal);
if (!stillNeedsConfirmation) {
this.setToolCallOutcome(
pendingTool.request.callId,
ToolConfirmationOutcome.ProceedAlways,
);
this.setStatusInternal(pendingTool.request.callId, 'scheduled');
}
} catch (error) {
console.error(
`Error checking confirmation for tool ${pendingTool.request.callId}:`,
error,
);
}
}
}
}