/** * @license * Copyright 2026 Google LLC * SPDX-License-Identifier: Apache-2.0 */ import { describe, it, expect, vi, beforeEach, afterEach, type Mock, type Mocked, } from 'vitest'; import { randomUUID } from 'node:crypto'; vi.mock('node:crypto', () => ({ randomUUID: vi.fn(), })); const runInDevTraceSpan = vi.hoisted(() => vi.fn(async (opts, fn) => { const metadata = { name: '', attributes: opts.attributes || {} }; return fn({ metadata, endSpan: vi.fn(), }); }), ); vi.mock('../telemetry/trace.js', () => ({ runInDevTraceSpan, })); vi.mock('../telemetry/loggers.js', () => ({ logToolCall: vi.fn(), })); vi.mock('../telemetry/types.js', () => ({ ToolCallEvent: vi.fn().mockImplementation((call) => ({ ...call })), })); import { SchedulerStateManager, type TerminalCallHandler, } from './state-manager.js'; import { checkPolicy, updatePolicy } from './policy.js'; import { ToolExecutor } from './tool-executor.js'; import { ToolModificationHandler } from './tool-modifier.js'; vi.mock('./state-manager.js'); vi.mock('./confirmation.js'); vi.mock('./policy.js', async (importOriginal) => { const actual = await importOriginal(); return { ...actual, checkPolicy: vi.fn(), updatePolicy: vi.fn(), }; }); vi.mock('./tool-executor.js'); vi.mock('./tool-modifier.js'); import { Scheduler } from './scheduler.js'; import type { Config } from '../config/config.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; import type { PolicyEngine } from '../policy/policy-engine.js'; import type { ToolRegistry } from '../tools/tool-registry.js'; import { ApprovalMode, PolicyDecision } from '../policy/types.js'; import { type AnyDeclarativeTool, type AnyToolInvocation, Kind, } from '../tools/tools.js'; import { ROOT_SCHEDULER_ID, type ToolCallRequestInfo, type CompletedToolCall, type SuccessfulToolCall, type Status, type ToolCall, } from './types.js'; import { GeminiCliOperation } from '../telemetry/constants.js'; import type { EditorType } from '../utils/editor.js'; describe('Scheduler Parallel Execution', () => { let scheduler: Scheduler; let signal: AbortSignal; let abortController: AbortController; let mockConfig: Mocked; let mockMessageBus: Mocked; let mockPolicyEngine: Mocked; let mockToolRegistry: Mocked; let getPreferredEditor: Mock<() => EditorType | undefined>; let mockStateManager: Mocked; let mockExecutor: Mocked; let mockModifier: Mocked; const req1: ToolCallRequestInfo = { callId: 'call-1', name: 'read-tool-1', args: { path: 'a.txt' }, isClientInitiated: false, prompt_id: 'p1', schedulerId: ROOT_SCHEDULER_ID, }; const req2: ToolCallRequestInfo = { callId: 'call-2', name: 'read-tool-2', args: { path: 'b.txt' }, isClientInitiated: false, prompt_id: 'p1', schedulerId: ROOT_SCHEDULER_ID, }; const req3: ToolCallRequestInfo = { callId: 'call-3', name: 'write-tool', args: { path: 'c.txt', content: 'hi' }, isClientInitiated: false, prompt_id: 'p1', schedulerId: ROOT_SCHEDULER_ID, }; const agentReq1: ToolCallRequestInfo = { callId: 'agent-1', name: 'agent-tool-1', args: { query: 'do thing 1' }, isClientInitiated: false, prompt_id: 'p1', schedulerId: ROOT_SCHEDULER_ID, }; const agentReq2: ToolCallRequestInfo = { callId: 'agent-2', name: 'agent-tool-2', args: { query: 'do thing 2' }, isClientInitiated: false, prompt_id: 'p1', schedulerId: ROOT_SCHEDULER_ID, }; const readTool1 = { name: 'read-tool-1', kind: Kind.Read, isReadOnly: true, build: vi.fn(), } as unknown as AnyDeclarativeTool; const readTool2 = { name: 'read-tool-2', kind: Kind.Read, isReadOnly: true, build: vi.fn(), } as unknown as AnyDeclarativeTool; const writeTool = { name: 'write-tool', kind: Kind.Execute, isReadOnly: false, build: vi.fn(), } as unknown as AnyDeclarativeTool; const agentTool1 = { name: 'agent-tool-1', kind: Kind.Agent, isReadOnly: false, build: vi.fn(), } as unknown as AnyDeclarativeTool; const agentTool2 = { name: 'agent-tool-2', kind: Kind.Agent, isReadOnly: false, build: vi.fn(), } as unknown as AnyDeclarativeTool; const mockInvocation = { shouldConfirmExecute: vi.fn().mockResolvedValue(false), }; beforeEach(() => { vi.mocked(randomUUID).mockReturnValue( 'uuid' as unknown as `${string}-${string}-${string}-${string}-${string}`, ); abortController = new AbortController(); signal = abortController.signal; mockPolicyEngine = { check: vi.fn().mockResolvedValue({ decision: PolicyDecision.ALLOW }), } as unknown as Mocked; mockToolRegistry = { getTool: vi.fn((name) => { if (name === 'read-tool-1') return readTool1; if (name === 'read-tool-2') return readTool2; if (name === 'write-tool') return writeTool; if (name === 'agent-tool-1') return agentTool1; if (name === 'agent-tool-2') return agentTool2; return undefined; }), getAllToolNames: vi .fn() .mockReturnValue([ 'read-tool-1', 'read-tool-2', 'write-tool', 'agent-tool-1', 'agent-tool-2', ]), } as unknown as Mocked; mockConfig = { getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), isInteractive: vi.fn().mockReturnValue(true), getEnableHooks: vi.fn().mockReturnValue(true), setApprovalMode: vi.fn(), getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT), } as unknown as Mocked; mockMessageBus = { publish: vi.fn(), subscribe: vi.fn(), } as unknown as Mocked; getPreferredEditor = vi.fn().mockReturnValue('vim'); vi.mocked(checkPolicy).mockReset(); vi.mocked(checkPolicy).mockResolvedValue({ decision: PolicyDecision.ALLOW, rule: undefined, }); vi.mocked(updatePolicy).mockReset(); const mockActiveCallsMap = new Map(); const mockQueue: ToolCall[] = []; let capturedTerminalHandler: TerminalCallHandler | undefined; mockStateManager = { enqueue: vi.fn((calls: ToolCall[]) => { mockQueue.push(...calls.map((c) => ({ ...c }) as ToolCall)); }), dequeue: vi.fn(() => { const next = mockQueue.shift(); if (next) mockActiveCallsMap.set(next.request.callId, next); return next; }), peekQueue: vi.fn(() => mockQueue[0]), getToolCall: vi.fn((id: string) => mockActiveCallsMap.get(id)), updateStatus: vi.fn((id: string, status: Status) => { const call = mockActiveCallsMap.get(id); if (call) (call as unknown as { status: Status }).status = status; }), finalizeCall: vi.fn((id: string) => { const call = mockActiveCallsMap.get(id); if (call) { mockActiveCallsMap.delete(id); capturedTerminalHandler?.(call as CompletedToolCall); } }), updateArgs: vi.fn(), setOutcome: vi.fn(), cancelAllQueued: vi.fn(() => { mockQueue.length = 0; }), clearBatch: vi.fn(), } as unknown as Mocked; Object.defineProperty(mockStateManager, 'isActive', { get: vi.fn(() => mockActiveCallsMap.size > 0), configurable: true, }); Object.defineProperty(mockStateManager, 'allActiveCalls', { get: vi.fn(() => Array.from(mockActiveCallsMap.values())), configurable: true, }); Object.defineProperty(mockStateManager, 'queueLength', { get: vi.fn(() => mockQueue.length), configurable: true, }); Object.defineProperty(mockStateManager, 'firstActiveCall', { get: vi.fn(() => mockActiveCallsMap.values().next().value), configurable: true, }); Object.defineProperty(mockStateManager, 'completedBatch', { get: vi.fn().mockReturnValue([]), configurable: true, }); vi.mocked(SchedulerStateManager).mockImplementation( (_bus, _id, onTerminal) => { capturedTerminalHandler = onTerminal; return mockStateManager as unknown as SchedulerStateManager; }, ); mockExecutor = { execute: vi.fn() } as unknown as Mocked; vi.mocked(ToolExecutor).mockReturnValue( mockExecutor as unknown as Mocked, ); mockModifier = { handleModifyWithEditor: vi.fn(), applyInlineModify: vi.fn(), } as unknown as Mocked; vi.mocked(ToolModificationHandler).mockReturnValue( mockModifier as unknown as Mocked, ); scheduler = new Scheduler({ config: mockConfig, messageBus: mockMessageBus, getPreferredEditor, schedulerId: 'root', }); vi.mocked(readTool1.build).mockReturnValue( mockInvocation as unknown as AnyToolInvocation, ); vi.mocked(readTool2.build).mockReturnValue( mockInvocation as unknown as AnyToolInvocation, ); vi.mocked(writeTool.build).mockReturnValue( mockInvocation as unknown as AnyToolInvocation, ); vi.mocked(agentTool1.build).mockReturnValue( mockInvocation as unknown as AnyToolInvocation, ); vi.mocked(agentTool2.build).mockReturnValue( mockInvocation as unknown as AnyToolInvocation, ); }); afterEach(() => { vi.clearAllMocks(); }); it('should execute contiguous read-only tools in parallel', async () => { const executionLog: string[] = []; mockExecutor.execute.mockImplementation(async ({ call }) => { const id = call.request.callId; executionLog.push(`start-${id}`); await new Promise((resolve) => setTimeout(resolve, 10)); executionLog.push(`end-${id}`); return { status: 'success', response: { callId: id, responseParts: [] }, } as unknown as SuccessfulToolCall; }); // Schedule 2 read tools and 1 write tool await scheduler.schedule([req1, req2, req3], signal); // Parallel read tools should start together expect(executionLog[0]).toBe('start-call-1'); expect(executionLog[1]).toBe('start-call-2'); // They can finish in any order, but both must finish before call-3 starts expect(executionLog.indexOf('start-call-3')).toBeGreaterThan( executionLog.indexOf('end-call-1'), ); expect(executionLog.indexOf('start-call-3')).toBeGreaterThan( executionLog.indexOf('end-call-2'), ); expect(executionLog).toContain('end-call-3'); expect(runInDevTraceSpan).toHaveBeenCalledWith( expect.objectContaining({ operation: GeminiCliOperation.ScheduleToolCalls, }), expect.any(Function), ); const spanArgs = vi.mocked(runInDevTraceSpan).mock.calls[0]; const fn = spanArgs[1]; const metadata = { name: '', attributes: {} }; await fn({ metadata, endSpan: vi.fn() }); expect(metadata).toMatchObject({ input: [req1, req2, req3], }); }); it('should execute non-read-only tools sequentially', async () => { const executionLog: string[] = []; mockExecutor.execute.mockImplementation(async ({ call }) => { const id = call.request.callId; executionLog.push(`start-${id}`); await new Promise((resolve) => setTimeout(resolve, 10)); executionLog.push(`end-${id}`); return { status: 'success', response: { callId: id, responseParts: [] }, } as unknown as SuccessfulToolCall; }); // req3 is NOT read-only await scheduler.schedule([req3, req1], signal); // Should be strictly sequential expect(executionLog).toEqual([ 'start-call-3', 'end-call-3', 'start-call-1', 'end-call-1', ]); }); it('should execute [WRITE, READ, READ] as [sequential, parallel]', async () => { const executionLog: string[] = []; mockExecutor.execute.mockImplementation(async ({ call }) => { const id = call.request.callId; executionLog.push(`start-${id}`); await new Promise((resolve) => setTimeout(resolve, 10)); executionLog.push(`end-${id}`); return { status: 'success', response: { callId: id, responseParts: [] }, } as unknown as SuccessfulToolCall; }); // req3 (WRITE), req1 (READ), req2 (READ) await scheduler.schedule([req3, req1, req2], signal); // Order should be: // 1. write starts and ends // 2. read1 and read2 start together (parallel) expect(executionLog[0]).toBe('start-call-3'); expect(executionLog[1]).toBe('end-call-3'); expect(executionLog.slice(2, 4)).toContain('start-call-1'); expect(executionLog.slice(2, 4)).toContain('start-call-2'); }); it('should execute [READ, READ, WRITE, READ, READ] in three waves', async () => { const executionLog: string[] = []; mockExecutor.execute.mockImplementation(async ({ call }) => { const id = call.request.callId; executionLog.push(`start-${id}`); await new Promise((resolve) => setTimeout(resolve, 10)); executionLog.push(`end-${id}`); return { status: 'success', response: { callId: id, responseParts: [] }, } as unknown as SuccessfulToolCall; }); const req4: ToolCallRequestInfo = { ...req1, callId: 'call-4' }; const req5: ToolCallRequestInfo = { ...req2, callId: 'call-5' }; await scheduler.schedule([req1, req2, req3, req4, req5], signal); // Wave 1: call-1, call-2 (parallel) expect(executionLog.slice(0, 2)).toContain('start-call-1'); expect(executionLog.slice(0, 2)).toContain('start-call-2'); // Wave 2: call-3 (sequential) // Must start after both call-1 and call-2 end const start3 = executionLog.indexOf('start-call-3'); expect(start3).toBeGreaterThan(executionLog.indexOf('end-call-1')); expect(start3).toBeGreaterThan(executionLog.indexOf('end-call-2')); const end3 = executionLog.indexOf('end-call-3'); expect(end3).toBeGreaterThan(start3); // Wave 3: call-4, call-5 (parallel) // Must start after call-3 ends expect(executionLog.indexOf('start-call-4')).toBeGreaterThan(end3); expect(executionLog.indexOf('start-call-5')).toBeGreaterThan(end3); }); it('should execute [Agent, Agent, Sequential, Parallelizable] in three waves', async () => { const executionLog: string[] = []; mockExecutor.execute.mockImplementation(async ({ call }) => { const id = call.request.callId; executionLog.push(`start-${id}`); await new Promise((resolve) => setTimeout(resolve, 10)); executionLog.push(`end-${id}`); return { status: 'success', response: { callId: id, responseParts: [] }, } as unknown as SuccessfulToolCall; }); // Schedule: agentReq1 (Parallel), agentReq2 (Parallel), req3 (Sequential/Write), req1 (Parallel/Read) await scheduler.schedule([agentReq1, agentReq2, req3, req1], signal); // Wave 1: agent-1, agent-2 (parallel) expect(executionLog.slice(0, 2)).toContain('start-agent-1'); expect(executionLog.slice(0, 2)).toContain('start-agent-2'); // Both agents must end before anything else starts const endAgent1 = executionLog.indexOf('end-agent-1'); const endAgent2 = executionLog.indexOf('end-agent-2'); const wave1End = Math.max(endAgent1, endAgent2); // Wave 2: call-3 (sequential/write) const start3 = executionLog.indexOf('start-call-3'); const end3 = executionLog.indexOf('end-call-3'); expect(start3).toBeGreaterThan(wave1End); expect(end3).toBeGreaterThan(start3); // Wave 3: call-1 (parallelizable/read) const start1 = executionLog.indexOf('start-call-1'); expect(start1).toBeGreaterThan(end3); }); });