fix(a2a): Don't mutate 'replace' tool args in scheduleToolCalls (#7369)

Co-authored-by: Abhi <43648792+abhipatel12@users.noreply.github.com>
This commit is contained in:
Victor Miura
2025-08-29 09:32:34 -07:00
committed by GitHub
parent f2bddfe054
commit 6868cbe7b7
4 changed files with 121 additions and 36 deletions
+4 -18
View File
@@ -33,6 +33,7 @@ import {
assertTaskCreationAndWorkingStatus, assertTaskCreationAndWorkingStatus,
createStreamMessageRequest, createStreamMessageRequest,
MockTool, MockTool,
createMockConfig,
} from './testing_utils.js'; } from './testing_utils.js';
const mockToolConfirmationFn = async () => const mockToolConfirmationFn = async () =>
@@ -68,26 +69,11 @@ vi.mock('./config.js', async () => {
return { return {
...actual, ...actual,
loadConfig: vi.fn().mockImplementation(async () => { loadConfig: vi.fn().mockImplementation(async () => {
config = { const mockConfig = createMockConfig({
getToolRegistry: getToolRegistrySpy, getToolRegistry: getToolRegistrySpy,
getApprovalMode: getApprovalModeSpy, getApprovalMode: getApprovalModeSpy,
getIdeMode: vi.fn().mockReturnValue(false), });
getAllowedTools: vi.fn().mockReturnValue([]), config = mockConfig as Config;
getIdeClient: vi.fn(),
getWorkspaceContext: vi.fn().mockReturnValue({
isPathWithinWorkspace: () => true,
}),
getTargetDir: () => '/test',
getGeminiClient: vi.fn(),
getDebugMode: vi.fn().mockReturnValue(false),
getContentGeneratorConfig: vi
.fn()
.mockReturnValue({ model: 'gemini-pro' }),
getModel: vi.fn().mockReturnValue('gemini-pro'),
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
setFlashFallbackHandler: vi.fn(),
initialize: vi.fn().mockResolvedValue(undefined),
} as unknown as Config;
return config; return config;
}), }),
}; };
+59
View File
@@ -0,0 +1,59 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi } from 'vitest';
import { Task } from './task.js';
import type { Config, ToolCallRequestInfo } from '@google/gemini-cli-core';
import { createMockConfig } from './testing_utils.js';
import type { ExecutionEventBus } from '@a2a-js/sdk/server';
describe('Task', () => {
it('scheduleToolCalls should not modify the input requests array', async () => {
const mockConfig = createMockConfig();
const mockEventBus: ExecutionEventBus = {
publish: vi.fn(),
on: vi.fn(),
off: vi.fn(),
once: vi.fn(),
removeAllListeners: vi.fn(),
finished: vi.fn(),
};
// The Task constructor is private. We'll bypass it for this unit test.
// @ts-expect-error - Calling private constructor for test purposes.
const task = new Task(
'task-id',
'context-id',
mockConfig as Config,
mockEventBus,
);
task['setTaskStateAndPublishUpdate'] = vi.fn();
task['getProposedContent'] = vi.fn().mockResolvedValue('new content');
const requests: ToolCallRequestInfo[] = [
{
callId: '1',
name: 'replace',
args: {
file_path: 'test.txt',
old_string: 'old',
new_string: 'new',
},
isClientInitiated: false,
prompt_id: 'prompt-id-1',
},
];
const originalRequests = JSON.parse(JSON.stringify(requests));
const abortController = new AbortController();
await task.scheduleToolCalls(requests, abortController.signal);
expect(requests).toEqual(originalRequests);
});
});
+24 -18
View File
@@ -520,30 +520,36 @@ export class Task {
return; return;
} }
for (const request of requests) { const updatedRequests = await Promise.all(
if ( requests.map(async (request) => {
!request.args['newContent'] && if (
request.name === 'replace' && request.name === 'replace' &&
request.args && request.args &&
request.args['file_path'] && !request.args['newContent'] &&
request.args['old_string'] && request.args['file_path'] &&
request.args['new_string'] request.args['old_string'] &&
) { request.args['new_string']
request.args['newContent'] = await this.getProposedContent( ) {
request.args['file_path'] as string, const newContent = await this.getProposedContent(
request.args['old_string'] as string, request.args['file_path'] as string,
request.args['new_string'] as string, request.args['old_string'] as string,
); request.args['new_string'] as string,
} );
} return { ...request, args: { ...request.args, newContent } };
}
return request;
}),
);
logger.info(`[Task] Scheduling batch of ${requests.length} tool calls.`); logger.info(
`[Task] Scheduling batch of ${updatedRequests.length} tool calls.`,
);
const stateChange: StateChange = { const stateChange: StateChange = {
kind: CoderAgentEvent.StateChangeEvent, kind: CoderAgentEvent.StateChangeEvent,
}; };
this.setTaskStateAndPublishUpdate('working', stateChange); this.setTaskStateAndPublishUpdate('working', stateChange);
await this.scheduler.schedule(requests, abortSignal); await this.scheduler.schedule(updatedRequests, abortSignal);
} }
async acceptAgentMessage(event: ServerGeminiStreamEvent): Promise<void> { async acceptAgentMessage(event: ServerGeminiStreamEvent): Promise<void> {
+34
View File
@@ -9,18 +9,52 @@ import type {
TaskStatusUpdateEvent, TaskStatusUpdateEvent,
SendStreamingMessageSuccessResponse, SendStreamingMessageSuccessResponse,
} from '@a2a-js/sdk'; } from '@a2a-js/sdk';
import { ApprovalMode } from '@google/gemini-cli-core';
import { import {
BaseDeclarativeTool, BaseDeclarativeTool,
BaseToolInvocation, BaseToolInvocation,
Kind, Kind,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import type { import type {
Config,
ToolCallConfirmationDetails, ToolCallConfirmationDetails,
ToolResult, ToolResult,
ToolInvocation, ToolInvocation,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import { expect, vi } from 'vitest'; import { expect, vi } from 'vitest';
export function createMockConfig(
overrides: Partial<Config> = {},
): Partial<Config> {
const mockConfig = {
getToolRegistry: vi.fn().mockReturnValue({
getTool: vi.fn(),
getAllToolNames: vi.fn().mockReturnValue([]),
}),
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
getIdeMode: vi.fn().mockReturnValue(false),
getAllowedTools: vi.fn().mockReturnValue([]),
getIdeClient: vi.fn(),
getWorkspaceContext: vi.fn().mockReturnValue({
isPathWithinWorkspace: () => true,
}),
getTargetDir: () => '/test',
getGeminiClient: vi.fn(),
getDebugMode: vi.fn().mockReturnValue(false),
getContentGeneratorConfig: vi.fn().mockReturnValue({ model: 'gemini-pro' }),
getModel: vi.fn().mockReturnValue('gemini-pro'),
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
setFlashFallbackHandler: vi.fn(),
initialize: vi.fn().mockResolvedValue(undefined),
getProxy: vi.fn().mockReturnValue(undefined),
getHistory: vi.fn().mockReturnValue([]),
getEmbeddingModel: vi.fn().mockReturnValue('text-embedding-004'),
getSessionId: vi.fn().mockReturnValue('test-session-id'),
...overrides,
};
return mockConfig;
}
export const mockOnUserConfirmForToolConfirmation = vi.fn(); export const mockOnUserConfirmForToolConfirmation = vi.fn();
export class MockToolInvocation extends BaseToolInvocation<object, ToolResult> { export class MockToolInvocation extends BaseToolInvocation<object, ToolResult> {