mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-15 00:21:09 -07:00
feat(core): enable contiguous parallel admission for Kind.Agent tools (#20583)
This commit is contained in:
@@ -29,6 +29,7 @@ import { PolicyDecision } from '../policy/types.js';
|
||||
import {
|
||||
ToolConfirmationOutcome,
|
||||
type AnyDeclarativeTool,
|
||||
Kind,
|
||||
} from '../tools/tools.js';
|
||||
import { getToolSuggestion } from '../utils/tool-utils.js';
|
||||
import { runInDevTraceSpan } from '../telemetry/trace.js';
|
||||
@@ -427,11 +428,11 @@ export class Scheduler {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If the first tool is read-only, batch all contiguous read-only tools.
|
||||
if (next.tool?.isReadOnly) {
|
||||
// If the first tool is parallelizable, batch all contiguous parallelizable tools.
|
||||
if (this._isParallelizable(next.tool)) {
|
||||
while (this.state.queueLength > 0) {
|
||||
const peeked = this.state.peekQueue();
|
||||
if (peeked && peeked.tool?.isReadOnly) {
|
||||
if (peeked && this._isParallelizable(peeked.tool)) {
|
||||
this.state.dequeue();
|
||||
} else {
|
||||
break;
|
||||
@@ -516,6 +517,11 @@ export class Scheduler {
|
||||
return false;
|
||||
}
|
||||
|
||||
private _isParallelizable(tool?: AnyDeclarativeTool): boolean {
|
||||
if (!tool) return false;
|
||||
return tool.isReadOnly || tool.kind === Kind.Agent;
|
||||
}
|
||||
|
||||
private async _processValidatingCall(
|
||||
active: ValidatingToolCall,
|
||||
signal: AbortSignal,
|
||||
|
||||
@@ -70,6 +70,7 @@ import { ApprovalMode, PolicyDecision } from '../policy/types.js';
|
||||
import {
|
||||
type AnyDeclarativeTool,
|
||||
type AnyToolInvocation,
|
||||
Kind,
|
||||
} from '../tools/tools.js';
|
||||
import type {
|
||||
ToolCallRequestInfo,
|
||||
@@ -124,18 +125,51 @@ describe('Scheduler Parallel Execution', () => {
|
||||
schedulerId: ROOT_SCHEDULER_ID,
|
||||
};
|
||||
|
||||
const agentReq1: ToolCallRequestInfo = {
|
||||
callId: 'agent-1',
|
||||
name: 'agent-tool-1',
|
||||
args: { query: 'do thing 1' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'p1',
|
||||
schedulerId: ROOT_SCHEDULER_ID,
|
||||
};
|
||||
|
||||
const agentReq2: ToolCallRequestInfo = {
|
||||
callId: 'agent-2',
|
||||
name: 'agent-tool-2',
|
||||
args: { query: 'do thing 2' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'p1',
|
||||
schedulerId: ROOT_SCHEDULER_ID,
|
||||
};
|
||||
|
||||
const readTool1 = {
|
||||
name: 'read-tool-1',
|
||||
kind: Kind.Read,
|
||||
isReadOnly: true,
|
||||
build: vi.fn(),
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
const readTool2 = {
|
||||
name: 'read-tool-2',
|
||||
kind: Kind.Read,
|
||||
isReadOnly: true,
|
||||
build: vi.fn(),
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
const writeTool = {
|
||||
name: 'write-tool',
|
||||
kind: Kind.Execute,
|
||||
isReadOnly: false,
|
||||
build: vi.fn(),
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
const agentTool1 = {
|
||||
name: 'agent-tool-1',
|
||||
kind: Kind.Agent,
|
||||
isReadOnly: false,
|
||||
build: vi.fn(),
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
const agentTool2 = {
|
||||
name: 'agent-tool-2',
|
||||
kind: Kind.Agent,
|
||||
isReadOnly: false,
|
||||
build: vi.fn(),
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
@@ -160,11 +194,19 @@ describe('Scheduler Parallel Execution', () => {
|
||||
if (name === 'read-tool-1') return readTool1;
|
||||
if (name === 'read-tool-2') return readTool2;
|
||||
if (name === 'write-tool') return writeTool;
|
||||
if (name === 'agent-tool-1') return agentTool1;
|
||||
if (name === 'agent-tool-2') return agentTool2;
|
||||
return undefined;
|
||||
}),
|
||||
getAllToolNames: vi
|
||||
.fn()
|
||||
.mockReturnValue(['read-tool-1', 'read-tool-2', 'write-tool']),
|
||||
.mockReturnValue([
|
||||
'read-tool-1',
|
||||
'read-tool-2',
|
||||
'write-tool',
|
||||
'agent-tool-1',
|
||||
'agent-tool-2',
|
||||
]),
|
||||
} as unknown as Mocked<ToolRegistry>;
|
||||
|
||||
mockConfig = {
|
||||
@@ -279,6 +321,12 @@ describe('Scheduler Parallel Execution', () => {
|
||||
vi.mocked(writeTool.build).mockReturnValue(
|
||||
mockInvocation as unknown as AnyToolInvocation,
|
||||
);
|
||||
vi.mocked(agentTool1.build).mockReturnValue(
|
||||
mockInvocation as unknown as AnyToolInvocation,
|
||||
);
|
||||
vi.mocked(agentTool2.build).mockReturnValue(
|
||||
mockInvocation as unknown as AnyToolInvocation,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -418,4 +466,41 @@ describe('Scheduler Parallel Execution', () => {
|
||||
expect(executionLog.indexOf('start-call-4')).toBeGreaterThan(end3);
|
||||
expect(executionLog.indexOf('start-call-5')).toBeGreaterThan(end3);
|
||||
});
|
||||
|
||||
it('should execute [Agent, Agent, Sequential, Parallelizable] in three waves', async () => {
|
||||
const executionLog: string[] = [];
|
||||
|
||||
mockExecutor.execute.mockImplementation(async ({ call }) => {
|
||||
const id = call.request.callId;
|
||||
executionLog.push(`start-${id}`);
|
||||
await new Promise<void>((resolve) => setTimeout(resolve, 10));
|
||||
executionLog.push(`end-${id}`);
|
||||
return {
|
||||
status: 'success',
|
||||
response: { callId: id, responseParts: [] },
|
||||
} as unknown as SuccessfulToolCall;
|
||||
});
|
||||
|
||||
// Schedule: agentReq1 (Parallel), agentReq2 (Parallel), req3 (Sequential/Write), req1 (Parallel/Read)
|
||||
await scheduler.schedule([agentReq1, agentReq2, req3, req1], signal);
|
||||
|
||||
// Wave 1: agent-1, agent-2 (parallel)
|
||||
expect(executionLog.slice(0, 2)).toContain('start-agent-1');
|
||||
expect(executionLog.slice(0, 2)).toContain('start-agent-2');
|
||||
|
||||
// Both agents must end before anything else starts
|
||||
const endAgent1 = executionLog.indexOf('end-agent-1');
|
||||
const endAgent2 = executionLog.indexOf('end-agent-2');
|
||||
const wave1End = Math.max(endAgent1, endAgent2);
|
||||
|
||||
// Wave 2: call-3 (sequential/write)
|
||||
const start3 = executionLog.indexOf('start-call-3');
|
||||
const end3 = executionLog.indexOf('end-call-3');
|
||||
expect(start3).toBeGreaterThan(wave1End);
|
||||
expect(end3).toBeGreaterThan(start3);
|
||||
|
||||
// Wave 3: call-1 (parallelizable/read)
|
||||
const start1 = executionLog.indexOf('start-call-1');
|
||||
expect(start1).toBeGreaterThan(end3);
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user