Hotfix/retry stream 6777 (#6881)

Co-authored-by: Victor May <mayvic@google.com>
This commit is contained in:
Silvio Junior
2025-08-22 17:37:13 -04:00
committed by GitHub
parent dda33bb360
commit 86fd6419a3
6 changed files with 965 additions and 186 deletions

View File

@@ -98,10 +98,6 @@ describe('handleAtCommand', () => {
processedQuery: [{ text: query }],
shouldProceed: true,
});
expect(mockAddItem).toHaveBeenCalledWith(
{ type: 'user', text: query },
123,
);
});
it('should pass through original query if only a lone @ symbol is present', async () => {
@@ -120,10 +116,6 @@ describe('handleAtCommand', () => {
processedQuery: [{ text: queryWithSpaces }],
shouldProceed: true,
});
expect(mockAddItem).toHaveBeenCalledWith(
{ type: 'user', text: queryWithSpaces },
124,
);
expect(mockOnDebugMessage).toHaveBeenCalledWith(
'Lone @ detected, will be treated as text in the modified query.',
);
@@ -156,10 +148,6 @@ describe('handleAtCommand', () => {
],
shouldProceed: true,
});
expect(mockAddItem).toHaveBeenCalledWith(
{ type: 'user', text: query },
125,
);
expect(mockAddItem).toHaveBeenCalledWith(
expect.objectContaining({
type: 'tool_group',
@@ -198,10 +186,6 @@ describe('handleAtCommand', () => {
],
shouldProceed: true,
});
expect(mockAddItem).toHaveBeenCalledWith(
{ type: 'user', text: query },
126,
);
expect(mockOnDebugMessage).toHaveBeenCalledWith(
`Path ${dirPath} resolved to directory, using glob: ${resolvedGlob}`,
);
@@ -236,10 +220,6 @@ describe('handleAtCommand', () => {
],
shouldProceed: true,
});
expect(mockAddItem).toHaveBeenCalledWith(
{ type: 'user', text: query },
128,
);
});
it('should correctly unescape paths with escaped spaces', async () => {
@@ -270,10 +250,6 @@ describe('handleAtCommand', () => {
],
shouldProceed: true,
});
expect(mockAddItem).toHaveBeenCalledWith(
{ type: 'user', text: query },
125,
);
expect(mockAddItem).toHaveBeenCalledWith(
expect.objectContaining({
type: 'tool_group',
@@ -1090,4 +1066,37 @@ describe('handleAtCommand', () => {
});
});
});
it("should not add the user's turn to history, as that is the caller's responsibility", async () => {
// Arrange
const fileContent = 'This is the file content.';
const filePath = await createTestFile(
path.join(testRootDir, 'path', 'to', 'another-file.txt'),
fileContent,
);
const query = `A query with @${filePath}`;
// Act
await handleAtCommand({
query,
config: mockConfig,
addItem: mockAddItem,
onDebugMessage: mockOnDebugMessage,
messageId: 999,
signal: abortController.signal,
});
// Assert
// It SHOULD be called for the tool_group
expect(mockAddItem).toHaveBeenCalledWith(
expect.objectContaining({ type: 'tool_group' }),
999,
);
// It should NOT have been called for the user turn
const userTurnCalls = mockAddItem.mock.calls.filter(
(call) => call[0].type === 'user',
);
expect(userTurnCalls).toHaveLength(0);
});
});

View File

