fix(core): resolve parallel tool call streaming ID collision (#26646)

This commit is contained in:
Aishanee Shah
2026-05-08 15:14:23 -04:00
committed by GitHub
parent 6b9b778d82
commit 5890f50496
2 changed files with 120 additions and 6 deletions
+106
View File
@@ -551,6 +551,112 @@ describe('GeminiChat', () => {
expect(modelTurn.parts![1].functionCall).toBeDefined();
expect(modelTurn.parts![2].text).toBe('This is the second part.');
});
it('repro: should not overwrite parallel tool calls when they arrive in separate streaming chunks', async () => {
vi.mocked(mockConfig.isContextManagementEnabled).mockReturnValue(true);
// 1. Mock the API to return parallel tool calls in separate chunks.
const parallelCallsStream = (async function* () {
yield {
candidates: [
{
content: {
role: 'model',
parts: [{ functionCall: { name: 'tool_A' } }],
},
},
],
functionCalls: [{ name: 'tool_A' }],
} as unknown as GenerateContentResponse;
yield {
candidates: [
{
content: {
role: 'model',
parts: [{ functionCall: { name: 'tool_B' } }],
},
finishReason: 'STOP',
},
],
functionCalls: [{ name: 'tool_B' }],
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
parallelCallsStream,
);
// 2. Action: Send a message and consume the stream to trigger history recording.
const stream = await chat.sendMessageStream(
{ model: 'test-model' },
'test parallel tools',
'prompt-parallel-tools',
new AbortController().signal,
LlmRole.MAIN,
);
for await (const _ of stream) {
// Consume
}
// 3. Assert: Check that the final history contains both function calls.
const history = chat.getHistory();
expect(history.length).toBe(2);
const modelTurn = history[1];
expect(modelTurn.role).toBe('model');
expect(modelTurn.parts?.length).toBe(2);
expect(modelTurn.parts![0].functionCall?.name).toBe('tool_A');
expect(modelTurn.parts![1].functionCall?.name).toBe('tool_B');
});
it('repro: should not collide when multiple tool calls with the same name arrive in the same chunk', async () => {
vi.mocked(mockConfig.isContextManagementEnabled).mockReturnValue(true);
const sameNameStream = (async function* () {
yield {
candidates: [
{
content: {
role: 'model',
parts: [
{ functionCall: { name: 'tool_X', args: { id: 1 } } },
{ functionCall: { name: 'tool_X', args: { id: 2 } } },
],
},
finishReason: 'STOP',
},
],
functionCalls: [
{ name: 'tool_X', args: { id: 1 } },
{ name: 'tool_X', args: { id: 2 } },
],
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
sameNameStream,
);
const stream = await chat.sendMessageStream(
{ model: 'test-model' },
'test same name tools',
'prompt-same-name',
new AbortController().signal,
LlmRole.MAIN,
);
for await (const _ of stream) {
// Consume the stream to trigger history recording
}
const history = chat.getHistory();
const modelTurn = history[1];
expect(modelTurn.parts?.length).toBe(2);
expect(modelTurn.parts![0].functionCall?.name).toBe('tool_X');
expect(modelTurn.parts![0].functionCall?.args).toEqual({ id: 1 });
expect(modelTurn.parts![1].functionCall?.name).toBe('tool_X');
expect(modelTurn.parts![1].functionCall?.args).toEqual({ id: 2 });
// If findIndex was used, both would likely point to index 0, and the second one might overwrite the first if consolidated incorrectly,
// or they both might end up with the same callIndex and thus the same args in final assembly.
});
it('should preserve text parts that stream in the same chunk as a thought', async () => {
// 1. Mock the API to return a single chunk containing both a thought and visible text.
const mixedContentStream = (async function* () {
+14 -6
View File
@@ -988,8 +988,10 @@ export class GeminiChat {
// Map to track synthetic IDs assigned to each call index across chunks
const callIndexToId = new Map<number, string>();
let runningFunctionCallCounter = 0;
for await (const chunk of streamResponse) {
const currentChunkStartCounter = runningFunctionCallCounter;
const candidateWithReason = chunk?.candidates?.find(
(candidate) => candidate.finishReason,
);
@@ -1002,19 +1004,21 @@ export class GeminiChat {
if (this.context.config.isContextManagementEnabled()) {
for (let i = 0; i < chunk.functionCalls.length; i++) {
const fnCall = chunk.functionCalls[i];
const globalIndex = currentChunkStartCounter + i;
if (!fnCall.id) {
let id = callIndexToId.get(i);
let id = callIndexToId.get(globalIndex);
if (!id) {
id = `synth_${this.context.promptId}_${Date.now()}_${this.callCounter++}`;
callIndexToId.set(i, id);
callIndexToId.set(globalIndex, id);
debugLogger.log(
`[GeminiChat] Assigned synthetic ID: ${id} to tool at index ${i}: ${fnCall.name}`,
`[GeminiChat] Assigned synthetic ID: ${id} to tool at index ${globalIndex}: ${fnCall.name}`,
);
}
fnCall.id = id;
}
finalFunctionCallsMap.set(fnCall.id, fnCall);
}
runningFunctionCallCounter += chunk.functionCalls.length;
} else {
legacyFunctionCalls.push(...chunk.functionCalls);
}
@@ -1031,6 +1035,7 @@ export class GeminiChat {
hasToolCall = true;
}
let localFunctionCallCounter = 0;
modelResponseParts.push(
...content.parts
.filter((part) => !part.thought)
@@ -1038,11 +1043,14 @@ export class GeminiChat {
if (!this.context.config.isContextManagementEnabled()) {
return part;
}
let callIndex: number | undefined;
if (part.functionCall) {
callIndex =
currentChunkStartCounter + localFunctionCallCounter++;
}
return {
...part,
callIndex: chunk.functionCalls?.findIndex(
(fc) => fc.name === part.functionCall?.name,
),
callIndex,
};
}),
);