mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-02 16:04:38 -07:00
fix(a2a): Don't mutate 'replace' tool args in scheduleToolCalls (#7369)
Co-authored-by: Abhi <43648792+abhipatel12@users.noreply.github.com>
This commit is contained in:
@@ -33,6 +33,7 @@ import {
|
|||||||
assertTaskCreationAndWorkingStatus,
|
assertTaskCreationAndWorkingStatus,
|
||||||
createStreamMessageRequest,
|
createStreamMessageRequest,
|
||||||
MockTool,
|
MockTool,
|
||||||
|
createMockConfig,
|
||||||
} from './testing_utils.js';
|
} from './testing_utils.js';
|
||||||
|
|
||||||
const mockToolConfirmationFn = async () =>
|
const mockToolConfirmationFn = async () =>
|
||||||
@@ -68,26 +69,11 @@ vi.mock('./config.js', async () => {
|
|||||||
return {
|
return {
|
||||||
...actual,
|
...actual,
|
||||||
loadConfig: vi.fn().mockImplementation(async () => {
|
loadConfig: vi.fn().mockImplementation(async () => {
|
||||||
config = {
|
const mockConfig = createMockConfig({
|
||||||
getToolRegistry: getToolRegistrySpy,
|
getToolRegistry: getToolRegistrySpy,
|
||||||
getApprovalMode: getApprovalModeSpy,
|
getApprovalMode: getApprovalModeSpy,
|
||||||
getIdeMode: vi.fn().mockReturnValue(false),
|
});
|
||||||
getAllowedTools: vi.fn().mockReturnValue([]),
|
config = mockConfig as Config;
|
||||||
getIdeClient: vi.fn(),
|
|
||||||
getWorkspaceContext: vi.fn().mockReturnValue({
|
|
||||||
isPathWithinWorkspace: () => true,
|
|
||||||
}),
|
|
||||||
getTargetDir: () => '/test',
|
|
||||||
getGeminiClient: vi.fn(),
|
|
||||||
getDebugMode: vi.fn().mockReturnValue(false),
|
|
||||||
getContentGeneratorConfig: vi
|
|
||||||
.fn()
|
|
||||||
.mockReturnValue({ model: 'gemini-pro' }),
|
|
||||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
|
||||||
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
|
|
||||||
setFlashFallbackHandler: vi.fn(),
|
|
||||||
initialize: vi.fn().mockResolvedValue(undefined),
|
|
||||||
} as unknown as Config;
|
|
||||||
return config;
|
return config;
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -0,0 +1,59 @@
|
|||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, vi } from 'vitest';
|
||||||
|
import { Task } from './task.js';
|
||||||
|
import type { Config, ToolCallRequestInfo } from '@google/gemini-cli-core';
|
||||||
|
import { createMockConfig } from './testing_utils.js';
|
||||||
|
import type { ExecutionEventBus } from '@a2a-js/sdk/server';
|
||||||
|
|
||||||
|
describe('Task', () => {
|
||||||
|
it('scheduleToolCalls should not modify the input requests array', async () => {
|
||||||
|
const mockConfig = createMockConfig();
|
||||||
|
|
||||||
|
const mockEventBus: ExecutionEventBus = {
|
||||||
|
publish: vi.fn(),
|
||||||
|
on: vi.fn(),
|
||||||
|
off: vi.fn(),
|
||||||
|
once: vi.fn(),
|
||||||
|
removeAllListeners: vi.fn(),
|
||||||
|
finished: vi.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// The Task constructor is private. We'll bypass it for this unit test.
|
||||||
|
// @ts-expect-error - Calling private constructor for test purposes.
|
||||||
|
const task = new Task(
|
||||||
|
'task-id',
|
||||||
|
'context-id',
|
||||||
|
mockConfig as Config,
|
||||||
|
mockEventBus,
|
||||||
|
);
|
||||||
|
|
||||||
|
task['setTaskStateAndPublishUpdate'] = vi.fn();
|
||||||
|
task['getProposedContent'] = vi.fn().mockResolvedValue('new content');
|
||||||
|
|
||||||
|
const requests: ToolCallRequestInfo[] = [
|
||||||
|
{
|
||||||
|
callId: '1',
|
||||||
|
name: 'replace',
|
||||||
|
args: {
|
||||||
|
file_path: 'test.txt',
|
||||||
|
old_string: 'old',
|
||||||
|
new_string: 'new',
|
||||||
|
},
|
||||||
|
isClientInitiated: false,
|
||||||
|
prompt_id: 'prompt-id-1',
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const originalRequests = JSON.parse(JSON.stringify(requests));
|
||||||
|
const abortController = new AbortController();
|
||||||
|
|
||||||
|
await task.scheduleToolCalls(requests, abortController.signal);
|
||||||
|
|
||||||
|
expect(requests).toEqual(originalRequests);
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -520,30 +520,36 @@ export class Task {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const request of requests) {
|
const updatedRequests = await Promise.all(
|
||||||
if (
|
requests.map(async (request) => {
|
||||||
!request.args['newContent'] &&
|
if (
|
||||||
request.name === 'replace' &&
|
request.name === 'replace' &&
|
||||||
request.args &&
|
request.args &&
|
||||||
request.args['file_path'] &&
|
!request.args['newContent'] &&
|
||||||
request.args['old_string'] &&
|
request.args['file_path'] &&
|
||||||
request.args['new_string']
|
request.args['old_string'] &&
|
||||||
) {
|
request.args['new_string']
|
||||||
request.args['newContent'] = await this.getProposedContent(
|
) {
|
||||||
request.args['file_path'] as string,
|
const newContent = await this.getProposedContent(
|
||||||
request.args['old_string'] as string,
|
request.args['file_path'] as string,
|
||||||
request.args['new_string'] as string,
|
request.args['old_string'] as string,
|
||||||
);
|
request.args['new_string'] as string,
|
||||||
}
|
);
|
||||||
}
|
return { ...request, args: { ...request.args, newContent } };
|
||||||
|
}
|
||||||
|
return request;
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
logger.info(`[Task] Scheduling batch of ${requests.length} tool calls.`);
|
logger.info(
|
||||||
|
`[Task] Scheduling batch of ${updatedRequests.length} tool calls.`,
|
||||||
|
);
|
||||||
const stateChange: StateChange = {
|
const stateChange: StateChange = {
|
||||||
kind: CoderAgentEvent.StateChangeEvent,
|
kind: CoderAgentEvent.StateChangeEvent,
|
||||||
};
|
};
|
||||||
this.setTaskStateAndPublishUpdate('working', stateChange);
|
this.setTaskStateAndPublishUpdate('working', stateChange);
|
||||||
|
|
||||||
await this.scheduler.schedule(requests, abortSignal);
|
await this.scheduler.schedule(updatedRequests, abortSignal);
|
||||||
}
|
}
|
||||||
|
|
||||||
async acceptAgentMessage(event: ServerGeminiStreamEvent): Promise<void> {
|
async acceptAgentMessage(event: ServerGeminiStreamEvent): Promise<void> {
|
||||||
|
|||||||
@@ -9,18 +9,52 @@ import type {
|
|||||||
TaskStatusUpdateEvent,
|
TaskStatusUpdateEvent,
|
||||||
SendStreamingMessageSuccessResponse,
|
SendStreamingMessageSuccessResponse,
|
||||||
} from '@a2a-js/sdk';
|
} from '@a2a-js/sdk';
|
||||||
|
import { ApprovalMode } from '@google/gemini-cli-core';
|
||||||
import {
|
import {
|
||||||
BaseDeclarativeTool,
|
BaseDeclarativeTool,
|
||||||
BaseToolInvocation,
|
BaseToolInvocation,
|
||||||
Kind,
|
Kind,
|
||||||
} from '@google/gemini-cli-core';
|
} from '@google/gemini-cli-core';
|
||||||
import type {
|
import type {
|
||||||
|
Config,
|
||||||
ToolCallConfirmationDetails,
|
ToolCallConfirmationDetails,
|
||||||
ToolResult,
|
ToolResult,
|
||||||
ToolInvocation,
|
ToolInvocation,
|
||||||
} from '@google/gemini-cli-core';
|
} from '@google/gemini-cli-core';
|
||||||
import { expect, vi } from 'vitest';
|
import { expect, vi } from 'vitest';
|
||||||
|
|
||||||
|
export function createMockConfig(
|
||||||
|
overrides: Partial<Config> = {},
|
||||||
|
): Partial<Config> {
|
||||||
|
const mockConfig = {
|
||||||
|
getToolRegistry: vi.fn().mockReturnValue({
|
||||||
|
getTool: vi.fn(),
|
||||||
|
getAllToolNames: vi.fn().mockReturnValue([]),
|
||||||
|
}),
|
||||||
|
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
|
||||||
|
getIdeMode: vi.fn().mockReturnValue(false),
|
||||||
|
getAllowedTools: vi.fn().mockReturnValue([]),
|
||||||
|
getIdeClient: vi.fn(),
|
||||||
|
getWorkspaceContext: vi.fn().mockReturnValue({
|
||||||
|
isPathWithinWorkspace: () => true,
|
||||||
|
}),
|
||||||
|
getTargetDir: () => '/test',
|
||||||
|
getGeminiClient: vi.fn(),
|
||||||
|
getDebugMode: vi.fn().mockReturnValue(false),
|
||||||
|
getContentGeneratorConfig: vi.fn().mockReturnValue({ model: 'gemini-pro' }),
|
||||||
|
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||||
|
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
|
||||||
|
setFlashFallbackHandler: vi.fn(),
|
||||||
|
initialize: vi.fn().mockResolvedValue(undefined),
|
||||||
|
getProxy: vi.fn().mockReturnValue(undefined),
|
||||||
|
getHistory: vi.fn().mockReturnValue([]),
|
||||||
|
getEmbeddingModel: vi.fn().mockReturnValue('text-embedding-004'),
|
||||||
|
getSessionId: vi.fn().mockReturnValue('test-session-id'),
|
||||||
|
...overrides,
|
||||||
|
};
|
||||||
|
return mockConfig;
|
||||||
|
}
|
||||||
|
|
||||||
export const mockOnUserConfirmForToolConfirmation = vi.fn();
|
export const mockOnUserConfirmForToolConfirmation = vi.fn();
|
||||||
|
|
||||||
export class MockToolInvocation extends BaseToolInvocation<object, ToolResult> {
|
export class MockToolInvocation extends BaseToolInvocation<object, ToolResult> {
|
||||||
|
|||||||
Reference in New Issue
Block a user