mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-14 16:10:59 -07:00
refactor(a2a): remove legacy CoreToolScheduler (#21955)
This commit is contained in:
@@ -4,25 +4,14 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
|
||||
import { Task } from './task.js';
|
||||
import {
|
||||
GeminiEventType,
|
||||
ApprovalMode,
|
||||
ToolConfirmationOutcome,
|
||||
type Config,
|
||||
type ToolCallRequestInfo,
|
||||
type GitService,
|
||||
type CompletedToolCall,
|
||||
type ToolCall,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { createMockConfig } from '../utils/testing_utils.js';
|
||||
import type { ExecutionEventBus, RequestContext } from '@a2a-js/sdk/server';
|
||||
@@ -389,214 +378,6 @@ describe('Task', () => {
|
||||
);
|
||||
});
|
||||
|
||||
describe('_schedulerToolCallsUpdate', () => {
|
||||
let task: Task;
|
||||
type SpyInstance = ReturnType<typeof vi.spyOn>;
|
||||
let setTaskStateAndPublishUpdateSpy: SpyInstance;
|
||||
let mockConfig: Config;
|
||||
let mockEventBus: ExecutionEventBus;
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig = createMockConfig() as Config;
|
||||
mockEventBus = {
|
||||
publish: vi.fn(),
|
||||
on: vi.fn(),
|
||||
off: vi.fn(),
|
||||
once: vi.fn(),
|
||||
removeAllListeners: vi.fn(),
|
||||
finished: vi.fn(),
|
||||
};
|
||||
|
||||
// @ts-expect-error - Calling private constructor
|
||||
task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
|
||||
|
||||
// Spy on the method we want to check calls for
|
||||
setTaskStateAndPublishUpdateSpy = vi.spyOn(
|
||||
task,
|
||||
'setTaskStateAndPublishUpdate',
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should set state to input-required when a tool is awaiting approval and none are executing', () => {
|
||||
const toolCalls = [
|
||||
{ request: { callId: '1' }, status: 'awaiting_approval' },
|
||||
] as ToolCall[];
|
||||
|
||||
// @ts-expect-error - Calling private method
|
||||
task._schedulerToolCallsUpdate(toolCalls);
|
||||
|
||||
// The last call should be the final state update
|
||||
expect(setTaskStateAndPublishUpdateSpy).toHaveBeenLastCalledWith(
|
||||
'input-required',
|
||||
{ kind: 'state-change' },
|
||||
undefined,
|
||||
undefined,
|
||||
true, // final: true
|
||||
);
|
||||
});
|
||||
|
||||
it('should NOT set state to input-required if a tool is awaiting approval but another is executing', () => {
|
||||
const toolCalls = [
|
||||
{ request: { callId: '1' }, status: 'awaiting_approval' },
|
||||
{ request: { callId: '2' }, status: 'executing' },
|
||||
] as ToolCall[];
|
||||
|
||||
// @ts-expect-error - Calling private method
|
||||
task._schedulerToolCallsUpdate(toolCalls);
|
||||
|
||||
// It will be called for status updates, but not with final: true
|
||||
const finalCall = setTaskStateAndPublishUpdateSpy.mock.calls.find(
|
||||
(call) => call[4] === true,
|
||||
);
|
||||
expect(finalCall).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should set state to input-required once an executing tool finishes, leaving one awaiting approval', () => {
|
||||
const initialToolCalls = [
|
||||
{ request: { callId: '1' }, status: 'awaiting_approval' },
|
||||
{ request: { callId: '2' }, status: 'executing' },
|
||||
] as ToolCall[];
|
||||
// @ts-expect-error - Calling private method
|
||||
task._schedulerToolCallsUpdate(initialToolCalls);
|
||||
|
||||
// No final call yet
|
||||
let finalCall = setTaskStateAndPublishUpdateSpy.mock.calls.find(
|
||||
(call) => call[4] === true,
|
||||
);
|
||||
expect(finalCall).toBeUndefined();
|
||||
|
||||
// Now, the executing tool finishes. The scheduler would call _resolveToolCall for it.
|
||||
// @ts-expect-error - Calling private method
|
||||
task._resolveToolCall('2');
|
||||
|
||||
// Then another update comes in for the awaiting tool (e.g., a re-check)
|
||||
const subsequentToolCalls = [
|
||||
{ request: { callId: '1' }, status: 'awaiting_approval' },
|
||||
] as ToolCall[];
|
||||
// @ts-expect-error - Calling private method
|
||||
task._schedulerToolCallsUpdate(subsequentToolCalls);
|
||||
|
||||
// NOW we should get the final call
|
||||
finalCall = setTaskStateAndPublishUpdateSpy.mock.calls.find(
|
||||
(call) => call[4] === true,
|
||||
);
|
||||
expect(finalCall).toBeDefined();
|
||||
expect(finalCall?.[0]).toBe('input-required');
|
||||
});
|
||||
|
||||
it('should NOT set state to input-required if skipFinalTrueAfterInlineEdit is true', () => {
|
||||
task.skipFinalTrueAfterInlineEdit = true;
|
||||
const toolCalls = [
|
||||
{ request: { callId: '1' }, status: 'awaiting_approval' },
|
||||
] as ToolCall[];
|
||||
|
||||
// @ts-expect-error - Calling private method
|
||||
task._schedulerToolCallsUpdate(toolCalls);
|
||||
|
||||
const finalCall = setTaskStateAndPublishUpdateSpy.mock.calls.find(
|
||||
(call) => call[4] === true,
|
||||
);
|
||||
expect(finalCall).toBeUndefined();
|
||||
});
|
||||
|
||||
describe('auto-approval', () => {
|
||||
it('should NOT publish ToolCallConfirmationEvent when autoExecute is true', () => {
|
||||
task.autoExecute = true;
|
||||
const onConfirmSpy = vi.fn();
|
||||
const toolCalls = [
|
||||
{
|
||||
request: { callId: '1' },
|
||||
status: 'awaiting_approval',
|
||||
correlationId: 'test-corr-id',
|
||||
confirmationDetails: {
|
||||
type: 'edit',
|
||||
onConfirm: onConfirmSpy,
|
||||
},
|
||||
},
|
||||
] as unknown as ToolCall[];
|
||||
|
||||
// @ts-expect-error - Calling private method
|
||||
task._schedulerToolCallsUpdate(toolCalls);
|
||||
|
||||
expect(onConfirmSpy).toHaveBeenCalledWith(
|
||||
ToolConfirmationOutcome.ProceedOnce,
|
||||
);
|
||||
const calls = (mockEventBus.publish as Mock).mock.calls;
|
||||
// Search if ToolCallConfirmationEvent was published
|
||||
const confEvent = calls.find(
|
||||
(call) =>
|
||||
call[0].metadata?.coderAgent?.kind ===
|
||||
CoderAgentEvent.ToolCallConfirmationEvent,
|
||||
);
|
||||
expect(confEvent).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should NOT publish ToolCallConfirmationEvent when approval mode is YOLO', () => {
|
||||
(mockConfig.getApprovalMode as Mock).mockReturnValue(ApprovalMode.YOLO);
|
||||
task.autoExecute = false;
|
||||
const onConfirmSpy = vi.fn();
|
||||
const toolCalls = [
|
||||
{
|
||||
request: { callId: '1' },
|
||||
status: 'awaiting_approval',
|
||||
correlationId: 'test-corr-id',
|
||||
confirmationDetails: {
|
||||
type: 'edit',
|
||||
onConfirm: onConfirmSpy,
|
||||
},
|
||||
},
|
||||
] as unknown as ToolCall[];
|
||||
|
||||
// @ts-expect-error - Calling private method
|
||||
task._schedulerToolCallsUpdate(toolCalls);
|
||||
|
||||
expect(onConfirmSpy).toHaveBeenCalledWith(
|
||||
ToolConfirmationOutcome.ProceedOnce,
|
||||
);
|
||||
const calls = (mockEventBus.publish as Mock).mock.calls;
|
||||
// Search if ToolCallConfirmationEvent was published
|
||||
const confEvent = calls.find(
|
||||
(call) =>
|
||||
call[0].metadata?.coderAgent?.kind ===
|
||||
CoderAgentEvent.ToolCallConfirmationEvent,
|
||||
);
|
||||
expect(confEvent).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should NOT auto-approve when autoExecute is false and mode is not YOLO', () => {
|
||||
task.autoExecute = false;
|
||||
(mockConfig.getApprovalMode as Mock).mockReturnValue(
|
||||
ApprovalMode.DEFAULT,
|
||||
);
|
||||
const onConfirmSpy = vi.fn();
|
||||
const toolCalls = [
|
||||
{
|
||||
request: { callId: '1' },
|
||||
status: 'awaiting_approval',
|
||||
confirmationDetails: { onConfirm: onConfirmSpy },
|
||||
},
|
||||
] as unknown as ToolCall[];
|
||||
|
||||
// @ts-expect-error - Calling private method
|
||||
task._schedulerToolCallsUpdate(toolCalls);
|
||||
|
||||
expect(onConfirmSpy).not.toHaveBeenCalled();
|
||||
const calls = (mockEventBus.publish as Mock).mock.calls;
|
||||
// Search if ToolCallConfirmationEvent was published
|
||||
const confEvent = calls.find(
|
||||
(call) =>
|
||||
call[0].metadata?.coderAgent?.kind ===
|
||||
CoderAgentEvent.ToolCallConfirmationEvent,
|
||||
);
|
||||
expect(confEvent).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('currentPromptId and promptCount', () => {
|
||||
it('should correctly initialize and update promptId and promptCount', async () => {
|
||||
const mockConfig = createMockConfig();
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
import {
|
||||
Scheduler,
|
||||
CoreToolScheduler,
|
||||
type GeminiClient,
|
||||
GeminiEventType,
|
||||
ToolConfirmationOutcome,
|
||||
@@ -69,37 +68,10 @@ import type { PartUnion, Part as genAiPart } from '@google/genai';
|
||||
|
||||
type UnionKeys<T> = T extends T ? keyof T : never;
|
||||
|
||||
type ConfirmationType = ToolCallConfirmationDetails['type'];
|
||||
|
||||
const VALID_CONFIRMATION_TYPES: readonly ConfirmationType[] = [
|
||||
'edit',
|
||||
'exec',
|
||||
'mcp',
|
||||
'info',
|
||||
'ask_user',
|
||||
'exit_plan_mode',
|
||||
] as const;
|
||||
|
||||
function isToolCallConfirmationDetails(
|
||||
value: unknown,
|
||||
): value is ToolCallConfirmationDetails {
|
||||
if (
|
||||
typeof value !== 'object' ||
|
||||
value === null ||
|
||||
!('onConfirm' in value) ||
|
||||
typeof value.onConfirm !== 'function' ||
|
||||
!('type' in value) ||
|
||||
typeof value.type !== 'string'
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
return (VALID_CONFIRMATION_TYPES as readonly string[]).includes(value.type);
|
||||
}
|
||||
|
||||
export class Task {
|
||||
id: string;
|
||||
contextId: string;
|
||||
scheduler: Scheduler | CoreToolScheduler;
|
||||
scheduler: Scheduler;
|
||||
config: Config;
|
||||
geminiClient: GeminiClient;
|
||||
pendingToolConfirmationDetails: Map<string, ToolCallConfirmationDetails>;
|
||||
@@ -140,11 +112,7 @@ export class Task {
|
||||
this.contextId = contextId;
|
||||
this.config = config;
|
||||
|
||||
if (this.config.isEventDrivenSchedulerEnabled()) {
|
||||
this.scheduler = this.setupEventDrivenScheduler();
|
||||
} else {
|
||||
this.scheduler = this.createLegacyScheduler();
|
||||
}
|
||||
this.scheduler = this.setupEventDrivenScheduler();
|
||||
|
||||
this.geminiClient = this.config.getGeminiClient();
|
||||
this.pendingToolConfirmationDetails = new Map();
|
||||
@@ -260,11 +228,7 @@ export class Task {
|
||||
this.pendingToolCalls.clear();
|
||||
this.pendingCorrelationIds.clear();
|
||||
|
||||
if (this.scheduler instanceof Scheduler) {
|
||||
this.scheduler.cancelAll();
|
||||
} else {
|
||||
this.scheduler.cancelAll(new AbortController().signal);
|
||||
}
|
||||
this.scheduler.cancelAll();
|
||||
// Reset the promise for any future operations, ensuring it's in a clean state.
|
||||
this._resetToolCompletionPromise();
|
||||
}
|
||||
@@ -409,133 +373,6 @@ export class Task {
|
||||
this.eventBus?.publish(artifactEvent);
|
||||
}
|
||||
|
||||
private async _schedulerAllToolCallsComplete(
|
||||
completedToolCalls: CompletedToolCall[],
|
||||
): Promise<void> {
|
||||
logger.info(
|
||||
'[Task] All tool calls completed by scheduler (batch):',
|
||||
completedToolCalls.map((tc) => tc.request.callId),
|
||||
);
|
||||
this.completedToolCalls.push(...completedToolCalls);
|
||||
completedToolCalls.forEach((tc) => {
|
||||
this._resolveToolCall(tc.request.callId);
|
||||
});
|
||||
}
|
||||
|
||||
private _schedulerToolCallsUpdate(toolCalls: ToolCall[]): void {
|
||||
logger.info(
|
||||
'[Task] Scheduler tool calls updated:',
|
||||
toolCalls.map((tc) => `${tc.request.callId} (${tc.status})`),
|
||||
);
|
||||
|
||||
// Update state and send continuous, non-final updates
|
||||
toolCalls.forEach((tc) => {
|
||||
const previousStatus = this.pendingToolCalls.get(tc.request.callId);
|
||||
const hasChanged = previousStatus !== tc.status;
|
||||
|
||||
// Resolve tool call if it has reached a terminal state
|
||||
if (['success', 'error', 'cancelled'].includes(tc.status)) {
|
||||
this._resolveToolCall(tc.request.callId);
|
||||
} else {
|
||||
// This will update the map
|
||||
this._registerToolCall(tc.request.callId, tc.status);
|
||||
}
|
||||
|
||||
if (tc.status === 'awaiting_approval' && tc.confirmationDetails) {
|
||||
const details = tc.confirmationDetails;
|
||||
if (isToolCallConfirmationDetails(details)) {
|
||||
this.pendingToolConfirmationDetails.set(tc.request.callId, details);
|
||||
}
|
||||
}
|
||||
|
||||
// Only send an update if the status has actually changed.
|
||||
if (hasChanged) {
|
||||
// Skip sending confirmation event if we are going to auto-approve it anyway
|
||||
if (
|
||||
tc.status === 'awaiting_approval' &&
|
||||
tc.confirmationDetails &&
|
||||
this.isYoloMatch
|
||||
) {
|
||||
logger.info(
|
||||
`[Task] Skipping ToolCallConfirmationEvent for ${tc.request.callId} due to YOLO mode.`,
|
||||
);
|
||||
} else {
|
||||
const coderAgentMessage: CoderAgentMessage =
|
||||
tc.status === 'awaiting_approval'
|
||||
? { kind: CoderAgentEvent.ToolCallConfirmationEvent }
|
||||
: { kind: CoderAgentEvent.ToolCallUpdateEvent };
|
||||
const message = this.toolStatusMessage(tc, this.id, this.contextId);
|
||||
|
||||
const event = this._createStatusUpdateEvent(
|
||||
this.taskState,
|
||||
coderAgentMessage,
|
||||
message,
|
||||
false, // Always false for these continuous updates
|
||||
);
|
||||
this.eventBus?.publish(event);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (this.isYoloMatch) {
|
||||
logger.info(
|
||||
'[Task] ' +
|
||||
(this.autoExecute ? '' : 'YOLO mode enabled. ') +
|
||||
'Auto-approving all tool calls.',
|
||||
);
|
||||
toolCalls.forEach((tc: ToolCall) => {
|
||||
if (tc.status === 'awaiting_approval' && tc.confirmationDetails) {
|
||||
const details = tc.confirmationDetails;
|
||||
if (isToolCallConfirmationDetails(details)) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
details.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||
this.pendingToolConfirmationDetails.delete(tc.request.callId);
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const allPendingStatuses = Array.from(this.pendingToolCalls.values());
|
||||
const isAwaitingApproval = allPendingStatuses.some(
|
||||
(status) => status === 'awaiting_approval',
|
||||
);
|
||||
const isExecuting = allPendingStatuses.some(
|
||||
(status) => status === 'executing',
|
||||
);
|
||||
|
||||
// The turn is complete and requires user input if at least one tool
|
||||
// is waiting for the user's decision, and no other tool is actively
|
||||
// running in the background.
|
||||
if (
|
||||
isAwaitingApproval &&
|
||||
!isExecuting &&
|
||||
!this.skipFinalTrueAfterInlineEdit
|
||||
) {
|
||||
this.skipFinalTrueAfterInlineEdit = false;
|
||||
|
||||
// We don't need to send another message, just a final status update.
|
||||
this.setTaskStateAndPublishUpdate(
|
||||
'input-required',
|
||||
{ kind: CoderAgentEvent.StateChangeEvent },
|
||||
undefined,
|
||||
undefined,
|
||||
/*final*/ true,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private createLegacyScheduler(): CoreToolScheduler {
|
||||
const scheduler = new CoreToolScheduler({
|
||||
outputUpdateHandler: this._schedulerOutputUpdate.bind(this),
|
||||
onAllToolCallsComplete: this._schedulerAllToolCallsComplete.bind(this),
|
||||
onToolCallsUpdate: this._schedulerToolCallsUpdate.bind(this),
|
||||
getPreferredEditor: () => DEFAULT_GUI_EDITOR,
|
||||
config: this.config,
|
||||
});
|
||||
return scheduler;
|
||||
}
|
||||
|
||||
private messageBusListener?: (message: ToolCallsUpdateMessage) => void;
|
||||
|
||||
private setupEventDrivenScheduler(): Scheduler {
|
||||
@@ -564,9 +401,7 @@ export class Task {
|
||||
this.messageBusListener = undefined;
|
||||
}
|
||||
|
||||
if (this.scheduler instanceof Scheduler) {
|
||||
this.scheduler.dispose();
|
||||
}
|
||||
this.scheduler.dispose();
|
||||
}
|
||||
|
||||
private handleEventDrivenToolCallsUpdate(
|
||||
|
||||
@@ -106,8 +106,6 @@ export async function loadConfig(
|
||||
trustedFolder: true,
|
||||
extensionLoader,
|
||||
checkpointing,
|
||||
enableEventDrivenScheduler:
|
||||
settings.experimental?.enableEventDrivenScheduler ?? true,
|
||||
interactive: !isHeadlessMode(),
|
||||
enableInteractiveShell: !isHeadlessMode(),
|
||||
ptyInfo: 'auto',
|
||||
|
||||
@@ -40,9 +40,6 @@ export interface Settings {
|
||||
general?: {
|
||||
previewFeatures?: boolean;
|
||||
};
|
||||
experimental?: {
|
||||
enableEventDrivenScheduler?: boolean;
|
||||
};
|
||||
|
||||
// Git-aware file filtering settings
|
||||
fileFiltering?: {
|
||||
|
||||
@@ -65,7 +65,12 @@ vi.mock('../utils/logger.js', () => ({
|
||||
}));
|
||||
|
||||
let config: Config;
|
||||
const getToolRegistrySpy = vi.fn().mockReturnValue(ApprovalMode.DEFAULT);
|
||||
const getToolRegistrySpy = vi.fn().mockReturnValue({
|
||||
getTool: vi.fn(),
|
||||
getAllToolNames: vi.fn().mockReturnValue([]),
|
||||
getAllTools: vi.fn().mockReturnValue([]),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
});
|
||||
const getApprovalModeSpy = vi.fn();
|
||||
const getShellExecutionConfigSpy = vi.fn();
|
||||
const getExtensionsSpy = vi.fn();
|
||||
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
tmpdir,
|
||||
type Config,
|
||||
type Storage,
|
||||
type ToolRegistry,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { createMockMessageBus } from '@google/gemini-cli-core/src/test-utils/mock-message-bus.js';
|
||||
import { expect, vi } from 'vitest';
|
||||
@@ -30,6 +31,10 @@ export function createMockConfig(
|
||||
const tmpDir = tmpdir();
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const mockConfig = {
|
||||
get toolRegistry(): ToolRegistry {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
return (this as unknown as Config).getToolRegistry();
|
||||
},
|
||||
getToolRegistry: vi.fn().mockReturnValue({
|
||||
getTool: vi.fn(),
|
||||
getAllToolNames: vi.fn().mockReturnValue([]),
|
||||
@@ -64,7 +69,6 @@ export function createMockConfig(
|
||||
getEmbeddingModel: vi.fn().mockReturnValue('text-embedding-004'),
|
||||
getSessionId: vi.fn().mockReturnValue('test-session-id'),
|
||||
getUserTier: vi.fn(),
|
||||
isEventDrivenSchedulerEnabled: vi.fn().mockReturnValue(false),
|
||||
getMessageBus: vi.fn(),
|
||||
getPolicyEngine: vi.fn(),
|
||||
getEnableExtensionReloading: vi.fn().mockReturnValue(false),
|
||||
|
||||
Reference in New Issue
Block a user