@@ -137,12 +137,9 @@ export async function handleAtCommand({
);
if (atPathCommandParts.length === 0) {
addItem({ type: 'user', text: query }, userMessageTimestamp);
return { processedQuery: [{ text: query }], shouldProceed: true };
}
addItem({ type: 'user', text: query }, userMessageTimestamp);
// Get centralized file discovery service
const fileDiscovery = config.getFileService();

View File

@@ -5,10 +5,19 @@
*/
/* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, it, expect, vi, beforeEach, Mock } from 'vitest';
import {
describe,
it,
expect,
vi,
beforeEach,
Mock,
MockInstance,
} from 'vitest';
import { renderHook, act, waitFor } from '@testing-library/react';
import { useGeminiStream, mergePartListUnions } from './useGeminiStream.js';
import { useKeypress } from './useKeypress.js';
import * as atCommandProcessor from './atCommandProcessor.js';
import {
useReactToolScheduler,
TrackedToolCall,
@@ -20,8 +29,10 @@ import {
Config,
EditorType,
AuthType,
GeminiClient,
GeminiEventType as ServerGeminiEventType,
AnyToolInvocation,
ToolErrorType, // <-- Import ToolErrorType
} from '@google/gemini-cli-core';
import { Part, PartListUnion } from '@google/genai';
import { UseHistoryManagerReturn } from './useHistoryManager.js';
@@ -83,11 +94,7 @@ vi.mock('./shellCommandProcessor.js', () => ({
}),
}));
vi.mock('./atCommandProcessor.js', () => ({
handleAtCommand: vi
.fn()
.mockResolvedValue({ shouldProceed: true, processedQuery: 'mocked' }),
}));
vi.mock('./atCommandProcessor.js');
vi.mock('../utils/markdownUtilities.js', () => ({
findLastSafeSplitPoint: vi.fn((s: string) => s.length),
@@ -259,6 +266,7 @@ describe('useGeminiStream', () => {
let mockScheduleToolCalls: Mock;
let mockCancelAllToolCalls: Mock;
let mockMarkToolsAsSubmitted: Mock;
let handleAtCommandSpy: MockInstance;
beforeEach(() => {
vi.clearAllMocks(); // Clear mocks before each test
@@ -342,6 +350,7 @@ describe('useGeminiStream', () => {
mockSendMessageStream
.mockClear()
.mockReturnValue((async function* () {})());
handleAtCommandSpy = vi.spyOn(atCommandProcessor, 'handleAtCommand');
});
const mockLoadedSettings: LoadedSettings = {
@@ -447,6 +456,7 @@ describe('useGeminiStream', () => {
callId: 'call1',
responseParts: [{ text: 'tool 1 response' }],
error: undefined,
errorType: undefined, // FIX: Added missing property
resultDisplay: 'Tool 1 success display',
},
tool: {
@@ -512,7 +522,11 @@ describe('useGeminiStream', () => {
},
status: 'success',
responseSubmittedToGemini: false,
response: { callId: 'call1', responseParts: toolCall1ResponseParts },
response: {
callId: 'call1',
responseParts: toolCall1ResponseParts,
errorType: undefined, // FIX: Added missing property
},
tool: {
displayName: 'MockTool',
},
@@ -530,7 +544,11 @@ describe('useGeminiStream', () => {
},
status: 'error',
responseSubmittedToGemini: false,
response: { callId: 'call2', responseParts: toolCall2ResponseParts },
response: {
callId: 'call2',
responseParts: toolCall2ResponseParts,
errorType: ToolErrorType.UNHANDLED_EXCEPTION, // FIX: Added missing property
},
} as TrackedCompletedToolCall, // Treat error as a form of completion for submission
];
@@ -597,7 +615,11 @@ describe('useGeminiStream', () => {
prompt_id: 'prompt-id-3',
},
status: 'cancelled',
response: { callId: '1', responseParts: [{ text: 'cancelled' }] },
response: {
callId: '1',
responseParts: [{ text: 'cancelled' }],
errorType: undefined, // FIX: Added missing property
},
responseSubmittedToGemini: false,
tool: {
displayName: 'mock tool',
@@ -682,6 +704,7 @@ describe('useGeminiStream', () => {
],
resultDisplay: undefined,
error: undefined,
errorType: undefined, // FIX: Added missing property
},
responseSubmittedToGemini: false,
};
@@ -710,6 +733,7 @@ describe('useGeminiStream', () => {
],
resultDisplay: undefined,
error: undefined,
errorType: undefined, // FIX: Added missing property
},
responseSubmittedToGemini: false,
};
@@ -812,6 +836,7 @@ describe('useGeminiStream', () => {
callId: 'call1',
responseParts: toolCallResponseParts,
error: undefined,
errorType: undefined, // FIX: Added missing property
resultDisplay: 'Tool 1 success display',
},
endTime: Date.now(),
@@ -1214,6 +1239,7 @@ describe('useGeminiStream', () => {
responseParts: [{ text: 'Memory saved' }],
resultDisplay: 'Success: Memory saved',
error: undefined,
errorType: undefined, // FIX: Added missing property
},
tool: {
name: 'save_memory',
@@ -1757,4 +1783,301 @@ describe('useGeminiStream', () => {
);
});
});
it('should process @include commands, adding user turn after processing to prevent race conditions', async () => {
const rawQuery = '@include file.txt Summarize this.';
const processedQueryParts = [
{ text: 'Summarize this with content from @file.txt' },
{ text: 'File content...' },
];
const userMessageTimestamp = Date.now();
vi.spyOn(Date, 'now').mockReturnValue(userMessageTimestamp);
handleAtCommandSpy.mockResolvedValue({
processedQuery: processedQueryParts,
shouldProceed: true,
});
const { result } = renderHook(() =>
useGeminiStream(
mockConfig.getGeminiClient() as GeminiClient,
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false, // shellModeActive
vi.fn(), // getPreferredEditor
vi.fn(), // onAuthError
vi.fn(), // performMemoryRefresh
false, // modelSwitched
vi.fn(), // setModelSwitched
vi.fn(), // onEditorClose
vi.fn(), // onCancelSubmit
),
);
await act(async () => {
await result.current.submitQuery(rawQuery);
});
expect(handleAtCommandSpy).toHaveBeenCalledWith(
expect.objectContaining({
query: rawQuery,
}),
);
expect(mockAddItem).toHaveBeenCalledWith(
{
type: MessageType.USER,
text: rawQuery,
},
userMessageTimestamp,
);
// FIX: The expectation now matches the actual call signature.
expect(mockSendMessageStream).toHaveBeenCalledWith(
processedQueryParts, // Argument 1: The parts array directly
expect.any(AbortSignal), // Argument 2: An AbortSignal
expect.any(String), // Argument 3: The prompt_id string
);
});
describe('Thought Reset', () => {
it('should reset thought to null when starting a new prompt', async () => {
mockSendMessageStream.mockReturnValue(
(async function* () {
yield {
type: ServerGeminiEventType.Thought,
value: {
subject: 'Previous thought',
description: 'Old description',
},
};
yield {
type: ServerGeminiEventType.Content,
value: 'Some response content',
};
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})(),
);
const { result } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
),
);
await act(async () => {
await result.current.submitQuery('First query');
});
await waitFor(() => {
expect(mockAddItem).toHaveBeenCalledWith(
expect.objectContaining({
type: 'gemini',
text: 'Some response content',
}),
expect.any(Number),
);
});
mockSendMessageStream.mockReturnValue(
(async function* () {
yield {
type: ServerGeminiEventType.Content,
value: 'New response content',
};
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})(),
);
await act(async () => {
await result.current.submitQuery('Second query');
});
await waitFor(() => {
expect(mockAddItem).toHaveBeenCalledWith(
expect.objectContaining({
type: 'gemini',
text: 'New response content',
}),
expect.any(Number),
);
});
});
it('should reset thought to null when user cancels', async () => {
mockSendMessageStream.mockReturnValue(
(async function* () {
yield {
type: ServerGeminiEventType.Thought,
value: { subject: 'Some thought', description: 'Description' },
};
yield { type: ServerGeminiEventType.UserCancelled };
})(),
);
const { result } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
),
);
await act(async () => {
await result.current.submitQuery('Test query');
});
await waitFor(() => {
expect(mockAddItem).toHaveBeenCalledWith(
expect.objectContaining({
type: 'info',
text: 'User cancelled the request.',
}),
expect.any(Number),
);
});
expect(result.current.streamingState).toBe(StreamingState.Idle);
});
it('should reset thought to null when there is an error', async () => {
mockSendMessageStream.mockReturnValue(
(async function* () {
yield {
type: ServerGeminiEventType.Thought,
value: { subject: 'Some thought', description: 'Description' },
};
yield {
type: ServerGeminiEventType.Error,
value: { error: { message: 'Test error' } },
};
})(),
);
const { result } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
false,
() => {},
() => {},
() => {},
),
);
await act(async () => {
await result.current.submitQuery('Test query');
});
await waitFor(() => {
expect(mockAddItem).toHaveBeenCalledWith(
expect.objectContaining({
type: 'error',
}),
expect.any(Number),
);
});
expect(mockParseAndFormatApiError).toHaveBeenCalledWith(
{ message: 'Test error' },
expect.any(String),
undefined,
'gemini-2.5-pro',
'gemini-2.5-flash',
);
});
});
it('should process @include commands, adding user turn after processing to prevent race conditions', async () => {
const rawQuery = '@include file.txt Summarize this.';
const processedQueryParts = [
{ text: 'Summarize this with content from @file.txt' },
{ text: 'File content...' },
];
const userMessageTimestamp = Date.now();
vi.spyOn(Date, 'now').mockReturnValue(userMessageTimestamp);
handleAtCommandSpy.mockResolvedValue({
processedQuery: processedQueryParts,
shouldProceed: true,
});
const { result } = renderHook(() =>
useGeminiStream(
mockConfig.getGeminiClient() as GeminiClient,
[],
mockAddItem,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
vi.fn(),
vi.fn(),
vi.fn(),
false,
vi.fn(),
vi.fn(),
vi.fn(),
),
);
await act(async () => {
await result.current.submitQuery(rawQuery);
});
expect(handleAtCommandSpy).toHaveBeenCalledWith(
expect.objectContaining({
query: rawQuery,
}),
);
expect(mockAddItem).toHaveBeenCalledWith(
{
type: MessageType.USER,
text: rawQuery,
},
userMessageTimestamp,
);
// FIX: This expectation now correctly matches the actual function call signature.
expect(mockSendMessageStream).toHaveBeenCalledWith(
processedQueryParts, // Argument 1: The parts array directly
expect.any(AbortSignal), // Argument 2: An AbortSignal
expect.any(String), // Argument 3: The prompt_id string
);
});
});

View File

@@ -306,6 +306,13 @@ export const useGeminiStream = (
messageId: userMessageTimestamp,
signal: abortSignal,
});
// Add user's turn after @ command processing is done.
addItem(
{ type: MessageType.USER, text: trimmedQuery },
userMessageTimestamp,
);
if (!atCommandResult.shouldProceed) {
return { queryToSend: null, shouldProceed: false };
}

View File

@@ -12,7 +12,7 @@ import {
Part,
GenerateContentResponse,
} from '@google/genai';
import { GeminiChat } from './geminiChat.js';
import { GeminiChat, EmptyStreamError } from './geminiChat.js';
import { Config } from '../config/config.js';
import { setSimulate429 } from '../utils/testUtils.js';
@@ -112,7 +112,13 @@ describe('GeminiChat', () => {
response,
);
await chat.sendMessageStream({ message: 'hello' }, 'prompt-id-1');
const stream = await chat.sendMessageStream(
{ message: 'hello' },
'prompt-id-1',
);
for await (const _ of stream) {
// consume stream to trigger internal logic
}
expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith(
{
@@ -475,4 +481,387 @@ describe('GeminiChat', () => {
expect(history[1]).toEqual(content2);
});
});
describe('sendMessageStream with retries', () => {
it('should retry on invalid content and succeed on the second attempt', async () => {
// Use mockImplementationOnce to provide a fresh, promise-wrapped generator for each attempt.
vi.mocked(mockModelsModule.generateContentStream)
.mockImplementationOnce(async () =>
// First call returns an invalid stream
(async function* () {
yield {
candidates: [{ content: { parts: [{ text: '' }] } }], // Invalid empty text part
} as unknown as GenerateContentResponse;
})(),
)
.mockImplementationOnce(async () =>
// Second call returns a valid stream
(async function* () {
yield {
candidates: [
{ content: { parts: [{ text: 'Successful response' }] } },
],
} as unknown as GenerateContentResponse;
})(),
);
const stream = await chat.sendMessageStream(
{ message: 'test' },
'prompt-id-retry-success',
);
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
// Assertions
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2);
expect(
chunks.some(
(c) =>
c.candidates?.[0]?.content?.parts?.[0]?.text ===
'Successful response',
),
).toBe(true);
// Check that history was recorded correctly once, with no duplicates.
const history = chat.getHistory();
expect(history.length).toBe(2);
expect(history[0]).toEqual({
role: 'user',
parts: [{ text: 'test' }],
});
expect(history[1]).toEqual({
role: 'model',
parts: [{ text: 'Successful response' }],
});
});
it('should fail after all retries on persistent invalid content', async () => {
vi.mocked(mockModelsModule.generateContentStream).mockImplementation(
async () =>
(async function* () {
yield {
candidates: [
{
content: {
parts: [{ text: '' }],
role: 'model',
},
},
],
} as unknown as GenerateContentResponse;
})(),
);
// This helper function consumes the stream and allows us to test for rejection.
async function consumeStreamAndExpectError() {
const stream = await chat.sendMessageStream(
{ message: 'test' },
'prompt-id-retry-fail',
);
for await (const _ of stream) {
// Must loop to trigger the internal logic that throws.
}
}
await expect(consumeStreamAndExpectError()).rejects.toThrow(
EmptyStreamError,
);
// Should be called 3 times (initial + 2 retries)
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(3);
// History should be clean, as if the failed turn never happened.
const history = chat.getHistory();
expect(history.length).toBe(0);
});
});
it('should correctly retry and append to an existing history mid-conversation', async () => {
// 1. Setup
const initialHistory: Content[] = [
{ role: 'user', parts: [{ text: 'First question' }] },
{ role: 'model', parts: [{ text: 'First answer' }] },
];
chat.setHistory(initialHistory);
// 2. Mock the API
vi.mocked(mockModelsModule.generateContentStream)
.mockImplementationOnce(async () =>
(async function* () {
yield {
candidates: [{ content: { parts: [{ text: '' }] } }],
} as unknown as GenerateContentResponse;
})(),
)
.mockImplementationOnce(async () =>
(async function* () {
yield {
candidates: [{ content: { parts: [{ text: 'Second answer' }] } }],
} as unknown as GenerateContentResponse;
})(),
);
// 3. Send a new message
const stream = await chat.sendMessageStream(
{ message: 'Second question' },
'prompt-id-retry-existing',
);
for await (const _ of stream) {
// consume stream
}
// 4. Assert the final history
const history = chat.getHistory();
expect(history.length).toBe(4);
// Explicitly verify the structure of each part to satisfy TypeScript
const turn1 = history[0];
if (!turn1?.parts?.[0] || !('text' in turn1.parts[0])) {
throw new Error('Test setup error: First turn is not a valid text part.');
}
expect(turn1.parts[0].text).toBe('First question');
const turn2 = history[1];
if (!turn2?.parts?.[0] || !('text' in turn2.parts[0])) {
throw new Error(
'Test setup error: Second turn is not a valid text part.',
);
}
expect(turn2.parts[0].text).toBe('First answer');
const turn3 = history[2];
if (!turn3?.parts?.[0] || !('text' in turn3.parts[0])) {
throw new Error('Test setup error: Third turn is not a valid text part.');
}
expect(turn3.parts[0].text).toBe('Second question');
const turn4 = history[3];
if (!turn4?.parts?.[0] || !('text' in turn4.parts[0])) {
throw new Error(
'Test setup error: Fourth turn is not a valid text part.',
);
}
expect(turn4.parts[0].text).toBe('Second answer');
});
describe('concurrency control', () => {
it('should queue a subsequent sendMessage call until the first one completes', async () => {
// 1. Create promises to manually control when the API calls resolve
let firstCallResolver: (value: GenerateContentResponse) => void;
const firstCallPromise = new Promise<GenerateContentResponse>(
(resolve) => {
firstCallResolver = resolve;
},
);
let secondCallResolver: (value: GenerateContentResponse) => void;
const secondCallPromise = new Promise<GenerateContentResponse>(
(resolve) => {
secondCallResolver = resolve;
},
);
// A standard response body for the mock
const mockResponse = {
candidates: [
{
content: { parts: [{ text: 'response' }], role: 'model' },
},
],
} as unknown as GenerateContentResponse;
// 2. Mock the API to return our controllable promises in order
vi.mocked(mockModelsModule.generateContent)
.mockReturnValueOnce(firstCallPromise)
.mockReturnValueOnce(secondCallPromise);
// 3. Start the first message call. Do not await it yet.
const firstMessagePromise = chat.sendMessage(
{ message: 'first' },
'prompt-1',
);
// Give the event loop a chance to run the async call up to the `await`
await new Promise(process.nextTick);
// 4. While the first call is "in-flight", start the second message call.
const secondMessagePromise = chat.sendMessage(
{ message: 'second' },
'prompt-2',
);
// 5. CRUCIAL CHECK: At this point, only the first API call should have been made.
// The second call should be waiting on `sendPromise`.
expect(mockModelsModule.generateContent).toHaveBeenCalledTimes(1);
expect(mockModelsModule.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
contents: expect.arrayContaining([
expect.objectContaining({ parts: [{ text: 'first' }] }),
]),
}),
'prompt-1',
);
// 6. Unblock the first API call and wait for the first message to fully complete.
firstCallResolver!(mockResponse);
await firstMessagePromise;
// Give the event loop a chance to unblock and run the second call.
await new Promise(process.nextTick);
// 7. CRUCIAL CHECK: Now, the second API call should have been made.
expect(mockModelsModule.generateContent).toHaveBeenCalledTimes(2);
expect(mockModelsModule.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
contents: expect.arrayContaining([
expect.objectContaining({ parts: [{ text: 'second' }] }),
]),
}),
'prompt-2',
);
// 8. Clean up by resolving the second call.
secondCallResolver!(mockResponse);
await secondMessagePromise;
});
});
it('should retry if the model returns a completely empty stream (no chunks)', async () => {
// 1. Mock the API to return an empty stream first, then a valid one.
vi.mocked(mockModelsModule.generateContentStream)
.mockImplementationOnce(
// First call resolves to an async generator that yields nothing.
async () => (async function* () {})(),
)
.mockImplementationOnce(
// Second call returns a valid stream.
async () =>
(async function* () {
yield {
candidates: [
{
content: {
parts: [{ text: 'Successful response after empty' }],
},
},
],
} as unknown as GenerateContentResponse;
})(),
);
// 2. Call the method and consume the stream.
const stream = await chat.sendMessageStream(
{ message: 'test empty stream' },
'prompt-id-empty-stream',
);
const chunks = [];
for await (const chunk of stream) {
chunks.push(chunk);
}
// 3. Assert the results.
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2);
expect(
chunks.some(
(c) =>
c.candidates?.[0]?.content?.parts?.[0]?.text ===
'Successful response after empty',
),
).toBe(true);
const history = chat.getHistory();
expect(history.length).toBe(2);
// Explicitly verify the structure of each part to satisfy TypeScript
const turn1 = history[0];
if (!turn1?.parts?.[0] || !('text' in turn1.parts[0])) {
throw new Error('Test setup error: First turn is not a valid text part.');
}
expect(turn1.parts[0].text).toBe('test empty stream');
const turn2 = history[1];
if (!turn2?.parts?.[0] || !('text' in turn2.parts[0])) {
throw new Error(
'Test setup error: Second turn is not a valid text part.',
);
}
expect(turn2.parts[0].text).toBe('Successful response after empty');
});
it('should queue a subsequent sendMessageStream call until the first stream is fully consumed', async () => {
// 1. Create a promise to manually control the stream's lifecycle
let continueFirstStream: () => void;
const firstStreamContinuePromise = new Promise<void>((resolve) => {
continueFirstStream = resolve;
});
// 2. Mock the API to return controllable async generators
const firstStreamGenerator = (async function* () {
yield {
candidates: [
{ content: { parts: [{ text: 'first response part 1' }] } },
],
} as unknown as GenerateContentResponse;
await firstStreamContinuePromise; // Pause the stream
yield {
candidates: [{ content: { parts: [{ text: ' part 2' }] } }],
} as unknown as GenerateContentResponse;
})();
const secondStreamGenerator = (async function* () {
yield {
candidates: [{ content: { parts: [{ text: 'second response' }] } }],
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockModelsModule.generateContentStream)
.mockResolvedValueOnce(firstStreamGenerator)
.mockResolvedValueOnce(secondStreamGenerator);
// 3. Start the first stream and consume only the first chunk to pause it
const firstStream = await chat.sendMessageStream(
{ message: 'first' },
'prompt-1',
);
const firstStreamIterator = firstStream[Symbol.asyncIterator]();
await firstStreamIterator.next();
// 4. While the first stream is paused, start the second call. It will block.
const secondStreamPromise = chat.sendMessageStream(
{ message: 'second' },
'prompt-2',
);
// 5. Assert that only one API call has been made so far.
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(1);
// 6. Unblock and fully consume the first stream to completion.
continueFirstStream!();
await firstStreamIterator.next(); // Consume the rest of the stream
await firstStreamIterator.next(); // Finish the iterator
// 7. Now that the first stream is done, await the second promise to get its generator.
const secondStream = await secondStreamPromise;
// 8. Start consuming the second stream, which triggers its internal API call.
const secondStreamIterator = secondStream[Symbol.asyncIterator]();
await secondStreamIterator.next();
// 9. The second API call should now have been made.
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2);
// 10. FIX: Fully consume the second stream to ensure recordHistory is called.
await secondStreamIterator.next(); // This finishes the iterator.
// 11. Final check on history.
const history = chat.getHistory();
expect(history.length).toBe(4);
const turn4 = history[3];
if (!turn4?.parts?.[0] || !('text' in turn4.parts[0])) {
throw new Error(
'Test setup error: Fourth turn is not a valid text part.',
);
}
expect(turn4.parts[0].text).toBe('second response');
});
});

View File

@@ -24,6 +24,21 @@ import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { hasCycleInSchema } from '../tools/tools.js';
import { StructuredError } from './turn.js';
/**
* Options for retrying due to invalid content from the model.
*/
interface ContentRetryOptions {
/** Total number of attempts to make (1 initial + N retries). */
maxAttempts: number;
/** The base delay in milliseconds for linear backoff. */
initialDelayMs: number;
}
const INVALID_CONTENT_RETRY_OPTIONS: ContentRetryOptions = {
maxAttempts: 3, // 1 initial call + 2 retries
initialDelayMs: 500,
};
/**
* Returns true if the response is valid, false otherwise.
*/
@@ -98,15 +113,23 @@ function extractCuratedHistory(comprehensiveHistory: Content[]): Content[] {
}
if (isValid) {
curatedHistory.push(...modelOutput);
} else {
// Remove the last user input when model content is invalid.
curatedHistory.pop();
}
}
}
return curatedHistory;
}
/**
* Custom error to signal that a stream completed without valid content,
* which should trigger a retry.
*/
export class EmptyStreamError extends Error {
constructor(message: string) {
super(message);
this.name = 'EmptyStreamError';
}
}
/**
* Chat session that enables sending messages to the model with previous
* conversation context.
@@ -305,65 +328,121 @@ export class GeminiChat {
prompt_id: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
await this.sendPromise;
let streamDoneResolver: () => void;
const streamDonePromise = new Promise<void>((resolve) => {
streamDoneResolver = resolve;
});
this.sendPromise = streamDonePromise;
const userContent = createUserContent(params.message);
const requestContents = this.getHistory(true).concat(userContent);
try {
const apiCall = () => {
const modelToUse = this.config.getModel();
// Add user content to history ONCE before any attempts.
this.history.push(userContent);
const requestContents = this.getHistory(true);
// Prevent Flash model calls immediately after quota error
if (
this.config.getQuotaErrorOccurred() &&
modelToUse === DEFAULT_GEMINI_FLASH_MODEL
// eslint-disable-next-line @typescript-eslint/no-this-alias
const self = this;
return (async function* () {
try {
let lastError: unknown = new Error('Request failed after all retries.');
for (
let attempt = 0;
attempt <= INVALID_CONTENT_RETRY_OPTIONS.maxAttempts;
attempt++
) {
throw new Error(
'Please submit a new query to continue with the Flash model.',
);
try {
const stream = await self.makeApiCallAndProcessStream(
requestContents,
params,
prompt_id,
userContent,
);
for await (const chunk of stream) {
yield chunk;
}
lastError = null;
break;
} catch (error) {
lastError = error;
const isContentError = error instanceof EmptyStreamError;
if (isContentError) {
// Check if we have more attempts left.
if (attempt < INVALID_CONTENT_RETRY_OPTIONS.maxAttempts - 1) {
await new Promise((res) =>
setTimeout(
res,
INVALID_CONTENT_RETRY_OPTIONS.initialDelayMs *
(attempt + 1),
),
);
continue;
}
}
break;
}
}
return this.contentGenerator.generateContentStream(
{
model: modelToUse,
contents: requestContents,
config: { ...this.generationConfig, ...params.config },
},
prompt_id,
);
};
// Note: Retrying streams can be complex. If generateContentStream itself doesn't handle retries
// for transient issues internally before yielding the async generator, this retry will re-initiate
// the stream. For simple 429/500 errors on initial call, this is fine.
// If errors occur mid-stream, this setup won't resume the stream; it will restart it.
const streamResponse = await retryWithBackoff(apiCall, {
shouldRetry: (error: unknown) => {
// Check for known error messages and codes.
if (error instanceof Error && error.message) {
if (isSchemaDepthError(error.message)) return false;
if (error.message.includes('429')) return true;
if (error.message.match(/5\d{2}/)) return true;
if (lastError) {
// If the stream fails, remove the user message that was added.
if (self.history[self.history.length - 1] === userContent) {
self.history.pop();
}
return false; // Don't retry other errors by default
throw lastError;
}
} finally {
streamDoneResolver!();
}
})();
}
private async makeApiCallAndProcessStream(
requestContents: Content[],
params: SendMessageParameters,
prompt_id: string,
userContent: Content,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const apiCall = () => {
const modelToUse = this.config.getModel();
if (
this.config.getQuotaErrorOccurred() &&
modelToUse === DEFAULT_GEMINI_FLASH_MODEL
) {
throw new Error(
'Please submit a new query to continue with the Flash model.',
);
}
return this.contentGenerator.generateContentStream(
{
model: modelToUse,
contents: requestContents,
config: { ...this.generationConfig, ...params.config },
},
onPersistent429: async (authType?: string, error?: unknown) =>
await this.handleFlashFallback(authType, error),
authType: this.config.getContentGeneratorConfig()?.authType,
});
prompt_id,
);
};
// Resolve the internal tracking of send completion promise - `sendPromise`
// for both success and failure response. The actual failure is still
// propagated by the `await streamResponse`.
this.sendPromise = Promise.resolve(streamResponse)
.then(() => undefined)
.catch(() => undefined);
const streamResponse = await retryWithBackoff(apiCall, {
shouldRetry: (error: unknown) => {
if (error instanceof Error && error.message) {
if (isSchemaDepthError(error.message)) return false;
if (error.message.includes('429')) return true;
if (error.message.match(/5\d{2}/)) return true;
}
return false;
},
onPersistent429: async (authType?: string, error?: unknown) =>
await this.handleFlashFallback(authType, error),
authType: this.config.getContentGeneratorConfig()?.authType,
});
const result = this.processStreamResponse(streamResponse, userContent);
return result;
} catch (error) {
this.sendPromise = Promise.resolve();
throw error;
}
return this.processStreamResponse(streamResponse, userContent);
}
/**
@@ -407,8 +486,6 @@ export class GeminiChat {
/**
* Adds a new entry to the chat history.
*
* @param content - The content to add to the history.
*/
addHistory(content: Content): void {
this.history.push(content);
@@ -451,41 +528,41 @@ export class GeminiChat {
private async *processStreamResponse(
streamResponse: AsyncGenerator<GenerateContentResponse>,
inputContent: Content,
) {
const outputContent: Content[] = [];
const chunks: GenerateContentResponse[] = [];
let errorOccurred = false;
userInput: Content,
): AsyncGenerator<GenerateContentResponse> {
const modelResponseParts: Part[] = [];
let isStreamInvalid = false;
let hasReceivedAnyChunk = false;
try {
for await (const chunk of streamResponse) {
if (isValidResponse(chunk)) {
chunks.push(chunk);
const content = chunk.candidates?.[0]?.content;
if (content !== undefined) {
if (this.isThoughtContent(content)) {
yield chunk;
continue;
}
outputContent.push(content);
for await (const chunk of streamResponse) {
hasReceivedAnyChunk = true;
if (isValidResponse(chunk)) {
const content = chunk.candidates?.[0]?.content;
if (content) {
// Filter out thought parts from being added to history.
if (!this.isThoughtContent(content) && content.parts) {
modelResponseParts.push(...content.parts);
}
}
yield chunk;
} else {
isStreamInvalid = true;
}
} catch (error) {
errorOccurred = true;
throw error;
yield chunk; // Yield every chunk to the UI immediately.
}
if (!errorOccurred) {
const allParts: Part[] = [];
for (const content of outputContent) {
if (content.parts) {
allParts.push(...content.parts);
}
}
// Now that the stream is finished, make a decision.
// Throw an error if the stream was invalid OR if it was completely empty.
if (isStreamInvalid || !hasReceivedAnyChunk) {
throw new EmptyStreamError(
'Model stream was invalid or completed without valid content.',
);
}
this.recordHistory(inputContent, outputContent);
// Use recordHistory to correctly save the conversation turn.
const modelOutput: Content[] = [
{ role: 'model', parts: modelResponseParts },
];
this.recordHistory(userInput, modelOutput);
}
private recordHistory(
@@ -493,88 +570,65 @@ export class GeminiChat {
modelOutput: Content[],
automaticFunctionCallingHistory?: Content[],
) {
const newHistoryEntries: Content[] = [];
// Part 1: Handle the user's part of the turn.
if (
automaticFunctionCallingHistory &&
automaticFunctionCallingHistory.length > 0
) {
newHistoryEntries.push(
...extractCuratedHistory(automaticFunctionCallingHistory),
);
} else {
// Guard for streaming calls where the user input might already be in the history.
if (
this.history.length === 0 ||
this.history[this.history.length - 1] !== userInput
) {
newHistoryEntries.push(userInput);
}
}
// Part 2: Handle the model's part of the turn, filtering out thoughts.
const nonThoughtModelOutput = modelOutput.filter(
(content) => !this.isThoughtContent(content),
);
let outputContents: Content[] = [];
if (
nonThoughtModelOutput.length > 0 &&
nonThoughtModelOutput.every((content) => content.role !== undefined)
) {
if (nonThoughtModelOutput.length > 0) {
outputContents = nonThoughtModelOutput;
} else if (nonThoughtModelOutput.length === 0 && modelOutput.length > 0) {
// This case handles when the model returns only a thought.
// We don't want to add an empty model response in this case.
} else {
// When not a function response appends an empty content when model returns empty response, so that the
// history is always alternating between user and model.
// Workaround for: https://b.corp.google.com/issues/420354090
if (!isFunctionResponse(userInput)) {
outputContents.push({
role: 'model',
parts: [],
} as Content);
}
}
if (
automaticFunctionCallingHistory &&
automaticFunctionCallingHistory.length > 0
} else if (
modelOutput.length === 0 &&
!isFunctionResponse(userInput) &&
!automaticFunctionCallingHistory
) {
this.history.push(
...extractCuratedHistory(automaticFunctionCallingHistory),
);
} else {
this.history.push(userInput);
// Add an empty model response if the model truly returned nothing.
outputContents.push({ role: 'model', parts: [] } as Content);
}
// Consolidate adjacent model roles in outputContents
// Part 3: Consolidate the parts of this turn's model response.
const consolidatedOutputContents: Content[] = [];
for (const content of outputContents) {
if (this.isThoughtContent(content)) {
continue;
}
const lastContent =
consolidatedOutputContents[consolidatedOutputContents.length - 1];
if (this.isTextContent(lastContent) && this.isTextContent(content)) {
// If both current and last are text, combine their text into the lastContent's first part
// and append any other parts from the current content.
lastContent.parts[0].text += content.parts[0].text || '';
if (content.parts.length > 1) {
lastContent.parts.push(...content.parts.slice(1));
if (outputContents.length > 0) {
for (const content of outputContents) {
const lastContent =
consolidatedOutputContents[consolidatedOutputContents.length - 1];
if (this.hasTextContent(lastContent) && this.hasTextContent(content)) {
lastContent.parts[0].text += content.parts[0].text || '';
if (content.parts.length > 1) {
lastContent.parts.push(...content.parts.slice(1));
}
} else {
consolidatedOutputContents.push(content);
}
} else {
consolidatedOutputContents.push(content);
}
}
if (consolidatedOutputContents.length > 0) {
const lastHistoryEntry = this.history[this.history.length - 1];
const canMergeWithLastHistory =
!automaticFunctionCallingHistory ||
automaticFunctionCallingHistory.length === 0;
if (
canMergeWithLastHistory &&
this.isTextContent(lastHistoryEntry) &&
this.isTextContent(consolidatedOutputContents[0])
) {
// If both current and last are text, combine their text into the lastHistoryEntry's first part
// and append any other parts from the current content.
lastHistoryEntry.parts[0].text +=
consolidatedOutputContents[0].parts[0].text || '';
if (consolidatedOutputContents[0].parts.length > 1) {
lastHistoryEntry.parts.push(
...consolidatedOutputContents[0].parts.slice(1),
);
}
consolidatedOutputContents.shift(); // Remove the first element as it's merged
}
this.history.push(...consolidatedOutputContents);
}
// Part 4: Add the new turn (user and model parts) to the main history.
this.history.push(...newHistoryEntries, ...consolidatedOutputContents);
}
private isTextContent(
private hasTextContent(
content: Content | undefined,
): content is Content & { parts: [{ text: string }, ...Part[]] } {
return !!(