diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 19aa737053..7a546bd6ad 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -16,7 +16,7 @@ import { import type { Content, GenerateContentResponse, Part } from '@google/genai'; import { - findIndexAfterFraction, + findCompressSplitPoint, isThinkingDefault, isThinkingSupported, GeminiClient, @@ -126,51 +126,70 @@ async function fromAsync(promise: AsyncGenerator): Promise { } describe('findIndexAfterFraction', () => { - const history: Content[] = [ - { role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 - { role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 - { role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 - { role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 - { role: 'user', parts: [{ text: 'This is the fifth message.' }] }, // JSON length: 65 - ]; - // Total length: 333 - it('should throw an error for non-positive numbers', () => { - expect(() => findIndexAfterFraction(history, 0)).toThrow( + expect(() => findCompressSplitPoint([], 0)).toThrow( 'Fraction must be between 0 and 1', ); }); it('should throw an error for a fraction greater than or equal to 1', () => { - expect(() => findIndexAfterFraction(history, 1)).toThrow( + expect(() => findCompressSplitPoint([], 1)).toThrow( 'Fraction must be between 0 and 1', ); }); - it('should handle a fraction in the middle', () => { - // 333 * 0.5 = 166.5 - // 0: 66 - // 1: 66 + 68 = 134 - // 2: 134 + 66 = 200 - // 200 >= 166.5, so index is 3 - expect(findIndexAfterFraction(history, 0.5)).toBe(3); - }); - - it('should handle a fraction that results in the last index', () => { - // 333 * 0.9 = 299.7 - // ... - // 3: 200 + 68 = 268 - // 4: 268 + 65 = 333 - // 333 >= 299.7, so index is 5 - expect(findIndexAfterFraction(history, 0.9)).toBe(5); - }); - it('should handle an empty history', () => { - expect(findIndexAfterFraction([], 0.5)).toBe(0); + expect(findCompressSplitPoint([], 0.5)).toBe(0); + }); + + it('should handle a fraction in the middle', () => { + const history: Content[] = [ + { role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 (19%) + { role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 (40%) + { role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 (60%) + { role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 (80%) + { role: 'user', parts: [{ text: 'This is the fifth message.' }] }, // JSON length: 65 (100%) + ]; + expect(findCompressSplitPoint(history, 0.5)).toBe(2); + }); + + it('should handle a fraction of last index', () => { + const history: Content[] = [ + { role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 (19%) + { role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 (40%) + { role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 (60%) + { role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 (80%) + { role: 'user', parts: [{ text: 'This is the fifth message.' }] }, // JSON length: 65 (100%) + ]; + expect(findCompressSplitPoint(history, 0.9)).toBe(4); + }); + + it('should handle a fraction of after last index', () => { + const history: Content[] = [ + { role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 (24%%) + { role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 (50%) + { role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 (74%) + { role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 (100%) + ]; + expect(findCompressSplitPoint(history, 0.8)).toBe(4); + }); + + it('should return earlier splitpoint if no valid ones are after threshhold', () => { + const history: Content[] = [ + { role: 'user', parts: [{ text: 'This is the first message.' }] }, + { role: 'model', parts: [{ text: 'This is the second message.' }] }, + { role: 'user', parts: [{ text: 'This is the third message.' }] }, + { role: 'model', parts: [{ functionCall: {} }] }, + ]; + // Can't return 4 because the previous item has a function call. + expect(findCompressSplitPoint(history, 0.99)).toBe(2); }); it('should handle a history with only one item', () => { - expect(findIndexAfterFraction(history.slice(0, 1), 0.5)).toBe(1); + const historyWithEmptyParts: Content[] = [ + { role: 'user', parts: [{ text: 'Message 1' }] }, + ]; + expect(findCompressSplitPoint(historyWithEmptyParts, 0.5)).toBe(0); }); it('should handle history with weird parts', () => { @@ -179,7 +198,7 @@ describe('findIndexAfterFraction', () => { { role: 'model', parts: [{ fileData: { fileUri: 'derp' } }] }, { role: 'user', parts: [{ text: 'Message 2' }] }, ]; - expect(findIndexAfterFraction(historyWithEmptyParts, 0.5)).toBe(2); + expect(findCompressSplitPoint(historyWithEmptyParts, 0.5)).toBe(2); }); }); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 1d98bc2628..587daeae7a 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -26,7 +26,6 @@ import { reportError } from '../utils/errorReporting.js'; import { GeminiChat } from './geminiChat.js'; import { retryWithBackoff } from '../utils/retry.js'; import { getErrorMessage } from '../utils/errors.js'; -import { isFunctionResponse } from '../utils/messageInspectors.js'; import { tokenLimit } from './tokenLimits.js'; import type { ChatRecordingService } from '../services/chatRecordingService.js'; import type { ContentGenerator } from './contentGenerator.js'; @@ -63,36 +62,51 @@ export function isThinkingDefault(model: string) { } /** - * Returns the index of the content after the fraction of the total characters in the history. + * Returns the index of the oldest item to keep when compressing. May return + * contents.length which indicates that everything should be compressed. * * Exported for testing purposes. */ -export function findIndexAfterFraction( - history: Content[], +export function findCompressSplitPoint( + contents: Content[], fraction: number, ): number { if (fraction <= 0 || fraction >= 1) { throw new Error('Fraction must be between 0 and 1'); } - const contentLengths = history.map( - (content) => JSON.stringify(content).length, - ); + const charCounts = contents.map((content) => JSON.stringify(content).length); + const totalCharCount = charCounts.reduce((a, b) => a + b, 0); + const targetCharCount = totalCharCount * fraction; - const totalCharacters = contentLengths.reduce( - (sum, length) => sum + length, - 0, - ); - const targetCharacters = totalCharacters * fraction; - - let charactersSoFar = 0; - for (let i = 0; i < contentLengths.length; i++) { - if (charactersSoFar >= targetCharacters) { - return i; + let lastSplitPoint = 0; // 0 is always valid (compress nothing) + let cumulativeCharCount = 0; + for (let i = 0; i < contents.length; i++) { + cumulativeCharCount += charCounts[i]; + const content = contents[i]; + if ( + content.role === 'user' && + !content.parts?.some((part) => !!part.functionResponse) + ) { + if (cumulativeCharCount >= targetCharCount) { + return i; + } + lastSplitPoint = i; } - charactersSoFar += contentLengths[i]; } - return contentLengths.length; + + // We found no split points after targetCharCount. + // Check if it's safe to compress everything. + const lastContent = contents[contents.length - 1]; + if ( + lastContent?.role === 'model' && + !lastContent?.parts?.some((part) => part.functionCall) + ) { + return contents.length; + } + + // Can't compress everything so just compress at last splitpoint. + return lastSplitPoint; } const MAX_TURNS = 100; @@ -696,21 +710,13 @@ export class GeminiClient { } } - let compressBeforeIndex = findIndexAfterFraction( + const splitPoint = findCompressSplitPoint( curatedHistory, 1 - COMPRESSION_PRESERVE_THRESHOLD, ); - // Find the first user message after the index. This is the start of the next turn. - while ( - compressBeforeIndex < curatedHistory.length && - (curatedHistory[compressBeforeIndex]?.role === 'model' || - isFunctionResponse(curatedHistory[compressBeforeIndex])) - ) { - compressBeforeIndex++; - } - const historyToCompress = curatedHistory.slice(0, compressBeforeIndex); - const historyToKeep = curatedHistory.slice(compressBeforeIndex); + const historyToCompress = curatedHistory.slice(0, splitPoint); + const historyToKeep = curatedHistory.slice(splitPoint); const summaryResponse = await this.config .getContentGenerator()