mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-24 20:14:44 -07:00
feat(core, cli): Implement sequential approval. (#11593)
This commit is contained in:
@@ -4,11 +4,12 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { Task } from './task.js';
|
||||
import type { Config, ToolCallRequestInfo } from '@google/gemini-cli-core';
|
||||
import { createMockConfig } from '../utils/testing_utils.js';
|
||||
import type { ExecutionEventBus } from '@a2a-js/sdk/server';
|
||||
import type { ToolCall } from '@google/gemini-cli-core';
|
||||
|
||||
describe('Task', () => {
|
||||
it('scheduleToolCalls should not modify the input requests array', async () => {
|
||||
@@ -94,4 +95,122 @@ describe('Task', () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('_schedulerToolCallsUpdate', () => {
|
||||
let task: Task;
|
||||
type SpyInstance = ReturnType<typeof vi.spyOn>;
|
||||
let setTaskStateAndPublishUpdateSpy: SpyInstance;
|
||||
|
||||
beforeEach(() => {
|
||||
const mockConfig = createMockConfig();
|
||||
const mockEventBus: ExecutionEventBus = {
|
||||
publish: vi.fn(),
|
||||
on: vi.fn(),
|
||||
off: vi.fn(),
|
||||
once: vi.fn(),
|
||||
removeAllListeners: vi.fn(),
|
||||
finished: vi.fn(),
|
||||
};
|
||||
|
||||
// @ts-expect-error - Calling private constructor
|
||||
task = new Task(
|
||||
'task-id',
|
||||
'context-id',
|
||||
mockConfig as Config,
|
||||
mockEventBus,
|
||||
);
|
||||
|
||||
// Spy on the method we want to check calls for
|
||||
setTaskStateAndPublishUpdateSpy = vi.spyOn(
|
||||
task,
|
||||
'setTaskStateAndPublishUpdate',
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should set state to input-required when a tool is awaiting approval and none are executing', () => {
|
||||
const toolCalls = [
|
||||
{ request: { callId: '1' }, status: 'awaiting_approval' },
|
||||
] as ToolCall[];
|
||||
|
||||
// @ts-expect-error - Calling private method
|
||||
task._schedulerToolCallsUpdate(toolCalls);
|
||||
|
||||
// The last call should be the final state update
|
||||
expect(setTaskStateAndPublishUpdateSpy).toHaveBeenLastCalledWith(
|
||||
'input-required',
|
||||
{ kind: 'state-change' },
|
||||
undefined,
|
||||
undefined,
|
||||
true, // final: true
|
||||
);
|
||||
});
|
||||
|
||||
it('should NOT set state to input-required if a tool is awaiting approval but another is executing', () => {
|
||||
const toolCalls = [
|
||||
{ request: { callId: '1' }, status: 'awaiting_approval' },
|
||||
{ request: { callId: '2' }, status: 'executing' },
|
||||
] as ToolCall[];
|
||||
|
||||
// @ts-expect-error - Calling private method
|
||||
task._schedulerToolCallsUpdate(toolCalls);
|
||||
|
||||
// It will be called for status updates, but not with final: true
|
||||
const finalCall = setTaskStateAndPublishUpdateSpy.mock.calls.find(
|
||||
(call) => call[4] === true,
|
||||
);
|
||||
expect(finalCall).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should set state to input-required once an executing tool finishes, leaving one awaiting approval', () => {
|
||||
const initialToolCalls = [
|
||||
{ request: { callId: '1' }, status: 'awaiting_approval' },
|
||||
{ request: { callId: '2' }, status: 'executing' },
|
||||
] as ToolCall[];
|
||||
// @ts-expect-error - Calling private method
|
||||
task._schedulerToolCallsUpdate(initialToolCalls);
|
||||
|
||||
// No final call yet
|
||||
let finalCall = setTaskStateAndPublishUpdateSpy.mock.calls.find(
|
||||
(call) => call[4] === true,
|
||||
);
|
||||
expect(finalCall).toBeUndefined();
|
||||
|
||||
// Now, the executing tool finishes. The scheduler would call _resolveToolCall for it.
|
||||
// @ts-expect-error - Calling private method
|
||||
task._resolveToolCall('2');
|
||||
|
||||
// Then another update comes in for the awaiting tool (e.g., a re-check)
|
||||
const subsequentToolCalls = [
|
||||
{ request: { callId: '1' }, status: 'awaiting_approval' },
|
||||
] as ToolCall[];
|
||||
// @ts-expect-error - Calling private method
|
||||
task._schedulerToolCallsUpdate(subsequentToolCalls);
|
||||
|
||||
// NOW we should get the final call
|
||||
finalCall = setTaskStateAndPublishUpdateSpy.mock.calls.find(
|
||||
(call) => call[4] === true,
|
||||
);
|
||||
expect(finalCall).toBeDefined();
|
||||
expect(finalCall?.[0]).toBe('input-required');
|
||||
});
|
||||
|
||||
it('should NOT set state to input-required if skipFinalTrueAfterInlineEdit is true', () => {
|
||||
task.skipFinalTrueAfterInlineEdit = true;
|
||||
const toolCalls = [
|
||||
{ request: { callId: '1' }, status: 'awaiting_approval' },
|
||||
] as ToolCall[];
|
||||
|
||||
// @ts-expect-error - Calling private method
|
||||
task._schedulerToolCallsUpdate(toolCalls);
|
||||
|
||||
const finalCall = setTaskStateAndPublishUpdateSpy.mock.calls.find(
|
||||
(call) => call[4] === true,
|
||||
);
|
||||
expect(finalCall).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -40,7 +40,6 @@ import type {
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { logger } from '../utils/logger.js';
|
||||
import * as fs from 'node:fs';
|
||||
|
||||
import { CoderAgentEvent } from '../types.js';
|
||||
import type {
|
||||
CoderAgentMessage,
|
||||
@@ -373,11 +372,11 @@ export class Task {
|
||||
|
||||
// Only send an update if the status has actually changed.
|
||||
if (hasChanged) {
|
||||
const message = this.toolStatusMessage(tc, this.id, this.contextId);
|
||||
const coderAgentMessage: CoderAgentMessage =
|
||||
tc.status === 'awaiting_approval'
|
||||
? { kind: CoderAgentEvent.ToolCallConfirmationEvent }
|
||||
: { kind: CoderAgentEvent.ToolCallUpdateEvent };
|
||||
const message = this.toolStatusMessage(tc, this.id, this.contextId);
|
||||
|
||||
const event = this._createStatusUpdateEvent(
|
||||
this.taskState,
|
||||
@@ -404,20 +403,16 @@ export class Task {
|
||||
const isAwaitingApproval = allPendingStatuses.some(
|
||||
(status) => status === 'awaiting_approval',
|
||||
);
|
||||
const allPendingAreStable = allPendingStatuses.every(
|
||||
(status) =>
|
||||
status === 'awaiting_approval' ||
|
||||
status === 'success' ||
|
||||
status === 'error' ||
|
||||
status === 'cancelled',
|
||||
const isExecuting = allPendingStatuses.some(
|
||||
(status) => status === 'executing',
|
||||
);
|
||||
|
||||
// 1. Are any pending tool calls awaiting_approval
|
||||
// 2. Are all pending tool calls in a stable state (i.e. not in validing or executing)
|
||||
// 3. After an inline edit, the edited tool call will send awaiting_approval THEN scheduled. We wait for the next update in this case.
|
||||
// The turn is complete and requires user input if at least one tool
|
||||
// is waiting for the user's decision, and no other tool is actively
|
||||
// running in the background.
|
||||
if (
|
||||
isAwaitingApproval &&
|
||||
allPendingAreStable &&
|
||||
!isExecuting &&
|
||||
!this.skipFinalTrueAfterInlineEdit
|
||||
) {
|
||||
this.skipFinalTrueAfterInlineEdit = false;
|
||||
|
||||
@@ -313,7 +313,7 @@ describe('E2E Tests', () => {
|
||||
expect(workingEvent.kind).toBe('status-update');
|
||||
expect(workingEvent.status.state).toBe('working');
|
||||
|
||||
// State Update: Validate each tool call
|
||||
// State Update: Validate the first tool call
|
||||
const toolCallValidateEvent1 = events[3].result as TaskStatusUpdateEvent;
|
||||
expect(toolCallValidateEvent1.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-update',
|
||||
@@ -326,47 +326,218 @@ describe('E2E Tests', () => {
|
||||
},
|
||||
},
|
||||
]);
|
||||
const toolCallValidateEvent2 = events[4].result as TaskStatusUpdateEvent;
|
||||
expect(toolCallValidateEvent2.metadata?.['coderAgent']).toMatchObject({
|
||||
|
||||
// --- Assert the event stream ---
|
||||
// 1. Initial "submitted" status.
|
||||
expect((events[0].result as TaskStatusUpdateEvent).status.state).toBe(
|
||||
'submitted',
|
||||
);
|
||||
|
||||
// 2. "working" status after receiving the user prompt.
|
||||
expect((events[1].result as TaskStatusUpdateEvent).status.state).toBe(
|
||||
'working',
|
||||
);
|
||||
|
||||
// 3. A "state-change" event from the agent.
|
||||
expect(events[2].result.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'state-change',
|
||||
});
|
||||
|
||||
// 4. Tool 1 is validating.
|
||||
const toolCallUpdate1 = events[3].result as TaskStatusUpdateEvent;
|
||||
expect(toolCallUpdate1.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-update',
|
||||
});
|
||||
expect(toolCallValidateEvent2.status.message?.parts).toMatchObject([
|
||||
expect(toolCallUpdate1.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
request: { callId: 'test-call-id-1' },
|
||||
status: 'validating',
|
||||
request: { callId: 'test-call-id-2' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
// State Update: Set each tool call to awaiting
|
||||
const toolCallAwaitEvent1 = events[5].result as TaskStatusUpdateEvent;
|
||||
expect(toolCallAwaitEvent1.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-confirmation',
|
||||
// 5. Tool 2 is validating.
|
||||
const toolCallUpdate2 = events[4].result as TaskStatusUpdateEvent;
|
||||
expect(toolCallUpdate2.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-update',
|
||||
});
|
||||
expect(toolCallAwaitEvent1.status.message?.parts).toMatchObject([
|
||||
expect(toolCallUpdate2.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
status: 'awaiting_approval',
|
||||
request: { callId: 'test-call-id-1' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
const toolCallAwaitEvent2 = events[6].result as TaskStatusUpdateEvent;
|
||||
expect(toolCallAwaitEvent2.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-confirmation',
|
||||
});
|
||||
expect(toolCallAwaitEvent2.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
status: 'awaiting_approval',
|
||||
request: { callId: 'test-call-id-2' },
|
||||
status: 'validating',
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
// 6. Tool 1 is awaiting approval.
|
||||
const toolCallAwaitEvent = events[5].result as TaskStatusUpdateEvent;
|
||||
expect(toolCallAwaitEvent.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-confirmation',
|
||||
});
|
||||
expect(toolCallAwaitEvent.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
request: { callId: 'test-call-id-1' },
|
||||
status: 'awaiting_approval',
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
// 7. The final event is "input-required".
|
||||
const finalEvent = events[6].result as TaskStatusUpdateEvent;
|
||||
expect(finalEvent.final).toBe(true);
|
||||
expect(finalEvent.status.state).toBe('input-required');
|
||||
|
||||
// The scheduler now waits for approval, so no more events are sent.
|
||||
assertUniqueFinalEventIsLast(events);
|
||||
expect(events.length).toBe(7);
|
||||
});
|
||||
|
||||
it('should handle multiple tool calls sequentially in YOLO mode', async () => {
|
||||
// Set YOLO mode to auto-approve tools and test sequential execution.
|
||||
getApprovalModeSpy.mockReturnValue(ApprovalMode.YOLO);
|
||||
|
||||
// First call yields the tool request
|
||||
sendMessageStreamSpy.mockImplementationOnce(async function* () {
|
||||
yield* [
|
||||
{
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: {
|
||||
callId: 'test-call-id-1',
|
||||
name: 'test-tool-1',
|
||||
args: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: {
|
||||
callId: 'test-call-id-2',
|
||||
name: 'test-tool-2',
|
||||
args: {},
|
||||
},
|
||||
},
|
||||
];
|
||||
});
|
||||
// Subsequent calls yield nothing, as the tools will "succeed".
|
||||
sendMessageStreamSpy.mockImplementation(async function* () {
|
||||
yield* [{ type: 'content', value: 'All tools executed.' }];
|
||||
});
|
||||
|
||||
const mockTool1 = new MockTool({
|
||||
name: 'test-tool-1',
|
||||
displayName: 'Test Tool 1',
|
||||
shouldConfirmExecute: vi.fn(mockToolConfirmationFn),
|
||||
execute: vi
|
||||
.fn()
|
||||
.mockResolvedValue({ llmContent: 'tool 1 done', returnDisplay: '' }),
|
||||
});
|
||||
const mockTool2 = new MockTool({
|
||||
name: 'test-tool-2',
|
||||
displayName: 'Test Tool 2',
|
||||
shouldConfirmExecute: vi.fn(mockToolConfirmationFn),
|
||||
execute: vi
|
||||
.fn()
|
||||
.mockResolvedValue({ llmContent: 'tool 2 done', returnDisplay: '' }),
|
||||
});
|
||||
|
||||
getToolRegistrySpy.mockReturnValue({
|
||||
getAllTools: vi.fn().mockReturnValue([mockTool1, mockTool2]),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getTool: vi.fn().mockImplementation((name: string) => {
|
||||
if (name === 'test-tool-1') return mockTool1;
|
||||
if (name === 'test-tool-2') return mockTool2;
|
||||
return undefined;
|
||||
}),
|
||||
});
|
||||
|
||||
const agent = request.agent(app);
|
||||
const res = await agent
|
||||
.post('/')
|
||||
.send(
|
||||
createStreamMessageRequest(
|
||||
'run two tools',
|
||||
'a2a-multi-tool-test-message',
|
||||
),
|
||||
)
|
||||
.set('Content-Type', 'application/json')
|
||||
.expect(200);
|
||||
|
||||
const events = streamToSSEEvents(res.text);
|
||||
assertTaskCreationAndWorkingStatus(events);
|
||||
|
||||
// --- Assert the sequential execution flow ---
|
||||
const eventStream = events.slice(2).map((e) => {
|
||||
const update = e.result as TaskStatusUpdateEvent;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const agentData = update.metadata?.['coderAgent'] as any;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const toolData = update.status.message?.parts[0] as any;
|
||||
if (!toolData) {
|
||||
return { kind: agentData.kind };
|
||||
}
|
||||
return {
|
||||
kind: agentData.kind,
|
||||
status: toolData.data?.status,
|
||||
callId: toolData.data?.request.callId,
|
||||
};
|
||||
});
|
||||
|
||||
const expectedFlow = [
|
||||
// Initial state change
|
||||
{ kind: 'state-change', status: undefined, callId: undefined },
|
||||
// Tool 1 Lifecycle
|
||||
{
|
||||
kind: 'tool-call-update',
|
||||
status: 'validating',
|
||||
callId: 'test-call-id-1',
|
||||
},
|
||||
{
|
||||
kind: 'tool-call-update',
|
||||
status: 'scheduled',
|
||||
callId: 'test-call-id-1',
|
||||
},
|
||||
{
|
||||
kind: 'tool-call-update',
|
||||
status: 'executing',
|
||||
callId: 'test-call-id-1',
|
||||
},
|
||||
{
|
||||
kind: 'tool-call-update',
|
||||
status: 'success',
|
||||
callId: 'test-call-id-1',
|
||||
},
|
||||
// Tool 2 Lifecycle
|
||||
{
|
||||
kind: 'tool-call-update',
|
||||
status: 'validating',
|
||||
callId: 'test-call-id-2',
|
||||
},
|
||||
{
|
||||
kind: 'tool-call-update',
|
||||
status: 'scheduled',
|
||||
callId: 'test-call-id-2',
|
||||
},
|
||||
{
|
||||
kind: 'tool-call-update',
|
||||
status: 'executing',
|
||||
callId: 'test-call-id-2',
|
||||
},
|
||||
{
|
||||
kind: 'tool-call-update',
|
||||
status: 'success',
|
||||
callId: 'test-call-id-2',
|
||||
},
|
||||
// Final updates
|
||||
{ kind: 'state-change', status: undefined, callId: undefined },
|
||||
{ kind: 'text-content', status: undefined, callId: undefined },
|
||||
];
|
||||
|
||||
// Use `toContainEqual` for flexibility if other events are interspersed.
|
||||
expect(eventStream).toEqual(expect.arrayContaining(expectedFlow));
|
||||
|
||||
assertUniqueFinalEventIsLast(events);
|
||||
expect(events.length).toBe(8);
|
||||
});
|
||||
|
||||
it('should handle tool calls that do not require approval', async () => {
|
||||
|
||||
@@ -37,7 +37,7 @@ import {
|
||||
} from '@google/gemini-cli-core';
|
||||
import type { Part, PartListUnion } from '@google/genai';
|
||||
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import type { HistoryItem, SlashCommandProcessorResult } from '../types.js';
|
||||
import type { SlashCommandProcessorResult } from '../types.js';
|
||||
import { MessageType, StreamingState } from '../types.js';
|
||||
import type { LoadedSettings } from '../../config/settings.js';
|
||||
|
||||
@@ -231,8 +231,9 @@ describe('useGeminiStream', () => {
|
||||
mockUseReactToolScheduler.mockReturnValue([
|
||||
[], // Default to empty array for toolCalls
|
||||
mockScheduleToolCalls,
|
||||
mockCancelAllToolCalls,
|
||||
mockMarkToolsAsSubmitted,
|
||||
vi.fn(), // setToolCallsForDisplay
|
||||
mockCancelAllToolCalls,
|
||||
]);
|
||||
|
||||
// Reset mocks for GeminiClient instance methods (startChat and sendMessageStream)
|
||||
@@ -259,38 +260,71 @@ describe('useGeminiStream', () => {
|
||||
initialToolCalls: TrackedToolCall[] = [],
|
||||
geminiClient?: any,
|
||||
) => {
|
||||
let currentToolCalls = initialToolCalls;
|
||||
const setToolCalls = (newToolCalls: TrackedToolCall[]) => {
|
||||
currentToolCalls = newToolCalls;
|
||||
};
|
||||
|
||||
mockUseReactToolScheduler.mockImplementation(() => [
|
||||
currentToolCalls,
|
||||
mockScheduleToolCalls,
|
||||
mockCancelAllToolCalls,
|
||||
mockMarkToolsAsSubmitted,
|
||||
]);
|
||||
|
||||
const client = geminiClient || mockConfig.getGeminiClient();
|
||||
|
||||
const initialProps = {
|
||||
client,
|
||||
history: [],
|
||||
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||
config: mockConfig,
|
||||
onDebugMessage: mockOnDebugMessage,
|
||||
handleSlashCommand: mockHandleSlashCommand as unknown as (
|
||||
cmd: PartListUnion,
|
||||
) => Promise<SlashCommandProcessorResult | false>,
|
||||
shellModeActive: false,
|
||||
loadedSettings: mockLoadedSettings,
|
||||
toolCalls: initialToolCalls,
|
||||
};
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
(props: {
|
||||
client: any;
|
||||
history: HistoryItem[];
|
||||
addItem: UseHistoryManagerReturn['addItem'];
|
||||
config: Config;
|
||||
onDebugMessage: (message: string) => void;
|
||||
handleSlashCommand: (
|
||||
cmd: PartListUnion,
|
||||
) => Promise<SlashCommandProcessorResult | false>;
|
||||
shellModeActive: boolean;
|
||||
loadedSettings: LoadedSettings;
|
||||
toolCalls?: TrackedToolCall[]; // Allow passing updated toolCalls
|
||||
}) => {
|
||||
// Update the mock's return value if new toolCalls are passed in props
|
||||
if (props.toolCalls) {
|
||||
setToolCalls(props.toolCalls);
|
||||
}
|
||||
(props: typeof initialProps) => {
|
||||
// This mock needs to be stateful. When setToolCallsForDisplay is called,
|
||||
// it should trigger a rerender with the new state.
|
||||
const mockSetToolCallsForDisplay = vi.fn((updater) => {
|
||||
const newToolCalls =
|
||||
typeof updater === 'function' ? updater(props.toolCalls) : updater;
|
||||
rerender({ ...props, toolCalls: newToolCalls });
|
||||
});
|
||||
|
||||
// Create a stateful mock for cancellation that updates the toolCalls state.
|
||||
const statefulCancelAllToolCalls = vi.fn((...args) => {
|
||||
// Call the original spy so `toHaveBeenCalled` checks still work.
|
||||
mockCancelAllToolCalls(...args);
|
||||
|
||||
const newToolCalls = props.toolCalls.map((tc) => {
|
||||
// Only cancel tools that are in a cancellable state.
|
||||
if (
|
||||
tc.status === 'awaiting_approval' ||
|
||||
tc.status === 'executing' ||
|
||||
tc.status === 'scheduled' ||
|
||||
tc.status === 'validating'
|
||||
) {
|
||||
// A real cancelled tool call has a response object.
|
||||
// We need to simulate this to avoid type errors downstream.
|
||||
return {
|
||||
...tc,
|
||||
status: 'cancelled',
|
||||
response: {
|
||||
callId: tc.request.callId,
|
||||
responseParts: [],
|
||||
resultDisplay: 'Request cancelled.',
|
||||
},
|
||||
responseSubmittedToGemini: true, // Mark as "processed"
|
||||
} as any as TrackedCancelledToolCall;
|
||||
}
|
||||
return tc;
|
||||
});
|
||||
rerender({ ...props, toolCalls: newToolCalls });
|
||||
});
|
||||
|
||||
mockUseReactToolScheduler.mockImplementation(() => [
|
||||
props.toolCalls,
|
||||
mockScheduleToolCalls,
|
||||
mockMarkToolsAsSubmitted,
|
||||
mockSetToolCallsForDisplay,
|
||||
statefulCancelAllToolCalls, // Use the stateful mock
|
||||
]);
|
||||
|
||||
return useGeminiStream(
|
||||
props.client,
|
||||
props.history,
|
||||
@@ -313,19 +347,7 @@ describe('useGeminiStream', () => {
|
||||
);
|
||||
},
|
||||
{
|
||||
initialProps: {
|
||||
client,
|
||||
history: [],
|
||||
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||
config: mockConfig,
|
||||
onDebugMessage: mockOnDebugMessage,
|
||||
handleSlashCommand: mockHandleSlashCommand as unknown as (
|
||||
cmd: PartListUnion,
|
||||
) => Promise<SlashCommandProcessorResult | false>,
|
||||
shellModeActive: false,
|
||||
loadedSettings: mockLoadedSettings,
|
||||
toolCalls: initialToolCalls,
|
||||
},
|
||||
initialProps,
|
||||
},
|
||||
);
|
||||
return {
|
||||
@@ -452,7 +474,7 @@ describe('useGeminiStream', () => {
|
||||
|
||||
mockUseReactToolScheduler.mockImplementation((onComplete) => {
|
||||
capturedOnComplete = onComplete;
|
||||
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted];
|
||||
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted, vi.fn()];
|
||||
});
|
||||
|
||||
renderHook(() =>
|
||||
@@ -535,7 +557,7 @@ describe('useGeminiStream', () => {
|
||||
|
||||
mockUseReactToolScheduler.mockImplementation((onComplete) => {
|
||||
capturedOnComplete = onComplete;
|
||||
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted];
|
||||
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted, vi.fn()];
|
||||
});
|
||||
|
||||
renderHook(() =>
|
||||
@@ -647,7 +669,7 @@ describe('useGeminiStream', () => {
|
||||
|
||||
mockUseReactToolScheduler.mockImplementation((onComplete) => {
|
||||
capturedOnComplete = onComplete;
|
||||
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted];
|
||||
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted, vi.fn()];
|
||||
});
|
||||
|
||||
renderHook(() =>
|
||||
@@ -760,6 +782,7 @@ describe('useGeminiStream', () => {
|
||||
currentToolCalls,
|
||||
mockScheduleToolCalls,
|
||||
mockMarkToolsAsSubmitted,
|
||||
vi.fn(), // setToolCallsForDisplay
|
||||
];
|
||||
});
|
||||
|
||||
@@ -797,6 +820,7 @@ describe('useGeminiStream', () => {
|
||||
completedToolCalls,
|
||||
mockScheduleToolCalls,
|
||||
mockMarkToolsAsSubmitted,
|
||||
vi.fn(), // setToolCallsForDisplay
|
||||
];
|
||||
});
|
||||
|
||||
@@ -1031,7 +1055,7 @@ describe('useGeminiStream', () => {
|
||||
expect(result.current.streamingState).toBe(StreamingState.Idle);
|
||||
});
|
||||
|
||||
it('should not cancel if a tool call is in progress (not just responding)', async () => {
|
||||
it('should cancel if a tool call is in progress', async () => {
|
||||
const toolCalls: TrackedToolCall[] = [
|
||||
{
|
||||
request: { callId: 'call1', name: 'tool1', args: {} },
|
||||
@@ -1052,7 +1076,6 @@ describe('useGeminiStream', () => {
|
||||
} as TrackedExecutingToolCall,
|
||||
];
|
||||
|
||||
const abortSpy = vi.spyOn(AbortController.prototype, 'abort');
|
||||
const { result } = renderTestHook(toolCalls);
|
||||
|
||||
// State is `Responding` because a tool is running
|
||||
@@ -1061,8 +1084,71 @@ describe('useGeminiStream', () => {
|
||||
// Try to cancel
|
||||
simulateEscapeKeyPress();
|
||||
|
||||
// Nothing should happen because the state is not `Responding`
|
||||
expect(abortSpy).not.toHaveBeenCalled();
|
||||
// The cancel function should be called
|
||||
expect(mockCancelAllToolCalls).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should cancel a request when a tool is awaiting confirmation', async () => {
|
||||
const mockOnConfirm = vi.fn().mockResolvedValue(undefined);
|
||||
const toolCalls: TrackedToolCall[] = [
|
||||
{
|
||||
request: {
|
||||
callId: 'confirm-call',
|
||||
name: 'some_tool',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
status: 'awaiting_approval',
|
||||
responseSubmittedToGemini: false,
|
||||
tool: {
|
||||
name: 'some_tool',
|
||||
description: 'a tool',
|
||||
build: vi.fn().mockImplementation((_) => ({
|
||||
getDescription: () => `Mock description`,
|
||||
})),
|
||||
} as any,
|
||||
invocation: {
|
||||
getDescription: () => `Mock description`,
|
||||
} as unknown as AnyToolInvocation,
|
||||
confirmationDetails: {
|
||||
type: 'edit',
|
||||
title: 'Confirm Edit',
|
||||
onConfirm: mockOnConfirm,
|
||||
fileName: 'file.txt',
|
||||
filePath: '/test/file.txt',
|
||||
fileDiff: 'fake diff',
|
||||
originalContent: 'old',
|
||||
newContent: 'new',
|
||||
},
|
||||
} as TrackedWaitingToolCall,
|
||||
];
|
||||
|
||||
const { result } = renderTestHook(toolCalls);
|
||||
|
||||
// State is `WaitingForConfirmation` because a tool is awaiting approval
|
||||
expect(result.current.streamingState).toBe(
|
||||
StreamingState.WaitingForConfirmation,
|
||||
);
|
||||
|
||||
// Try to cancel
|
||||
simulateEscapeKeyPress();
|
||||
|
||||
// The imperative cancel function should be called on the scheduler
|
||||
expect(mockCancelAllToolCalls).toHaveBeenCalled();
|
||||
|
||||
// A cancellation message should be added to history
|
||||
await waitFor(() => {
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
text: 'Request cancelled.',
|
||||
}),
|
||||
expect.any(Number),
|
||||
);
|
||||
});
|
||||
|
||||
// The final state should be idle
|
||||
expect(result.current.streamingState).toBe(StreamingState.Idle);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1282,7 +1368,7 @@ describe('useGeminiStream', () => {
|
||||
|
||||
mockUseReactToolScheduler.mockImplementation((onComplete) => {
|
||||
capturedOnComplete = onComplete;
|
||||
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted];
|
||||
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted, vi.fn()];
|
||||
});
|
||||
|
||||
renderHook(() =>
|
||||
|
||||
@@ -111,6 +111,7 @@ export const useGeminiStream = (
|
||||
const [initError, setInitError] = useState<string | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
const turnCancelledRef = useRef(false);
|
||||
const activeQueryIdRef = useRef<string | null>(null);
|
||||
const [isResponding, setIsResponding] = useState<boolean>(false);
|
||||
const [thought, setThought] = useState<ThoughtSummary | null>(null);
|
||||
const [pendingHistoryItem, pendingHistoryItemRef, setPendingHistoryItem] =
|
||||
@@ -126,47 +127,55 @@ export const useGeminiStream = (
|
||||
return new GitService(config.getProjectRoot(), storage);
|
||||
}, [config, storage]);
|
||||
|
||||
const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] =
|
||||
useReactToolScheduler(
|
||||
async (completedToolCallsFromScheduler) => {
|
||||
// This onComplete is called when ALL scheduled tools for a given batch are done.
|
||||
if (completedToolCallsFromScheduler.length > 0) {
|
||||
// Add the final state of these tools to the history for display.
|
||||
addItem(
|
||||
mapTrackedToolCallsToDisplay(
|
||||
completedToolCallsFromScheduler as TrackedToolCall[],
|
||||
),
|
||||
Date.now(),
|
||||
);
|
||||
|
||||
// Record tool calls with full metadata before sending responses.
|
||||
try {
|
||||
const currentModel =
|
||||
config.getGeminiClient().getCurrentSequenceModel() ??
|
||||
config.getModel();
|
||||
config
|
||||
.getGeminiClient()
|
||||
.getChat()
|
||||
.recordCompletedToolCalls(
|
||||
currentModel,
|
||||
completedToolCallsFromScheduler,
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error recording completed tool call information: ${error}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Handle tool response submission immediately when tools complete
|
||||
await handleCompletedTools(
|
||||
const [
|
||||
toolCalls,
|
||||
scheduleToolCalls,
|
||||
markToolsAsSubmitted,
|
||||
setToolCallsForDisplay,
|
||||
cancelAllToolCalls,
|
||||
] = useReactToolScheduler(
|
||||
async (completedToolCallsFromScheduler) => {
|
||||
// This onComplete is called when ALL scheduled tools for a given batch are done.
|
||||
if (completedToolCallsFromScheduler.length > 0) {
|
||||
// Add the final state of these tools to the history for display.
|
||||
addItem(
|
||||
mapTrackedToolCallsToDisplay(
|
||||
completedToolCallsFromScheduler as TrackedToolCall[],
|
||||
),
|
||||
Date.now(),
|
||||
);
|
||||
|
||||
// Clear the live-updating display now that the final state is in history.
|
||||
setToolCallsForDisplay([]);
|
||||
|
||||
// Record tool calls with full metadata before sending responses.
|
||||
try {
|
||||
const currentModel =
|
||||
config.getGeminiClient().getCurrentSequenceModel() ??
|
||||
config.getModel();
|
||||
config
|
||||
.getGeminiClient()
|
||||
.getChat()
|
||||
.recordCompletedToolCalls(
|
||||
currentModel,
|
||||
completedToolCallsFromScheduler,
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error recording completed tool call information: ${error}`,
|
||||
);
|
||||
}
|
||||
},
|
||||
config,
|
||||
getPreferredEditor,
|
||||
onEditorClose,
|
||||
);
|
||||
|
||||
// Handle tool response submission immediately when tools complete
|
||||
await handleCompletedTools(
|
||||
completedToolCallsFromScheduler as TrackedToolCall[],
|
||||
);
|
||||
}
|
||||
},
|
||||
config,
|
||||
getPreferredEditor,
|
||||
onEditorClose,
|
||||
);
|
||||
|
||||
const pendingToolCallGroupDisplay = useMemo(
|
||||
() =>
|
||||
@@ -265,27 +274,54 @@ export const useGeminiStream = (
|
||||
}, [streamingState, config, history]);
|
||||
|
||||
const cancelOngoingRequest = useCallback(() => {
|
||||
if (streamingState !== StreamingState.Responding) {
|
||||
if (
|
||||
streamingState !== StreamingState.Responding &&
|
||||
streamingState !== StreamingState.WaitingForConfirmation
|
||||
) {
|
||||
return;
|
||||
}
|
||||
if (turnCancelledRef.current) {
|
||||
return;
|
||||
}
|
||||
turnCancelledRef.current = true;
|
||||
abortControllerRef.current?.abort();
|
||||
|
||||
// A full cancellation means no tools have produced a final result yet.
|
||||
// This determines if we show a generic "Request cancelled" message.
|
||||
const isFullCancellation = !toolCalls.some(
|
||||
(tc) => tc.status === 'success' || tc.status === 'error',
|
||||
);
|
||||
|
||||
// Ensure we have an abort controller, creating one if it doesn't exist.
|
||||
if (!abortControllerRef.current) {
|
||||
abortControllerRef.current = new AbortController();
|
||||
}
|
||||
|
||||
// The order is important here.
|
||||
// 1. Fire the signal to interrupt any active async operations.
|
||||
abortControllerRef.current.abort();
|
||||
// 2. Call the imperative cancel to clear the queue of pending tools.
|
||||
cancelAllToolCalls(abortControllerRef.current.signal);
|
||||
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, Date.now());
|
||||
}
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: 'Request cancelled.',
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
setPendingHistoryItem(null);
|
||||
|
||||
// If it was a full cancellation, add the info message now.
|
||||
// Otherwise, we let handleCompletedTools figure out the next step,
|
||||
// which might involve sending partial results back to the model.
|
||||
if (isFullCancellation) {
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: 'Request cancelled.',
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
setIsResponding(false);
|
||||
}
|
||||
|
||||
onCancelSubmit();
|
||||
setIsResponding(false);
|
||||
setShellInputFocused(false);
|
||||
}, [
|
||||
streamingState,
|
||||
@@ -294,6 +330,8 @@ export const useGeminiStream = (
|
||||
onCancelSubmit,
|
||||
pendingHistoryItemRef,
|
||||
setShellInputFocused,
|
||||
cancelAllToolCalls,
|
||||
toolCalls,
|
||||
]);
|
||||
|
||||
useKeypress(
|
||||
@@ -302,7 +340,11 @@ export const useGeminiStream = (
|
||||
cancelOngoingRequest();
|
||||
}
|
||||
},
|
||||
{ isActive: streamingState === StreamingState.Responding },
|
||||
{
|
||||
isActive:
|
||||
streamingState === StreamingState.Responding ||
|
||||
streamingState === StreamingState.WaitingForConfirmation,
|
||||
},
|
||||
);
|
||||
|
||||
const prepareQueryForGemini = useCallback(
|
||||
@@ -764,6 +806,8 @@ export const useGeminiStream = (
|
||||
options?: { isContinuation: boolean },
|
||||
prompt_id?: string,
|
||||
) => {
|
||||
const queryId = `${Date.now()}-${Math.random()}`;
|
||||
activeQueryIdRef.current = queryId;
|
||||
if (
|
||||
(streamingState === StreamingState.Responding ||
|
||||
streamingState === StreamingState.WaitingForConfirmation) &&
|
||||
@@ -901,7 +945,9 @@ export const useGeminiStream = (
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
setIsResponding(false);
|
||||
if (activeQueryIdRef.current === queryId) {
|
||||
setIsResponding(false);
|
||||
}
|
||||
}
|
||||
});
|
||||
},
|
||||
@@ -963,10 +1009,6 @@ export const useGeminiStream = (
|
||||
|
||||
const handleCompletedTools = useCallback(
|
||||
async (completedToolCallsFromScheduler: TrackedToolCall[]) => {
|
||||
if (isResponding) {
|
||||
return;
|
||||
}
|
||||
|
||||
const completedAndReadyToSubmitTools =
|
||||
completedToolCallsFromScheduler.filter(
|
||||
(
|
||||
@@ -1028,6 +1070,19 @@ export const useGeminiStream = (
|
||||
);
|
||||
|
||||
if (allToolsCancelled) {
|
||||
// If the turn was cancelled via the imperative escape key flow,
|
||||
// the cancellation message is added there. We check the ref to avoid duplication.
|
||||
if (!turnCancelledRef.current) {
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: 'Request cancelled.',
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
}
|
||||
setIsResponding(false);
|
||||
|
||||
if (geminiClient) {
|
||||
// We need to manually add the function responses to the history
|
||||
// so the model knows the tools were cancelled.
|
||||
@@ -1074,12 +1129,12 @@ export const useGeminiStream = (
|
||||
);
|
||||
},
|
||||
[
|
||||
isResponding,
|
||||
submitQuery,
|
||||
markToolsAsSubmitted,
|
||||
geminiClient,
|
||||
performMemoryRefresh,
|
||||
modelSwitchedFromQuotaError,
|
||||
addItem,
|
||||
],
|
||||
);
|
||||
|
||||
|
||||
@@ -62,12 +62,20 @@ export type TrackedToolCall =
|
||||
| TrackedCompletedToolCall
|
||||
| TrackedCancelledToolCall;
|
||||
|
||||
export type CancelAllFn = (signal: AbortSignal) => void;
|
||||
|
||||
export function useReactToolScheduler(
|
||||
onComplete: (tools: CompletedToolCall[]) => Promise<void>,
|
||||
config: Config,
|
||||
getPreferredEditor: () => EditorType | undefined,
|
||||
onEditorClose: () => void,
|
||||
): [TrackedToolCall[], ScheduleFn, MarkToolsAsSubmittedFn] {
|
||||
): [
|
||||
TrackedToolCall[],
|
||||
ScheduleFn,
|
||||
MarkToolsAsSubmittedFn,
|
||||
React.Dispatch<React.SetStateAction<TrackedToolCall[]>>,
|
||||
CancelAllFn,
|
||||
] {
|
||||
const [toolCallsForDisplay, setToolCallsForDisplay] = useState<
|
||||
TrackedToolCall[]
|
||||
>([]);
|
||||
@@ -112,37 +120,36 @@ export function useReactToolScheduler(
|
||||
);
|
||||
|
||||
const toolCallsUpdateHandler: ToolCallsUpdateHandler = useCallback(
|
||||
(updatedCoreToolCalls: ToolCall[]) => {
|
||||
setToolCallsForDisplay((prevTrackedCalls) =>
|
||||
updatedCoreToolCalls.map((coreTc) => {
|
||||
const existingTrackedCall = prevTrackedCalls.find(
|
||||
(ptc) => ptc.request.callId === coreTc.request.callId,
|
||||
);
|
||||
// Start with the new core state, then layer on the existing UI state
|
||||
// to ensure UI-only properties like pid are preserved.
|
||||
(allCoreToolCalls: ToolCall[]) => {
|
||||
setToolCallsForDisplay((prevTrackedCalls) => {
|
||||
const prevCallsMap = new Map(
|
||||
prevTrackedCalls.map((c) => [c.request.callId, c]),
|
||||
);
|
||||
|
||||
return allCoreToolCalls.map((coreTc): TrackedToolCall => {
|
||||
const existingTrackedCall = prevCallsMap.get(coreTc.request.callId);
|
||||
|
||||
const responseSubmittedToGemini =
|
||||
existingTrackedCall?.responseSubmittedToGemini ?? false;
|
||||
|
||||
if (coreTc.status === 'executing') {
|
||||
// Preserve live output if it exists from a previous render.
|
||||
const liveOutput = (existingTrackedCall as TrackedExecutingToolCall)
|
||||
?.liveOutput;
|
||||
return {
|
||||
...coreTc,
|
||||
responseSubmittedToGemini,
|
||||
liveOutput: (existingTrackedCall as TrackedExecutingToolCall)
|
||||
?.liveOutput,
|
||||
liveOutput,
|
||||
pid: (coreTc as ExecutingToolCall).pid,
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
...coreTc,
|
||||
responseSubmittedToGemini,
|
||||
};
|
||||
}
|
||||
|
||||
// For other statuses, explicitly set liveOutput and pid to undefined
|
||||
// to ensure they are not carried over from a previous executing state.
|
||||
return {
|
||||
...coreTc,
|
||||
responseSubmittedToGemini,
|
||||
liveOutput: undefined,
|
||||
pid: undefined,
|
||||
};
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
},
|
||||
[setToolCallsForDisplay],
|
||||
);
|
||||
@@ -178,9 +185,10 @@ export function useReactToolScheduler(
|
||||
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||
signal: AbortSignal,
|
||||
) => {
|
||||
setToolCallsForDisplay([]);
|
||||
void scheduler.schedule(request, signal);
|
||||
},
|
||||
[scheduler],
|
||||
[scheduler, setToolCallsForDisplay],
|
||||
);
|
||||
|
||||
const markToolsAsSubmitted: MarkToolsAsSubmittedFn = useCallback(
|
||||
@@ -196,7 +204,20 @@ export function useReactToolScheduler(
|
||||
[],
|
||||
);
|
||||
|
||||
return [toolCallsForDisplay, schedule, markToolsAsSubmitted];
|
||||
const cancelAllToolCalls = useCallback(
|
||||
(signal: AbortSignal) => {
|
||||
scheduler.cancelAll(signal);
|
||||
},
|
||||
[scheduler],
|
||||
);
|
||||
|
||||
return [
|
||||
toolCallsForDisplay,
|
||||
schedule,
|
||||
markToolsAsSubmitted,
|
||||
setToolCallsForDisplay,
|
||||
cancelAllToolCalls,
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -260,9 +260,15 @@ describe('useReactToolScheduler', () => {
|
||||
args: { param: 'value' },
|
||||
} as any;
|
||||
|
||||
let completedToolCalls: ToolCall[] = [];
|
||||
onComplete.mockImplementation((calls) => {
|
||||
completedToolCalls = calls;
|
||||
});
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
@@ -292,7 +298,110 @@ describe('useReactToolScheduler', () => {
|
||||
}),
|
||||
}),
|
||||
]);
|
||||
expect(result.current[0]).toEqual([]);
|
||||
expect(completedToolCalls).toHaveLength(1);
|
||||
expect(completedToolCalls[0].status).toBe('success');
|
||||
expect(completedToolCalls[0].request).toBe(request);
|
||||
});
|
||||
|
||||
it('should clear previous tool calls when scheduling new ones', async () => {
|
||||
mockToolRegistry.getTool.mockReturnValue(mockTool);
|
||||
(mockTool.execute as Mock).mockResolvedValue({
|
||||
llmContent: 'Tool output',
|
||||
returnDisplay: 'Formatted tool output',
|
||||
} as ToolResult);
|
||||
|
||||
const { result } = renderScheduler();
|
||||
const schedule = result.current[1];
|
||||
const setToolCallsForDisplay = result.current[3];
|
||||
|
||||
// Manually set a tool call in the display.
|
||||
const oldToolCall = {
|
||||
request: { callId: 'oldCall' },
|
||||
status: 'success',
|
||||
} as any;
|
||||
act(() => {
|
||||
setToolCallsForDisplay([oldToolCall]);
|
||||
});
|
||||
expect(result.current[0]).toEqual([oldToolCall]);
|
||||
|
||||
const newRequest: ToolCallRequestInfo = {
|
||||
callId: 'newCall',
|
||||
name: 'mockTool',
|
||||
args: {},
|
||||
} as any;
|
||||
act(() => {
|
||||
schedule(newRequest, new AbortController().signal);
|
||||
});
|
||||
|
||||
// After scheduling, the old call should be gone,
|
||||
// and the new one should be in the display in its initial state.
|
||||
expect(result.current[0].length).toBe(1);
|
||||
expect(result.current[0][0].request.callId).toBe('newCall');
|
||||
expect(result.current[0][0].request.callId).not.toBe('oldCall');
|
||||
|
||||
// Let the new call finish.
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
expect(onComplete).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should cancel all running tool calls', async () => {
|
||||
mockToolRegistry.getTool.mockReturnValue(mockTool);
|
||||
|
||||
let resolveExecute: (value: ToolResult) => void = () => {};
|
||||
const executePromise = new Promise<ToolResult>((resolve) => {
|
||||
resolveExecute = resolve;
|
||||
});
|
||||
(mockTool.execute as Mock).mockReturnValue(executePromise);
|
||||
(mockTool.shouldConfirmExecute as Mock).mockResolvedValue(null);
|
||||
|
||||
const { result } = renderScheduler();
|
||||
const schedule = result.current[1];
|
||||
const cancelAllToolCalls = result.current[4];
|
||||
const request: ToolCallRequestInfo = {
|
||||
callId: 'cancelCall',
|
||||
name: 'mockTool',
|
||||
args: {},
|
||||
} as any;
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
}); // validation
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
}); // scheduling
|
||||
|
||||
// At this point, the tool is 'executing' and waiting on the promise.
|
||||
expect(result.current[0][0].status).toBe('executing');
|
||||
|
||||
const cancelController = new AbortController();
|
||||
act(() => {
|
||||
cancelAllToolCalls(cancelController.signal);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(onComplete).toHaveBeenCalledWith([
|
||||
expect.objectContaining({
|
||||
status: 'cancelled',
|
||||
request,
|
||||
}),
|
||||
]);
|
||||
|
||||
// Clean up the pending promise to avoid open handles.
|
||||
resolveExecute({ llmContent: 'output', returnDisplay: 'display' });
|
||||
});
|
||||
|
||||
it('should handle tool not found', async () => {
|
||||
@@ -305,6 +414,11 @@ describe('useReactToolScheduler', () => {
|
||||
args: {},
|
||||
} as any;
|
||||
|
||||
let completedToolCalls: ToolCall[] = [];
|
||||
onComplete.mockImplementation((calls) => {
|
||||
completedToolCalls = calls;
|
||||
});
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
@@ -315,24 +429,15 @@ describe('useReactToolScheduler', () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(onComplete).toHaveBeenCalledWith([
|
||||
expect.objectContaining({
|
||||
status: 'error',
|
||||
request,
|
||||
response: expect.objectContaining({
|
||||
error: expect.objectContaining({
|
||||
message: expect.stringMatching(
|
||||
/Tool "nonexistentTool" not found in registry/,
|
||||
),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
]);
|
||||
const errorMessage = onComplete.mock.calls[0][0][0].response.error.message;
|
||||
expect(errorMessage).toContain('Did you mean one of:');
|
||||
expect(errorMessage).toContain('"mockTool"');
|
||||
expect(errorMessage).toContain('"anotherTool"');
|
||||
expect(result.current[0]).toEqual([]);
|
||||
expect(completedToolCalls).toHaveLength(1);
|
||||
expect(completedToolCalls[0].status).toBe('error');
|
||||
expect(completedToolCalls[0].request).toBe(request);
|
||||
expect((completedToolCalls[0] as any).response.error.message).toContain(
|
||||
'Tool "nonexistentTool" not found in registry',
|
||||
);
|
||||
expect((completedToolCalls[0] as any).response.error.message).toContain(
|
||||
'Did you mean one of:',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle error during shouldConfirmExecute', async () => {
|
||||
@@ -348,6 +453,11 @@ describe('useReactToolScheduler', () => {
|
||||
args: {},
|
||||
} as any;
|
||||
|
||||
let completedToolCalls: ToolCall[] = [];
|
||||
onComplete.mockImplementation((calls) => {
|
||||
completedToolCalls = calls;
|
||||
});
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
@@ -358,16 +468,10 @@ describe('useReactToolScheduler', () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(onComplete).toHaveBeenCalledWith([
|
||||
expect.objectContaining({
|
||||
status: 'error',
|
||||
request,
|
||||
response: expect.objectContaining({
|
||||
error: confirmError,
|
||||
}),
|
||||
}),
|
||||
]);
|
||||
expect(result.current[0]).toEqual([]);
|
||||
expect(completedToolCalls).toHaveLength(1);
|
||||
expect(completedToolCalls[0].status).toBe('error');
|
||||
expect(completedToolCalls[0].request).toBe(request);
|
||||
expect((completedToolCalls[0] as any).response.error).toBe(confirmError);
|
||||
});
|
||||
|
||||
it('should handle error during execute', async () => {
|
||||
@@ -384,6 +488,11 @@ describe('useReactToolScheduler', () => {
|
||||
args: {},
|
||||
} as any;
|
||||
|
||||
let completedToolCalls: ToolCall[] = [];
|
||||
onComplete.mockImplementation((calls) => {
|
||||
completedToolCalls = calls;
|
||||
});
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
@@ -397,16 +506,10 @@ describe('useReactToolScheduler', () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(onComplete).toHaveBeenCalledWith([
|
||||
expect.objectContaining({
|
||||
status: 'error',
|
||||
request,
|
||||
response: expect.objectContaining({
|
||||
error: execError,
|
||||
}),
|
||||
}),
|
||||
]);
|
||||
expect(result.current[0]).toEqual([]);
|
||||
expect(completedToolCalls).toHaveLength(1);
|
||||
expect(completedToolCalls[0].status).toBe('error');
|
||||
expect(completedToolCalls[0].request).toBe(request);
|
||||
expect((completedToolCalls[0] as any).response.error).toBe(execError);
|
||||
});
|
||||
|
||||
it('should handle tool requiring confirmation - approved', async () => {
|
||||
@@ -518,7 +621,7 @@ describe('useReactToolScheduler', () => {
|
||||
functionResponse: expect.objectContaining({
|
||||
response: expect.objectContaining({
|
||||
error:
|
||||
'[Operation Cancelled] Reason: User did not allow tool call',
|
||||
'[Operation Cancelled] Reason: User cancelled the operation.',
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
@@ -705,7 +808,9 @@ describe('useReactToolScheduler', () => {
|
||||
],
|
||||
}),
|
||||
});
|
||||
expect(result.current[0]).toEqual([]);
|
||||
|
||||
expect(completedCalls).toHaveLength(2);
|
||||
expect(completedCalls.every((t) => t.status === 'success')).toBe(true);
|
||||
});
|
||||
|
||||
it('should queue if scheduling while already running', async () => {
|
||||
@@ -774,7 +879,8 @@ describe('useReactToolScheduler', () => {
|
||||
response: expect.objectContaining({ resultDisplay: 'done display' }),
|
||||
}),
|
||||
]);
|
||||
expect(result.current[0]).toEqual([]);
|
||||
const toolCalls = result.current[0];
|
||||
expect(toolCalls).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -288,6 +288,263 @@ describe('CoreToolScheduler', () => {
|
||||
expect(completedCalls[0].status).toBe('cancelled');
|
||||
});
|
||||
|
||||
it('should cancel all tools when cancelAll is called', async () => {
|
||||
const mockTool1 = new MockTool({
|
||||
name: 'mockTool1',
|
||||
shouldConfirmExecute: MOCK_TOOL_SHOULD_CONFIRM_EXECUTE,
|
||||
});
|
||||
const mockTool2 = new MockTool({ name: 'mockTool2' });
|
||||
const mockTool3 = new MockTool({ name: 'mockTool3' });
|
||||
|
||||
const mockToolRegistry = {
|
||||
getTool: (name: string) => {
|
||||
if (name === 'mockTool1') return mockTool1;
|
||||
if (name === 'mockTool2') return mockTool2;
|
||||
if (name === 'mockTool3') return mockTool3;
|
||||
return undefined;
|
||||
},
|
||||
getFunctionDeclarations: () => [],
|
||||
tools: new Map(),
|
||||
discovery: {},
|
||||
registerTool: () => {},
|
||||
getToolByName: (name: string) => {
|
||||
if (name === 'mockTool1') return mockTool1;
|
||||
if (name === 'mockTool2') return mockTool2;
|
||||
if (name === 'mockTool3') return mockTool3;
|
||||
return undefined;
|
||||
},
|
||||
getToolByDisplayName: () => undefined,
|
||||
getTools: () => [],
|
||||
discoverTools: async () => {},
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const onToolCallsUpdate = vi.fn();
|
||||
|
||||
const mockConfig = {
|
||||
getSessionId: () => 'test-session-id',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: () => ApprovalMode.DEFAULT,
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
terminalHeight: 30,
|
||||
}),
|
||||
storage: {
|
||||
getProjectTempDir: () => '/tmp',
|
||||
},
|
||||
getTruncateToolOutputThreshold: () =>
|
||||
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
|
||||
getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES,
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
getPolicyEngine: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config: mockConfig,
|
||||
onAllToolCallsComplete,
|
||||
onToolCallsUpdate,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
onEditorClose: vi.fn(),
|
||||
});
|
||||
|
||||
const abortController = new AbortController();
|
||||
const requests = [
|
||||
{
|
||||
callId: '1',
|
||||
name: 'mockTool1',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
{
|
||||
callId: '2',
|
||||
name: 'mockTool2',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
{
|
||||
callId: '3',
|
||||
name: 'mockTool3',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
];
|
||||
|
||||
// Don't await, let it run in the background
|
||||
void scheduler.schedule(requests, abortController.signal);
|
||||
|
||||
// Wait for the first tool to be awaiting approval
|
||||
await waitForStatus(onToolCallsUpdate, 'awaiting_approval');
|
||||
|
||||
// Cancel all operations
|
||||
scheduler.cancelAll(abortController.signal);
|
||||
abortController.abort(); // Also fire the signal
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
const completedCalls = onAllToolCallsComplete.mock
|
||||
.calls[0][0] as ToolCall[];
|
||||
|
||||
expect(completedCalls).toHaveLength(3);
|
||||
expect(completedCalls.find((c) => c.request.callId === '1')?.status).toBe(
|
||||
'cancelled',
|
||||
);
|
||||
expect(completedCalls.find((c) => c.request.callId === '2')?.status).toBe(
|
||||
'cancelled',
|
||||
);
|
||||
expect(completedCalls.find((c) => c.request.callId === '3')?.status).toBe(
|
||||
'cancelled',
|
||||
);
|
||||
});
|
||||
|
||||
it('should cancel all tools in a batch when one is cancelled via confirmation', async () => {
|
||||
const mockTool1 = new MockTool({
|
||||
name: 'mockTool1',
|
||||
shouldConfirmExecute: MOCK_TOOL_SHOULD_CONFIRM_EXECUTE,
|
||||
});
|
||||
const mockTool2 = new MockTool({ name: 'mockTool2' });
|
||||
const mockTool3 = new MockTool({ name: 'mockTool3' });
|
||||
|
||||
const mockToolRegistry = {
|
||||
getTool: (name: string) => {
|
||||
if (name === 'mockTool1') return mockTool1;
|
||||
if (name === 'mockTool2') return mockTool2;
|
||||
if (name === 'mockTool3') return mockTool3;
|
||||
return undefined;
|
||||
},
|
||||
getFunctionDeclarations: () => [],
|
||||
tools: new Map(),
|
||||
discovery: {},
|
||||
registerTool: () => {},
|
||||
getToolByName: (name: string) => {
|
||||
if (name === 'mockTool1') return mockTool1;
|
||||
if (name === 'mockTool2') return mockTool2;
|
||||
if (name === 'mockTool3') return mockTool3;
|
||||
return undefined;
|
||||
},
|
||||
getToolByDisplayName: () => undefined,
|
||||
getTools: () => [],
|
||||
discoverTools: async () => {},
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const onToolCallsUpdate = vi.fn();
|
||||
|
||||
const mockConfig = {
|
||||
getSessionId: () => 'test-session-id',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: () => ApprovalMode.DEFAULT,
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
terminalHeight: 30,
|
||||
}),
|
||||
storage: {
|
||||
getProjectTempDir: () => '/tmp',
|
||||
},
|
||||
getTruncateToolOutputThreshold: () =>
|
||||
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
|
||||
getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES,
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
getPolicyEngine: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config: mockConfig,
|
||||
onAllToolCallsComplete,
|
||||
onToolCallsUpdate,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
onEditorClose: vi.fn(),
|
||||
});
|
||||
|
||||
const abortController = new AbortController();
|
||||
const requests = [
|
||||
{
|
||||
callId: '1',
|
||||
name: 'mockTool1',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
{
|
||||
callId: '2',
|
||||
name: 'mockTool2',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
{
|
||||
callId: '3',
|
||||
name: 'mockTool3',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
];
|
||||
|
||||
// Don't await, let it run in the background
|
||||
void scheduler.schedule(requests, abortController.signal);
|
||||
|
||||
// Wait for the first tool to be awaiting approval
|
||||
const awaitingCall = (await waitForStatus(
|
||||
onToolCallsUpdate,
|
||||
'awaiting_approval',
|
||||
)) as WaitingToolCall;
|
||||
|
||||
// Cancel the first tool via its confirmation handler
|
||||
await awaitingCall.confirmationDetails.onConfirm(
|
||||
ToolConfirmationOutcome.Cancel,
|
||||
);
|
||||
abortController.abort(); // User cancelling often involves an abort signal
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
const completedCalls = onAllToolCallsComplete.mock
|
||||
.calls[0][0] as ToolCall[];
|
||||
|
||||
expect(completedCalls).toHaveLength(3);
|
||||
expect(completedCalls.find((c) => c.request.callId === '1')?.status).toBe(
|
||||
'cancelled',
|
||||
);
|
||||
expect(completedCalls.find((c) => c.request.callId === '2')?.status).toBe(
|
||||
'cancelled',
|
||||
);
|
||||
expect(completedCalls.find((c) => c.request.callId === '3')?.status).toBe(
|
||||
'cancelled',
|
||||
);
|
||||
});
|
||||
|
||||
it('should mark tool call as cancelled when abort happens during confirmation error', async () => {
|
||||
const abortController = new AbortController();
|
||||
const abortError = new Error('Abort requested during confirmation');
|
||||
@@ -1510,16 +1767,19 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
|
||||
await scheduler.schedule(requests, abortController.signal);
|
||||
|
||||
// Wait for all tools to be awaiting approval
|
||||
// Wait for the FIRST tool to be awaiting approval
|
||||
await vi.waitFor(() => {
|
||||
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
|
||||
// With the sequential scheduler, the update includes the active call and the queue.
|
||||
expect(calls?.length).toBe(3);
|
||||
expect(calls?.every((call) => call.status === 'awaiting_approval')).toBe(
|
||||
true,
|
||||
);
|
||||
expect(calls?.[0].status).toBe('awaiting_approval');
|
||||
expect(calls?.[0].request.callId).toBe('1');
|
||||
// Check that the other two are in the queue (still in 'validating' state)
|
||||
expect(calls?.[1].status).toBe('validating');
|
||||
expect(calls?.[2].status).toBe('validating');
|
||||
});
|
||||
|
||||
expect(pendingConfirmations.length).toBe(3);
|
||||
expect(pendingConfirmations.length).toBe(1);
|
||||
|
||||
// Approve the first tool with ProceedAlways
|
||||
const firstConfirmation = pendingConfirmations[0];
|
||||
@@ -1528,15 +1788,16 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
// Wait for all tools to be completed
|
||||
await vi.waitFor(() => {
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||
const completedCalls = onAllToolCallsComplete.mock.calls.at(
|
||||
-1,
|
||||
)?.[0] as ToolCall[];
|
||||
expect(completedCalls?.length).toBe(3);
|
||||
expect(completedCalls?.every((call) => call.status === 'success')).toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
const completedCalls = onAllToolCallsComplete.mock.calls.at(
|
||||
-1,
|
||||
)?.[0] as ToolCall[];
|
||||
expect(completedCalls?.length).toBe(3);
|
||||
expect(completedCalls?.every((call) => call.status === 'success')).toBe(
|
||||
true,
|
||||
);
|
||||
|
||||
// Verify approval mode was changed
|
||||
expect(approvalMode).toBe(ApprovalMode.AUTO_EDIT);
|
||||
});
|
||||
@@ -1788,11 +2049,10 @@ describe('CoreToolScheduler Sequential Execution', () => {
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// Check that execute was called for all three tools initially
|
||||
expect(executeFn).toHaveBeenCalledTimes(3);
|
||||
// Check that execute was called for the first two tools only
|
||||
expect(executeFn).toHaveBeenCalledTimes(2);
|
||||
expect(executeFn).toHaveBeenCalledWith({ call: 1 });
|
||||
expect(executeFn).toHaveBeenCalledWith({ call: 2 });
|
||||
expect(executeFn).toHaveBeenCalledWith({ call: 3 });
|
||||
|
||||
const completedCalls = onAllToolCallsComplete.mock
|
||||
.calls[0][0] as ToolCall[];
|
||||
|
||||
@@ -348,12 +348,15 @@ export class CoreToolScheduler {
|
||||
private onEditorClose: () => void;
|
||||
private isFinalizingToolCalls = false;
|
||||
private isScheduling = false;
|
||||
private isCancelling = false;
|
||||
private requestQueue: Array<{
|
||||
request: ToolCallRequestInfo | ToolCallRequestInfo[];
|
||||
signal: AbortSignal;
|
||||
resolve: () => void;
|
||||
reject: (reason?: Error) => void;
|
||||
}> = [];
|
||||
private toolCallQueue: ToolCall[] = [];
|
||||
private completedToolCallsForBatch: CompletedToolCall[] = [];
|
||||
|
||||
constructor(options: CoreToolSchedulerOptions) {
|
||||
this.config = options.config;
|
||||
@@ -398,30 +401,36 @@ export class CoreToolScheduler {
|
||||
private setStatusInternal(
|
||||
targetCallId: string,
|
||||
status: 'success',
|
||||
signal: AbortSignal,
|
||||
response: ToolCallResponseInfo,
|
||||
): void;
|
||||
private setStatusInternal(
|
||||
targetCallId: string,
|
||||
status: 'awaiting_approval',
|
||||
signal: AbortSignal,
|
||||
confirmationDetails: ToolCallConfirmationDetails,
|
||||
): void;
|
||||
private setStatusInternal(
|
||||
targetCallId: string,
|
||||
status: 'error',
|
||||
signal: AbortSignal,
|
||||
response: ToolCallResponseInfo,
|
||||
): void;
|
||||
private setStatusInternal(
|
||||
targetCallId: string,
|
||||
status: 'cancelled',
|
||||
signal: AbortSignal,
|
||||
reason: string,
|
||||
): void;
|
||||
private setStatusInternal(
|
||||
targetCallId: string,
|
||||
status: 'executing' | 'scheduled' | 'validating',
|
||||
signal: AbortSignal,
|
||||
): void;
|
||||
private setStatusInternal(
|
||||
targetCallId: string,
|
||||
newStatus: Status,
|
||||
signal: AbortSignal,
|
||||
auxiliaryData?: unknown,
|
||||
): void {
|
||||
this.toolCalls = this.toolCalls.map((currentCall) => {
|
||||
@@ -561,7 +570,6 @@ export class CoreToolScheduler {
|
||||
}
|
||||
});
|
||||
this.notifyToolCallsUpdate();
|
||||
this.checkAndNotifyCompletion();
|
||||
}
|
||||
|
||||
private setArgsInternal(targetCallId: string, args: unknown): void {
|
||||
@@ -692,11 +700,43 @@ export class CoreToolScheduler {
|
||||
return this._schedule(request, signal);
|
||||
}
|
||||
|
||||
cancelAll(signal: AbortSignal): void {
|
||||
if (this.isCancelling) {
|
||||
return;
|
||||
}
|
||||
this.isCancelling = true;
|
||||
// Cancel the currently active tool call, if there is one.
|
||||
if (this.toolCalls.length > 0) {
|
||||
const activeCall = this.toolCalls[0];
|
||||
// Only cancel if it's in a cancellable state.
|
||||
if (
|
||||
activeCall.status === 'awaiting_approval' ||
|
||||
activeCall.status === 'executing' ||
|
||||
activeCall.status === 'scheduled' ||
|
||||
activeCall.status === 'validating'
|
||||
) {
|
||||
this.setStatusInternal(
|
||||
activeCall.request.callId,
|
||||
'cancelled',
|
||||
signal,
|
||||
'User cancelled the operation.',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the queue and mark all queued items as cancelled for completion reporting.
|
||||
this._cancelAllQueuedCalls();
|
||||
|
||||
// Finalize the batch immediately.
|
||||
void this.checkAndNotifyCompletion(signal);
|
||||
}
|
||||
|
||||
private async _schedule(
|
||||
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||
signal: AbortSignal,
|
||||
): Promise<void> {
|
||||
this.isScheduling = true;
|
||||
this.isCancelling = false;
|
||||
try {
|
||||
if (this.isRunning()) {
|
||||
throw new Error(
|
||||
@@ -704,6 +744,7 @@ export class CoreToolScheduler {
|
||||
);
|
||||
}
|
||||
const requestsToProcess = Array.isArray(request) ? request : [request];
|
||||
this.completedToolCallsForBatch = [];
|
||||
|
||||
const newToolCalls: ToolCall[] = requestsToProcess.map(
|
||||
(reqInfo): ToolCall => {
|
||||
@@ -753,45 +794,74 @@ export class CoreToolScheduler {
|
||||
},
|
||||
);
|
||||
|
||||
this.toolCalls = this.toolCalls.concat(newToolCalls);
|
||||
this.notifyToolCallsUpdate();
|
||||
this.toolCallQueue.push(...newToolCalls);
|
||||
await this._processNextInQueue(signal);
|
||||
} finally {
|
||||
this.isScheduling = false;
|
||||
}
|
||||
}
|
||||
|
||||
for (const toolCall of newToolCalls) {
|
||||
if (toolCall.status !== 'validating') {
|
||||
continue;
|
||||
private async _processNextInQueue(signal: AbortSignal): Promise<void> {
|
||||
// If there's already a tool being processed, or the queue is empty, stop.
|
||||
if (this.toolCalls.length > 0 || this.toolCallQueue.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If cancellation happened between steps, handle it.
|
||||
if (signal.aborted) {
|
||||
this._cancelAllQueuedCalls();
|
||||
// Finalize the batch.
|
||||
await this.checkAndNotifyCompletion(signal);
|
||||
return;
|
||||
}
|
||||
|
||||
const toolCall = this.toolCallQueue.shift()!;
|
||||
|
||||
// This is now the single active tool call.
|
||||
this.toolCalls = [toolCall];
|
||||
this.notifyToolCallsUpdate();
|
||||
|
||||
// Handle tools that were already errored during creation.
|
||||
if (toolCall.status === 'error') {
|
||||
// An error during validation means this "active" tool is already complete.
|
||||
// We need to check for batch completion to either finish or process the next in queue.
|
||||
await this.checkAndNotifyCompletion(signal);
|
||||
return;
|
||||
}
|
||||
|
||||
// This logic is moved from the old `for` loop in `_schedule`.
|
||||
if (toolCall.status === 'validating') {
|
||||
const { request: reqInfo, invocation } = toolCall;
|
||||
|
||||
try {
|
||||
if (signal.aborted) {
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'cancelled',
|
||||
signal,
|
||||
'Tool call cancelled by user.',
|
||||
);
|
||||
// The completion check will handle the cascade.
|
||||
await this.checkAndNotifyCompletion(signal);
|
||||
return;
|
||||
}
|
||||
|
||||
const validatingCall = toolCall as ValidatingToolCall;
|
||||
const { request: reqInfo, invocation } = validatingCall;
|
||||
const confirmationDetails =
|
||||
await invocation.shouldConfirmExecute(signal);
|
||||
|
||||
try {
|
||||
if (signal.aborted) {
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'cancelled',
|
||||
'Tool call cancelled by user.',
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const confirmationDetails =
|
||||
await invocation.shouldConfirmExecute(signal);
|
||||
|
||||
if (!confirmationDetails) {
|
||||
if (!confirmationDetails) {
|
||||
this.setToolCallOutcome(
|
||||
reqInfo.callId,
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
this.setStatusInternal(reqInfo.callId, 'scheduled', signal);
|
||||
} else {
|
||||
if (this.isAutoApproved(toolCall)) {
|
||||
this.setToolCallOutcome(
|
||||
reqInfo.callId,
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
this.setStatusInternal(reqInfo.callId, 'scheduled');
|
||||
continue;
|
||||
}
|
||||
|
||||
if (this.isAutoApproved(validatingCall)) {
|
||||
this.setToolCallOutcome(
|
||||
reqInfo.callId,
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
this.setStatusInternal(reqInfo.callId, 'scheduled');
|
||||
this.setStatusInternal(reqInfo.callId, 'scheduled', signal);
|
||||
} else {
|
||||
// Allow IDE to resolve confirmation
|
||||
if (
|
||||
@@ -835,35 +905,36 @@ export class CoreToolScheduler {
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'awaiting_approval',
|
||||
signal,
|
||||
wrappedConfirmationDetails,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'cancelled',
|
||||
'Tool call cancelled by user.',
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
}
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'cancelled',
|
||||
signal,
|
||||
'Tool call cancelled by user.',
|
||||
);
|
||||
await this.checkAndNotifyCompletion(signal);
|
||||
} else {
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'error',
|
||||
signal,
|
||||
createErrorResponse(
|
||||
reqInfo,
|
||||
error instanceof Error ? error : new Error(String(error)),
|
||||
ToolErrorType.UNHANDLED_EXCEPTION,
|
||||
),
|
||||
);
|
||||
await this.checkAndNotifyCompletion(signal);
|
||||
}
|
||||
}
|
||||
await this.attemptExecutionOfScheduledCalls(signal);
|
||||
void this.checkAndNotifyCompletion();
|
||||
} finally {
|
||||
this.isScheduling = false;
|
||||
}
|
||||
await this.attemptExecutionOfScheduledCalls(signal);
|
||||
}
|
||||
|
||||
async handleConfirmationResponse(
|
||||
@@ -881,18 +952,12 @@ export class CoreToolScheduler {
|
||||
await originalOnConfirm(outcome);
|
||||
}
|
||||
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
await this.autoApproveCompatiblePendingTools(signal, callId);
|
||||
}
|
||||
|
||||
this.setToolCallOutcome(callId, outcome);
|
||||
|
||||
if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) {
|
||||
this.setStatusInternal(
|
||||
callId,
|
||||
'cancelled',
|
||||
'User did not allow tool call',
|
||||
);
|
||||
// Instead of just cancelling one tool, trigger the full cancel cascade.
|
||||
this.cancelAll(signal);
|
||||
return; // `cancelAll` calls `checkAndNotifyCompletion`, so we can exit here.
|
||||
} else if (outcome === ToolConfirmationOutcome.ModifyWithEditor) {
|
||||
const waitingToolCall = toolCall as WaitingToolCall;
|
||||
if (isModifiableDeclarativeTool(waitingToolCall.tool)) {
|
||||
@@ -902,7 +967,7 @@ export class CoreToolScheduler {
|
||||
return;
|
||||
}
|
||||
|
||||
this.setStatusInternal(callId, 'awaiting_approval', {
|
||||
this.setStatusInternal(callId, 'awaiting_approval', signal, {
|
||||
...waitingToolCall.confirmationDetails,
|
||||
isModifying: true,
|
||||
} as ToolCallConfirmationDetails);
|
||||
@@ -917,7 +982,7 @@ export class CoreToolScheduler {
|
||||
this.onEditorClose,
|
||||
);
|
||||
this.setArgsInternal(callId, updatedParams);
|
||||
this.setStatusInternal(callId, 'awaiting_approval', {
|
||||
this.setStatusInternal(callId, 'awaiting_approval', signal, {
|
||||
...waitingToolCall.confirmationDetails,
|
||||
fileDiff: updatedDiff,
|
||||
isModifying: false,
|
||||
@@ -932,7 +997,7 @@ export class CoreToolScheduler {
|
||||
signal,
|
||||
);
|
||||
}
|
||||
this.setStatusInternal(callId, 'scheduled');
|
||||
this.setStatusInternal(callId, 'scheduled', signal);
|
||||
}
|
||||
await this.attemptExecutionOfScheduledCalls(signal);
|
||||
}
|
||||
@@ -974,10 +1039,15 @@ export class CoreToolScheduler {
|
||||
);
|
||||
|
||||
this.setArgsInternal(toolCall.request.callId, updatedParams);
|
||||
this.setStatusInternal(toolCall.request.callId, 'awaiting_approval', {
|
||||
...toolCall.confirmationDetails,
|
||||
fileDiff: updatedDiff,
|
||||
});
|
||||
this.setStatusInternal(
|
||||
toolCall.request.callId,
|
||||
'awaiting_approval',
|
||||
signal,
|
||||
{
|
||||
...toolCall.confirmationDetails,
|
||||
fileDiff: updatedDiff,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
private async attemptExecutionOfScheduledCalls(
|
||||
@@ -1002,7 +1072,7 @@ export class CoreToolScheduler {
|
||||
const scheduledCall = toolCall;
|
||||
const { callId, name: toolName } = scheduledCall.request;
|
||||
const invocation = scheduledCall.invocation;
|
||||
this.setStatusInternal(callId, 'executing');
|
||||
this.setStatusInternal(callId, 'executing', signal);
|
||||
|
||||
const liveOutputCallback =
|
||||
scheduledCall.tool.canUpdateOutput && this.outputUpdateHandler
|
||||
@@ -1055,12 +1125,10 @@ export class CoreToolScheduler {
|
||||
this.setStatusInternal(
|
||||
callId,
|
||||
'cancelled',
|
||||
signal,
|
||||
'User cancelled tool execution.',
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (toolResult.error === undefined) {
|
||||
} else if (toolResult.error === undefined) {
|
||||
let content = toolResult.llmContent;
|
||||
let outputFile: string | undefined = undefined;
|
||||
const contentLength =
|
||||
@@ -1116,7 +1184,7 @@ export class CoreToolScheduler {
|
||||
outputFile,
|
||||
contentLength,
|
||||
};
|
||||
this.setStatusInternal(callId, 'success', successResponse);
|
||||
this.setStatusInternal(callId, 'success', signal, successResponse);
|
||||
} else {
|
||||
// It is a failure
|
||||
const error = new Error(toolResult.error.message);
|
||||
@@ -1125,19 +1193,21 @@ export class CoreToolScheduler {
|
||||
error,
|
||||
toolResult.error.type,
|
||||
);
|
||||
this.setStatusInternal(callId, 'error', errorResponse);
|
||||
this.setStatusInternal(callId, 'error', signal, errorResponse);
|
||||
}
|
||||
} catch (executionError: unknown) {
|
||||
if (signal.aborted) {
|
||||
this.setStatusInternal(
|
||||
callId,
|
||||
'cancelled',
|
||||
signal,
|
||||
'User cancelled tool execution.',
|
||||
);
|
||||
} else {
|
||||
this.setStatusInternal(
|
||||
callId,
|
||||
'error',
|
||||
signal,
|
||||
createErrorResponse(
|
||||
scheduledCall.request,
|
||||
executionError instanceof Error
|
||||
@@ -1148,45 +1218,126 @@ export class CoreToolScheduler {
|
||||
);
|
||||
}
|
||||
}
|
||||
await this.checkAndNotifyCompletion(signal);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async checkAndNotifyCompletion(): Promise<void> {
|
||||
const allCallsAreTerminal = this.toolCalls.every(
|
||||
(call) =>
|
||||
call.status === 'success' ||
|
||||
call.status === 'error' ||
|
||||
call.status === 'cancelled',
|
||||
);
|
||||
private async checkAndNotifyCompletion(signal: AbortSignal): Promise<void> {
|
||||
// This method is now only concerned with the single active tool call.
|
||||
if (this.toolCalls.length === 0) {
|
||||
// It's possible to be called when a batch is cancelled before any tool has started.
|
||||
if (signal.aborted && this.toolCallQueue.length > 0) {
|
||||
this._cancelAllQueuedCalls();
|
||||
}
|
||||
} else {
|
||||
const activeCall = this.toolCalls[0];
|
||||
const isTerminal =
|
||||
activeCall.status === 'success' ||
|
||||
activeCall.status === 'error' ||
|
||||
activeCall.status === 'cancelled';
|
||||
|
||||
if (this.toolCalls.length > 0 && allCallsAreTerminal) {
|
||||
const completedCalls = [...this.toolCalls] as CompletedToolCall[];
|
||||
// If the active tool is not in a terminal state (e.g., it's 'executing' or 'awaiting_approval'),
|
||||
// then the scheduler is still busy or paused. We should not proceed.
|
||||
if (!isTerminal) {
|
||||
return;
|
||||
}
|
||||
|
||||
// The active tool is finished. Move it to the completed batch.
|
||||
const completedCall = activeCall as CompletedToolCall;
|
||||
this.completedToolCallsForBatch.push(completedCall);
|
||||
logToolCall(this.config, new ToolCallEvent(completedCall));
|
||||
|
||||
// Clear the active tool slot. This is crucial for the sequential processing.
|
||||
this.toolCalls = [];
|
||||
}
|
||||
|
||||
for (const call of completedCalls) {
|
||||
logToolCall(this.config, new ToolCallEvent(call));
|
||||
// Now, check if the entire batch is complete.
|
||||
// The batch is complete if the queue is empty or the operation was cancelled.
|
||||
if (this.toolCallQueue.length === 0 || signal.aborted) {
|
||||
if (signal.aborted) {
|
||||
this._cancelAllQueuedCalls();
|
||||
}
|
||||
|
||||
// If there's nothing to report and we weren't cancelled, we can stop.
|
||||
// But if we were cancelled, we must proceed to potentially start the next queued request.
|
||||
if (this.completedToolCallsForBatch.length === 0 && !signal.aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.onAllToolCallsComplete) {
|
||||
this.isFinalizingToolCalls = true;
|
||||
await this.onAllToolCallsComplete(completedCalls);
|
||||
// Use the batch array, not the (now empty) active array.
|
||||
await this.onAllToolCallsComplete(this.completedToolCallsForBatch);
|
||||
this.completedToolCallsForBatch = []; // Clear after reporting.
|
||||
this.isFinalizingToolCalls = false;
|
||||
}
|
||||
this.isCancelling = false;
|
||||
this.notifyToolCallsUpdate();
|
||||
// After completion, process the next item in the queue.
|
||||
|
||||
// After completion of the entire batch, process the next item in the main request queue.
|
||||
if (this.requestQueue.length > 0) {
|
||||
const next = this.requestQueue.shift()!;
|
||||
this._schedule(next.request, next.signal)
|
||||
.then(next.resolve)
|
||||
.catch(next.reject);
|
||||
}
|
||||
} else {
|
||||
// The batch is not yet complete, so continue processing the current batch sequence.
|
||||
await this._processNextInQueue(signal);
|
||||
}
|
||||
}
|
||||
|
||||
private _cancelAllQueuedCalls(): void {
|
||||
while (this.toolCallQueue.length > 0) {
|
||||
const queuedCall = this.toolCallQueue.shift()!;
|
||||
// Don't cancel tools that already errored during validation.
|
||||
if (queuedCall.status === 'error') {
|
||||
this.completedToolCallsForBatch.push(queuedCall);
|
||||
continue;
|
||||
}
|
||||
const durationMs =
|
||||
'startTime' in queuedCall && queuedCall.startTime
|
||||
? Date.now() - queuedCall.startTime
|
||||
: undefined;
|
||||
const errorMessage =
|
||||
'[Operation Cancelled] User cancelled the operation.';
|
||||
this.completedToolCallsForBatch.push({
|
||||
request: queuedCall.request,
|
||||
tool: queuedCall.tool,
|
||||
invocation: queuedCall.invocation,
|
||||
status: 'cancelled',
|
||||
response: {
|
||||
callId: queuedCall.request.callId,
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
id: queuedCall.request.callId,
|
||||
name: queuedCall.request.name,
|
||||
response: {
|
||||
error: errorMessage,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
resultDisplay: undefined,
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: errorMessage.length,
|
||||
},
|
||||
durationMs,
|
||||
outcome: ToolConfirmationOutcome.Cancel,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private notifyToolCallsUpdate(): void {
|
||||
if (this.onToolCallsUpdate) {
|
||||
this.onToolCallsUpdate([...this.toolCalls]);
|
||||
this.onToolCallsUpdate([
|
||||
...this.completedToolCallsForBatch,
|
||||
...this.toolCalls,
|
||||
...this.toolCallQueue,
|
||||
]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1215,35 +1366,4 @@ export class CoreToolScheduler {
|
||||
|
||||
return doesToolInvocationMatch(tool, invocation, allowedTools);
|
||||
}
|
||||
|
||||
private async autoApproveCompatiblePendingTools(
|
||||
signal: AbortSignal,
|
||||
triggeringCallId: string,
|
||||
): Promise<void> {
|
||||
const pendingTools = this.toolCalls.filter(
|
||||
(call) =>
|
||||
call.status === 'awaiting_approval' &&
|
||||
call.request.callId !== triggeringCallId,
|
||||
) as WaitingToolCall[];
|
||||
|
||||
for (const pendingTool of pendingTools) {
|
||||
try {
|
||||
const stillNeedsConfirmation =
|
||||
await pendingTool.invocation.shouldConfirmExecute(signal);
|
||||
|
||||
if (!stillNeedsConfirmation) {
|
||||
this.setToolCallOutcome(
|
||||
pendingTool.request.callId,
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
this.setStatusInternal(pendingTool.request.callId, 'scheduled');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error checking confirmation for tool ${pendingTool.request.callId}:`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user