Serialize function calls that mutate state (#8513)

This commit is contained in:
Tommaso Sciortino
2025-09-17 11:45:04 -07:00
committed by GitHub
parent efb57e1cef
commit e76dda37ad
4 changed files with 331 additions and 8 deletions

View File

@@ -22,6 +22,8 @@ import { setSimulate429 } from '../utils/testUtils.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { AuthType } from './contentGenerator.js';
import { type RetryOptions } from '../utils/retry.js';
import type { ToolRegistry } from '../tools/tool-registry.js';
import { Kind } from '../tools/tools.js';
// Mock fs module to prevent actual file system operations during tests
const mockFileSystem = new Map<string, string>();
@@ -973,6 +975,259 @@ describe('GeminiChat', () => {
expect(turn4.parts[0].text).toBe('second response');
});
describe('stopBeforeSecondMutator', () => {
beforeEach(() => {
// Common setup for these tests: mock the tool registry.
const mockToolRegistry = {
getTool: vi.fn((toolName: string) => {
if (toolName === 'edit') {
return { kind: Kind.Edit };
}
return { kind: Kind.Other };
}),
} as unknown as ToolRegistry;
vi.mocked(mockConfig.getToolRegistry).mockReturnValue(mockToolRegistry);
});
it('should stop streaming before a second mutator tool call', async () => {
const responses = [
{
candidates: [
{ content: { role: 'model', parts: [{ text: 'First part. ' }] } },
],
},
{
candidates: [
{
content: {
role: 'model',
parts: [{ functionCall: { name: 'edit', args: {} } }],
},
},
],
},
{
candidates: [
{
content: {
role: 'model',
parts: [{ functionCall: { name: 'fetch', args: {} } }],
},
},
],
},
// This chunk contains the second mutator and should be clipped.
{
candidates: [
{
content: {
role: 'model',
parts: [
{ functionCall: { name: 'edit', args: {} } },
{ text: 'some trailing text' },
],
},
},
],
},
// This chunk should never be reached.
{
candidates: [
{
content: {
role: 'model',
parts: [{ text: 'This should not appear.' }],
},
},
],
},
] as unknown as GenerateContentResponse[];
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
(async function* () {
for (const response of responses) {
yield response;
}
})(),
);
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'test message' },
'prompt-id-mutator-test',
);
for await (const _ of stream) {
// Consume the stream to trigger history recording.
}
const history = chat.getHistory();
expect(history.length).toBe(2);
const modelTurn = history[1]!;
expect(modelTurn.role).toBe('model');
expect(modelTurn?.parts?.length).toBe(3);
expect(modelTurn?.parts![0]!.text).toBe('First part. ');
expect(modelTurn.parts![1]!.functionCall?.name).toBe('edit');
expect(modelTurn.parts![2]!.functionCall?.name).toBe('fetch');
});
it('should not stop streaming if only one mutator is present', async () => {
const responses = [
{
candidates: [
{ content: { role: 'model', parts: [{ text: 'Part 1. ' }] } },
],
},
{
candidates: [
{
content: {
role: 'model',
parts: [{ functionCall: { name: 'edit', args: {} } }],
},
},
],
},
{
candidates: [
{
content: {
role: 'model',
parts: [{ text: 'Part 2.' }],
},
finishReason: 'STOP',
},
],
},
] as unknown as GenerateContentResponse[];
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
(async function* () {
for (const response of responses) {
yield response;
}
})(),
);
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'test message' },
'prompt-id-one-mutator',
);
for await (const _ of stream) {
/* consume */
}
const history = chat.getHistory();
const modelTurn = history[1]!;
expect(modelTurn?.parts?.length).toBe(3);
expect(modelTurn.parts![1]!.functionCall?.name).toBe('edit');
expect(modelTurn.parts![2]!.text).toBe('Part 2.');
});
it('should clip the chunk containing the second mutator, preserving prior parts', async () => {
const responses = [
{
candidates: [
{
content: {
role: 'model',
parts: [{ functionCall: { name: 'edit', args: {} } }],
},
},
],
},
// This chunk has a valid part before the second mutator.
// The valid part should be kept, the rest of the chunk discarded.
{
candidates: [
{
content: {
role: 'model',
parts: [
{ text: 'Keep this text. ' },
{ functionCall: { name: 'edit', args: {} } },
{ text: 'Discard this text.' },
],
},
finishReason: 'STOP',
},
],
},
] as unknown as GenerateContentResponse[];
const stream = (async function* () {
for (const response of responses) {
yield response;
}
})();
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
stream,
);
const resultStream = await chat.sendMessageStream(
'test-model',
{ message: 'test' },
'prompt-id-clip-chunk',
);
for await (const _ of resultStream) {
/* consume */
}
const history = chat.getHistory();
const modelTurn = history[1]!;
expect(modelTurn?.parts?.length).toBe(2);
expect(modelTurn.parts![0]!.functionCall?.name).toBe('edit');
expect(modelTurn.parts![1]!.text).toBe('Keep this text. ');
});
it('should handle two mutators in the same chunk (parallel call scenario)', async () => {
const responses = [
{
candidates: [
{
content: {
role: 'model',
parts: [
{ text: 'Some text. ' },
{ functionCall: { name: 'edit', args: {} } },
{ functionCall: { name: 'edit', args: {} } },
],
},
finishReason: 'STOP',
},
],
},
] as unknown as GenerateContentResponse[];
const stream = (async function* () {
for (const response of responses) {
yield response;
}
})();
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
stream,
);
const resultStream = await chat.sendMessageStream(
'test-model',
{ message: 'test' },
'prompt-id-parallel-mutators',
);
for await (const _ of resultStream) {
/* consume */
}
const history = chat.getHistory();
const modelTurn = history[1]!;
expect(modelTurn?.parts?.length).toBe(2);
expect(modelTurn.parts![0]!.text).toBe('Some text. ');
expect(modelTurn.parts![1]!.functionCall?.name).toBe('edit');
});
});
describe('Model Resolution', () => {
const mockResponse = {
candidates: [

View File

@@ -7,13 +7,14 @@
// DISCLAIMER: This is a copied version of https://github.com/googleapis/js-genai/blob/main/src/chats.ts with the intention of working around a key bug
// where function responses are not treated as "valid" responses: https://b.corp.google.com/issues/420354090
import type {
import {
GenerateContentResponse,
Content,
GenerateContentConfig,
SendMessageParameters,
Part,
Tool,
type Content,
type GenerateContentConfig,
type SendMessageParameters,
type Part,
type Tool,
FinishReason,
} from '@google/genai';
import { toParts } from '../code_assist/converter.js';
import { createUserContent } from '@google/genai';
@@ -23,7 +24,7 @@ import {
DEFAULT_GEMINI_FLASH_MODEL,
getEffectiveModel,
} from '../config/models.js';
import { hasCycleInSchema } from '../tools/tools.js';
import { hasCycleInSchema, MUTATOR_KINDS } from '../tools/tools.js';
import type { StructuredError } from './turn.js';
import {
logContentRetry,
@@ -495,7 +496,7 @@ export class GeminiChat {
let lastChunk: GenerateContentResponse | null = null;
let lastChunkIsInvalid = false;
for await (const chunk of streamResponse) {
for await (const chunk of this.stopBeforeSecondMutator(streamResponse)) {
hasReceivedAnyChunk = true;
lastChunk = chunk;
@@ -621,6 +622,64 @@ export class GeminiChat {
});
}
}
/**
* Truncates the chunkStream right before the second function call to a
* function that mutates state. This may involve trimming parts from a chunk
* as well as omtting some chunks altogether.
*
* We do this because it improves tool call quality if the model gets
* feedback from one mutating function call before it makes the next one.
*/
private async *stopBeforeSecondMutator(
chunkStream: AsyncGenerator<GenerateContentResponse>,
): AsyncGenerator<GenerateContentResponse> {
let foundMutatorFunctionCall = false;
for await (const chunk of chunkStream) {
const candidate = chunk.candidates?.[0];
const content = candidate?.content;
if (!candidate || !content?.parts) {
yield chunk;
continue;
}
const truncatedParts: Part[] = [];
for (const part of content.parts) {
if (this.isMutatorFunctionCall(part)) {
if (foundMutatorFunctionCall) {
// This is the second mutator call.
// Truncate and return immedaitely.
const newChunk = new GenerateContentResponse();
newChunk.candidates = [
{
...candidate,
content: {
...content,
parts: truncatedParts,
},
finishReason: FinishReason.STOP,
},
];
yield newChunk;
return;
}
foundMutatorFunctionCall = true;
}
truncatedParts.push(part);
}
yield chunk;
}
}
private isMutatorFunctionCall(part: Part): boolean {
if (!part?.functionCall?.name) {
return false;
}
const tool = this.config.getToolRegistry().getTool(part.functionCall.name);
return !!tool && MUTATOR_KINDS.includes(tool.kind);
}
}
/** Visible for Testing */

View File

@@ -167,6 +167,7 @@ Signal: Signal number or \`(none)\` if no signal was received.
}
export class ToolRegistry {
// The tools keyed by tool name as seen by the LLM.
private tools: Map<string, AnyDeclarativeTool> = new Map();
private config: Config;
private mcpClientManager: McpClientManager;

View File

@@ -532,6 +532,14 @@ export enum Kind {
Other = 'other',
}
// Function kinds that have side effects
export const MUTATOR_KINDS: Kind[] = [
Kind.Edit,
Kind.Delete,
Kind.Move,
Kind.Execute,
] as const;
export interface ToolLocation {
// Absolute path to the file
path: string;