mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-24 12:04:56 -07:00
feat(core, cli): Implement sequential approval. (#11593)
This commit is contained in:
@@ -288,6 +288,263 @@ describe('CoreToolScheduler', () => {
|
||||
expect(completedCalls[0].status).toBe('cancelled');
|
||||
});
|
||||
|
||||
it('should cancel all tools when cancelAll is called', async () => {
|
||||
const mockTool1 = new MockTool({
|
||||
name: 'mockTool1',
|
||||
shouldConfirmExecute: MOCK_TOOL_SHOULD_CONFIRM_EXECUTE,
|
||||
});
|
||||
const mockTool2 = new MockTool({ name: 'mockTool2' });
|
||||
const mockTool3 = new MockTool({ name: 'mockTool3' });
|
||||
|
||||
const mockToolRegistry = {
|
||||
getTool: (name: string) => {
|
||||
if (name === 'mockTool1') return mockTool1;
|
||||
if (name === 'mockTool2') return mockTool2;
|
||||
if (name === 'mockTool3') return mockTool3;
|
||||
return undefined;
|
||||
},
|
||||
getFunctionDeclarations: () => [],
|
||||
tools: new Map(),
|
||||
discovery: {},
|
||||
registerTool: () => {},
|
||||
getToolByName: (name: string) => {
|
||||
if (name === 'mockTool1') return mockTool1;
|
||||
if (name === 'mockTool2') return mockTool2;
|
||||
if (name === 'mockTool3') return mockTool3;
|
||||
return undefined;
|
||||
},
|
||||
getToolByDisplayName: () => undefined,
|
||||
getTools: () => [],
|
||||
discoverTools: async () => {},
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const onToolCallsUpdate = vi.fn();
|
||||
|
||||
const mockConfig = {
|
||||
getSessionId: () => 'test-session-id',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: () => ApprovalMode.DEFAULT,
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
terminalHeight: 30,
|
||||
}),
|
||||
storage: {
|
||||
getProjectTempDir: () => '/tmp',
|
||||
},
|
||||
getTruncateToolOutputThreshold: () =>
|
||||
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
|
||||
getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES,
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
getPolicyEngine: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config: mockConfig,
|
||||
onAllToolCallsComplete,
|
||||
onToolCallsUpdate,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
onEditorClose: vi.fn(),
|
||||
});
|
||||
|
||||
const abortController = new AbortController();
|
||||
const requests = [
|
||||
{
|
||||
callId: '1',
|
||||
name: 'mockTool1',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
{
|
||||
callId: '2',
|
||||
name: 'mockTool2',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
{
|
||||
callId: '3',
|
||||
name: 'mockTool3',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
];
|
||||
|
||||
// Don't await, let it run in the background
|
||||
void scheduler.schedule(requests, abortController.signal);
|
||||
|
||||
// Wait for the first tool to be awaiting approval
|
||||
await waitForStatus(onToolCallsUpdate, 'awaiting_approval');
|
||||
|
||||
// Cancel all operations
|
||||
scheduler.cancelAll(abortController.signal);
|
||||
abortController.abort(); // Also fire the signal
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
const completedCalls = onAllToolCallsComplete.mock
|
||||
.calls[0][0] as ToolCall[];
|
||||
|
||||
expect(completedCalls).toHaveLength(3);
|
||||
expect(completedCalls.find((c) => c.request.callId === '1')?.status).toBe(
|
||||
'cancelled',
|
||||
);
|
||||
expect(completedCalls.find((c) => c.request.callId === '2')?.status).toBe(
|
||||
'cancelled',
|
||||
);
|
||||
expect(completedCalls.find((c) => c.request.callId === '3')?.status).toBe(
|
||||
'cancelled',
|
||||
);
|
||||
});
|
||||
|
||||
it('should cancel all tools in a batch when one is cancelled via confirmation', async () => {
|
||||
const mockTool1 = new MockTool({
|
||||
name: 'mockTool1',
|
||||
shouldConfirmExecute: MOCK_TOOL_SHOULD_CONFIRM_EXECUTE,
|
||||
});
|
||||
const mockTool2 = new MockTool({ name: 'mockTool2' });
|
||||
const mockTool3 = new MockTool({ name: 'mockTool3' });
|
||||
|
||||
const mockToolRegistry = {
|
||||
getTool: (name: string) => {
|
||||
if (name === 'mockTool1') return mockTool1;
|
||||
if (name === 'mockTool2') return mockTool2;
|
||||
if (name === 'mockTool3') return mockTool3;
|
||||
return undefined;
|
||||
},
|
||||
getFunctionDeclarations: () => [],
|
||||
tools: new Map(),
|
||||
discovery: {},
|
||||
registerTool: () => {},
|
||||
getToolByName: (name: string) => {
|
||||
if (name === 'mockTool1') return mockTool1;
|
||||
if (name === 'mockTool2') return mockTool2;
|
||||
if (name === 'mockTool3') return mockTool3;
|
||||
return undefined;
|
||||
},
|
||||
getToolByDisplayName: () => undefined,
|
||||
getTools: () => [],
|
||||
discoverTools: async () => {},
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const onToolCallsUpdate = vi.fn();
|
||||
|
||||
const mockConfig = {
|
||||
getSessionId: () => 'test-session-id',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: () => ApprovalMode.DEFAULT,
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
terminalHeight: 30,
|
||||
}),
|
||||
storage: {
|
||||
getProjectTempDir: () => '/tmp',
|
||||
},
|
||||
getTruncateToolOutputThreshold: () =>
|
||||
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
|
||||
getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES,
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
getPolicyEngine: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config: mockConfig,
|
||||
onAllToolCallsComplete,
|
||||
onToolCallsUpdate,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
onEditorClose: vi.fn(),
|
||||
});
|
||||
|
||||
const abortController = new AbortController();
|
||||
const requests = [
|
||||
{
|
||||
callId: '1',
|
||||
name: 'mockTool1',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
{
|
||||
callId: '2',
|
||||
name: 'mockTool2',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
{
|
||||
callId: '3',
|
||||
name: 'mockTool3',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
];
|
||||
|
||||
// Don't await, let it run in the background
|
||||
void scheduler.schedule(requests, abortController.signal);
|
||||
|
||||
// Wait for the first tool to be awaiting approval
|
||||
const awaitingCall = (await waitForStatus(
|
||||
onToolCallsUpdate,
|
||||
'awaiting_approval',
|
||||
)) as WaitingToolCall;
|
||||
|
||||
// Cancel the first tool via its confirmation handler
|
||||
await awaitingCall.confirmationDetails.onConfirm(
|
||||
ToolConfirmationOutcome.Cancel,
|
||||
);
|
||||
abortController.abort(); // User cancelling often involves an abort signal
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
const completedCalls = onAllToolCallsComplete.mock
|
||||
.calls[0][0] as ToolCall[];
|
||||
|
||||
expect(completedCalls).toHaveLength(3);
|
||||
expect(completedCalls.find((c) => c.request.callId === '1')?.status).toBe(
|
||||
'cancelled',
|
||||
);
|
||||
expect(completedCalls.find((c) => c.request.callId === '2')?.status).toBe(
|
||||
'cancelled',
|
||||
);
|
||||
expect(completedCalls.find((c) => c.request.callId === '3')?.status).toBe(
|
||||
'cancelled',
|
||||
);
|
||||
});
|
||||
|
||||
it('should mark tool call as cancelled when abort happens during confirmation error', async () => {
|
||||
const abortController = new AbortController();
|
||||
const abortError = new Error('Abort requested during confirmation');
|
||||
@@ -1510,16 +1767,19 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
|
||||
await scheduler.schedule(requests, abortController.signal);
|
||||
|
||||
// Wait for all tools to be awaiting approval
|
||||
// Wait for the FIRST tool to be awaiting approval
|
||||
await vi.waitFor(() => {
|
||||
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
|
||||
// With the sequential scheduler, the update includes the active call and the queue.
|
||||
expect(calls?.length).toBe(3);
|
||||
expect(calls?.every((call) => call.status === 'awaiting_approval')).toBe(
|
||||
true,
|
||||
);
|
||||
expect(calls?.[0].status).toBe('awaiting_approval');
|
||||
expect(calls?.[0].request.callId).toBe('1');
|
||||
// Check that the other two are in the queue (still in 'validating' state)
|
||||
expect(calls?.[1].status).toBe('validating');
|
||||
expect(calls?.[2].status).toBe('validating');
|
||||
});
|
||||
|
||||
expect(pendingConfirmations.length).toBe(3);
|
||||
expect(pendingConfirmations.length).toBe(1);
|
||||
|
||||
// Approve the first tool with ProceedAlways
|
||||
const firstConfirmation = pendingConfirmations[0];
|
||||
@@ -1528,15 +1788,16 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
// Wait for all tools to be completed
|
||||
await vi.waitFor(() => {
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||
const completedCalls = onAllToolCallsComplete.mock.calls.at(
|
||||
-1,
|
||||
)?.[0] as ToolCall[];
|
||||
expect(completedCalls?.length).toBe(3);
|
||||
expect(completedCalls?.every((call) => call.status === 'success')).toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
const completedCalls = onAllToolCallsComplete.mock.calls.at(
|
||||
-1,
|
||||
)?.[0] as ToolCall[];
|
||||
expect(completedCalls?.length).toBe(3);
|
||||
expect(completedCalls?.every((call) => call.status === 'success')).toBe(
|
||||
true,
|
||||
);
|
||||
|
||||
// Verify approval mode was changed
|
||||
expect(approvalMode).toBe(ApprovalMode.AUTO_EDIT);
|
||||
});
|
||||
@@ -1788,11 +2049,10 @@ describe('CoreToolScheduler Sequential Execution', () => {
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// Check that execute was called for all three tools initially
|
||||
expect(executeFn).toHaveBeenCalledTimes(3);
|
||||
// Check that execute was called for the first two tools only
|
||||
expect(executeFn).toHaveBeenCalledTimes(2);
|
||||
expect(executeFn).toHaveBeenCalledWith({ call: 1 });
|
||||
expect(executeFn).toHaveBeenCalledWith({ call: 2 });
|
||||
expect(executeFn).toHaveBeenCalledWith({ call: 3 });
|
||||
|
||||
const completedCalls = onAllToolCallsComplete.mock
|
||||
.calls[0][0] as ToolCall[];
|
||||
|
||||
@@ -348,12 +348,15 @@ export class CoreToolScheduler {
|
||||
private onEditorClose: () => void;
|
||||
private isFinalizingToolCalls = false;
|
||||
private isScheduling = false;
|
||||
private isCancelling = false;
|
||||
private requestQueue: Array<{
|
||||
request: ToolCallRequestInfo | ToolCallRequestInfo[];
|
||||
signal: AbortSignal;
|
||||
resolve: () => void;
|
||||
reject: (reason?: Error) => void;
|
||||
}> = [];
|
||||
private toolCallQueue: ToolCall[] = [];
|
||||
private completedToolCallsForBatch: CompletedToolCall[] = [];
|
||||
|
||||
constructor(options: CoreToolSchedulerOptions) {
|
||||
this.config = options.config;
|
||||
@@ -398,30 +401,36 @@ export class CoreToolScheduler {
|
||||
private setStatusInternal(
|
||||
targetCallId: string,
|
||||
status: 'success',
|
||||
signal: AbortSignal,
|
||||
response: ToolCallResponseInfo,
|
||||
): void;
|
||||
private setStatusInternal(
|
||||
targetCallId: string,
|
||||
status: 'awaiting_approval',
|
||||
signal: AbortSignal,
|
||||
confirmationDetails: ToolCallConfirmationDetails,
|
||||
): void;
|
||||
private setStatusInternal(
|
||||
targetCallId: string,
|
||||
status: 'error',
|
||||
signal: AbortSignal,
|
||||
response: ToolCallResponseInfo,
|
||||
): void;
|
||||
private setStatusInternal(
|
||||
targetCallId: string,
|
||||
status: 'cancelled',
|
||||
signal: AbortSignal,
|
||||
reason: string,
|
||||
): void;
|
||||
private setStatusInternal(
|
||||
targetCallId: string,
|
||||
status: 'executing' | 'scheduled' | 'validating',
|
||||
signal: AbortSignal,
|
||||
): void;
|
||||
private setStatusInternal(
|
||||
targetCallId: string,
|
||||
newStatus: Status,
|
||||
signal: AbortSignal,
|
||||
auxiliaryData?: unknown,
|
||||
): void {
|
||||
this.toolCalls = this.toolCalls.map((currentCall) => {
|
||||
@@ -561,7 +570,6 @@ export class CoreToolScheduler {
|
||||
}
|
||||
});
|
||||
this.notifyToolCallsUpdate();
|
||||
this.checkAndNotifyCompletion();
|
||||
}
|
||||
|
||||
private setArgsInternal(targetCallId: string, args: unknown): void {
|
||||
@@ -692,11 +700,43 @@ export class CoreToolScheduler {
|
||||
return this._schedule(request, signal);
|
||||
}
|
||||
|
||||
cancelAll(signal: AbortSignal): void {
|
||||
if (this.isCancelling) {
|
||||
return;
|
||||
}
|
||||
this.isCancelling = true;
|
||||
// Cancel the currently active tool call, if there is one.
|
||||
if (this.toolCalls.length > 0) {
|
||||
const activeCall = this.toolCalls[0];
|
||||
// Only cancel if it's in a cancellable state.
|
||||
if (
|
||||
activeCall.status === 'awaiting_approval' ||
|
||||
activeCall.status === 'executing' ||
|
||||
activeCall.status === 'scheduled' ||
|
||||
activeCall.status === 'validating'
|
||||
) {
|
||||
this.setStatusInternal(
|
||||
activeCall.request.callId,
|
||||
'cancelled',
|
||||
signal,
|
||||
'User cancelled the operation.',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the queue and mark all queued items as cancelled for completion reporting.
|
||||
this._cancelAllQueuedCalls();
|
||||
|
||||
// Finalize the batch immediately.
|
||||
void this.checkAndNotifyCompletion(signal);
|
||||
}
|
||||
|
||||
private async _schedule(
|
||||
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||
signal: AbortSignal,
|
||||
): Promise<void> {
|
||||
this.isScheduling = true;
|
||||
this.isCancelling = false;
|
||||
try {
|
||||
if (this.isRunning()) {
|
||||
throw new Error(
|
||||
@@ -704,6 +744,7 @@ export class CoreToolScheduler {
|
||||
);
|
||||
}
|
||||
const requestsToProcess = Array.isArray(request) ? request : [request];
|
||||
this.completedToolCallsForBatch = [];
|
||||
|
||||
const newToolCalls: ToolCall[] = requestsToProcess.map(
|
||||
(reqInfo): ToolCall => {
|
||||
@@ -753,45 +794,74 @@ export class CoreToolScheduler {
|
||||
},
|
||||
);
|
||||
|
||||
this.toolCalls = this.toolCalls.concat(newToolCalls);
|
||||
this.notifyToolCallsUpdate();
|
||||
this.toolCallQueue.push(...newToolCalls);
|
||||
await this._processNextInQueue(signal);
|
||||
} finally {
|
||||
this.isScheduling = false;
|
||||
}
|
||||
}
|
||||
|
||||
for (const toolCall of newToolCalls) {
|
||||
if (toolCall.status !== 'validating') {
|
||||
continue;
|
||||
private async _processNextInQueue(signal: AbortSignal): Promise<void> {
|
||||
// If there's already a tool being processed, or the queue is empty, stop.
|
||||
if (this.toolCalls.length > 0 || this.toolCallQueue.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If cancellation happened between steps, handle it.
|
||||
if (signal.aborted) {
|
||||
this._cancelAllQueuedCalls();
|
||||
// Finalize the batch.
|
||||
await this.checkAndNotifyCompletion(signal);
|
||||
return;
|
||||
}
|
||||
|
||||
const toolCall = this.toolCallQueue.shift()!;
|
||||
|
||||
// This is now the single active tool call.
|
||||
this.toolCalls = [toolCall];
|
||||
this.notifyToolCallsUpdate();
|
||||
|
||||
// Handle tools that were already errored during creation.
|
||||
if (toolCall.status === 'error') {
|
||||
// An error during validation means this "active" tool is already complete.
|
||||
// We need to check for batch completion to either finish or process the next in queue.
|
||||
await this.checkAndNotifyCompletion(signal);
|
||||
return;
|
||||
}
|
||||
|
||||
// This logic is moved from the old `for` loop in `_schedule`.
|
||||
if (toolCall.status === 'validating') {
|
||||
const { request: reqInfo, invocation } = toolCall;
|
||||
|
||||
try {
|
||||
if (signal.aborted) {
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'cancelled',
|
||||
signal,
|
||||
'Tool call cancelled by user.',
|
||||
);
|
||||
// The completion check will handle the cascade.
|
||||
await this.checkAndNotifyCompletion(signal);
|
||||
return;
|
||||
}
|
||||
|
||||
const validatingCall = toolCall as ValidatingToolCall;
|
||||
const { request: reqInfo, invocation } = validatingCall;
|
||||
const confirmationDetails =
|
||||
await invocation.shouldConfirmExecute(signal);
|
||||
|
||||
try {
|
||||
if (signal.aborted) {
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'cancelled',
|
||||
'Tool call cancelled by user.',
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const confirmationDetails =
|
||||
await invocation.shouldConfirmExecute(signal);
|
||||
|
||||
if (!confirmationDetails) {
|
||||
if (!confirmationDetails) {
|
||||
this.setToolCallOutcome(
|
||||
reqInfo.callId,
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
this.setStatusInternal(reqInfo.callId, 'scheduled', signal);
|
||||
} else {
|
||||
if (this.isAutoApproved(toolCall)) {
|
||||
this.setToolCallOutcome(
|
||||
reqInfo.callId,
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
this.setStatusInternal(reqInfo.callId, 'scheduled');
|
||||
continue;
|
||||
}
|
||||
|
||||
if (this.isAutoApproved(validatingCall)) {
|
||||
this.setToolCallOutcome(
|
||||
reqInfo.callId,
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
this.setStatusInternal(reqInfo.callId, 'scheduled');
|
||||
this.setStatusInternal(reqInfo.callId, 'scheduled', signal);
|
||||
} else {
|
||||
// Allow IDE to resolve confirmation
|
||||
if (
|
||||
@@ -835,35 +905,36 @@ export class CoreToolScheduler {
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'awaiting_approval',
|
||||
signal,
|
||||
wrappedConfirmationDetails,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'cancelled',
|
||||
'Tool call cancelled by user.',
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
}
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'cancelled',
|
||||
signal,
|
||||
'Tool call cancelled by user.',
|
||||
);
|
||||
await this.checkAndNotifyCompletion(signal);
|
||||
} else {
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'error',
|
||||
signal,
|
||||
createErrorResponse(
|
||||
reqInfo,
|
||||
error instanceof Error ? error : new Error(String(error)),
|
||||
ToolErrorType.UNHANDLED_EXCEPTION,
|
||||
),
|
||||
);
|
||||
await this.checkAndNotifyCompletion(signal);
|
||||
}
|
||||
}
|
||||
await this.attemptExecutionOfScheduledCalls(signal);
|
||||
void this.checkAndNotifyCompletion();
|
||||
} finally {
|
||||
this.isScheduling = false;
|
||||
}
|
||||
await this.attemptExecutionOfScheduledCalls(signal);
|
||||
}
|
||||
|
||||
async handleConfirmationResponse(
|
||||
@@ -881,18 +952,12 @@ export class CoreToolScheduler {
|
||||
await originalOnConfirm(outcome);
|
||||
}
|
||||
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
await this.autoApproveCompatiblePendingTools(signal, callId);
|
||||
}
|
||||
|
||||
this.setToolCallOutcome(callId, outcome);
|
||||
|
||||
if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) {
|
||||
this.setStatusInternal(
|
||||
callId,
|
||||
'cancelled',
|
||||
'User did not allow tool call',
|
||||
);
|
||||
// Instead of just cancelling one tool, trigger the full cancel cascade.
|
||||
this.cancelAll(signal);
|
||||
return; // `cancelAll` calls `checkAndNotifyCompletion`, so we can exit here.
|
||||
} else if (outcome === ToolConfirmationOutcome.ModifyWithEditor) {
|
||||
const waitingToolCall = toolCall as WaitingToolCall;
|
||||
if (isModifiableDeclarativeTool(waitingToolCall.tool)) {
|
||||
@@ -902,7 +967,7 @@ export class CoreToolScheduler {
|
||||
return;
|
||||
}
|
||||
|
||||
this.setStatusInternal(callId, 'awaiting_approval', {
|
||||
this.setStatusInternal(callId, 'awaiting_approval', signal, {
|
||||
...waitingToolCall.confirmationDetails,
|
||||
isModifying: true,
|
||||
} as ToolCallConfirmationDetails);
|
||||
@@ -917,7 +982,7 @@ export class CoreToolScheduler {
|
||||
this.onEditorClose,
|
||||
);
|
||||
this.setArgsInternal(callId, updatedParams);
|
||||
this.setStatusInternal(callId, 'awaiting_approval', {
|
||||
this.setStatusInternal(callId, 'awaiting_approval', signal, {
|
||||
...waitingToolCall.confirmationDetails,
|
||||
fileDiff: updatedDiff,
|
||||
isModifying: false,
|
||||
@@ -932,7 +997,7 @@ export class CoreToolScheduler {
|
||||
signal,
|
||||
);
|
||||
}
|
||||
this.setStatusInternal(callId, 'scheduled');
|
||||
this.setStatusInternal(callId, 'scheduled', signal);
|
||||
}
|
||||
await this.attemptExecutionOfScheduledCalls(signal);
|
||||
}
|
||||
@@ -974,10 +1039,15 @@ export class CoreToolScheduler {
|
||||
);
|
||||
|
||||
this.setArgsInternal(toolCall.request.callId, updatedParams);
|
||||
this.setStatusInternal(toolCall.request.callId, 'awaiting_approval', {
|
||||
...toolCall.confirmationDetails,
|
||||
fileDiff: updatedDiff,
|
||||
});
|
||||
this.setStatusInternal(
|
||||
toolCall.request.callId,
|
||||
'awaiting_approval',
|
||||
signal,
|
||||
{
|
||||
...toolCall.confirmationDetails,
|
||||
fileDiff: updatedDiff,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
private async attemptExecutionOfScheduledCalls(
|
||||
@@ -1002,7 +1072,7 @@ export class CoreToolScheduler {
|
||||
const scheduledCall = toolCall;
|
||||
const { callId, name: toolName } = scheduledCall.request;
|
||||
const invocation = scheduledCall.invocation;
|
||||
this.setStatusInternal(callId, 'executing');
|
||||
this.setStatusInternal(callId, 'executing', signal);
|
||||
|
||||
const liveOutputCallback =
|
||||
scheduledCall.tool.canUpdateOutput && this.outputUpdateHandler
|
||||
@@ -1055,12 +1125,10 @@ export class CoreToolScheduler {
|
||||
this.setStatusInternal(
|
||||
callId,
|
||||
'cancelled',
|
||||
signal,
|
||||
'User cancelled tool execution.',
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (toolResult.error === undefined) {
|
||||
} else if (toolResult.error === undefined) {
|
||||
let content = toolResult.llmContent;
|
||||
let outputFile: string | undefined = undefined;
|
||||
const contentLength =
|
||||
@@ -1116,7 +1184,7 @@ export class CoreToolScheduler {
|
||||
outputFile,
|
||||
contentLength,
|
||||
};
|
||||
this.setStatusInternal(callId, 'success', successResponse);
|
||||
this.setStatusInternal(callId, 'success', signal, successResponse);
|
||||
} else {
|
||||
// It is a failure
|
||||
const error = new Error(toolResult.error.message);
|
||||
@@ -1125,19 +1193,21 @@ export class CoreToolScheduler {
|
||||
error,
|
||||
toolResult.error.type,
|
||||
);
|
||||
this.setStatusInternal(callId, 'error', errorResponse);
|
||||
this.setStatusInternal(callId, 'error', signal, errorResponse);
|
||||
}
|
||||
} catch (executionError: unknown) {
|
||||
if (signal.aborted) {
|
||||
this.setStatusInternal(
|
||||
callId,
|
||||
'cancelled',
|
||||
signal,
|
||||
'User cancelled tool execution.',
|
||||
);
|
||||
} else {
|
||||
this.setStatusInternal(
|
||||
callId,
|
||||
'error',
|
||||
signal,
|
||||
createErrorResponse(
|
||||
scheduledCall.request,
|
||||
executionError instanceof Error
|
||||
@@ -1148,45 +1218,126 @@ export class CoreToolScheduler {
|
||||
);
|
||||
}
|
||||
}
|
||||
await this.checkAndNotifyCompletion(signal);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async checkAndNotifyCompletion(): Promise<void> {
|
||||
const allCallsAreTerminal = this.toolCalls.every(
|
||||
(call) =>
|
||||
call.status === 'success' ||
|
||||
call.status === 'error' ||
|
||||
call.status === 'cancelled',
|
||||
);
|
||||
private async checkAndNotifyCompletion(signal: AbortSignal): Promise<void> {
|
||||
// This method is now only concerned with the single active tool call.
|
||||
if (this.toolCalls.length === 0) {
|
||||
// It's possible to be called when a batch is cancelled before any tool has started.
|
||||
if (signal.aborted && this.toolCallQueue.length > 0) {
|
||||
this._cancelAllQueuedCalls();
|
||||
}
|
||||
} else {
|
||||
const activeCall = this.toolCalls[0];
|
||||
const isTerminal =
|
||||
activeCall.status === 'success' ||
|
||||
activeCall.status === 'error' ||
|
||||
activeCall.status === 'cancelled';
|
||||
|
||||
if (this.toolCalls.length > 0 && allCallsAreTerminal) {
|
||||
const completedCalls = [...this.toolCalls] as CompletedToolCall[];
|
||||
// If the active tool is not in a terminal state (e.g., it's 'executing' or 'awaiting_approval'),
|
||||
// then the scheduler is still busy or paused. We should not proceed.
|
||||
if (!isTerminal) {
|
||||
return;
|
||||
}
|
||||
|
||||
// The active tool is finished. Move it to the completed batch.
|
||||
const completedCall = activeCall as CompletedToolCall;
|
||||
this.completedToolCallsForBatch.push(completedCall);
|
||||
logToolCall(this.config, new ToolCallEvent(completedCall));
|
||||
|
||||
// Clear the active tool slot. This is crucial for the sequential processing.
|
||||
this.toolCalls = [];
|
||||
}
|
||||
|
||||
for (const call of completedCalls) {
|
||||
logToolCall(this.config, new ToolCallEvent(call));
|
||||
// Now, check if the entire batch is complete.
|
||||
// The batch is complete if the queue is empty or the operation was cancelled.
|
||||
if (this.toolCallQueue.length === 0 || signal.aborted) {
|
||||
if (signal.aborted) {
|
||||
this._cancelAllQueuedCalls();
|
||||
}
|
||||
|
||||
// If there's nothing to report and we weren't cancelled, we can stop.
|
||||
// But if we were cancelled, we must proceed to potentially start the next queued request.
|
||||
if (this.completedToolCallsForBatch.length === 0 && !signal.aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.onAllToolCallsComplete) {
|
||||
this.isFinalizingToolCalls = true;
|
||||
await this.onAllToolCallsComplete(completedCalls);
|
||||
// Use the batch array, not the (now empty) active array.
|
||||
await this.onAllToolCallsComplete(this.completedToolCallsForBatch);
|
||||
this.completedToolCallsForBatch = []; // Clear after reporting.
|
||||
this.isFinalizingToolCalls = false;
|
||||
}
|
||||
this.isCancelling = false;
|
||||
this.notifyToolCallsUpdate();
|
||||
// After completion, process the next item in the queue.
|
||||
|
||||
// After completion of the entire batch, process the next item in the main request queue.
|
||||
if (this.requestQueue.length > 0) {
|
||||
const next = this.requestQueue.shift()!;
|
||||
this._schedule(next.request, next.signal)
|
||||
.then(next.resolve)
|
||||
.catch(next.reject);
|
||||
}
|
||||
} else {
|
||||
// The batch is not yet complete, so continue processing the current batch sequence.
|
||||
await this._processNextInQueue(signal);
|
||||
}
|
||||
}
|
||||
|
||||
private _cancelAllQueuedCalls(): void {
|
||||
while (this.toolCallQueue.length > 0) {
|
||||
const queuedCall = this.toolCallQueue.shift()!;
|
||||
// Don't cancel tools that already errored during validation.
|
||||
if (queuedCall.status === 'error') {
|
||||
this.completedToolCallsForBatch.push(queuedCall);
|
||||
continue;
|
||||
}
|
||||
const durationMs =
|
||||
'startTime' in queuedCall && queuedCall.startTime
|
||||
? Date.now() - queuedCall.startTime
|
||||
: undefined;
|
||||
const errorMessage =
|
||||
'[Operation Cancelled] User cancelled the operation.';
|
||||
this.completedToolCallsForBatch.push({
|
||||
request: queuedCall.request,
|
||||
tool: queuedCall.tool,
|
||||
invocation: queuedCall.invocation,
|
||||
status: 'cancelled',
|
||||
response: {
|
||||
callId: queuedCall.request.callId,
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
id: queuedCall.request.callId,
|
||||
name: queuedCall.request.name,
|
||||
response: {
|
||||
error: errorMessage,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
resultDisplay: undefined,
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: errorMessage.length,
|
||||
},
|
||||
durationMs,
|
||||
outcome: ToolConfirmationOutcome.Cancel,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private notifyToolCallsUpdate(): void {
|
||||
if (this.onToolCallsUpdate) {
|
||||
this.onToolCallsUpdate([...this.toolCalls]);
|
||||
this.onToolCallsUpdate([
|
||||
...this.completedToolCallsForBatch,
|
||||
...this.toolCalls,
|
||||
...this.toolCallQueue,
|
||||
]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1215,35 +1366,4 @@ export class CoreToolScheduler {
|
||||
|
||||
return doesToolInvocationMatch(tool, invocation, allowedTools);
|
||||
}
|
||||
|
||||
private async autoApproveCompatiblePendingTools(
|
||||
signal: AbortSignal,
|
||||
triggeringCallId: string,
|
||||
): Promise<void> {
|
||||
const pendingTools = this.toolCalls.filter(
|
||||
(call) =>
|
||||
call.status === 'awaiting_approval' &&
|
||||
call.request.callId !== triggeringCallId,
|
||||
) as WaitingToolCall[];
|
||||
|
||||
for (const pendingTool of pendingTools) {
|
||||
try {
|
||||
const stillNeedsConfirmation =
|
||||
await pendingTool.invocation.shouldConfirmExecute(signal);
|
||||
|
||||
if (!stillNeedsConfirmation) {
|
||||
this.setToolCallOutcome(
|
||||
pendingTool.request.callId,
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
this.setStatusInternal(pendingTool.request.callId, 'scheduled');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error checking confirmation for tool ${pendingTool.request.callId}:`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user