mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-03 00:14:28 -07:00
Serialize function calls that mutate state (#8513)
This commit is contained in:
committed by
GitHub
parent
efb57e1cef
commit
e76dda37ad
@@ -22,6 +22,8 @@ import { setSimulate429 } from '../utils/testUtils.js';
|
|||||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||||
import { AuthType } from './contentGenerator.js';
|
import { AuthType } from './contentGenerator.js';
|
||||||
import { type RetryOptions } from '../utils/retry.js';
|
import { type RetryOptions } from '../utils/retry.js';
|
||||||
|
import type { ToolRegistry } from '../tools/tool-registry.js';
|
||||||
|
import { Kind } from '../tools/tools.js';
|
||||||
|
|
||||||
// Mock fs module to prevent actual file system operations during tests
|
// Mock fs module to prevent actual file system operations during tests
|
||||||
const mockFileSystem = new Map<string, string>();
|
const mockFileSystem = new Map<string, string>();
|
||||||
@@ -973,6 +975,259 @@ describe('GeminiChat', () => {
|
|||||||
expect(turn4.parts[0].text).toBe('second response');
|
expect(turn4.parts[0].text).toBe('second response');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('stopBeforeSecondMutator', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
// Common setup for these tests: mock the tool registry.
|
||||||
|
const mockToolRegistry = {
|
||||||
|
getTool: vi.fn((toolName: string) => {
|
||||||
|
if (toolName === 'edit') {
|
||||||
|
return { kind: Kind.Edit };
|
||||||
|
}
|
||||||
|
return { kind: Kind.Other };
|
||||||
|
}),
|
||||||
|
} as unknown as ToolRegistry;
|
||||||
|
vi.mocked(mockConfig.getToolRegistry).mockReturnValue(mockToolRegistry);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should stop streaming before a second mutator tool call', async () => {
|
||||||
|
const responses = [
|
||||||
|
{
|
||||||
|
candidates: [
|
||||||
|
{ content: { role: 'model', parts: [{ text: 'First part. ' }] } },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ functionCall: { name: 'edit', args: {} } }],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ functionCall: { name: 'fetch', args: {} } }],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
// This chunk contains the second mutator and should be clipped.
|
||||||
|
{
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
role: 'model',
|
||||||
|
parts: [
|
||||||
|
{ functionCall: { name: 'edit', args: {} } },
|
||||||
|
{ text: 'some trailing text' },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
// This chunk should never be reached.
|
||||||
|
{
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: 'This should not appear.' }],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
] as unknown as GenerateContentResponse[];
|
||||||
|
|
||||||
|
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||||
|
(async function* () {
|
||||||
|
for (const response of responses) {
|
||||||
|
yield response;
|
||||||
|
}
|
||||||
|
})(),
|
||||||
|
);
|
||||||
|
|
||||||
|
const stream = await chat.sendMessageStream(
|
||||||
|
'test-model',
|
||||||
|
{ message: 'test message' },
|
||||||
|
'prompt-id-mutator-test',
|
||||||
|
);
|
||||||
|
for await (const _ of stream) {
|
||||||
|
// Consume the stream to trigger history recording.
|
||||||
|
}
|
||||||
|
|
||||||
|
const history = chat.getHistory();
|
||||||
|
expect(history.length).toBe(2);
|
||||||
|
|
||||||
|
const modelTurn = history[1]!;
|
||||||
|
expect(modelTurn.role).toBe('model');
|
||||||
|
expect(modelTurn?.parts?.length).toBe(3);
|
||||||
|
expect(modelTurn?.parts![0]!.text).toBe('First part. ');
|
||||||
|
expect(modelTurn.parts![1]!.functionCall?.name).toBe('edit');
|
||||||
|
expect(modelTurn.parts![2]!.functionCall?.name).toBe('fetch');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not stop streaming if only one mutator is present', async () => {
|
||||||
|
const responses = [
|
||||||
|
{
|
||||||
|
candidates: [
|
||||||
|
{ content: { role: 'model', parts: [{ text: 'Part 1. ' }] } },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ functionCall: { name: 'edit', args: {} } }],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: 'Part 2.' }],
|
||||||
|
},
|
||||||
|
finishReason: 'STOP',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
] as unknown as GenerateContentResponse[];
|
||||||
|
|
||||||
|
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||||
|
(async function* () {
|
||||||
|
for (const response of responses) {
|
||||||
|
yield response;
|
||||||
|
}
|
||||||
|
})(),
|
||||||
|
);
|
||||||
|
|
||||||
|
const stream = await chat.sendMessageStream(
|
||||||
|
'test-model',
|
||||||
|
{ message: 'test message' },
|
||||||
|
'prompt-id-one-mutator',
|
||||||
|
);
|
||||||
|
for await (const _ of stream) {
|
||||||
|
/* consume */
|
||||||
|
}
|
||||||
|
|
||||||
|
const history = chat.getHistory();
|
||||||
|
const modelTurn = history[1]!;
|
||||||
|
expect(modelTurn?.parts?.length).toBe(3);
|
||||||
|
expect(modelTurn.parts![1]!.functionCall?.name).toBe('edit');
|
||||||
|
expect(modelTurn.parts![2]!.text).toBe('Part 2.');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should clip the chunk containing the second mutator, preserving prior parts', async () => {
|
||||||
|
const responses = [
|
||||||
|
{
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ functionCall: { name: 'edit', args: {} } }],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
// This chunk has a valid part before the second mutator.
|
||||||
|
// The valid part should be kept, the rest of the chunk discarded.
|
||||||
|
{
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
role: 'model',
|
||||||
|
parts: [
|
||||||
|
{ text: 'Keep this text. ' },
|
||||||
|
{ functionCall: { name: 'edit', args: {} } },
|
||||||
|
{ text: 'Discard this text.' },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
finishReason: 'STOP',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
] as unknown as GenerateContentResponse[];
|
||||||
|
|
||||||
|
const stream = (async function* () {
|
||||||
|
for (const response of responses) {
|
||||||
|
yield response;
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
|
||||||
|
const resultStream = await chat.sendMessageStream(
|
||||||
|
'test-model',
|
||||||
|
{ message: 'test' },
|
||||||
|
'prompt-id-clip-chunk',
|
||||||
|
);
|
||||||
|
for await (const _ of resultStream) {
|
||||||
|
/* consume */
|
||||||
|
}
|
||||||
|
|
||||||
|
const history = chat.getHistory();
|
||||||
|
const modelTurn = history[1]!;
|
||||||
|
expect(modelTurn?.parts?.length).toBe(2);
|
||||||
|
expect(modelTurn.parts![0]!.functionCall?.name).toBe('edit');
|
||||||
|
expect(modelTurn.parts![1]!.text).toBe('Keep this text. ');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle two mutators in the same chunk (parallel call scenario)', async () => {
|
||||||
|
const responses = [
|
||||||
|
{
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
role: 'model',
|
||||||
|
parts: [
|
||||||
|
{ text: 'Some text. ' },
|
||||||
|
{ functionCall: { name: 'edit', args: {} } },
|
||||||
|
{ functionCall: { name: 'edit', args: {} } },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
finishReason: 'STOP',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
] as unknown as GenerateContentResponse[];
|
||||||
|
|
||||||
|
const stream = (async function* () {
|
||||||
|
for (const response of responses) {
|
||||||
|
yield response;
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||||
|
stream,
|
||||||
|
);
|
||||||
|
|
||||||
|
const resultStream = await chat.sendMessageStream(
|
||||||
|
'test-model',
|
||||||
|
{ message: 'test' },
|
||||||
|
'prompt-id-parallel-mutators',
|
||||||
|
);
|
||||||
|
for await (const _ of resultStream) {
|
||||||
|
/* consume */
|
||||||
|
}
|
||||||
|
|
||||||
|
const history = chat.getHistory();
|
||||||
|
const modelTurn = history[1]!;
|
||||||
|
expect(modelTurn?.parts?.length).toBe(2);
|
||||||
|
expect(modelTurn.parts![0]!.text).toBe('Some text. ');
|
||||||
|
expect(modelTurn.parts![1]!.functionCall?.name).toBe('edit');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('Model Resolution', () => {
|
describe('Model Resolution', () => {
|
||||||
const mockResponse = {
|
const mockResponse = {
|
||||||
candidates: [
|
candidates: [
|
||||||
|
|||||||
@@ -7,13 +7,14 @@
|
|||||||
// DISCLAIMER: This is a copied version of https://github.com/googleapis/js-genai/blob/main/src/chats.ts with the intention of working around a key bug
|
// DISCLAIMER: This is a copied version of https://github.com/googleapis/js-genai/blob/main/src/chats.ts with the intention of working around a key bug
|
||||||
// where function responses are not treated as "valid" responses: https://b.corp.google.com/issues/420354090
|
// where function responses are not treated as "valid" responses: https://b.corp.google.com/issues/420354090
|
||||||
|
|
||||||
import type {
|
import {
|
||||||
GenerateContentResponse,
|
GenerateContentResponse,
|
||||||
Content,
|
type Content,
|
||||||
GenerateContentConfig,
|
type GenerateContentConfig,
|
||||||
SendMessageParameters,
|
type SendMessageParameters,
|
||||||
Part,
|
type Part,
|
||||||
Tool,
|
type Tool,
|
||||||
|
FinishReason,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { toParts } from '../code_assist/converter.js';
|
import { toParts } from '../code_assist/converter.js';
|
||||||
import { createUserContent } from '@google/genai';
|
import { createUserContent } from '@google/genai';
|
||||||
@@ -23,7 +24,7 @@ import {
|
|||||||
DEFAULT_GEMINI_FLASH_MODEL,
|
DEFAULT_GEMINI_FLASH_MODEL,
|
||||||
getEffectiveModel,
|
getEffectiveModel,
|
||||||
} from '../config/models.js';
|
} from '../config/models.js';
|
||||||
import { hasCycleInSchema } from '../tools/tools.js';
|
import { hasCycleInSchema, MUTATOR_KINDS } from '../tools/tools.js';
|
||||||
import type { StructuredError } from './turn.js';
|
import type { StructuredError } from './turn.js';
|
||||||
import {
|
import {
|
||||||
logContentRetry,
|
logContentRetry,
|
||||||
@@ -495,7 +496,7 @@ export class GeminiChat {
|
|||||||
let lastChunk: GenerateContentResponse | null = null;
|
let lastChunk: GenerateContentResponse | null = null;
|
||||||
let lastChunkIsInvalid = false;
|
let lastChunkIsInvalid = false;
|
||||||
|
|
||||||
for await (const chunk of streamResponse) {
|
for await (const chunk of this.stopBeforeSecondMutator(streamResponse)) {
|
||||||
hasReceivedAnyChunk = true;
|
hasReceivedAnyChunk = true;
|
||||||
lastChunk = chunk;
|
lastChunk = chunk;
|
||||||
|
|
||||||
@@ -621,6 +622,64 @@ export class GeminiChat {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Truncates the chunkStream right before the second function call to a
|
||||||
|
* function that mutates state. This may involve trimming parts from a chunk
|
||||||
|
* as well as omtting some chunks altogether.
|
||||||
|
*
|
||||||
|
* We do this because it improves tool call quality if the model gets
|
||||||
|
* feedback from one mutating function call before it makes the next one.
|
||||||
|
*/
|
||||||
|
private async *stopBeforeSecondMutator(
|
||||||
|
chunkStream: AsyncGenerator<GenerateContentResponse>,
|
||||||
|
): AsyncGenerator<GenerateContentResponse> {
|
||||||
|
let foundMutatorFunctionCall = false;
|
||||||
|
|
||||||
|
for await (const chunk of chunkStream) {
|
||||||
|
const candidate = chunk.candidates?.[0];
|
||||||
|
const content = candidate?.content;
|
||||||
|
if (!candidate || !content?.parts) {
|
||||||
|
yield chunk;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const truncatedParts: Part[] = [];
|
||||||
|
for (const part of content.parts) {
|
||||||
|
if (this.isMutatorFunctionCall(part)) {
|
||||||
|
if (foundMutatorFunctionCall) {
|
||||||
|
// This is the second mutator call.
|
||||||
|
// Truncate and return immedaitely.
|
||||||
|
const newChunk = new GenerateContentResponse();
|
||||||
|
newChunk.candidates = [
|
||||||
|
{
|
||||||
|
...candidate,
|
||||||
|
content: {
|
||||||
|
...content,
|
||||||
|
parts: truncatedParts,
|
||||||
|
},
|
||||||
|
finishReason: FinishReason.STOP,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
yield newChunk;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
foundMutatorFunctionCall = true;
|
||||||
|
}
|
||||||
|
truncatedParts.push(part);
|
||||||
|
}
|
||||||
|
|
||||||
|
yield chunk;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private isMutatorFunctionCall(part: Part): boolean {
|
||||||
|
if (!part?.functionCall?.name) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const tool = this.config.getToolRegistry().getTool(part.functionCall.name);
|
||||||
|
return !!tool && MUTATOR_KINDS.includes(tool.kind);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Visible for Testing */
|
/** Visible for Testing */
|
||||||
|
|||||||
@@ -167,6 +167,7 @@ Signal: Signal number or \`(none)\` if no signal was received.
|
|||||||
}
|
}
|
||||||
|
|
||||||
export class ToolRegistry {
|
export class ToolRegistry {
|
||||||
|
// The tools keyed by tool name as seen by the LLM.
|
||||||
private tools: Map<string, AnyDeclarativeTool> = new Map();
|
private tools: Map<string, AnyDeclarativeTool> = new Map();
|
||||||
private config: Config;
|
private config: Config;
|
||||||
private mcpClientManager: McpClientManager;
|
private mcpClientManager: McpClientManager;
|
||||||
|
|||||||
@@ -532,6 +532,14 @@ export enum Kind {
|
|||||||
Other = 'other',
|
Other = 'other',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Function kinds that have side effects
|
||||||
|
export const MUTATOR_KINDS: Kind[] = [
|
||||||
|
Kind.Edit,
|
||||||
|
Kind.Delete,
|
||||||
|
Kind.Move,
|
||||||
|
Kind.Execute,
|
||||||
|
] as const;
|
||||||
|
|
||||||
export interface ToolLocation {
|
export interface ToolLocation {
|
||||||
// Absolute path to the file
|
// Absolute path to the file
|
||||||
path: string;
|
path: string;
|
||||||
|
|||||||
Reference in New Issue
Block a user