mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-04 17:04:04 -07:00
bug(core): Avoid stateful tool use in executor. (#14305)
This commit is contained in:
@@ -65,13 +65,11 @@ const {
|
|||||||
mockExecuteToolCall,
|
mockExecuteToolCall,
|
||||||
mockSetSystemInstruction,
|
mockSetSystemInstruction,
|
||||||
mockCompress,
|
mockCompress,
|
||||||
mockSetTools,
|
|
||||||
} = vi.hoisted(() => ({
|
} = vi.hoisted(() => ({
|
||||||
mockSendMessageStream: vi.fn(),
|
mockSendMessageStream: vi.fn(),
|
||||||
mockExecuteToolCall: vi.fn(),
|
mockExecuteToolCall: vi.fn(),
|
||||||
mockSetSystemInstruction: vi.fn(),
|
mockSetSystemInstruction: vi.fn(),
|
||||||
mockCompress: vi.fn(),
|
mockCompress: vi.fn(),
|
||||||
mockSetTools: vi.fn(),
|
|
||||||
}));
|
}));
|
||||||
|
|
||||||
let mockChatHistory: Content[] = [];
|
let mockChatHistory: Content[] = [];
|
||||||
@@ -94,7 +92,6 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => {
|
|||||||
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||||
setHistory: mockSetHistory,
|
setHistory: mockSetHistory,
|
||||||
setSystemInstruction: mockSetSystemInstruction,
|
setSystemInstruction: mockSetSystemInstruction,
|
||||||
setTools: mockSetTools,
|
|
||||||
})),
|
})),
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
@@ -238,7 +235,6 @@ describe('AgentExecutor', () => {
|
|||||||
mockSetHistory.mockClear();
|
mockSetHistory.mockClear();
|
||||||
mockSendMessageStream.mockReset();
|
mockSendMessageStream.mockReset();
|
||||||
mockSetSystemInstruction.mockReset();
|
mockSetSystemInstruction.mockReset();
|
||||||
mockSetTools.mockReset();
|
|
||||||
mockExecuteToolCall.mockReset();
|
mockExecuteToolCall.mockReset();
|
||||||
mockedLogAgentStart.mockReset();
|
mockedLogAgentStart.mockReset();
|
||||||
mockedLogAgentFinish.mockReset();
|
mockedLogAgentFinish.mockReset();
|
||||||
@@ -258,7 +254,6 @@ describe('AgentExecutor', () => {
|
|||||||
({
|
({
|
||||||
sendMessageStream: mockSendMessageStream,
|
sendMessageStream: mockSendMessageStream,
|
||||||
setSystemInstruction: mockSetSystemInstruction,
|
setSystemInstruction: mockSetSystemInstruction,
|
||||||
setTools: mockSetTools,
|
|
||||||
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||||
getLastPromptTokenCount: vi.fn(() => 100),
|
getLastPromptTokenCount: vi.fn(() => 100),
|
||||||
setHistory: mockSetHistory,
|
setHistory: mockSetHistory,
|
||||||
@@ -490,8 +485,10 @@ describe('AgentExecutor', () => {
|
|||||||
const { modelConfigKey } = getMockMessageParams(0);
|
const { modelConfigKey } = getMockMessageParams(0);
|
||||||
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
|
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
|
||||||
|
|
||||||
const call = mockSetTools.mock.calls[0];
|
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
|
||||||
const sentTools = (call[0] as Tool[])[0].functionDeclarations;
|
// tools are the 3rd argument (index 2), passed as [{ functionDeclarations: [...] }]
|
||||||
|
const passedToolsArg = chatConstructorArgs[2] as Tool[];
|
||||||
|
const sentTools = passedToolsArg[0].functionDeclarations;
|
||||||
expect(sentTools).toBeDefined();
|
expect(sentTools).toBeDefined();
|
||||||
|
|
||||||
expect(sentTools).toEqual(
|
expect(sentTools).toEqual(
|
||||||
@@ -615,8 +612,9 @@ describe('AgentExecutor', () => {
|
|||||||
const { modelConfigKey } = getMockMessageParams(0);
|
const { modelConfigKey } = getMockMessageParams(0);
|
||||||
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
|
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
|
||||||
|
|
||||||
const call = mockSetTools.mock.calls[0];
|
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
|
||||||
const sentTools = (call[0] as Tool[])[0].functionDeclarations;
|
const passedToolsArg = chatConstructorArgs[2] as Tool[];
|
||||||
|
const sentTools = passedToolsArg[0].functionDeclarations;
|
||||||
expect(sentTools).toBeDefined();
|
expect(sentTools).toBeDefined();
|
||||||
|
|
||||||
const completeToolDef = sentTools!.find(
|
const completeToolDef = sentTools!.find(
|
||||||
|
|||||||
@@ -182,7 +182,6 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
private async executeTurn(
|
private async executeTurn(
|
||||||
chat: GeminiChat,
|
chat: GeminiChat,
|
||||||
currentMessage: Content,
|
currentMessage: Content,
|
||||||
tools: FunctionDeclaration[],
|
|
||||||
turnCounter: number,
|
turnCounter: number,
|
||||||
combinedSignal: AbortSignal,
|
combinedSignal: AbortSignal,
|
||||||
timeoutSignal: AbortSignal, // Pass the timeout controller's signal
|
timeoutSignal: AbortSignal, // Pass the timeout controller's signal
|
||||||
@@ -192,7 +191,7 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
await this.tryCompressChat(chat, promptId);
|
await this.tryCompressChat(chat, promptId);
|
||||||
|
|
||||||
const { functionCalls } = await promptIdContext.run(promptId, async () =>
|
const { functionCalls } = await promptIdContext.run(promptId, async () =>
|
||||||
this.callModel(chat, currentMessage, tools, combinedSignal, promptId),
|
this.callModel(chat, currentMessage, combinedSignal, promptId),
|
||||||
);
|
);
|
||||||
|
|
||||||
if (combinedSignal.aborted) {
|
if (combinedSignal.aborted) {
|
||||||
@@ -272,7 +271,6 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
*/
|
*/
|
||||||
private async executeFinalWarningTurn(
|
private async executeFinalWarningTurn(
|
||||||
chat: GeminiChat,
|
chat: GeminiChat,
|
||||||
tools: FunctionDeclaration[],
|
|
||||||
turnCounter: number,
|
turnCounter: number,
|
||||||
reason:
|
reason:
|
||||||
| AgentTerminateMode.TIMEOUT
|
| AgentTerminateMode.TIMEOUT
|
||||||
@@ -309,7 +307,6 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
const turnResult = await this.executeTurn(
|
const turnResult = await this.executeTurn(
|
||||||
chat,
|
chat,
|
||||||
recoveryMessage,
|
recoveryMessage,
|
||||||
tools,
|
|
||||||
turnCounter, // This will be the "last" turn number
|
turnCounter, // This will be the "last" turn number
|
||||||
combinedSignal,
|
combinedSignal,
|
||||||
graceTimeoutController.signal, // Pass grace signal to identify a *grace* timeout
|
graceTimeoutController.signal, // Pass grace signal to identify a *grace* timeout
|
||||||
@@ -387,8 +384,8 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
let chat: GeminiChat | undefined;
|
let chat: GeminiChat | undefined;
|
||||||
let tools: FunctionDeclaration[] | undefined;
|
let tools: FunctionDeclaration[] | undefined;
|
||||||
try {
|
try {
|
||||||
chat = await this.createChatObject(inputs);
|
|
||||||
tools = this.prepareToolsList();
|
tools = this.prepareToolsList();
|
||||||
|
chat = await this.createChatObject(inputs, tools);
|
||||||
const query = this.definition.promptConfig.query
|
const query = this.definition.promptConfig.query
|
||||||
? templateString(this.definition.promptConfig.query, inputs)
|
? templateString(this.definition.promptConfig.query, inputs)
|
||||||
: 'Get Started!';
|
: 'Get Started!';
|
||||||
@@ -414,7 +411,6 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
const turnResult = await this.executeTurn(
|
const turnResult = await this.executeTurn(
|
||||||
chat,
|
chat,
|
||||||
currentMessage,
|
currentMessage,
|
||||||
tools,
|
|
||||||
turnCounter++,
|
turnCounter++,
|
||||||
combinedSignal,
|
combinedSignal,
|
||||||
timeoutController.signal,
|
timeoutController.signal,
|
||||||
@@ -443,7 +439,6 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
) {
|
) {
|
||||||
const recoveryResult = await this.executeFinalWarningTurn(
|
const recoveryResult = await this.executeFinalWarningTurn(
|
||||||
chat,
|
chat,
|
||||||
tools,
|
|
||||||
turnCounter, // Use current turnCounter for the recovery attempt
|
turnCounter, // Use current turnCounter for the recovery attempt
|
||||||
terminateReason,
|
terminateReason,
|
||||||
signal, // Pass the external signal
|
signal, // Pass the external signal
|
||||||
@@ -509,7 +504,6 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
if (chat && tools) {
|
if (chat && tools) {
|
||||||
const recoveryResult = await this.executeFinalWarningTurn(
|
const recoveryResult = await this.executeFinalWarningTurn(
|
||||||
chat,
|
chat,
|
||||||
tools,
|
|
||||||
turnCounter, // Use current turnCounter
|
turnCounter, // Use current turnCounter
|
||||||
AgentTerminateMode.TIMEOUT,
|
AgentTerminateMode.TIMEOUT,
|
||||||
signal,
|
signal,
|
||||||
@@ -591,15 +585,9 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
private async callModel(
|
private async callModel(
|
||||||
chat: GeminiChat,
|
chat: GeminiChat,
|
||||||
message: Content,
|
message: Content,
|
||||||
tools: FunctionDeclaration[],
|
|
||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
promptId: string,
|
promptId: string,
|
||||||
): Promise<{ functionCalls: FunctionCall[]; textResponse: string }> {
|
): Promise<{ functionCalls: FunctionCall[]; textResponse: string }> {
|
||||||
if (tools.length > 0) {
|
|
||||||
// TODO(12622): Move tools back to config.
|
|
||||||
chat.setTools([{ functionDeclarations: tools }]);
|
|
||||||
}
|
|
||||||
|
|
||||||
const responseStream = await chat.sendMessageStream(
|
const responseStream = await chat.sendMessageStream(
|
||||||
{
|
{
|
||||||
model: getModelConfigAlias(this.definition),
|
model: getModelConfigAlias(this.definition),
|
||||||
@@ -650,7 +638,10 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Initializes a `GeminiChat` instance for the agent run. */
|
/** Initializes a `GeminiChat` instance for the agent run. */
|
||||||
private async createChatObject(inputs: AgentInputs): Promise<GeminiChat> {
|
private async createChatObject(
|
||||||
|
inputs: AgentInputs,
|
||||||
|
tools: FunctionDeclaration[],
|
||||||
|
): Promise<GeminiChat> {
|
||||||
const { promptConfig } = this.definition;
|
const { promptConfig } = this.definition;
|
||||||
|
|
||||||
if (!promptConfig.systemPrompt && !promptConfig.initialMessages) {
|
if (!promptConfig.systemPrompt && !promptConfig.initialMessages) {
|
||||||
@@ -673,7 +664,7 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
return new GeminiChat(
|
return new GeminiChat(
|
||||||
this.runtimeContext,
|
this.runtimeContext,
|
||||||
systemInstruction,
|
systemInstruction,
|
||||||
[], // set in `callModel`,
|
[{ functionDeclarations: tools }],
|
||||||
startHistory,
|
startHistory,
|
||||||
);
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user