mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
Improve compression splitpoint logic. (#8691)
This commit is contained in:
committed by
GitHub
parent
f41db212ec
commit
853ae56e7e
@@ -16,7 +16,7 @@ import {
|
||||
|
||||
import type { Content, GenerateContentResponse, Part } from '@google/genai';
|
||||
import {
|
||||
findIndexAfterFraction,
|
||||
findCompressSplitPoint,
|
||||
isThinkingDefault,
|
||||
isThinkingSupported,
|
||||
GeminiClient,
|
||||
@@ -126,51 +126,70 @@ async function fromAsync<T>(promise: AsyncGenerator<T>): Promise<readonly T[]> {
|
||||
}
|
||||
|
||||
describe('findIndexAfterFraction', () => {
|
||||
const history: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66
|
||||
{ role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68
|
||||
{ role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66
|
||||
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68
|
||||
{ role: 'user', parts: [{ text: 'This is the fifth message.' }] }, // JSON length: 65
|
||||
];
|
||||
// Total length: 333
|
||||
|
||||
it('should throw an error for non-positive numbers', () => {
|
||||
expect(() => findIndexAfterFraction(history, 0)).toThrow(
|
||||
expect(() => findCompressSplitPoint([], 0)).toThrow(
|
||||
'Fraction must be between 0 and 1',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error for a fraction greater than or equal to 1', () => {
|
||||
expect(() => findIndexAfterFraction(history, 1)).toThrow(
|
||||
expect(() => findCompressSplitPoint([], 1)).toThrow(
|
||||
'Fraction must be between 0 and 1',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle a fraction in the middle', () => {
|
||||
// 333 * 0.5 = 166.5
|
||||
// 0: 66
|
||||
// 1: 66 + 68 = 134
|
||||
// 2: 134 + 66 = 200
|
||||
// 200 >= 166.5, so index is 3
|
||||
expect(findIndexAfterFraction(history, 0.5)).toBe(3);
|
||||
});
|
||||
|
||||
it('should handle a fraction that results in the last index', () => {
|
||||
// 333 * 0.9 = 299.7
|
||||
// ...
|
||||
// 3: 200 + 68 = 268
|
||||
// 4: 268 + 65 = 333
|
||||
// 333 >= 299.7, so index is 5
|
||||
expect(findIndexAfterFraction(history, 0.9)).toBe(5);
|
||||
});
|
||||
|
||||
it('should handle an empty history', () => {
|
||||
expect(findIndexAfterFraction([], 0.5)).toBe(0);
|
||||
expect(findCompressSplitPoint([], 0.5)).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle a fraction in the middle', () => {
|
||||
const history: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 (19%)
|
||||
{ role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 (40%)
|
||||
{ role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 (60%)
|
||||
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 (80%)
|
||||
{ role: 'user', parts: [{ text: 'This is the fifth message.' }] }, // JSON length: 65 (100%)
|
||||
];
|
||||
expect(findCompressSplitPoint(history, 0.5)).toBe(2);
|
||||
});
|
||||
|
||||
it('should handle a fraction of last index', () => {
|
||||
const history: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 (19%)
|
||||
{ role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 (40%)
|
||||
{ role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 (60%)
|
||||
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 (80%)
|
||||
{ role: 'user', parts: [{ text: 'This is the fifth message.' }] }, // JSON length: 65 (100%)
|
||||
];
|
||||
expect(findCompressSplitPoint(history, 0.9)).toBe(4);
|
||||
});
|
||||
|
||||
it('should handle a fraction of after last index', () => {
|
||||
const history: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 (24%%)
|
||||
{ role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 (50%)
|
||||
{ role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 (74%)
|
||||
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 (100%)
|
||||
];
|
||||
expect(findCompressSplitPoint(history, 0.8)).toBe(4);
|
||||
});
|
||||
|
||||
it('should return earlier splitpoint if no valid ones are after threshhold', () => {
|
||||
const history: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'This is the first message.' }] },
|
||||
{ role: 'model', parts: [{ text: 'This is the second message.' }] },
|
||||
{ role: 'user', parts: [{ text: 'This is the third message.' }] },
|
||||
{ role: 'model', parts: [{ functionCall: {} }] },
|
||||
];
|
||||
// Can't return 4 because the previous item has a function call.
|
||||
expect(findCompressSplitPoint(history, 0.99)).toBe(2);
|
||||
});
|
||||
|
||||
it('should handle a history with only one item', () => {
|
||||
expect(findIndexAfterFraction(history.slice(0, 1), 0.5)).toBe(1);
|
||||
const historyWithEmptyParts: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'Message 1' }] },
|
||||
];
|
||||
expect(findCompressSplitPoint(historyWithEmptyParts, 0.5)).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle history with weird parts', () => {
|
||||
@@ -179,7 +198,7 @@ describe('findIndexAfterFraction', () => {
|
||||
{ role: 'model', parts: [{ fileData: { fileUri: 'derp' } }] },
|
||||
{ role: 'user', parts: [{ text: 'Message 2' }] },
|
||||
];
|
||||
expect(findIndexAfterFraction(historyWithEmptyParts, 0.5)).toBe(2);
|
||||
expect(findCompressSplitPoint(historyWithEmptyParts, 0.5)).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ import { reportError } from '../utils/errorReporting.js';
|
||||
import { GeminiChat } from './geminiChat.js';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||
import { tokenLimit } from './tokenLimits.js';
|
||||
import type { ChatRecordingService } from '../services/chatRecordingService.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
@@ -63,36 +62,51 @@ export function isThinkingDefault(model: string) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the index of the content after the fraction of the total characters in the history.
|
||||
* Returns the index of the oldest item to keep when compressing. May return
|
||||
* contents.length which indicates that everything should be compressed.
|
||||
*
|
||||
* Exported for testing purposes.
|
||||
*/
|
||||
export function findIndexAfterFraction(
|
||||
history: Content[],
|
||||
export function findCompressSplitPoint(
|
||||
contents: Content[],
|
||||
fraction: number,
|
||||
): number {
|
||||
if (fraction <= 0 || fraction >= 1) {
|
||||
throw new Error('Fraction must be between 0 and 1');
|
||||
}
|
||||
|
||||
const contentLengths = history.map(
|
||||
(content) => JSON.stringify(content).length,
|
||||
);
|
||||
const charCounts = contents.map((content) => JSON.stringify(content).length);
|
||||
const totalCharCount = charCounts.reduce((a, b) => a + b, 0);
|
||||
const targetCharCount = totalCharCount * fraction;
|
||||
|
||||
const totalCharacters = contentLengths.reduce(
|
||||
(sum, length) => sum + length,
|
||||
0,
|
||||
);
|
||||
const targetCharacters = totalCharacters * fraction;
|
||||
|
||||
let charactersSoFar = 0;
|
||||
for (let i = 0; i < contentLengths.length; i++) {
|
||||
if (charactersSoFar >= targetCharacters) {
|
||||
return i;
|
||||
let lastSplitPoint = 0; // 0 is always valid (compress nothing)
|
||||
let cumulativeCharCount = 0;
|
||||
for (let i = 0; i < contents.length; i++) {
|
||||
cumulativeCharCount += charCounts[i];
|
||||
const content = contents[i];
|
||||
if (
|
||||
content.role === 'user' &&
|
||||
!content.parts?.some((part) => !!part.functionResponse)
|
||||
) {
|
||||
if (cumulativeCharCount >= targetCharCount) {
|
||||
return i;
|
||||
}
|
||||
lastSplitPoint = i;
|
||||
}
|
||||
charactersSoFar += contentLengths[i];
|
||||
}
|
||||
return contentLengths.length;
|
||||
|
||||
// We found no split points after targetCharCount.
|
||||
// Check if it's safe to compress everything.
|
||||
const lastContent = contents[contents.length - 1];
|
||||
if (
|
||||
lastContent?.role === 'model' &&
|
||||
!lastContent?.parts?.some((part) => part.functionCall)
|
||||
) {
|
||||
return contents.length;
|
||||
}
|
||||
|
||||
// Can't compress everything so just compress at last splitpoint.
|
||||
return lastSplitPoint;
|
||||
}
|
||||
|
||||
const MAX_TURNS = 100;
|
||||
@@ -696,21 +710,13 @@ export class GeminiClient {
|
||||
}
|
||||
}
|
||||
|
||||
let compressBeforeIndex = findIndexAfterFraction(
|
||||
const splitPoint = findCompressSplitPoint(
|
||||
curatedHistory,
|
||||
1 - COMPRESSION_PRESERVE_THRESHOLD,
|
||||
);
|
||||
// Find the first user message after the index. This is the start of the next turn.
|
||||
while (
|
||||
compressBeforeIndex < curatedHistory.length &&
|
||||
(curatedHistory[compressBeforeIndex]?.role === 'model' ||
|
||||
isFunctionResponse(curatedHistory[compressBeforeIndex]))
|
||||
) {
|
||||
compressBeforeIndex++;
|
||||
}
|
||||
|
||||
const historyToCompress = curatedHistory.slice(0, compressBeforeIndex);
|
||||
const historyToKeep = curatedHistory.slice(compressBeforeIndex);
|
||||
const historyToCompress = curatedHistory.slice(0, splitPoint);
|
||||
const historyToKeep = curatedHistory.slice(splitPoint);
|
||||
|
||||
const summaryResponse = await this.config
|
||||
.getContentGenerator()
|
||||
|
||||
Reference in New Issue
Block a user