diff --git a/packages/a2a-server/src/agent.test.ts b/packages/a2a-server/src/agent.test.ts index 04160d3f23..b396d58a12 100644 --- a/packages/a2a-server/src/agent.test.ts +++ b/packages/a2a-server/src/agent.test.ts @@ -33,6 +33,7 @@ import { assertTaskCreationAndWorkingStatus, createStreamMessageRequest, MockTool, + createMockConfig, } from './testing_utils.js'; const mockToolConfirmationFn = async () => @@ -68,26 +69,11 @@ vi.mock('./config.js', async () => { return { ...actual, loadConfig: vi.fn().mockImplementation(async () => { - config = { + const mockConfig = createMockConfig({ getToolRegistry: getToolRegistrySpy, getApprovalMode: getApprovalModeSpy, - getIdeMode: vi.fn().mockReturnValue(false), - getAllowedTools: vi.fn().mockReturnValue([]), - getIdeClient: vi.fn(), - getWorkspaceContext: vi.fn().mockReturnValue({ - isPathWithinWorkspace: () => true, - }), - getTargetDir: () => '/test', - getGeminiClient: vi.fn(), - getDebugMode: vi.fn().mockReturnValue(false), - getContentGeneratorConfig: vi - .fn() - .mockReturnValue({ model: 'gemini-pro' }), - getModel: vi.fn().mockReturnValue('gemini-pro'), - getUsageStatisticsEnabled: vi.fn().mockReturnValue(false), - setFlashFallbackHandler: vi.fn(), - initialize: vi.fn().mockResolvedValue(undefined), - } as unknown as Config; + }); + config = mockConfig as Config; return config; }), }; diff --git a/packages/a2a-server/src/task.test.ts b/packages/a2a-server/src/task.test.ts new file mode 100644 index 0000000000..6c14392f8c --- /dev/null +++ b/packages/a2a-server/src/task.test.ts @@ -0,0 +1,59 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi } from 'vitest'; +import { Task } from './task.js'; +import type { Config, ToolCallRequestInfo } from '@google/gemini-cli-core'; +import { createMockConfig } from './testing_utils.js'; +import type { ExecutionEventBus } from '@a2a-js/sdk/server'; + +describe('Task', () => { + it('scheduleToolCalls should not modify the input requests array', async () => { + const mockConfig = createMockConfig(); + + const mockEventBus: ExecutionEventBus = { + publish: vi.fn(), + on: vi.fn(), + off: vi.fn(), + once: vi.fn(), + removeAllListeners: vi.fn(), + finished: vi.fn(), + }; + + // The Task constructor is private. We'll bypass it for this unit test. + // @ts-expect-error - Calling private constructor for test purposes. + const task = new Task( + 'task-id', + 'context-id', + mockConfig as Config, + mockEventBus, + ); + + task['setTaskStateAndPublishUpdate'] = vi.fn(); + task['getProposedContent'] = vi.fn().mockResolvedValue('new content'); + + const requests: ToolCallRequestInfo[] = [ + { + callId: '1', + name: 'replace', + args: { + file_path: 'test.txt', + old_string: 'old', + new_string: 'new', + }, + isClientInitiated: false, + prompt_id: 'prompt-id-1', + }, + ]; + + const originalRequests = JSON.parse(JSON.stringify(requests)); + const abortController = new AbortController(); + + await task.scheduleToolCalls(requests, abortController.signal); + + expect(requests).toEqual(originalRequests); + }); +}); diff --git a/packages/a2a-server/src/task.ts b/packages/a2a-server/src/task.ts index dfcedb46c6..dbdbbf495b 100644 --- a/packages/a2a-server/src/task.ts +++ b/packages/a2a-server/src/task.ts @@ -520,30 +520,36 @@ export class Task { return; } - for (const request of requests) { - if ( - !request.args['newContent'] && - request.name === 'replace' && - request.args && - request.args['file_path'] && - request.args['old_string'] && - request.args['new_string'] - ) { - request.args['newContent'] = await this.getProposedContent( - request.args['file_path'] as string, - request.args['old_string'] as string, - request.args['new_string'] as string, - ); - } - } + const updatedRequests = await Promise.all( + requests.map(async (request) => { + if ( + request.name === 'replace' && + request.args && + !request.args['newContent'] && + request.args['file_path'] && + request.args['old_string'] && + request.args['new_string'] + ) { + const newContent = await this.getProposedContent( + request.args['file_path'] as string, + request.args['old_string'] as string, + request.args['new_string'] as string, + ); + return { ...request, args: { ...request.args, newContent } }; + } + return request; + }), + ); - logger.info(`[Task] Scheduling batch of ${requests.length} tool calls.`); + logger.info( + `[Task] Scheduling batch of ${updatedRequests.length} tool calls.`, + ); const stateChange: StateChange = { kind: CoderAgentEvent.StateChangeEvent, }; this.setTaskStateAndPublishUpdate('working', stateChange); - await this.scheduler.schedule(requests, abortSignal); + await this.scheduler.schedule(updatedRequests, abortSignal); } async acceptAgentMessage(event: ServerGeminiStreamEvent): Promise { diff --git a/packages/a2a-server/src/testing_utils.ts b/packages/a2a-server/src/testing_utils.ts index bd7ddaaa87..9a7c8e023a 100644 --- a/packages/a2a-server/src/testing_utils.ts +++ b/packages/a2a-server/src/testing_utils.ts @@ -9,18 +9,52 @@ import type { TaskStatusUpdateEvent, SendStreamingMessageSuccessResponse, } from '@a2a-js/sdk'; +import { ApprovalMode } from '@google/gemini-cli-core'; import { BaseDeclarativeTool, BaseToolInvocation, Kind, } from '@google/gemini-cli-core'; import type { + Config, ToolCallConfirmationDetails, ToolResult, ToolInvocation, } from '@google/gemini-cli-core'; import { expect, vi } from 'vitest'; +export function createMockConfig( + overrides: Partial = {}, +): Partial { + const mockConfig = { + getToolRegistry: vi.fn().mockReturnValue({ + getTool: vi.fn(), + getAllToolNames: vi.fn().mockReturnValue([]), + }), + getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT), + getIdeMode: vi.fn().mockReturnValue(false), + getAllowedTools: vi.fn().mockReturnValue([]), + getIdeClient: vi.fn(), + getWorkspaceContext: vi.fn().mockReturnValue({ + isPathWithinWorkspace: () => true, + }), + getTargetDir: () => '/test', + getGeminiClient: vi.fn(), + getDebugMode: vi.fn().mockReturnValue(false), + getContentGeneratorConfig: vi.fn().mockReturnValue({ model: 'gemini-pro' }), + getModel: vi.fn().mockReturnValue('gemini-pro'), + getUsageStatisticsEnabled: vi.fn().mockReturnValue(false), + setFlashFallbackHandler: vi.fn(), + initialize: vi.fn().mockResolvedValue(undefined), + getProxy: vi.fn().mockReturnValue(undefined), + getHistory: vi.fn().mockReturnValue([]), + getEmbeddingModel: vi.fn().mockReturnValue('text-embedding-004'), + getSessionId: vi.fn().mockReturnValue('test-session-id'), + ...overrides, + }; + return mockConfig; +} + export const mockOnUserConfirmForToolConfirmation = vi.fn(); export class MockToolInvocation extends BaseToolInvocation {