From e7a4142b2aaa4e115ac3a2ddad6316f7a3b398e5 Mon Sep 17 00:00:00 2001 From: Victor May Date: Wed, 3 Sep 2025 22:00:16 -0400 Subject: [PATCH] Handle cleaning up the response text in the UI when a response stream retry occurs (#7416) --- packages/cli/src/ui/hooks/useGeminiStream.ts | 3 + .../cli/src/zed-integration/zedIntegration.ts | 13 +- packages/core/src/core/geminiChat.test.ts | 125 ++++- packages/core/src/core/geminiChat.ts | 20 +- packages/core/src/core/subagent.test.ts | 32 +- packages/core/src/core/subagent.ts | 11 +- packages/core/src/core/turn.test.ts | 468 ++++++++++-------- packages/core/src/core/turn.ts | 39 +- 8 files changed, 455 insertions(+), 256 deletions(-) diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index d09d3fcdea..7e325c3c1b 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -650,6 +650,9 @@ export const useGeminiStream = ( // before we add loop detected message to history loopDetectedRef.current = true; break; + case ServerGeminiEventType.Retry: + // Will add the missing logic later + break; default: { // enforces exhaustive switch-case const unreachable: never = event; diff --git a/packages/cli/src/zed-integration/zedIntegration.ts b/packages/cli/src/zed-integration/zedIntegration.ts index 7762225555..a6a502e4a4 100644 --- a/packages/cli/src/zed-integration/zedIntegration.ts +++ b/packages/cli/src/zed-integration/zedIntegration.ts @@ -24,6 +24,7 @@ import { getErrorStatus, MCPServerConfig, DiscoveredMCPTool, + StreamEventType, } from '@google/gemini-cli-core'; import * as acp from './acp.js'; import { AcpFileSystemService } from './fileSystemService.js'; @@ -269,8 +270,12 @@ class Session { return { stopReason: 'cancelled' }; } - if (resp.candidates && resp.candidates.length > 0) { - const candidate = resp.candidates[0]; + if ( + resp.type === StreamEventType.CHUNK && + resp.value.candidates && + resp.value.candidates.length > 0 + ) { + const candidate = resp.value.candidates[0]; for (const part of candidate.content?.parts ?? []) { if (!part.text) { continue; @@ -290,8 +295,8 @@ class Session { } } - if (resp.functionCalls) { - functionCalls.push(...resp.functionCalls); + if (resp.type === StreamEventType.CHUNK && resp.value.functionCalls) { + functionCalls.push(...resp.value.functionCalls); } } } catch (error) { diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index 156b284237..8592c56091 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -12,7 +12,12 @@ import type { Part, GenerateContentResponse, } from '@google/genai'; -import { GeminiChat, EmptyStreamError } from './geminiChat.js'; +import { + GeminiChat, + EmptyStreamError, + StreamEventType, + type StreamEvent, +} from './geminiChat.js'; import type { Config } from '../config/config.js'; import { setSimulate429 } from '../utils/testUtils.js'; @@ -955,6 +960,42 @@ describe('GeminiChat', () => { }); describe('sendMessageStream with retries', () => { + it('should yield a RETRY event when an invalid stream is encountered', async () => { + // ARRANGE: Mock the stream to fail once, then succeed. + vi.mocked(mockModelsModule.generateContentStream) + .mockImplementationOnce(async () => + // First attempt: An invalid stream with an empty text part. + (async function* () { + yield { + candidates: [{ content: { parts: [{ text: '' }] } }], + } as unknown as GenerateContentResponse; + })(), + ) + .mockImplementationOnce(async () => + // Second attempt (the retry): A minimal valid stream. + (async function* () { + yield { + candidates: [{ content: { parts: [{ text: 'Success' }] } }], + } as unknown as GenerateContentResponse; + })(), + ); + + // ACT: Send a message and collect all events from the stream. + const stream = await chat.sendMessageStream( + { message: 'test' }, + 'prompt-id-yield-retry', + ); + const events: StreamEvent[] = []; + for await (const event of stream) { + events.push(event); + } + + // ASSERT: Check that a RETRY event was present in the stream's output. + const retryEvent = events.find((e) => e.type === StreamEventType.RETRY); + + expect(retryEvent).toBeDefined(); + expect(retryEvent?.type).toBe(StreamEventType.RETRY); + }); it('should retry on invalid content, succeed, and report metrics', async () => { // Use mockImplementationOnce to provide a fresh, promise-wrapped generator for each attempt. vi.mocked(mockModelsModule.generateContentStream) @@ -981,7 +1022,7 @@ describe('GeminiChat', () => { { message: 'test' }, 'prompt-id-retry-success', ); - const chunks = []; + const chunks: StreamEvent[] = []; for await (const chunk of stream) { chunks.push(chunk); } @@ -991,11 +1032,17 @@ describe('GeminiChat', () => { expect(mockLogContentRetry).toHaveBeenCalledTimes(1); expect(mockLogContentRetryFailure).not.toHaveBeenCalled(); expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); + + // Check for a retry event + expect(chunks.some((c) => c.type === StreamEventType.RETRY)).toBe(true); + + // Check for the successful content chunk expect( chunks.some( (c) => - c.candidates?.[0]?.content?.parts?.[0]?.text === - 'Successful response', + c.type === StreamEventType.CHUNK && + c.value.candidates?.[0]?.content?.parts?.[0]?.text === + 'Successful response', ), ).toBe(true); @@ -1236,7 +1283,7 @@ describe('GeminiChat', () => { { message: 'test empty stream' }, 'prompt-id-empty-stream', ); - const chunks = []; + const chunks: StreamEvent[] = []; for await (const chunk of stream) { chunks.push(chunk); } @@ -1246,8 +1293,9 @@ describe('GeminiChat', () => { expect( chunks.some( (c) => - c.candidates?.[0]?.content?.parts?.[0]?.text === - 'Successful response after empty', + c.type === StreamEventType.CHUNK && + c.value.candidates?.[0]?.content?.parts?.[0]?.text === + 'Successful response after empty', ), ).toBe(true); @@ -1346,4 +1394,67 @@ describe('GeminiChat', () => { } expect(turn4.parts[0].text).toBe('second response'); }); + + it('should discard valid partial content from a failed attempt upon retry', async () => { + // ARRANGE: Mock the stream to fail on the first attempt after yielding some valid content. + vi.mocked(mockModelsModule.generateContentStream) + .mockImplementationOnce(async () => + // First attempt: yields one valid chunk, then one invalid chunk + (async function* () { + yield { + candidates: [ + { + content: { + parts: [{ text: 'This valid part should be discarded' }], + }, + }, + ], + } as unknown as GenerateContentResponse; + yield { + candidates: [{ content: { parts: [{ text: '' }] } }], // Invalid chunk triggers retry + } as unknown as GenerateContentResponse; + })(), + ) + .mockImplementationOnce(async () => + // Second attempt (the retry): succeeds + (async function* () { + yield { + candidates: [ + { + content: { + parts: [{ text: 'Successful final response' }], + }, + }, + ], + } as unknown as GenerateContentResponse; + })(), + ); + + // ACT: Send a message and consume the stream + const stream = await chat.sendMessageStream( + { message: 'test' }, + 'prompt-id-discard-test', + ); + const events: StreamEvent[] = []; + for await (const event of stream) { + events.push(event); + } + + // ASSERT + // Check that a retry happened + expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); + expect(events.some((e) => e.type === StreamEventType.RETRY)).toBe(true); + + // Check the final recorded history + const history = chat.getHistory(); + expect(history.length).toBe(2); // user turn + final model turn + + const modelTurn = history[1]!; + // The model turn should only contain the text from the successful attempt + expect(modelTurn!.parts![0]!.text).toBe('Successful final response'); + // It should NOT contain any text from the failed attempt + expect(modelTurn!.parts![0]!.text).not.toContain( + 'This valid part should be discarded', + ); + }); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index af08d0c0a4..17da50078f 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -39,6 +39,18 @@ import { import { isFunctionResponse } from '../utils/messageInspectors.js'; import { partListUnionToString } from './geminiRequest.js'; +export enum StreamEventType { + /** A regular content chunk from the API. */ + CHUNK = 'chunk', + /** A signal that a retry is about to happen. The UI should discard any partial + * content from the attempt that just failed. */ + RETRY = 'retry', +} + +export type StreamEvent = + | { type: StreamEventType.CHUNK; value: GenerateContentResponse } + | { type: StreamEventType.RETRY }; + /** * Options for retrying due to invalid content from the model. */ @@ -360,7 +372,7 @@ export class GeminiChat { async sendMessageStream( params: SendMessageParameters, prompt_id: string, - ): Promise> { + ): Promise> { await this.sendPromise; let streamDoneResolver: () => void; @@ -400,6 +412,10 @@ export class GeminiChat { attempt++ ) { try { + if (attempt > 0) { + yield { type: StreamEventType.RETRY }; + } + const stream = await self.makeApiCallAndProcessStream( requestContents, params, @@ -408,7 +424,7 @@ export class GeminiChat { ); for await (const chunk of stream) { - yield chunk; + yield { type: StreamEventType.CHUNK, value: chunk }; } lastError = null; diff --git a/packages/core/src/core/subagent.test.ts b/packages/core/src/core/subagent.test.ts index f2073306b2..cc54037bad 100644 --- a/packages/core/src/core/subagent.test.ts +++ b/packages/core/src/core/subagent.test.ts @@ -21,7 +21,7 @@ import type { } from './subagent.js'; import { Config } from '../config/config.js'; import type { ConfigParameters } from '../config/config.js'; -import { GeminiChat } from './geminiChat.js'; +import { GeminiChat, StreamEventType } from './geminiChat.js'; import { createContentGenerator } from './contentGenerator.js'; import { getEnvironmentContext } from '../utils/environmentContext.js'; import { executeToolCall } from './nonInteractiveToolExecutor.js'; @@ -33,6 +33,7 @@ import type { FunctionCall, FunctionDeclaration, GenerateContentConfig, + GenerateContentResponse, } from '@google/genai'; import { ToolErrorType } from '../tools/tool-error.js'; @@ -73,18 +74,33 @@ const createMockStream = ( functionCallsList: Array, ) => { let index = 0; - return vi.fn().mockImplementation(() => { + // This mock now returns a Promise that resolves to the async generator, + // matching the new signature for sendMessageStream. + return vi.fn().mockImplementation(async () => { const response = functionCallsList[index] || 'stop'; index++; + return (async function* () { - if (response === 'stop') { - // When stopping, the model might return text, but the subagent logic primarily cares about the absence of functionCalls. - yield { text: 'Done.' }; - } else if (response.length > 0) { - yield { functionCalls: response }; + let mockResponseValue: Partial; + + if (response === 'stop' || response.length === 0) { + // Simulate a text response for stop/empty conditions. + mockResponseValue = { + candidates: [{ content: { parts: [{ text: 'Done.' }] } }], + }; } else { - yield { text: 'Done.' }; // Handle empty array also as stop + // Simulate a tool call response. + mockResponseValue = { + candidates: [], // Good practice to include for safety. + functionCalls: response, + }; } + + // The stream must now yield a StreamEvent object of type CHUNK. + yield { + type: StreamEventType.CHUNK, + value: mockResponseValue as GenerateContentResponse, + }; })(); }); }; diff --git a/packages/core/src/core/subagent.ts b/packages/core/src/core/subagent.ts index 75f77c572c..41de5978a1 100644 --- a/packages/core/src/core/subagent.ts +++ b/packages/core/src/core/subagent.ts @@ -20,7 +20,7 @@ import type { FunctionDeclaration, } from '@google/genai'; import { Type } from '@google/genai'; -import { GeminiChat } from './geminiChat.js'; +import { GeminiChat, StreamEventType } from './geminiChat.js'; /** * @fileoverview Defines the configuration interfaces for a subagent. @@ -439,12 +439,11 @@ export class SubAgentScope { let textResponse = ''; for await (const resp of responseStream) { if (abortController.signal.aborted) return; - if (resp.functionCalls) { - functionCalls.push(...resp.functionCalls); + if (resp.type === StreamEventType.CHUNK && resp.value.functionCalls) { + functionCalls.push(...resp.value.functionCalls); } - const text = resp.text; - if (text) { - textResponse += text; + if (resp.type === StreamEventType.CHUNK && resp.value.text) { + textResponse += resp.value.text; } } diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts index 081a4c63eb..16fdd90fd9 100644 --- a/packages/core/src/core/turn.test.ts +++ b/packages/core/src/core/turn.test.ts @@ -13,6 +13,7 @@ import { Turn, GeminiEventType } from './turn.js'; import type { GenerateContentResponse, Part, Content } from '@google/genai'; import { reportError } from '../utils/errorReporting.js'; import type { GeminiChat } from './geminiChat.js'; +import { StreamEventType } from './geminiChat.js'; const mockSendMessageStream = vi.fn(); const mockGetHistory = vi.fn(); @@ -35,6 +36,7 @@ vi.mock('../utils/errorReporting', () => ({ reportError: vi.fn(), })); +// Use the actual implementation from partUtils now that it's provided. vi.mock('../utils/generateContentResponseUtilities', () => ({ getResponseText: (resp: GenerateContentResponse) => resp.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') || @@ -78,11 +80,17 @@ describe('Turn', () => { it('should yield content events for text parts', async () => { const mockResponseStream = (async function* () { yield { - candidates: [{ content: { parts: [{ text: 'Hello' }] } }], - } as unknown as GenerateContentResponse; + type: StreamEventType.CHUNK, + value: { + candidates: [{ content: { parts: [{ text: 'Hello' }] } }], + } as GenerateContentResponse, + }; yield { - candidates: [{ content: { parts: [{ text: ' world' }] } }], - } as unknown as GenerateContentResponse; + type: StreamEventType.CHUNK, + value: { + candidates: [{ content: { parts: [{ text: ' world' }] } }], + } as GenerateContentResponse, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -105,21 +113,7 @@ describe('Turn', () => { expect(events).toEqual([ { type: GeminiEventType.Content, value: 'Hello' }, - { - type: GeminiEventType.Finished, - value: { - reason: undefined, - usageMetadata: undefined, - }, - }, { type: GeminiEventType.Content, value: ' world' }, - { - type: GeminiEventType.Finished, - value: { - reason: undefined, - usageMetadata: undefined, - }, - }, ]); expect(turn.getDebugResponses().length).toBe(2); }); @@ -127,16 +121,23 @@ describe('Turn', () => { it('should yield tool_call_request events for function calls', async () => { const mockResponseStream = (async function* () { yield { - functionCalls: [ - { - id: 'fc1', - name: 'tool1', - args: { arg1: 'val1' }, - isClientInitiated: false, - }, - { name: 'tool2', args: { arg2: 'val2' }, isClientInitiated: false }, // No ID - ], - } as unknown as GenerateContentResponse; + type: StreamEventType.CHUNK, + value: { + functionCalls: [ + { + id: 'fc1', + name: 'tool1', + args: { arg1: 'val1' }, + isClientInitiated: false, + }, + { + name: 'tool2', + args: { arg2: 'val2' }, + isClientInitiated: false, + }, // No ID + ], + } as unknown as GenerateContentResponse, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -149,7 +150,7 @@ describe('Turn', () => { events.push(event); } - expect(events.length).toBe(3); + expect(events.length).toBe(2); const event1 = events[0] as ServerGeminiToolCallRequestEvent; expect(event1.type).toBe(GeminiEventType.ToolCallRequest); expect(event1.value).toEqual( @@ -182,18 +183,24 @@ describe('Turn', () => { const abortController = new AbortController(); const mockResponseStream = (async function* () { yield { - candidates: [{ content: { parts: [{ text: 'First part' }] } }], - } as unknown as GenerateContentResponse; + type: StreamEventType.CHUNK, + value: { + candidates: [{ content: { parts: [{ text: 'First part' }] } }], + } as GenerateContentResponse, + }; abortController.abort(); yield { - candidates: [ - { - content: { - parts: [{ text: 'Second part - should not be processed' }], + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { + parts: [{ text: 'Second part - should not be processed' }], + }, }, - }, - ], - } as unknown as GenerateContentResponse; + ], + } as GenerateContentResponse, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -204,13 +211,6 @@ describe('Turn', () => { } expect(events).toEqual([ { type: GeminiEventType.Content, value: 'First part' }, - { - type: GeminiEventType.Finished, - value: { - reason: undefined, - usageMetadata: undefined, - }, - }, { type: GeminiEventType.UserCancelled }, ]); expect(turn.getDebugResponses().length).toBe(1); @@ -251,86 +251,79 @@ describe('Turn', () => { it('should handle function calls with undefined name or args', async () => { const mockResponseStream = (async function* () { yield { - functionCalls: [ - { id: 'fc1', name: undefined, args: { arg1: 'val1' } }, - { id: 'fc2', name: 'tool2', args: undefined }, - { id: 'fc3', name: undefined, args: undefined }, - ], - } as unknown as GenerateContentResponse; + type: StreamEventType.CHUNK, + value: { + candidates: [], + functionCalls: [ + // Add `id` back to the mock to match what the code expects + { id: 'fc1', name: undefined, args: { arg1: 'val1' } }, + { id: 'fc2', name: 'tool2', args: undefined }, + { id: 'fc3', name: undefined, args: undefined }, + ], + }, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); + const events = []; - const reqParts: Part[] = [{ text: 'Test undefined tool parts' }]; for await (const event of turn.run( - reqParts, + [{ text: 'Test undefined tool parts' }], new AbortController().signal, )) { events.push(event); } - expect(events.length).toBe(4); + expect(events.length).toBe(3); + + // Assertions for each specific tool call event const event1 = events[0] as ServerGeminiToolCallRequestEvent; - expect(event1.type).toBe(GeminiEventType.ToolCallRequest); - expect(event1.value).toEqual( - expect.objectContaining({ - callId: 'fc1', - name: 'undefined_tool_name', - args: { arg1: 'val1' }, - isClientInitiated: false, - }), - ); - expect(turn.pendingToolCalls[0]).toEqual(event1.value); + expect(event1.value).toMatchObject({ + callId: 'fc1', + name: 'undefined_tool_name', + args: { arg1: 'val1' }, + }); const event2 = events[1] as ServerGeminiToolCallRequestEvent; - expect(event2.type).toBe(GeminiEventType.ToolCallRequest); - expect(event2.value).toEqual( - expect.objectContaining({ - callId: 'fc2', - name: 'tool2', - args: {}, - isClientInitiated: false, - }), - ); - expect(turn.pendingToolCalls[1]).toEqual(event2.value); + expect(event2.value).toMatchObject({ + callId: 'fc2', + name: 'tool2', + args: {}, + }); const event3 = events[2] as ServerGeminiToolCallRequestEvent; - expect(event3.type).toBe(GeminiEventType.ToolCallRequest); - expect(event3.value).toEqual( - expect.objectContaining({ - callId: 'fc3', - name: 'undefined_tool_name', - args: {}, - isClientInitiated: false, - }), - ); - expect(turn.pendingToolCalls[2]).toEqual(event3.value); - expect(turn.getDebugResponses().length).toBe(1); + expect(event3.value).toMatchObject({ + callId: 'fc3', + name: 'undefined_tool_name', + args: {}, + }); }); it('should yield finished event when response has finish reason', async () => { const mockResponseStream = (async function* () { yield { - candidates: [ - { - content: { parts: [{ text: 'Partial response' }] }, - finishReason: 'STOP', + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { parts: [{ text: 'Partial response' }] }, + finishReason: 'STOP', + }, + ], + usageMetadata: { + promptTokenCount: 17, + candidatesTokenCount: 50, + cachedContentTokenCount: 10, + thoughtsTokenCount: 5, + toolUsePromptTokenCount: 2, }, - ], - usageMetadata: { - promptTokenCount: 17, - candidatesTokenCount: 50, - cachedContentTokenCount: 10, - thoughtsTokenCount: 5, - toolUsePromptTokenCount: 2, - }, - } as unknown as GenerateContentResponse; + } as GenerateContentResponse, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); const events = []; - const reqParts: Part[] = [{ text: 'Test finish reason' }]; for await (const event of turn.run( - reqParts, + [{ text: 'Test finish reason' }], new AbortController().signal, )) { events.push(event); @@ -357,17 +350,20 @@ describe('Turn', () => { it('should yield finished event for MAX_TOKENS finish reason', async () => { const mockResponseStream = (async function* () { yield { - candidates: [ - { - content: { - parts: [ - { text: 'This is a long response that was cut off...' }, - ], + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { + parts: [ + { text: 'This is a long response that was cut off...' }, + ], + }, + finishReason: 'MAX_TOKENS', }, - finishReason: 'MAX_TOKENS', - }, - ], - } as unknown as GenerateContentResponse; + ], + }, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -395,13 +391,16 @@ describe('Turn', () => { it('should yield finished event for SAFETY finish reason', async () => { const mockResponseStream = (async function* () { yield { - candidates: [ - { - content: { parts: [{ text: 'Content blocked' }] }, - finishReason: 'SAFETY', - }, - ], - } as unknown as GenerateContentResponse; + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { parts: [{ text: 'Content blocked' }] }, + finishReason: 'SAFETY', + }, + ], + }, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -426,13 +425,18 @@ describe('Turn', () => { it('should yield finished event with undefined reason when there is no finish reason', async () => { const mockResponseStream = (async function* () { yield { - candidates: [ - { - content: { parts: [{ text: 'Response without finish reason' }] }, - // No finishReason property - }, - ], - } as unknown as GenerateContentResponse; + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { + parts: [{ text: 'Response without finish reason' }], + }, + // No finishReason property + }, + ], + }, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -450,31 +454,33 @@ describe('Turn', () => { type: GeminiEventType.Content, value: 'Response without finish reason', }, - { - type: GeminiEventType.Finished, - value: { reason: undefined, usageMetadata: undefined }, - }, ]); }); it('should handle multiple responses with different finish reasons', async () => { const mockResponseStream = (async function* () { yield { - candidates: [ - { - content: { parts: [{ text: 'First part' }] }, - // No finish reason on first response - }, - ], - } as unknown as GenerateContentResponse; + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { parts: [{ text: 'First part' }] }, + // No finish reason on first response + }, + ], + }, + }; yield { - candidates: [ - { - content: { parts: [{ text: 'Second part' }] }, - finishReason: 'OTHER', - }, - ], - } as unknown as GenerateContentResponse; + value: { + type: StreamEventType.CHUNK, + candidates: [ + { + content: { parts: [{ text: 'Second part' }] }, + finishReason: 'OTHER', + }, + ], + }, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -489,13 +495,6 @@ describe('Turn', () => { expect(events).toEqual([ { type: GeminiEventType.Content, value: 'First part' }, - { - type: GeminiEventType.Finished, - value: { - reason: undefined, - usageMetadata: undefined, - }, - }, { type: GeminiEventType.Content, value: 'Second part' }, { type: GeminiEventType.Finished, @@ -507,21 +506,24 @@ describe('Turn', () => { it('should yield citation and finished events when response has citationMetadata', async () => { const mockResponseStream = (async function* () { yield { - candidates: [ - { - content: { parts: [{ text: 'Some text.' }] }, - citationMetadata: { - citations: [ - { - uri: 'https://example.com/source1', - title: 'Source 1 Title', - }, - ], + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { parts: [{ text: 'Some text.' }] }, + citationMetadata: { + citations: [ + { + uri: 'https://example.com/source1', + title: 'Source 1 Title', + }, + ], + }, + finishReason: 'STOP', }, - finishReason: 'STOP', - }, - ], - } as unknown as GenerateContentResponse; + ], + }, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -549,25 +551,28 @@ describe('Turn', () => { it('should yield a single citation event for multiple citations in one response', async () => { const mockResponseStream = (async function* () { yield { - candidates: [ - { - content: { parts: [{ text: 'Some text.' }] }, - citationMetadata: { - citations: [ - { - uri: 'https://example.com/source2', - title: 'Title2', - }, - { - uri: 'https://example.com/source1', - title: 'Title1', - }, - ], + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { parts: [{ text: 'Some text.' }] }, + citationMetadata: { + citations: [ + { + uri: 'https://example.com/source2', + title: 'Title2', + }, + { + uri: 'https://example.com/source1', + title: 'Title1', + }, + ], + }, + finishReason: 'STOP', }, - finishReason: 'STOP', - }, - ], - } as unknown as GenerateContentResponse; + ], + }, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -596,21 +601,24 @@ describe('Turn', () => { it('should not yield citation event if there is no finish reason', async () => { const mockResponseStream = (async function* () { yield { - candidates: [ - { - content: { parts: [{ text: 'Some text.' }] }, - citationMetadata: { - citations: [ - { - uri: 'https://example.com/source1', - title: 'Source 1 Title', - }, - ], + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { parts: [{ text: 'Some text.' }] }, + citationMetadata: { + citations: [ + { + uri: 'https://example.com/source1', + title: 'Source 1 Title', + }, + ], + }, + // No finishReason }, - // No finishReason - }, - ], - } as unknown as GenerateContentResponse; + ], + }, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -624,10 +632,6 @@ describe('Turn', () => { expect(events).toEqual([ { type: GeminiEventType.Content, value: 'Some text.' }, - { - type: GeminiEventType.Finished, - value: { reason: undefined, usageMetadata: undefined }, - }, ]); // No Citation event (but we do get a Finished event with undefined reason) expect(events.some((e) => e.type === GeminiEventType.Citation)).toBe( @@ -638,25 +642,28 @@ describe('Turn', () => { it('should ignore citations without a URI', async () => { const mockResponseStream = (async function* () { yield { - candidates: [ - { - content: { parts: [{ text: 'Some text.' }] }, - citationMetadata: { - citations: [ - { - uri: 'https://example.com/source1', - title: 'Good Source', - }, - { - // uri is undefined - title: 'Bad Source', - }, - ], + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { parts: [{ text: 'Some text.' }] }, + citationMetadata: { + citations: [ + { + uri: 'https://example.com/source1', + title: 'Good Source', + }, + { + // uri is undefined + title: 'Bad Source', + }, + ], + }, + finishReason: 'STOP', }, - finishReason: 'STOP', - }, - ], - } as unknown as GenerateContentResponse; + ], + }, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -706,6 +713,29 @@ describe('Turn', () => { expect(reportError).not.toHaveBeenCalled(); }); + + it('should yield a Retry event when it receives one from the chat stream', async () => { + const mockResponseStream = (async function* () { + yield { type: StreamEventType.RETRY }; + yield { + type: StreamEventType.CHUNK, + value: { + candidates: [{ content: { parts: [{ text: 'Success' }] } }], + }, + }; + })(); + mockSendMessageStream.mockResolvedValue(mockResponseStream); + + const events = []; + for await (const event of turn.run([], new AbortController().signal)) { + events.push(event); + } + + expect(events).toEqual([ + { type: GeminiEventType.Retry }, + { type: GeminiEventType.Content, value: 'Success' }, + ]); + }); }); describe('getDebugResponses', () => { @@ -717,8 +747,8 @@ describe('Turn', () => { functionCalls: [{ name: 'debugTool' }], } as unknown as GenerateContentResponse; const mockResponseStream = (async function* () { - yield resp1; - yield resp2; + yield { type: StreamEventType.CHUNK, value: resp1 }; + yield { type: StreamEventType.CHUNK, value: resp2 }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); const reqParts: Part[] = [{ text: 'Hi' }]; diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 69118d2a9a..95ffa8e761 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -56,8 +56,13 @@ export enum GeminiEventType { Finished = 'finished', LoopDetected = 'loop_detected', Citation = 'citation', + Retry = 'retry', } +export type ServerGeminiRetryEvent = { + type: GeminiEventType.Retry; +}; + export interface StructuredError { message: string; status?: number; @@ -188,7 +193,8 @@ export type ServerGeminiStreamEvent = | ServerGeminiToolCallConfirmationEvent | ServerGeminiToolCallRequestEvent | ServerGeminiToolCallResponseEvent - | ServerGeminiUserCancelledEvent; + | ServerGeminiUserCancelledEvent + | ServerGeminiRetryEvent; // A turn manages the agentic loop turn within the server context. export class Turn { @@ -207,6 +213,8 @@ export class Turn { signal: AbortSignal, ): AsyncGenerator { try { + // Note: This assumes `sendMessageStream` yields events like + // { type: StreamEventType.RETRY } or { type: StreamEventType.CHUNK, value: GenerateContentResponse } const responseStream = await this.chat.sendMessageStream( { message: req, @@ -217,12 +225,22 @@ export class Turn { this.prompt_id, ); - for await (const resp of responseStream) { + for await (const streamEvent of responseStream) { if (signal?.aborted) { yield { type: GeminiEventType.UserCancelled }; - // Do not add resp to debugResponses if aborted before processing return; } + + // Handle the new RETRY event + if (streamEvent.type === 'retry') { + yield { type: GeminiEventType.Retry }; + continue; // Skip to the next event in the stream + } + + // Assuming other events are chunks with a `value` property + const resp = streamEvent.value as GenerateContentResponse; + if (!resp) continue; // Skip if there's no response body + this.debugResponses.push(resp); const thoughtPart = resp.candidates?.[0]?.content?.parts?.[0]; @@ -268,6 +286,7 @@ export class Turn { // Check if response was truncated or stopped for various reasons const finishReason = resp.candidates?.[0]?.finishReason; + // This is the key change: Only yield 'Finished' if there is a finishReason. if (finishReason) { if (this.pendingCitations.size > 0) { yield { @@ -278,14 +297,14 @@ export class Turn { } this.finishReason = finishReason; + yield { + type: GeminiEventType.Finished, + value: { + reason: finishReason, + usageMetadata: resp.usageMetadata, + }, + }; } - yield { - type: GeminiEventType.Finished, - value: { - reason: finishReason ? finishReason : undefined, - usageMetadata: resp.usageMetadata, - }, - }; } } catch (e) { if (signal.aborted) {