refactor(cli): migrate non-interactive flow to event-driven scheduler (#17572)

This commit is contained in:
Abhi
2026-01-26 22:11:29 -05:00
committed by GitHub
parent a79051d9f8
commit 67b00252d3
6 changed files with 383 additions and 211 deletions

View File

@@ -14,7 +14,6 @@ import type {
UserFeedbackPayload,
} from '@google/gemini-cli-core';
import {
executeToolCall,
ToolErrorType,
GeminiEventType,
OutputFormat,
@@ -48,6 +47,8 @@ const mockCoreEvents = vi.hoisted(() => ({
drainBacklogs: vi.fn(),
}));
const mockSchedulerSchedule = vi.hoisted(() => vi.fn());
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const original =
await importOriginal<typeof import('@google/gemini-cli-core')>();
@@ -61,7 +62,10 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
return {
...original,
executeToolCall: vi.fn(),
Scheduler: class {
schedule = mockSchedulerSchedule;
cancelAll = vi.fn();
},
isTelemetrySdkInitialized: vi.fn().mockReturnValue(true),
ChatRecordingService: MockChatRecordingService,
uiTelemetryService: {
@@ -91,7 +95,6 @@ describe('runNonInteractive', () => {
let mockConfig: Config;
let mockSettings: LoadedSettings;
let mockToolRegistry: ToolRegistry;
let mockCoreExecuteToolCall: Mock;
let consoleErrorSpy: MockInstance;
let processStdoutSpy: MockInstance;
let processStderrSpy: MockInstance;
@@ -122,7 +125,7 @@ describe('runNonInteractive', () => {
};
beforeEach(async () => {
mockCoreExecuteToolCall = vi.mocked(executeToolCall);
mockSchedulerSchedule.mockReset();
mockCommandServiceCreate.mockResolvedValue({
getCommands: mockGetCommands,
@@ -158,6 +161,11 @@ describe('runNonInteractive', () => {
mockConfig = {
initialize: vi.fn().mockResolvedValue(undefined),
getMessageBus: vi.fn().mockReturnValue({
subscribe: vi.fn(),
unsubscribe: vi.fn(),
publish: vi.fn(),
}),
getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient),
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
getMaxSessionTurns: vi.fn().mockReturnValue(10),
@@ -263,25 +271,27 @@ describe('runNonInteractive', () => {
},
};
const toolResponse: Part[] = [{ text: 'Tool response' }];
mockCoreExecuteToolCall.mockResolvedValue({
status: 'success',
request: {
callId: 'tool-1',
name: 'testTool',
args: { arg1: 'value1' },
isClientInitiated: false,
prompt_id: 'prompt-id-2',
mockSchedulerSchedule.mockResolvedValue([
{
status: 'success',
request: {
callId: 'tool-1',
name: 'testTool',
args: { arg1: 'value1' },
isClientInitiated: false,
prompt_id: 'prompt-id-2',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
responseParts: toolResponse,
callId: 'tool-1',
error: undefined,
errorType: undefined,
contentLength: undefined,
},
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
responseParts: toolResponse,
callId: 'tool-1',
error: undefined,
errorType: undefined,
contentLength: undefined,
},
});
]);
const firstCallEvents: ServerGeminiStreamEvent[] = [toolCallEvent];
const secondCallEvents: ServerGeminiStreamEvent[] = [
@@ -304,9 +314,8 @@ describe('runNonInteractive', () => {
});
expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2);
expect(mockCoreExecuteToolCall).toHaveBeenCalledWith(
mockConfig,
expect.objectContaining({ name: 'testTool' }),
expect(mockSchedulerSchedule).toHaveBeenCalledWith(
[expect.objectContaining({ name: 'testTool' })],
expect.any(AbortSignal),
);
expect(mockGeminiClient.sendMessageStream).toHaveBeenNthCalledWith(
@@ -335,16 +344,18 @@ describe('runNonInteractive', () => {
};
// 2. Mock the execution of the tools. We just need them to succeed.
mockCoreExecuteToolCall.mockResolvedValue({
status: 'success',
request: toolCallEvent.value, // This is generic enough for both calls
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
responseParts: [],
callId: 'mock-tool',
mockSchedulerSchedule.mockResolvedValue([
{
status: 'success',
request: toolCallEvent.value, // This is generic enough for both calls
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
responseParts: [],
callId: 'mock-tool',
},
},
});
]);
// 3. Define the sequence of events streamed from the mock model.
// Turn 1: Model outputs text, then requests a tool call.
@@ -385,7 +396,7 @@ describe('runNonInteractive', () => {
expect(getWrittenOutput()).toMatchSnapshot();
// Also verify the tools were called as expected.
expect(mockCoreExecuteToolCall).toHaveBeenCalledTimes(2);
expect(mockSchedulerSchedule).toHaveBeenCalledTimes(2);
});
it('should handle error during tool execution and should send error back to the model', async () => {
@@ -399,34 +410,36 @@ describe('runNonInteractive', () => {
prompt_id: 'prompt-id-3',
},
};
mockCoreExecuteToolCall.mockResolvedValue({
status: 'error',
request: {
callId: 'tool-1',
name: 'errorTool',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-3',
},
tool: {} as AnyDeclarativeTool,
response: {
callId: 'tool-1',
error: new Error('Execution failed'),
errorType: ToolErrorType.EXECUTION_FAILED,
responseParts: [
{
functionResponse: {
name: 'errorTool',
response: {
output: 'Error: Execution failed',
mockSchedulerSchedule.mockResolvedValue([
{
status: 'error',
request: {
callId: 'tool-1',
name: 'errorTool',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-3',
},
tool: {} as AnyDeclarativeTool,
response: {
callId: 'tool-1',
error: new Error('Execution failed'),
errorType: ToolErrorType.EXECUTION_FAILED,
responseParts: [
{
functionResponse: {
name: 'errorTool',
response: {
output: 'Error: Execution failed',
},
},
},
},
],
resultDisplay: 'Execution failed',
contentLength: undefined,
],
resultDisplay: 'Execution failed',
contentLength: undefined,
},
},
});
]);
const finalResponse: ServerGeminiStreamEvent[] = [
{
type: GeminiEventType.Content,
@@ -448,7 +461,7 @@ describe('runNonInteractive', () => {
prompt_id: 'prompt-id-3',
});
expect(mockCoreExecuteToolCall).toHaveBeenCalled();
expect(mockSchedulerSchedule).toHaveBeenCalled();
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Error executing tool errorTool: Execution failed',
);
@@ -498,24 +511,26 @@ describe('runNonInteractive', () => {
prompt_id: 'prompt-id-5',
},
};
mockCoreExecuteToolCall.mockResolvedValue({
status: 'error',
request: {
callId: 'tool-1',
name: 'nonexistentTool',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-5',
mockSchedulerSchedule.mockResolvedValue([
{
status: 'error',
request: {
callId: 'tool-1',
name: 'nonexistentTool',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-5',
},
response: {
callId: 'tool-1',
error: new Error('Tool "nonexistentTool" not found in registry.'),
resultDisplay: 'Tool "nonexistentTool" not found in registry.',
responseParts: [],
errorType: undefined,
contentLength: undefined,
},
},
response: {
callId: 'tool-1',
error: new Error('Tool "nonexistentTool" not found in registry.'),
resultDisplay: 'Tool "nonexistentTool" not found in registry.',
responseParts: [],
errorType: undefined,
contentLength: undefined,
},
});
]);
const finalResponse: ServerGeminiStreamEvent[] = [
{
type: GeminiEventType.Content,
@@ -538,7 +553,7 @@ describe('runNonInteractive', () => {
prompt_id: 'prompt-id-5',
});
expect(mockCoreExecuteToolCall).toHaveBeenCalled();
expect(mockSchedulerSchedule).toHaveBeenCalled();
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Error executing tool nonexistentTool: Tool "nonexistentTool" not found in registry.',
);
@@ -665,25 +680,27 @@ describe('runNonInteractive', () => {
},
};
const toolResponse: Part[] = [{ text: 'Tool executed successfully' }];
mockCoreExecuteToolCall.mockResolvedValue({
status: 'success',
request: {
callId: 'tool-1',
name: 'testTool',
args: { arg1: 'value1' },
isClientInitiated: false,
prompt_id: 'prompt-id-tool-only',
mockSchedulerSchedule.mockResolvedValue([
{
status: 'success',
request: {
callId: 'tool-1',
name: 'testTool',
args: { arg1: 'value1' },
isClientInitiated: false,
prompt_id: 'prompt-id-tool-only',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
responseParts: toolResponse,
callId: 'tool-1',
error: undefined,
errorType: undefined,
contentLength: undefined,
},
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
responseParts: toolResponse,
callId: 'tool-1',
error: undefined,
errorType: undefined,
contentLength: undefined,
},
});
]);
// First call returns only tool call, no content
const firstCallEvents: ServerGeminiStreamEvent[] = [
@@ -719,9 +736,8 @@ describe('runNonInteractive', () => {
});
expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2);
expect(mockCoreExecuteToolCall).toHaveBeenCalledWith(
mockConfig,
expect.objectContaining({ name: 'testTool' }),
expect(mockSchedulerSchedule).toHaveBeenCalledWith(
[expect.objectContaining({ name: 'testTool' })],
expect.any(AbortSignal),
);
@@ -1248,25 +1264,27 @@ describe('runNonInteractive', () => {
},
};
const toolResponse: Part[] = [{ text: 'file.txt' }];
mockCoreExecuteToolCall.mockResolvedValue({
status: 'success',
request: {
callId: 'tool-shell-1',
name: 'ShellTool',
args: { command: 'ls' },
isClientInitiated: false,
prompt_id: 'prompt-id-allowed',
mockSchedulerSchedule.mockResolvedValue([
{
status: 'success',
request: {
callId: 'tool-shell-1',
name: 'ShellTool',
args: { command: 'ls' },
isClientInitiated: false,
prompt_id: 'prompt-id-allowed',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
responseParts: toolResponse,
callId: 'tool-shell-1',
error: undefined,
errorType: undefined,
contentLength: undefined,
},
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
responseParts: toolResponse,
callId: 'tool-shell-1',
error: undefined,
errorType: undefined,
contentLength: undefined,
},
});
]);
const firstCallEvents: ServerGeminiStreamEvent[] = [toolCallEvent];
const secondCallEvents: ServerGeminiStreamEvent[] = [
@@ -1288,9 +1306,8 @@ describe('runNonInteractive', () => {
prompt_id: 'prompt-id-allowed',
});
expect(mockCoreExecuteToolCall).toHaveBeenCalledWith(
mockConfig,
expect.objectContaining({ name: 'ShellTool' }),
expect(mockSchedulerSchedule).toHaveBeenCalledWith(
[expect.objectContaining({ name: 'ShellTool' })],
expect.any(AbortSignal),
);
expect(getWrittenOutput()).toBe('file.txt\n');
@@ -1446,20 +1463,22 @@ describe('runNonInteractive', () => {
},
};
mockCoreExecuteToolCall.mockResolvedValue({
status: 'success',
request: toolCallEvent.value,
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
responseParts: [{ text: 'Tool response' }],
callId: 'tool-1',
error: undefined,
errorType: undefined,
contentLength: undefined,
resultDisplay: 'Tool executed successfully',
mockSchedulerSchedule.mockResolvedValue([
{
status: 'success',
request: toolCallEvent.value,
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
responseParts: [{ text: 'Tool response' }],
callId: 'tool-1',
error: undefined,
errorType: undefined,
contentLength: undefined,
resultDisplay: 'Tool executed successfully',
},
},
});
]);
const firstCallEvents: ServerGeminiStreamEvent[] = [
{ type: GeminiEventType.Content, value: 'Thinking...' },
@@ -1636,19 +1655,21 @@ describe('runNonInteractive', () => {
prompt_id: 'prompt-id-tool-error',
},
};
mockCoreExecuteToolCall.mockResolvedValue({
status: 'success',
request: toolCallEvent.value,
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
responseParts: [],
callId: 'tool-1',
error: undefined,
errorType: undefined,
contentLength: undefined,
mockSchedulerSchedule.mockResolvedValue([
{
status: 'success',
request: toolCallEvent.value,
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
responseParts: [],
callId: 'tool-1',
error: undefined,
errorType: undefined,
contentLength: undefined,
},
},
});
]);
const events: ServerGeminiStreamEvent[] = [
toolCallEvent,
@@ -1717,19 +1738,21 @@ describe('runNonInteractive', () => {
};
// Mock tool execution returning STOP_EXECUTION
mockCoreExecuteToolCall.mockResolvedValue({
status: 'error',
request: toolCallEvent.value,
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'stop-call',
responseParts: [{ text: 'error occurred' }],
errorType: ToolErrorType.STOP_EXECUTION,
error: new Error('Stop reason from hook'),
resultDisplay: undefined,
mockSchedulerSchedule.mockResolvedValue([
{
status: 'error',
request: toolCallEvent.value,
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'stop-call',
responseParts: [{ text: 'error occurred' }],
errorType: ToolErrorType.STOP_EXECUTION,
error: new Error('Stop reason from hook'),
resultDisplay: undefined,
},
},
});
]);
const firstCallEvents: ServerGeminiStreamEvent[] = [
{ type: GeminiEventType.Content, value: 'Executing tool...' },
@@ -1750,7 +1773,7 @@ describe('runNonInteractive', () => {
prompt_id: 'prompt-id-stop',
});
expect(mockCoreExecuteToolCall).toHaveBeenCalled();
expect(mockSchedulerSchedule).toHaveBeenCalled();
// The key assertion: sendMessageStream should have been called ONLY ONCE (initial user input).
expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(1);
@@ -1777,19 +1800,21 @@ describe('runNonInteractive', () => {
},
};
mockCoreExecuteToolCall.mockResolvedValue({
status: 'error',
request: toolCallEvent.value,
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'stop-call',
responseParts: [{ text: 'error occurred' }],
errorType: ToolErrorType.STOP_EXECUTION,
error: new Error('Stop reason'),
resultDisplay: undefined,
mockSchedulerSchedule.mockResolvedValue([
{
status: 'error',
request: toolCallEvent.value,
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'stop-call',
responseParts: [{ text: 'error occurred' }],
errorType: ToolErrorType.STOP_EXECUTION,
error: new Error('Stop reason'),
resultDisplay: undefined,
},
},
});
]);
const firstCallEvents: ServerGeminiStreamEvent[] = [
{ type: GeminiEventType.Content, value: 'Partial content' },
@@ -1839,19 +1864,21 @@ describe('runNonInteractive', () => {
},
};
mockCoreExecuteToolCall.mockResolvedValue({
status: 'error',
request: toolCallEvent.value,
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'stop-call',
responseParts: [{ text: 'error occurred' }],
errorType: ToolErrorType.STOP_EXECUTION,
error: new Error('Stop reason'),
resultDisplay: undefined,
mockSchedulerSchedule.mockResolvedValue([
{
status: 'error',
request: toolCallEvent.value,
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'stop-call',
responseParts: [{ text: 'error occurred' }],
errorType: ToolErrorType.STOP_EXECUTION,
error: new Error('Stop reason'),
resultDisplay: undefined,
},
},
});
]);
const firstCallEvents: ServerGeminiStreamEvent[] = [toolCallEvent];
@@ -2066,5 +2093,63 @@ describe('runNonInteractive', () => {
expect.stringContaining('[WARNING] --raw-output is enabled'),
);
});
it('should report cancelled tool calls as success in stream-json mode (legacy parity)', async () => {
const toolCallEvent: ServerGeminiStreamEvent = {
type: GeminiEventType.ToolCallRequest,
value: {
callId: 'tool-1',
name: 'testTool',
args: { arg1: 'value1' },
isClientInitiated: false,
prompt_id: 'prompt-id-cancel',
},
};
// Mock the scheduler to return a cancelled status
mockSchedulerSchedule.mockResolvedValue([
{
status: 'cancelled',
request: toolCallEvent.value,
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'tool-1',
responseParts: [{ text: 'Operation cancelled' }],
resultDisplay: 'Cancelled',
},
},
]);
const events: ServerGeminiStreamEvent[] = [
toolCallEvent,
{
type: GeminiEventType.Content,
value: 'Model continues...',
},
];
mockGeminiClient.sendMessageStream.mockReturnValue(
createStreamFromEvents(events),
);
vi.mocked(mockConfig.getOutputFormat).mockReturnValue(
OutputFormat.STREAM_JSON,
);
vi.mocked(uiTelemetryService.getMetrics).mockReturnValue(
MOCK_SESSION_METRICS,
);
await runNonInteractive({
config: mockConfig,
settings: mockSettings,
input: 'Test input',
prompt_id: 'prompt-id-cancel',
});
const output = getWrittenOutput();
expect(output).toContain('"type":"tool_result"');
expect(output).toContain('"status":"success"');
});
});
});

View File

@@ -8,13 +8,11 @@ import type {
Config,
ToolCallRequestInfo,
ResumedSessionData,
CompletedToolCall,
UserFeedbackPayload,
} from '@google/gemini-cli-core';
import { isSlashCommand } from './ui/utils/commandUtils.js';
import type { LoadedSettings } from './config/settings.js';
import {
executeToolCall,
GeminiEventType,
FatalInputError,
promptIdContext,
@@ -29,6 +27,8 @@ import {
createWorkingStdio,
recordToolCallInteractions,
ToolErrorType,
Scheduler,
ROOT_SCHEDULER_ID,
} from '@google/gemini-cli-core';
import type { Content, Part } from '@google/genai';
@@ -202,6 +202,12 @@ export async function runNonInteractive({
});
const geminiClient = config.getGeminiClient();
const scheduler = new Scheduler({
config,
messageBus: config.getMessageBus(),
getPreferredEditor: () => undefined,
schedulerId: ROOT_SCHEDULER_ID,
});
// Initialize chat. Resume if resume data is passed.
if (resumedSessionData) {
@@ -375,25 +381,23 @@ export async function runNonInteractive({
if (toolCallRequests.length > 0) {
textOutput.ensureTrailingNewline();
const completedToolCalls = await scheduler.schedule(
toolCallRequests,
abortController.signal,
);
const toolResponseParts: Part[] = [];
const completedToolCalls: CompletedToolCall[] = [];
for (const requestInfo of toolCallRequests) {
const completedToolCall = await executeToolCall(
config,
requestInfo,
abortController.signal,
);
for (const completedToolCall of completedToolCalls) {
const toolResponse = completedToolCall.response;
completedToolCalls.push(completedToolCall);
const requestInfo = completedToolCall.request;
if (streamFormatter) {
streamFormatter.emitEvent({
type: JsonStreamEventType.TOOL_RESULT,
timestamp: new Date().toISOString(),
tool_id: requestInfo.callId,
status: toolResponse.error ? 'error' : 'success',
status:
completedToolCall.status === 'error' ? 'error' : 'success',
output:
typeof toolResponse.resultDisplay === 'string'
? toolResponse.resultDisplay

View File

@@ -35,7 +35,10 @@ vi.mock('../telemetry/types.js', () => ({
ToolCallEvent: vi.fn().mockImplementation((call) => ({ ...call })),
}));
import { SchedulerStateManager } from './state-manager.js';
import {
SchedulerStateManager,
type TerminalCallHandler,
} from './state-manager.js';
import { resolveConfirmation } from './confirmation.js';
import { checkPolicy, updatePolicy } from './policy.js';
import { ToolExecutor } from './tool-executor.js';
@@ -64,6 +67,7 @@ import type {
SuccessfulToolCall,
ErroredToolCall,
CancelledToolCall,
CompletedToolCall,
ToolCallResponseInfo,
} from './types.js';
import { ROOT_SCHEDULER_ID } from './types.js';
@@ -201,10 +205,27 @@ describe('Scheduler (Orchestrator)', () => {
applyInlineModify: vi.fn(),
} as unknown as Mocked<ToolModificationHandler>;
// Wire up class constructors to return our mock instances
vi.mocked(SchedulerStateManager).mockReturnValue(
mockStateManager as unknown as Mocked<SchedulerStateManager>,
let capturedTerminalHandler: TerminalCallHandler | undefined;
vi.mocked(SchedulerStateManager).mockImplementation(
(_messageBus, _schedulerId, onTerminalCall) => {
capturedTerminalHandler = onTerminalCall;
return mockStateManager as unknown as SchedulerStateManager;
},
);
mockStateManager.finalizeCall.mockImplementation((callId: string) => {
const call = mockStateManager.getToolCall(callId);
if (call) {
capturedTerminalHandler?.(call as CompletedToolCall);
}
});
mockStateManager.cancelAllQueued.mockImplementation((_reason: string) => {
// In tests, we usually mock the queue or completed batch.
// For the sake of telemetry tests, we manually trigger if needed,
// but most tests here check if finalizing is called.
});
vi.mocked(ToolExecutor).mockReturnValue(
mockExecutor as unknown as Mocked<ToolExecutor>,
);

View File

@@ -101,7 +101,11 @@ export class Scheduler {
this.getPreferredEditor = options.getPreferredEditor;
this.schedulerId = options.schedulerId;
this.parentCallId = options.parentCallId;
this.state = new SchedulerStateManager(this.messageBus, this.schedulerId);
this.state = new SchedulerStateManager(
this.messageBus,
this.schedulerId,
(call) => logToolCall(this.config, new ToolCallEvent(call)),
);
this.executor = new ToolExecutor(this.config);
this.modifier = new ToolModificationHandler();
@@ -388,16 +392,6 @@ export class Scheduler {
}
}
// Fetch the updated call from state before finalizing to capture the
// terminal status.
const terminalCall = this.state.getToolCall(active.request.callId);
if (terminalCall && this.isTerminal(terminalCall.status)) {
logToolCall(
this.config,
new ToolCallEvent(terminalCall as CompletedToolCall),
);
}
this.state.finalizeCall(active.request.callId);
}
@@ -422,6 +416,7 @@ export class Scheduler {
ToolErrorType.POLICY_VIOLATION,
),
);
this.state.finalizeCall(callId);
return;
}
@@ -453,6 +448,7 @@ export class Scheduler {
// Handle cancellation (cascades to entire batch)
if (outcome === ToolConfirmationOutcome.Cancel) {
this.state.updateStatus(callId, 'cancelled', 'User denied execution.');
this.state.finalizeCall(callId);
this.state.cancelAllQueued('User cancelled operation');
return; // Skip execution
}

View File

@@ -23,6 +23,7 @@ import {
} from '../tools/tools.js';
import { MessageBusType } from '../confirmation-bus/types.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { ROOT_SCHEDULER_ID } from './types.js';
describe('SchedulerStateManager', () => {
const mockRequest: ToolCallRequestInfo = {
@@ -83,6 +84,60 @@ describe('SchedulerStateManager', () => {
stateManager = new SchedulerStateManager(mockMessageBus);
});
describe('Observer Callback', () => {
it('should trigger onTerminalCall when finalizing a call', () => {
const onTerminalCall = vi.fn();
const manager = new SchedulerStateManager(
mockMessageBus,
ROOT_SCHEDULER_ID,
onTerminalCall,
);
const call = createValidatingCall();
manager.enqueue([call]);
manager.dequeue();
manager.updateStatus(
call.request.callId,
'success',
createMockResponse(call.request.callId),
);
manager.finalizeCall(call.request.callId);
expect(onTerminalCall).toHaveBeenCalledTimes(1);
expect(onTerminalCall).toHaveBeenCalledWith(
expect.objectContaining({
status: 'success',
request: expect.objectContaining({ callId: call.request.callId }),
}),
);
});
it('should trigger onTerminalCall for every call in cancelAllQueued', () => {
const onTerminalCall = vi.fn();
const manager = new SchedulerStateManager(
mockMessageBus,
ROOT_SCHEDULER_ID,
onTerminalCall,
);
manager.enqueue([createValidatingCall('1'), createValidatingCall('2')]);
manager.cancelAllQueued('Test cancel');
expect(onTerminalCall).toHaveBeenCalledTimes(2);
expect(onTerminalCall).toHaveBeenCalledWith(
expect.objectContaining({
status: 'cancelled',
request: expect.objectContaining({ callId: '1' }),
}),
);
expect(onTerminalCall).toHaveBeenCalledWith(
expect.objectContaining({
status: 'cancelled',
request: expect.objectContaining({ callId: '2' }),
}),
);
});
});
describe('Initialization', () => {
it('should start with empty state', () => {
expect(stateManager.isActive).toBe(false);

View File

@@ -31,6 +31,11 @@ import {
type SerializableConfirmationDetails,
} from '../confirmation-bus/types.js';
/**
* Handler for terminal tool calls.
*/
export type TerminalCallHandler = (call: CompletedToolCall) => void;
/**
* Manages the state of tool calls.
* Publishes state changes to the MessageBus via TOOL_CALLS_UPDATE events.
@@ -43,6 +48,7 @@ export class SchedulerStateManager {
constructor(
private readonly messageBus: MessageBus,
private readonly schedulerId: string = ROOT_SCHEDULER_ID,
private readonly onTerminalCall?: TerminalCallHandler,
) {}
addToolCalls(calls: ToolCall[]): void {
@@ -134,6 +140,8 @@ export class SchedulerStateManager {
if (this.isTerminalCall(call)) {
this._completedBatch.push(call);
this.activeCalls.delete(callId);
this.onTerminalCall?.(call);
this.emitUpdate();
}
}
@@ -173,9 +181,12 @@ export class SchedulerStateManager {
const queuedCall = this.queue.shift()!;
if (queuedCall.status === 'error') {
this._completedBatch.push(queuedCall);
this.onTerminalCall?.(queuedCall);
continue;
}
this._completedBatch.push(this.toCancelled(queuedCall, reason));
const cancelledCall = this.toCancelled(queuedCall, reason);
this._completedBatch.push(cancelledCall);
this.onTerminalCall?.(cancelledCall);
}
this.emitUpdate();
}