mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
fix(core): Include the latest user request in countTokens for compression (#8375)
This commit is contained in:
@@ -15,6 +15,7 @@ import {
|
||||
} from 'vitest';
|
||||
|
||||
import type { Content, GenerateContentResponse, Part } from '@google/genai';
|
||||
import { createUserContent } from '@google/genai';
|
||||
import {
|
||||
findIndexAfterFraction,
|
||||
isThinkingDefault,
|
||||
@@ -578,8 +579,12 @@ describe('Gemini Client (client.ts)', () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
||||
totalTokens: 1000,
|
||||
});
|
||||
await client.tryCompressChat('prompt-id-4'); // Fails
|
||||
const result = await client.tryCompressChat('prompt-id-4', true);
|
||||
await client.tryCompressChat('prompt-id-4', false, [
|
||||
{ text: 'request' },
|
||||
]); // Fails
|
||||
const result = await client.tryCompressChat('prompt-id-4', true, [
|
||||
{ text: 'request' },
|
||||
]);
|
||||
|
||||
expect(result).toEqual({
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
@@ -593,7 +598,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
||||
totalTokens: 1000,
|
||||
});
|
||||
const result = await client.tryCompressChat('prompt-id-4', true);
|
||||
const result = await client.tryCompressChat('prompt-id-4', false, [
|
||||
{ text: 'request' },
|
||||
]);
|
||||
|
||||
expect(result).toEqual({
|
||||
compressionStatus:
|
||||
@@ -605,7 +612,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
|
||||
it('does not manipulate the source chat', async () => {
|
||||
const { client, mockChat } = setup();
|
||||
await client.tryCompressChat('prompt-id-4', true);
|
||||
await client.tryCompressChat('prompt-id-4', false, [
|
||||
{ text: 'request' },
|
||||
]);
|
||||
|
||||
expect(client['chat']).toBe(mockChat); // a new chat session was not created
|
||||
});
|
||||
@@ -625,8 +634,11 @@ describe('Gemini Client (client.ts)', () => {
|
||||
const { client } = setup({
|
||||
chatHistory: originalHistory,
|
||||
});
|
||||
const { compressionStatus } =
|
||||
await client.tryCompressChat('prompt-id-4');
|
||||
const { compressionStatus } = await client.tryCompressChat(
|
||||
'prompt-id-4',
|
||||
false,
|
||||
[{ text: 'what is your wisdom?' }],
|
||||
);
|
||||
|
||||
expect(compressionStatus).toBe(
|
||||
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||
@@ -638,9 +650,13 @@ describe('Gemini Client (client.ts)', () => {
|
||||
|
||||
it('will not attempt to compress context after a failure', async () => {
|
||||
const { client } = setup();
|
||||
await client.tryCompressChat('prompt-id-4');
|
||||
await client.tryCompressChat('prompt-id-4', false, [
|
||||
{ text: 'request' },
|
||||
]);
|
||||
|
||||
const result = await client.tryCompressChat('prompt-id-5');
|
||||
const result = await client.tryCompressChat('prompt-id-5', false, [
|
||||
{ text: 'request' },
|
||||
]);
|
||||
|
||||
// it counts tokens for {original, compressed} and then never again
|
||||
expect(mockContentGenerator.countTokens).toHaveBeenCalledTimes(2);
|
||||
@@ -663,7 +679,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
});
|
||||
|
||||
const initialChat = client.getChat();
|
||||
const result = await client.tryCompressChat('prompt-id-2');
|
||||
const result = await client.tryCompressChat('prompt-id-2', false, [
|
||||
{ text: '...history...' },
|
||||
]);
|
||||
const newChat = client.getChat();
|
||||
|
||||
expect(tokenLimit).toHaveBeenCalled();
|
||||
@@ -708,7 +726,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
],
|
||||
} as unknown as GenerateContentResponse);
|
||||
|
||||
await client.tryCompressChat('prompt-id-3');
|
||||
await client.tryCompressChat('prompt-id-3', false, [
|
||||
{ text: '...history...' },
|
||||
]);
|
||||
|
||||
expect(
|
||||
ClearcutLogger.prototype.logChatCompressionEvent,
|
||||
@@ -752,7 +772,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
} as unknown as GenerateContentResponse);
|
||||
|
||||
const initialChat = client.getChat();
|
||||
const result = await client.tryCompressChat('prompt-id-3');
|
||||
const result = await client.tryCompressChat('prompt-id-3', false, [
|
||||
{ text: '...history...' },
|
||||
]);
|
||||
const newChat = client.getChat();
|
||||
|
||||
expect(tokenLimit).toHaveBeenCalled();
|
||||
@@ -811,7 +833,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
} as unknown as GenerateContentResponse);
|
||||
|
||||
const initialChat = client.getChat();
|
||||
const result = await client.tryCompressChat('prompt-id-3');
|
||||
const result = await client.tryCompressChat('prompt-id-3', false, [
|
||||
{ text: '...history...' },
|
||||
]);
|
||||
const newChat = client.getChat();
|
||||
|
||||
expect(tokenLimit).toHaveBeenCalled();
|
||||
@@ -831,7 +855,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
// 3. compressed summary message
|
||||
// 4. standard canned user summary message
|
||||
// 5. The last user message (not the last 3 because that would start with a function response)
|
||||
expect(newChat.getHistory().length).toEqual(5);
|
||||
expect(newChat.getHistory().length).toEqual(6);
|
||||
});
|
||||
|
||||
it('should always trigger summarization when force is true, regardless of token count', async () => {
|
||||
@@ -859,7 +883,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
} as unknown as GenerateContentResponse);
|
||||
|
||||
const initialChat = client.getChat();
|
||||
const result = await client.tryCompressChat('prompt-id-1', true); // force = true
|
||||
const result = await client.tryCompressChat('prompt-id-1', false, [
|
||||
{ text: '...history...' },
|
||||
]); // force = true
|
||||
const newChat = client.getChat();
|
||||
|
||||
expect(mockGenerateContentFn).toHaveBeenCalled();
|
||||
@@ -896,7 +922,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
];
|
||||
|
||||
const mockChat = {
|
||||
getHistory: vi.fn().mockReturnValue(mockChatHistory),
|
||||
getHistory: vi.fn().mockImplementation(() => [...mockChatHistory]),
|
||||
setHistory: vi.fn(),
|
||||
sendMessage: mockSendMessage,
|
||||
} as unknown as GeminiChat;
|
||||
@@ -904,12 +930,17 @@ describe('Gemini Client (client.ts)', () => {
|
||||
client['chat'] = mockChat;
|
||||
client['startChat'] = vi.fn().mockResolvedValue(mockChat);
|
||||
|
||||
const result = await client.tryCompressChat('prompt-id-4', true);
|
||||
const request = [{ text: 'Long conversation' }];
|
||||
const result = await client.tryCompressChat(
|
||||
'prompt-id-4',
|
||||
false,
|
||||
request,
|
||||
);
|
||||
|
||||
expect(mockContentGenerator.countTokens).toHaveBeenCalledTimes(2);
|
||||
expect(mockContentGenerator.countTokens).toHaveBeenNthCalledWith(1, {
|
||||
model: firstCurrentModel,
|
||||
contents: mockChatHistory,
|
||||
contents: [...mockChatHistory, createUserContent(request)],
|
||||
});
|
||||
expect(mockContentGenerator.countTokens).toHaveBeenNthCalledWith(2, {
|
||||
model: secondCurrentModel,
|
||||
@@ -1031,6 +1062,12 @@ describe('Gemini Client (client.ts)', () => {
|
||||
|
||||
vi.mocked(mockConfig.getIdeMode).mockReturnValue(true);
|
||||
|
||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValue({
|
||||
originalTokenCount: 0,
|
||||
newTokenCount: 0,
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
});
|
||||
|
||||
mockTurnRunFn.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: 'content', value: 'Hello' };
|
||||
@@ -1148,6 +1185,12 @@ ${JSON.stringify(
|
||||
|
||||
vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true);
|
||||
|
||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValue({
|
||||
originalTokenCount: 0,
|
||||
newTokenCount: 0,
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
});
|
||||
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'Hello' };
|
||||
})();
|
||||
@@ -1218,6 +1261,12 @@ ${JSON.stringify(
|
||||
|
||||
vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true);
|
||||
|
||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValue({
|
||||
originalTokenCount: 0,
|
||||
newTokenCount: 0,
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
});
|
||||
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'Hello' };
|
||||
})();
|
||||
|
||||
@@ -12,6 +12,7 @@ import type {
|
||||
Tool,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import { createUserContent } from '@google/genai';
|
||||
import {
|
||||
getDirectoryContextString,
|
||||
getEnvironmentContext,
|
||||
@@ -453,7 +454,7 @@ export class GeminiClient {
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
}
|
||||
|
||||
const compressed = await this.tryCompressChat(prompt_id);
|
||||
const compressed = await this.tryCompressChat(prompt_id, false, request);
|
||||
|
||||
if (compressed.compressionStatus === CompressionStatus.COMPRESSED) {
|
||||
yield { type: GeminiEventType.ChatCompressed, value: compressed };
|
||||
@@ -790,6 +791,7 @@ export class GeminiClient {
|
||||
async tryCompressChat(
|
||||
prompt_id: string,
|
||||
force: boolean = false,
|
||||
request?: PartListUnion,
|
||||
): Promise<ChatCompressionInfo> {
|
||||
// If the model is 'auto', we will use a placeholder model to check.
|
||||
// Compression occurs before we choose a model, so calling `count_tokens`
|
||||
@@ -805,6 +807,10 @@ export class GeminiClient {
|
||||
|
||||
const curatedHistory = this.getChat().getHistory(true);
|
||||
|
||||
if (request) {
|
||||
curatedHistory.push(createUserContent(request));
|
||||
}
|
||||
|
||||
// Regardless of `force`, don't do anything if the history is empty.
|
||||
if (
|
||||
curatedHistory.length === 0 ||
|
||||
|
||||
Reference in New Issue
Block a user