mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 14:10:37 -07:00
feat(core): Implement parallel FC for read only tools. (#18791)
This commit is contained in:
1
integration-tests/parallel-tools.responses
Normal file
1
integration-tests/parallel-tools.responses
Normal file
@@ -0,0 +1 @@
|
||||
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"functionCall":{"name":"read_file","args":{"file_path":"file1.txt"}}},{"functionCall":{"name":"read_file","args":{"file_path":"file2.txt"}}},{"functionCall":{"name":"write_file","args":{"file_path":"output.txt","content":"wave2"}}},{"functionCall":{"name":"read_file","args":{"file_path":"file3.txt"}}},{"functionCall":{"name":"read_file","args":{"file_path":"file4.txt"}}}, {"text":"All waves completed successfully."}]},"finishReason":"STOP","index":0}]}]}
|
||||
77
integration-tests/parallel-tools.test.ts
Normal file
77
integration-tests/parallel-tools.test.ts
Normal file
@@ -0,0 +1,77 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import { TestRig } from './test-helper.js';
|
||||
import { join } from 'node:path';
|
||||
import fs from 'node:fs';
|
||||
|
||||
describe('Parallel Tool Execution Integration', () => {
|
||||
let rig: TestRig;
|
||||
|
||||
beforeEach(() => {
|
||||
rig = new TestRig();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await rig.cleanup();
|
||||
});
|
||||
|
||||
it('should execute [read, read, write, read, read] in correct waves with user approval', async () => {
|
||||
rig.setup('parallel-wave-execution', {
|
||||
fakeResponsesPath: join(import.meta.dirname, 'parallel-tools.responses'),
|
||||
settings: {
|
||||
tools: {
|
||||
core: ['read_file', 'write_file'],
|
||||
approval: 'ASK', // Disable YOLO mode to show permission prompts
|
||||
confirmationRequired: ['write_file'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
rig.createFile('file1.txt', 'c1');
|
||||
rig.createFile('file2.txt', 'c2');
|
||||
rig.createFile('file3.txt', 'c3');
|
||||
rig.createFile('file4.txt', 'c4');
|
||||
rig.sync();
|
||||
|
||||
const run = await rig.runInteractive({ approvalMode: 'default' });
|
||||
|
||||
// 1. Trigger the wave
|
||||
await run.type('ok');
|
||||
await run.type('\r');
|
||||
|
||||
// 3. Wait for the write_file prompt.
|
||||
await run.expectText('Allow', 5000);
|
||||
|
||||
// 4. Press Enter to approve the write_file.
|
||||
await run.type('y');
|
||||
await run.type('\r');
|
||||
|
||||
// 5. Wait for the final model response
|
||||
await run.expectText('All waves completed successfully.', 5000);
|
||||
|
||||
// Verify all tool calls were made and succeeded in the logs
|
||||
await rig.expectToolCallSuccess(['write_file']);
|
||||
const toolLogs = rig.readToolLogs();
|
||||
|
||||
const readFiles = toolLogs.filter(
|
||||
(l) => l.toolRequest.name === 'read_file',
|
||||
);
|
||||
const writeFiles = toolLogs.filter(
|
||||
(l) => l.toolRequest.name === 'write_file',
|
||||
);
|
||||
|
||||
expect(readFiles.length).toBe(4);
|
||||
expect(writeFiles.length).toBe(1);
|
||||
expect(toolLogs.every((l) => l.toolRequest.success)).toBe(true);
|
||||
|
||||
// Check that output.txt was actually written
|
||||
expect(fs.readFileSync(join(rig.testDir!, 'output.txt'), 'utf8')).toBe(
|
||||
'wave2',
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -17,10 +17,12 @@ import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import type {
|
||||
DeclarativeTool,
|
||||
ToolCallConfirmationDetails,
|
||||
ToolInvocation,
|
||||
ToolResult,
|
||||
} from '../tools/tools.js';
|
||||
import type { ToolRegistry } from 'src/tools/tool-registry.js';
|
||||
|
||||
vi.mock('./subagent-tool-wrapper.js');
|
||||
|
||||
@@ -274,3 +276,85 @@ describe('SubAgentInvocation', () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('SubagentTool Read-Only logic', () => {
|
||||
let mockConfig: Config;
|
||||
let mockMessageBus: MessageBus;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockConfig = makeFakeConfig();
|
||||
mockMessageBus = createMockMessageBus();
|
||||
});
|
||||
|
||||
it('should be false for remote agents', () => {
|
||||
const tool = new SubagentTool(
|
||||
testRemoteDefinition,
|
||||
mockConfig,
|
||||
mockMessageBus,
|
||||
);
|
||||
expect(tool.isReadOnly).toBe(false);
|
||||
});
|
||||
|
||||
it('should be true for local agent with only read-only tools', () => {
|
||||
const readOnlyTool = {
|
||||
name: 'read',
|
||||
isReadOnly: true,
|
||||
} as unknown as DeclarativeTool<object, ToolResult>;
|
||||
const registry = {
|
||||
getTool: (name: string) => (name === 'read' ? readOnlyTool : undefined),
|
||||
};
|
||||
vi.spyOn(mockConfig, 'getToolRegistry').mockReturnValue(
|
||||
registry as unknown as ToolRegistry,
|
||||
);
|
||||
|
||||
const defWithTools: LocalAgentDefinition = {
|
||||
...testDefinition,
|
||||
toolConfig: { tools: ['read'] },
|
||||
};
|
||||
const tool = new SubagentTool(defWithTools, mockConfig, mockMessageBus);
|
||||
expect(tool.isReadOnly).toBe(true);
|
||||
});
|
||||
|
||||
it('should be false for local agent with at least one non-read-only tool', () => {
|
||||
const readOnlyTool = {
|
||||
name: 'read',
|
||||
isReadOnly: true,
|
||||
} as unknown as DeclarativeTool<object, ToolResult>;
|
||||
const mutatorTool = {
|
||||
name: 'write',
|
||||
isReadOnly: false,
|
||||
} as unknown as DeclarativeTool<object, ToolResult>;
|
||||
const registry = {
|
||||
getTool: (name: string) => {
|
||||
if (name === 'read') return readOnlyTool;
|
||||
if (name === 'write') return mutatorTool;
|
||||
return undefined;
|
||||
},
|
||||
};
|
||||
vi.spyOn(mockConfig, 'getToolRegistry').mockReturnValue(
|
||||
registry as unknown as ToolRegistry,
|
||||
);
|
||||
|
||||
const defWithTools: LocalAgentDefinition = {
|
||||
...testDefinition,
|
||||
toolConfig: { tools: ['read', 'write'] },
|
||||
};
|
||||
const tool = new SubagentTool(defWithTools, mockConfig, mockMessageBus);
|
||||
expect(tool.isReadOnly).toBe(false);
|
||||
});
|
||||
|
||||
it('should be true for local agent with no tools', () => {
|
||||
const registry = { getTool: () => undefined };
|
||||
vi.spyOn(mockConfig, 'getToolRegistry').mockReturnValue(
|
||||
registry as unknown as ToolRegistry,
|
||||
);
|
||||
|
||||
const defNoTools: LocalAgentDefinition = {
|
||||
...testDefinition,
|
||||
toolConfig: { tools: [] },
|
||||
};
|
||||
const tool = new SubagentTool(defNoTools, mockConfig, mockMessageBus);
|
||||
expect(tool.isReadOnly).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
type ToolResult,
|
||||
BaseToolInvocation,
|
||||
type ToolCallConfirmationDetails,
|
||||
isTool,
|
||||
} from '../tools/tools.js';
|
||||
import type { AnsiOutput } from '../utils/terminalSerializer.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
@@ -48,6 +49,53 @@ export class SubagentTool extends BaseDeclarativeTool<AgentInputs, ToolResult> {
|
||||
);
|
||||
}
|
||||
|
||||
private _memoizedIsReadOnly: boolean | undefined;
|
||||
|
||||
override get isReadOnly(): boolean {
|
||||
if (this._memoizedIsReadOnly !== undefined) {
|
||||
return this._memoizedIsReadOnly;
|
||||
}
|
||||
// No try-catch here. If getToolRegistry() throws, we let it throw.
|
||||
// This is an invariant: you can't check read-only status if the system isn't initialized.
|
||||
this._memoizedIsReadOnly = SubagentTool.checkIsReadOnly(
|
||||
this.definition,
|
||||
this.config,
|
||||
);
|
||||
return this._memoizedIsReadOnly;
|
||||
}
|
||||
|
||||
private static checkIsReadOnly(
|
||||
definition: AgentDefinition,
|
||||
config: Config,
|
||||
): boolean {
|
||||
if (definition.kind === 'remote') {
|
||||
return false;
|
||||
}
|
||||
const tools = definition.toolConfig?.tools ?? [];
|
||||
const registry = config.getToolRegistry();
|
||||
|
||||
if (!registry) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const tool of tools) {
|
||||
if (typeof tool === 'string') {
|
||||
const resolvedTool = registry.getTool(tool);
|
||||
if (!resolvedTool || !resolvedTool.isReadOnly) {
|
||||
return false;
|
||||
}
|
||||
} else if (isTool(tool)) {
|
||||
if (!tool.isReadOnly) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// FunctionDeclaration - we don't know, so assume NOT read-only
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
protected createInvocation(
|
||||
params: AgentInputs,
|
||||
messageBus: MessageBus,
|
||||
|
||||
@@ -27,7 +27,6 @@ vi.mock('../telemetry/trace.js', () => ({
|
||||
}));
|
||||
|
||||
import { logToolCall } from '../telemetry/loggers.js';
|
||||
import { ToolCallEvent } from '../telemetry/types.js';
|
||||
vi.mock('../telemetry/loggers.js', () => ({
|
||||
logToolCall: vi.fn(),
|
||||
}));
|
||||
@@ -76,6 +75,8 @@ import type {
|
||||
CancelledToolCall,
|
||||
CompletedToolCall,
|
||||
ToolCallResponseInfo,
|
||||
Status,
|
||||
ToolCall,
|
||||
} from './types.js';
|
||||
import { CoreToolCallStatus, ROOT_SCHEDULER_ID } from './types.js';
|
||||
import { ToolErrorType } from '../tools/tool-error.js';
|
||||
@@ -168,29 +169,55 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
getPreferredEditor = vi.fn().mockReturnValue('vim');
|
||||
|
||||
// --- Setup Sub-component Mocks ---
|
||||
const mockActiveCallsMap = new Map<string, ToolCall>();
|
||||
const mockQueue: ToolCall[] = [];
|
||||
|
||||
mockStateManager = {
|
||||
enqueue: vi.fn(),
|
||||
dequeue: vi.fn(),
|
||||
getToolCall: vi.fn(),
|
||||
updateStatus: vi.fn(),
|
||||
finalizeCall: vi.fn(),
|
||||
enqueue: vi.fn((calls: ToolCall[]) => {
|
||||
// Clone to preserve initial state for Phase 1 tests
|
||||
mockQueue.push(...calls.map((c) => ({ ...c }) as ToolCall));
|
||||
}),
|
||||
dequeue: vi.fn(() => {
|
||||
const next = mockQueue.shift();
|
||||
if (next) mockActiveCallsMap.set(next.request.callId, next);
|
||||
return next;
|
||||
}),
|
||||
peekQueue: vi.fn(() => mockQueue[0]),
|
||||
getToolCall: vi.fn((id: string) => mockActiveCallsMap.get(id)),
|
||||
updateStatus: vi.fn((id: string, status: Status) => {
|
||||
const call = mockActiveCallsMap.get(id);
|
||||
if (call) (call as unknown as { status: Status }).status = status;
|
||||
}),
|
||||
finalizeCall: vi.fn((id: string) => {
|
||||
const call = mockActiveCallsMap.get(id);
|
||||
if (call) {
|
||||
mockActiveCallsMap.delete(id);
|
||||
capturedTerminalHandler?.(call as CompletedToolCall);
|
||||
}
|
||||
}),
|
||||
updateArgs: vi.fn(),
|
||||
setOutcome: vi.fn(),
|
||||
cancelAllQueued: vi.fn(),
|
||||
cancelAllQueued: vi.fn(() => {
|
||||
mockQueue.length = 0;
|
||||
}),
|
||||
clearBatch: vi.fn(),
|
||||
} as unknown as Mocked<SchedulerStateManager>;
|
||||
|
||||
// Define getters for accessors idiomatically
|
||||
Object.defineProperty(mockStateManager, 'isActive', {
|
||||
get: vi.fn().mockReturnValue(false),
|
||||
get: vi.fn(() => mockActiveCallsMap.size > 0),
|
||||
configurable: true,
|
||||
});
|
||||
Object.defineProperty(mockStateManager, 'allActiveCalls', {
|
||||
get: vi.fn(() => Array.from(mockActiveCallsMap.values())),
|
||||
configurable: true,
|
||||
});
|
||||
Object.defineProperty(mockStateManager, 'queueLength', {
|
||||
get: vi.fn().mockReturnValue(0),
|
||||
get: vi.fn(() => mockQueue.length),
|
||||
configurable: true,
|
||||
});
|
||||
Object.defineProperty(mockStateManager, 'firstActiveCall', {
|
||||
get: vi.fn().mockReturnValue(undefined),
|
||||
get: vi.fn(() => mockActiveCallsMap.values().next().value),
|
||||
configurable: true,
|
||||
});
|
||||
Object.defineProperty(mockStateManager, 'completedBatch', {
|
||||
@@ -227,8 +254,9 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
);
|
||||
|
||||
mockStateManager.finalizeCall.mockImplementation((callId: string) => {
|
||||
const call = mockStateManager.getToolCall(callId);
|
||||
const call = mockActiveCallsMap.get(callId);
|
||||
if (call) {
|
||||
mockActiveCallsMap.delete(callId);
|
||||
capturedTerminalHandler?.(call as CompletedToolCall);
|
||||
}
|
||||
});
|
||||
@@ -242,6 +270,13 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
vi.mocked(ToolExecutor).mockReturnValue(
|
||||
mockExecutor as unknown as Mocked<ToolExecutor>,
|
||||
);
|
||||
mockExecutor.execute.mockResolvedValue({
|
||||
status: 'success',
|
||||
response: {
|
||||
callId: 'default',
|
||||
responseParts: [],
|
||||
} as unknown as ToolCallResponseInfo,
|
||||
} as unknown as SuccessfulToolCall);
|
||||
vi.mocked(ToolModificationHandler).mockReturnValue(
|
||||
mockModifier as unknown as Mocked<ToolModificationHandler>,
|
||||
);
|
||||
@@ -339,35 +374,6 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
|
||||
describe('Phase 2: Queue Management', () => {
|
||||
it('should drain the queue if multiple calls are scheduled', async () => {
|
||||
const validatingCall: ValidatingToolCall = {
|
||||
status: CoreToolCallStatus.Validating,
|
||||
request: req1,
|
||||
tool: mockTool,
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
};
|
||||
|
||||
// Setup queue simulation: two items
|
||||
Object.defineProperty(mockStateManager, 'queueLength', {
|
||||
get: vi
|
||||
.fn()
|
||||
.mockReturnValueOnce(2)
|
||||
.mockReturnValueOnce(1)
|
||||
.mockReturnValue(0),
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
Object.defineProperty(mockStateManager, 'isActive', {
|
||||
get: vi.fn().mockReturnValue(false),
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
mockStateManager.dequeue.mockReturnValue(validatingCall);
|
||||
vi.mocked(mockStateManager.dequeue).mockReturnValue(validatingCall);
|
||||
Object.defineProperty(mockStateManager, 'firstActiveCall', {
|
||||
get: vi.fn().mockReturnValue(validatingCall),
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
// Execute is the end of the loop, stub it
|
||||
mockExecutor.execute.mockResolvedValue({
|
||||
status: CoreToolCallStatus.Success,
|
||||
@@ -375,56 +381,12 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
|
||||
await scheduler.schedule(req1, signal);
|
||||
|
||||
// Verify loop ran twice
|
||||
expect(mockStateManager.dequeue).toHaveBeenCalledTimes(2);
|
||||
expect(mockStateManager.finalizeCall).toHaveBeenCalledTimes(2);
|
||||
// Verify loop ran once for this schedule call (which had 1 request)
|
||||
// schedule(req1) enqueues 1 request.
|
||||
expect(mockExecutor.execute).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should execute tool calls sequentially (first completes before second starts)', async () => {
|
||||
// Setup queue simulation: two items
|
||||
Object.defineProperty(mockStateManager, 'queueLength', {
|
||||
get: vi
|
||||
.fn()
|
||||
.mockReturnValueOnce(2)
|
||||
.mockReturnValueOnce(1)
|
||||
.mockReturnValue(0),
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
Object.defineProperty(mockStateManager, 'isActive', {
|
||||
get: vi.fn().mockReturnValue(false),
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
const validatingCall1: ValidatingToolCall = {
|
||||
status: CoreToolCallStatus.Validating,
|
||||
request: req1,
|
||||
tool: mockTool,
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
};
|
||||
|
||||
const validatingCall2: ValidatingToolCall = {
|
||||
status: CoreToolCallStatus.Validating,
|
||||
request: req2,
|
||||
tool: mockTool,
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
};
|
||||
|
||||
vi.mocked(mockStateManager.dequeue)
|
||||
.mockReturnValueOnce(validatingCall1)
|
||||
.mockReturnValueOnce(validatingCall2)
|
||||
.mockReturnValue(undefined);
|
||||
|
||||
Object.defineProperty(mockStateManager, 'firstActiveCall', {
|
||||
get: vi
|
||||
.fn()
|
||||
.mockReturnValueOnce(validatingCall1) // Used in loop check for call 1
|
||||
.mockReturnValueOnce(validatingCall1) // Used in _execute for call 1
|
||||
.mockReturnValueOnce(validatingCall2) // Used in loop check for call 2
|
||||
.mockReturnValueOnce(validatingCall2), // Used in _execute for call 2
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
const executionLog: string[] = [];
|
||||
|
||||
// Mock executor to push to log with a deterministic microtask delay
|
||||
@@ -452,52 +414,6 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
});
|
||||
|
||||
it('should queue and process multiple schedule() calls made synchronously', async () => {
|
||||
const validatingCall1: ValidatingToolCall = {
|
||||
status: CoreToolCallStatus.Validating,
|
||||
request: req1,
|
||||
tool: mockTool,
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
};
|
||||
|
||||
const validatingCall2: ValidatingToolCall = {
|
||||
status: CoreToolCallStatus.Validating,
|
||||
request: req2, // Second request
|
||||
tool: mockTool,
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
};
|
||||
|
||||
// Mock state responses dynamically
|
||||
Object.defineProperty(mockStateManager, 'isActive', {
|
||||
get: vi.fn().mockReturnValue(false),
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
// Queue state responses for the two batches:
|
||||
// Batch 1: length 1 -> 0
|
||||
// Batch 2: length 1 -> 0
|
||||
Object.defineProperty(mockStateManager, 'queueLength', {
|
||||
get: vi
|
||||
.fn()
|
||||
.mockReturnValueOnce(1)
|
||||
.mockReturnValueOnce(0)
|
||||
.mockReturnValueOnce(1)
|
||||
.mockReturnValue(0),
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
vi.mocked(mockStateManager.dequeue)
|
||||
.mockReturnValueOnce(validatingCall1)
|
||||
.mockReturnValueOnce(validatingCall2);
|
||||
Object.defineProperty(mockStateManager, 'firstActiveCall', {
|
||||
get: vi
|
||||
.fn()
|
||||
.mockReturnValueOnce(validatingCall1)
|
||||
.mockReturnValueOnce(validatingCall1)
|
||||
.mockReturnValueOnce(validatingCall2)
|
||||
.mockReturnValueOnce(validatingCall2),
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
// Executor succeeds instantly
|
||||
mockExecutor.execute.mockResolvedValue({
|
||||
status: CoreToolCallStatus.Success,
|
||||
@@ -516,50 +432,6 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
});
|
||||
|
||||
it('should queue requests when scheduler is busy (overlapping batches)', async () => {
|
||||
const validatingCall1: ValidatingToolCall = {
|
||||
status: CoreToolCallStatus.Validating,
|
||||
request: req1,
|
||||
tool: mockTool,
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
};
|
||||
|
||||
const validatingCall2: ValidatingToolCall = {
|
||||
status: CoreToolCallStatus.Validating,
|
||||
request: req2, // Second request
|
||||
tool: mockTool,
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
};
|
||||
|
||||
// 1. Setup State Manager for 2 sequential batches
|
||||
Object.defineProperty(mockStateManager, 'isActive', {
|
||||
get: vi.fn().mockReturnValue(false),
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
Object.defineProperty(mockStateManager, 'queueLength', {
|
||||
get: vi
|
||||
.fn()
|
||||
.mockReturnValueOnce(1) // Batch 1
|
||||
.mockReturnValueOnce(0)
|
||||
.mockReturnValueOnce(1) // Batch 2
|
||||
.mockReturnValue(0),
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
vi.mocked(mockStateManager.dequeue)
|
||||
.mockReturnValueOnce(validatingCall1)
|
||||
.mockReturnValueOnce(validatingCall2);
|
||||
|
||||
Object.defineProperty(mockStateManager, 'firstActiveCall', {
|
||||
get: vi
|
||||
.fn()
|
||||
.mockReturnValueOnce(validatingCall1)
|
||||
.mockReturnValueOnce(validatingCall1)
|
||||
.mockReturnValueOnce(validatingCall2)
|
||||
.mockReturnValueOnce(validatingCall2),
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
// 2. Setup Executor with a controllable lock for the first batch
|
||||
const executionLog: string[] = [];
|
||||
let finishFirstBatch: (value: unknown) => void;
|
||||
@@ -635,10 +507,8 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
};
|
||||
|
||||
Object.defineProperty(mockStateManager, 'firstActiveCall', {
|
||||
get: vi.fn().mockReturnValue(activeCall),
|
||||
configurable: true,
|
||||
});
|
||||
mockStateManager.enqueue([activeCall]);
|
||||
mockStateManager.dequeue();
|
||||
|
||||
scheduler.cancelAll();
|
||||
|
||||
@@ -676,24 +546,7 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
});
|
||||
|
||||
describe('Phase 3: Policy & Confirmation Loop', () => {
|
||||
const validatingCall: ValidatingToolCall = {
|
||||
status: CoreToolCallStatus.Validating,
|
||||
request: req1,
|
||||
tool: mockTool,
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
Object.defineProperty(mockStateManager, 'queueLength', {
|
||||
get: vi.fn().mockReturnValueOnce(1).mockReturnValue(0),
|
||||
configurable: true,
|
||||
});
|
||||
vi.mocked(mockStateManager.dequeue).mockReturnValue(validatingCall);
|
||||
Object.defineProperty(mockStateManager, 'firstActiveCall', {
|
||||
get: vi.fn().mockReturnValue(validatingCall),
|
||||
configurable: true,
|
||||
});
|
||||
});
|
||||
beforeEach(() => {});
|
||||
|
||||
it('should update state to error with POLICY_VIOLATION if Policy returns DENY', async () => {
|
||||
vi.mocked(checkPolicy).mockResolvedValue({
|
||||
@@ -854,30 +707,6 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
});
|
||||
|
||||
it('should auto-approve remaining identical tools in batch after ProceedAlways', async () => {
|
||||
// Setup: two identical tools
|
||||
const validatingCall1: ValidatingToolCall = {
|
||||
status: CoreToolCallStatus.Validating,
|
||||
request: req1,
|
||||
tool: mockTool,
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
};
|
||||
const validatingCall2: ValidatingToolCall = {
|
||||
status: CoreToolCallStatus.Validating,
|
||||
request: req2,
|
||||
tool: mockTool,
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
};
|
||||
|
||||
vi.mocked(mockStateManager.dequeue)
|
||||
.mockReturnValueOnce(validatingCall1)
|
||||
.mockReturnValueOnce(validatingCall2)
|
||||
.mockReturnValue(undefined);
|
||||
|
||||
vi.spyOn(mockStateManager, 'queueLength', 'get')
|
||||
.mockReturnValueOnce(2)
|
||||
.mockReturnValueOnce(1)
|
||||
.mockReturnValue(0);
|
||||
|
||||
// First call requires confirmation, second is auto-approved (simulating policy update)
|
||||
vi.mocked(checkPolicy)
|
||||
.mockResolvedValueOnce({
|
||||
@@ -1045,21 +874,7 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
});
|
||||
|
||||
describe('Phase 4: Execution Outcomes', () => {
|
||||
const validatingCall: ValidatingToolCall = {
|
||||
status: CoreToolCallStatus.Validating,
|
||||
request: req1,
|
||||
tool: mockTool,
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.spyOn(mockStateManager, 'queueLength', 'get')
|
||||
.mockReturnValueOnce(1)
|
||||
.mockReturnValue(0);
|
||||
mockStateManager.dequeue.mockReturnValue(validatingCall);
|
||||
vi.spyOn(mockStateManager, 'firstActiveCall', 'get').mockReturnValue(
|
||||
validatingCall,
|
||||
);
|
||||
mockPolicyEngine.check.mockResolvedValue({
|
||||
decision: PolicyDecision.ALLOW,
|
||||
}); // Bypass confirmation
|
||||
@@ -1132,30 +947,12 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
response: mockResponse,
|
||||
} as unknown as SuccessfulToolCall);
|
||||
|
||||
// Mock the state manager to return a SUCCESS state when getToolCall is
|
||||
// called
|
||||
const successfulCall: SuccessfulToolCall = {
|
||||
status: CoreToolCallStatus.Success,
|
||||
request: req1,
|
||||
response: mockResponse,
|
||||
tool: mockTool,
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
};
|
||||
mockStateManager.getToolCall.mockReturnValue(successfulCall);
|
||||
Object.defineProperty(mockStateManager, 'completedBatch', {
|
||||
get: vi.fn().mockReturnValue([successfulCall]),
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
await scheduler.schedule(req1, signal);
|
||||
|
||||
// Verify the finalizer and logger were called
|
||||
expect(mockStateManager.finalizeCall).toHaveBeenCalledWith('call-1');
|
||||
expect(ToolCallEvent).toHaveBeenCalledWith(successfulCall);
|
||||
expect(logToolCall).toHaveBeenCalledWith(
|
||||
mockConfig,
|
||||
expect.objectContaining(successfulCall),
|
||||
);
|
||||
// We check that logToolCall was called (it's called via the state manager's terminal handler)
|
||||
expect(logToolCall).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not double-report completed tools when concurrent completions occur', async () => {
|
||||
@@ -1182,6 +979,33 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
expect(mockStateManager.finalizeCall).toHaveBeenCalledTimes(1);
|
||||
expect(mockStateManager.finalizeCall).toHaveBeenCalledWith('call-1');
|
||||
});
|
||||
|
||||
it('should break the loop if no progress is made (safeguard against stuck states)', async () => {
|
||||
// Setup: A tool that is 'validating' but stays 'validating' even after processing
|
||||
// This simulates a bug in state management or a weird edge case.
|
||||
const stuckCall: ValidatingToolCall = {
|
||||
status: CoreToolCallStatus.Validating,
|
||||
request: req1,
|
||||
tool: mockTool,
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
};
|
||||
|
||||
// Mock dequeue to keep returning the same stuck call
|
||||
mockStateManager.dequeue.mockReturnValue(stuckCall);
|
||||
// Mock isActive to be true
|
||||
Object.defineProperty(mockStateManager, 'isActive', {
|
||||
get: vi.fn().mockReturnValue(true),
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
// Mock updateStatus to do NOTHING (simulating no progress)
|
||||
mockStateManager.updateStatus.mockImplementation(() => {});
|
||||
|
||||
// This should return false (break loop) instead of hanging indefinitely
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const result = await (scheduler as any)._processNextItem(signal);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Tool Call Context Propagation', () => {
|
||||
@@ -1196,26 +1020,6 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
parentCallId,
|
||||
});
|
||||
|
||||
const validatingCall: ValidatingToolCall = {
|
||||
status: CoreToolCallStatus.Validating,
|
||||
request: req1,
|
||||
tool: mockTool,
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
};
|
||||
|
||||
// Mock queueLength to run the loop once
|
||||
Object.defineProperty(mockStateManager, 'queueLength', {
|
||||
get: vi.fn().mockReturnValueOnce(1).mockReturnValue(0),
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
vi.mocked(mockStateManager.dequeue).mockReturnValue(validatingCall);
|
||||
Object.defineProperty(mockStateManager, 'firstActiveCall', {
|
||||
get: vi.fn().mockReturnValue(validatingCall),
|
||||
configurable: true,
|
||||
});
|
||||
vi.mocked(mockStateManager.getToolCall).mockReturnValue(validatingCall);
|
||||
|
||||
mockToolRegistry.getTool.mockReturnValue(mockTool);
|
||||
mockPolicyEngine.check.mockResolvedValue({
|
||||
decision: PolicyDecision.ALLOW,
|
||||
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
type ValidatingToolCall,
|
||||
type ErroredToolCall,
|
||||
CoreToolCallStatus,
|
||||
type ScheduledToolCall,
|
||||
} from './types.js';
|
||||
import { ToolErrorType } from '../tools/tool-error.js';
|
||||
import type { ApprovalMode } from '../policy/types.js';
|
||||
@@ -231,14 +232,16 @@ export class Scheduler {
|
||||
next?.reject(new Error('Operation cancelled by user'));
|
||||
}
|
||||
|
||||
// Cancel active call
|
||||
const activeCall = this.state.firstActiveCall;
|
||||
if (activeCall && !this.isTerminal(activeCall.status)) {
|
||||
this.state.updateStatus(
|
||||
activeCall.request.callId,
|
||||
CoreToolCallStatus.Cancelled,
|
||||
'Operation cancelled by user',
|
||||
);
|
||||
// Cancel active calls
|
||||
const activeCalls = this.state.allActiveCalls;
|
||||
for (const activeCall of activeCalls) {
|
||||
if (!this.isTerminal(activeCall.status)) {
|
||||
this.state.updateStatus(
|
||||
activeCall.request.callId,
|
||||
CoreToolCallStatus.Cancelled,
|
||||
'Operation cancelled by user',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear queue
|
||||
@@ -384,6 +387,10 @@ export class Scheduler {
|
||||
return false;
|
||||
}
|
||||
|
||||
const initialStatuses = new Map(
|
||||
this.state.allActiveCalls.map((c) => [c.request.callId, c.status]),
|
||||
);
|
||||
|
||||
if (!this.state.isActive) {
|
||||
const next = this.state.dequeue();
|
||||
if (!next) return false;
|
||||
@@ -397,16 +404,91 @@ export class Scheduler {
|
||||
this.state.finalizeCall(next.request.callId);
|
||||
return true;
|
||||
}
|
||||
|
||||
// If the first tool is read-only, batch all contiguous read-only tools.
|
||||
if (next.tool?.isReadOnly) {
|
||||
while (this.state.queueLength > 0) {
|
||||
const peeked = this.state.peekQueue();
|
||||
if (peeked && peeked.tool?.isReadOnly) {
|
||||
this.state.dequeue();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const active = this.state.firstActiveCall;
|
||||
if (!active) return false;
|
||||
// Now we have one or more active calls. Move them through the lifecycle
|
||||
// as much as possible in this iteration.
|
||||
|
||||
if (active.status === CoreToolCallStatus.Validating) {
|
||||
await this._processValidatingCall(active, signal);
|
||||
// 1. Process all 'validating' calls (Policy & Confirmation)
|
||||
let activeCalls = this.state.allActiveCalls;
|
||||
const validatingCalls = activeCalls.filter(
|
||||
(c): c is ValidatingToolCall =>
|
||||
c.status === CoreToolCallStatus.Validating,
|
||||
);
|
||||
if (validatingCalls.length > 0) {
|
||||
await Promise.all(
|
||||
validatingCalls.map((c) => this._processValidatingCall(c, signal)),
|
||||
);
|
||||
}
|
||||
|
||||
return true;
|
||||
// 2. Execute scheduled calls
|
||||
// Refresh activeCalls as status might have changed to 'scheduled'
|
||||
activeCalls = this.state.allActiveCalls;
|
||||
const scheduledCalls = activeCalls.filter(
|
||||
(c): c is ScheduledToolCall => c.status === CoreToolCallStatus.Scheduled,
|
||||
);
|
||||
|
||||
// We only execute if ALL active calls are in a ready state (scheduled or terminal)
|
||||
const allReady = activeCalls.every(
|
||||
(c) =>
|
||||
c.status === CoreToolCallStatus.Scheduled || this.isTerminal(c.status),
|
||||
);
|
||||
|
||||
if (allReady && scheduledCalls.length > 0) {
|
||||
await Promise.all(scheduledCalls.map((c) => this._execute(c, signal)));
|
||||
}
|
||||
|
||||
// 3. Finalize terminal calls
|
||||
activeCalls = this.state.allActiveCalls;
|
||||
let madeProgress = false;
|
||||
for (const call of activeCalls) {
|
||||
if (this.isTerminal(call.status)) {
|
||||
this.state.finalizeCall(call.request.callId);
|
||||
madeProgress = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if any calls changed status during this iteration (excluding terminal finalization)
|
||||
const currentStatuses = new Map(
|
||||
activeCalls.map((c) => [c.request.callId, c.status]),
|
||||
);
|
||||
const anyStatusChanged = Array.from(initialStatuses.entries()).some(
|
||||
([id, status]) => currentStatuses.get(id) !== status,
|
||||
);
|
||||
|
||||
if (madeProgress || anyStatusChanged) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If we have active calls but NONE of them progressed, check if we are waiting for external events.
|
||||
// States that are 'waiting' from the loop's perspective: awaiting_approval, executing.
|
||||
const isWaitingForExternal = activeCalls.some(
|
||||
(c) =>
|
||||
c.status === CoreToolCallStatus.AwaitingApproval ||
|
||||
c.status === CoreToolCallStatus.Executing,
|
||||
);
|
||||
|
||||
if (isWaitingForExternal && this.state.isActive) {
|
||||
// Yield to the event loop to allow external events (tool completion, user input) to progress.
|
||||
await new Promise((resolve) => queueMicrotask(() => resolve(true)));
|
||||
return true;
|
||||
}
|
||||
|
||||
// If we are here, we have active calls (likely Validating or Scheduled) but none progressed.
|
||||
// This is a stuck state.
|
||||
return false;
|
||||
}
|
||||
|
||||
private async _processValidatingCall(
|
||||
@@ -437,8 +519,6 @@ export class Scheduler {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
this.state.finalizeCall(active.request.callId);
|
||||
}
|
||||
|
||||
// --- Phase 3: Single Call Orchestration ---
|
||||
@@ -467,7 +547,6 @@ export class Scheduler {
|
||||
errorType,
|
||||
),
|
||||
);
|
||||
this.state.finalizeCall(callId);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -506,13 +585,11 @@ export class Scheduler {
|
||||
CoreToolCallStatus.Cancelled,
|
||||
'User denied execution.',
|
||||
);
|
||||
this.state.finalizeCall(callId);
|
||||
this.state.cancelAllQueued('User cancelled operation');
|
||||
return; // Skip execution
|
||||
}
|
||||
|
||||
// Execution
|
||||
await this._execute(callId, signal);
|
||||
this.state.updateStatus(callId, CoreToolCallStatus.Scheduled);
|
||||
}
|
||||
|
||||
// --- Sub-phase Handlers ---
|
||||
@@ -520,13 +597,23 @@ export class Scheduler {
|
||||
/**
|
||||
* Executes the tool and records the result.
|
||||
*/
|
||||
private async _execute(callId: string, signal: AbortSignal): Promise<void> {
|
||||
this.state.updateStatus(callId, CoreToolCallStatus.Scheduled);
|
||||
if (signal.aborted) throw new Error('Operation cancelled');
|
||||
private async _execute(
|
||||
toolCall: ScheduledToolCall,
|
||||
signal: AbortSignal,
|
||||
): Promise<void> {
|
||||
const callId = toolCall.request.callId;
|
||||
if (signal.aborted) {
|
||||
this.state.updateStatus(
|
||||
callId,
|
||||
CoreToolCallStatus.Cancelled,
|
||||
'Operation cancelled',
|
||||
);
|
||||
return;
|
||||
}
|
||||
this.state.updateStatus(callId, CoreToolCallStatus.Executing);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const activeCall = this.state.firstActiveCall as ExecutingToolCall;
|
||||
const activeCall = this.state.getToolCall(callId) as ExecutingToolCall;
|
||||
|
||||
const result = await runWithToolCallContext(
|
||||
{
|
||||
|
||||
397
packages/core/src/scheduler/scheduler_parallel.test.ts
Normal file
397
packages/core/src/scheduler/scheduler_parallel.test.ts
Normal file
@@ -0,0 +1,397 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
type Mocked,
|
||||
} from 'vitest';
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
vi.mock('node:crypto', () => ({
|
||||
randomUUID: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../telemetry/trace.js', () => ({
|
||||
runInDevTraceSpan: vi.fn(async (_opts, fn) =>
|
||||
fn({ metadata: { input: {}, output: {} } }),
|
||||
),
|
||||
}));
|
||||
vi.mock('../telemetry/loggers.js', () => ({
|
||||
logToolCall: vi.fn(),
|
||||
}));
|
||||
vi.mock('../telemetry/types.js', () => ({
|
||||
ToolCallEvent: vi.fn().mockImplementation((call) => ({ ...call })),
|
||||
}));
|
||||
|
||||
import {
|
||||
SchedulerStateManager,
|
||||
type TerminalCallHandler,
|
||||
} from './state-manager.js';
|
||||
import { checkPolicy, updatePolicy } from './policy.js';
|
||||
import { ToolExecutor } from './tool-executor.js';
|
||||
import { ToolModificationHandler } from './tool-modifier.js';
|
||||
|
||||
vi.mock('./state-manager.js');
|
||||
vi.mock('./confirmation.js');
|
||||
vi.mock('./policy.js', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('./policy.js')>();
|
||||
return {
|
||||
...actual,
|
||||
checkPolicy: vi.fn(),
|
||||
updatePolicy: vi.fn(),
|
||||
};
|
||||
});
|
||||
vi.mock('./tool-executor.js');
|
||||
vi.mock('./tool-modifier.js');
|
||||
|
||||
import { Scheduler } from './scheduler.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import type { PolicyEngine } from '../policy/policy-engine.js';
|
||||
import type { ToolRegistry } from '../tools/tool-registry.js';
|
||||
import { ApprovalMode, PolicyDecision } from '../policy/types.js';
|
||||
import {
|
||||
type AnyDeclarativeTool,
|
||||
type AnyToolInvocation,
|
||||
} from '../tools/tools.js';
|
||||
import type {
|
||||
ToolCallRequestInfo,
|
||||
CompletedToolCall,
|
||||
SuccessfulToolCall,
|
||||
Status,
|
||||
ToolCall,
|
||||
} from './types.js';
|
||||
import { ROOT_SCHEDULER_ID } from './types.js';
|
||||
import type { EditorType } from '../utils/editor.js';
|
||||
|
||||
describe('Scheduler Parallel Execution', () => {
|
||||
let scheduler: Scheduler;
|
||||
let signal: AbortSignal;
|
||||
let abortController: AbortController;
|
||||
|
||||
let mockConfig: Mocked<Config>;
|
||||
let mockMessageBus: Mocked<MessageBus>;
|
||||
let mockPolicyEngine: Mocked<PolicyEngine>;
|
||||
let mockToolRegistry: Mocked<ToolRegistry>;
|
||||
let getPreferredEditor: Mock<() => EditorType | undefined>;
|
||||
|
||||
let mockStateManager: Mocked<SchedulerStateManager>;
|
||||
let mockExecutor: Mocked<ToolExecutor>;
|
||||
let mockModifier: Mocked<ToolModificationHandler>;
|
||||
|
||||
const req1: ToolCallRequestInfo = {
|
||||
callId: 'call-1',
|
||||
name: 'read-tool-1',
|
||||
args: { path: 'a.txt' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'p1',
|
||||
schedulerId: ROOT_SCHEDULER_ID,
|
||||
};
|
||||
|
||||
const req2: ToolCallRequestInfo = {
|
||||
callId: 'call-2',
|
||||
name: 'read-tool-2',
|
||||
args: { path: 'b.txt' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'p1',
|
||||
schedulerId: ROOT_SCHEDULER_ID,
|
||||
};
|
||||
|
||||
const req3: ToolCallRequestInfo = {
|
||||
callId: 'call-3',
|
||||
name: 'write-tool',
|
||||
args: { path: 'c.txt', content: 'hi' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'p1',
|
||||
schedulerId: ROOT_SCHEDULER_ID,
|
||||
};
|
||||
|
||||
const readTool1 = {
|
||||
name: 'read-tool-1',
|
||||
isReadOnly: true,
|
||||
build: vi.fn(),
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
const readTool2 = {
|
||||
name: 'read-tool-2',
|
||||
isReadOnly: true,
|
||||
build: vi.fn(),
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
const writeTool = {
|
||||
name: 'write-tool',
|
||||
isReadOnly: false,
|
||||
build: vi.fn(),
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
|
||||
const mockInvocation = {
|
||||
shouldConfirmExecute: vi.fn().mockResolvedValue(false),
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.mocked(randomUUID).mockReturnValue(
|
||||
'uuid' as unknown as `${string}-${string}-${string}-${string}-${string}`,
|
||||
);
|
||||
abortController = new AbortController();
|
||||
signal = abortController.signal;
|
||||
|
||||
mockPolicyEngine = {
|
||||
check: vi.fn().mockResolvedValue({ decision: PolicyDecision.ALLOW }),
|
||||
} as unknown as Mocked<PolicyEngine>;
|
||||
|
||||
mockToolRegistry = {
|
||||
getTool: vi.fn((name) => {
|
||||
if (name === 'read-tool-1') return readTool1;
|
||||
if (name === 'read-tool-2') return readTool2;
|
||||
if (name === 'write-tool') return writeTool;
|
||||
return undefined;
|
||||
}),
|
||||
getAllToolNames: vi
|
||||
.fn()
|
||||
.mockReturnValue(['read-tool-1', 'read-tool-2', 'write-tool']),
|
||||
} as unknown as Mocked<ToolRegistry>;
|
||||
|
||||
mockConfig = {
|
||||
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
|
||||
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||
isInteractive: vi.fn().mockReturnValue(true),
|
||||
getEnableHooks: vi.fn().mockReturnValue(true),
|
||||
setApprovalMode: vi.fn(),
|
||||
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
subscribe: vi.fn(),
|
||||
} as unknown as Mocked<MessageBus>;
|
||||
getPreferredEditor = vi.fn().mockReturnValue('vim');
|
||||
|
||||
vi.mocked(checkPolicy).mockReset();
|
||||
vi.mocked(checkPolicy).mockResolvedValue({
|
||||
decision: PolicyDecision.ALLOW,
|
||||
rule: undefined,
|
||||
});
|
||||
vi.mocked(updatePolicy).mockReset();
|
||||
|
||||
const mockActiveCallsMap = new Map<string, ToolCall>();
|
||||
const mockQueue: ToolCall[] = [];
|
||||
let capturedTerminalHandler: TerminalCallHandler | undefined;
|
||||
|
||||
mockStateManager = {
|
||||
enqueue: vi.fn((calls: ToolCall[]) => {
|
||||
mockQueue.push(...calls.map((c) => ({ ...c }) as ToolCall));
|
||||
}),
|
||||
dequeue: vi.fn(() => {
|
||||
const next = mockQueue.shift();
|
||||
if (next) mockActiveCallsMap.set(next.request.callId, next);
|
||||
return next;
|
||||
}),
|
||||
peekQueue: vi.fn(() => mockQueue[0]),
|
||||
getToolCall: vi.fn((id: string) => mockActiveCallsMap.get(id)),
|
||||
updateStatus: vi.fn((id: string, status: Status) => {
|
||||
const call = mockActiveCallsMap.get(id);
|
||||
if (call) (call as unknown as { status: Status }).status = status;
|
||||
}),
|
||||
finalizeCall: vi.fn((id: string) => {
|
||||
const call = mockActiveCallsMap.get(id);
|
||||
if (call) {
|
||||
mockActiveCallsMap.delete(id);
|
||||
capturedTerminalHandler?.(call as CompletedToolCall);
|
||||
}
|
||||
}),
|
||||
updateArgs: vi.fn(),
|
||||
setOutcome: vi.fn(),
|
||||
cancelAllQueued: vi.fn(() => {
|
||||
mockQueue.length = 0;
|
||||
}),
|
||||
clearBatch: vi.fn(),
|
||||
} as unknown as Mocked<SchedulerStateManager>;
|
||||
|
||||
Object.defineProperty(mockStateManager, 'isActive', {
|
||||
get: vi.fn(() => mockActiveCallsMap.size > 0),
|
||||
configurable: true,
|
||||
});
|
||||
Object.defineProperty(mockStateManager, 'allActiveCalls', {
|
||||
get: vi.fn(() => Array.from(mockActiveCallsMap.values())),
|
||||
configurable: true,
|
||||
});
|
||||
Object.defineProperty(mockStateManager, 'queueLength', {
|
||||
get: vi.fn(() => mockQueue.length),
|
||||
configurable: true,
|
||||
});
|
||||
Object.defineProperty(mockStateManager, 'firstActiveCall', {
|
||||
get: vi.fn(() => mockActiveCallsMap.values().next().value),
|
||||
configurable: true,
|
||||
});
|
||||
Object.defineProperty(mockStateManager, 'completedBatch', {
|
||||
get: vi.fn().mockReturnValue([]),
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
vi.mocked(SchedulerStateManager).mockImplementation(
|
||||
(_bus, _id, onTerminal) => {
|
||||
capturedTerminalHandler = onTerminal;
|
||||
return mockStateManager as unknown as SchedulerStateManager;
|
||||
},
|
||||
);
|
||||
|
||||
mockExecutor = { execute: vi.fn() } as unknown as Mocked<ToolExecutor>;
|
||||
vi.mocked(ToolExecutor).mockReturnValue(
|
||||
mockExecutor as unknown as Mocked<ToolExecutor>,
|
||||
);
|
||||
mockModifier = {
|
||||
handleModifyWithEditor: vi.fn(),
|
||||
applyInlineModify: vi.fn(),
|
||||
} as unknown as Mocked<ToolModificationHandler>;
|
||||
vi.mocked(ToolModificationHandler).mockReturnValue(
|
||||
mockModifier as unknown as Mocked<ToolModificationHandler>,
|
||||
);
|
||||
|
||||
scheduler = new Scheduler({
|
||||
config: mockConfig,
|
||||
messageBus: mockMessageBus,
|
||||
getPreferredEditor,
|
||||
schedulerId: 'root',
|
||||
});
|
||||
|
||||
vi.mocked(readTool1.build).mockReturnValue(
|
||||
mockInvocation as unknown as AnyToolInvocation,
|
||||
);
|
||||
vi.mocked(readTool2.build).mockReturnValue(
|
||||
mockInvocation as unknown as AnyToolInvocation,
|
||||
);
|
||||
vi.mocked(writeTool.build).mockReturnValue(
|
||||
mockInvocation as unknown as AnyToolInvocation,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should execute contiguous read-only tools in parallel', async () => {
|
||||
const executionLog: string[] = [];
|
||||
|
||||
mockExecutor.execute.mockImplementation(async ({ call }) => {
|
||||
const id = call.request.callId;
|
||||
executionLog.push(`start-${id}`);
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
executionLog.push(`end-${id}`);
|
||||
return {
|
||||
status: 'success',
|
||||
response: { callId: id, responseParts: [] },
|
||||
} as unknown as SuccessfulToolCall;
|
||||
});
|
||||
|
||||
// Schedule 2 read tools and 1 write tool
|
||||
await scheduler.schedule([req1, req2, req3], signal);
|
||||
|
||||
// Parallel read tools should start together
|
||||
expect(executionLog[0]).toBe('start-call-1');
|
||||
expect(executionLog[1]).toBe('start-call-2');
|
||||
|
||||
// They can finish in any order, but both must finish before call-3 starts
|
||||
expect(executionLog.indexOf('start-call-3')).toBeGreaterThan(
|
||||
executionLog.indexOf('end-call-1'),
|
||||
);
|
||||
expect(executionLog.indexOf('start-call-3')).toBeGreaterThan(
|
||||
executionLog.indexOf('end-call-2'),
|
||||
);
|
||||
|
||||
expect(executionLog).toContain('end-call-3');
|
||||
});
|
||||
|
||||
it('should execute non-read-only tools sequentially', async () => {
|
||||
const executionLog: string[] = [];
|
||||
|
||||
mockExecutor.execute.mockImplementation(async ({ call }) => {
|
||||
const id = call.request.callId;
|
||||
executionLog.push(`start-${id}`);
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
executionLog.push(`end-${id}`);
|
||||
return {
|
||||
status: 'success',
|
||||
response: { callId: id, responseParts: [] },
|
||||
} as unknown as SuccessfulToolCall;
|
||||
});
|
||||
|
||||
// req3 is NOT read-only
|
||||
await scheduler.schedule([req3, req1], signal);
|
||||
|
||||
// Should be strictly sequential
|
||||
expect(executionLog).toEqual([
|
||||
'start-call-3',
|
||||
'end-call-3',
|
||||
'start-call-1',
|
||||
'end-call-1',
|
||||
]);
|
||||
});
|
||||
|
||||
it('should execute [WRITE, READ, READ] as [sequential, parallel]', async () => {
|
||||
const executionLog: string[] = [];
|
||||
mockExecutor.execute.mockImplementation(async ({ call }) => {
|
||||
const id = call.request.callId;
|
||||
executionLog.push(`start-${id}`);
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
executionLog.push(`end-${id}`);
|
||||
return {
|
||||
status: 'success',
|
||||
response: { callId: id, responseParts: [] },
|
||||
} as unknown as SuccessfulToolCall;
|
||||
});
|
||||
|
||||
// req3 (WRITE), req1 (READ), req2 (READ)
|
||||
await scheduler.schedule([req3, req1, req2], signal);
|
||||
|
||||
// Order should be:
|
||||
// 1. write starts and ends
|
||||
// 2. read1 and read2 start together (parallel)
|
||||
expect(executionLog[0]).toBe('start-call-3');
|
||||
expect(executionLog[1]).toBe('end-call-3');
|
||||
expect(executionLog.slice(2, 4)).toContain('start-call-1');
|
||||
expect(executionLog.slice(2, 4)).toContain('start-call-2');
|
||||
});
|
||||
|
||||
it('should execute [READ, READ, WRITE, READ, READ] in three waves', async () => {
|
||||
const executionLog: string[] = [];
|
||||
mockExecutor.execute.mockImplementation(async ({ call }) => {
|
||||
const id = call.request.callId;
|
||||
executionLog.push(`start-${id}`);
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
executionLog.push(`end-${id}`);
|
||||
return {
|
||||
status: 'success',
|
||||
response: { callId: id, responseParts: [] },
|
||||
} as unknown as SuccessfulToolCall;
|
||||
});
|
||||
|
||||
const req4: ToolCallRequestInfo = { ...req1, callId: 'call-4' };
|
||||
const req5: ToolCallRequestInfo = { ...req2, callId: 'call-5' };
|
||||
|
||||
await scheduler.schedule([req1, req2, req3, req4, req5], signal);
|
||||
|
||||
// Wave 1: call-1, call-2 (parallel)
|
||||
expect(executionLog.slice(0, 2)).toContain('start-call-1');
|
||||
expect(executionLog.slice(0, 2)).toContain('start-call-2');
|
||||
|
||||
// Wave 2: call-3 (sequential)
|
||||
// Must start after both call-1 and call-2 end
|
||||
const start3 = executionLog.indexOf('start-call-3');
|
||||
expect(start3).toBeGreaterThan(executionLog.indexOf('end-call-1'));
|
||||
expect(start3).toBeGreaterThan(executionLog.indexOf('end-call-2'));
|
||||
const end3 = executionLog.indexOf('end-call-3');
|
||||
expect(end3).toBeGreaterThan(start3);
|
||||
|
||||
// Wave 3: call-4, call-5 (parallel)
|
||||
// Must start after call-3 ends
|
||||
expect(executionLog.indexOf('start-call-4')).toBeGreaterThan(end3);
|
||||
expect(executionLog.indexOf('start-call-5')).toBeGreaterThan(end3);
|
||||
});
|
||||
});
|
||||
@@ -78,10 +78,18 @@ export class SchedulerStateManager {
|
||||
return next;
|
||||
}
|
||||
|
||||
peekQueue(): ToolCall | undefined {
|
||||
return this.queue[0];
|
||||
}
|
||||
|
||||
get isActive(): boolean {
|
||||
return this.activeCalls.size > 0;
|
||||
}
|
||||
|
||||
get allActiveCalls(): ToolCall[] {
|
||||
return Array.from(this.activeCalls.values());
|
||||
}
|
||||
|
||||
get activeCallCount(): number {
|
||||
return this.activeCalls.size;
|
||||
}
|
||||
|
||||
@@ -247,7 +247,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
override readonly parameterSchema: unknown,
|
||||
messageBus: MessageBus,
|
||||
readonly trust?: boolean,
|
||||
readonly isReadOnly?: boolean,
|
||||
isReadOnly?: boolean,
|
||||
nameOverride?: string,
|
||||
private readonly cliConfig?: Config,
|
||||
override readonly extensionName?: string,
|
||||
@@ -265,6 +265,16 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
extensionName,
|
||||
extensionId,
|
||||
);
|
||||
this._isReadOnly = isReadOnly;
|
||||
}
|
||||
|
||||
private readonly _isReadOnly?: boolean;
|
||||
|
||||
override get isReadOnly(): boolean {
|
||||
if (this._isReadOnly !== undefined) {
|
||||
return this._isReadOnly;
|
||||
}
|
||||
return super.isReadOnly;
|
||||
}
|
||||
|
||||
getFullyQualifiedPrefix(): string {
|
||||
|
||||
@@ -9,6 +9,8 @@ import type { ToolInvocation, ToolResult } from './tools.js';
|
||||
import { DeclarativeTool, hasCycleInSchema, Kind } from './tools.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
import { ReadFileTool } from './read-file.js';
|
||||
import { makeFakeConfig } from '../test-utils/config.js';
|
||||
|
||||
class TestToolInvocation implements ToolInvocation<object, ToolResult> {
|
||||
constructor(
|
||||
@@ -238,3 +240,30 @@ describe('hasCycleInSchema', () => {
|
||||
expect(hasCycleInSchema({})).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Tools Read-Only property', () => {
|
||||
it('should have isReadOnly true for ReadFileTool', () => {
|
||||
const config = makeFakeConfig();
|
||||
const bus = createMockMessageBus();
|
||||
const tool = new ReadFileTool(config, bus);
|
||||
expect(tool.isReadOnly).toBe(true);
|
||||
});
|
||||
|
||||
it('should derive isReadOnly from Kind', () => {
|
||||
const bus = createMockMessageBus();
|
||||
class MyTool extends DeclarativeTool<object, ToolResult> {
|
||||
build(_params: object): ToolInvocation<object, ToolResult> {
|
||||
throw new Error('Not implemented');
|
||||
}
|
||||
}
|
||||
|
||||
const mutator = new MyTool('m', 'M', 'd', Kind.Edit, {}, bus);
|
||||
expect(mutator.isReadOnly).toBe(false);
|
||||
|
||||
const reader = new MyTool('r', 'R', 'd', Kind.Read, {}, bus);
|
||||
expect(reader.isReadOnly).toBe(true);
|
||||
|
||||
const searcher = new MyTool('s', 'S', 'd', Kind.Search, {}, bus);
|
||||
expect(searcher.isReadOnly).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -333,6 +333,11 @@ export interface ToolBuilder<
|
||||
*/
|
||||
canUpdateOutput: boolean;
|
||||
|
||||
/**
|
||||
* Whether the tool is read-only (has no side effects).
|
||||
*/
|
||||
isReadOnly: boolean;
|
||||
|
||||
/**
|
||||
* Validates raw parameters and builds a ready-to-execute invocation.
|
||||
* @param params The raw, untrusted parameters from the model.
|
||||
@@ -363,6 +368,10 @@ export abstract class DeclarativeTool<
|
||||
readonly extensionId?: string,
|
||||
) {}
|
||||
|
||||
get isReadOnly(): boolean {
|
||||
return READ_ONLY_KINDS.includes(this.kind);
|
||||
}
|
||||
|
||||
getSchema(_modelId?: string): FunctionDeclaration {
|
||||
return {
|
||||
name: this.name,
|
||||
@@ -819,6 +828,13 @@ export const MUTATOR_KINDS: Kind[] = [
|
||||
Kind.Execute,
|
||||
] as const;
|
||||
|
||||
// Function kinds that are safe to run in parallel
|
||||
export const READ_ONLY_KINDS: Kind[] = [
|
||||
Kind.Read,
|
||||
Kind.Search,
|
||||
Kind.Fetch,
|
||||
] as const;
|
||||
|
||||
export interface ToolLocation {
|
||||
// Absolute path to the file
|
||||
path: string;
|
||||
|
||||
Reference in New Issue
Block a user