From e76dda37add503d239856ba8fb421f07455bd3bb Mon Sep 17 00:00:00 2001 From: Tommaso Sciortino Date: Wed, 17 Sep 2025 11:45:04 -0700 Subject: [PATCH] Serialize function calls that mutate state (#8513) --- packages/core/src/core/geminiChat.test.ts | 255 ++++++++++++++++++++++ packages/core/src/core/geminiChat.ts | 75 ++++++- packages/core/src/tools/tool-registry.ts | 1 + packages/core/src/tools/tools.ts | 8 + 4 files changed, 331 insertions(+), 8 deletions(-) diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index ea783d9211..94db6f05c6 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -22,6 +22,8 @@ import { setSimulate429 } from '../utils/testUtils.js'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { AuthType } from './contentGenerator.js'; import { type RetryOptions } from '../utils/retry.js'; +import type { ToolRegistry } from '../tools/tool-registry.js'; +import { Kind } from '../tools/tools.js'; // Mock fs module to prevent actual file system operations during tests const mockFileSystem = new Map(); @@ -973,6 +975,259 @@ describe('GeminiChat', () => { expect(turn4.parts[0].text).toBe('second response'); }); + describe('stopBeforeSecondMutator', () => { + beforeEach(() => { + // Common setup for these tests: mock the tool registry. + const mockToolRegistry = { + getTool: vi.fn((toolName: string) => { + if (toolName === 'edit') { + return { kind: Kind.Edit }; + } + return { kind: Kind.Other }; + }), + } as unknown as ToolRegistry; + vi.mocked(mockConfig.getToolRegistry).mockReturnValue(mockToolRegistry); + }); + + it('should stop streaming before a second mutator tool call', async () => { + const responses = [ + { + candidates: [ + { content: { role: 'model', parts: [{ text: 'First part. ' }] } }, + ], + }, + { + candidates: [ + { + content: { + role: 'model', + parts: [{ functionCall: { name: 'edit', args: {} } }], + }, + }, + ], + }, + { + candidates: [ + { + content: { + role: 'model', + parts: [{ functionCall: { name: 'fetch', args: {} } }], + }, + }, + ], + }, + // This chunk contains the second mutator and should be clipped. + { + candidates: [ + { + content: { + role: 'model', + parts: [ + { functionCall: { name: 'edit', args: {} } }, + { text: 'some trailing text' }, + ], + }, + }, + ], + }, + // This chunk should never be reached. + { + candidates: [ + { + content: { + role: 'model', + parts: [{ text: 'This should not appear.' }], + }, + }, + ], + }, + ] as unknown as GenerateContentResponse[]; + + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( + (async function* () { + for (const response of responses) { + yield response; + } + })(), + ); + + const stream = await chat.sendMessageStream( + 'test-model', + { message: 'test message' }, + 'prompt-id-mutator-test', + ); + for await (const _ of stream) { + // Consume the stream to trigger history recording. + } + + const history = chat.getHistory(); + expect(history.length).toBe(2); + + const modelTurn = history[1]!; + expect(modelTurn.role).toBe('model'); + expect(modelTurn?.parts?.length).toBe(3); + expect(modelTurn?.parts![0]!.text).toBe('First part. '); + expect(modelTurn.parts![1]!.functionCall?.name).toBe('edit'); + expect(modelTurn.parts![2]!.functionCall?.name).toBe('fetch'); + }); + + it('should not stop streaming if only one mutator is present', async () => { + const responses = [ + { + candidates: [ + { content: { role: 'model', parts: [{ text: 'Part 1. ' }] } }, + ], + }, + { + candidates: [ + { + content: { + role: 'model', + parts: [{ functionCall: { name: 'edit', args: {} } }], + }, + }, + ], + }, + { + candidates: [ + { + content: { + role: 'model', + parts: [{ text: 'Part 2.' }], + }, + finishReason: 'STOP', + }, + ], + }, + ] as unknown as GenerateContentResponse[]; + + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( + (async function* () { + for (const response of responses) { + yield response; + } + })(), + ); + + const stream = await chat.sendMessageStream( + 'test-model', + { message: 'test message' }, + 'prompt-id-one-mutator', + ); + for await (const _ of stream) { + /* consume */ + } + + const history = chat.getHistory(); + const modelTurn = history[1]!; + expect(modelTurn?.parts?.length).toBe(3); + expect(modelTurn.parts![1]!.functionCall?.name).toBe('edit'); + expect(modelTurn.parts![2]!.text).toBe('Part 2.'); + }); + + it('should clip the chunk containing the second mutator, preserving prior parts', async () => { + const responses = [ + { + candidates: [ + { + content: { + role: 'model', + parts: [{ functionCall: { name: 'edit', args: {} } }], + }, + }, + ], + }, + // This chunk has a valid part before the second mutator. + // The valid part should be kept, the rest of the chunk discarded. + { + candidates: [ + { + content: { + role: 'model', + parts: [ + { text: 'Keep this text. ' }, + { functionCall: { name: 'edit', args: {} } }, + { text: 'Discard this text.' }, + ], + }, + finishReason: 'STOP', + }, + ], + }, + ] as unknown as GenerateContentResponse[]; + + const stream = (async function* () { + for (const response of responses) { + yield response; + } + })(); + + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( + stream, + ); + + const resultStream = await chat.sendMessageStream( + 'test-model', + { message: 'test' }, + 'prompt-id-clip-chunk', + ); + for await (const _ of resultStream) { + /* consume */ + } + + const history = chat.getHistory(); + const modelTurn = history[1]!; + expect(modelTurn?.parts?.length).toBe(2); + expect(modelTurn.parts![0]!.functionCall?.name).toBe('edit'); + expect(modelTurn.parts![1]!.text).toBe('Keep this text. '); + }); + + it('should handle two mutators in the same chunk (parallel call scenario)', async () => { + const responses = [ + { + candidates: [ + { + content: { + role: 'model', + parts: [ + { text: 'Some text. ' }, + { functionCall: { name: 'edit', args: {} } }, + { functionCall: { name: 'edit', args: {} } }, + ], + }, + finishReason: 'STOP', + }, + ], + }, + ] as unknown as GenerateContentResponse[]; + + const stream = (async function* () { + for (const response of responses) { + yield response; + } + })(); + + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( + stream, + ); + + const resultStream = await chat.sendMessageStream( + 'test-model', + { message: 'test' }, + 'prompt-id-parallel-mutators', + ); + for await (const _ of resultStream) { + /* consume */ + } + + const history = chat.getHistory(); + const modelTurn = history[1]!; + expect(modelTurn?.parts?.length).toBe(2); + expect(modelTurn.parts![0]!.text).toBe('Some text. '); + expect(modelTurn.parts![1]!.functionCall?.name).toBe('edit'); + }); + }); + describe('Model Resolution', () => { const mockResponse = { candidates: [ diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 7db86ace83..71bc04d6bd 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -7,13 +7,14 @@ // DISCLAIMER: This is a copied version of https://github.com/googleapis/js-genai/blob/main/src/chats.ts with the intention of working around a key bug // where function responses are not treated as "valid" responses: https://b.corp.google.com/issues/420354090 -import type { +import { GenerateContentResponse, - Content, - GenerateContentConfig, - SendMessageParameters, - Part, - Tool, + type Content, + type GenerateContentConfig, + type SendMessageParameters, + type Part, + type Tool, + FinishReason, } from '@google/genai'; import { toParts } from '../code_assist/converter.js'; import { createUserContent } from '@google/genai'; @@ -23,7 +24,7 @@ import { DEFAULT_GEMINI_FLASH_MODEL, getEffectiveModel, } from '../config/models.js'; -import { hasCycleInSchema } from '../tools/tools.js'; +import { hasCycleInSchema, MUTATOR_KINDS } from '../tools/tools.js'; import type { StructuredError } from './turn.js'; import { logContentRetry, @@ -495,7 +496,7 @@ export class GeminiChat { let lastChunk: GenerateContentResponse | null = null; let lastChunkIsInvalid = false; - for await (const chunk of streamResponse) { + for await (const chunk of this.stopBeforeSecondMutator(streamResponse)) { hasReceivedAnyChunk = true; lastChunk = chunk; @@ -621,6 +622,64 @@ export class GeminiChat { }); } } + + /** + * Truncates the chunkStream right before the second function call to a + * function that mutates state. This may involve trimming parts from a chunk + * as well as omtting some chunks altogether. + * + * We do this because it improves tool call quality if the model gets + * feedback from one mutating function call before it makes the next one. + */ + private async *stopBeforeSecondMutator( + chunkStream: AsyncGenerator, + ): AsyncGenerator { + let foundMutatorFunctionCall = false; + + for await (const chunk of chunkStream) { + const candidate = chunk.candidates?.[0]; + const content = candidate?.content; + if (!candidate || !content?.parts) { + yield chunk; + continue; + } + + const truncatedParts: Part[] = []; + for (const part of content.parts) { + if (this.isMutatorFunctionCall(part)) { + if (foundMutatorFunctionCall) { + // This is the second mutator call. + // Truncate and return immedaitely. + const newChunk = new GenerateContentResponse(); + newChunk.candidates = [ + { + ...candidate, + content: { + ...content, + parts: truncatedParts, + }, + finishReason: FinishReason.STOP, + }, + ]; + yield newChunk; + return; + } + foundMutatorFunctionCall = true; + } + truncatedParts.push(part); + } + + yield chunk; + } + } + + private isMutatorFunctionCall(part: Part): boolean { + if (!part?.functionCall?.name) { + return false; + } + const tool = this.config.getToolRegistry().getTool(part.functionCall.name); + return !!tool && MUTATOR_KINDS.includes(tool.kind); + } } /** Visible for Testing */ diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index c4cb46e9b2..627412f38a 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -167,6 +167,7 @@ Signal: Signal number or \`(none)\` if no signal was received. } export class ToolRegistry { + // The tools keyed by tool name as seen by the LLM. private tools: Map = new Map(); private config: Config; private mcpClientManager: McpClientManager; diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 5c51dba755..48cf2d2d1f 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -532,6 +532,14 @@ export enum Kind { Other = 'other', } +// Function kinds that have side effects +export const MUTATOR_KINDS: Kind[] = [ + Kind.Edit, + Kind.Delete, + Kind.Move, + Kind.Execute, +] as const; + export interface ToolLocation { // Absolute path to the file path: string;