bug(core): Avoid stateful tool use in executor. (#14305)

This commit is contained in:
joshualitt
2025-12-01 10:54:28 -08:00
committed by GitHub
parent db027dd95b
commit 62f890b5aa
2 changed files with 14 additions and 25 deletions
+7 -9
View File
@@ -65,13 +65,11 @@ const {
mockExecuteToolCall, mockExecuteToolCall,
mockSetSystemInstruction, mockSetSystemInstruction,
mockCompress, mockCompress,
mockSetTools,
} = vi.hoisted(() => ({ } = vi.hoisted(() => ({
mockSendMessageStream: vi.fn(), mockSendMessageStream: vi.fn(),
mockExecuteToolCall: vi.fn(), mockExecuteToolCall: vi.fn(),
mockSetSystemInstruction: vi.fn(), mockSetSystemInstruction: vi.fn(),
mockCompress: vi.fn(), mockCompress: vi.fn(),
mockSetTools: vi.fn(),
})); }));
let mockChatHistory: Content[] = []; let mockChatHistory: Content[] = [];
@@ -94,7 +92,6 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => {
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]), getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
setHistory: mockSetHistory, setHistory: mockSetHistory,
setSystemInstruction: mockSetSystemInstruction, setSystemInstruction: mockSetSystemInstruction,
setTools: mockSetTools,
})), })),
}; };
}); });
@@ -238,7 +235,6 @@ describe('AgentExecutor', () => {
mockSetHistory.mockClear(); mockSetHistory.mockClear();
mockSendMessageStream.mockReset(); mockSendMessageStream.mockReset();
mockSetSystemInstruction.mockReset(); mockSetSystemInstruction.mockReset();
mockSetTools.mockReset();
mockExecuteToolCall.mockReset(); mockExecuteToolCall.mockReset();
mockedLogAgentStart.mockReset(); mockedLogAgentStart.mockReset();
mockedLogAgentFinish.mockReset(); mockedLogAgentFinish.mockReset();
@@ -258,7 +254,6 @@ describe('AgentExecutor', () => {
({ ({
sendMessageStream: mockSendMessageStream, sendMessageStream: mockSendMessageStream,
setSystemInstruction: mockSetSystemInstruction, setSystemInstruction: mockSetSystemInstruction,
setTools: mockSetTools,
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]), getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
getLastPromptTokenCount: vi.fn(() => 100), getLastPromptTokenCount: vi.fn(() => 100),
setHistory: mockSetHistory, setHistory: mockSetHistory,
@@ -490,8 +485,10 @@ describe('AgentExecutor', () => {
const { modelConfigKey } = getMockMessageParams(0); const { modelConfigKey } = getMockMessageParams(0);
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition)); expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
const call = mockSetTools.mock.calls[0]; const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
const sentTools = (call[0] as Tool[])[0].functionDeclarations; // tools are the 3rd argument (index 2), passed as [{ functionDeclarations: [...] }]
const passedToolsArg = chatConstructorArgs[2] as Tool[];
const sentTools = passedToolsArg[0].functionDeclarations;
expect(sentTools).toBeDefined(); expect(sentTools).toBeDefined();
expect(sentTools).toEqual( expect(sentTools).toEqual(
@@ -615,8 +612,9 @@ describe('AgentExecutor', () => {
const { modelConfigKey } = getMockMessageParams(0); const { modelConfigKey } = getMockMessageParams(0);
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition)); expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
const call = mockSetTools.mock.calls[0]; const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
const sentTools = (call[0] as Tool[])[0].functionDeclarations; const passedToolsArg = chatConstructorArgs[2] as Tool[];
const sentTools = passedToolsArg[0].functionDeclarations;
expect(sentTools).toBeDefined(); expect(sentTools).toBeDefined();
const completeToolDef = sentTools!.find( const completeToolDef = sentTools!.find(
+7 -16
View File
@@ -182,7 +182,6 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
private async executeTurn( private async executeTurn(
chat: GeminiChat, chat: GeminiChat,
currentMessage: Content, currentMessage: Content,
tools: FunctionDeclaration[],
turnCounter: number, turnCounter: number,
combinedSignal: AbortSignal, combinedSignal: AbortSignal,
timeoutSignal: AbortSignal, // Pass the timeout controller's signal timeoutSignal: AbortSignal, // Pass the timeout controller's signal
@@ -192,7 +191,7 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
await this.tryCompressChat(chat, promptId); await this.tryCompressChat(chat, promptId);
const { functionCalls } = await promptIdContext.run(promptId, async () => const { functionCalls } = await promptIdContext.run(promptId, async () =>
this.callModel(chat, currentMessage, tools, combinedSignal, promptId), this.callModel(chat, currentMessage, combinedSignal, promptId),
); );
if (combinedSignal.aborted) { if (combinedSignal.aborted) {
@@ -272,7 +271,6 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
*/ */
private async executeFinalWarningTurn( private async executeFinalWarningTurn(
chat: GeminiChat, chat: GeminiChat,
tools: FunctionDeclaration[],
turnCounter: number, turnCounter: number,
reason: reason:
| AgentTerminateMode.TIMEOUT | AgentTerminateMode.TIMEOUT
@@ -309,7 +307,6 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
const turnResult = await this.executeTurn( const turnResult = await this.executeTurn(
chat, chat,
recoveryMessage, recoveryMessage,
tools,
turnCounter, // This will be the "last" turn number turnCounter, // This will be the "last" turn number
combinedSignal, combinedSignal,
graceTimeoutController.signal, // Pass grace signal to identify a *grace* timeout graceTimeoutController.signal, // Pass grace signal to identify a *grace* timeout
@@ -387,8 +384,8 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
let chat: GeminiChat | undefined; let chat: GeminiChat | undefined;
let tools: FunctionDeclaration[] | undefined; let tools: FunctionDeclaration[] | undefined;
try { try {
chat = await this.createChatObject(inputs);
tools = this.prepareToolsList(); tools = this.prepareToolsList();
chat = await this.createChatObject(inputs, tools);
const query = this.definition.promptConfig.query const query = this.definition.promptConfig.query
? templateString(this.definition.promptConfig.query, inputs) ? templateString(this.definition.promptConfig.query, inputs)
: 'Get Started!'; : 'Get Started!';
@@ -414,7 +411,6 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
const turnResult = await this.executeTurn( const turnResult = await this.executeTurn(
chat, chat,
currentMessage, currentMessage,
tools,
turnCounter++, turnCounter++,
combinedSignal, combinedSignal,
timeoutController.signal, timeoutController.signal,
@@ -443,7 +439,6 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
) { ) {
const recoveryResult = await this.executeFinalWarningTurn( const recoveryResult = await this.executeFinalWarningTurn(
chat, chat,
tools,
turnCounter, // Use current turnCounter for the recovery attempt turnCounter, // Use current turnCounter for the recovery attempt
terminateReason, terminateReason,
signal, // Pass the external signal signal, // Pass the external signal
@@ -509,7 +504,6 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
if (chat && tools) { if (chat && tools) {
const recoveryResult = await this.executeFinalWarningTurn( const recoveryResult = await this.executeFinalWarningTurn(
chat, chat,
tools,
turnCounter, // Use current turnCounter turnCounter, // Use current turnCounter
AgentTerminateMode.TIMEOUT, AgentTerminateMode.TIMEOUT,
signal, signal,
@@ -591,15 +585,9 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
private async callModel( private async callModel(
chat: GeminiChat, chat: GeminiChat,
message: Content, message: Content,
tools: FunctionDeclaration[],
signal: AbortSignal, signal: AbortSignal,
promptId: string, promptId: string,
): Promise<{ functionCalls: FunctionCall[]; textResponse: string }> { ): Promise<{ functionCalls: FunctionCall[]; textResponse: string }> {
if (tools.length > 0) {
// TODO(12622): Move tools back to config.
chat.setTools([{ functionDeclarations: tools }]);
}
const responseStream = await chat.sendMessageStream( const responseStream = await chat.sendMessageStream(
{ {
model: getModelConfigAlias(this.definition), model: getModelConfigAlias(this.definition),
@@ -650,7 +638,10 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
} }
/** Initializes a `GeminiChat` instance for the agent run. */ /** Initializes a `GeminiChat` instance for the agent run. */
private async createChatObject(inputs: AgentInputs): Promise<GeminiChat> { private async createChatObject(
inputs: AgentInputs,
tools: FunctionDeclaration[],
): Promise<GeminiChat> {
const { promptConfig } = this.definition; const { promptConfig } = this.definition;
if (!promptConfig.systemPrompt && !promptConfig.initialMessages) { if (!promptConfig.systemPrompt && !promptConfig.initialMessages) {
@@ -673,7 +664,7 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
return new GeminiChat( return new GeminiChat(
this.runtimeContext, this.runtimeContext,
systemInstruction, systemInstruction,
[], // set in `callModel`, [{ functionDeclarations: tools }],
startHistory, startHistory,
); );
} catch (error) { } catch (error) {