feat(core): Implement parallel FC for read only tools. (#18791)

This commit is contained in:
joshualitt
2026-02-19 16:38:22 -08:00
committed by GitHub
parent 2ac39b6acc
commit 6351352e54
11 changed files with 862 additions and 301 deletions
+81 -277
View File
@@ -27,7 +27,6 @@ vi.mock('../telemetry/trace.js', () => ({
}));
import { logToolCall } from '../telemetry/loggers.js';
import { ToolCallEvent } from '../telemetry/types.js';
vi.mock('../telemetry/loggers.js', () => ({
logToolCall: vi.fn(),
}));
@@ -76,6 +75,8 @@ import type {
CancelledToolCall,
CompletedToolCall,
ToolCallResponseInfo,
Status,
ToolCall,
} from './types.js';
import { CoreToolCallStatus, ROOT_SCHEDULER_ID } from './types.js';
import { ToolErrorType } from '../tools/tool-error.js';
@@ -168,29 +169,55 @@ describe('Scheduler (Orchestrator)', () => {
getPreferredEditor = vi.fn().mockReturnValue('vim');
// --- Setup Sub-component Mocks ---
const mockActiveCallsMap = new Map<string, ToolCall>();
const mockQueue: ToolCall[] = [];
mockStateManager = {
enqueue: vi.fn(),
dequeue: vi.fn(),
getToolCall: vi.fn(),
updateStatus: vi.fn(),
finalizeCall: vi.fn(),
enqueue: vi.fn((calls: ToolCall[]) => {
// Clone to preserve initial state for Phase 1 tests
mockQueue.push(...calls.map((c) => ({ ...c }) as ToolCall));
}),
dequeue: vi.fn(() => {
const next = mockQueue.shift();
if (next) mockActiveCallsMap.set(next.request.callId, next);
return next;
}),
peekQueue: vi.fn(() => mockQueue[0]),
getToolCall: vi.fn((id: string) => mockActiveCallsMap.get(id)),
updateStatus: vi.fn((id: string, status: Status) => {
const call = mockActiveCallsMap.get(id);
if (call) (call as unknown as { status: Status }).status = status;
}),
finalizeCall: vi.fn((id: string) => {
const call = mockActiveCallsMap.get(id);
if (call) {
mockActiveCallsMap.delete(id);
capturedTerminalHandler?.(call as CompletedToolCall);
}
}),
updateArgs: vi.fn(),
setOutcome: vi.fn(),
cancelAllQueued: vi.fn(),
cancelAllQueued: vi.fn(() => {
mockQueue.length = 0;
}),
clearBatch: vi.fn(),
} as unknown as Mocked<SchedulerStateManager>;
// Define getters for accessors idiomatically
Object.defineProperty(mockStateManager, 'isActive', {
get: vi.fn().mockReturnValue(false),
get: vi.fn(() => mockActiveCallsMap.size > 0),
configurable: true,
});
Object.defineProperty(mockStateManager, 'allActiveCalls', {
get: vi.fn(() => Array.from(mockActiveCallsMap.values())),
configurable: true,
});
Object.defineProperty(mockStateManager, 'queueLength', {
get: vi.fn().mockReturnValue(0),
get: vi.fn(() => mockQueue.length),
configurable: true,
});
Object.defineProperty(mockStateManager, 'firstActiveCall', {
get: vi.fn().mockReturnValue(undefined),
get: vi.fn(() => mockActiveCallsMap.values().next().value),
configurable: true,
});
Object.defineProperty(mockStateManager, 'completedBatch', {
@@ -227,8 +254,9 @@ describe('Scheduler (Orchestrator)', () => {
);
mockStateManager.finalizeCall.mockImplementation((callId: string) => {
const call = mockStateManager.getToolCall(callId);
const call = mockActiveCallsMap.get(callId);
if (call) {
mockActiveCallsMap.delete(callId);
capturedTerminalHandler?.(call as CompletedToolCall);
}
});
@@ -242,6 +270,13 @@ describe('Scheduler (Orchestrator)', () => {
vi.mocked(ToolExecutor).mockReturnValue(
mockExecutor as unknown as Mocked<ToolExecutor>,
);
mockExecutor.execute.mockResolvedValue({
status: 'success',
response: {
callId: 'default',
responseParts: [],
} as unknown as ToolCallResponseInfo,
} as unknown as SuccessfulToolCall);
vi.mocked(ToolModificationHandler).mockReturnValue(
mockModifier as unknown as Mocked<ToolModificationHandler>,
);
@@ -339,35 +374,6 @@ describe('Scheduler (Orchestrator)', () => {
describe('Phase 2: Queue Management', () => {
it('should drain the queue if multiple calls are scheduled', async () => {
const validatingCall: ValidatingToolCall = {
status: CoreToolCallStatus.Validating,
request: req1,
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
};
// Setup queue simulation: two items
Object.defineProperty(mockStateManager, 'queueLength', {
get: vi
.fn()
.mockReturnValueOnce(2)
.mockReturnValueOnce(1)
.mockReturnValue(0),
configurable: true,
});
Object.defineProperty(mockStateManager, 'isActive', {
get: vi.fn().mockReturnValue(false),
configurable: true,
});
mockStateManager.dequeue.mockReturnValue(validatingCall);
vi.mocked(mockStateManager.dequeue).mockReturnValue(validatingCall);
Object.defineProperty(mockStateManager, 'firstActiveCall', {
get: vi.fn().mockReturnValue(validatingCall),
configurable: true,
});
// Execute is the end of the loop, stub it
mockExecutor.execute.mockResolvedValue({
status: CoreToolCallStatus.Success,
@@ -375,56 +381,12 @@ describe('Scheduler (Orchestrator)', () => {
await scheduler.schedule(req1, signal);
// Verify loop ran twice
expect(mockStateManager.dequeue).toHaveBeenCalledTimes(2);
expect(mockStateManager.finalizeCall).toHaveBeenCalledTimes(2);
// Verify loop ran once for this schedule call (which had 1 request)
// schedule(req1) enqueues 1 request.
expect(mockExecutor.execute).toHaveBeenCalledTimes(1);
});
it('should execute tool calls sequentially (first completes before second starts)', async () => {
// Setup queue simulation: two items
Object.defineProperty(mockStateManager, 'queueLength', {
get: vi
.fn()
.mockReturnValueOnce(2)
.mockReturnValueOnce(1)
.mockReturnValue(0),
configurable: true,
});
Object.defineProperty(mockStateManager, 'isActive', {
get: vi.fn().mockReturnValue(false),
configurable: true,
});
const validatingCall1: ValidatingToolCall = {
status: CoreToolCallStatus.Validating,
request: req1,
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
};
const validatingCall2: ValidatingToolCall = {
status: CoreToolCallStatus.Validating,
request: req2,
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
};
vi.mocked(mockStateManager.dequeue)
.mockReturnValueOnce(validatingCall1)
.mockReturnValueOnce(validatingCall2)
.mockReturnValue(undefined);
Object.defineProperty(mockStateManager, 'firstActiveCall', {
get: vi
.fn()
.mockReturnValueOnce(validatingCall1) // Used in loop check for call 1
.mockReturnValueOnce(validatingCall1) // Used in _execute for call 1
.mockReturnValueOnce(validatingCall2) // Used in loop check for call 2
.mockReturnValueOnce(validatingCall2), // Used in _execute for call 2
configurable: true,
});
const executionLog: string[] = [];
// Mock executor to push to log with a deterministic microtask delay
@@ -452,52 +414,6 @@ describe('Scheduler (Orchestrator)', () => {
});
it('should queue and process multiple schedule() calls made synchronously', async () => {
const validatingCall1: ValidatingToolCall = {
status: CoreToolCallStatus.Validating,
request: req1,
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
};
const validatingCall2: ValidatingToolCall = {
status: CoreToolCallStatus.Validating,
request: req2, // Second request
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
};
// Mock state responses dynamically
Object.defineProperty(mockStateManager, 'isActive', {
get: vi.fn().mockReturnValue(false),
configurable: true,
});
// Queue state responses for the two batches:
// Batch 1: length 1 -> 0
// Batch 2: length 1 -> 0
Object.defineProperty(mockStateManager, 'queueLength', {
get: vi
.fn()
.mockReturnValueOnce(1)
.mockReturnValueOnce(0)
.mockReturnValueOnce(1)
.mockReturnValue(0),
configurable: true,
});
vi.mocked(mockStateManager.dequeue)
.mockReturnValueOnce(validatingCall1)
.mockReturnValueOnce(validatingCall2);
Object.defineProperty(mockStateManager, 'firstActiveCall', {
get: vi
.fn()
.mockReturnValueOnce(validatingCall1)
.mockReturnValueOnce(validatingCall1)
.mockReturnValueOnce(validatingCall2)
.mockReturnValueOnce(validatingCall2),
configurable: true,
});
// Executor succeeds instantly
mockExecutor.execute.mockResolvedValue({
status: CoreToolCallStatus.Success,
@@ -516,50 +432,6 @@ describe('Scheduler (Orchestrator)', () => {
});
it('should queue requests when scheduler is busy (overlapping batches)', async () => {
const validatingCall1: ValidatingToolCall = {
status: CoreToolCallStatus.Validating,
request: req1,
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
};
const validatingCall2: ValidatingToolCall = {
status: CoreToolCallStatus.Validating,
request: req2, // Second request
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
};
// 1. Setup State Manager for 2 sequential batches
Object.defineProperty(mockStateManager, 'isActive', {
get: vi.fn().mockReturnValue(false),
configurable: true,
});
Object.defineProperty(mockStateManager, 'queueLength', {
get: vi
.fn()
.mockReturnValueOnce(1) // Batch 1
.mockReturnValueOnce(0)
.mockReturnValueOnce(1) // Batch 2
.mockReturnValue(0),
configurable: true,
});
vi.mocked(mockStateManager.dequeue)
.mockReturnValueOnce(validatingCall1)
.mockReturnValueOnce(validatingCall2);
Object.defineProperty(mockStateManager, 'firstActiveCall', {
get: vi
.fn()
.mockReturnValueOnce(validatingCall1)
.mockReturnValueOnce(validatingCall1)
.mockReturnValueOnce(validatingCall2)
.mockReturnValueOnce(validatingCall2),
configurable: true,
});
// 2. Setup Executor with a controllable lock for the first batch
const executionLog: string[] = [];
let finishFirstBatch: (value: unknown) => void;
@@ -635,10 +507,8 @@ describe('Scheduler (Orchestrator)', () => {
invocation: mockInvocation as unknown as AnyToolInvocation,
};
Object.defineProperty(mockStateManager, 'firstActiveCall', {
get: vi.fn().mockReturnValue(activeCall),
configurable: true,
});
mockStateManager.enqueue([activeCall]);
mockStateManager.dequeue();
scheduler.cancelAll();
@@ -676,24 +546,7 @@ describe('Scheduler (Orchestrator)', () => {
});
describe('Phase 3: Policy & Confirmation Loop', () => {
const validatingCall: ValidatingToolCall = {
status: CoreToolCallStatus.Validating,
request: req1,
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
};
beforeEach(() => {
Object.defineProperty(mockStateManager, 'queueLength', {
get: vi.fn().mockReturnValueOnce(1).mockReturnValue(0),
configurable: true,
});
vi.mocked(mockStateManager.dequeue).mockReturnValue(validatingCall);
Object.defineProperty(mockStateManager, 'firstActiveCall', {
get: vi.fn().mockReturnValue(validatingCall),
configurable: true,
});
});
beforeEach(() => {});
it('should update state to error with POLICY_VIOLATION if Policy returns DENY', async () => {
vi.mocked(checkPolicy).mockResolvedValue({
@@ -854,30 +707,6 @@ describe('Scheduler (Orchestrator)', () => {
});
it('should auto-approve remaining identical tools in batch after ProceedAlways', async () => {
// Setup: two identical tools
const validatingCall1: ValidatingToolCall = {
status: CoreToolCallStatus.Validating,
request: req1,
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
};
const validatingCall2: ValidatingToolCall = {
status: CoreToolCallStatus.Validating,
request: req2,
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
};
vi.mocked(mockStateManager.dequeue)
.mockReturnValueOnce(validatingCall1)
.mockReturnValueOnce(validatingCall2)
.mockReturnValue(undefined);
vi.spyOn(mockStateManager, 'queueLength', 'get')
.mockReturnValueOnce(2)
.mockReturnValueOnce(1)
.mockReturnValue(0);
// First call requires confirmation, second is auto-approved (simulating policy update)
vi.mocked(checkPolicy)
.mockResolvedValueOnce({
@@ -1045,21 +874,7 @@ describe('Scheduler (Orchestrator)', () => {
});
describe('Phase 4: Execution Outcomes', () => {
const validatingCall: ValidatingToolCall = {
status: CoreToolCallStatus.Validating,
request: req1,
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
};
beforeEach(() => {
vi.spyOn(mockStateManager, 'queueLength', 'get')
.mockReturnValueOnce(1)
.mockReturnValue(0);
mockStateManager.dequeue.mockReturnValue(validatingCall);
vi.spyOn(mockStateManager, 'firstActiveCall', 'get').mockReturnValue(
validatingCall,
);
mockPolicyEngine.check.mockResolvedValue({
decision: PolicyDecision.ALLOW,
}); // Bypass confirmation
@@ -1132,30 +947,12 @@ describe('Scheduler (Orchestrator)', () => {
response: mockResponse,
} as unknown as SuccessfulToolCall);
// Mock the state manager to return a SUCCESS state when getToolCall is
// called
const successfulCall: SuccessfulToolCall = {
status: CoreToolCallStatus.Success,
request: req1,
response: mockResponse,
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
};
mockStateManager.getToolCall.mockReturnValue(successfulCall);
Object.defineProperty(mockStateManager, 'completedBatch', {
get: vi.fn().mockReturnValue([successfulCall]),
configurable: true,
});
await scheduler.schedule(req1, signal);
// Verify the finalizer and logger were called
expect(mockStateManager.finalizeCall).toHaveBeenCalledWith('call-1');
expect(ToolCallEvent).toHaveBeenCalledWith(successfulCall);
expect(logToolCall).toHaveBeenCalledWith(
mockConfig,
expect.objectContaining(successfulCall),
);
// We check that logToolCall was called (it's called via the state manager's terminal handler)
expect(logToolCall).toHaveBeenCalled();
});
it('should not double-report completed tools when concurrent completions occur', async () => {
@@ -1182,6 +979,33 @@ describe('Scheduler (Orchestrator)', () => {
expect(mockStateManager.finalizeCall).toHaveBeenCalledTimes(1);
expect(mockStateManager.finalizeCall).toHaveBeenCalledWith('call-1');
});
it('should break the loop if no progress is made (safeguard against stuck states)', async () => {
// Setup: A tool that is 'validating' but stays 'validating' even after processing
// This simulates a bug in state management or a weird edge case.
const stuckCall: ValidatingToolCall = {
status: CoreToolCallStatus.Validating,
request: req1,
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
};
// Mock dequeue to keep returning the same stuck call
mockStateManager.dequeue.mockReturnValue(stuckCall);
// Mock isActive to be true
Object.defineProperty(mockStateManager, 'isActive', {
get: vi.fn().mockReturnValue(true),
configurable: true,
});
// Mock updateStatus to do NOTHING (simulating no progress)
mockStateManager.updateStatus.mockImplementation(() => {});
// This should return false (break loop) instead of hanging indefinitely
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const result = await (scheduler as any)._processNextItem(signal);
expect(result).toBe(false);
});
});
describe('Tool Call Context Propagation', () => {
@@ -1196,26 +1020,6 @@ describe('Scheduler (Orchestrator)', () => {
parentCallId,
});
const validatingCall: ValidatingToolCall = {
status: CoreToolCallStatus.Validating,
request: req1,
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
};
// Mock queueLength to run the loop once
Object.defineProperty(mockStateManager, 'queueLength', {
get: vi.fn().mockReturnValueOnce(1).mockReturnValue(0),
configurable: true,
});
vi.mocked(mockStateManager.dequeue).mockReturnValue(validatingCall);
Object.defineProperty(mockStateManager, 'firstActiveCall', {
get: vi.fn().mockReturnValue(validatingCall),
configurable: true,
});
vi.mocked(mockStateManager.getToolCall).mockReturnValue(validatingCall);
mockToolRegistry.getTool.mockReturnValue(mockTool);
mockPolicyEngine.check.mockResolvedValue({
decision: PolicyDecision.ALLOW,
+110 -23
View File
@@ -20,6 +20,7 @@ import {
type ValidatingToolCall,
type ErroredToolCall,
CoreToolCallStatus,
type ScheduledToolCall,
} from './types.js';
import { ToolErrorType } from '../tools/tool-error.js';
import type { ApprovalMode } from '../policy/types.js';
@@ -231,14 +232,16 @@ export class Scheduler {
next?.reject(new Error('Operation cancelled by user'));
}
// Cancel active call
const activeCall = this.state.firstActiveCall;
if (activeCall && !this.isTerminal(activeCall.status)) {
this.state.updateStatus(
activeCall.request.callId,
CoreToolCallStatus.Cancelled,
'Operation cancelled by user',
);
// Cancel active calls
const activeCalls = this.state.allActiveCalls;
for (const activeCall of activeCalls) {
if (!this.isTerminal(activeCall.status)) {
this.state.updateStatus(
activeCall.request.callId,
CoreToolCallStatus.Cancelled,
'Operation cancelled by user',
);
}
}
// Clear queue
@@ -384,6 +387,10 @@ export class Scheduler {
return false;
}
const initialStatuses = new Map(
this.state.allActiveCalls.map((c) => [c.request.callId, c.status]),
);
if (!this.state.isActive) {
const next = this.state.dequeue();
if (!next) return false;
@@ -397,16 +404,91 @@ export class Scheduler {
this.state.finalizeCall(next.request.callId);
return true;
}
// If the first tool is read-only, batch all contiguous read-only tools.
if (next.tool?.isReadOnly) {
while (this.state.queueLength > 0) {
const peeked = this.state.peekQueue();
if (peeked && peeked.tool?.isReadOnly) {
this.state.dequeue();
} else {
break;
}
}
}
}
const active = this.state.firstActiveCall;
if (!active) return false;
// Now we have one or more active calls. Move them through the lifecycle
// as much as possible in this iteration.
if (active.status === CoreToolCallStatus.Validating) {
await this._processValidatingCall(active, signal);
// 1. Process all 'validating' calls (Policy & Confirmation)
let activeCalls = this.state.allActiveCalls;
const validatingCalls = activeCalls.filter(
(c): c is ValidatingToolCall =>
c.status === CoreToolCallStatus.Validating,
);
if (validatingCalls.length > 0) {
await Promise.all(
validatingCalls.map((c) => this._processValidatingCall(c, signal)),
);
}
return true;
// 2. Execute scheduled calls
// Refresh activeCalls as status might have changed to 'scheduled'
activeCalls = this.state.allActiveCalls;
const scheduledCalls = activeCalls.filter(
(c): c is ScheduledToolCall => c.status === CoreToolCallStatus.Scheduled,
);
// We only execute if ALL active calls are in a ready state (scheduled or terminal)
const allReady = activeCalls.every(
(c) =>
c.status === CoreToolCallStatus.Scheduled || this.isTerminal(c.status),
);
if (allReady && scheduledCalls.length > 0) {
await Promise.all(scheduledCalls.map((c) => this._execute(c, signal)));
}
// 3. Finalize terminal calls
activeCalls = this.state.allActiveCalls;
let madeProgress = false;
for (const call of activeCalls) {
if (this.isTerminal(call.status)) {
this.state.finalizeCall(call.request.callId);
madeProgress = true;
}
}
// Check if any calls changed status during this iteration (excluding terminal finalization)
const currentStatuses = new Map(
activeCalls.map((c) => [c.request.callId, c.status]),
);
const anyStatusChanged = Array.from(initialStatuses.entries()).some(
([id, status]) => currentStatuses.get(id) !== status,
);
if (madeProgress || anyStatusChanged) {
return true;
}
// If we have active calls but NONE of them progressed, check if we are waiting for external events.
// States that are 'waiting' from the loop's perspective: awaiting_approval, executing.
const isWaitingForExternal = activeCalls.some(
(c) =>
c.status === CoreToolCallStatus.AwaitingApproval ||
c.status === CoreToolCallStatus.Executing,
);
if (isWaitingForExternal && this.state.isActive) {
// Yield to the event loop to allow external events (tool completion, user input) to progress.
await new Promise((resolve) => queueMicrotask(() => resolve(true)));
return true;
}
// If we are here, we have active calls (likely Validating or Scheduled) but none progressed.
// This is a stuck state.
return false;
}
private async _processValidatingCall(
@@ -437,8 +519,6 @@ export class Scheduler {
);
}
}
this.state.finalizeCall(active.request.callId);
}
// --- Phase 3: Single Call Orchestration ---
@@ -467,7 +547,6 @@ export class Scheduler {
errorType,
),
);
this.state.finalizeCall(callId);
return;
}
@@ -506,13 +585,11 @@ export class Scheduler {
CoreToolCallStatus.Cancelled,
'User denied execution.',
);
this.state.finalizeCall(callId);
this.state.cancelAllQueued('User cancelled operation');
return; // Skip execution
}
// Execution
await this._execute(callId, signal);
this.state.updateStatus(callId, CoreToolCallStatus.Scheduled);
}
// --- Sub-phase Handlers ---
@@ -520,13 +597,23 @@ export class Scheduler {
/**
* Executes the tool and records the result.
*/
private async _execute(callId: string, signal: AbortSignal): Promise<void> {
this.state.updateStatus(callId, CoreToolCallStatus.Scheduled);
if (signal.aborted) throw new Error('Operation cancelled');
private async _execute(
toolCall: ScheduledToolCall,
signal: AbortSignal,
): Promise<void> {
const callId = toolCall.request.callId;
if (signal.aborted) {
this.state.updateStatus(
callId,
CoreToolCallStatus.Cancelled,
'Operation cancelled',
);
return;
}
this.state.updateStatus(callId, CoreToolCallStatus.Executing);
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const activeCall = this.state.firstActiveCall as ExecutingToolCall;
const activeCall = this.state.getToolCall(callId) as ExecutingToolCall;
const result = await runWithToolCallContext(
{
@@ -0,0 +1,397 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
describe,
it,
expect,
vi,
beforeEach,
afterEach,
type Mock,
type Mocked,
} from 'vitest';
import { randomUUID } from 'node:crypto';
vi.mock('node:crypto', () => ({
randomUUID: vi.fn(),
}));
vi.mock('../telemetry/trace.js', () => ({
runInDevTraceSpan: vi.fn(async (_opts, fn) =>
fn({ metadata: { input: {}, output: {} } }),
),
}));
vi.mock('../telemetry/loggers.js', () => ({
logToolCall: vi.fn(),
}));
vi.mock('../telemetry/types.js', () => ({
ToolCallEvent: vi.fn().mockImplementation((call) => ({ ...call })),
}));
import {
SchedulerStateManager,
type TerminalCallHandler,
} from './state-manager.js';
import { checkPolicy, updatePolicy } from './policy.js';
import { ToolExecutor } from './tool-executor.js';
import { ToolModificationHandler } from './tool-modifier.js';
vi.mock('./state-manager.js');
vi.mock('./confirmation.js');
vi.mock('./policy.js', async (importOriginal) => {
const actual = await importOriginal<typeof import('./policy.js')>();
return {
...actual,
checkPolicy: vi.fn(),
updatePolicy: vi.fn(),
};
});
vi.mock('./tool-executor.js');
vi.mock('./tool-modifier.js');
import { Scheduler } from './scheduler.js';
import type { Config } from '../config/config.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import type { PolicyEngine } from '../policy/policy-engine.js';
import type { ToolRegistry } from '../tools/tool-registry.js';
import { ApprovalMode, PolicyDecision } from '../policy/types.js';
import {
type AnyDeclarativeTool,
type AnyToolInvocation,
} from '../tools/tools.js';
import type {
ToolCallRequestInfo,
CompletedToolCall,
SuccessfulToolCall,
Status,
ToolCall,
} from './types.js';
import { ROOT_SCHEDULER_ID } from './types.js';
import type { EditorType } from '../utils/editor.js';
describe('Scheduler Parallel Execution', () => {
let scheduler: Scheduler;
let signal: AbortSignal;
let abortController: AbortController;
let mockConfig: Mocked<Config>;
let mockMessageBus: Mocked<MessageBus>;
let mockPolicyEngine: Mocked<PolicyEngine>;
let mockToolRegistry: Mocked<ToolRegistry>;
let getPreferredEditor: Mock<() => EditorType | undefined>;
let mockStateManager: Mocked<SchedulerStateManager>;
let mockExecutor: Mocked<ToolExecutor>;
let mockModifier: Mocked<ToolModificationHandler>;
const req1: ToolCallRequestInfo = {
callId: 'call-1',
name: 'read-tool-1',
args: { path: 'a.txt' },
isClientInitiated: false,
prompt_id: 'p1',
schedulerId: ROOT_SCHEDULER_ID,
};
const req2: ToolCallRequestInfo = {
callId: 'call-2',
name: 'read-tool-2',
args: { path: 'b.txt' },
isClientInitiated: false,
prompt_id: 'p1',
schedulerId: ROOT_SCHEDULER_ID,
};
const req3: ToolCallRequestInfo = {
callId: 'call-3',
name: 'write-tool',
args: { path: 'c.txt', content: 'hi' },
isClientInitiated: false,
prompt_id: 'p1',
schedulerId: ROOT_SCHEDULER_ID,
};
const readTool1 = {
name: 'read-tool-1',
isReadOnly: true,
build: vi.fn(),
} as unknown as AnyDeclarativeTool;
const readTool2 = {
name: 'read-tool-2',
isReadOnly: true,
build: vi.fn(),
} as unknown as AnyDeclarativeTool;
const writeTool = {
name: 'write-tool',
isReadOnly: false,
build: vi.fn(),
} as unknown as AnyDeclarativeTool;
const mockInvocation = {
shouldConfirmExecute: vi.fn().mockResolvedValue(false),
};
beforeEach(() => {
vi.mocked(randomUUID).mockReturnValue(
'uuid' as unknown as `${string}-${string}-${string}-${string}-${string}`,
);
abortController = new AbortController();
signal = abortController.signal;
mockPolicyEngine = {
check: vi.fn().mockResolvedValue({ decision: PolicyDecision.ALLOW }),
} as unknown as Mocked<PolicyEngine>;
mockToolRegistry = {
getTool: vi.fn((name) => {
if (name === 'read-tool-1') return readTool1;
if (name === 'read-tool-2') return readTool2;
if (name === 'write-tool') return writeTool;
return undefined;
}),
getAllToolNames: vi
.fn()
.mockReturnValue(['read-tool-1', 'read-tool-2', 'write-tool']),
} as unknown as Mocked<ToolRegistry>;
mockConfig = {
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
isInteractive: vi.fn().mockReturnValue(true),
getEnableHooks: vi.fn().mockReturnValue(true),
setApprovalMode: vi.fn(),
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
} as unknown as Mocked<Config>;
mockMessageBus = {
publish: vi.fn(),
subscribe: vi.fn(),
} as unknown as Mocked<MessageBus>;
getPreferredEditor = vi.fn().mockReturnValue('vim');
vi.mocked(checkPolicy).mockReset();
vi.mocked(checkPolicy).mockResolvedValue({
decision: PolicyDecision.ALLOW,
rule: undefined,
});
vi.mocked(updatePolicy).mockReset();
const mockActiveCallsMap = new Map<string, ToolCall>();
const mockQueue: ToolCall[] = [];
let capturedTerminalHandler: TerminalCallHandler | undefined;
mockStateManager = {
enqueue: vi.fn((calls: ToolCall[]) => {
mockQueue.push(...calls.map((c) => ({ ...c }) as ToolCall));
}),
dequeue: vi.fn(() => {
const next = mockQueue.shift();
if (next) mockActiveCallsMap.set(next.request.callId, next);
return next;
}),
peekQueue: vi.fn(() => mockQueue[0]),
getToolCall: vi.fn((id: string) => mockActiveCallsMap.get(id)),
updateStatus: vi.fn((id: string, status: Status) => {
const call = mockActiveCallsMap.get(id);
if (call) (call as unknown as { status: Status }).status = status;
}),
finalizeCall: vi.fn((id: string) => {
const call = mockActiveCallsMap.get(id);
if (call) {
mockActiveCallsMap.delete(id);
capturedTerminalHandler?.(call as CompletedToolCall);
}
}),
updateArgs: vi.fn(),
setOutcome: vi.fn(),
cancelAllQueued: vi.fn(() => {
mockQueue.length = 0;
}),
clearBatch: vi.fn(),
} as unknown as Mocked<SchedulerStateManager>;
Object.defineProperty(mockStateManager, 'isActive', {
get: vi.fn(() => mockActiveCallsMap.size > 0),
configurable: true,
});
Object.defineProperty(mockStateManager, 'allActiveCalls', {
get: vi.fn(() => Array.from(mockActiveCallsMap.values())),
configurable: true,
});
Object.defineProperty(mockStateManager, 'queueLength', {
get: vi.fn(() => mockQueue.length),
configurable: true,
});
Object.defineProperty(mockStateManager, 'firstActiveCall', {
get: vi.fn(() => mockActiveCallsMap.values().next().value),
configurable: true,
});
Object.defineProperty(mockStateManager, 'completedBatch', {
get: vi.fn().mockReturnValue([]),
configurable: true,
});
vi.mocked(SchedulerStateManager).mockImplementation(
(_bus, _id, onTerminal) => {
capturedTerminalHandler = onTerminal;
return mockStateManager as unknown as SchedulerStateManager;
},
);
mockExecutor = { execute: vi.fn() } as unknown as Mocked<ToolExecutor>;
vi.mocked(ToolExecutor).mockReturnValue(
mockExecutor as unknown as Mocked<ToolExecutor>,
);
mockModifier = {
handleModifyWithEditor: vi.fn(),
applyInlineModify: vi.fn(),
} as unknown as Mocked<ToolModificationHandler>;
vi.mocked(ToolModificationHandler).mockReturnValue(
mockModifier as unknown as Mocked<ToolModificationHandler>,
);
scheduler = new Scheduler({
config: mockConfig,
messageBus: mockMessageBus,
getPreferredEditor,
schedulerId: 'root',
});
vi.mocked(readTool1.build).mockReturnValue(
mockInvocation as unknown as AnyToolInvocation,
);
vi.mocked(readTool2.build).mockReturnValue(
mockInvocation as unknown as AnyToolInvocation,
);
vi.mocked(writeTool.build).mockReturnValue(
mockInvocation as unknown as AnyToolInvocation,
);
});
afterEach(() => {
vi.clearAllMocks();
});
it('should execute contiguous read-only tools in parallel', async () => {
const executionLog: string[] = [];
mockExecutor.execute.mockImplementation(async ({ call }) => {
const id = call.request.callId;
executionLog.push(`start-${id}`);
await new Promise((resolve) => setTimeout(resolve, 10));
executionLog.push(`end-${id}`);
return {
status: 'success',
response: { callId: id, responseParts: [] },
} as unknown as SuccessfulToolCall;
});
// Schedule 2 read tools and 1 write tool
await scheduler.schedule([req1, req2, req3], signal);
// Parallel read tools should start together
expect(executionLog[0]).toBe('start-call-1');
expect(executionLog[1]).toBe('start-call-2');
// They can finish in any order, but both must finish before call-3 starts
expect(executionLog.indexOf('start-call-3')).toBeGreaterThan(
executionLog.indexOf('end-call-1'),
);
expect(executionLog.indexOf('start-call-3')).toBeGreaterThan(
executionLog.indexOf('end-call-2'),
);
expect(executionLog).toContain('end-call-3');
});
it('should execute non-read-only tools sequentially', async () => {
const executionLog: string[] = [];
mockExecutor.execute.mockImplementation(async ({ call }) => {
const id = call.request.callId;
executionLog.push(`start-${id}`);
await new Promise((resolve) => setTimeout(resolve, 10));
executionLog.push(`end-${id}`);
return {
status: 'success',
response: { callId: id, responseParts: [] },
} as unknown as SuccessfulToolCall;
});
// req3 is NOT read-only
await scheduler.schedule([req3, req1], signal);
// Should be strictly sequential
expect(executionLog).toEqual([
'start-call-3',
'end-call-3',
'start-call-1',
'end-call-1',
]);
});
it('should execute [WRITE, READ, READ] as [sequential, parallel]', async () => {
const executionLog: string[] = [];
mockExecutor.execute.mockImplementation(async ({ call }) => {
const id = call.request.callId;
executionLog.push(`start-${id}`);
await new Promise((resolve) => setTimeout(resolve, 10));
executionLog.push(`end-${id}`);
return {
status: 'success',
response: { callId: id, responseParts: [] },
} as unknown as SuccessfulToolCall;
});
// req3 (WRITE), req1 (READ), req2 (READ)
await scheduler.schedule([req3, req1, req2], signal);
// Order should be:
// 1. write starts and ends
// 2. read1 and read2 start together (parallel)
expect(executionLog[0]).toBe('start-call-3');
expect(executionLog[1]).toBe('end-call-3');
expect(executionLog.slice(2, 4)).toContain('start-call-1');
expect(executionLog.slice(2, 4)).toContain('start-call-2');
});
it('should execute [READ, READ, WRITE, READ, READ] in three waves', async () => {
const executionLog: string[] = [];
mockExecutor.execute.mockImplementation(async ({ call }) => {
const id = call.request.callId;
executionLog.push(`start-${id}`);
await new Promise((resolve) => setTimeout(resolve, 10));
executionLog.push(`end-${id}`);
return {
status: 'success',
response: { callId: id, responseParts: [] },
} as unknown as SuccessfulToolCall;
});
const req4: ToolCallRequestInfo = { ...req1, callId: 'call-4' };
const req5: ToolCallRequestInfo = { ...req2, callId: 'call-5' };
await scheduler.schedule([req1, req2, req3, req4, req5], signal);
// Wave 1: call-1, call-2 (parallel)
expect(executionLog.slice(0, 2)).toContain('start-call-1');
expect(executionLog.slice(0, 2)).toContain('start-call-2');
// Wave 2: call-3 (sequential)
// Must start after both call-1 and call-2 end
const start3 = executionLog.indexOf('start-call-3');
expect(start3).toBeGreaterThan(executionLog.indexOf('end-call-1'));
expect(start3).toBeGreaterThan(executionLog.indexOf('end-call-2'));
const end3 = executionLog.indexOf('end-call-3');
expect(end3).toBeGreaterThan(start3);
// Wave 3: call-4, call-5 (parallel)
// Must start after call-3 ends
expect(executionLog.indexOf('start-call-4')).toBeGreaterThan(end3);
expect(executionLog.indexOf('start-call-5')).toBeGreaterThan(end3);
});
});
@@ -78,10 +78,18 @@ export class SchedulerStateManager {
return next;
}
peekQueue(): ToolCall | undefined {
return this.queue[0];
}
get isActive(): boolean {
return this.activeCalls.size > 0;
}
get allActiveCalls(): ToolCall[] {
return Array.from(this.activeCalls.values());
}
get activeCallCount(): number {
return this.activeCalls.size;
}