diff --git a/packages/core/src/scheduler/scheduler.ts b/packages/core/src/scheduler/scheduler.ts index fa3b4526ad..58e4586887 100644 --- a/packages/core/src/scheduler/scheduler.ts +++ b/packages/core/src/scheduler/scheduler.ts @@ -29,6 +29,7 @@ import { PolicyDecision } from '../policy/types.js'; import { ToolConfirmationOutcome, type AnyDeclarativeTool, + Kind, } from '../tools/tools.js'; import { getToolSuggestion } from '../utils/tool-utils.js'; import { runInDevTraceSpan } from '../telemetry/trace.js'; @@ -427,11 +428,11 @@ export class Scheduler { return true; } - // If the first tool is read-only, batch all contiguous read-only tools. - if (next.tool?.isReadOnly) { + // If the first tool is parallelizable, batch all contiguous parallelizable tools. + if (this._isParallelizable(next.tool)) { while (this.state.queueLength > 0) { const peeked = this.state.peekQueue(); - if (peeked && peeked.tool?.isReadOnly) { + if (peeked && this._isParallelizable(peeked.tool)) { this.state.dequeue(); } else { break; @@ -516,6 +517,11 @@ export class Scheduler { return false; } + private _isParallelizable(tool?: AnyDeclarativeTool): boolean { + if (!tool) return false; + return tool.isReadOnly || tool.kind === Kind.Agent; + } + private async _processValidatingCall( active: ValidatingToolCall, signal: AbortSignal, diff --git a/packages/core/src/scheduler/scheduler_parallel.test.ts b/packages/core/src/scheduler/scheduler_parallel.test.ts index 9febf494c0..9633784323 100644 --- a/packages/core/src/scheduler/scheduler_parallel.test.ts +++ b/packages/core/src/scheduler/scheduler_parallel.test.ts @@ -70,6 +70,7 @@ import { ApprovalMode, PolicyDecision } from '../policy/types.js'; import { type AnyDeclarativeTool, type AnyToolInvocation, + Kind, } from '../tools/tools.js'; import type { ToolCallRequestInfo, @@ -124,18 +125,51 @@ describe('Scheduler Parallel Execution', () => { schedulerId: ROOT_SCHEDULER_ID, }; + const agentReq1: ToolCallRequestInfo = { + callId: 'agent-1', + name: 'agent-tool-1', + args: { query: 'do thing 1' }, + isClientInitiated: false, + prompt_id: 'p1', + schedulerId: ROOT_SCHEDULER_ID, + }; + + const agentReq2: ToolCallRequestInfo = { + callId: 'agent-2', + name: 'agent-tool-2', + args: { query: 'do thing 2' }, + isClientInitiated: false, + prompt_id: 'p1', + schedulerId: ROOT_SCHEDULER_ID, + }; + const readTool1 = { name: 'read-tool-1', + kind: Kind.Read, isReadOnly: true, build: vi.fn(), } as unknown as AnyDeclarativeTool; const readTool2 = { name: 'read-tool-2', + kind: Kind.Read, isReadOnly: true, build: vi.fn(), } as unknown as AnyDeclarativeTool; const writeTool = { name: 'write-tool', + kind: Kind.Execute, + isReadOnly: false, + build: vi.fn(), + } as unknown as AnyDeclarativeTool; + const agentTool1 = { + name: 'agent-tool-1', + kind: Kind.Agent, + isReadOnly: false, + build: vi.fn(), + } as unknown as AnyDeclarativeTool; + const agentTool2 = { + name: 'agent-tool-2', + kind: Kind.Agent, isReadOnly: false, build: vi.fn(), } as unknown as AnyDeclarativeTool; @@ -160,11 +194,19 @@ describe('Scheduler Parallel Execution', () => { if (name === 'read-tool-1') return readTool1; if (name === 'read-tool-2') return readTool2; if (name === 'write-tool') return writeTool; + if (name === 'agent-tool-1') return agentTool1; + if (name === 'agent-tool-2') return agentTool2; return undefined; }), getAllToolNames: vi .fn() - .mockReturnValue(['read-tool-1', 'read-tool-2', 'write-tool']), + .mockReturnValue([ + 'read-tool-1', + 'read-tool-2', + 'write-tool', + 'agent-tool-1', + 'agent-tool-2', + ]), } as unknown as Mocked; mockConfig = { @@ -279,6 +321,12 @@ describe('Scheduler Parallel Execution', () => { vi.mocked(writeTool.build).mockReturnValue( mockInvocation as unknown as AnyToolInvocation, ); + vi.mocked(agentTool1.build).mockReturnValue( + mockInvocation as unknown as AnyToolInvocation, + ); + vi.mocked(agentTool2.build).mockReturnValue( + mockInvocation as unknown as AnyToolInvocation, + ); }); afterEach(() => { @@ -418,4 +466,41 @@ describe('Scheduler Parallel Execution', () => { expect(executionLog.indexOf('start-call-4')).toBeGreaterThan(end3); expect(executionLog.indexOf('start-call-5')).toBeGreaterThan(end3); }); + + it('should execute [Agent, Agent, Sequential, Parallelizable] in three waves', async () => { + const executionLog: string[] = []; + + mockExecutor.execute.mockImplementation(async ({ call }) => { + const id = call.request.callId; + executionLog.push(`start-${id}`); + await new Promise((resolve) => setTimeout(resolve, 10)); + executionLog.push(`end-${id}`); + return { + status: 'success', + response: { callId: id, responseParts: [] }, + } as unknown as SuccessfulToolCall; + }); + + // Schedule: agentReq1 (Parallel), agentReq2 (Parallel), req3 (Sequential/Write), req1 (Parallel/Read) + await scheduler.schedule([agentReq1, agentReq2, req3, req1], signal); + + // Wave 1: agent-1, agent-2 (parallel) + expect(executionLog.slice(0, 2)).toContain('start-agent-1'); + expect(executionLog.slice(0, 2)).toContain('start-agent-2'); + + // Both agents must end before anything else starts + const endAgent1 = executionLog.indexOf('end-agent-1'); + const endAgent2 = executionLog.indexOf('end-agent-2'); + const wave1End = Math.max(endAgent1, endAgent2); + + // Wave 2: call-3 (sequential/write) + const start3 = executionLog.indexOf('start-call-3'); + const end3 = executionLog.indexOf('end-call-3'); + expect(start3).toBeGreaterThan(wave1End); + expect(end3).toBeGreaterThan(start3); + + // Wave 3: call-1 (parallelizable/read) + const start1 = executionLog.indexOf('start-call-1'); + expect(start1).toBeGreaterThan(end3); + }); });