diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.test.ts b/packages/cli/src/ui/hooks/atCommandProcessor.test.ts index 837e0d32c0..7403f78890 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.test.ts @@ -98,6 +98,10 @@ describe('handleAtCommand', () => { processedQuery: [{ text: query }], shouldProceed: true, }); + expect(mockAddItem).toHaveBeenCalledWith( + { type: 'user', text: query }, + 123, + ); }); it('should pass through original query if only a lone @ symbol is present', async () => { @@ -116,6 +120,10 @@ describe('handleAtCommand', () => { processedQuery: [{ text: queryWithSpaces }], shouldProceed: true, }); + expect(mockAddItem).toHaveBeenCalledWith( + { type: 'user', text: queryWithSpaces }, + 124, + ); expect(mockOnDebugMessage).toHaveBeenCalledWith( 'Lone @ detected, will be treated as text in the modified query.', ); @@ -148,6 +156,10 @@ describe('handleAtCommand', () => { ], shouldProceed: true, }); + expect(mockAddItem).toHaveBeenCalledWith( + { type: 'user', text: query }, + 125, + ); expect(mockAddItem).toHaveBeenCalledWith( expect.objectContaining({ type: 'tool_group', @@ -186,6 +198,10 @@ describe('handleAtCommand', () => { ], shouldProceed: true, }); + expect(mockAddItem).toHaveBeenCalledWith( + { type: 'user', text: query }, + 126, + ); expect(mockOnDebugMessage).toHaveBeenCalledWith( `Path ${dirPath} resolved to directory, using glob: ${resolvedGlob}`, ); @@ -220,6 +236,10 @@ describe('handleAtCommand', () => { ], shouldProceed: true, }); + expect(mockAddItem).toHaveBeenCalledWith( + { type: 'user', text: query }, + 128, + ); }); it('should correctly unescape paths with escaped spaces', async () => { @@ -250,6 +270,10 @@ describe('handleAtCommand', () => { ], shouldProceed: true, }); + expect(mockAddItem).toHaveBeenCalledWith( + { type: 'user', text: query }, + 125, + ); expect(mockAddItem).toHaveBeenCalledWith( expect.objectContaining({ type: 'tool_group', @@ -1066,37 +1090,4 @@ describe('handleAtCommand', () => { }); }); }); - - it("should not add the user's turn to history, as that is the caller's responsibility", async () => { - // Arrange - const fileContent = 'This is the file content.'; - const filePath = await createTestFile( - path.join(testRootDir, 'path', 'to', 'another-file.txt'), - fileContent, - ); - const query = `A query with @${filePath}`; - - // Act - await handleAtCommand({ - query, - config: mockConfig, - addItem: mockAddItem, - onDebugMessage: mockOnDebugMessage, - messageId: 999, - signal: abortController.signal, - }); - - // Assert - // It SHOULD be called for the tool_group - expect(mockAddItem).toHaveBeenCalledWith( - expect.objectContaining({ type: 'tool_group' }), - 999, - ); - - // It should NOT have been called for the user turn - const userTurnCalls = mockAddItem.mock.calls.filter( - (call) => call[0].type === 'user', - ); - expect(userTurnCalls).toHaveLength(0); - }); }); diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.ts b/packages/cli/src/ui/hooks/atCommandProcessor.ts index 3d139db86b..85ad6f6f58 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.ts @@ -137,9 +137,12 @@ export async function handleAtCommand({ ); if (atPathCommandParts.length === 0) { + addItem({ type: 'user', text: query }, userMessageTimestamp); return { processedQuery: [{ text: query }], shouldProceed: true }; } + addItem({ type: 'user', text: query }, userMessageTimestamp); + // Get centralized file discovery service const fileDiscovery = config.getFileService(); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 5218843648..9eed09124e 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -5,19 +5,10 @@ */ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { - describe, - it, - expect, - vi, - beforeEach, - Mock, - MockInstance, -} from 'vitest'; +import { describe, it, expect, vi, beforeEach, Mock } from 'vitest'; import { renderHook, act, waitFor } from '@testing-library/react'; import { useGeminiStream, mergePartListUnions } from './useGeminiStream.js'; import { useKeypress } from './useKeypress.js'; -import * as atCommandProcessor from './atCommandProcessor.js'; import { useReactToolScheduler, TrackedToolCall, @@ -29,10 +20,8 @@ import { Config, EditorType, AuthType, - GeminiClient, GeminiEventType as ServerGeminiEventType, AnyToolInvocation, - ToolErrorType, // <-- Import ToolErrorType } from '@google/gemini-cli-core'; import { Part, PartListUnion } from '@google/genai'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; @@ -94,7 +83,11 @@ vi.mock('./shellCommandProcessor.js', () => ({ }), })); -vi.mock('./atCommandProcessor.js'); +vi.mock('./atCommandProcessor.js', () => ({ + handleAtCommand: vi + .fn() + .mockResolvedValue({ shouldProceed: true, processedQuery: 'mocked' }), +})); vi.mock('../utils/markdownUtilities.js', () => ({ findLastSafeSplitPoint: vi.fn((s: string) => s.length), @@ -266,7 +259,6 @@ describe('useGeminiStream', () => { let mockScheduleToolCalls: Mock; let mockCancelAllToolCalls: Mock; let mockMarkToolsAsSubmitted: Mock; - let handleAtCommandSpy: MockInstance; beforeEach(() => { vi.clearAllMocks(); // Clear mocks before each test @@ -350,7 +342,6 @@ describe('useGeminiStream', () => { mockSendMessageStream .mockClear() .mockReturnValue((async function* () {})()); - handleAtCommandSpy = vi.spyOn(atCommandProcessor, 'handleAtCommand'); }); const mockLoadedSettings: LoadedSettings = { @@ -456,7 +447,6 @@ describe('useGeminiStream', () => { callId: 'call1', responseParts: [{ text: 'tool 1 response' }], error: undefined, - errorType: undefined, // FIX: Added missing property resultDisplay: 'Tool 1 success display', }, tool: { @@ -522,11 +512,7 @@ describe('useGeminiStream', () => { }, status: 'success', responseSubmittedToGemini: false, - response: { - callId: 'call1', - responseParts: toolCall1ResponseParts, - errorType: undefined, // FIX: Added missing property - }, + response: { callId: 'call1', responseParts: toolCall1ResponseParts }, tool: { displayName: 'MockTool', }, @@ -544,11 +530,7 @@ describe('useGeminiStream', () => { }, status: 'error', responseSubmittedToGemini: false, - response: { - callId: 'call2', - responseParts: toolCall2ResponseParts, - errorType: ToolErrorType.UNHANDLED_EXCEPTION, // FIX: Added missing property - }, + response: { callId: 'call2', responseParts: toolCall2ResponseParts }, } as TrackedCompletedToolCall, // Treat error as a form of completion for submission ]; @@ -615,11 +597,7 @@ describe('useGeminiStream', () => { prompt_id: 'prompt-id-3', }, status: 'cancelled', - response: { - callId: '1', - responseParts: [{ text: 'cancelled' }], - errorType: undefined, // FIX: Added missing property - }, + response: { callId: '1', responseParts: [{ text: 'cancelled' }] }, responseSubmittedToGemini: false, tool: { displayName: 'mock tool', @@ -704,7 +682,6 @@ describe('useGeminiStream', () => { ], resultDisplay: undefined, error: undefined, - errorType: undefined, // FIX: Added missing property }, responseSubmittedToGemini: false, }; @@ -733,7 +710,6 @@ describe('useGeminiStream', () => { ], resultDisplay: undefined, error: undefined, - errorType: undefined, // FIX: Added missing property }, responseSubmittedToGemini: false, }; @@ -836,7 +812,6 @@ describe('useGeminiStream', () => { callId: 'call1', responseParts: toolCallResponseParts, error: undefined, - errorType: undefined, // FIX: Added missing property resultDisplay: 'Tool 1 success display', }, endTime: Date.now(), @@ -1239,7 +1214,6 @@ describe('useGeminiStream', () => { responseParts: [{ text: 'Memory saved' }], resultDisplay: 'Success: Memory saved', error: undefined, - errorType: undefined, // FIX: Added missing property }, tool: { name: 'save_memory', @@ -1783,301 +1757,4 @@ describe('useGeminiStream', () => { ); }); }); - - it('should process @include commands, adding user turn after processing to prevent race conditions', async () => { - const rawQuery = '@include file.txt Summarize this.'; - const processedQueryParts = [ - { text: 'Summarize this with content from @file.txt' }, - { text: 'File content...' }, - ]; - const userMessageTimestamp = Date.now(); - vi.spyOn(Date, 'now').mockReturnValue(userMessageTimestamp); - - handleAtCommandSpy.mockResolvedValue({ - processedQuery: processedQueryParts, - shouldProceed: true, - }); - - const { result } = renderHook(() => - useGeminiStream( - mockConfig.getGeminiClient() as GeminiClient, - [], - mockAddItem, - mockConfig, - mockOnDebugMessage, - mockHandleSlashCommand, - false, // shellModeActive - vi.fn(), // getPreferredEditor - vi.fn(), // onAuthError - vi.fn(), // performMemoryRefresh - false, // modelSwitched - vi.fn(), // setModelSwitched - vi.fn(), // onEditorClose - vi.fn(), // onCancelSubmit - ), - ); - - await act(async () => { - await result.current.submitQuery(rawQuery); - }); - - expect(handleAtCommandSpy).toHaveBeenCalledWith( - expect.objectContaining({ - query: rawQuery, - }), - ); - - expect(mockAddItem).toHaveBeenCalledWith( - { - type: MessageType.USER, - text: rawQuery, - }, - userMessageTimestamp, - ); - - // FIX: The expectation now matches the actual call signature. - expect(mockSendMessageStream).toHaveBeenCalledWith( - processedQueryParts, // Argument 1: The parts array directly - expect.any(AbortSignal), // Argument 2: An AbortSignal - expect.any(String), // Argument 3: The prompt_id string - ); - }); - describe('Thought Reset', () => { - it('should reset thought to null when starting a new prompt', async () => { - mockSendMessageStream.mockReturnValue( - (async function* () { - yield { - type: ServerGeminiEventType.Thought, - value: { - subject: 'Previous thought', - description: 'Old description', - }, - }; - yield { - type: ServerGeminiEventType.Content, - value: 'Some response content', - }; - yield { type: ServerGeminiEventType.Finished, value: 'STOP' }; - })(), - ); - - const { result } = renderHook(() => - useGeminiStream( - new MockedGeminiClientClass(mockConfig), - [], - mockAddItem, - mockConfig, - mockOnDebugMessage, - mockHandleSlashCommand, - false, - () => 'vscode' as EditorType, - () => {}, - () => Promise.resolve(), - false, - () => {}, - () => {}, - () => {}, - ), - ); - - await act(async () => { - await result.current.submitQuery('First query'); - }); - - await waitFor(() => { - expect(mockAddItem).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'gemini', - text: 'Some response content', - }), - expect.any(Number), - ); - }); - - mockSendMessageStream.mockReturnValue( - (async function* () { - yield { - type: ServerGeminiEventType.Content, - value: 'New response content', - }; - yield { type: ServerGeminiEventType.Finished, value: 'STOP' }; - })(), - ); - - await act(async () => { - await result.current.submitQuery('Second query'); - }); - - await waitFor(() => { - expect(mockAddItem).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'gemini', - text: 'New response content', - }), - expect.any(Number), - ); - }); - }); - - it('should reset thought to null when user cancels', async () => { - mockSendMessageStream.mockReturnValue( - (async function* () { - yield { - type: ServerGeminiEventType.Thought, - value: { subject: 'Some thought', description: 'Description' }, - }; - yield { type: ServerGeminiEventType.UserCancelled }; - })(), - ); - - const { result } = renderHook(() => - useGeminiStream( - new MockedGeminiClientClass(mockConfig), - [], - mockAddItem, - mockConfig, - mockOnDebugMessage, - mockHandleSlashCommand, - false, - () => 'vscode' as EditorType, - () => {}, - () => Promise.resolve(), - false, - () => {}, - () => {}, - () => {}, - ), - ); - - await act(async () => { - await result.current.submitQuery('Test query'); - }); - - await waitFor(() => { - expect(mockAddItem).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'info', - text: 'User cancelled the request.', - }), - expect.any(Number), - ); - }); - - expect(result.current.streamingState).toBe(StreamingState.Idle); - }); - - it('should reset thought to null when there is an error', async () => { - mockSendMessageStream.mockReturnValue( - (async function* () { - yield { - type: ServerGeminiEventType.Thought, - value: { subject: 'Some thought', description: 'Description' }, - }; - yield { - type: ServerGeminiEventType.Error, - value: { error: { message: 'Test error' } }, - }; - })(), - ); - - const { result } = renderHook(() => - useGeminiStream( - new MockedGeminiClientClass(mockConfig), - [], - mockAddItem, - mockConfig, - mockOnDebugMessage, - mockHandleSlashCommand, - false, - () => 'vscode' as EditorType, - () => {}, - () => Promise.resolve(), - false, - () => {}, - () => {}, - () => {}, - ), - ); - - await act(async () => { - await result.current.submitQuery('Test query'); - }); - - await waitFor(() => { - expect(mockAddItem).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'error', - }), - expect.any(Number), - ); - }); - - expect(mockParseAndFormatApiError).toHaveBeenCalledWith( - { message: 'Test error' }, - expect.any(String), - undefined, - 'gemini-2.5-pro', - 'gemini-2.5-flash', - ); - }); - }); - - it('should process @include commands, adding user turn after processing to prevent race conditions', async () => { - const rawQuery = '@include file.txt Summarize this.'; - const processedQueryParts = [ - { text: 'Summarize this with content from @file.txt' }, - { text: 'File content...' }, - ]; - const userMessageTimestamp = Date.now(); - vi.spyOn(Date, 'now').mockReturnValue(userMessageTimestamp); - - handleAtCommandSpy.mockResolvedValue({ - processedQuery: processedQueryParts, - shouldProceed: true, - }); - - const { result } = renderHook(() => - useGeminiStream( - mockConfig.getGeminiClient() as GeminiClient, - [], - mockAddItem, - mockConfig, - mockOnDebugMessage, - mockHandleSlashCommand, - false, - vi.fn(), - vi.fn(), - vi.fn(), - false, - vi.fn(), - vi.fn(), - vi.fn(), - ), - ); - - await act(async () => { - await result.current.submitQuery(rawQuery); - }); - - expect(handleAtCommandSpy).toHaveBeenCalledWith( - expect.objectContaining({ - query: rawQuery, - }), - ); - - expect(mockAddItem).toHaveBeenCalledWith( - { - type: MessageType.USER, - text: rawQuery, - }, - userMessageTimestamp, - ); - - // FIX: This expectation now correctly matches the actual function call signature. - expect(mockSendMessageStream).toHaveBeenCalledWith( - processedQueryParts, // Argument 1: The parts array directly - expect.any(AbortSignal), // Argument 2: An AbortSignal - expect.any(String), // Argument 3: The prompt_id string - ); - }); }); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index ad9f99a7b2..99b727b66b 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -306,13 +306,6 @@ export const useGeminiStream = ( messageId: userMessageTimestamp, signal: abortSignal, }); - - // Add user's turn after @ command processing is done. - addItem( - { type: MessageType.USER, text: trimmedQuery }, - userMessageTimestamp, - ); - if (!atCommandResult.shouldProceed) { return { queryToSend: null, shouldProceed: false }; } diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index c4fb7f0f6f..cd5e38412d 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -12,7 +12,7 @@ import { Part, GenerateContentResponse, } from '@google/genai'; -import { GeminiChat, EmptyStreamError } from './geminiChat.js'; +import { GeminiChat } from './geminiChat.js'; import { Config } from '../config/config.js'; import { setSimulate429 } from '../utils/testUtils.js'; @@ -112,13 +112,7 @@ describe('GeminiChat', () => { response, ); - const stream = await chat.sendMessageStream( - { message: 'hello' }, - 'prompt-id-1', - ); - for await (const _ of stream) { - // consume stream to trigger internal logic - } + await chat.sendMessageStream({ message: 'hello' }, 'prompt-id-1'); expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith( { @@ -481,387 +475,4 @@ describe('GeminiChat', () => { expect(history[1]).toEqual(content2); }); }); - - describe('sendMessageStream with retries', () => { - it('should retry on invalid content and succeed on the second attempt', async () => { - // Use mockImplementationOnce to provide a fresh, promise-wrapped generator for each attempt. - vi.mocked(mockModelsModule.generateContentStream) - .mockImplementationOnce(async () => - // First call returns an invalid stream - (async function* () { - yield { - candidates: [{ content: { parts: [{ text: '' }] } }], // Invalid empty text part - } as unknown as GenerateContentResponse; - })(), - ) - .mockImplementationOnce(async () => - // Second call returns a valid stream - (async function* () { - yield { - candidates: [ - { content: { parts: [{ text: 'Successful response' }] } }, - ], - } as unknown as GenerateContentResponse; - })(), - ); - - const stream = await chat.sendMessageStream( - { message: 'test' }, - 'prompt-id-retry-success', - ); - const chunks = []; - for await (const chunk of stream) { - chunks.push(chunk); - } - - // Assertions - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); - expect( - chunks.some( - (c) => - c.candidates?.[0]?.content?.parts?.[0]?.text === - 'Successful response', - ), - ).toBe(true); - - // Check that history was recorded correctly once, with no duplicates. - const history = chat.getHistory(); - expect(history.length).toBe(2); - expect(history[0]).toEqual({ - role: 'user', - parts: [{ text: 'test' }], - }); - expect(history[1]).toEqual({ - role: 'model', - parts: [{ text: 'Successful response' }], - }); - }); - - it('should fail after all retries on persistent invalid content', async () => { - vi.mocked(mockModelsModule.generateContentStream).mockImplementation( - async () => - (async function* () { - yield { - candidates: [ - { - content: { - parts: [{ text: '' }], - role: 'model', - }, - }, - ], - } as unknown as GenerateContentResponse; - })(), - ); - - // This helper function consumes the stream and allows us to test for rejection. - async function consumeStreamAndExpectError() { - const stream = await chat.sendMessageStream( - { message: 'test' }, - 'prompt-id-retry-fail', - ); - for await (const _ of stream) { - // Must loop to trigger the internal logic that throws. - } - } - - await expect(consumeStreamAndExpectError()).rejects.toThrow( - EmptyStreamError, - ); - - // Should be called 3 times (initial + 2 retries) - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(3); - - // History should be clean, as if the failed turn never happened. - const history = chat.getHistory(); - expect(history.length).toBe(0); - }); - }); - it('should correctly retry and append to an existing history mid-conversation', async () => { - // 1. Setup - const initialHistory: Content[] = [ - { role: 'user', parts: [{ text: 'First question' }] }, - { role: 'model', parts: [{ text: 'First answer' }] }, - ]; - chat.setHistory(initialHistory); - - // 2. Mock the API - vi.mocked(mockModelsModule.generateContentStream) - .mockImplementationOnce(async () => - (async function* () { - yield { - candidates: [{ content: { parts: [{ text: '' }] } }], - } as unknown as GenerateContentResponse; - })(), - ) - .mockImplementationOnce(async () => - (async function* () { - yield { - candidates: [{ content: { parts: [{ text: 'Second answer' }] } }], - } as unknown as GenerateContentResponse; - })(), - ); - - // 3. Send a new message - const stream = await chat.sendMessageStream( - { message: 'Second question' }, - 'prompt-id-retry-existing', - ); - for await (const _ of stream) { - // consume stream - } - - // 4. Assert the final history - const history = chat.getHistory(); - expect(history.length).toBe(4); - - // Explicitly verify the structure of each part to satisfy TypeScript - const turn1 = history[0]; - if (!turn1?.parts?.[0] || !('text' in turn1.parts[0])) { - throw new Error('Test setup error: First turn is not a valid text part.'); - } - expect(turn1.parts[0].text).toBe('First question'); - - const turn2 = history[1]; - if (!turn2?.parts?.[0] || !('text' in turn2.parts[0])) { - throw new Error( - 'Test setup error: Second turn is not a valid text part.', - ); - } - expect(turn2.parts[0].text).toBe('First answer'); - - const turn3 = history[2]; - if (!turn3?.parts?.[0] || !('text' in turn3.parts[0])) { - throw new Error('Test setup error: Third turn is not a valid text part.'); - } - expect(turn3.parts[0].text).toBe('Second question'); - - const turn4 = history[3]; - if (!turn4?.parts?.[0] || !('text' in turn4.parts[0])) { - throw new Error( - 'Test setup error: Fourth turn is not a valid text part.', - ); - } - expect(turn4.parts[0].text).toBe('Second answer'); - }); - - describe('concurrency control', () => { - it('should queue a subsequent sendMessage call until the first one completes', async () => { - // 1. Create promises to manually control when the API calls resolve - let firstCallResolver: (value: GenerateContentResponse) => void; - const firstCallPromise = new Promise( - (resolve) => { - firstCallResolver = resolve; - }, - ); - - let secondCallResolver: (value: GenerateContentResponse) => void; - const secondCallPromise = new Promise( - (resolve) => { - secondCallResolver = resolve; - }, - ); - - // A standard response body for the mock - const mockResponse = { - candidates: [ - { - content: { parts: [{ text: 'response' }], role: 'model' }, - }, - ], - } as unknown as GenerateContentResponse; - - // 2. Mock the API to return our controllable promises in order - vi.mocked(mockModelsModule.generateContent) - .mockReturnValueOnce(firstCallPromise) - .mockReturnValueOnce(secondCallPromise); - - // 3. Start the first message call. Do not await it yet. - const firstMessagePromise = chat.sendMessage( - { message: 'first' }, - 'prompt-1', - ); - - // Give the event loop a chance to run the async call up to the `await` - await new Promise(process.nextTick); - - // 4. While the first call is "in-flight", start the second message call. - const secondMessagePromise = chat.sendMessage( - { message: 'second' }, - 'prompt-2', - ); - - // 5. CRUCIAL CHECK: At this point, only the first API call should have been made. - // The second call should be waiting on `sendPromise`. - expect(mockModelsModule.generateContent).toHaveBeenCalledTimes(1); - expect(mockModelsModule.generateContent).toHaveBeenCalledWith( - expect.objectContaining({ - contents: expect.arrayContaining([ - expect.objectContaining({ parts: [{ text: 'first' }] }), - ]), - }), - 'prompt-1', - ); - - // 6. Unblock the first API call and wait for the first message to fully complete. - firstCallResolver!(mockResponse); - await firstMessagePromise; - - // Give the event loop a chance to unblock and run the second call. - await new Promise(process.nextTick); - - // 7. CRUCIAL CHECK: Now, the second API call should have been made. - expect(mockModelsModule.generateContent).toHaveBeenCalledTimes(2); - expect(mockModelsModule.generateContent).toHaveBeenCalledWith( - expect.objectContaining({ - contents: expect.arrayContaining([ - expect.objectContaining({ parts: [{ text: 'second' }] }), - ]), - }), - 'prompt-2', - ); - - // 8. Clean up by resolving the second call. - secondCallResolver!(mockResponse); - await secondMessagePromise; - }); - }); - it('should retry if the model returns a completely empty stream (no chunks)', async () => { - // 1. Mock the API to return an empty stream first, then a valid one. - vi.mocked(mockModelsModule.generateContentStream) - .mockImplementationOnce( - // First call resolves to an async generator that yields nothing. - async () => (async function* () {})(), - ) - .mockImplementationOnce( - // Second call returns a valid stream. - async () => - (async function* () { - yield { - candidates: [ - { - content: { - parts: [{ text: 'Successful response after empty' }], - }, - }, - ], - } as unknown as GenerateContentResponse; - })(), - ); - - // 2. Call the method and consume the stream. - const stream = await chat.sendMessageStream( - { message: 'test empty stream' }, - 'prompt-id-empty-stream', - ); - const chunks = []; - for await (const chunk of stream) { - chunks.push(chunk); - } - - // 3. Assert the results. - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); - expect( - chunks.some( - (c) => - c.candidates?.[0]?.content?.parts?.[0]?.text === - 'Successful response after empty', - ), - ).toBe(true); - - const history = chat.getHistory(); - expect(history.length).toBe(2); - - // Explicitly verify the structure of each part to satisfy TypeScript - const turn1 = history[0]; - if (!turn1?.parts?.[0] || !('text' in turn1.parts[0])) { - throw new Error('Test setup error: First turn is not a valid text part.'); - } - expect(turn1.parts[0].text).toBe('test empty stream'); - - const turn2 = history[1]; - if (!turn2?.parts?.[0] || !('text' in turn2.parts[0])) { - throw new Error( - 'Test setup error: Second turn is not a valid text part.', - ); - } - expect(turn2.parts[0].text).toBe('Successful response after empty'); - }); - it('should queue a subsequent sendMessageStream call until the first stream is fully consumed', async () => { - // 1. Create a promise to manually control the stream's lifecycle - let continueFirstStream: () => void; - const firstStreamContinuePromise = new Promise((resolve) => { - continueFirstStream = resolve; - }); - - // 2. Mock the API to return controllable async generators - const firstStreamGenerator = (async function* () { - yield { - candidates: [ - { content: { parts: [{ text: 'first response part 1' }] } }, - ], - } as unknown as GenerateContentResponse; - await firstStreamContinuePromise; // Pause the stream - yield { - candidates: [{ content: { parts: [{ text: ' part 2' }] } }], - } as unknown as GenerateContentResponse; - })(); - - const secondStreamGenerator = (async function* () { - yield { - candidates: [{ content: { parts: [{ text: 'second response' }] } }], - } as unknown as GenerateContentResponse; - })(); - - vi.mocked(mockModelsModule.generateContentStream) - .mockResolvedValueOnce(firstStreamGenerator) - .mockResolvedValueOnce(secondStreamGenerator); - - // 3. Start the first stream and consume only the first chunk to pause it - const firstStream = await chat.sendMessageStream( - { message: 'first' }, - 'prompt-1', - ); - const firstStreamIterator = firstStream[Symbol.asyncIterator](); - await firstStreamIterator.next(); - - // 4. While the first stream is paused, start the second call. It will block. - const secondStreamPromise = chat.sendMessageStream( - { message: 'second' }, - 'prompt-2', - ); - - // 5. Assert that only one API call has been made so far. - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(1); - - // 6. Unblock and fully consume the first stream to completion. - continueFirstStream!(); - await firstStreamIterator.next(); // Consume the rest of the stream - await firstStreamIterator.next(); // Finish the iterator - - // 7. Now that the first stream is done, await the second promise to get its generator. - const secondStream = await secondStreamPromise; - - // 8. Start consuming the second stream, which triggers its internal API call. - const secondStreamIterator = secondStream[Symbol.asyncIterator](); - await secondStreamIterator.next(); - - // 9. The second API call should now have been made. - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); - - // 10. FIX: Fully consume the second stream to ensure recordHistory is called. - await secondStreamIterator.next(); // This finishes the iterator. - - // 11. Final check on history. - const history = chat.getHistory(); - expect(history.length).toBe(4); - - const turn4 = history[3]; - if (!turn4?.parts?.[0] || !('text' in turn4.parts[0])) { - throw new Error( - 'Test setup error: Fourth turn is not a valid text part.', - ); - } - expect(turn4.parts[0].text).toBe('second response'); - }); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 93428684ba..dfcf425a27 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -24,21 +24,6 @@ import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { hasCycleInSchema } from '../tools/tools.js'; import { StructuredError } from './turn.js'; -/** - * Options for retrying due to invalid content from the model. - */ -interface ContentRetryOptions { - /** Total number of attempts to make (1 initial + N retries). */ - maxAttempts: number; - /** The base delay in milliseconds for linear backoff. */ - initialDelayMs: number; -} - -const INVALID_CONTENT_RETRY_OPTIONS: ContentRetryOptions = { - maxAttempts: 3, // 1 initial call + 2 retries - initialDelayMs: 500, -}; - /** * Returns true if the response is valid, false otherwise. */ @@ -113,23 +98,15 @@ function extractCuratedHistory(comprehensiveHistory: Content[]): Content[] { } if (isValid) { curatedHistory.push(...modelOutput); + } else { + // Remove the last user input when model content is invalid. + curatedHistory.pop(); } } } return curatedHistory; } -/** - * Custom error to signal that a stream completed without valid content, - * which should trigger a retry. - */ -export class EmptyStreamError extends Error { - constructor(message: string) { - super(message); - this.name = 'EmptyStreamError'; - } -} - /** * Chat session that enables sending messages to the model with previous * conversation context. @@ -328,121 +305,65 @@ export class GeminiChat { prompt_id: string, ): Promise> { await this.sendPromise; - - let streamDoneResolver: () => void; - const streamDonePromise = new Promise((resolve) => { - streamDoneResolver = resolve; - }); - this.sendPromise = streamDonePromise; - const userContent = createUserContent(params.message); + const requestContents = this.getHistory(true).concat(userContent); - // Add user content to history ONCE before any attempts. - this.history.push(userContent); - const requestContents = this.getHistory(true); + try { + const apiCall = () => { + const modelToUse = this.config.getModel(); - // eslint-disable-next-line @typescript-eslint/no-this-alias - const self = this; - return (async function* () { - try { - let lastError: unknown = new Error('Request failed after all retries.'); - - for ( - let attempt = 0; - attempt <= INVALID_CONTENT_RETRY_OPTIONS.maxAttempts; - attempt++ + // Prevent Flash model calls immediately after quota error + if ( + this.config.getQuotaErrorOccurred() && + modelToUse === DEFAULT_GEMINI_FLASH_MODEL ) { - try { - const stream = await self.makeApiCallAndProcessStream( - requestContents, - params, - prompt_id, - userContent, - ); - - for await (const chunk of stream) { - yield chunk; - } - - lastError = null; - break; - } catch (error) { - lastError = error; - const isContentError = error instanceof EmptyStreamError; - - if (isContentError) { - // Check if we have more attempts left. - if (attempt < INVALID_CONTENT_RETRY_OPTIONS.maxAttempts - 1) { - await new Promise((res) => - setTimeout( - res, - INVALID_CONTENT_RETRY_OPTIONS.initialDelayMs * - (attempt + 1), - ), - ); - continue; - } - } - break; - } + throw new Error( + 'Please submit a new query to continue with the Flash model.', + ); } - if (lastError) { - // If the stream fails, remove the user message that was added. - if (self.history[self.history.length - 1] === userContent) { - self.history.pop(); - } - throw lastError; - } - } finally { - streamDoneResolver!(); - } - })(); - } - - private async makeApiCallAndProcessStream( - requestContents: Content[], - params: SendMessageParameters, - prompt_id: string, - userContent: Content, - ): Promise> { - const apiCall = () => { - const modelToUse = this.config.getModel(); - - if ( - this.config.getQuotaErrorOccurred() && - modelToUse === DEFAULT_GEMINI_FLASH_MODEL - ) { - throw new Error( - 'Please submit a new query to continue with the Flash model.', + return this.contentGenerator.generateContentStream( + { + model: modelToUse, + contents: requestContents, + config: { ...this.generationConfig, ...params.config }, + }, + prompt_id, ); - } + }; - return this.contentGenerator.generateContentStream( - { - model: modelToUse, - contents: requestContents, - config: { ...this.generationConfig, ...params.config }, + // Note: Retrying streams can be complex. If generateContentStream itself doesn't handle retries + // for transient issues internally before yielding the async generator, this retry will re-initiate + // the stream. For simple 429/500 errors on initial call, this is fine. + // If errors occur mid-stream, this setup won't resume the stream; it will restart it. + const streamResponse = await retryWithBackoff(apiCall, { + shouldRetry: (error: unknown) => { + // Check for known error messages and codes. + if (error instanceof Error && error.message) { + if (isSchemaDepthError(error.message)) return false; + if (error.message.includes('429')) return true; + if (error.message.match(/5\d{2}/)) return true; + } + return false; // Don't retry other errors by default }, - prompt_id, - ); - }; + onPersistent429: async (authType?: string, error?: unknown) => + await this.handleFlashFallback(authType, error), + authType: this.config.getContentGeneratorConfig()?.authType, + }); - const streamResponse = await retryWithBackoff(apiCall, { - shouldRetry: (error: unknown) => { - if (error instanceof Error && error.message) { - if (isSchemaDepthError(error.message)) return false; - if (error.message.includes('429')) return true; - if (error.message.match(/5\d{2}/)) return true; - } - return false; - }, - onPersistent429: async (authType?: string, error?: unknown) => - await this.handleFlashFallback(authType, error), - authType: this.config.getContentGeneratorConfig()?.authType, - }); + // Resolve the internal tracking of send completion promise - `sendPromise` + // for both success and failure response. The actual failure is still + // propagated by the `await streamResponse`. + this.sendPromise = Promise.resolve(streamResponse) + .then(() => undefined) + .catch(() => undefined); - return this.processStreamResponse(streamResponse, userContent); + const result = this.processStreamResponse(streamResponse, userContent); + return result; + } catch (error) { + this.sendPromise = Promise.resolve(); + throw error; + } } /** @@ -486,6 +407,8 @@ export class GeminiChat { /** * Adds a new entry to the chat history. + * + * @param content - The content to add to the history. */ addHistory(content: Content): void { this.history.push(content); @@ -528,41 +451,41 @@ export class GeminiChat { private async *processStreamResponse( streamResponse: AsyncGenerator, - userInput: Content, - ): AsyncGenerator { - const modelResponseParts: Part[] = []; - let isStreamInvalid = false; - let hasReceivedAnyChunk = false; + inputContent: Content, + ) { + const outputContent: Content[] = []; + const chunks: GenerateContentResponse[] = []; + let errorOccurred = false; - for await (const chunk of streamResponse) { - hasReceivedAnyChunk = true; - if (isValidResponse(chunk)) { - const content = chunk.candidates?.[0]?.content; - if (content) { - // Filter out thought parts from being added to history. - if (!this.isThoughtContent(content) && content.parts) { - modelResponseParts.push(...content.parts); + try { + for await (const chunk of streamResponse) { + if (isValidResponse(chunk)) { + chunks.push(chunk); + const content = chunk.candidates?.[0]?.content; + if (content !== undefined) { + if (this.isThoughtContent(content)) { + yield chunk; + continue; + } + outputContent.push(content); } } - } else { - isStreamInvalid = true; + yield chunk; } - yield chunk; // Yield every chunk to the UI immediately. + } catch (error) { + errorOccurred = true; + throw error; } - // Now that the stream is finished, make a decision. - // Throw an error if the stream was invalid OR if it was completely empty. - if (isStreamInvalid || !hasReceivedAnyChunk) { - throw new EmptyStreamError( - 'Model stream was invalid or completed without valid content.', - ); + if (!errorOccurred) { + const allParts: Part[] = []; + for (const content of outputContent) { + if (content.parts) { + allParts.push(...content.parts); + } + } } - - // Use recordHistory to correctly save the conversation turn. - const modelOutput: Content[] = [ - { role: 'model', parts: modelResponseParts }, - ]; - this.recordHistory(userInput, modelOutput); + this.recordHistory(inputContent, outputContent); } private recordHistory( @@ -570,65 +493,88 @@ export class GeminiChat { modelOutput: Content[], automaticFunctionCallingHistory?: Content[], ) { - const newHistoryEntries: Content[] = []; - - // Part 1: Handle the user's part of the turn. - if ( - automaticFunctionCallingHistory && - automaticFunctionCallingHistory.length > 0 - ) { - newHistoryEntries.push( - ...extractCuratedHistory(automaticFunctionCallingHistory), - ); - } else { - // Guard for streaming calls where the user input might already be in the history. - if ( - this.history.length === 0 || - this.history[this.history.length - 1] !== userInput - ) { - newHistoryEntries.push(userInput); - } - } - - // Part 2: Handle the model's part of the turn, filtering out thoughts. const nonThoughtModelOutput = modelOutput.filter( (content) => !this.isThoughtContent(content), ); let outputContents: Content[] = []; - if (nonThoughtModelOutput.length > 0) { - outputContents = nonThoughtModelOutput; - } else if ( - modelOutput.length === 0 && - !isFunctionResponse(userInput) && - !automaticFunctionCallingHistory + if ( + nonThoughtModelOutput.length > 0 && + nonThoughtModelOutput.every((content) => content.role !== undefined) ) { - // Add an empty model response if the model truly returned nothing. - outputContents.push({ role: 'model', parts: [] } as Content); + outputContents = nonThoughtModelOutput; + } else if (nonThoughtModelOutput.length === 0 && modelOutput.length > 0) { + // This case handles when the model returns only a thought. + // We don't want to add an empty model response in this case. + } else { + // When not a function response appends an empty content when model returns empty response, so that the + // history is always alternating between user and model. + // Workaround for: https://b.corp.google.com/issues/420354090 + if (!isFunctionResponse(userInput)) { + outputContents.push({ + role: 'model', + parts: [], + } as Content); + } + } + if ( + automaticFunctionCallingHistory && + automaticFunctionCallingHistory.length > 0 + ) { + this.history.push( + ...extractCuratedHistory(automaticFunctionCallingHistory), + ); + } else { + this.history.push(userInput); } - // Part 3: Consolidate the parts of this turn's model response. + // Consolidate adjacent model roles in outputContents const consolidatedOutputContents: Content[] = []; - if (outputContents.length > 0) { - for (const content of outputContents) { - const lastContent = - consolidatedOutputContents[consolidatedOutputContents.length - 1]; - if (this.hasTextContent(lastContent) && this.hasTextContent(content)) { - lastContent.parts[0].text += content.parts[0].text || ''; - if (content.parts.length > 1) { - lastContent.parts.push(...content.parts.slice(1)); - } - } else { - consolidatedOutputContents.push(content); + for (const content of outputContents) { + if (this.isThoughtContent(content)) { + continue; + } + const lastContent = + consolidatedOutputContents[consolidatedOutputContents.length - 1]; + if (this.isTextContent(lastContent) && this.isTextContent(content)) { + // If both current and last are text, combine their text into the lastContent's first part + // and append any other parts from the current content. + lastContent.parts[0].text += content.parts[0].text || ''; + if (content.parts.length > 1) { + lastContent.parts.push(...content.parts.slice(1)); } + } else { + consolidatedOutputContents.push(content); } } - // Part 4: Add the new turn (user and model parts) to the main history. - this.history.push(...newHistoryEntries, ...consolidatedOutputContents); + if (consolidatedOutputContents.length > 0) { + const lastHistoryEntry = this.history[this.history.length - 1]; + const canMergeWithLastHistory = + !automaticFunctionCallingHistory || + automaticFunctionCallingHistory.length === 0; + + if ( + canMergeWithLastHistory && + this.isTextContent(lastHistoryEntry) && + this.isTextContent(consolidatedOutputContents[0]) + ) { + // If both current and last are text, combine their text into the lastHistoryEntry's first part + // and append any other parts from the current content. + lastHistoryEntry.parts[0].text += + consolidatedOutputContents[0].parts[0].text || ''; + if (consolidatedOutputContents[0].parts.length > 1) { + lastHistoryEntry.parts.push( + ...consolidatedOutputContents[0].parts.slice(1), + ); + } + consolidatedOutputContents.shift(); // Remove the first element as it's merged + } + this.history.push(...consolidatedOutputContents); + } } - private hasTextContent( + private isTextContent( content: Content | undefined, ): content is Content & { parts: [{ text: string }, ...Part[]] } { return !!(