mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-12 21:03:05 -07:00
fix: Chat logs and errors handle tail tool calls correctly (#22460)
Co-authored-by: Abhi <43648792+abhipatel12@users.noreply.github.com>
This commit is contained in:
@@ -19,6 +19,11 @@ import {
|
|||||||
SYNTHETIC_THOUGHT_SIGNATURE,
|
SYNTHETIC_THOUGHT_SIGNATURE,
|
||||||
type StreamEvent,
|
type StreamEvent,
|
||||||
} from './geminiChat.js';
|
} from './geminiChat.js';
|
||||||
|
import {
|
||||||
|
type CompletedToolCall,
|
||||||
|
CoreToolCallStatus,
|
||||||
|
} from '../scheduler/types.js';
|
||||||
|
import { MockTool } from '../test-utils/mock-tool.js';
|
||||||
import type { Config } from '../config/config.js';
|
import type { Config } from '../config/config.js';
|
||||||
import { setSimulate429 } from '../utils/testUtils.js';
|
import { setSimulate429 } from '../utils/testUtils.js';
|
||||||
import { DEFAULT_THINKING_MODE } from '../config/models.js';
|
import { DEFAULT_THINKING_MODE } from '../config/models.js';
|
||||||
@@ -165,6 +170,9 @@ describe('GeminiChat', () => {
|
|||||||
getToolRegistry: vi.fn().mockReturnValue({
|
getToolRegistry: vi.fn().mockReturnValue({
|
||||||
getTool: vi.fn(),
|
getTool: vi.fn(),
|
||||||
}),
|
}),
|
||||||
|
toolRegistry: {
|
||||||
|
getTool: vi.fn(),
|
||||||
|
},
|
||||||
getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator),
|
getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator),
|
||||||
getRetryFetchErrors: vi.fn().mockReturnValue(false),
|
getRetryFetchErrors: vi.fn().mockReturnValue(false),
|
||||||
getMaxAttempts: vi.fn().mockReturnValue(10),
|
getMaxAttempts: vi.fn().mockReturnValue(10),
|
||||||
@@ -2569,4 +2577,78 @@ describe('GeminiChat', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('recordCompletedToolCalls', () => {
|
||||||
|
it('should use originalRequestName and originalRequestArgs if present', () => {
|
||||||
|
const completedCall: CompletedToolCall = {
|
||||||
|
status: CoreToolCallStatus.Success,
|
||||||
|
request: {
|
||||||
|
callId: 'call-1',
|
||||||
|
name: 'tail-tool',
|
||||||
|
args: { tail: 'args' },
|
||||||
|
originalRequestName: 'original-tool',
|
||||||
|
originalRequestArgs: { original: 'args' },
|
||||||
|
isClientInitiated: false,
|
||||||
|
prompt_id: 'p1',
|
||||||
|
},
|
||||||
|
response: {
|
||||||
|
callId: 'call-1',
|
||||||
|
responseParts: [{ text: 'response' }],
|
||||||
|
resultDisplay: undefined,
|
||||||
|
error: undefined,
|
||||||
|
errorType: undefined,
|
||||||
|
},
|
||||||
|
tool: new MockTool({ name: 'mock-tool' }),
|
||||||
|
invocation: new MockTool({ name: 'mock-tool' }).build({ key: 'value' }),
|
||||||
|
};
|
||||||
|
|
||||||
|
const spy = vi.spyOn(chat.getChatRecordingService(), 'recordToolCalls');
|
||||||
|
|
||||||
|
chat.recordCompletedToolCalls('test-model', [completedCall]);
|
||||||
|
|
||||||
|
expect(spy).toHaveBeenCalledWith('test-model', [
|
||||||
|
expect.objectContaining({
|
||||||
|
id: 'call-1',
|
||||||
|
name: 'original-tool',
|
||||||
|
args: { original: 'args' },
|
||||||
|
result: [{ text: 'response' }],
|
||||||
|
}),
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should fall back to request name and args if original are not present', () => {
|
||||||
|
const completedCall: CompletedToolCall = {
|
||||||
|
status: CoreToolCallStatus.Success,
|
||||||
|
request: {
|
||||||
|
callId: 'call-1',
|
||||||
|
name: 'tool-name',
|
||||||
|
args: { key: 'value' },
|
||||||
|
isClientInitiated: false,
|
||||||
|
prompt_id: 'p1',
|
||||||
|
},
|
||||||
|
response: {
|
||||||
|
callId: 'call-1',
|
||||||
|
responseParts: [{ text: 'response' }],
|
||||||
|
resultDisplay: undefined,
|
||||||
|
error: undefined,
|
||||||
|
errorType: undefined,
|
||||||
|
},
|
||||||
|
tool: new MockTool({ name: 'mock-tool' }),
|
||||||
|
invocation: new MockTool({ name: 'mock-tool' }).build({ key: 'value' }),
|
||||||
|
};
|
||||||
|
|
||||||
|
const spy = vi.spyOn(chat.getChatRecordingService(), 'recordToolCalls');
|
||||||
|
|
||||||
|
chat.recordCompletedToolCalls('test-model', [completedCall]);
|
||||||
|
|
||||||
|
expect(spy).toHaveBeenCalledWith('test-model', [
|
||||||
|
expect.objectContaining({
|
||||||
|
id: 'call-1',
|
||||||
|
name: 'tool-name',
|
||||||
|
args: { key: 'value' },
|
||||||
|
result: [{ text: 'response' }],
|
||||||
|
}),
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1032,8 +1032,8 @@ export class GeminiChat {
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
id: call.request.callId,
|
id: call.request.callId,
|
||||||
name: call.request.name,
|
name: call.request.originalRequestName ?? call.request.name,
|
||||||
args: call.request.args,
|
args: call.request.originalRequestArgs ?? call.request.args,
|
||||||
result: call.response?.responseParts || null,
|
result: call.response?.responseParts || null,
|
||||||
status: call.status,
|
status: call.status,
|
||||||
timestamp: new Date().toISOString(),
|
timestamp: new Date().toISOString(),
|
||||||
|
|||||||
@@ -669,6 +669,30 @@ describe('Scheduler (Orchestrator)', () => {
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should use originalRequestName when generating an error response', async () => {
|
||||||
|
const error = new Error('Some error');
|
||||||
|
vi.mocked(checkPolicy).mockRejectedValue(error);
|
||||||
|
|
||||||
|
const tailReq = { ...req1, originalRequestName: 'original-tool-name' };
|
||||||
|
await scheduler.schedule(tailReq, signal);
|
||||||
|
|
||||||
|
expect(mockStateManager.updateStatus).toHaveBeenCalledWith(
|
||||||
|
'call-1',
|
||||||
|
CoreToolCallStatus.Error,
|
||||||
|
expect.objectContaining({
|
||||||
|
errorType: ToolErrorType.UNHANDLED_EXCEPTION,
|
||||||
|
responseParts: expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
functionResponse: expect.objectContaining({
|
||||||
|
name: 'original-tool-name',
|
||||||
|
response: { error: 'Some error' },
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
it('should handle errors from checkPolicy (e.g. non-interactive ASK_USER)', async () => {
|
it('should handle errors from checkPolicy (e.g. non-interactive ASK_USER)', async () => {
|
||||||
const error = new Error('Not interactive');
|
const error = new Error('Not interactive');
|
||||||
vi.mocked(checkPolicy).mockRejectedValue(error);
|
vi.mocked(checkPolicy).mockRejectedValue(error);
|
||||||
@@ -1131,6 +1155,7 @@ describe('Scheduler (Orchestrator)', () => {
|
|||||||
name: 'tool-b',
|
name: 'tool-b',
|
||||||
args: { key: 'value' },
|
args: { key: 'value' },
|
||||||
originalRequestName: 'test-tool', // Preserves original name
|
originalRequestName: 'test-tool', // Preserves original name
|
||||||
|
originalRequestArgs: req1.args, // Preserves original args
|
||||||
}),
|
}),
|
||||||
tool: mockToolB,
|
tool: mockToolB,
|
||||||
}),
|
}),
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ const createErrorResponse = (
|
|||||||
{
|
{
|
||||||
functionResponse: {
|
functionResponse: {
|
||||||
id: request.callId,
|
id: request.callId,
|
||||||
name: request.name,
|
name: request.originalRequestName ?? request.name,
|
||||||
response: { error: error.message },
|
response: { error: error.message },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -766,6 +766,8 @@ export class Scheduler {
|
|||||||
name: tailRequest.name,
|
name: tailRequest.name,
|
||||||
args: tailRequest.args,
|
args: tailRequest.args,
|
||||||
originalRequestName,
|
originalRequestName,
|
||||||
|
originalRequestArgs:
|
||||||
|
result.request.originalRequestArgs ?? result.request.args,
|
||||||
isClientInitiated: result.request.isClientInitiated,
|
isClientInitiated: result.request.isClientInitiated,
|
||||||
prompt_id: result.request.prompt_id,
|
prompt_id: result.request.prompt_id,
|
||||||
schedulerId: this.schedulerId,
|
schedulerId: this.schedulerId,
|
||||||
|
|||||||
@@ -44,6 +44,8 @@ describe('SchedulerStateManager', () => {
|
|||||||
|
|
||||||
const mockInvocation = {
|
const mockInvocation = {
|
||||||
shouldConfirmExecute: vi.fn(),
|
shouldConfirmExecute: vi.fn(),
|
||||||
|
execute: vi.fn(),
|
||||||
|
getDescription: vi.fn(),
|
||||||
} as unknown as AnyToolInvocation;
|
} as unknown as AnyToolInvocation;
|
||||||
|
|
||||||
const createValidatingCall = (
|
const createValidatingCall = (
|
||||||
@@ -610,6 +612,19 @@ describe('SchedulerStateManager', () => {
|
|||||||
expect(onUpdate).toHaveBeenCalledTimes(1);
|
expect(onUpdate).toHaveBeenCalledTimes(1);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should use originalRequestName when cancelling queued calls', () => {
|
||||||
|
const call = createValidatingCall('tail-1');
|
||||||
|
call.request.originalRequestName = 'original-tool';
|
||||||
|
stateManager.enqueue([call]);
|
||||||
|
|
||||||
|
stateManager.cancelAllQueued('Batch cancel');
|
||||||
|
|
||||||
|
const completed = stateManager.completedBatch[0] as CancelledToolCall;
|
||||||
|
expect(completed.response.responseParts[0]?.functionResponse?.name).toBe(
|
||||||
|
'original-tool',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
it('should not notify if cancelAllQueued is called on an empty queue', () => {
|
it('should not notify if cancelAllQueued is called on an empty queue', () => {
|
||||||
vi.mocked(onUpdate).mockClear();
|
vi.mocked(onUpdate).mockClear();
|
||||||
stateManager.cancelAllQueued('Batch cancel');
|
stateManager.cancelAllQueued('Batch cancel');
|
||||||
|
|||||||
@@ -517,7 +517,7 @@ export class SchedulerStateManager {
|
|||||||
{
|
{
|
||||||
functionResponse: {
|
functionResponse: {
|
||||||
id: call.request.callId,
|
id: call.request.callId,
|
||||||
name: call.request.name,
|
name: call.request.originalRequestName ?? call.request.name,
|
||||||
response: { error: errorMessage },
|
response: { error: errorMessage },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -332,6 +332,53 @@ describe('ToolExecutor', () => {
|
|||||||
expect(result.status).toBe(CoreToolCallStatus.Cancelled);
|
expect(result.status).toBe(CoreToolCallStatus.Cancelled);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should return cancelled result and use originalRequestName when signal is aborted', async () => {
|
||||||
|
const mockTool = new MockTool({
|
||||||
|
name: 'slowTool',
|
||||||
|
});
|
||||||
|
const invocation = mockTool.build({});
|
||||||
|
|
||||||
|
// Mock executeToolWithHooks to simulate slow execution
|
||||||
|
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockImplementation(
|
||||||
|
async () => {
|
||||||
|
await new Promise((r) => setTimeout(r, 100));
|
||||||
|
return { llmContent: 'Done', returnDisplay: 'Done' };
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
const scheduledCall: ScheduledToolCall = {
|
||||||
|
status: CoreToolCallStatus.Scheduled,
|
||||||
|
request: {
|
||||||
|
callId: 'call-4',
|
||||||
|
name: 'actualToolName',
|
||||||
|
originalRequestName: 'originalToolName',
|
||||||
|
args: {},
|
||||||
|
isClientInitiated: false,
|
||||||
|
prompt_id: 'prompt-4',
|
||||||
|
},
|
||||||
|
tool: mockTool,
|
||||||
|
invocation: invocation as unknown as AnyToolInvocation,
|
||||||
|
startTime: Date.now(),
|
||||||
|
};
|
||||||
|
|
||||||
|
const controller = new AbortController();
|
||||||
|
const promise = executor.execute({
|
||||||
|
call: scheduledCall,
|
||||||
|
signal: controller.signal,
|
||||||
|
onUpdateToolCall: vi.fn(),
|
||||||
|
});
|
||||||
|
|
||||||
|
controller.abort();
|
||||||
|
const result = await promise;
|
||||||
|
|
||||||
|
expect(result.status).toBe(CoreToolCallStatus.Cancelled);
|
||||||
|
if (result.status === CoreToolCallStatus.Cancelled) {
|
||||||
|
expect(result.response.responseParts[0]?.functionResponse?.name).toBe(
|
||||||
|
'originalToolName',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
it('should truncate large shell output', async () => {
|
it('should truncate large shell output', async () => {
|
||||||
// 1. Setup Config for Truncation
|
// 1. Setup Config for Truncation
|
||||||
vi.spyOn(config, 'getTruncateToolOutputThreshold').mockReturnValue(10);
|
vi.spyOn(config, 'getTruncateToolOutputThreshold').mockReturnValue(10);
|
||||||
|
|||||||
@@ -307,7 +307,7 @@ export class ToolExecutor {
|
|||||||
|
|
||||||
outputFile = truncatedOutputFile;
|
outputFile = truncatedOutputFile;
|
||||||
responseParts = convertToFunctionResponse(
|
responseParts = convertToFunctionResponse(
|
||||||
call.request.name,
|
call.request.originalRequestName ?? call.request.name,
|
||||||
call.request.callId,
|
call.request.callId,
|
||||||
output,
|
output,
|
||||||
this.config.getActiveModel(),
|
this.config.getActiveModel(),
|
||||||
@@ -325,7 +325,7 @@ export class ToolExecutor {
|
|||||||
{
|
{
|
||||||
functionResponse: {
|
functionResponse: {
|
||||||
id: call.request.callId,
|
id: call.request.callId,
|
||||||
name: call.request.name,
|
name: call.request.originalRequestName ?? call.request.name,
|
||||||
response: { error: errorMessage },
|
response: { error: errorMessage },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -37,10 +37,12 @@ export interface ToolCallRequestInfo {
|
|||||||
name: string;
|
name: string;
|
||||||
args: Record<string, unknown>;
|
args: Record<string, unknown>;
|
||||||
/**
|
/**
|
||||||
* The original name of the tool requested by the model.
|
* The original name and arguments of the tool requested by the model.
|
||||||
* This is used for tail calls to ensure the final response retains the original name.
|
* This is used for tail calls to ensure the final response and log retains
|
||||||
|
* the original values.
|
||||||
*/
|
*/
|
||||||
originalRequestName?: string;
|
originalRequestName?: string;
|
||||||
|
originalRequestArgs?: Record<string, unknown>;
|
||||||
isClientInitiated: boolean;
|
isClientInitiated: boolean;
|
||||||
prompt_id: string;
|
prompt_id: string;
|
||||||
checkpoint?: string;
|
checkpoint?: string;
|
||||||
|
|||||||
Reference in New Issue
Block a user