Handle cleaning up the response text in the UI when a response stream retry occurs (#7416)

This commit is contained in:
Victor May
2025-09-03 22:00:16 -04:00
committed by GitHub
parent b49410e1d0
commit e7a4142b2a
8 changed files with 455 additions and 256 deletions
@@ -650,6 +650,9 @@ export const useGeminiStream = (
// before we add loop detected message to history // before we add loop detected message to history
loopDetectedRef.current = true; loopDetectedRef.current = true;
break; break;
case ServerGeminiEventType.Retry:
// Will add the missing logic later
break;
default: { default: {
// enforces exhaustive switch-case // enforces exhaustive switch-case
const unreachable: never = event; const unreachable: never = event;
@@ -24,6 +24,7 @@ import {
getErrorStatus, getErrorStatus,
MCPServerConfig, MCPServerConfig,
DiscoveredMCPTool, DiscoveredMCPTool,
StreamEventType,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import * as acp from './acp.js'; import * as acp from './acp.js';
import { AcpFileSystemService } from './fileSystemService.js'; import { AcpFileSystemService } from './fileSystemService.js';
@@ -269,8 +270,12 @@ class Session {
return { stopReason: 'cancelled' }; return { stopReason: 'cancelled' };
} }
if (resp.candidates && resp.candidates.length > 0) { if (
const candidate = resp.candidates[0]; resp.type === StreamEventType.CHUNK &&
resp.value.candidates &&
resp.value.candidates.length > 0
) {
const candidate = resp.value.candidates[0];
for (const part of candidate.content?.parts ?? []) { for (const part of candidate.content?.parts ?? []) {
if (!part.text) { if (!part.text) {
continue; continue;
@@ -290,8 +295,8 @@ class Session {
} }
} }
if (resp.functionCalls) { if (resp.type === StreamEventType.CHUNK && resp.value.functionCalls) {
functionCalls.push(...resp.functionCalls); functionCalls.push(...resp.value.functionCalls);
} }
} }
} catch (error) { } catch (error) {
+118 -7
View File
@@ -12,7 +12,12 @@ import type {
Part, Part,
GenerateContentResponse, GenerateContentResponse,
} from '@google/genai'; } from '@google/genai';
import { GeminiChat, EmptyStreamError } from './geminiChat.js'; import {
GeminiChat,
EmptyStreamError,
StreamEventType,
type StreamEvent,
} from './geminiChat.js';
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
import { setSimulate429 } from '../utils/testUtils.js'; import { setSimulate429 } from '../utils/testUtils.js';
@@ -955,6 +960,42 @@ describe('GeminiChat', () => {
}); });
describe('sendMessageStream with retries', () => { describe('sendMessageStream with retries', () => {
it('should yield a RETRY event when an invalid stream is encountered', async () => {
// ARRANGE: Mock the stream to fail once, then succeed.
vi.mocked(mockModelsModule.generateContentStream)
.mockImplementationOnce(async () =>
// First attempt: An invalid stream with an empty text part.
(async function* () {
yield {
candidates: [{ content: { parts: [{ text: '' }] } }],
} as unknown as GenerateContentResponse;
})(),
)
.mockImplementationOnce(async () =>
// Second attempt (the retry): A minimal valid stream.
(async function* () {
yield {
candidates: [{ content: { parts: [{ text: 'Success' }] } }],
} as unknown as GenerateContentResponse;
})(),
);
// ACT: Send a message and collect all events from the stream.
const stream = await chat.sendMessageStream(
{ message: 'test' },
'prompt-id-yield-retry',
);
const events: StreamEvent[] = [];
for await (const event of stream) {
events.push(event);
}
// ASSERT: Check that a RETRY event was present in the stream's output.
const retryEvent = events.find((e) => e.type === StreamEventType.RETRY);
expect(retryEvent).toBeDefined();
expect(retryEvent?.type).toBe(StreamEventType.RETRY);
});
it('should retry on invalid content, succeed, and report metrics', async () => { it('should retry on invalid content, succeed, and report metrics', async () => {
// Use mockImplementationOnce to provide a fresh, promise-wrapped generator for each attempt. // Use mockImplementationOnce to provide a fresh, promise-wrapped generator for each attempt.
vi.mocked(mockModelsModule.generateContentStream) vi.mocked(mockModelsModule.generateContentStream)
@@ -981,7 +1022,7 @@ describe('GeminiChat', () => {
{ message: 'test' }, { message: 'test' },
'prompt-id-retry-success', 'prompt-id-retry-success',
); );
const chunks = []; const chunks: StreamEvent[] = [];
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk);
} }
@@ -991,11 +1032,17 @@ describe('GeminiChat', () => {
expect(mockLogContentRetry).toHaveBeenCalledTimes(1); expect(mockLogContentRetry).toHaveBeenCalledTimes(1);
expect(mockLogContentRetryFailure).not.toHaveBeenCalled(); expect(mockLogContentRetryFailure).not.toHaveBeenCalled();
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2);
// Check for a retry event
expect(chunks.some((c) => c.type === StreamEventType.RETRY)).toBe(true);
// Check for the successful content chunk
expect( expect(
chunks.some( chunks.some(
(c) => (c) =>
c.candidates?.[0]?.content?.parts?.[0]?.text === c.type === StreamEventType.CHUNK &&
'Successful response', c.value.candidates?.[0]?.content?.parts?.[0]?.text ===
'Successful response',
), ),
).toBe(true); ).toBe(true);
@@ -1236,7 +1283,7 @@ describe('GeminiChat', () => {
{ message: 'test empty stream' }, { message: 'test empty stream' },
'prompt-id-empty-stream', 'prompt-id-empty-stream',
); );
const chunks = []; const chunks: StreamEvent[] = [];
for await (const chunk of stream) { for await (const chunk of stream) {
chunks.push(chunk); chunks.push(chunk);
} }
@@ -1246,8 +1293,9 @@ describe('GeminiChat', () => {
expect( expect(
chunks.some( chunks.some(
(c) => (c) =>
c.candidates?.[0]?.content?.parts?.[0]?.text === c.type === StreamEventType.CHUNK &&
'Successful response after empty', c.value.candidates?.[0]?.content?.parts?.[0]?.text ===
'Successful response after empty',
), ),
).toBe(true); ).toBe(true);
@@ -1346,4 +1394,67 @@ describe('GeminiChat', () => {
} }
expect(turn4.parts[0].text).toBe('second response'); expect(turn4.parts[0].text).toBe('second response');
}); });
it('should discard valid partial content from a failed attempt upon retry', async () => {
// ARRANGE: Mock the stream to fail on the first attempt after yielding some valid content.
vi.mocked(mockModelsModule.generateContentStream)
.mockImplementationOnce(async () =>
// First attempt: yields one valid chunk, then one invalid chunk
(async function* () {
yield {
candidates: [
{
content: {
parts: [{ text: 'This valid part should be discarded' }],
},
},
],
} as unknown as GenerateContentResponse;
yield {
candidates: [{ content: { parts: [{ text: '' }] } }], // Invalid chunk triggers retry
} as unknown as GenerateContentResponse;
})(),
)
.mockImplementationOnce(async () =>
// Second attempt (the retry): succeeds
(async function* () {
yield {
candidates: [
{
content: {
parts: [{ text: 'Successful final response' }],
},
},
],
} as unknown as GenerateContentResponse;
})(),
);
// ACT: Send a message and consume the stream
const stream = await chat.sendMessageStream(
{ message: 'test' },
'prompt-id-discard-test',
);
const events: StreamEvent[] = [];
for await (const event of stream) {
events.push(event);
}
// ASSERT
// Check that a retry happened
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2);
expect(events.some((e) => e.type === StreamEventType.RETRY)).toBe(true);
// Check the final recorded history
const history = chat.getHistory();
expect(history.length).toBe(2); // user turn + final model turn
const modelTurn = history[1]!;
// The model turn should only contain the text from the successful attempt
expect(modelTurn!.parts![0]!.text).toBe('Successful final response');
// It should NOT contain any text from the failed attempt
expect(modelTurn!.parts![0]!.text).not.toContain(
'This valid part should be discarded',
);
});
}); });
+18 -2
View File
@@ -39,6 +39,18 @@ import {
import { isFunctionResponse } from '../utils/messageInspectors.js'; import { isFunctionResponse } from '../utils/messageInspectors.js';
import { partListUnionToString } from './geminiRequest.js'; import { partListUnionToString } from './geminiRequest.js';
export enum StreamEventType {
/** A regular content chunk from the API. */
CHUNK = 'chunk',
/** A signal that a retry is about to happen. The UI should discard any partial
* content from the attempt that just failed. */
RETRY = 'retry',
}
export type StreamEvent =
| { type: StreamEventType.CHUNK; value: GenerateContentResponse }
| { type: StreamEventType.RETRY };
/** /**
* Options for retrying due to invalid content from the model. * Options for retrying due to invalid content from the model.
*/ */
@@ -360,7 +372,7 @@ export class GeminiChat {
async sendMessageStream( async sendMessageStream(
params: SendMessageParameters, params: SendMessageParameters,
prompt_id: string, prompt_id: string,
): Promise<AsyncGenerator<GenerateContentResponse>> { ): Promise<AsyncGenerator<StreamEvent>> {
await this.sendPromise; await this.sendPromise;
let streamDoneResolver: () => void; let streamDoneResolver: () => void;
@@ -400,6 +412,10 @@ export class GeminiChat {
attempt++ attempt++
) { ) {
try { try {
if (attempt > 0) {
yield { type: StreamEventType.RETRY };
}
const stream = await self.makeApiCallAndProcessStream( const stream = await self.makeApiCallAndProcessStream(
requestContents, requestContents,
params, params,
@@ -408,7 +424,7 @@ export class GeminiChat {
); );
for await (const chunk of stream) { for await (const chunk of stream) {
yield chunk; yield { type: StreamEventType.CHUNK, value: chunk };
} }
lastError = null; lastError = null;
+24 -8
View File
@@ -21,7 +21,7 @@ import type {
} from './subagent.js'; } from './subagent.js';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
import type { ConfigParameters } from '../config/config.js'; import type { ConfigParameters } from '../config/config.js';
import { GeminiChat } from './geminiChat.js'; import { GeminiChat, StreamEventType } from './geminiChat.js';
import { createContentGenerator } from './contentGenerator.js'; import { createContentGenerator } from './contentGenerator.js';
import { getEnvironmentContext } from '../utils/environmentContext.js'; import { getEnvironmentContext } from '../utils/environmentContext.js';
import { executeToolCall } from './nonInteractiveToolExecutor.js'; import { executeToolCall } from './nonInteractiveToolExecutor.js';
@@ -33,6 +33,7 @@ import type {
FunctionCall, FunctionCall,
FunctionDeclaration, FunctionDeclaration,
GenerateContentConfig, GenerateContentConfig,
GenerateContentResponse,
} from '@google/genai'; } from '@google/genai';
import { ToolErrorType } from '../tools/tool-error.js'; import { ToolErrorType } from '../tools/tool-error.js';
@@ -73,18 +74,33 @@ const createMockStream = (
functionCallsList: Array<FunctionCall[] | 'stop'>, functionCallsList: Array<FunctionCall[] | 'stop'>,
) => { ) => {
let index = 0; let index = 0;
return vi.fn().mockImplementation(() => { // This mock now returns a Promise that resolves to the async generator,
// matching the new signature for sendMessageStream.
return vi.fn().mockImplementation(async () => {
const response = functionCallsList[index] || 'stop'; const response = functionCallsList[index] || 'stop';
index++; index++;
return (async function* () { return (async function* () {
if (response === 'stop') { let mockResponseValue: Partial<GenerateContentResponse>;
// When stopping, the model might return text, but the subagent logic primarily cares about the absence of functionCalls.
yield { text: 'Done.' }; if (response === 'stop' || response.length === 0) {
} else if (response.length > 0) { // Simulate a text response for stop/empty conditions.
yield { functionCalls: response }; mockResponseValue = {
candidates: [{ content: { parts: [{ text: 'Done.' }] } }],
};
} else { } else {
yield { text: 'Done.' }; // Handle empty array also as stop // Simulate a tool call response.
mockResponseValue = {
candidates: [], // Good practice to include for safety.
functionCalls: response,
};
} }
// The stream must now yield a StreamEvent object of type CHUNK.
yield {
type: StreamEventType.CHUNK,
value: mockResponseValue as GenerateContentResponse,
};
})(); })();
}); });
}; };
+5 -6
View File
@@ -20,7 +20,7 @@ import type {
FunctionDeclaration, FunctionDeclaration,
} from '@google/genai'; } from '@google/genai';
import { Type } from '@google/genai'; import { Type } from '@google/genai';
import { GeminiChat } from './geminiChat.js'; import { GeminiChat, StreamEventType } from './geminiChat.js';
/** /**
* @fileoverview Defines the configuration interfaces for a subagent. * @fileoverview Defines the configuration interfaces for a subagent.
@@ -439,12 +439,11 @@ export class SubAgentScope {
let textResponse = ''; let textResponse = '';
for await (const resp of responseStream) { for await (const resp of responseStream) {
if (abortController.signal.aborted) return; if (abortController.signal.aborted) return;
if (resp.functionCalls) { if (resp.type === StreamEventType.CHUNK && resp.value.functionCalls) {
functionCalls.push(...resp.functionCalls); functionCalls.push(...resp.value.functionCalls);
} }
const text = resp.text; if (resp.type === StreamEventType.CHUNK && resp.value.text) {
if (text) { textResponse += resp.value.text;
textResponse += text;
} }
} }
+249 -219
View File
@@ -13,6 +13,7 @@ import { Turn, GeminiEventType } from './turn.js';
import type { GenerateContentResponse, Part, Content } from '@google/genai'; import type { GenerateContentResponse, Part, Content } from '@google/genai';
import { reportError } from '../utils/errorReporting.js'; import { reportError } from '../utils/errorReporting.js';
import type { GeminiChat } from './geminiChat.js'; import type { GeminiChat } from './geminiChat.js';
import { StreamEventType } from './geminiChat.js';
const mockSendMessageStream = vi.fn(); const mockSendMessageStream = vi.fn();
const mockGetHistory = vi.fn(); const mockGetHistory = vi.fn();
@@ -35,6 +36,7 @@ vi.mock('../utils/errorReporting', () => ({
reportError: vi.fn(), reportError: vi.fn(),
})); }));
// Use the actual implementation from partUtils now that it's provided.
vi.mock('../utils/generateContentResponseUtilities', () => ({ vi.mock('../utils/generateContentResponseUtilities', () => ({
getResponseText: (resp: GenerateContentResponse) => getResponseText: (resp: GenerateContentResponse) =>
resp.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') || resp.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') ||
@@ -78,11 +80,17 @@ describe('Turn', () => {
it('should yield content events for text parts', async () => { it('should yield content events for text parts', async () => {
const mockResponseStream = (async function* () { const mockResponseStream = (async function* () {
yield { yield {
candidates: [{ content: { parts: [{ text: 'Hello' }] } }], type: StreamEventType.CHUNK,
} as unknown as GenerateContentResponse; value: {
candidates: [{ content: { parts: [{ text: 'Hello' }] } }],
} as GenerateContentResponse,
};
yield { yield {
candidates: [{ content: { parts: [{ text: ' world' }] } }], type: StreamEventType.CHUNK,
} as unknown as GenerateContentResponse; value: {
candidates: [{ content: { parts: [{ text: ' world' }] } }],
} as GenerateContentResponse,
};
})(); })();
mockSendMessageStream.mockResolvedValue(mockResponseStream); mockSendMessageStream.mockResolvedValue(mockResponseStream);
@@ -105,21 +113,7 @@ describe('Turn', () => {
expect(events).toEqual([ expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'Hello' }, { type: GeminiEventType.Content, value: 'Hello' },
{
type: GeminiEventType.Finished,
value: {
reason: undefined,
usageMetadata: undefined,
},
},
{ type: GeminiEventType.Content, value: ' world' }, { type: GeminiEventType.Content, value: ' world' },
{
type: GeminiEventType.Finished,
value: {
reason: undefined,
usageMetadata: undefined,
},
},
]); ]);
expect(turn.getDebugResponses().length).toBe(2); expect(turn.getDebugResponses().length).toBe(2);
}); });
@@ -127,16 +121,23 @@ describe('Turn', () => {
it('should yield tool_call_request events for function calls', async () => { it('should yield tool_call_request events for function calls', async () => {
const mockResponseStream = (async function* () { const mockResponseStream = (async function* () {
yield { yield {
functionCalls: [ type: StreamEventType.CHUNK,
{ value: {
id: 'fc1', functionCalls: [
name: 'tool1', {
args: { arg1: 'val1' }, id: 'fc1',
isClientInitiated: false, name: 'tool1',
}, args: { arg1: 'val1' },
{ name: 'tool2', args: { arg2: 'val2' }, isClientInitiated: false }, // No ID isClientInitiated: false,
], },
} as unknown as GenerateContentResponse; {
name: 'tool2',
args: { arg2: 'val2' },
isClientInitiated: false,
}, // No ID
],
} as unknown as GenerateContentResponse,
};
})(); })();
mockSendMessageStream.mockResolvedValue(mockResponseStream); mockSendMessageStream.mockResolvedValue(mockResponseStream);
@@ -149,7 +150,7 @@ describe('Turn', () => {
events.push(event); events.push(event);
} }
expect(events.length).toBe(3); expect(events.length).toBe(2);
const event1 = events[0] as ServerGeminiToolCallRequestEvent; const event1 = events[0] as ServerGeminiToolCallRequestEvent;
expect(event1.type).toBe(GeminiEventType.ToolCallRequest); expect(event1.type).toBe(GeminiEventType.ToolCallRequest);
expect(event1.value).toEqual( expect(event1.value).toEqual(
@@ -182,18 +183,24 @@ describe('Turn', () => {
const abortController = new AbortController(); const abortController = new AbortController();
const mockResponseStream = (async function* () { const mockResponseStream = (async function* () {
yield { yield {
candidates: [{ content: { parts: [{ text: 'First part' }] } }], type: StreamEventType.CHUNK,
} as unknown as GenerateContentResponse; value: {
candidates: [{ content: { parts: [{ text: 'First part' }] } }],
} as GenerateContentResponse,
};
abortController.abort(); abortController.abort();
yield { yield {
candidates: [ type: StreamEventType.CHUNK,
{ value: {
content: { candidates: [
parts: [{ text: 'Second part - should not be processed' }], {
content: {
parts: [{ text: 'Second part - should not be processed' }],
},
}, },
}, ],
], } as GenerateContentResponse,
} as unknown as GenerateContentResponse; };
})(); })();
mockSendMessageStream.mockResolvedValue(mockResponseStream); mockSendMessageStream.mockResolvedValue(mockResponseStream);
@@ -204,13 +211,6 @@ describe('Turn', () => {
} }
expect(events).toEqual([ expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'First part' }, { type: GeminiEventType.Content, value: 'First part' },
{
type: GeminiEventType.Finished,
value: {
reason: undefined,
usageMetadata: undefined,
},
},
{ type: GeminiEventType.UserCancelled }, { type: GeminiEventType.UserCancelled },
]); ]);
expect(turn.getDebugResponses().length).toBe(1); expect(turn.getDebugResponses().length).toBe(1);
@@ -251,86 +251,79 @@ describe('Turn', () => {
it('should handle function calls with undefined name or args', async () => { it('should handle function calls with undefined name or args', async () => {
const mockResponseStream = (async function* () { const mockResponseStream = (async function* () {
yield { yield {
functionCalls: [ type: StreamEventType.CHUNK,
{ id: 'fc1', name: undefined, args: { arg1: 'val1' } }, value: {
{ id: 'fc2', name: 'tool2', args: undefined }, candidates: [],
{ id: 'fc3', name: undefined, args: undefined }, functionCalls: [
], // Add `id` back to the mock to match what the code expects
} as unknown as GenerateContentResponse; { id: 'fc1', name: undefined, args: { arg1: 'val1' } },
{ id: 'fc2', name: 'tool2', args: undefined },
{ id: 'fc3', name: undefined, args: undefined },
],
},
};
})(); })();
mockSendMessageStream.mockResolvedValue(mockResponseStream); mockSendMessageStream.mockResolvedValue(mockResponseStream);
const events = []; const events = [];
const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
for await (const event of turn.run( for await (const event of turn.run(
reqParts, [{ text: 'Test undefined tool parts' }],
new AbortController().signal, new AbortController().signal,
)) { )) {
events.push(event); events.push(event);
} }
expect(events.length).toBe(4); expect(events.length).toBe(3);
// Assertions for each specific tool call event
const event1 = events[0] as ServerGeminiToolCallRequestEvent; const event1 = events[0] as ServerGeminiToolCallRequestEvent;
expect(event1.type).toBe(GeminiEventType.ToolCallRequest); expect(event1.value).toMatchObject({
expect(event1.value).toEqual( callId: 'fc1',
expect.objectContaining({ name: 'undefined_tool_name',
callId: 'fc1', args: { arg1: 'val1' },
name: 'undefined_tool_name', });
args: { arg1: 'val1' },
isClientInitiated: false,
}),
);
expect(turn.pendingToolCalls[0]).toEqual(event1.value);
const event2 = events[1] as ServerGeminiToolCallRequestEvent; const event2 = events[1] as ServerGeminiToolCallRequestEvent;
expect(event2.type).toBe(GeminiEventType.ToolCallRequest); expect(event2.value).toMatchObject({
expect(event2.value).toEqual( callId: 'fc2',
expect.objectContaining({ name: 'tool2',
callId: 'fc2', args: {},
name: 'tool2', });
args: {},
isClientInitiated: false,
}),
);
expect(turn.pendingToolCalls[1]).toEqual(event2.value);
const event3 = events[2] as ServerGeminiToolCallRequestEvent; const event3 = events[2] as ServerGeminiToolCallRequestEvent;
expect(event3.type).toBe(GeminiEventType.ToolCallRequest); expect(event3.value).toMatchObject({
expect(event3.value).toEqual( callId: 'fc3',
expect.objectContaining({ name: 'undefined_tool_name',
callId: 'fc3', args: {},
name: 'undefined_tool_name', });
args: {},
isClientInitiated: false,
}),
);
expect(turn.pendingToolCalls[2]).toEqual(event3.value);
expect(turn.getDebugResponses().length).toBe(1);
}); });
it('should yield finished event when response has finish reason', async () => { it('should yield finished event when response has finish reason', async () => {
const mockResponseStream = (async function* () { const mockResponseStream = (async function* () {
yield { yield {
candidates: [ type: StreamEventType.CHUNK,
{ value: {
content: { parts: [{ text: 'Partial response' }] }, candidates: [
finishReason: 'STOP', {
content: { parts: [{ text: 'Partial response' }] },
finishReason: 'STOP',
},
],
usageMetadata: {
promptTokenCount: 17,
candidatesTokenCount: 50,
cachedContentTokenCount: 10,
thoughtsTokenCount: 5,
toolUsePromptTokenCount: 2,
}, },
], } as GenerateContentResponse,
usageMetadata: { };
promptTokenCount: 17,
candidatesTokenCount: 50,
cachedContentTokenCount: 10,
thoughtsTokenCount: 5,
toolUsePromptTokenCount: 2,
},
} as unknown as GenerateContentResponse;
})(); })();
mockSendMessageStream.mockResolvedValue(mockResponseStream); mockSendMessageStream.mockResolvedValue(mockResponseStream);
const events = []; const events = [];
const reqParts: Part[] = [{ text: 'Test finish reason' }];
for await (const event of turn.run( for await (const event of turn.run(
reqParts, [{ text: 'Test finish reason' }],
new AbortController().signal, new AbortController().signal,
)) { )) {
events.push(event); events.push(event);
@@ -357,17 +350,20 @@ describe('Turn', () => {
it('should yield finished event for MAX_TOKENS finish reason', async () => { it('should yield finished event for MAX_TOKENS finish reason', async () => {
const mockResponseStream = (async function* () { const mockResponseStream = (async function* () {
yield { yield {
candidates: [ type: StreamEventType.CHUNK,
{ value: {
content: { candidates: [
parts: [ {
{ text: 'This is a long response that was cut off...' }, content: {
], parts: [
{ text: 'This is a long response that was cut off...' },
],
},
finishReason: 'MAX_TOKENS',
}, },
finishReason: 'MAX_TOKENS', ],
}, },
], };
} as unknown as GenerateContentResponse;
})(); })();
mockSendMessageStream.mockResolvedValue(mockResponseStream); mockSendMessageStream.mockResolvedValue(mockResponseStream);
@@ -395,13 +391,16 @@ describe('Turn', () => {
it('should yield finished event for SAFETY finish reason', async () => { it('should yield finished event for SAFETY finish reason', async () => {
const mockResponseStream = (async function* () { const mockResponseStream = (async function* () {
yield { yield {
candidates: [ type: StreamEventType.CHUNK,
{ value: {
content: { parts: [{ text: 'Content blocked' }] }, candidates: [
finishReason: 'SAFETY', {
}, content: { parts: [{ text: 'Content blocked' }] },
], finishReason: 'SAFETY',
} as unknown as GenerateContentResponse; },
],
},
};
})(); })();
mockSendMessageStream.mockResolvedValue(mockResponseStream); mockSendMessageStream.mockResolvedValue(mockResponseStream);
@@ -426,13 +425,18 @@ describe('Turn', () => {
it('should yield finished event with undefined reason when there is no finish reason', async () => { it('should yield finished event with undefined reason when there is no finish reason', async () => {
const mockResponseStream = (async function* () { const mockResponseStream = (async function* () {
yield { yield {
candidates: [ type: StreamEventType.CHUNK,
{ value: {
content: { parts: [{ text: 'Response without finish reason' }] }, candidates: [
// No finishReason property {
}, content: {
], parts: [{ text: 'Response without finish reason' }],
} as unknown as GenerateContentResponse; },
// No finishReason property
},
],
},
};
})(); })();
mockSendMessageStream.mockResolvedValue(mockResponseStream); mockSendMessageStream.mockResolvedValue(mockResponseStream);
@@ -450,31 +454,33 @@ describe('Turn', () => {
type: GeminiEventType.Content, type: GeminiEventType.Content,
value: 'Response without finish reason', value: 'Response without finish reason',
}, },
{
type: GeminiEventType.Finished,
value: { reason: undefined, usageMetadata: undefined },
},
]); ]);
}); });
it('should handle multiple responses with different finish reasons', async () => { it('should handle multiple responses with different finish reasons', async () => {
const mockResponseStream = (async function* () { const mockResponseStream = (async function* () {
yield { yield {
candidates: [ type: StreamEventType.CHUNK,
{ value: {
content: { parts: [{ text: 'First part' }] }, candidates: [
// No finish reason on first response {
}, content: { parts: [{ text: 'First part' }] },
], // No finish reason on first response
} as unknown as GenerateContentResponse; },
],
},
};
yield { yield {
candidates: [ value: {
{ type: StreamEventType.CHUNK,
content: { parts: [{ text: 'Second part' }] }, candidates: [
finishReason: 'OTHER', {
}, content: { parts: [{ text: 'Second part' }] },
], finishReason: 'OTHER',
} as unknown as GenerateContentResponse; },
],
},
};
})(); })();
mockSendMessageStream.mockResolvedValue(mockResponseStream); mockSendMessageStream.mockResolvedValue(mockResponseStream);
@@ -489,13 +495,6 @@ describe('Turn', () => {
expect(events).toEqual([ expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'First part' }, { type: GeminiEventType.Content, value: 'First part' },
{
type: GeminiEventType.Finished,
value: {
reason: undefined,
usageMetadata: undefined,
},
},
{ type: GeminiEventType.Content, value: 'Second part' }, { type: GeminiEventType.Content, value: 'Second part' },
{ {
type: GeminiEventType.Finished, type: GeminiEventType.Finished,
@@ -507,21 +506,24 @@ describe('Turn', () => {
it('should yield citation and finished events when response has citationMetadata', async () => { it('should yield citation and finished events when response has citationMetadata', async () => {
const mockResponseStream = (async function* () { const mockResponseStream = (async function* () {
yield { yield {
candidates: [ type: StreamEventType.CHUNK,
{ value: {
content: { parts: [{ text: 'Some text.' }] }, candidates: [
citationMetadata: { {
citations: [ content: { parts: [{ text: 'Some text.' }] },
{ citationMetadata: {
uri: 'https://example.com/source1', citations: [
title: 'Source 1 Title', {
}, uri: 'https://example.com/source1',
], title: 'Source 1 Title',
},
],
},
finishReason: 'STOP',
}, },
finishReason: 'STOP', ],
}, },
], };
} as unknown as GenerateContentResponse;
})(); })();
mockSendMessageStream.mockResolvedValue(mockResponseStream); mockSendMessageStream.mockResolvedValue(mockResponseStream);
@@ -549,25 +551,28 @@ describe('Turn', () => {
it('should yield a single citation event for multiple citations in one response', async () => { it('should yield a single citation event for multiple citations in one response', async () => {
const mockResponseStream = (async function* () { const mockResponseStream = (async function* () {
yield { yield {
candidates: [ type: StreamEventType.CHUNK,
{ value: {
content: { parts: [{ text: 'Some text.' }] }, candidates: [
citationMetadata: { {
citations: [ content: { parts: [{ text: 'Some text.' }] },
{ citationMetadata: {
uri: 'https://example.com/source2', citations: [
title: 'Title2', {
}, uri: 'https://example.com/source2',
{ title: 'Title2',
uri: 'https://example.com/source1', },
title: 'Title1', {
}, uri: 'https://example.com/source1',
], title: 'Title1',
},
],
},
finishReason: 'STOP',
}, },
finishReason: 'STOP', ],
}, },
], };
} as unknown as GenerateContentResponse;
})(); })();
mockSendMessageStream.mockResolvedValue(mockResponseStream); mockSendMessageStream.mockResolvedValue(mockResponseStream);
@@ -596,21 +601,24 @@ describe('Turn', () => {
it('should not yield citation event if there is no finish reason', async () => { it('should not yield citation event if there is no finish reason', async () => {
const mockResponseStream = (async function* () { const mockResponseStream = (async function* () {
yield { yield {
candidates: [ type: StreamEventType.CHUNK,
{ value: {
content: { parts: [{ text: 'Some text.' }] }, candidates: [
citationMetadata: { {
citations: [ content: { parts: [{ text: 'Some text.' }] },
{ citationMetadata: {
uri: 'https://example.com/source1', citations: [
title: 'Source 1 Title', {
}, uri: 'https://example.com/source1',
], title: 'Source 1 Title',
},
],
},
// No finishReason
}, },
// No finishReason ],
}, },
], };
} as unknown as GenerateContentResponse;
})(); })();
mockSendMessageStream.mockResolvedValue(mockResponseStream); mockSendMessageStream.mockResolvedValue(mockResponseStream);
@@ -624,10 +632,6 @@ describe('Turn', () => {
expect(events).toEqual([ expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'Some text.' }, { type: GeminiEventType.Content, value: 'Some text.' },
{
type: GeminiEventType.Finished,
value: { reason: undefined, usageMetadata: undefined },
},
]); ]);
// No Citation event (but we do get a Finished event with undefined reason) // No Citation event (but we do get a Finished event with undefined reason)
expect(events.some((e) => e.type === GeminiEventType.Citation)).toBe( expect(events.some((e) => e.type === GeminiEventType.Citation)).toBe(
@@ -638,25 +642,28 @@ describe('Turn', () => {
it('should ignore citations without a URI', async () => { it('should ignore citations without a URI', async () => {
const mockResponseStream = (async function* () { const mockResponseStream = (async function* () {
yield { yield {
candidates: [ type: StreamEventType.CHUNK,
{ value: {
content: { parts: [{ text: 'Some text.' }] }, candidates: [
citationMetadata: { {
citations: [ content: { parts: [{ text: 'Some text.' }] },
{ citationMetadata: {
uri: 'https://example.com/source1', citations: [
title: 'Good Source', {
}, uri: 'https://example.com/source1',
{ title: 'Good Source',
// uri is undefined },
title: 'Bad Source', {
}, // uri is undefined
], title: 'Bad Source',
},
],
},
finishReason: 'STOP',
}, },
finishReason: 'STOP', ],
}, },
], };
} as unknown as GenerateContentResponse;
})(); })();
mockSendMessageStream.mockResolvedValue(mockResponseStream); mockSendMessageStream.mockResolvedValue(mockResponseStream);
@@ -706,6 +713,29 @@ describe('Turn', () => {
expect(reportError).not.toHaveBeenCalled(); expect(reportError).not.toHaveBeenCalled();
}); });
it('should yield a Retry event when it receives one from the chat stream', async () => {
const mockResponseStream = (async function* () {
yield { type: StreamEventType.RETRY };
yield {
type: StreamEventType.CHUNK,
value: {
candidates: [{ content: { parts: [{ text: 'Success' }] } }],
},
};
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
const events = [];
for await (const event of turn.run([], new AbortController().signal)) {
events.push(event);
}
expect(events).toEqual([
{ type: GeminiEventType.Retry },
{ type: GeminiEventType.Content, value: 'Success' },
]);
});
}); });
describe('getDebugResponses', () => { describe('getDebugResponses', () => {
@@ -717,8 +747,8 @@ describe('Turn', () => {
functionCalls: [{ name: 'debugTool' }], functionCalls: [{ name: 'debugTool' }],
} as unknown as GenerateContentResponse; } as unknown as GenerateContentResponse;
const mockResponseStream = (async function* () { const mockResponseStream = (async function* () {
yield resp1; yield { type: StreamEventType.CHUNK, value: resp1 };
yield resp2; yield { type: StreamEventType.CHUNK, value: resp2 };
})(); })();
mockSendMessageStream.mockResolvedValue(mockResponseStream); mockSendMessageStream.mockResolvedValue(mockResponseStream);
const reqParts: Part[] = [{ text: 'Hi' }]; const reqParts: Part[] = [{ text: 'Hi' }];
+29 -10
View File
@@ -56,8 +56,13 @@ export enum GeminiEventType {
Finished = 'finished', Finished = 'finished',
LoopDetected = 'loop_detected', LoopDetected = 'loop_detected',
Citation = 'citation', Citation = 'citation',
Retry = 'retry',
} }
export type ServerGeminiRetryEvent = {
type: GeminiEventType.Retry;
};
export interface StructuredError { export interface StructuredError {
message: string; message: string;
status?: number; status?: number;
@@ -188,7 +193,8 @@ export type ServerGeminiStreamEvent =
| ServerGeminiToolCallConfirmationEvent | ServerGeminiToolCallConfirmationEvent
| ServerGeminiToolCallRequestEvent | ServerGeminiToolCallRequestEvent
| ServerGeminiToolCallResponseEvent | ServerGeminiToolCallResponseEvent
| ServerGeminiUserCancelledEvent; | ServerGeminiUserCancelledEvent
| ServerGeminiRetryEvent;
// A turn manages the agentic loop turn within the server context. // A turn manages the agentic loop turn within the server context.
export class Turn { export class Turn {
@@ -207,6 +213,8 @@ export class Turn {
signal: AbortSignal, signal: AbortSignal,
): AsyncGenerator<ServerGeminiStreamEvent> { ): AsyncGenerator<ServerGeminiStreamEvent> {
try { try {
// Note: This assumes `sendMessageStream` yields events like
// { type: StreamEventType.RETRY } or { type: StreamEventType.CHUNK, value: GenerateContentResponse }
const responseStream = await this.chat.sendMessageStream( const responseStream = await this.chat.sendMessageStream(
{ {
message: req, message: req,
@@ -217,12 +225,22 @@ export class Turn {
this.prompt_id, this.prompt_id,
); );
for await (const resp of responseStream) { for await (const streamEvent of responseStream) {
if (signal?.aborted) { if (signal?.aborted) {
yield { type: GeminiEventType.UserCancelled }; yield { type: GeminiEventType.UserCancelled };
// Do not add resp to debugResponses if aborted before processing
return; return;
} }
// Handle the new RETRY event
if (streamEvent.type === 'retry') {
yield { type: GeminiEventType.Retry };
continue; // Skip to the next event in the stream
}
// Assuming other events are chunks with a `value` property
const resp = streamEvent.value as GenerateContentResponse;
if (!resp) continue; // Skip if there's no response body
this.debugResponses.push(resp); this.debugResponses.push(resp);
const thoughtPart = resp.candidates?.[0]?.content?.parts?.[0]; const thoughtPart = resp.candidates?.[0]?.content?.parts?.[0];
@@ -268,6 +286,7 @@ export class Turn {
// Check if response was truncated or stopped for various reasons // Check if response was truncated or stopped for various reasons
const finishReason = resp.candidates?.[0]?.finishReason; const finishReason = resp.candidates?.[0]?.finishReason;
// This is the key change: Only yield 'Finished' if there is a finishReason.
if (finishReason) { if (finishReason) {
if (this.pendingCitations.size > 0) { if (this.pendingCitations.size > 0) {
yield { yield {
@@ -278,14 +297,14 @@ export class Turn {
} }
this.finishReason = finishReason; this.finishReason = finishReason;
yield {
type: GeminiEventType.Finished,
value: {
reason: finishReason,
usageMetadata: resp.usageMetadata,
},
};
} }
yield {
type: GeminiEventType.Finished,
value: {
reason: finishReason ? finishReason : undefined,
usageMetadata: resp.usageMetadata,
},
};
} }
} catch (e) { } catch (e) {
if (signal.aborted) { if (signal.aborted) {