Improve compression splitpoint logic. (#8691)

This commit is contained in:
Tommaso Sciortino
2025-09-18 10:59:13 -07:00
committed by GitHub
parent f41db212ec
commit 853ae56e7e
2 changed files with 88 additions and 63 deletions

View File

@@ -16,7 +16,7 @@ import {
import type { Content, GenerateContentResponse, Part } from '@google/genai';
import {
findIndexAfterFraction,
findCompressSplitPoint,
isThinkingDefault,
isThinkingSupported,
GeminiClient,
@@ -126,51 +126,70 @@ async function fromAsync<T>(promise: AsyncGenerator<T>): Promise<readonly T[]> {
}
describe('findIndexAfterFraction', () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66
{ role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68
{ role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68
{ role: 'user', parts: [{ text: 'This is the fifth message.' }] }, // JSON length: 65
];
// Total length: 333
it('should throw an error for non-positive numbers', () => {
expect(() => findIndexAfterFraction(history, 0)).toThrow(
expect(() => findCompressSplitPoint([], 0)).toThrow(
'Fraction must be between 0 and 1',
);
});
it('should throw an error for a fraction greater than or equal to 1', () => {
expect(() => findIndexAfterFraction(history, 1)).toThrow(
expect(() => findCompressSplitPoint([], 1)).toThrow(
'Fraction must be between 0 and 1',
);
});
it('should handle a fraction in the middle', () => {
// 333 * 0.5 = 166.5
// 0: 66
// 1: 66 + 68 = 134
// 2: 134 + 66 = 200
// 200 >= 166.5, so index is 3
expect(findIndexAfterFraction(history, 0.5)).toBe(3);
});
it('should handle a fraction that results in the last index', () => {
// 333 * 0.9 = 299.7
// ...
// 3: 200 + 68 = 268
// 4: 268 + 65 = 333
// 333 >= 299.7, so index is 5
expect(findIndexAfterFraction(history, 0.9)).toBe(5);
});
it('should handle an empty history', () => {
expect(findIndexAfterFraction([], 0.5)).toBe(0);
expect(findCompressSplitPoint([], 0.5)).toBe(0);
});
it('should handle a fraction in the middle', () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 (19%)
{ role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 (40%)
{ role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 (60%)
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 (80%)
{ role: 'user', parts: [{ text: 'This is the fifth message.' }] }, // JSON length: 65 (100%)
];
expect(findCompressSplitPoint(history, 0.5)).toBe(2);
});
it('should handle a fraction of last index', () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 (19%)
{ role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 (40%)
{ role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 (60%)
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 (80%)
{ role: 'user', parts: [{ text: 'This is the fifth message.' }] }, // JSON length: 65 (100%)
];
expect(findCompressSplitPoint(history, 0.9)).toBe(4);
});
it('should handle a fraction of after last index', () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 (24%%)
{ role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 (50%)
{ role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 (74%)
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 (100%)
];
expect(findCompressSplitPoint(history, 0.8)).toBe(4);
});
it('should return earlier splitpoint if no valid ones are after threshhold', () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'This is the first message.' }] },
{ role: 'model', parts: [{ text: 'This is the second message.' }] },
{ role: 'user', parts: [{ text: 'This is the third message.' }] },
{ role: 'model', parts: [{ functionCall: {} }] },
];
// Can't return 4 because the previous item has a function call.
expect(findCompressSplitPoint(history, 0.99)).toBe(2);
});
it('should handle a history with only one item', () => {
expect(findIndexAfterFraction(history.slice(0, 1), 0.5)).toBe(1);
const historyWithEmptyParts: Content[] = [
{ role: 'user', parts: [{ text: 'Message 1' }] },
];
expect(findCompressSplitPoint(historyWithEmptyParts, 0.5)).toBe(0);
});
it('should handle history with weird parts', () => {
@@ -179,7 +198,7 @@ describe('findIndexAfterFraction', () => {
{ role: 'model', parts: [{ fileData: { fileUri: 'derp' } }] },
{ role: 'user', parts: [{ text: 'Message 2' }] },
];
expect(findIndexAfterFraction(historyWithEmptyParts, 0.5)).toBe(2);
expect(findCompressSplitPoint(historyWithEmptyParts, 0.5)).toBe(2);
});
});

View File

@@ -26,7 +26,6 @@ import { reportError } from '../utils/errorReporting.js';
import { GeminiChat } from './geminiChat.js';
import { retryWithBackoff } from '../utils/retry.js';
import { getErrorMessage } from '../utils/errors.js';
import { isFunctionResponse } from '../utils/messageInspectors.js';
import { tokenLimit } from './tokenLimits.js';
import type { ChatRecordingService } from '../services/chatRecordingService.js';
import type { ContentGenerator } from './contentGenerator.js';
@@ -63,36 +62,51 @@ export function isThinkingDefault(model: string) {
}
/**
* Returns the index of the content after the fraction of the total characters in the history.
* Returns the index of the oldest item to keep when compressing. May return
* contents.length which indicates that everything should be compressed.
*
* Exported for testing purposes.
*/
export function findIndexAfterFraction(
history: Content[],
export function findCompressSplitPoint(
contents: Content[],
fraction: number,
): number {
if (fraction <= 0 || fraction >= 1) {
throw new Error('Fraction must be between 0 and 1');
}
const contentLengths = history.map(
(content) => JSON.stringify(content).length,
);
const charCounts = contents.map((content) => JSON.stringify(content).length);
const totalCharCount = charCounts.reduce((a, b) => a + b, 0);
const targetCharCount = totalCharCount * fraction;
const totalCharacters = contentLengths.reduce(
(sum, length) => sum + length,
0,
);
const targetCharacters = totalCharacters * fraction;
let charactersSoFar = 0;
for (let i = 0; i < contentLengths.length; i++) {
if (charactersSoFar >= targetCharacters) {
return i;
let lastSplitPoint = 0; // 0 is always valid (compress nothing)
let cumulativeCharCount = 0;
for (let i = 0; i < contents.length; i++) {
cumulativeCharCount += charCounts[i];
const content = contents[i];
if (
content.role === 'user' &&
!content.parts?.some((part) => !!part.functionResponse)
) {
if (cumulativeCharCount >= targetCharCount) {
return i;
}
lastSplitPoint = i;
}
charactersSoFar += contentLengths[i];
}
return contentLengths.length;
// We found no split points after targetCharCount.
// Check if it's safe to compress everything.
const lastContent = contents[contents.length - 1];
if (
lastContent?.role === 'model' &&
!lastContent?.parts?.some((part) => part.functionCall)
) {
return contents.length;
}
// Can't compress everything so just compress at last splitpoint.
return lastSplitPoint;
}
const MAX_TURNS = 100;
@@ -696,21 +710,13 @@ export class GeminiClient {
}
}
let compressBeforeIndex = findIndexAfterFraction(
const splitPoint = findCompressSplitPoint(
curatedHistory,
1 - COMPRESSION_PRESERVE_THRESHOLD,
);
// Find the first user message after the index. This is the start of the next turn.
while (
compressBeforeIndex < curatedHistory.length &&
(curatedHistory[compressBeforeIndex]?.role === 'model' ||
isFunctionResponse(curatedHistory[compressBeforeIndex]))
) {
compressBeforeIndex++;
}
const historyToCompress = curatedHistory.slice(0, compressBeforeIndex);
const historyToKeep = curatedHistory.slice(compressBeforeIndex);
const historyToCompress = curatedHistory.slice(0, splitPoint);
const historyToKeep = curatedHistory.slice(splitPoint);
const summaryResponse = await this.config
.getContentGenerator()