mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-11 06:31:01 -07:00
Fix compression issues (#8225)
This commit is contained in:
committed by
GitHub
parent
3ceefe8732
commit
03e3c59bf0
@@ -151,8 +151,8 @@ describe('findIndexAfterFraction', () => {
|
||||
// 0: 66
|
||||
// 1: 66 + 68 = 134
|
||||
// 2: 134 + 66 = 200
|
||||
// 200 >= 166.5, so index is 2
|
||||
expect(findIndexAfterFraction(history, 0.5)).toBe(2);
|
||||
// 200 >= 166.5, so index is 3
|
||||
expect(findIndexAfterFraction(history, 0.5)).toBe(3);
|
||||
});
|
||||
|
||||
it('should handle a fraction that results in the last index', () => {
|
||||
@@ -160,8 +160,8 @@ describe('findIndexAfterFraction', () => {
|
||||
// ...
|
||||
// 3: 200 + 68 = 268
|
||||
// 4: 268 + 65 = 333
|
||||
// 333 >= 299.7, so index is 4
|
||||
expect(findIndexAfterFraction(history, 0.9)).toBe(4);
|
||||
// 333 >= 299.7, so index is 5
|
||||
expect(findIndexAfterFraction(history, 0.9)).toBe(5);
|
||||
});
|
||||
|
||||
it('should handle an empty history', () => {
|
||||
@@ -169,7 +169,7 @@ describe('findIndexAfterFraction', () => {
|
||||
});
|
||||
|
||||
it('should handle a history with only one item', () => {
|
||||
expect(findIndexAfterFraction(history.slice(0, 1), 0.5)).toBe(0);
|
||||
expect(findIndexAfterFraction(history.slice(0, 1), 0.5)).toBe(1);
|
||||
});
|
||||
|
||||
it('should handle history with weird parts', () => {
|
||||
@@ -178,7 +178,7 @@ describe('findIndexAfterFraction', () => {
|
||||
{ role: 'model', parts: [{ fileData: { fileUri: 'derp' } }] },
|
||||
{ role: 'user', parts: [{ text: 'Message 2' }] },
|
||||
];
|
||||
expect(findIndexAfterFraction(historyWithEmptyParts, 0.5)).toBe(1);
|
||||
expect(findIndexAfterFraction(historyWithEmptyParts, 0.5)).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -534,7 +534,6 @@ describe('Gemini Client (client.ts)', () => {
|
||||
});
|
||||
|
||||
describe('tryCompressChat', () => {
|
||||
const mockSendMessage = vi.fn();
|
||||
const mockGetHistory = vi.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -546,7 +545,6 @@ describe('Gemini Client (client.ts)', () => {
|
||||
getHistory: mockGetHistory,
|
||||
addHistory: vi.fn(),
|
||||
setHistory: vi.fn(),
|
||||
sendMessage: mockSendMessage,
|
||||
} as unknown as GeminiChat;
|
||||
});
|
||||
|
||||
@@ -559,7 +557,6 @@ describe('Gemini Client (client.ts)', () => {
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
getHistory: vi.fn().mockReturnValue(chatHistory),
|
||||
setHistory: vi.fn(),
|
||||
sendMessage: vi.fn().mockResolvedValue({ text: 'Summary' }),
|
||||
};
|
||||
vi.mocked(mockContentGenerator.countTokens)
|
||||
.mockResolvedValueOnce({ totalTokens: 1000 })
|
||||
@@ -697,10 +694,16 @@ describe('Gemini Client (client.ts)', () => {
|
||||
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
|
||||
|
||||
// Mock the summary response from the chat
|
||||
mockSendMessage.mockResolvedValue({
|
||||
role: 'model',
|
||||
parts: [{ text: 'This is a summary.' }],
|
||||
});
|
||||
mockGenerateContentFn.mockResolvedValue({
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [{ text: 'This is a summary.' }],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse);
|
||||
|
||||
await client.tryCompressChat('prompt-id-3');
|
||||
|
||||
@@ -734,17 +737,23 @@ describe('Gemini Client (client.ts)', () => {
|
||||
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
|
||||
|
||||
// Mock the summary response from the chat
|
||||
mockSendMessage.mockResolvedValue({
|
||||
role: 'model',
|
||||
parts: [{ text: 'This is a summary.' }],
|
||||
});
|
||||
mockGenerateContentFn.mockResolvedValue({
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [{ text: 'This is a summary.' }],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse);
|
||||
|
||||
const initialChat = client.getChat();
|
||||
const result = await client.tryCompressChat('prompt-id-3');
|
||||
const newChat = client.getChat();
|
||||
|
||||
expect(tokenLimit).toHaveBeenCalled();
|
||||
expect(mockSendMessage).toHaveBeenCalled();
|
||||
expect(mockGenerateContentFn).toHaveBeenCalled();
|
||||
|
||||
// Assert that summarization happened and returned the correct stats
|
||||
expect(result).toEqual({
|
||||
@@ -787,17 +796,23 @@ describe('Gemini Client (client.ts)', () => {
|
||||
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
|
||||
|
||||
// Mock the summary response from the chat
|
||||
mockSendMessage.mockResolvedValue({
|
||||
role: 'model',
|
||||
parts: [{ text: 'This is a summary.' }],
|
||||
});
|
||||
mockGenerateContentFn.mockResolvedValue({
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [{ text: 'This is a summary.' }],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse);
|
||||
|
||||
const initialChat = client.getChat();
|
||||
const result = await client.tryCompressChat('prompt-id-3');
|
||||
const newChat = client.getChat();
|
||||
|
||||
expect(tokenLimit).toHaveBeenCalled();
|
||||
expect(mockSendMessage).toHaveBeenCalled();
|
||||
expect(mockGenerateContentFn).toHaveBeenCalled();
|
||||
|
||||
// Assert that summarization happened and returned the correct stats
|
||||
expect(result).toEqual({
|
||||
@@ -829,16 +844,22 @@ describe('Gemini Client (client.ts)', () => {
|
||||
.mockResolvedValueOnce({ totalTokens: newTokenCount });
|
||||
|
||||
// Mock the summary response from the chat
|
||||
mockSendMessage.mockResolvedValue({
|
||||
role: 'model',
|
||||
parts: [{ text: 'This is a summary.' }],
|
||||
});
|
||||
mockGenerateContentFn.mockResolvedValue({
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [{ text: 'This is a summary.' }],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse);
|
||||
|
||||
const initialChat = client.getChat();
|
||||
const result = await client.tryCompressChat('prompt-id-1', true); // force = true
|
||||
const newChat = client.getChat();
|
||||
|
||||
expect(mockSendMessage).toHaveBeenCalled();
|
||||
expect(mockGenerateContentFn).toHaveBeenCalled();
|
||||
|
||||
expect(result).toEqual({
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
@@ -1503,7 +1524,6 @@ ${JSON.stringify(
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
setHistory: vi.fn(),
|
||||
sendMessage: vi.fn().mockResolvedValue({ text: 'summary' }),
|
||||
// Assume history is not empty for delta checks
|
||||
getHistory: vi
|
||||
.fn()
|
||||
@@ -1764,7 +1784,6 @@ ${JSON.stringify(
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]), // Default empty history
|
||||
setHistory: vi.fn(),
|
||||
sendMessage: vi.fn().mockResolvedValue({ text: 'summary' }),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
|
||||
@@ -86,10 +86,10 @@ export function findIndexAfterFraction(
|
||||
|
||||
let charactersSoFar = 0;
|
||||
for (let i = 0; i < contentLengths.length; i++) {
|
||||
charactersSoFar += contentLengths[i];
|
||||
if (charactersSoFar >= targetCharacters) {
|
||||
return i;
|
||||
}
|
||||
charactersSoFar += contentLengths[i];
|
||||
}
|
||||
return contentLengths.length;
|
||||
}
|
||||
@@ -836,19 +836,30 @@ export class GeminiClient {
|
||||
const historyToCompress = curatedHistory.slice(0, compressBeforeIndex);
|
||||
const historyToKeep = curatedHistory.slice(compressBeforeIndex);
|
||||
|
||||
this.getChat().setHistory(historyToCompress);
|
||||
const summaryResponse = await this.config
|
||||
.getContentGenerator()
|
||||
.generateContent(
|
||||
{
|
||||
model,
|
||||
contents: [
|
||||
...historyToCompress,
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
text: 'First, reason in your scratchpad. Then, generate the <state_snapshot>.',
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
config: {
|
||||
systemInstruction: { text: getCompressionPrompt() },
|
||||
},
|
||||
},
|
||||
prompt_id,
|
||||
);
|
||||
const summary = getResponseText(summaryResponse) ?? '';
|
||||
|
||||
const { text: summary } = await this.getChat().sendMessage(
|
||||
{
|
||||
message: {
|
||||
text: 'First, reason in your scratchpad. Then, generate the <state_snapshot>.',
|
||||
},
|
||||
config: {
|
||||
systemInstruction: { text: getCompressionPrompt() },
|
||||
},
|
||||
},
|
||||
prompt_id,
|
||||
);
|
||||
const chat = await this.startChat([
|
||||
{
|
||||
role: 'user',
|
||||
|
||||
@@ -135,261 +135,6 @@ describe('GeminiChat', () => {
|
||||
vi.resetAllMocks();
|
||||
});
|
||||
|
||||
describe('sendMessage', () => {
|
||||
it('should retain the initial user message when an automatic function call occurs', async () => {
|
||||
// 1. Define the user's initial text message. This is the turn that gets dropped by the buggy logic.
|
||||
const userInitialMessage: Content = {
|
||||
role: 'user',
|
||||
parts: [{ text: 'How is the weather in Boston?' }],
|
||||
};
|
||||
|
||||
// 2. Mock the full API response, including the automaticFunctionCallingHistory.
|
||||
// This history represents the full turn: user asks, model calls tool, tool responds, model answers.
|
||||
const mockAfcResponse = {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [
|
||||
{ text: 'The weather in Boston is 72 degrees and sunny.' },
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
automaticFunctionCallingHistory: [
|
||||
userInitialMessage, // The user's turn
|
||||
{
|
||||
// The model's first response: a tool call
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
functionCall: {
|
||||
name: 'get_weather',
|
||||
args: { location: 'Boston' },
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
// The tool's response, which has a 'user' role
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: 'get_weather',
|
||||
response: { temperature: 72, condition: 'sunny' },
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
|
||||
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
|
||||
mockAfcResponse,
|
||||
);
|
||||
|
||||
// 3. Action: Send the initial message.
|
||||
await chat.sendMessage(
|
||||
{ message: 'How is the weather in Boston?' },
|
||||
'prompt-id-afc-bug',
|
||||
);
|
||||
|
||||
// 4. Assert: Check the final state of the history.
|
||||
const history = chat.getHistory();
|
||||
|
||||
// With the bug, history.length will be 3, because the first user message is dropped.
|
||||
// The correct behavior is for the history to contain all 4 turns.
|
||||
expect(history.length).toBe(4);
|
||||
|
||||
// Crucially, assert that the very first turn in the history matches the user's initial message.
|
||||
// This is the assertion that will fail.
|
||||
const firstTurn = history[0]!;
|
||||
expect(firstTurn.role).toBe('user');
|
||||
expect(firstTurn?.parts![0]!.text).toBe('How is the weather in Boston?');
|
||||
|
||||
// Verify the rest of the history is also correct.
|
||||
const secondTurn = history[1]!;
|
||||
expect(secondTurn.role).toBe('model');
|
||||
expect(secondTurn?.parts![0]!.functionCall).toBeDefined();
|
||||
|
||||
const thirdTurn = history[2]!;
|
||||
expect(thirdTurn.role).toBe('user');
|
||||
expect(thirdTurn?.parts![0]!.functionResponse).toBeDefined();
|
||||
|
||||
const fourthTurn = history[3]!;
|
||||
expect(fourthTurn.role).toBe('model');
|
||||
expect(fourthTurn?.parts![0]!.text).toContain('72 degrees and sunny');
|
||||
});
|
||||
|
||||
it('should throw an error when attempting to add a user turn after another user turn', async () => {
|
||||
// 1. Setup: Create a history that already ends with a user turn (a functionResponse).
|
||||
const initialHistory: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'Initial prompt' }] },
|
||||
{
|
||||
role: 'model',
|
||||
parts: [{ functionCall: { name: 'test_tool', args: {} } }],
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ functionResponse: { name: 'test_tool', response: {} } }],
|
||||
},
|
||||
];
|
||||
chat.setHistory(initialHistory);
|
||||
|
||||
// 2. Mock a valid model response so the call doesn't fail for other reasons.
|
||||
const mockResponse = {
|
||||
candidates: [
|
||||
{ content: { role: 'model', parts: [{ text: 'some response' }] } },
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
|
||||
mockResponse,
|
||||
);
|
||||
|
||||
// 3. Action & Assert: Expect that sending another user message immediately
|
||||
// after a user-role turn throws the specific error.
|
||||
await expect(
|
||||
chat.sendMessage(
|
||||
{ message: 'This is an invalid consecutive user message' },
|
||||
'prompt-id-1',
|
||||
),
|
||||
).rejects.toThrow('Cannot add a user turn after another user turn.');
|
||||
});
|
||||
it('should preserve text parts that are in the same response as a thought', async () => {
|
||||
// 1. Mock the API to return a single response containing both a thought and visible text.
|
||||
const mixedContentResponse = {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [
|
||||
{ thought: 'This is a thought.' },
|
||||
{ text: 'This is the visible text that should not be lost.' },
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
|
||||
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
|
||||
mixedContentResponse,
|
||||
);
|
||||
|
||||
// 2. Action: Send a standard, non-streaming message.
|
||||
await chat.sendMessage(
|
||||
{ message: 'test message' },
|
||||
'prompt-id-mixed-response',
|
||||
);
|
||||
|
||||
// 3. Assert: Check the final state of the history.
|
||||
const history = chat.getHistory();
|
||||
|
||||
// The history should contain two turns: the user's message and the model's response.
|
||||
expect(history.length).toBe(2);
|
||||
|
||||
const modelTurn = history[1]!;
|
||||
expect(modelTurn.role).toBe('model');
|
||||
|
||||
// CRUCIAL ASSERTION:
|
||||
// Buggy code would discard the entire response because a "thought" was present,
|
||||
// resulting in an empty placeholder turn with 0 parts.
|
||||
// The corrected code will pass, preserving the single visible text part.
|
||||
expect(modelTurn?.parts?.length).toBe(1);
|
||||
expect(modelTurn?.parts![0]!.text).toBe(
|
||||
'This is the visible text that should not be lost.',
|
||||
);
|
||||
});
|
||||
it('should add a placeholder model turn when a tool call is followed by an empty model response', async () => {
|
||||
// 1. Setup: A history where the model has just made a function call.
|
||||
const initialHistory: Content[] = [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'Find a good Italian restaurant for me.' }],
|
||||
},
|
||||
{
|
||||
role: 'model',
|
||||
parts: [
|
||||
{
|
||||
functionCall: {
|
||||
name: 'find_restaurant',
|
||||
args: { cuisine: 'Italian' },
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
chat.setHistory(initialHistory);
|
||||
|
||||
// 2. Mock the API to return an empty/thought-only response.
|
||||
const emptyModelResponse = {
|
||||
candidates: [
|
||||
{ content: { role: 'model', parts: [{ thought: true }] } },
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
|
||||
emptyModelResponse,
|
||||
);
|
||||
|
||||
// 3. Action: Send the function response back to the model.
|
||||
await chat.sendMessage(
|
||||
{
|
||||
message: {
|
||||
functionResponse: {
|
||||
name: 'find_restaurant',
|
||||
response: { name: 'Vesuvio' },
|
||||
},
|
||||
},
|
||||
},
|
||||
'prompt-id-1',
|
||||
);
|
||||
|
||||
// 4. Assert: The history should now have four valid, alternating turns.
|
||||
const history = chat.getHistory();
|
||||
expect(history.length).toBe(4);
|
||||
|
||||
// The final turn must be the empty model placeholder.
|
||||
const lastTurn = history[3]!;
|
||||
expect(lastTurn.role).toBe('model');
|
||||
expect(lastTurn?.parts?.length).toBe(0);
|
||||
|
||||
// The second-to-last turn must be the function response we sent.
|
||||
const secondToLastTurn = history[2]!;
|
||||
expect(secondToLastTurn.role).toBe('user');
|
||||
expect(secondToLastTurn?.parts![0]!.functionResponse).toBeDefined();
|
||||
});
|
||||
it('should call generateContent with the correct parameters', async () => {
|
||||
const response = {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: 'response' }],
|
||||
role: 'model',
|
||||
},
|
||||
finishReason: 'STOP',
|
||||
index: 0,
|
||||
safetyRatings: [],
|
||||
},
|
||||
],
|
||||
text: () => 'response',
|
||||
} as unknown as GenerateContentResponse;
|
||||
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
|
||||
response,
|
||||
);
|
||||
|
||||
await chat.sendMessage({ message: 'hello' }, 'prompt-id-1');
|
||||
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
{
|
||||
model: 'gemini-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
||||
config: {},
|
||||
},
|
||||
'prompt-id-1',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('sendMessageStream', () => {
|
||||
it('should succeed if a tool call is followed by an empty part', async () => {
|
||||
// 1. Mock a stream that contains a tool call, then an invalid (empty) part.
|
||||
@@ -1163,20 +908,15 @@ describe('GeminiChat', () => {
|
||||
})(),
|
||||
);
|
||||
|
||||
// This helper function consumes the stream and allows us to test for rejection.
|
||||
async function consumeStreamAndExpectError() {
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ message: 'test' },
|
||||
'prompt-id-retry-fail',
|
||||
);
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ message: 'test' },
|
||||
'prompt-id-retry-fail',
|
||||
);
|
||||
await expect(async () => {
|
||||
for await (const _ of stream) {
|
||||
// Must loop to trigger the internal logic that throws.
|
||||
}
|
||||
}
|
||||
|
||||
await expect(consumeStreamAndExpectError()).rejects.toThrow(
|
||||
EmptyStreamError,
|
||||
);
|
||||
}).rejects.toThrow(EmptyStreamError);
|
||||
|
||||
// Should be called 3 times (initial + 2 retries)
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(
|
||||
@@ -1268,87 +1008,6 @@ describe('GeminiChat', () => {
|
||||
expect(turn4.parts[0].text).toBe('Second answer');
|
||||
});
|
||||
|
||||
describe('concurrency control', () => {
|
||||
it('should queue a subsequent sendMessage call until the first one completes', async () => {
|
||||
// 1. Create promises to manually control when the API calls resolve
|
||||
let firstCallResolver: (value: GenerateContentResponse) => void;
|
||||
const firstCallPromise = new Promise<GenerateContentResponse>(
|
||||
(resolve) => {
|
||||
firstCallResolver = resolve;
|
||||
},
|
||||
);
|
||||
|
||||
let secondCallResolver: (value: GenerateContentResponse) => void;
|
||||
const secondCallPromise = new Promise<GenerateContentResponse>(
|
||||
(resolve) => {
|
||||
secondCallResolver = resolve;
|
||||
},
|
||||
);
|
||||
|
||||
// A standard response body for the mock
|
||||
const mockResponse = {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'response' }], role: 'model' },
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
|
||||
// 2. Mock the API to return our controllable promises in order
|
||||
vi.mocked(mockContentGenerator.generateContent)
|
||||
.mockReturnValueOnce(firstCallPromise)
|
||||
.mockReturnValueOnce(secondCallPromise);
|
||||
|
||||
// 3. Start the first message call. Do not await it yet.
|
||||
const firstMessagePromise = chat.sendMessage(
|
||||
{ message: 'first' },
|
||||
'prompt-1',
|
||||
);
|
||||
|
||||
// Give the event loop a chance to run the async call up to the `await`
|
||||
await new Promise(process.nextTick);
|
||||
|
||||
// 4. While the first call is "in-flight", start the second message call.
|
||||
const secondMessagePromise = chat.sendMessage(
|
||||
{ message: 'second' },
|
||||
'prompt-2',
|
||||
);
|
||||
|
||||
// 5. CRUCIAL CHECK: At this point, only the first API call should have been made.
|
||||
// The second call should be waiting on `sendPromise`.
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(1);
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
contents: expect.arrayContaining([
|
||||
expect.objectContaining({ parts: [{ text: 'first' }] }),
|
||||
]),
|
||||
}),
|
||||
'prompt-1',
|
||||
);
|
||||
|
||||
// 6. Unblock the first API call and wait for the first message to fully complete.
|
||||
firstCallResolver!(mockResponse);
|
||||
await firstMessagePromise;
|
||||
|
||||
// Give the event loop a chance to unblock and run the second call.
|
||||
await new Promise(process.nextTick);
|
||||
|
||||
// 7. CRUCIAL CHECK: Now, the second API call should have been made.
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(2);
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
contents: expect.arrayContaining([
|
||||
expect.objectContaining({ parts: [{ text: 'second' }] }),
|
||||
]),
|
||||
}),
|
||||
'prompt-2',
|
||||
);
|
||||
|
||||
// 8. Clean up by resolving the second call.
|
||||
secondCallResolver!(mockResponse);
|
||||
await secondMessagePromise;
|
||||
});
|
||||
});
|
||||
it('should retry if the model returns a completely empty stream (no chunks)', async () => {
|
||||
// 1. Mock the API to return an empty stream first, then a valid one.
|
||||
vi.mocked(mockContentGenerator.generateContentStream)
|
||||
@@ -1510,40 +1169,6 @@ describe('GeminiChat', () => {
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
|
||||
it('should use the configured model when not in fallback mode (sendMessage)', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-2.5-pro');
|
||||
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(false);
|
||||
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
|
||||
mockResponse,
|
||||
);
|
||||
|
||||
await chat.sendMessage({ message: 'test' }, 'prompt-id-res1');
|
||||
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'gemini-2.5-pro',
|
||||
}),
|
||||
'prompt-id-res1',
|
||||
);
|
||||
});
|
||||
|
||||
it('should use the FLASH model when in fallback mode (sendMessage)', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-2.5-pro');
|
||||
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true);
|
||||
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
|
||||
mockResponse,
|
||||
);
|
||||
|
||||
await chat.sendMessage({ message: 'test' }, 'prompt-id-res2');
|
||||
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
}),
|
||||
'prompt-id-res2',
|
||||
);
|
||||
});
|
||||
|
||||
it('should use the FLASH model when in fallback mode (sendMessageStream)', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-pro');
|
||||
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true);
|
||||
@@ -1620,26 +1245,41 @@ describe('GeminiChat', () => {
|
||||
const isInFallbackModeSpy = vi.spyOn(mockConfig, 'isInFallbackMode');
|
||||
isInFallbackModeSpy.mockReturnValue(false);
|
||||
|
||||
vi.mocked(mockContentGenerator.generateContent)
|
||||
vi.mocked(mockContentGenerator.generateContentStream)
|
||||
.mockRejectedValueOnce(error429) // Attempt 1 fails
|
||||
.mockResolvedValueOnce({
|
||||
candidates: [{ content: { parts: [{ text: 'Success on retry' }] } }],
|
||||
} as unknown as GenerateContentResponse); // Attempt 2 succeeds
|
||||
.mockResolvedValueOnce(
|
||||
// Attempt 2 succeeds
|
||||
(async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Success on retry' }] },
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})(),
|
||||
);
|
||||
|
||||
mockHandleFallback.mockImplementation(async () => {
|
||||
isInFallbackModeSpy.mockReturnValue(true);
|
||||
return true; // Signal retry
|
||||
});
|
||||
|
||||
const result = await chat.sendMessage(
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ message: 'trigger 429' },
|
||||
'prompt-id-fb1',
|
||||
);
|
||||
|
||||
expect(mockRetryWithBackoff).toHaveBeenCalledTimes(1);
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(2);
|
||||
expect(mockHandleFallback).toHaveBeenCalledTimes(1);
|
||||
// Consume stream to trigger logic
|
||||
for await (const _ of stream) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(
|
||||
2,
|
||||
);
|
||||
expect(mockHandleFallback).toHaveBeenCalledTimes(1);
|
||||
expect(mockHandleFallback).toHaveBeenCalledWith(
|
||||
mockConfig,
|
||||
FAILED_MODEL,
|
||||
@@ -1647,23 +1287,34 @@ describe('GeminiChat', () => {
|
||||
error429,
|
||||
);
|
||||
|
||||
expect(result.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
|
||||
'Success on retry',
|
||||
);
|
||||
const history = chat.getHistory();
|
||||
const modelTurn = history[1]!;
|
||||
expect(modelTurn.parts![0]!.text).toBe('Success on retry');
|
||||
});
|
||||
|
||||
it('should stop retrying if handleFallback returns false (e.g., auth intent)', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-pro');
|
||||
vi.mocked(mockContentGenerator.generateContent).mockRejectedValue(
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockRejectedValue(
|
||||
error429,
|
||||
);
|
||||
mockHandleFallback.mockResolvedValue(false);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ message: 'test stop' },
|
||||
'prompt-id-fb2',
|
||||
);
|
||||
|
||||
await expect(
|
||||
chat.sendMessage({ message: 'test stop' }, 'prompt-id-fb2'),
|
||||
(async () => {
|
||||
for await (const _ of stream) {
|
||||
/* consume stream */
|
||||
}
|
||||
})(),
|
||||
).rejects.toThrow(error429);
|
||||
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(1);
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(
|
||||
1,
|
||||
);
|
||||
expect(mockHandleFallback).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -182,138 +182,6 @@ export class GeminiChat {
|
||||
setSystemInstruction(sysInstr: string) {
|
||||
this.generationConfig.systemInstruction = sysInstr;
|
||||
}
|
||||
/**
|
||||
* Sends a message to the model and returns the response.
|
||||
*
|
||||
* @remarks
|
||||
* This method will wait for the previous message to be processed before
|
||||
* sending the next message.
|
||||
*
|
||||
* @see {@link Chat#sendMessageStream} for streaming method.
|
||||
* @param params - parameters for sending messages within a chat session.
|
||||
* @returns The model's response.
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* const chat = ai.chats.create({model: 'gemini-2.0-flash'});
|
||||
* const response = await chat.sendMessage({
|
||||
* message: 'Why is the sky blue?'
|
||||
* });
|
||||
* console.log(response.text);
|
||||
* ```
|
||||
*/
|
||||
async sendMessage(
|
||||
params: SendMessageParameters,
|
||||
prompt_id: string,
|
||||
): Promise<GenerateContentResponse> {
|
||||
await this.sendPromise;
|
||||
const userContent = createUserContent(params.message);
|
||||
|
||||
// Record user input - capture complete message with all parts (text, files, images, etc.)
|
||||
// but skip recording function responses (tool call results) as they should be stored in tool call records
|
||||
if (!isFunctionResponse(userContent)) {
|
||||
const userMessage = Array.isArray(params.message)
|
||||
? params.message
|
||||
: [params.message];
|
||||
this.chatRecordingService.recordMessage({
|
||||
type: 'user',
|
||||
content: userMessage,
|
||||
});
|
||||
}
|
||||
const requestContents = this.getHistory(true).concat(userContent);
|
||||
|
||||
let response: GenerateContentResponse;
|
||||
|
||||
try {
|
||||
let currentAttemptModel: string | undefined;
|
||||
|
||||
const apiCall = () => {
|
||||
const modelToUse = this.config.isInFallbackMode()
|
||||
? DEFAULT_GEMINI_FLASH_MODEL
|
||||
: this.config.getModel();
|
||||
currentAttemptModel = modelToUse;
|
||||
|
||||
// Prevent Flash model calls immediately after quota error
|
||||
if (
|
||||
this.config.getQuotaErrorOccurred() &&
|
||||
modelToUse === DEFAULT_GEMINI_FLASH_MODEL
|
||||
) {
|
||||
throw new Error(
|
||||
'Please submit a new query to continue with the Flash model.',
|
||||
);
|
||||
}
|
||||
|
||||
return this.config.getContentGenerator().generateContent(
|
||||
{
|
||||
model: modelToUse,
|
||||
contents: requestContents,
|
||||
config: { ...this.generationConfig, ...params.config },
|
||||
},
|
||||
prompt_id,
|
||||
);
|
||||
};
|
||||
|
||||
const onPersistent429Callback = async (
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
) => {
|
||||
if (!currentAttemptModel) return null;
|
||||
return await handleFallback(
|
||||
this.config,
|
||||
currentAttemptModel,
|
||||
authType,
|
||||
error,
|
||||
);
|
||||
};
|
||||
|
||||
response = await retryWithBackoff(apiCall, {
|
||||
shouldRetry: (error: unknown) => {
|
||||
// Check for known error messages and codes.
|
||||
if (error instanceof Error && error.message) {
|
||||
if (isSchemaDepthError(error.message)) return false;
|
||||
if (error.message.includes('429')) return true;
|
||||
if (error.message.match(/5\d{2}/)) return true;
|
||||
}
|
||||
return false; // Don't retry other errors by default
|
||||
},
|
||||
onPersistent429: onPersistent429Callback,
|
||||
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||
});
|
||||
|
||||
this.sendPromise = (async () => {
|
||||
const outputContent = response.candidates?.[0]?.content;
|
||||
const modelOutput = outputContent ? [outputContent] : [];
|
||||
|
||||
// Because the AFC input contains the entire curated chat history in
|
||||
// addition to the new user input, we need to truncate the AFC history
|
||||
// to deduplicate the existing chat history.
|
||||
const fullAutomaticFunctionCallingHistory =
|
||||
response.automaticFunctionCallingHistory;
|
||||
const index = this.getHistory(true).length;
|
||||
let automaticFunctionCallingHistory: Content[] = [];
|
||||
if (fullAutomaticFunctionCallingHistory != null) {
|
||||
automaticFunctionCallingHistory =
|
||||
fullAutomaticFunctionCallingHistory.slice(index) ?? [];
|
||||
}
|
||||
|
||||
this.recordHistory(
|
||||
userContent,
|
||||
modelOutput,
|
||||
automaticFunctionCallingHistory,
|
||||
);
|
||||
})();
|
||||
await this.sendPromise.catch((error) => {
|
||||
// Resets sendPromise to avoid subsequent calls failing
|
||||
this.sendPromise = Promise.resolve();
|
||||
// Re-throw the error so the caller knows something went wrong.
|
||||
throw error;
|
||||
});
|
||||
return response;
|
||||
} catch (error) {
|
||||
this.sendPromise = Promise.resolve();
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a message to the model and returns the response in chunks.
|
||||
@@ -704,35 +572,18 @@ export class GeminiChat {
|
||||
this.recordHistory(userInput, modelOutput);
|
||||
}
|
||||
|
||||
private recordHistory(
|
||||
userInput: Content,
|
||||
modelOutput: Content[],
|
||||
automaticFunctionCallingHistory?: Content[],
|
||||
) {
|
||||
private recordHistory(userInput: Content, modelOutput: Content[]) {
|
||||
// Part 1: Handle the user's turn.
|
||||
if (
|
||||
automaticFunctionCallingHistory &&
|
||||
automaticFunctionCallingHistory.length > 0
|
||||
) {
|
||||
this.history.push(
|
||||
...extractCuratedHistory(automaticFunctionCallingHistory),
|
||||
);
|
||||
} else {
|
||||
if (
|
||||
this.history.length === 0 ||
|
||||
this.history[this.history.length - 1] !== userInput
|
||||
) {
|
||||
const lastTurn = this.history[this.history.length - 1];
|
||||
// The only time we don't push is if it's the *exact same* object,
|
||||
// which happens in streaming where we add it preemptively.
|
||||
if (lastTurn !== userInput) {
|
||||
if (lastTurn?.role === 'user') {
|
||||
// This is an invalid sequence.
|
||||
throw new Error('Cannot add a user turn after another user turn.');
|
||||
}
|
||||
this.history.push(userInput);
|
||||
}
|
||||
|
||||
const lastTurn = this.history[this.history.length - 1];
|
||||
// The only time we don't push is if it's the *exact same* object,
|
||||
// which happens in streaming where we add it preemptively.
|
||||
if (lastTurn !== userInput) {
|
||||
if (lastTurn?.role === 'user') {
|
||||
// This is an invalid sequence.
|
||||
throw new Error('Cannot add a user turn after another user turn.');
|
||||
}
|
||||
this.history.push(userInput);
|
||||
}
|
||||
|
||||
// Part 2: Process the model output into a final, consolidated list of turns.
|
||||
|
||||
Reference in New Issue
Block a user