Fix compression issues (#8225)

This commit is contained in:
Tommaso Sciortino
2025-09-10 16:04:20 -07:00
committed by GitHub
parent 3ceefe8732
commit 03e3c59bf0
4 changed files with 129 additions and 597 deletions

View File

@@ -151,8 +151,8 @@ describe('findIndexAfterFraction', () => {
// 0: 66
// 1: 66 + 68 = 134
// 2: 134 + 66 = 200
// 200 >= 166.5, so index is 2
expect(findIndexAfterFraction(history, 0.5)).toBe(2);
// 200 >= 166.5, so index is 3
expect(findIndexAfterFraction(history, 0.5)).toBe(3);
});
it('should handle a fraction that results in the last index', () => {
@@ -160,8 +160,8 @@ describe('findIndexAfterFraction', () => {
// ...
// 3: 200 + 68 = 268
// 4: 268 + 65 = 333
// 333 >= 299.7, so index is 4
expect(findIndexAfterFraction(history, 0.9)).toBe(4);
// 333 >= 299.7, so index is 5
expect(findIndexAfterFraction(history, 0.9)).toBe(5);
});
it('should handle an empty history', () => {
@@ -169,7 +169,7 @@ describe('findIndexAfterFraction', () => {
});
it('should handle a history with only one item', () => {
expect(findIndexAfterFraction(history.slice(0, 1), 0.5)).toBe(0);
expect(findIndexAfterFraction(history.slice(0, 1), 0.5)).toBe(1);
});
it('should handle history with weird parts', () => {
@@ -178,7 +178,7 @@ describe('findIndexAfterFraction', () => {
{ role: 'model', parts: [{ fileData: { fileUri: 'derp' } }] },
{ role: 'user', parts: [{ text: 'Message 2' }] },
];
expect(findIndexAfterFraction(historyWithEmptyParts, 0.5)).toBe(1);
expect(findIndexAfterFraction(historyWithEmptyParts, 0.5)).toBe(2);
});
});
@@ -534,7 +534,6 @@ describe('Gemini Client (client.ts)', () => {
});
describe('tryCompressChat', () => {
const mockSendMessage = vi.fn();
const mockGetHistory = vi.fn();
beforeEach(() => {
@@ -546,7 +545,6 @@ describe('Gemini Client (client.ts)', () => {
getHistory: mockGetHistory,
addHistory: vi.fn(),
setHistory: vi.fn(),
sendMessage: mockSendMessage,
} as unknown as GeminiChat;
});
@@ -559,7 +557,6 @@ describe('Gemini Client (client.ts)', () => {
const mockChat: Partial<GeminiChat> = {
getHistory: vi.fn().mockReturnValue(chatHistory),
setHistory: vi.fn(),
sendMessage: vi.fn().mockResolvedValue({ text: 'Summary' }),
};
vi.mocked(mockContentGenerator.countTokens)
.mockResolvedValueOnce({ totalTokens: 1000 })
@@ -697,10 +694,16 @@ describe('Gemini Client (client.ts)', () => {
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
// Mock the summary response from the chat
mockSendMessage.mockResolvedValue({
role: 'model',
parts: [{ text: 'This is a summary.' }],
});
mockGenerateContentFn.mockResolvedValue({
candidates: [
{
content: {
role: 'model',
parts: [{ text: 'This is a summary.' }],
},
},
],
} as unknown as GenerateContentResponse);
await client.tryCompressChat('prompt-id-3');
@@ -734,17 +737,23 @@ describe('Gemini Client (client.ts)', () => {
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
// Mock the summary response from the chat
mockSendMessage.mockResolvedValue({
role: 'model',
parts: [{ text: 'This is a summary.' }],
});
mockGenerateContentFn.mockResolvedValue({
candidates: [
{
content: {
role: 'model',
parts: [{ text: 'This is a summary.' }],
},
},
],
} as unknown as GenerateContentResponse);
const initialChat = client.getChat();
const result = await client.tryCompressChat('prompt-id-3');
const newChat = client.getChat();
expect(tokenLimit).toHaveBeenCalled();
expect(mockSendMessage).toHaveBeenCalled();
expect(mockGenerateContentFn).toHaveBeenCalled();
// Assert that summarization happened and returned the correct stats
expect(result).toEqual({
@@ -787,17 +796,23 @@ describe('Gemini Client (client.ts)', () => {
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
// Mock the summary response from the chat
mockSendMessage.mockResolvedValue({
role: 'model',
parts: [{ text: 'This is a summary.' }],
});
mockGenerateContentFn.mockResolvedValue({
candidates: [
{
content: {
role: 'model',
parts: [{ text: 'This is a summary.' }],
},
},
],
} as unknown as GenerateContentResponse);
const initialChat = client.getChat();
const result = await client.tryCompressChat('prompt-id-3');
const newChat = client.getChat();
expect(tokenLimit).toHaveBeenCalled();
expect(mockSendMessage).toHaveBeenCalled();
expect(mockGenerateContentFn).toHaveBeenCalled();
// Assert that summarization happened and returned the correct stats
expect(result).toEqual({
@@ -829,16 +844,22 @@ describe('Gemini Client (client.ts)', () => {
.mockResolvedValueOnce({ totalTokens: newTokenCount });
// Mock the summary response from the chat
mockSendMessage.mockResolvedValue({
role: 'model',
parts: [{ text: 'This is a summary.' }],
});
mockGenerateContentFn.mockResolvedValue({
candidates: [
{
content: {
role: 'model',
parts: [{ text: 'This is a summary.' }],
},
},
],
} as unknown as GenerateContentResponse);
const initialChat = client.getChat();
const result = await client.tryCompressChat('prompt-id-1', true); // force = true
const newChat = client.getChat();
expect(mockSendMessage).toHaveBeenCalled();
expect(mockGenerateContentFn).toHaveBeenCalled();
expect(result).toEqual({
compressionStatus: CompressionStatus.COMPRESSED,
@@ -1503,7 +1524,6 @@ ${JSON.stringify(
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
setHistory: vi.fn(),
sendMessage: vi.fn().mockResolvedValue({ text: 'summary' }),
// Assume history is not empty for delta checks
getHistory: vi
.fn()
@@ -1764,7 +1784,6 @@ ${JSON.stringify(
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]), // Default empty history
setHistory: vi.fn(),
sendMessage: vi.fn().mockResolvedValue({ text: 'summary' }),
};
client['chat'] = mockChat as GeminiChat;

View File

@@ -86,10 +86,10 @@ export function findIndexAfterFraction(
let charactersSoFar = 0;
for (let i = 0; i < contentLengths.length; i++) {
charactersSoFar += contentLengths[i];
if (charactersSoFar >= targetCharacters) {
return i;
}
charactersSoFar += contentLengths[i];
}
return contentLengths.length;
}
@@ -836,19 +836,30 @@ export class GeminiClient {
const historyToCompress = curatedHistory.slice(0, compressBeforeIndex);
const historyToKeep = curatedHistory.slice(compressBeforeIndex);
this.getChat().setHistory(historyToCompress);
const summaryResponse = await this.config
.getContentGenerator()
.generateContent(
{
model,
contents: [
...historyToCompress,
{
role: 'user',
parts: [
{
text: 'First, reason in your scratchpad. Then, generate the <state_snapshot>.',
},
],
},
],
config: {
systemInstruction: { text: getCompressionPrompt() },
},
},
prompt_id,
);
const summary = getResponseText(summaryResponse) ?? '';
const { text: summary } = await this.getChat().sendMessage(
{
message: {
text: 'First, reason in your scratchpad. Then, generate the <state_snapshot>.',
},
config: {
systemInstruction: { text: getCompressionPrompt() },
},
},
prompt_id,
);
const chat = await this.startChat([
{
role: 'user',

View File

@@ -135,261 +135,6 @@ describe('GeminiChat', () => {
vi.resetAllMocks();
});
describe('sendMessage', () => {
it('should retain the initial user message when an automatic function call occurs', async () => {
// 1. Define the user's initial text message. This is the turn that gets dropped by the buggy logic.
const userInitialMessage: Content = {
role: 'user',
parts: [{ text: 'How is the weather in Boston?' }],
};
// 2. Mock the full API response, including the automaticFunctionCallingHistory.
// This history represents the full turn: user asks, model calls tool, tool responds, model answers.
const mockAfcResponse = {
candidates: [
{
content: {
role: 'model',
parts: [
{ text: 'The weather in Boston is 72 degrees and sunny.' },
],
},
},
],
automaticFunctionCallingHistory: [
userInitialMessage, // The user's turn
{
// The model's first response: a tool call
role: 'model',
parts: [
{
functionCall: {
name: 'get_weather',
args: { location: 'Boston' },
},
},
],
},
{
// The tool's response, which has a 'user' role
role: 'user',
parts: [
{
functionResponse: {
name: 'get_weather',
response: { temperature: 72, condition: 'sunny' },
},
},
],
},
],
} as unknown as GenerateContentResponse;
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
mockAfcResponse,
);
// 3. Action: Send the initial message.
await chat.sendMessage(
{ message: 'How is the weather in Boston?' },
'prompt-id-afc-bug',
);
// 4. Assert: Check the final state of the history.
const history = chat.getHistory();
// With the bug, history.length will be 3, because the first user message is dropped.
// The correct behavior is for the history to contain all 4 turns.
expect(history.length).toBe(4);
// Crucially, assert that the very first turn in the history matches the user's initial message.
// This is the assertion that will fail.
const firstTurn = history[0]!;
expect(firstTurn.role).toBe('user');
expect(firstTurn?.parts![0]!.text).toBe('How is the weather in Boston?');
// Verify the rest of the history is also correct.
const secondTurn = history[1]!;
expect(secondTurn.role).toBe('model');
expect(secondTurn?.parts![0]!.functionCall).toBeDefined();
const thirdTurn = history[2]!;
expect(thirdTurn.role).toBe('user');
expect(thirdTurn?.parts![0]!.functionResponse).toBeDefined();
const fourthTurn = history[3]!;
expect(fourthTurn.role).toBe('model');
expect(fourthTurn?.parts![0]!.text).toContain('72 degrees and sunny');
});
it('should throw an error when attempting to add a user turn after another user turn', async () => {
// 1. Setup: Create a history that already ends with a user turn (a functionResponse).
const initialHistory: Content[] = [
{ role: 'user', parts: [{ text: 'Initial prompt' }] },
{
role: 'model',
parts: [{ functionCall: { name: 'test_tool', args: {} } }],
},
{
role: 'user',
parts: [{ functionResponse: { name: 'test_tool', response: {} } }],
},
];
chat.setHistory(initialHistory);
// 2. Mock a valid model response so the call doesn't fail for other reasons.
const mockResponse = {
candidates: [
{ content: { role: 'model', parts: [{ text: 'some response' }] } },
],
} as unknown as GenerateContentResponse;
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
mockResponse,
);
// 3. Action & Assert: Expect that sending another user message immediately
// after a user-role turn throws the specific error.
await expect(
chat.sendMessage(
{ message: 'This is an invalid consecutive user message' },
'prompt-id-1',
),
).rejects.toThrow('Cannot add a user turn after another user turn.');
});
it('should preserve text parts that are in the same response as a thought', async () => {
// 1. Mock the API to return a single response containing both a thought and visible text.
const mixedContentResponse = {
candidates: [
{
content: {
role: 'model',
parts: [
{ thought: 'This is a thought.' },
{ text: 'This is the visible text that should not be lost.' },
],
},
},
],
} as unknown as GenerateContentResponse;
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
mixedContentResponse,
);
// 2. Action: Send a standard, non-streaming message.
await chat.sendMessage(
{ message: 'test message' },
'prompt-id-mixed-response',
);
// 3. Assert: Check the final state of the history.
const history = chat.getHistory();
// The history should contain two turns: the user's message and the model's response.
expect(history.length).toBe(2);
const modelTurn = history[1]!;
expect(modelTurn.role).toBe('model');
// CRUCIAL ASSERTION:
// Buggy code would discard the entire response because a "thought" was present,
// resulting in an empty placeholder turn with 0 parts.
// The corrected code will pass, preserving the single visible text part.
expect(modelTurn?.parts?.length).toBe(1);
expect(modelTurn?.parts![0]!.text).toBe(
'This is the visible text that should not be lost.',
);
});
it('should add a placeholder model turn when a tool call is followed by an empty model response', async () => {
// 1. Setup: A history where the model has just made a function call.
const initialHistory: Content[] = [
{
role: 'user',
parts: [{ text: 'Find a good Italian restaurant for me.' }],
},
{
role: 'model',
parts: [
{
functionCall: {
name: 'find_restaurant',
args: { cuisine: 'Italian' },
},
},
],
},
];
chat.setHistory(initialHistory);
// 2. Mock the API to return an empty/thought-only response.
const emptyModelResponse = {
candidates: [
{ content: { role: 'model', parts: [{ thought: true }] } },
],
} as unknown as GenerateContentResponse;
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
emptyModelResponse,
);
// 3. Action: Send the function response back to the model.
await chat.sendMessage(
{
message: {
functionResponse: {
name: 'find_restaurant',
response: { name: 'Vesuvio' },
},
},
},
'prompt-id-1',
);
// 4. Assert: The history should now have four valid, alternating turns.
const history = chat.getHistory();
expect(history.length).toBe(4);
// The final turn must be the empty model placeholder.
const lastTurn = history[3]!;
expect(lastTurn.role).toBe('model');
expect(lastTurn?.parts?.length).toBe(0);
// The second-to-last turn must be the function response we sent.
const secondToLastTurn = history[2]!;
expect(secondToLastTurn.role).toBe('user');
expect(secondToLastTurn?.parts![0]!.functionResponse).toBeDefined();
});
it('should call generateContent with the correct parameters', async () => {
const response = {
candidates: [
{
content: {
parts: [{ text: 'response' }],
role: 'model',
},
finishReason: 'STOP',
index: 0,
safetyRatings: [],
},
],
text: () => 'response',
} as unknown as GenerateContentResponse;
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
response,
);
await chat.sendMessage({ message: 'hello' }, 'prompt-id-1');
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
{
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
config: {},
},
'prompt-id-1',
);
});
});
describe('sendMessageStream', () => {
it('should succeed if a tool call is followed by an empty part', async () => {
// 1. Mock a stream that contains a tool call, then an invalid (empty) part.
@@ -1163,20 +908,15 @@ describe('GeminiChat', () => {
})(),
);
// This helper function consumes the stream and allows us to test for rejection.
async function consumeStreamAndExpectError() {
const stream = await chat.sendMessageStream(
{ message: 'test' },
'prompt-id-retry-fail',
);
const stream = await chat.sendMessageStream(
{ message: 'test' },
'prompt-id-retry-fail',
);
await expect(async () => {
for await (const _ of stream) {
// Must loop to trigger the internal logic that throws.
}
}
await expect(consumeStreamAndExpectError()).rejects.toThrow(
EmptyStreamError,
);
}).rejects.toThrow(EmptyStreamError);
// Should be called 3 times (initial + 2 retries)
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(
@@ -1268,87 +1008,6 @@ describe('GeminiChat', () => {
expect(turn4.parts[0].text).toBe('Second answer');
});
describe('concurrency control', () => {
it('should queue a subsequent sendMessage call until the first one completes', async () => {
// 1. Create promises to manually control when the API calls resolve
let firstCallResolver: (value: GenerateContentResponse) => void;
const firstCallPromise = new Promise<GenerateContentResponse>(
(resolve) => {
firstCallResolver = resolve;
},
);
let secondCallResolver: (value: GenerateContentResponse) => void;
const secondCallPromise = new Promise<GenerateContentResponse>(
(resolve) => {
secondCallResolver = resolve;
},
);
// A standard response body for the mock
const mockResponse = {
candidates: [
{
content: { parts: [{ text: 'response' }], role: 'model' },
},
],
} as unknown as GenerateContentResponse;
// 2. Mock the API to return our controllable promises in order
vi.mocked(mockContentGenerator.generateContent)
.mockReturnValueOnce(firstCallPromise)
.mockReturnValueOnce(secondCallPromise);
// 3. Start the first message call. Do not await it yet.
const firstMessagePromise = chat.sendMessage(
{ message: 'first' },
'prompt-1',
);
// Give the event loop a chance to run the async call up to the `await`
await new Promise(process.nextTick);
// 4. While the first call is "in-flight", start the second message call.
const secondMessagePromise = chat.sendMessage(
{ message: 'second' },
'prompt-2',
);
// 5. CRUCIAL CHECK: At this point, only the first API call should have been made.
// The second call should be waiting on `sendPromise`.
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(1);
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
contents: expect.arrayContaining([
expect.objectContaining({ parts: [{ text: 'first' }] }),
]),
}),
'prompt-1',
);
// 6. Unblock the first API call and wait for the first message to fully complete.
firstCallResolver!(mockResponse);
await firstMessagePromise;
// Give the event loop a chance to unblock and run the second call.
await new Promise(process.nextTick);
// 7. CRUCIAL CHECK: Now, the second API call should have been made.
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(2);
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
contents: expect.arrayContaining([
expect.objectContaining({ parts: [{ text: 'second' }] }),
]),
}),
'prompt-2',
);
// 8. Clean up by resolving the second call.
secondCallResolver!(mockResponse);
await secondMessagePromise;
});
});
it('should retry if the model returns a completely empty stream (no chunks)', async () => {
// 1. Mock the API to return an empty stream first, then a valid one.
vi.mocked(mockContentGenerator.generateContentStream)
@@ -1510,40 +1169,6 @@ describe('GeminiChat', () => {
],
} as unknown as GenerateContentResponse;
it('should use the configured model when not in fallback mode (sendMessage)', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-2.5-pro');
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(false);
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
mockResponse,
);
await chat.sendMessage({ message: 'test' }, 'prompt-id-res1');
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
model: 'gemini-2.5-pro',
}),
'prompt-id-res1',
);
});
it('should use the FLASH model when in fallback mode (sendMessage)', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-2.5-pro');
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true);
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
mockResponse,
);
await chat.sendMessage({ message: 'test' }, 'prompt-id-res2');
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
model: DEFAULT_GEMINI_FLASH_MODEL,
}),
'prompt-id-res2',
);
});
it('should use the FLASH model when in fallback mode (sendMessageStream)', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-pro');
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true);
@@ -1620,26 +1245,41 @@ describe('GeminiChat', () => {
const isInFallbackModeSpy = vi.spyOn(mockConfig, 'isInFallbackMode');
isInFallbackModeSpy.mockReturnValue(false);
vi.mocked(mockContentGenerator.generateContent)
vi.mocked(mockContentGenerator.generateContentStream)
.mockRejectedValueOnce(error429) // Attempt 1 fails
.mockResolvedValueOnce({
candidates: [{ content: { parts: [{ text: 'Success on retry' }] } }],
} as unknown as GenerateContentResponse); // Attempt 2 succeeds
.mockResolvedValueOnce(
// Attempt 2 succeeds
(async function* () {
yield {
candidates: [
{
content: { parts: [{ text: 'Success on retry' }] },
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})(),
);
mockHandleFallback.mockImplementation(async () => {
isInFallbackModeSpy.mockReturnValue(true);
return true; // Signal retry
});
const result = await chat.sendMessage(
const stream = await chat.sendMessageStream(
{ message: 'trigger 429' },
'prompt-id-fb1',
);
expect(mockRetryWithBackoff).toHaveBeenCalledTimes(1);
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(2);
expect(mockHandleFallback).toHaveBeenCalledTimes(1);
// Consume stream to trigger logic
for await (const _ of stream) {
// no-op
}
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(
2,
);
expect(mockHandleFallback).toHaveBeenCalledTimes(1);
expect(mockHandleFallback).toHaveBeenCalledWith(
mockConfig,
FAILED_MODEL,
@@ -1647,23 +1287,34 @@ describe('GeminiChat', () => {
error429,
);
expect(result.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
'Success on retry',
);
const history = chat.getHistory();
const modelTurn = history[1]!;
expect(modelTurn.parts![0]!.text).toBe('Success on retry');
});
it('should stop retrying if handleFallback returns false (e.g., auth intent)', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-pro');
vi.mocked(mockContentGenerator.generateContent).mockRejectedValue(
vi.mocked(mockContentGenerator.generateContentStream).mockRejectedValue(
error429,
);
mockHandleFallback.mockResolvedValue(false);
const stream = await chat.sendMessageStream(
{ message: 'test stop' },
'prompt-id-fb2',
);
await expect(
chat.sendMessage({ message: 'test stop' }, 'prompt-id-fb2'),
(async () => {
for await (const _ of stream) {
/* consume stream */
}
})(),
).rejects.toThrow(error429);
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(1);
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(
1,
);
expect(mockHandleFallback).toHaveBeenCalledTimes(1);
});
});

View File

@@ -182,138 +182,6 @@ export class GeminiChat {
setSystemInstruction(sysInstr: string) {
this.generationConfig.systemInstruction = sysInstr;
}
/**
* Sends a message to the model and returns the response.
*
* @remarks
* This method will wait for the previous message to be processed before
* sending the next message.
*
* @see {@link Chat#sendMessageStream} for streaming method.
* @param params - parameters for sending messages within a chat session.
* @returns The model's response.
*
* @example
* ```ts
* const chat = ai.chats.create({model: 'gemini-2.0-flash'});
* const response = await chat.sendMessage({
* message: 'Why is the sky blue?'
* });
* console.log(response.text);
* ```
*/
async sendMessage(
params: SendMessageParameters,
prompt_id: string,
): Promise<GenerateContentResponse> {
await this.sendPromise;
const userContent = createUserContent(params.message);
// Record user input - capture complete message with all parts (text, files, images, etc.)
// but skip recording function responses (tool call results) as they should be stored in tool call records
if (!isFunctionResponse(userContent)) {
const userMessage = Array.isArray(params.message)
? params.message
: [params.message];
this.chatRecordingService.recordMessage({
type: 'user',
content: userMessage,
});
}
const requestContents = this.getHistory(true).concat(userContent);
let response: GenerateContentResponse;
try {
let currentAttemptModel: string | undefined;
const apiCall = () => {
const modelToUse = this.config.isInFallbackMode()
? DEFAULT_GEMINI_FLASH_MODEL
: this.config.getModel();
currentAttemptModel = modelToUse;
// Prevent Flash model calls immediately after quota error
if (
this.config.getQuotaErrorOccurred() &&
modelToUse === DEFAULT_GEMINI_FLASH_MODEL
) {
throw new Error(
'Please submit a new query to continue with the Flash model.',
);
}
return this.config.getContentGenerator().generateContent(
{
model: modelToUse,
contents: requestContents,
config: { ...this.generationConfig, ...params.config },
},
prompt_id,
);
};
const onPersistent429Callback = async (
authType?: string,
error?: unknown,
) => {
if (!currentAttemptModel) return null;
return await handleFallback(
this.config,
currentAttemptModel,
authType,
error,
);
};
response = await retryWithBackoff(apiCall, {
shouldRetry: (error: unknown) => {
// Check for known error messages and codes.
if (error instanceof Error && error.message) {
if (isSchemaDepthError(error.message)) return false;
if (error.message.includes('429')) return true;
if (error.message.match(/5\d{2}/)) return true;
}
return false; // Don't retry other errors by default
},
onPersistent429: onPersistent429Callback,
authType: this.config.getContentGeneratorConfig()?.authType,
});
this.sendPromise = (async () => {
const outputContent = response.candidates?.[0]?.content;
const modelOutput = outputContent ? [outputContent] : [];
// Because the AFC input contains the entire curated chat history in
// addition to the new user input, we need to truncate the AFC history
// to deduplicate the existing chat history.
const fullAutomaticFunctionCallingHistory =
response.automaticFunctionCallingHistory;
const index = this.getHistory(true).length;
let automaticFunctionCallingHistory: Content[] = [];
if (fullAutomaticFunctionCallingHistory != null) {
automaticFunctionCallingHistory =
fullAutomaticFunctionCallingHistory.slice(index) ?? [];
}
this.recordHistory(
userContent,
modelOutput,
automaticFunctionCallingHistory,
);
})();
await this.sendPromise.catch((error) => {
// Resets sendPromise to avoid subsequent calls failing
this.sendPromise = Promise.resolve();
// Re-throw the error so the caller knows something went wrong.
throw error;
});
return response;
} catch (error) {
this.sendPromise = Promise.resolve();
throw error;
}
}
/**
* Sends a message to the model and returns the response in chunks.
@@ -704,35 +572,18 @@ export class GeminiChat {
this.recordHistory(userInput, modelOutput);
}
private recordHistory(
userInput: Content,
modelOutput: Content[],
automaticFunctionCallingHistory?: Content[],
) {
private recordHistory(userInput: Content, modelOutput: Content[]) {
// Part 1: Handle the user's turn.
if (
automaticFunctionCallingHistory &&
automaticFunctionCallingHistory.length > 0
) {
this.history.push(
...extractCuratedHistory(automaticFunctionCallingHistory),
);
} else {
if (
this.history.length === 0 ||
this.history[this.history.length - 1] !== userInput
) {
const lastTurn = this.history[this.history.length - 1];
// The only time we don't push is if it's the *exact same* object,
// which happens in streaming where we add it preemptively.
if (lastTurn !== userInput) {
if (lastTurn?.role === 'user') {
// This is an invalid sequence.
throw new Error('Cannot add a user turn after another user turn.');
}
this.history.push(userInput);
}
const lastTurn = this.history[this.history.length - 1];
// The only time we don't push is if it's the *exact same* object,
// which happens in streaming where we add it preemptively.
if (lastTurn !== userInput) {
if (lastTurn?.role === 'user') {
// This is an invalid sequence.
throw new Error('Cannot add a user turn after another user turn.');
}
this.history.push(userInput);
}
// Part 2: Process the model output into a final, consolidated list of turns.