mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-06-11 11:57:03 -07:00
fix: Chat logs and errors handle tail tool calls correctly (#22460)
Co-authored-by: Abhi <43648792+abhipatel12@users.noreply.github.com>
This commit is contained in:
@@ -19,6 +19,11 @@ import {
|
||||
SYNTHETIC_THOUGHT_SIGNATURE,
|
||||
type StreamEvent,
|
||||
} from './geminiChat.js';
|
||||
import {
|
||||
type CompletedToolCall,
|
||||
CoreToolCallStatus,
|
||||
} from '../scheduler/types.js';
|
||||
import { MockTool } from '../test-utils/mock-tool.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { setSimulate429 } from '../utils/testUtils.js';
|
||||
import { DEFAULT_THINKING_MODE } from '../config/models.js';
|
||||
@@ -165,6 +170,9 @@ describe('GeminiChat', () => {
|
||||
getToolRegistry: vi.fn().mockReturnValue({
|
||||
getTool: vi.fn(),
|
||||
}),
|
||||
toolRegistry: {
|
||||
getTool: vi.fn(),
|
||||
},
|
||||
getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator),
|
||||
getRetryFetchErrors: vi.fn().mockReturnValue(false),
|
||||
getMaxAttempts: vi.fn().mockReturnValue(10),
|
||||
@@ -2569,4 +2577,78 @@ describe('GeminiChat', () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('recordCompletedToolCalls', () => {
|
||||
it('should use originalRequestName and originalRequestArgs if present', () => {
|
||||
const completedCall: CompletedToolCall = {
|
||||
status: CoreToolCallStatus.Success,
|
||||
request: {
|
||||
callId: 'call-1',
|
||||
name: 'tail-tool',
|
||||
args: { tail: 'args' },
|
||||
originalRequestName: 'original-tool',
|
||||
originalRequestArgs: { original: 'args' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'p1',
|
||||
},
|
||||
response: {
|
||||
callId: 'call-1',
|
||||
responseParts: [{ text: 'response' }],
|
||||
resultDisplay: undefined,
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
},
|
||||
tool: new MockTool({ name: 'mock-tool' }),
|
||||
invocation: new MockTool({ name: 'mock-tool' }).build({ key: 'value' }),
|
||||
};
|
||||
|
||||
const spy = vi.spyOn(chat.getChatRecordingService(), 'recordToolCalls');
|
||||
|
||||
chat.recordCompletedToolCalls('test-model', [completedCall]);
|
||||
|
||||
expect(spy).toHaveBeenCalledWith('test-model', [
|
||||
expect.objectContaining({
|
||||
id: 'call-1',
|
||||
name: 'original-tool',
|
||||
args: { original: 'args' },
|
||||
result: [{ text: 'response' }],
|
||||
}),
|
||||
]);
|
||||
});
|
||||
|
||||
it('should fall back to request name and args if original are not present', () => {
|
||||
const completedCall: CompletedToolCall = {
|
||||
status: CoreToolCallStatus.Success,
|
||||
request: {
|
||||
callId: 'call-1',
|
||||
name: 'tool-name',
|
||||
args: { key: 'value' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'p1',
|
||||
},
|
||||
response: {
|
||||
callId: 'call-1',
|
||||
responseParts: [{ text: 'response' }],
|
||||
resultDisplay: undefined,
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
},
|
||||
tool: new MockTool({ name: 'mock-tool' }),
|
||||
invocation: new MockTool({ name: 'mock-tool' }).build({ key: 'value' }),
|
||||
};
|
||||
|
||||
const spy = vi.spyOn(chat.getChatRecordingService(), 'recordToolCalls');
|
||||
|
||||
chat.recordCompletedToolCalls('test-model', [completedCall]);
|
||||
|
||||
expect(spy).toHaveBeenCalledWith('test-model', [
|
||||
expect.objectContaining({
|
||||
id: 'call-1',
|
||||
name: 'tool-name',
|
||||
args: { key: 'value' },
|
||||
result: [{ text: 'response' }],
|
||||
}),
|
||||
]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1032,8 +1032,8 @@ export class GeminiChat {
|
||||
|
||||
return {
|
||||
id: call.request.callId,
|
||||
name: call.request.name,
|
||||
args: call.request.args,
|
||||
name: call.request.originalRequestName ?? call.request.name,
|
||||
args: call.request.originalRequestArgs ?? call.request.args,
|
||||
result: call.response?.responseParts || null,
|
||||
status: call.status,
|
||||
timestamp: new Date().toISOString(),
|
||||
|
||||
@@ -669,6 +669,30 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should use originalRequestName when generating an error response', async () => {
|
||||
const error = new Error('Some error');
|
||||
vi.mocked(checkPolicy).mockRejectedValue(error);
|
||||
|
||||
const tailReq = { ...req1, originalRequestName: 'original-tool-name' };
|
||||
await scheduler.schedule(tailReq, signal);
|
||||
|
||||
expect(mockStateManager.updateStatus).toHaveBeenCalledWith(
|
||||
'call-1',
|
||||
CoreToolCallStatus.Error,
|
||||
expect.objectContaining({
|
||||
errorType: ToolErrorType.UNHANDLED_EXCEPTION,
|
||||
responseParts: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
functionResponse: expect.objectContaining({
|
||||
name: 'original-tool-name',
|
||||
response: { error: 'Some error' },
|
||||
}),
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle errors from checkPolicy (e.g. non-interactive ASK_USER)', async () => {
|
||||
const error = new Error('Not interactive');
|
||||
vi.mocked(checkPolicy).mockRejectedValue(error);
|
||||
@@ -1131,6 +1155,7 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
name: 'tool-b',
|
||||
args: { key: 'value' },
|
||||
originalRequestName: 'test-tool', // Preserves original name
|
||||
originalRequestArgs: req1.args, // Preserves original args
|
||||
}),
|
||||
tool: mockToolB,
|
||||
}),
|
||||
|
||||
@@ -77,7 +77,7 @@ const createErrorResponse = (
|
||||
{
|
||||
functionResponse: {
|
||||
id: request.callId,
|
||||
name: request.name,
|
||||
name: request.originalRequestName ?? request.name,
|
||||
response: { error: error.message },
|
||||
},
|
||||
},
|
||||
@@ -766,6 +766,8 @@ export class Scheduler {
|
||||
name: tailRequest.name,
|
||||
args: tailRequest.args,
|
||||
originalRequestName,
|
||||
originalRequestArgs:
|
||||
result.request.originalRequestArgs ?? result.request.args,
|
||||
isClientInitiated: result.request.isClientInitiated,
|
||||
prompt_id: result.request.prompt_id,
|
||||
schedulerId: this.schedulerId,
|
||||
|
||||
@@ -44,6 +44,8 @@ describe('SchedulerStateManager', () => {
|
||||
|
||||
const mockInvocation = {
|
||||
shouldConfirmExecute: vi.fn(),
|
||||
execute: vi.fn(),
|
||||
getDescription: vi.fn(),
|
||||
} as unknown as AnyToolInvocation;
|
||||
|
||||
const createValidatingCall = (
|
||||
@@ -610,6 +612,19 @@ describe('SchedulerStateManager', () => {
|
||||
expect(onUpdate).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should use originalRequestName when cancelling queued calls', () => {
|
||||
const call = createValidatingCall('tail-1');
|
||||
call.request.originalRequestName = 'original-tool';
|
||||
stateManager.enqueue([call]);
|
||||
|
||||
stateManager.cancelAllQueued('Batch cancel');
|
||||
|
||||
const completed = stateManager.completedBatch[0] as CancelledToolCall;
|
||||
expect(completed.response.responseParts[0]?.functionResponse?.name).toBe(
|
||||
'original-tool',
|
||||
);
|
||||
});
|
||||
|
||||
it('should not notify if cancelAllQueued is called on an empty queue', () => {
|
||||
vi.mocked(onUpdate).mockClear();
|
||||
stateManager.cancelAllQueued('Batch cancel');
|
||||
|
||||
@@ -517,7 +517,7 @@ export class SchedulerStateManager {
|
||||
{
|
||||
functionResponse: {
|
||||
id: call.request.callId,
|
||||
name: call.request.name,
|
||||
name: call.request.originalRequestName ?? call.request.name,
|
||||
response: { error: errorMessage },
|
||||
},
|
||||
},
|
||||
|
||||
@@ -332,6 +332,53 @@ describe('ToolExecutor', () => {
|
||||
expect(result.status).toBe(CoreToolCallStatus.Cancelled);
|
||||
});
|
||||
|
||||
it('should return cancelled result and use originalRequestName when signal is aborted', async () => {
|
||||
const mockTool = new MockTool({
|
||||
name: 'slowTool',
|
||||
});
|
||||
const invocation = mockTool.build({});
|
||||
|
||||
// Mock executeToolWithHooks to simulate slow execution
|
||||
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockImplementation(
|
||||
async () => {
|
||||
await new Promise((r) => setTimeout(r, 100));
|
||||
return { llmContent: 'Done', returnDisplay: 'Done' };
|
||||
},
|
||||
);
|
||||
|
||||
const scheduledCall: ScheduledToolCall = {
|
||||
status: CoreToolCallStatus.Scheduled,
|
||||
request: {
|
||||
callId: 'call-4',
|
||||
name: 'actualToolName',
|
||||
originalRequestName: 'originalToolName',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-4',
|
||||
},
|
||||
tool: mockTool,
|
||||
invocation: invocation as unknown as AnyToolInvocation,
|
||||
startTime: Date.now(),
|
||||
};
|
||||
|
||||
const controller = new AbortController();
|
||||
const promise = executor.execute({
|
||||
call: scheduledCall,
|
||||
signal: controller.signal,
|
||||
onUpdateToolCall: vi.fn(),
|
||||
});
|
||||
|
||||
controller.abort();
|
||||
const result = await promise;
|
||||
|
||||
expect(result.status).toBe(CoreToolCallStatus.Cancelled);
|
||||
if (result.status === CoreToolCallStatus.Cancelled) {
|
||||
expect(result.response.responseParts[0]?.functionResponse?.name).toBe(
|
||||
'originalToolName',
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it('should truncate large shell output', async () => {
|
||||
// 1. Setup Config for Truncation
|
||||
vi.spyOn(config, 'getTruncateToolOutputThreshold').mockReturnValue(10);
|
||||
|
||||
@@ -307,7 +307,7 @@ export class ToolExecutor {
|
||||
|
||||
outputFile = truncatedOutputFile;
|
||||
responseParts = convertToFunctionResponse(
|
||||
call.request.name,
|
||||
call.request.originalRequestName ?? call.request.name,
|
||||
call.request.callId,
|
||||
output,
|
||||
this.config.getActiveModel(),
|
||||
@@ -325,7 +325,7 @@ export class ToolExecutor {
|
||||
{
|
||||
functionResponse: {
|
||||
id: call.request.callId,
|
||||
name: call.request.name,
|
||||
name: call.request.originalRequestName ?? call.request.name,
|
||||
response: { error: errorMessage },
|
||||
},
|
||||
},
|
||||
|
||||
@@ -37,10 +37,12 @@ export interface ToolCallRequestInfo {
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
/**
|
||||
* The original name of the tool requested by the model.
|
||||
* This is used for tail calls to ensure the final response retains the original name.
|
||||
* The original name and arguments of the tool requested by the model.
|
||||
* This is used for tail calls to ensure the final response and log retains
|
||||
* the original values.
|
||||
*/
|
||||
originalRequestName?: string;
|
||||
originalRequestArgs?: Record<string, unknown>;
|
||||
isClientInitiated: boolean;
|
||||
prompt_id: string;
|
||||
checkpoint?: string;
|
||||
|
||||
Reference in New Issue
Block a user