mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 13:22:35 -07:00
fix(core): resolve parallel tool call streaming ID collision (#26646)
This commit is contained in:
@@ -551,6 +551,112 @@ describe('GeminiChat', () => {
|
|||||||
expect(modelTurn.parts![1].functionCall).toBeDefined();
|
expect(modelTurn.parts![1].functionCall).toBeDefined();
|
||||||
expect(modelTurn.parts![2].text).toBe('This is the second part.');
|
expect(modelTurn.parts![2].text).toBe('This is the second part.');
|
||||||
});
|
});
|
||||||
|
it('repro: should not overwrite parallel tool calls when they arrive in separate streaming chunks', async () => {
|
||||||
|
vi.mocked(mockConfig.isContextManagementEnabled).mockReturnValue(true);
|
||||||
|
|
||||||
|
// 1. Mock the API to return parallel tool calls in separate chunks.
|
||||||
|
const parallelCallsStream = (async function* () {
|
||||||
|
yield {
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ functionCall: { name: 'tool_A' } }],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
functionCalls: [{ name: 'tool_A' }],
|
||||||
|
} as unknown as GenerateContentResponse;
|
||||||
|
yield {
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ functionCall: { name: 'tool_B' } }],
|
||||||
|
},
|
||||||
|
finishReason: 'STOP',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
functionCalls: [{ name: 'tool_B' }],
|
||||||
|
} as unknown as GenerateContentResponse;
|
||||||
|
})();
|
||||||
|
|
||||||
|
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||||
|
parallelCallsStream,
|
||||||
|
);
|
||||||
|
|
||||||
|
// 2. Action: Send a message and consume the stream to trigger history recording.
|
||||||
|
const stream = await chat.sendMessageStream(
|
||||||
|
{ model: 'test-model' },
|
||||||
|
'test parallel tools',
|
||||||
|
'prompt-parallel-tools',
|
||||||
|
new AbortController().signal,
|
||||||
|
LlmRole.MAIN,
|
||||||
|
);
|
||||||
|
for await (const _ of stream) {
|
||||||
|
// Consume
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Assert: Check that the final history contains both function calls.
|
||||||
|
const history = chat.getHistory();
|
||||||
|
expect(history.length).toBe(2);
|
||||||
|
|
||||||
|
const modelTurn = history[1];
|
||||||
|
expect(modelTurn.role).toBe('model');
|
||||||
|
expect(modelTurn.parts?.length).toBe(2);
|
||||||
|
expect(modelTurn.parts![0].functionCall?.name).toBe('tool_A');
|
||||||
|
expect(modelTurn.parts![1].functionCall?.name).toBe('tool_B');
|
||||||
|
});
|
||||||
|
it('repro: should not collide when multiple tool calls with the same name arrive in the same chunk', async () => {
|
||||||
|
vi.mocked(mockConfig.isContextManagementEnabled).mockReturnValue(true);
|
||||||
|
|
||||||
|
const sameNameStream = (async function* () {
|
||||||
|
yield {
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
role: 'model',
|
||||||
|
parts: [
|
||||||
|
{ functionCall: { name: 'tool_X', args: { id: 1 } } },
|
||||||
|
{ functionCall: { name: 'tool_X', args: { id: 2 } } },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
finishReason: 'STOP',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
functionCalls: [
|
||||||
|
{ name: 'tool_X', args: { id: 1 } },
|
||||||
|
{ name: 'tool_X', args: { id: 2 } },
|
||||||
|
],
|
||||||
|
} as unknown as GenerateContentResponse;
|
||||||
|
})();
|
||||||
|
|
||||||
|
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||||
|
sameNameStream,
|
||||||
|
);
|
||||||
|
|
||||||
|
const stream = await chat.sendMessageStream(
|
||||||
|
{ model: 'test-model' },
|
||||||
|
'test same name tools',
|
||||||
|
'prompt-same-name',
|
||||||
|
new AbortController().signal,
|
||||||
|
LlmRole.MAIN,
|
||||||
|
);
|
||||||
|
for await (const _ of stream) {
|
||||||
|
// Consume the stream to trigger history recording
|
||||||
|
}
|
||||||
|
|
||||||
|
const history = chat.getHistory();
|
||||||
|
const modelTurn = history[1];
|
||||||
|
expect(modelTurn.parts?.length).toBe(2);
|
||||||
|
expect(modelTurn.parts![0].functionCall?.name).toBe('tool_X');
|
||||||
|
expect(modelTurn.parts![0].functionCall?.args).toEqual({ id: 1 });
|
||||||
|
expect(modelTurn.parts![1].functionCall?.name).toBe('tool_X');
|
||||||
|
expect(modelTurn.parts![1].functionCall?.args).toEqual({ id: 2 });
|
||||||
|
|
||||||
|
// If findIndex was used, both would likely point to index 0, and the second one might overwrite the first if consolidated incorrectly,
|
||||||
|
// or they both might end up with the same callIndex and thus the same args in final assembly.
|
||||||
|
});
|
||||||
it('should preserve text parts that stream in the same chunk as a thought', async () => {
|
it('should preserve text parts that stream in the same chunk as a thought', async () => {
|
||||||
// 1. Mock the API to return a single chunk containing both a thought and visible text.
|
// 1. Mock the API to return a single chunk containing both a thought and visible text.
|
||||||
const mixedContentStream = (async function* () {
|
const mixedContentStream = (async function* () {
|
||||||
|
|||||||
@@ -988,8 +988,10 @@ export class GeminiChat {
|
|||||||
|
|
||||||
// Map to track synthetic IDs assigned to each call index across chunks
|
// Map to track synthetic IDs assigned to each call index across chunks
|
||||||
const callIndexToId = new Map<number, string>();
|
const callIndexToId = new Map<number, string>();
|
||||||
|
let runningFunctionCallCounter = 0;
|
||||||
|
|
||||||
for await (const chunk of streamResponse) {
|
for await (const chunk of streamResponse) {
|
||||||
|
const currentChunkStartCounter = runningFunctionCallCounter;
|
||||||
const candidateWithReason = chunk?.candidates?.find(
|
const candidateWithReason = chunk?.candidates?.find(
|
||||||
(candidate) => candidate.finishReason,
|
(candidate) => candidate.finishReason,
|
||||||
);
|
);
|
||||||
@@ -1002,19 +1004,21 @@ export class GeminiChat {
|
|||||||
if (this.context.config.isContextManagementEnabled()) {
|
if (this.context.config.isContextManagementEnabled()) {
|
||||||
for (let i = 0; i < chunk.functionCalls.length; i++) {
|
for (let i = 0; i < chunk.functionCalls.length; i++) {
|
||||||
const fnCall = chunk.functionCalls[i];
|
const fnCall = chunk.functionCalls[i];
|
||||||
|
const globalIndex = currentChunkStartCounter + i;
|
||||||
if (!fnCall.id) {
|
if (!fnCall.id) {
|
||||||
let id = callIndexToId.get(i);
|
let id = callIndexToId.get(globalIndex);
|
||||||
if (!id) {
|
if (!id) {
|
||||||
id = `synth_${this.context.promptId}_${Date.now()}_${this.callCounter++}`;
|
id = `synth_${this.context.promptId}_${Date.now()}_${this.callCounter++}`;
|
||||||
callIndexToId.set(i, id);
|
callIndexToId.set(globalIndex, id);
|
||||||
debugLogger.log(
|
debugLogger.log(
|
||||||
`[GeminiChat] Assigned synthetic ID: ${id} to tool at index ${i}: ${fnCall.name}`,
|
`[GeminiChat] Assigned synthetic ID: ${id} to tool at index ${globalIndex}: ${fnCall.name}`,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
fnCall.id = id;
|
fnCall.id = id;
|
||||||
}
|
}
|
||||||
finalFunctionCallsMap.set(fnCall.id, fnCall);
|
finalFunctionCallsMap.set(fnCall.id, fnCall);
|
||||||
}
|
}
|
||||||
|
runningFunctionCallCounter += chunk.functionCalls.length;
|
||||||
} else {
|
} else {
|
||||||
legacyFunctionCalls.push(...chunk.functionCalls);
|
legacyFunctionCalls.push(...chunk.functionCalls);
|
||||||
}
|
}
|
||||||
@@ -1031,6 +1035,7 @@ export class GeminiChat {
|
|||||||
hasToolCall = true;
|
hasToolCall = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let localFunctionCallCounter = 0;
|
||||||
modelResponseParts.push(
|
modelResponseParts.push(
|
||||||
...content.parts
|
...content.parts
|
||||||
.filter((part) => !part.thought)
|
.filter((part) => !part.thought)
|
||||||
@@ -1038,11 +1043,14 @@ export class GeminiChat {
|
|||||||
if (!this.context.config.isContextManagementEnabled()) {
|
if (!this.context.config.isContextManagementEnabled()) {
|
||||||
return part;
|
return part;
|
||||||
}
|
}
|
||||||
|
let callIndex: number | undefined;
|
||||||
|
if (part.functionCall) {
|
||||||
|
callIndex =
|
||||||
|
currentChunkStartCounter + localFunctionCallCounter++;
|
||||||
|
}
|
||||||
return {
|
return {
|
||||||
...part,
|
...part,
|
||||||
callIndex: chunk.functionCalls?.findIndex(
|
callIndex,
|
||||||
(fc) => fc.name === part.functionCall?.name,
|
|
||||||
),
|
|
||||||
};
|
};
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|||||||
Reference in New Issue
Block a user