mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-30 06:54:15 -07:00
feat(core): Use lastPromptTokenCount to determine if we need to compress (#10000)
This commit is contained in:
@@ -115,6 +115,7 @@ vi.mock('../ide/ideContext.js');
|
|||||||
vi.mock('../telemetry/uiTelemetry.js', () => ({
|
vi.mock('../telemetry/uiTelemetry.js', () => ({
|
||||||
uiTelemetryService: {
|
uiTelemetryService: {
|
||||||
setLastPromptTokenCount: vi.fn(),
|
setLastPromptTokenCount: vi.fn(),
|
||||||
|
getLastPromptTokenCount: vi.fn(),
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
|
||||||
@@ -261,7 +262,6 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
mockContentGenerator = {
|
mockContentGenerator = {
|
||||||
generateContent: mockGenerateContentFn,
|
generateContent: mockGenerateContentFn,
|
||||||
generateContentStream: vi.fn(),
|
generateContentStream: vi.fn(),
|
||||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }),
|
|
||||||
batchEmbedContents: vi.fn(),
|
batchEmbedContents: vi.fn(),
|
||||||
} as unknown as ContentGenerator;
|
} as unknown as ContentGenerator;
|
||||||
|
|
||||||
@@ -402,102 +402,161 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
{ role: 'user', parts: [{ text: 'Long conversation' }] },
|
{ role: 'user', parts: [{ text: 'Long conversation' }] },
|
||||||
{ role: 'model', parts: [{ text: 'Long response' }] },
|
{ role: 'model', parts: [{ text: 'Long response' }] },
|
||||||
] as Content[],
|
] as Content[],
|
||||||
|
originalTokenCount = 1000,
|
||||||
|
summaryText = 'This is a summary.',
|
||||||
} = {}) {
|
} = {}) {
|
||||||
const mockChat: Partial<GeminiChat> = {
|
const mockOriginalChat: Partial<GeminiChat> = {
|
||||||
getHistory: vi.fn().mockReturnValue(chatHistory),
|
getHistory: vi.fn((_curated?: boolean) => chatHistory),
|
||||||
setHistory: vi.fn(),
|
setHistory: vi.fn(),
|
||||||
};
|
};
|
||||||
vi.mocked(mockContentGenerator.countTokens)
|
client['chat'] = mockOriginalChat as GeminiChat;
|
||||||
.mockResolvedValueOnce({ totalTokens: 1000 })
|
|
||||||
.mockResolvedValueOnce({ totalTokens: 5000 });
|
|
||||||
|
|
||||||
client['chat'] = mockChat as GeminiChat;
|
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(
|
||||||
client['startChat'] = vi.fn().mockResolvedValue({ ...mockChat });
|
originalTokenCount,
|
||||||
|
);
|
||||||
|
|
||||||
return { client, mockChat };
|
mockGenerateContentFn.mockResolvedValue({
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: summaryText }],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
} as unknown as GenerateContentResponse);
|
||||||
|
|
||||||
|
// Calculate what the new history will be
|
||||||
|
const splitPoint = findCompressSplitPoint(chatHistory, 0.7); // 1 - 0.3
|
||||||
|
const historyToKeep = chatHistory.slice(splitPoint);
|
||||||
|
|
||||||
|
// This is the history that the new chat will have.
|
||||||
|
// It includes the default startChat history + the extra history from tryCompressChat
|
||||||
|
const newCompressedHistory: Content[] = [
|
||||||
|
// Mocked envParts + canned response from startChat
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ text: 'Mocked env context' }],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: 'Got it. Thanks for the context!' }],
|
||||||
|
},
|
||||||
|
// extraHistory from tryCompressChat
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ text: summaryText }],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: 'Got it. Thanks for the additional context!' }],
|
||||||
|
},
|
||||||
|
...historyToKeep,
|
||||||
|
];
|
||||||
|
|
||||||
|
const mockNewChat: Partial<GeminiChat> = {
|
||||||
|
getHistory: vi.fn().mockReturnValue(newCompressedHistory),
|
||||||
|
setHistory: vi.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
client['startChat'] = vi
|
||||||
|
.fn()
|
||||||
|
.mockResolvedValue(mockNewChat as GeminiChat);
|
||||||
|
|
||||||
|
const totalChars = newCompressedHistory.reduce(
|
||||||
|
(total, content) => total + JSON.stringify(content).length,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
const estimatedNewTokenCount = Math.floor(totalChars / 4);
|
||||||
|
|
||||||
|
return {
|
||||||
|
client,
|
||||||
|
mockOriginalChat,
|
||||||
|
mockNewChat,
|
||||||
|
estimatedNewTokenCount,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
describe('when compression inflates the token count', () => {
|
describe('when compression inflates the token count', () => {
|
||||||
it('allows compression to be forced/manual after a failure', async () => {
|
it('allows compression to be forced/manual after a failure', async () => {
|
||||||
const { client } = setup();
|
// Call 1 (Fails): Setup with a long summary to inflate tokens
|
||||||
|
const longSummary = 'long summary '.repeat(100);
|
||||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
const { client, estimatedNewTokenCount: inflatedTokenCount } = setup({
|
||||||
totalTokens: 1000,
|
originalTokenCount: 100,
|
||||||
|
summaryText: longSummary,
|
||||||
});
|
});
|
||||||
|
expect(inflatedTokenCount).toBeGreaterThan(100); // Ensure setup is correct
|
||||||
|
|
||||||
await client.tryCompressChat('prompt-id-4', false); // Fails
|
await client.tryCompressChat('prompt-id-4', false); // Fails
|
||||||
const result = await client.tryCompressChat('prompt-id-4', true);
|
|
||||||
|
// Call 2 (Forced): Re-setup with a short summary
|
||||||
|
const shortSummary = 'short';
|
||||||
|
const { estimatedNewTokenCount: compressedTokenCount } = setup({
|
||||||
|
originalTokenCount: 100,
|
||||||
|
summaryText: shortSummary,
|
||||||
|
});
|
||||||
|
expect(compressedTokenCount).toBeLessThanOrEqual(100); // Ensure setup is correct
|
||||||
|
|
||||||
|
const result = await client.tryCompressChat('prompt-id-4', true); // Forced
|
||||||
|
|
||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
compressionStatus: CompressionStatus.COMPRESSED,
|
compressionStatus: CompressionStatus.COMPRESSED,
|
||||||
newTokenCount: 1000,
|
newTokenCount: compressedTokenCount,
|
||||||
originalTokenCount: 1000,
|
originalTokenCount: 100,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('yields the result even if the compression inflated the tokens', async () => {
|
it('yields the result even if the compression inflated the tokens', async () => {
|
||||||
const { client } = setup();
|
const longSummary = 'long summary '.repeat(100);
|
||||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
const { client, estimatedNewTokenCount } = setup({
|
||||||
totalTokens: 1000,
|
originalTokenCount: 100,
|
||||||
|
summaryText: longSummary,
|
||||||
});
|
});
|
||||||
|
expect(estimatedNewTokenCount).toBeGreaterThan(100); // Ensure setup is correct
|
||||||
|
|
||||||
const result = await client.tryCompressChat('prompt-id-4', false);
|
const result = await client.tryCompressChat('prompt-id-4', false);
|
||||||
|
|
||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
compressionStatus:
|
compressionStatus:
|
||||||
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||||
newTokenCount: 5000,
|
newTokenCount: estimatedNewTokenCount,
|
||||||
originalTokenCount: 1000,
|
originalTokenCount: 100,
|
||||||
});
|
});
|
||||||
expect(uiTelemetryService.setLastPromptTokenCount).toHaveBeenCalledWith(
|
// IMPORTANT: The change in client.ts means setLastPromptTokenCount is NOT called on failure
|
||||||
5000,
|
|
||||||
);
|
|
||||||
expect(
|
expect(
|
||||||
uiTelemetryService.setLastPromptTokenCount,
|
uiTelemetryService.setLastPromptTokenCount,
|
||||||
).toHaveBeenCalledTimes(1);
|
).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('does not manipulate the source chat', async () => {
|
it('does not manipulate the source chat', async () => {
|
||||||
const { client, mockChat } = setup();
|
const longSummary = 'long summary '.repeat(100);
|
||||||
|
const { client, mockOriginalChat, estimatedNewTokenCount } = setup({
|
||||||
|
originalTokenCount: 100,
|
||||||
|
summaryText: longSummary,
|
||||||
|
});
|
||||||
|
expect(estimatedNewTokenCount).toBeGreaterThan(100); // Ensure setup is correct
|
||||||
|
|
||||||
await client.tryCompressChat('prompt-id-4', false);
|
await client.tryCompressChat('prompt-id-4', false);
|
||||||
|
|
||||||
expect(client['chat']).toBe(mockChat); // a new chat session was not created
|
// On failure, the chat should NOT be replaced
|
||||||
});
|
expect(client['chat']).toBe(mockOriginalChat);
|
||||||
|
|
||||||
it('restores the history back to the original', async () => {
|
|
||||||
vi.mocked(tokenLimit).mockReturnValue(1000);
|
|
||||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
|
||||||
totalTokens: 999,
|
|
||||||
});
|
|
||||||
|
|
||||||
const originalHistory: Content[] = [
|
|
||||||
{ role: 'user', parts: [{ text: 'what is your wisdom?' }] },
|
|
||||||
{ role: 'model', parts: [{ text: 'some wisdom' }] },
|
|
||||||
{ role: 'user', parts: [{ text: 'ahh that is a good a wisdom' }] },
|
|
||||||
];
|
|
||||||
|
|
||||||
const { client } = setup({
|
|
||||||
chatHistory: originalHistory,
|
|
||||||
});
|
|
||||||
const { compressionStatus } = await client.tryCompressChat(
|
|
||||||
'prompt-id-4',
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
|
|
||||||
expect(compressionStatus).toBe(
|
|
||||||
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
|
||||||
);
|
|
||||||
expect(client['chat']?.setHistory).toHaveBeenCalledWith(
|
|
||||||
originalHistory,
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('will not attempt to compress context after a failure', async () => {
|
it('will not attempt to compress context after a failure', async () => {
|
||||||
const { client } = setup();
|
const longSummary = 'long summary '.repeat(100);
|
||||||
await client.tryCompressChat('prompt-id-4', false);
|
const { client, estimatedNewTokenCount } = setup({
|
||||||
|
originalTokenCount: 100,
|
||||||
|
summaryText: longSummary,
|
||||||
|
});
|
||||||
|
expect(estimatedNewTokenCount).toBeGreaterThan(100); // Ensure setup is correct
|
||||||
|
|
||||||
|
await client.tryCompressChat('prompt-id-4', false); // This fails and sets hasFailedCompressionAttempt = true
|
||||||
|
|
||||||
|
// This call should now be a NOOP
|
||||||
const result = await client.tryCompressChat('prompt-id-5', false);
|
const result = await client.tryCompressChat('prompt-id-5', false);
|
||||||
|
|
||||||
// it counts tokens for {original, compressed} and then never again
|
// generateContent (for summary) should only have been called once
|
||||||
expect(mockContentGenerator.countTokens).toHaveBeenCalledTimes(2);
|
expect(mockGenerateContentFn).toHaveBeenCalledTimes(1);
|
||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
compressionStatus: CompressionStatus.NOOP,
|
compressionStatus: CompressionStatus.NOOP,
|
||||||
newTokenCount: 0,
|
newTokenCount: 0,
|
||||||
@@ -512,9 +571,10 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
mockGetHistory.mockReturnValue([
|
mockGetHistory.mockReturnValue([
|
||||||
{ role: 'user', parts: [{ text: '...history...' }] },
|
{ role: 'user', parts: [{ text: '...history...' }] },
|
||||||
]);
|
]);
|
||||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
const originalTokenCount = MOCKED_TOKEN_LIMIT * 0.699;
|
||||||
totalTokens: MOCKED_TOKEN_LIMIT * 0.699, // TOKEN_THRESHOLD_FOR_SUMMARIZATION = 0.7
|
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(
|
||||||
});
|
originalTokenCount,
|
||||||
|
);
|
||||||
|
|
||||||
const initialChat = client.getChat();
|
const initialChat = client.getChat();
|
||||||
const result = await client.tryCompressChat('prompt-id-2', false);
|
const result = await client.tryCompressChat('prompt-id-2', false);
|
||||||
@@ -523,8 +583,8 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
expect(tokenLimit).toHaveBeenCalled();
|
expect(tokenLimit).toHaveBeenCalled();
|
||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
compressionStatus: CompressionStatus.NOOP,
|
compressionStatus: CompressionStatus.NOOP,
|
||||||
newTokenCount: 699,
|
newTokenCount: originalTokenCount,
|
||||||
originalTokenCount: 699,
|
originalTokenCount,
|
||||||
});
|
});
|
||||||
expect(newChat).toBe(initialChat);
|
expect(newChat).toBe(initialChat);
|
||||||
});
|
});
|
||||||
@@ -538,17 +598,43 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
vi.spyOn(client['config'], 'getChatCompression').mockReturnValue({
|
vi.spyOn(client['config'], 'getChatCompression').mockReturnValue({
|
||||||
contextPercentageThreshold: MOCKED_CONTEXT_PERCENTAGE_THRESHOLD,
|
contextPercentageThreshold: MOCKED_CONTEXT_PERCENTAGE_THRESHOLD,
|
||||||
});
|
});
|
||||||
mockGetHistory.mockReturnValue([
|
const history = [{ role: 'user', parts: [{ text: '...history...' }] }];
|
||||||
{ role: 'user', parts: [{ text: '...history...' }] },
|
mockGetHistory.mockReturnValue(history);
|
||||||
]);
|
|
||||||
|
|
||||||
const originalTokenCount =
|
const originalTokenCount =
|
||||||
MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD;
|
MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD;
|
||||||
const newTokenCount = 100;
|
|
||||||
|
|
||||||
vi.mocked(mockContentGenerator.countTokens)
|
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(
|
||||||
.mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check
|
originalTokenCount,
|
||||||
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
|
);
|
||||||
|
|
||||||
|
// We need to control the estimated new token count.
|
||||||
|
// We mock startChat to return a chat with a known history.
|
||||||
|
const summaryText = 'This is a summary.';
|
||||||
|
const splitPoint = findCompressSplitPoint(history, 0.7);
|
||||||
|
const historyToKeep = history.slice(splitPoint);
|
||||||
|
const newCompressedHistory: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'Mocked env context' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] },
|
||||||
|
{ role: 'user', parts: [{ text: summaryText }] },
|
||||||
|
{
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: 'Got it. Thanks for the additional context!' }],
|
||||||
|
},
|
||||||
|
...historyToKeep,
|
||||||
|
];
|
||||||
|
const mockNewChat: Partial<GeminiChat> = {
|
||||||
|
getHistory: vi.fn().mockReturnValue(newCompressedHistory),
|
||||||
|
};
|
||||||
|
client['startChat'] = vi
|
||||||
|
.fn()
|
||||||
|
.mockResolvedValue(mockNewChat as GeminiChat);
|
||||||
|
|
||||||
|
const totalChars = newCompressedHistory.reduce(
|
||||||
|
(total, content) => total + JSON.stringify(content).length,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
const newTokenCount = Math.floor(totalChars / 4);
|
||||||
|
|
||||||
// Mock the summary response from the chat
|
// Mock the summary response from the chat
|
||||||
mockGenerateContentFn.mockResolvedValue({
|
mockGenerateContentFn.mockResolvedValue({
|
||||||
@@ -556,7 +642,7 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
{
|
{
|
||||||
content: {
|
content: {
|
||||||
role: 'model',
|
role: 'model',
|
||||||
parts: [{ text: 'This is a summary.' }],
|
parts: [{ text: summaryText }],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -587,17 +673,42 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
vi.spyOn(client['config'], 'getChatCompression').mockReturnValue({
|
vi.spyOn(client['config'], 'getChatCompression').mockReturnValue({
|
||||||
contextPercentageThreshold: MOCKED_CONTEXT_PERCENTAGE_THRESHOLD,
|
contextPercentageThreshold: MOCKED_CONTEXT_PERCENTAGE_THRESHOLD,
|
||||||
});
|
});
|
||||||
mockGetHistory.mockReturnValue([
|
const history = [{ role: 'user', parts: [{ text: '...history...' }] }];
|
||||||
{ role: 'user', parts: [{ text: '...history...' }] },
|
mockGetHistory.mockReturnValue(history);
|
||||||
]);
|
|
||||||
|
|
||||||
const originalTokenCount =
|
const originalTokenCount =
|
||||||
MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD;
|
MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD;
|
||||||
const newTokenCount = 100;
|
|
||||||
|
|
||||||
vi.mocked(mockContentGenerator.countTokens)
|
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(
|
||||||
.mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check
|
originalTokenCount,
|
||||||
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
|
);
|
||||||
|
|
||||||
|
// Mock summary and new chat
|
||||||
|
const summaryText = 'This is a summary.';
|
||||||
|
const splitPoint = findCompressSplitPoint(history, 0.7);
|
||||||
|
const historyToKeep = history.slice(splitPoint);
|
||||||
|
const newCompressedHistory: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'Mocked env context' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] },
|
||||||
|
{ role: 'user', parts: [{ text: summaryText }] },
|
||||||
|
{
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: 'Got it. Thanks for the additional context!' }],
|
||||||
|
},
|
||||||
|
...historyToKeep,
|
||||||
|
];
|
||||||
|
const mockNewChat: Partial<GeminiChat> = {
|
||||||
|
getHistory: vi.fn().mockReturnValue(newCompressedHistory),
|
||||||
|
};
|
||||||
|
client['startChat'] = vi
|
||||||
|
.fn()
|
||||||
|
.mockResolvedValue(mockNewChat as GeminiChat);
|
||||||
|
|
||||||
|
const totalChars = newCompressedHistory.reduce(
|
||||||
|
(total, content) => total + JSON.stringify(content).length,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
const newTokenCount = Math.floor(totalChars / 4);
|
||||||
|
|
||||||
// Mock the summary response from the chat
|
// Mock the summary response from the chat
|
||||||
mockGenerateContentFn.mockResolvedValue({
|
mockGenerateContentFn.mockResolvedValue({
|
||||||
@@ -605,7 +716,7 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
{
|
{
|
||||||
content: {
|
content: {
|
||||||
role: 'model',
|
role: 'model',
|
||||||
parts: [{ text: 'This is a summary.' }],
|
parts: [{ text: summaryText }],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -632,7 +743,7 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
it('should not compress across a function call response', async () => {
|
it('should not compress across a function call response', async () => {
|
||||||
const MOCKED_TOKEN_LIMIT = 1000;
|
const MOCKED_TOKEN_LIMIT = 1000;
|
||||||
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
|
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
|
||||||
mockGetHistory.mockReturnValue([
|
const history: Content[] = [
|
||||||
{ role: 'user', parts: [{ text: '...history 1...' }] },
|
{ role: 'user', parts: [{ text: '...history 1...' }] },
|
||||||
{ role: 'model', parts: [{ text: '...history 2...' }] },
|
{ role: 'model', parts: [{ text: '...history 2...' }] },
|
||||||
{ role: 'user', parts: [{ text: '...history 3...' }] },
|
{ role: 'user', parts: [{ text: '...history 3...' }] },
|
||||||
@@ -649,14 +760,45 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
{ role: 'model', parts: [{ text: '...history 10...' }] },
|
{ role: 'model', parts: [{ text: '...history 10...' }] },
|
||||||
// Instead we will break here.
|
// Instead we will break here.
|
||||||
{ role: 'user', parts: [{ text: '...history 10...' }] },
|
{ role: 'user', parts: [{ text: '...history 10...' }] },
|
||||||
]);
|
];
|
||||||
|
mockGetHistory.mockReturnValue(history);
|
||||||
|
|
||||||
const originalTokenCount = 1000 * 0.7;
|
const originalTokenCount = 1000 * 0.7;
|
||||||
const newTokenCount = 100;
|
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(
|
||||||
|
originalTokenCount,
|
||||||
|
);
|
||||||
|
|
||||||
vi.mocked(mockContentGenerator.countTokens)
|
// Mock summary and new chat
|
||||||
.mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check
|
const summaryText = 'This is a summary.';
|
||||||
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
|
const splitPoint = findCompressSplitPoint(history, 0.7); // This should be 10
|
||||||
|
expect(splitPoint).toBe(10); // Verify split point logic
|
||||||
|
const historyToKeep = history.slice(splitPoint); // Should keep last user message
|
||||||
|
expect(historyToKeep).toEqual([
|
||||||
|
{ role: 'user', parts: [{ text: '...history 10...' }] },
|
||||||
|
]);
|
||||||
|
|
||||||
|
const newCompressedHistory: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'Mocked env context' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] },
|
||||||
|
{ role: 'user', parts: [{ text: summaryText }] },
|
||||||
|
{
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: 'Got it. Thanks for the additional context!' }],
|
||||||
|
},
|
||||||
|
...historyToKeep,
|
||||||
|
];
|
||||||
|
const mockNewChat: Partial<GeminiChat> = {
|
||||||
|
getHistory: vi.fn().mockReturnValue(newCompressedHistory),
|
||||||
|
};
|
||||||
|
client['startChat'] = vi
|
||||||
|
.fn()
|
||||||
|
.mockResolvedValue(mockNewChat as GeminiChat);
|
||||||
|
|
||||||
|
const totalChars = newCompressedHistory.reduce(
|
||||||
|
(total, content) => total + JSON.stringify(content).length,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
const newTokenCount = Math.floor(totalChars / 4);
|
||||||
|
|
||||||
// Mock the summary response from the chat
|
// Mock the summary response from the chat
|
||||||
mockGenerateContentFn.mockResolvedValue({
|
mockGenerateContentFn.mockResolvedValue({
|
||||||
@@ -664,7 +806,7 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
{
|
{
|
||||||
content: {
|
content: {
|
||||||
role: 'model',
|
role: 'model',
|
||||||
parts: [{ text: 'This is a summary.' }],
|
parts: [{ text: summaryText }],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -686,25 +828,49 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
// Assert that the chat was reset
|
// Assert that the chat was reset
|
||||||
expect(newChat).not.toBe(initialChat);
|
expect(newChat).not.toBe(initialChat);
|
||||||
|
|
||||||
// 1. standard start context message
|
// 1. standard start context message (env)
|
||||||
// 2. standard canned user start message
|
// 2. standard canned model response
|
||||||
// 3. compressed summary message
|
// 3. compressed summary message (user)
|
||||||
// 4. standard canned user summary message
|
// 4. standard canned model response
|
||||||
// 5. The last user message (not the last 3 because that would start with a function response)
|
// 5. The last user message (historyToKeep)
|
||||||
expect(newChat.getHistory().length).toEqual(5);
|
expect(newChat.getHistory().length).toEqual(5);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should always trigger summarization when force is true, regardless of token count', async () => {
|
it('should always trigger summarization when force is true, regardless of token count', async () => {
|
||||||
mockGetHistory.mockReturnValue([
|
const history = [{ role: 'user', parts: [{ text: '...history...' }] }];
|
||||||
{ role: 'user', parts: [{ text: '...history...' }] },
|
mockGetHistory.mockReturnValue(history);
|
||||||
]);
|
|
||||||
|
|
||||||
const originalTokenCount = 10; // Well below threshold
|
const originalTokenCount = 100; // Well below threshold, but > estimated new count
|
||||||
const newTokenCount = 5;
|
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(
|
||||||
|
originalTokenCount,
|
||||||
|
);
|
||||||
|
|
||||||
vi.mocked(mockContentGenerator.countTokens)
|
// Mock summary and new chat
|
||||||
.mockResolvedValueOnce({ totalTokens: originalTokenCount })
|
const summaryText = 'This is a summary.';
|
||||||
.mockResolvedValueOnce({ totalTokens: newTokenCount });
|
const splitPoint = findCompressSplitPoint(history, 0.7);
|
||||||
|
const historyToKeep = history.slice(splitPoint);
|
||||||
|
const newCompressedHistory: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'Mocked env context' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] },
|
||||||
|
{ role: 'user', parts: [{ text: summaryText }] },
|
||||||
|
{
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: 'Got it. Thanks for the additional context!' }],
|
||||||
|
},
|
||||||
|
...historyToKeep,
|
||||||
|
];
|
||||||
|
const mockNewChat: Partial<GeminiChat> = {
|
||||||
|
getHistory: vi.fn().mockReturnValue(newCompressedHistory),
|
||||||
|
};
|
||||||
|
client['startChat'] = vi
|
||||||
|
.fn()
|
||||||
|
.mockResolvedValue(mockNewChat as GeminiChat);
|
||||||
|
|
||||||
|
const totalChars = newCompressedHistory.reduce(
|
||||||
|
(total, content) => total + JSON.stringify(content).length,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
const newTokenCount = Math.floor(totalChars / 4);
|
||||||
|
|
||||||
// Mock the summary response from the chat
|
// Mock the summary response from the chat
|
||||||
mockGenerateContentFn.mockResolvedValue({
|
mockGenerateContentFn.mockResolvedValue({
|
||||||
@@ -712,14 +878,14 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
{
|
{
|
||||||
content: {
|
content: {
|
||||||
role: 'model',
|
role: 'model',
|
||||||
parts: [{ text: 'This is a summary.' }],
|
parts: [{ text: summaryText }],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
} as unknown as GenerateContentResponse);
|
} as unknown as GenerateContentResponse);
|
||||||
|
|
||||||
const initialChat = client.getChat();
|
const initialChat = client.getChat();
|
||||||
const result = await client.tryCompressChat('prompt-id-1', false); // force = true
|
const result = await client.tryCompressChat('prompt-id-1', true); // force = true
|
||||||
const newChat = client.getChat();
|
const newChat = client.getChat();
|
||||||
|
|
||||||
expect(mockGenerateContentFn).toHaveBeenCalled();
|
expect(mockGenerateContentFn).toHaveBeenCalled();
|
||||||
@@ -776,10 +942,6 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||||
},
|
},
|
||||||
{ compressionStatus: CompressionStatus.NOOP },
|
{ compressionStatus: CompressionStatus.NOOP },
|
||||||
{
|
|
||||||
compressionStatus:
|
|
||||||
CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
|
|
||||||
},
|
|
||||||
])(
|
])(
|
||||||
'does not emit a compression event when the status is $compressionStatus',
|
'does not emit a compression event when the status is $compressionStatus',
|
||||||
async ({ compressionStatus }) => {
|
async ({ compressionStatus }) => {
|
||||||
|
|||||||
@@ -679,21 +679,7 @@ export class GeminiClient {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
const { totalTokens: originalTokenCount } =
|
const originalTokenCount = uiTelemetryService.getLastPromptTokenCount();
|
||||||
await this.getContentGeneratorOrFail().countTokens({
|
|
||||||
model,
|
|
||||||
contents: curatedHistory,
|
|
||||||
});
|
|
||||||
if (originalTokenCount === undefined) {
|
|
||||||
console.warn(`Could not determine token count for model ${model}.`);
|
|
||||||
this.hasFailedCompressionAttempt = !force && true;
|
|
||||||
return {
|
|
||||||
originalTokenCount: 0,
|
|
||||||
newTokenCount: 0,
|
|
||||||
compressionStatus:
|
|
||||||
CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
const contextPercentageThreshold =
|
const contextPercentageThreshold =
|
||||||
this.config.getChatCompression()?.contextPercentageThreshold;
|
this.config.getChatCompression()?.contextPercentageThreshold;
|
||||||
@@ -756,23 +742,13 @@ export class GeminiClient {
|
|||||||
]);
|
]);
|
||||||
this.forceFullIdeContext = true;
|
this.forceFullIdeContext = true;
|
||||||
|
|
||||||
const { totalTokens: newTokenCount } =
|
// Estimate token count 1 token ≈ 4 characters
|
||||||
await this.getContentGeneratorOrFail().countTokens({
|
const newTokenCount = Math.floor(
|
||||||
model,
|
chat
|
||||||
contents: chat.getHistory(),
|
.getHistory()
|
||||||
});
|
.reduce((total, content) => total + JSON.stringify(content).length, 0) /
|
||||||
if (newTokenCount === undefined) {
|
4,
|
||||||
console.warn('Could not determine compressed history token count.');
|
);
|
||||||
this.hasFailedCompressionAttempt = !force && true;
|
|
||||||
return {
|
|
||||||
originalTokenCount,
|
|
||||||
newTokenCount: originalTokenCount,
|
|
||||||
compressionStatus:
|
|
||||||
CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
uiTelemetryService.setLastPromptTokenCount(newTokenCount);
|
|
||||||
|
|
||||||
logChatCompression(
|
logChatCompression(
|
||||||
this.config,
|
this.config,
|
||||||
@@ -783,7 +759,6 @@ export class GeminiClient {
|
|||||||
);
|
);
|
||||||
|
|
||||||
if (newTokenCount > originalTokenCount) {
|
if (newTokenCount > originalTokenCount) {
|
||||||
this.getChat().setHistory(curatedHistory);
|
|
||||||
this.hasFailedCompressionAttempt = !force && true;
|
this.hasFailedCompressionAttempt = !force && true;
|
||||||
return {
|
return {
|
||||||
originalTokenCount,
|
originalTokenCount,
|
||||||
@@ -793,6 +768,7 @@ export class GeminiClient {
|
|||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
this.chat = chat; // Chat compression successful, set new state.
|
this.chat = chat; // Chat compression successful, set new state.
|
||||||
|
uiTelemetryService.setLastPromptTokenCount(newTokenCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
Reference in New Issue
Block a user