mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 05:12:55 -07:00
feat(context): Introduce adaptive token calculator to more accurately calculate content sizes. (#26888)
This commit is contained in:
@@ -37,8 +37,8 @@ describe('ContextManager Sync Pressure Barrier Tests', () => {
|
||||
]);
|
||||
|
||||
// 3. Add massive history that blows past the 150k maxTokens limit
|
||||
// 20 turns * 10,000 tokens/turn = ~200,000 tokens
|
||||
const massiveHistory = createSyntheticHistory(20, 35000);
|
||||
// 20 turns * ~20,000 tokens/turn (10k user + 10k model) = ~400,000 tokens
|
||||
const massiveHistory = createSyntheticHistory(20, 10000);
|
||||
chatHistory.set([...chatHistory.get(), ...massiveHistory]);
|
||||
|
||||
// 4. Add the Latest Turn (Protected)
|
||||
@@ -60,8 +60,8 @@ describe('ContextManager Sync Pressure Barrier Tests', () => {
|
||||
|
||||
// Verify Episode 0 (System) was pruned, so we now start with a sentinel due to role alternation
|
||||
expect(projection[0].role).toBe('user');
|
||||
expect(projection[0].parts![0].text).toContain('User turn 17');
|
||||
|
||||
const projectionString = JSON.stringify(projection);
|
||||
expect(projectionString).toContain('User turn 17');
|
||||
// Filter out synthetic Yield nodes (they are model responses without actual tool/text bodies)
|
||||
const contentNodes = projection.filter(
|
||||
(p) =>
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import {
|
||||
createMockContextConfig,
|
||||
setupContextComponentTest,
|
||||
createMockLlmClient,
|
||||
} from './testing/contextTestUtils.js';
|
||||
import { stressTestProfile } from './config/profiles.js';
|
||||
|
||||
describe('ContextManager - Hot Start Calibration', () => {
|
||||
it('should not perform calibration if the buffer is empty', async () => {
|
||||
const mockLlm = createMockLlmClient();
|
||||
const config = createMockContextConfig(undefined, mockLlm);
|
||||
const { contextManager } = setupContextComponentTest(
|
||||
config,
|
||||
stressTestProfile,
|
||||
);
|
||||
|
||||
// We can spy on the underlying mock LLM client countTokens
|
||||
const countTokensSpy = vi.spyOn(mockLlm, 'countTokens');
|
||||
|
||||
// Render an empty graph
|
||||
await contextManager.renderHistory();
|
||||
|
||||
expect(countTokensSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should perform calibration exactly once when rendering with existing nodes', async () => {
|
||||
const mockLlm = createMockLlmClient();
|
||||
const countTokensSpy = vi
|
||||
.spyOn(mockLlm, 'countTokens')
|
||||
.mockResolvedValue({ totalTokens: 42 });
|
||||
|
||||
const config = createMockContextConfig(undefined, mockLlm);
|
||||
const { contextManager, chatHistory } = setupContextComponentTest(
|
||||
config,
|
||||
stressTestProfile,
|
||||
);
|
||||
|
||||
// We need to access the env's eventBus inside the contextManager
|
||||
const env = Reflect.get(contextManager, 'env');
|
||||
const emitGroundTruthSpy = vi.spyOn(env.eventBus, 'emitTokenGroundTruth');
|
||||
|
||||
// Add a node to make the buffer non-empty
|
||||
chatHistory.set([{ role: 'user', parts: [{ text: 'Hello' }] }]);
|
||||
|
||||
// First render should trigger calibration
|
||||
await contextManager.renderHistory();
|
||||
|
||||
expect(countTokensSpy).toHaveBeenCalledTimes(1);
|
||||
expect(emitGroundTruthSpy).toHaveBeenCalledTimes(1);
|
||||
expect(emitGroundTruthSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
actualTokens: 42,
|
||||
promptBaseUnits: 10,
|
||||
}),
|
||||
);
|
||||
|
||||
// Second render should skip calibration
|
||||
await contextManager.renderHistory();
|
||||
expect(countTokensSpy).toHaveBeenCalledTimes(1);
|
||||
// emit hasn't been called again
|
||||
expect(emitGroundTruthSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should silently swallow errors if countTokens API fails', async () => {
|
||||
const mockLlm = createMockLlmClient();
|
||||
const countTokensSpy = vi
|
||||
.spyOn(mockLlm, 'countTokens')
|
||||
.mockRejectedValue(new Error('API failure'));
|
||||
|
||||
const config = createMockContextConfig(undefined, mockLlm);
|
||||
const { contextManager, chatHistory } = setupContextComponentTest(
|
||||
config,
|
||||
stressTestProfile,
|
||||
);
|
||||
|
||||
// Add a node
|
||||
chatHistory.set([{ role: 'user', parts: [{ text: 'Hello' }] }]);
|
||||
|
||||
// Render should succeed without throwing
|
||||
const result = await contextManager.renderHistory();
|
||||
|
||||
expect(result.history).toBeDefined();
|
||||
expect(countTokensSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
@@ -18,6 +18,7 @@ import { ContextWorkingBufferImpl } from './pipeline/contextWorkingBuffer.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { hardenHistory } from '../utils/historyHardening.js';
|
||||
import { checkContextInvariants } from './utils/invariantChecker.js';
|
||||
import type { AdvancedTokenCalculator } from './utils/contextTokenCalculator.js';
|
||||
|
||||
export class ContextManager {
|
||||
// The master state containing the pristine graph and current active graph.
|
||||
@@ -36,15 +37,22 @@ export class ContextManager {
|
||||
// Cache for Anomaly 3 (Redundant Renders)
|
||||
private lastRenderCache?: {
|
||||
nodesHash: string;
|
||||
result: { history: Content[]; didApplyManagement: boolean };
|
||||
result: {
|
||||
history: Content[];
|
||||
didApplyManagement: boolean;
|
||||
baseUnits: number;
|
||||
};
|
||||
};
|
||||
|
||||
private hasPerformedHotStart = false;
|
||||
|
||||
constructor(
|
||||
private readonly sidecar: ContextProfile,
|
||||
private readonly env: ContextEnvironment,
|
||||
private readonly tracer: ContextTracer,
|
||||
orchestrator: PipelineOrchestrator,
|
||||
chatHistory: AgentChatHistory,
|
||||
private readonly advancedTokenCalculator: AdvancedTokenCalculator,
|
||||
private readonly headerProvider?: () => Promise<Content | undefined>,
|
||||
) {
|
||||
this.eventBus = env.eventBus;
|
||||
@@ -260,6 +268,10 @@ export class ContextManager {
|
||||
return [...this.buffer.nodes];
|
||||
}
|
||||
|
||||
getEnvironment(): ContextEnvironment {
|
||||
return this.env;
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes the final 'gc_backstop' pipeline if necessary, enforcing the token budget,
|
||||
* and maps the Episodic Context Graph back into a raw Gemini Content[] array for transmission.
|
||||
@@ -268,22 +280,44 @@ export class ContextManager {
|
||||
async renderHistory(
|
||||
pendingRequest?: Content,
|
||||
activeTaskIds: Set<string> = new Set(),
|
||||
): Promise<{ history: Content[]; didApplyManagement: boolean }> {
|
||||
abortSignal?: AbortSignal,
|
||||
): Promise<{
|
||||
history: Content[];
|
||||
didApplyManagement: boolean;
|
||||
baseUnits: number;
|
||||
}> {
|
||||
this.tracer.logEvent('ContextManager', 'Starting rendering of LLM context');
|
||||
|
||||
let previewNodes: ConcreteNode[] = [];
|
||||
if (pendingRequest) {
|
||||
previewNodes = this.env.graphMapper.applyEvent({
|
||||
type: 'PUSH',
|
||||
payload: [pendingRequest],
|
||||
});
|
||||
}
|
||||
|
||||
// --- Hot Start Calibration ---
|
||||
// If we are resuming a session with history, we don't want the adaptive token calculator
|
||||
// to fly blind on its first GC pass. We do a one-time API calibration.
|
||||
const hotStartPromise = (async () => {
|
||||
if (!this.hasPerformedHotStart) {
|
||||
this.hasPerformedHotStart = true;
|
||||
if (this.buffer.nodes.length > 0) {
|
||||
const nodesForHotStart = [...this.buffer.nodes, ...previewNodes];
|
||||
await this.performHotStartCalibration(nodesForHotStart, abortSignal);
|
||||
}
|
||||
}
|
||||
})();
|
||||
|
||||
// 1. Synchronous Pressure Barrier: Wait for background management pipelines to finish.
|
||||
// This ensures that the render sees the results of recent pushes (Anomaly 2).
|
||||
await this.orchestrator.waitForPipelines();
|
||||
// We run hot start calibration in parallel to hide the network latency.
|
||||
await Promise.all([this.orchestrator.waitForPipelines(), hotStartPromise]);
|
||||
|
||||
let nodes = this.buffer.nodes;
|
||||
const previewNodeIds = new Set<string>();
|
||||
|
||||
// If we have a pending request, we need to build a 'preview' graph for this render.
|
||||
if (pendingRequest) {
|
||||
const previewNodes = this.env.graphMapper.applyEvent({
|
||||
type: 'PUSH',
|
||||
payload: [pendingRequest],
|
||||
});
|
||||
// Apply the preview nodes to the final graph
|
||||
if (previewNodes.length > 0) {
|
||||
for (const n of previewNodes) {
|
||||
previewNodeIds.add(n.id);
|
||||
}
|
||||
@@ -294,9 +328,6 @@ export class ContextManager {
|
||||
const header = this.headerProvider
|
||||
? await this.headerProvider()
|
||||
: undefined;
|
||||
const headerTokens = header
|
||||
? this.env.tokenCalculator.calculateContentTokens(header)
|
||||
: 0;
|
||||
|
||||
// 3. Cache Check (Anomaly 3): If nodes haven't changed, return previous result.
|
||||
// We combine the graph hash with a hash of the header to ensure total freshness.
|
||||
@@ -314,14 +345,19 @@ export class ContextManager {
|
||||
const protectionReasons = this.getProtectedNodeIds(nodes, activeTaskIds);
|
||||
|
||||
// Apply final GC Backstop pressure barrier synchronously before mapping
|
||||
const { history: renderedHistory, didApplyManagement } = await render(
|
||||
const {
|
||||
history: renderedHistory,
|
||||
didApplyManagement,
|
||||
baseUnits,
|
||||
} = await render(
|
||||
nodes,
|
||||
this.orchestrator,
|
||||
this.sidecar,
|
||||
this.tracer,
|
||||
this.env,
|
||||
this.advancedTokenCalculator,
|
||||
protectionReasons,
|
||||
headerTokens,
|
||||
header,
|
||||
previewNodeIds,
|
||||
);
|
||||
|
||||
@@ -339,10 +375,58 @@ export class ContextManager {
|
||||
sentinels: this.sidecar.sentinels,
|
||||
}),
|
||||
didApplyManagement,
|
||||
baseUnits,
|
||||
};
|
||||
|
||||
// Update cache
|
||||
this.lastRenderCache = { nodesHash: totalHash, result };
|
||||
return result;
|
||||
}
|
||||
|
||||
private async performHotStartCalibration(
|
||||
nodes: readonly ConcreteNode[],
|
||||
abortSignal?: AbortSignal,
|
||||
) {
|
||||
try {
|
||||
this.tracer.logEvent(
|
||||
'ContextManager',
|
||||
'Performing Hot Start Token Calibration',
|
||||
);
|
||||
|
||||
const contents = this.env.graphMapper.fromGraph(nodes);
|
||||
const header = this.headerProvider
|
||||
? await this.headerProvider()
|
||||
: undefined;
|
||||
const combinedHistory = header ? [header, ...contents] : contents;
|
||||
|
||||
const baseUnits =
|
||||
this.advancedTokenCalculator.getRawBaseUnits(nodes) +
|
||||
(header
|
||||
? this.advancedTokenCalculator.getRawBaseUnitsForContent(header)
|
||||
: 0);
|
||||
|
||||
// We only make the network call if we have actual contents to send,
|
||||
// avoiding 400 Bad Request errors from the API.
|
||||
if (combinedHistory.length > 0) {
|
||||
const result = await this.env.llmClient.countTokens({
|
||||
contents: combinedHistory,
|
||||
abortSignal,
|
||||
});
|
||||
if (result.totalTokens > 0) {
|
||||
this.env.eventBus.emitTokenGroundTruth({
|
||||
actualTokens: result.totalTokens,
|
||||
promptBaseUnits: baseUnits,
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// Hot start calibration is purely an optimization. If the network fails or auth is weird,
|
||||
// we silently swallow and fallback to the un-calibrated 1.0 ratio heuristic.
|
||||
this.tracer.logEvent(
|
||||
'ContextManager',
|
||||
'Hot Start Token Calibration Failed (Ignored)',
|
||||
{ error },
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,7 +29,20 @@ export interface ChunkReceivedEvent {
|
||||
targetNodeIds: Set<string>;
|
||||
}
|
||||
|
||||
export interface TokenGroundTruthEvent {
|
||||
actualTokens: number;
|
||||
promptBaseUnits: number;
|
||||
}
|
||||
|
||||
export class ContextEventBus extends EventEmitter {
|
||||
emitTokenGroundTruth(event: TokenGroundTruthEvent) {
|
||||
this.emit('TOKEN_GROUND_TRUTH', event);
|
||||
}
|
||||
|
||||
onTokenGroundTruth(listener: (event: TokenGroundTruthEvent) => void) {
|
||||
this.on('TOKEN_GROUND_TRUTH', listener);
|
||||
}
|
||||
|
||||
emitPristineHistoryUpdated(event: PristineHistoryUpdatedEvent) {
|
||||
this.emit('PRISTINE_HISTORY_UPDATED', event);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import { describe, it, expect, vi } from 'vitest';
|
||||
import { render } from './render.js';
|
||||
import type { ConcreteNode } from './types.js';
|
||||
import { NodeType } from './types.js';
|
||||
import type { AdvancedTokenCalculator } from '../utils/contextTokenCalculator.js';
|
||||
import type { ContextEnvironment } from '../pipeline/environment.js';
|
||||
import type { ContextTracer } from '../tracer.js';
|
||||
import type { ContextProfile } from '../config/profiles.js';
|
||||
@@ -37,7 +38,20 @@ describe('render', () => {
|
||||
|
||||
const orchestrator = {} as PipelineOrchestrator;
|
||||
const sidecar = { config: {} } as ContextProfile; // No budget
|
||||
const mockAdvancedTokenCalculator = {
|
||||
calculateTokensAndBaseUnits: vi.fn().mockReturnValue({
|
||||
tokens: 100,
|
||||
baseUnits: 100,
|
||||
}),
|
||||
getRawBaseUnits: vi.fn().mockReturnValue(100),
|
||||
getRawBaseUnitsForContent: vi.fn().mockReturnValue(0),
|
||||
};
|
||||
|
||||
const env = {
|
||||
tokenCalculator: {
|
||||
calculateConcreteListTokens: vi.fn().mockReturnValue(100),
|
||||
calculateTokenBreakdown: vi.fn().mockReturnValue({}),
|
||||
},
|
||||
graphMapper: {
|
||||
fromGraph: vi.fn((nodes: readonly ConcreteNode[]) =>
|
||||
nodes.map((n) => ({ text: n.id })),
|
||||
@@ -54,12 +68,14 @@ describe('render', () => {
|
||||
sidecar,
|
||||
tracer,
|
||||
env,
|
||||
mockAdvancedTokenCalculator as unknown as AdvancedTokenCalculator,
|
||||
new Map(),
|
||||
0,
|
||||
undefined,
|
||||
previewNodeIds,
|
||||
);
|
||||
|
||||
expect(result.history).toEqual([{ text: '1' }, { text: '2' }]);
|
||||
expect(result.baseUnits).toBe(100);
|
||||
});
|
||||
|
||||
it('simulates the boundary knapsack problem (loose boundary policy)', async () => {
|
||||
@@ -108,12 +124,24 @@ describe('render', () => {
|
||||
|
||||
const currentTokens = 160000;
|
||||
|
||||
const mockAdvancedTokenCalculator = {
|
||||
calculateTokensAndBaseUnits: vi.fn((nodes: readonly ConcreteNode[]) => {
|
||||
const tokens =
|
||||
nodes.length === 1 ? tokenMap[nodes[0].id] : currentTokens;
|
||||
return { tokens, baseUnits: tokens };
|
||||
}),
|
||||
getRawBaseUnits: vi.fn((nodes: readonly ConcreteNode[]) => {
|
||||
if (nodes.length === 1) return tokenMap[nodes[0].id];
|
||||
return currentTokens;
|
||||
}),
|
||||
};
|
||||
|
||||
const env = {
|
||||
llmClient: {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1000 }),
|
||||
},
|
||||
tokenCalculator: {
|
||||
calculateConcreteListTokens: vi.fn((nodes) => {
|
||||
calculateConcreteListTokens: vi.fn((nodes: readonly ConcreteNode[]) => {
|
||||
if (nodes.length === 1) return tokenMap[nodes[0].id];
|
||||
return currentTokens;
|
||||
}),
|
||||
@@ -136,8 +164,9 @@ describe('render', () => {
|
||||
sidecar,
|
||||
tracer,
|
||||
env,
|
||||
mockAdvancedTokenCalculator as unknown as AdvancedTokenCalculator,
|
||||
new Map(),
|
||||
0,
|
||||
undefined,
|
||||
new Set(),
|
||||
);
|
||||
|
||||
@@ -147,6 +176,7 @@ describe('render', () => {
|
||||
// Adding C pushes rolling total (70k) above retainedTokens (65k).
|
||||
// Under loose policy, C survives. D is strictly older and drops.
|
||||
expect(surviving).toEqual(['C', 'B', 'A']); // D is dropped
|
||||
expect(result.baseUnits).toBe(160000);
|
||||
});
|
||||
|
||||
it('drops nodes that are STRICTLY older than the boundary node', async () => {
|
||||
@@ -188,12 +218,24 @@ describe('render', () => {
|
||||
|
||||
const currentTokens = 160000;
|
||||
|
||||
const mockAdvancedTokenCalculator = {
|
||||
calculateTokensAndBaseUnits: vi.fn((nodes: readonly ConcreteNode[]) => {
|
||||
const tokens =
|
||||
nodes.length === 1 ? tokenMap[nodes[0].id] : currentTokens;
|
||||
return { tokens, baseUnits: tokens };
|
||||
}),
|
||||
getRawBaseUnits: vi.fn((nodes: readonly ConcreteNode[]) => {
|
||||
if (nodes.length === 1) return tokenMap[nodes[0].id];
|
||||
return currentTokens;
|
||||
}),
|
||||
};
|
||||
|
||||
const env = {
|
||||
llmClient: {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1000 }),
|
||||
},
|
||||
tokenCalculator: {
|
||||
calculateConcreteListTokens: vi.fn((nodes) => {
|
||||
calculateConcreteListTokens: vi.fn((nodes: readonly ConcreteNode[]) => {
|
||||
if (nodes.length === 1) return tokenMap[nodes[0].id];
|
||||
return currentTokens;
|
||||
}),
|
||||
@@ -216,8 +258,9 @@ describe('render', () => {
|
||||
sidecar,
|
||||
tracer,
|
||||
env,
|
||||
mockAdvancedTokenCalculator as unknown as AdvancedTokenCalculator,
|
||||
new Map(),
|
||||
0,
|
||||
undefined,
|
||||
new Set(),
|
||||
);
|
||||
|
||||
@@ -225,5 +268,6 @@ describe('render', () => {
|
||||
const surviving = result.history.map((c: any) => c.text);
|
||||
// C(40k), B(40k). Adding B pushes total to 80k. B is the boundary node and survives. A drops.
|
||||
expect(surviving).toEqual(['B', 'C']); // A is dropped
|
||||
expect(result.baseUnits).toBe(160000);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -11,6 +11,7 @@ import type { ContextProfile } from '../config/profiles.js';
|
||||
import type { PipelineOrchestrator } from '../pipeline/orchestrator.js';
|
||||
import type { ContextEnvironment } from '../pipeline/environment.js';
|
||||
import { performCalibration } from '../utils/tokenCalibration.js';
|
||||
import type { AdvancedTokenCalculator } from '../utils/contextTokenCalculator.js';
|
||||
|
||||
/**
|
||||
* Maps the Episodic Context Graph back into a raw Gemini Content[] array for transmission.
|
||||
@@ -22,21 +23,43 @@ export async function render(
|
||||
sidecar: ContextProfile,
|
||||
tracer: ContextTracer,
|
||||
env: ContextEnvironment,
|
||||
advancedTokenCalculator: AdvancedTokenCalculator,
|
||||
protectionReasons: Map<string, string> = new Map(),
|
||||
headerTokens: number = 0,
|
||||
header?: Content,
|
||||
previewNodeIds: ReadonlySet<string> = new Set(),
|
||||
): Promise<{ history: Content[]; didApplyManagement: boolean }> {
|
||||
): Promise<{
|
||||
history: Content[];
|
||||
didApplyManagement: boolean;
|
||||
baseUnits: number;
|
||||
}> {
|
||||
let headerTokens = 0;
|
||||
let headerBaseUnits = 0;
|
||||
if (header) {
|
||||
const costs =
|
||||
advancedTokenCalculator.calculateContentTokensAndBaseUnits(header);
|
||||
headerTokens = costs.tokens;
|
||||
headerBaseUnits = costs.baseUnits;
|
||||
}
|
||||
|
||||
if (!sidecar.config.budget) {
|
||||
const visibleNodes = nodes.filter((n) => !previewNodeIds.has(n.id));
|
||||
const contents = env.graphMapper.fromGraph(visibleNodes);
|
||||
tracer.logEvent('Render', 'Render Context to LLM (No Budget)', {
|
||||
renderedContext: contents,
|
||||
});
|
||||
return { history: contents, didApplyManagement: false };
|
||||
|
||||
// In all cases, retrieve raw base units from the token calculator interface
|
||||
const baseUnits =
|
||||
advancedTokenCalculator.getRawBaseUnits(nodes) + headerBaseUnits;
|
||||
|
||||
return { history: contents, didApplyManagement: false, baseUnits };
|
||||
}
|
||||
|
||||
const maxTokens = sidecar.config.budget.maxTokens;
|
||||
const graphTokens = env.tokenCalculator.calculateConcreteListTokens(nodes);
|
||||
|
||||
const { tokens: graphTokens, baseUnits: graphBaseUnits } =
|
||||
advancedTokenCalculator.calculateTokensAndBaseUnits(nodes);
|
||||
|
||||
const currentTokens = graphTokens + headerTokens;
|
||||
|
||||
const protectedIds = new Set(protectionReasons.keys());
|
||||
@@ -70,7 +93,11 @@ export async function render(
|
||||
renderedContext: contents,
|
||||
});
|
||||
performCalibration(env, visibleNodes, contents);
|
||||
return { history: contents, didApplyManagement: false };
|
||||
return {
|
||||
history: contents,
|
||||
didApplyManagement: false,
|
||||
baseUnits: graphBaseUnits + headerBaseUnits,
|
||||
};
|
||||
}
|
||||
const targetDelta = currentTokens - sidecar.config.budget.retainedTokens;
|
||||
tracer.logEvent(
|
||||
@@ -119,5 +146,10 @@ export async function render(
|
||||
renderedContextSanitized: contents,
|
||||
});
|
||||
performCalibration(env, visibleNodes, contents);
|
||||
return { history: contents, didApplyManagement: true };
|
||||
return {
|
||||
history: contents,
|
||||
didApplyManagement: true,
|
||||
baseUnits:
|
||||
advancedTokenCalculator.getRawBaseUnits(visibleNodes) + headerBaseUnits,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -23,6 +23,9 @@ import { StateSnapshotProcessorOptionsSchema } from './processors/stateSnapshotP
|
||||
import { StateSnapshotAsyncProcessorOptionsSchema } from './processors/stateSnapshotAsyncProcessor.js';
|
||||
import { RollingSummaryProcessorOptionsSchema } from './processors/rollingSummaryProcessor.js';
|
||||
import { getEnvironmentContext } from '../utils/environmentContext.js';
|
||||
import { AdaptiveTokenCalculator } from './utils/adaptiveTokenCalculator.js';
|
||||
import { NodeBehaviorRegistry } from './graph/behaviorRegistry.js';
|
||||
import { registerBuiltInBehaviors } from './graph/builtinBehaviors.js';
|
||||
|
||||
export async function initializeContextManager(
|
||||
config: Config,
|
||||
@@ -85,6 +88,16 @@ export async function initializeContextManager(
|
||||
|
||||
const eventBus = new ContextEventBus();
|
||||
|
||||
const charsPerToken = 3;
|
||||
const behaviorRegistry = new NodeBehaviorRegistry();
|
||||
registerBuiltInBehaviors(behaviorRegistry);
|
||||
|
||||
const calculator = new AdaptiveTokenCalculator(
|
||||
charsPerToken,
|
||||
behaviorRegistry,
|
||||
eventBus,
|
||||
);
|
||||
|
||||
const env = new ContextEnvironmentImpl(
|
||||
() => config.getBaseLlmClient(),
|
||||
config.getSessionId(),
|
||||
@@ -92,8 +105,10 @@ export async function initializeContextManager(
|
||||
logDir,
|
||||
projectTempDir,
|
||||
tracer,
|
||||
4,
|
||||
charsPerToken,
|
||||
eventBus,
|
||||
calculator,
|
||||
behaviorRegistry,
|
||||
{
|
||||
calibrateTokenCalculation:
|
||||
!!process.env['GEMINI_CONTEXT_CALIBRATE_TOKEN_CALCULATIONS'],
|
||||
@@ -114,6 +129,7 @@ export async function initializeContextManager(
|
||||
tracer,
|
||||
orchestrator,
|
||||
chat.agentHistory,
|
||||
calculator,
|
||||
async () => {
|
||||
const parts = await getEnvironmentContext(config);
|
||||
return { role: 'user', parts };
|
||||
|
||||
@@ -8,12 +8,16 @@ import { ContextEnvironmentImpl } from './environmentImpl.js';
|
||||
import { ContextTracer } from '../tracer.js';
|
||||
import { ContextEventBus } from '../eventBus.js';
|
||||
import { createMockLlmClient } from '../testing/contextTestUtils.js';
|
||||
import { StaticTokenCalculator } from '../utils/contextTokenCalculator.js';
|
||||
import { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
|
||||
|
||||
describe('ContextEnvironmentImpl', () => {
|
||||
it('should initialize with defaults correctly', () => {
|
||||
const tracer = new ContextTracer({ targetDir: '/tmp', sessionId: 'mock' });
|
||||
const eventBus = new ContextEventBus();
|
||||
const mockLlmClient = createMockLlmClient();
|
||||
const behaviorRegistry = new NodeBehaviorRegistry();
|
||||
const calculator = new StaticTokenCalculator(4, behaviorRegistry);
|
||||
|
||||
const env = new ContextEnvironmentImpl(
|
||||
() => mockLlmClient,
|
||||
@@ -24,6 +28,8 @@ describe('ContextEnvironmentImpl', () => {
|
||||
tracer,
|
||||
4,
|
||||
eventBus,
|
||||
calculator,
|
||||
behaviorRegistry,
|
||||
);
|
||||
|
||||
expect(env.llmClient).toBe(mockLlmClient);
|
||||
|
||||
@@ -8,16 +8,13 @@ import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import type { ContextTracer } from '../tracer.js';
|
||||
import type { ContextEnvironment, RenderOptions } from './environment.js';
|
||||
import type { ContextEventBus } from '../eventBus.js';
|
||||
import { ContextTokenCalculator } from '../utils/contextTokenCalculator.js';
|
||||
import type { ContextTokenCalculator } from '../utils/contextTokenCalculator.js';
|
||||
import { LiveInbox } from './inbox.js';
|
||||
import { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
|
||||
import { registerBuiltInBehaviors } from '../graph/builtinBehaviors.js';
|
||||
import type { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
|
||||
import { ContextGraphMapper } from '../graph/mapper.js';
|
||||
|
||||
export class ContextEnvironmentImpl implements ContextEnvironment {
|
||||
readonly tokenCalculator: ContextTokenCalculator;
|
||||
readonly inbox: LiveInbox;
|
||||
readonly behaviorRegistry: NodeBehaviorRegistry;
|
||||
readonly graphMapper: ContextGraphMapper;
|
||||
|
||||
constructor(
|
||||
@@ -29,14 +26,10 @@ export class ContextEnvironmentImpl implements ContextEnvironment {
|
||||
readonly tracer: ContextTracer,
|
||||
readonly charsPerToken: number,
|
||||
readonly eventBus: ContextEventBus,
|
||||
readonly tokenCalculator: ContextTokenCalculator,
|
||||
readonly behaviorRegistry: NodeBehaviorRegistry,
|
||||
readonly renderOptions?: RenderOptions,
|
||||
) {
|
||||
this.behaviorRegistry = new NodeBehaviorRegistry();
|
||||
registerBuiltInBehaviors(this.behaviorRegistry);
|
||||
this.tokenCalculator = new ContextTokenCalculator(
|
||||
this.charsPerToken,
|
||||
this.behaviorRegistry,
|
||||
);
|
||||
this.inbox = new LiveInbox();
|
||||
this.graphMapper = new ContextGraphMapper();
|
||||
}
|
||||
|
||||
@@ -66,12 +66,12 @@ describe('BlobDegradationProcessor', () => {
|
||||
|
||||
const node1 = createDummyNode('ep1', NodeType.USER_PROMPT, 100, {
|
||||
payload: {
|
||||
fileData: { mimeType: 'video/mp4', fileUri: 'gs://test1' },
|
||||
fileData: { mimeType: 'image/png', fileUri: 'gs://test1' },
|
||||
},
|
||||
});
|
||||
const node2 = createDummyNode('ep1', NodeType.USER_PROMPT, 100, {
|
||||
payload: {
|
||||
fileData: { mimeType: 'video/mp4', fileUri: 'gs://test2' },
|
||||
fileData: { mimeType: 'image/png', fileUri: 'gs://test2' },
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ describe('StateSnapshotAsyncProcessor', () => {
|
||||
'PROPOSED_SNAPSHOT',
|
||||
expect.objectContaining({
|
||||
newText:
|
||||
'{"active_tasks":[],"discovered_facts":[],"constraints_and_preferences":[],"recent_arc":[]}',
|
||||
'{"active_tasks":[],"discovered_facts":[],"constraints_and_preferences":[],"recent_arc":["Mock LLM summary response"]}',
|
||||
consumedIds: ['node-A', 'node-B'],
|
||||
type: 'point-in-time',
|
||||
}),
|
||||
@@ -107,7 +107,7 @@ describe('StateSnapshotAsyncProcessor', () => {
|
||||
'PROPOSED_SNAPSHOT',
|
||||
expect.objectContaining({
|
||||
newText:
|
||||
'{"active_tasks":[],"discovered_facts":[],"constraints_and_preferences":[],"recent_arc":[]}',
|
||||
'{"active_tasks":[],"discovered_facts":[],"constraints_and_preferences":[],"recent_arc":["Mock LLM summary response"]}',
|
||||
consumedIds: ['node-A', 'node-B', 'node-C'], // Aggregated!
|
||||
type: 'accumulate',
|
||||
}),
|
||||
|
||||
+36
-64
File diff suppressed because one or more lines are too long
@@ -9,6 +9,7 @@ import { SimulationHarness } from './simulationHarness.js';
|
||||
import { createMockLlmClient } from '../testing/contextTestUtils.js';
|
||||
import type { ContextProfile } from '../config/profiles.js';
|
||||
import { generalistProfile } from '../config/profiles.js';
|
||||
import type { Content } from '@google/genai';
|
||||
|
||||
describe('Context Manager Hysteresis Tests', () => {
|
||||
const mockLlmClient = createMockLlmClient(['<SNAPSHOT>']);
|
||||
@@ -25,6 +26,12 @@ describe('Context Manager Hysteresis Tests', () => {
|
||||
},
|
||||
});
|
||||
|
||||
const getProjectionTokens = (proj: Content[], harness: SimulationHarness) =>
|
||||
proj.reduce(
|
||||
(sum, c) => sum + harness.env.tokenCalculator.calculateContentTokens(c),
|
||||
0,
|
||||
);
|
||||
|
||||
it('should block consolidation when deficit is below coalescing threshold', async () => {
|
||||
const threshold = 1500;
|
||||
const harness = await SimulationHarness.create(
|
||||
@@ -35,14 +42,14 @@ describe('Context Manager Hysteresis Tests', () => {
|
||||
// Turn 0: INIT
|
||||
await harness.simulateTurn([{ role: 'user', parts: [{ text: 'INIT' }] }]);
|
||||
|
||||
// Turn 1: Add 1500 chars (~500 tokens). Total ~500. Under retained (1000).
|
||||
// Turn 1: Add ~500 tokens
|
||||
await harness.simulateTurn([
|
||||
{ role: 'user', parts: [{ text: 'A'.repeat(1500) }] },
|
||||
{ role: 'user', parts: [{ text: 'A'.repeat(500) }] },
|
||||
]);
|
||||
|
||||
// Turn 2: Add 3000 chars (~1000 tokens). Total ~1500. Deficit ~500 < 1500.
|
||||
// Turn 2: Add ~1000 tokens. Total ~1500. Deficit ~500 < 1500.
|
||||
await harness.simulateTurn([
|
||||
{ role: 'user', parts: [{ text: 'B'.repeat(3000) }] },
|
||||
{ role: 'user', parts: [{ text: 'B'.repeat(1000) }] },
|
||||
]);
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
@@ -54,19 +61,19 @@ describe('Context Manager Hysteresis Tests', () => {
|
||||
),
|
||||
).toBe(false);
|
||||
|
||||
// Turn 3: Add 9000 chars (~3000 tokens). Total ~4500.
|
||||
// Turn 3: Add ~3000 tokens. Total ~4500.
|
||||
// Deficit ~3500 > 1500. TRIGGER!
|
||||
await harness.simulateTurn([
|
||||
{ role: 'user', parts: [{ text: 'C'.repeat(9000) }] },
|
||||
{ role: 'user', parts: [{ text: 'C'.repeat(3000) }] },
|
||||
]);
|
||||
|
||||
// Give it a moment for the async task to finish
|
||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
|
||||
// Exceed maxTokens to force a render that shows the snapshot
|
||||
// Add 3000 more tokens (9000 chars). Total ~7500 > 5000.
|
||||
// Add ~3000 tokens. Total ~7500 > 5000.
|
||||
await harness.simulateTurn([
|
||||
{ role: 'user', parts: [{ text: 'D'.repeat(9000) }] },
|
||||
{ role: 'user', parts: [{ text: 'D'.repeat(3000) }] },
|
||||
]);
|
||||
|
||||
state = await harness.getGoldenState();
|
||||
@@ -85,57 +92,51 @@ describe('Context Manager Hysteresis Tests', () => {
|
||||
);
|
||||
|
||||
// 1. Trigger first consolidation
|
||||
// Add ~9000 chars (~3000 tokens). Total ~3000. Deficit ~2000 > 1000.
|
||||
// Add ~3000 tokens. Total ~3000. Deficit ~2000 > 1000.
|
||||
await harness.simulateTurn([
|
||||
{ role: 'user', parts: [{ text: 'A'.repeat(9000) }] },
|
||||
{ role: 'user', parts: [{ text: 'A'.repeat(3000) }] },
|
||||
]);
|
||||
await harness.simulateTurn([{ role: 'user', parts: [{ text: 'B' }] }]); // Make eligible
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
// Exceed maxTokens (5000) to see it
|
||||
await harness.simulateTurn([
|
||||
{ role: 'user', parts: [{ text: 'X'.repeat(9000) }] },
|
||||
{ role: 'user', parts: [{ text: 'X'.repeat(3000) }] },
|
||||
]);
|
||||
|
||||
const state = await harness.getGoldenState();
|
||||
// Get baseline tokens
|
||||
let state = await harness.getGoldenState();
|
||||
expect(
|
||||
state.finalProjection.some((c) =>
|
||||
c.parts?.some((p) => p.text?.includes('<SNAPSHOT>')),
|
||||
),
|
||||
).toBe(true);
|
||||
|
||||
// Get baseline tokens
|
||||
const baselineTokens =
|
||||
harness.env.tokenCalculator.calculateConcreteListTokens(
|
||||
harness.contextManager.getNodes(),
|
||||
);
|
||||
const baselineTokens = getProjectionTokens(state.finalProjection, harness);
|
||||
|
||||
// 2. Add nodes again, staying below threshold growth
|
||||
// Add 1500 chars (~500 tokens). Growth ~500 < 1000.
|
||||
// Add ~500 tokens. Growth ~500 < 1000.
|
||||
await harness.simulateTurn([
|
||||
{ role: 'user', parts: [{ text: 'C'.repeat(1500) }] },
|
||||
{ role: 'user', parts: [{ text: 'C'.repeat(500) }] },
|
||||
]);
|
||||
await harness.simulateTurn([{ role: 'user', parts: [{ text: 'D' }] }]); // Make eligible
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 200));
|
||||
const currentTokens =
|
||||
harness.env.tokenCalculator.calculateConcreteListTokens(
|
||||
harness.contextManager.getNodes(),
|
||||
);
|
||||
state = await harness.getGoldenState();
|
||||
const currentTokens = getProjectionTokens(state.finalProjection, harness);
|
||||
// Should not have shrunk further (except for D's small addition)
|
||||
expect(currentTokens).toBeGreaterThanOrEqual(baselineTokens);
|
||||
|
||||
// 3. Exceed threshold growth
|
||||
// Add 6000 chars (~2000 tokens). Growth = ~500 + ~2000 = ~2500 > 1000.
|
||||
// Add ~2000 tokens. Growth = ~500 + ~2000 = ~2500 > 1000.
|
||||
await harness.simulateTurn([
|
||||
{ role: 'user', parts: [{ text: 'E'.repeat(6000) }] },
|
||||
{ role: 'user', parts: [{ text: 'E'.repeat(2000) }] },
|
||||
]);
|
||||
await harness.simulateTurn([{ role: 'user', parts: [{ text: 'F' }] }]); // Make eligible
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
const finalTokens = harness.env.tokenCalculator.calculateConcreteListTokens(
|
||||
harness.contextManager.getNodes(),
|
||||
);
|
||||
state = await harness.getGoldenState();
|
||||
const finalTokens = getProjectionTokens(state.finalProjection, harness);
|
||||
// Now it should have consolidated again (E should be replaced by a snapshot eventually)
|
||||
expect(finalTokens).toBeLessThan(currentTokens + 2000);
|
||||
});
|
||||
|
||||
@@ -13,6 +13,9 @@ import { ContextTracer } from '../tracer.js';
|
||||
import { ContextEventBus } from '../eventBus.js';
|
||||
import { PipelineOrchestrator } from '../pipeline/orchestrator.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import { StaticTokenCalculator } from '../utils/contextTokenCalculator.js';
|
||||
import { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
|
||||
import { registerBuiltInBehaviors } from '../graph/builtinBehaviors.js';
|
||||
|
||||
export interface TurnSummary {
|
||||
turnIndex: number;
|
||||
@@ -57,6 +60,11 @@ export class SimulationHarness {
|
||||
targetDir: mockTempDir,
|
||||
sessionId: 'sim-session',
|
||||
});
|
||||
|
||||
const behaviorRegistry = new NodeBehaviorRegistry();
|
||||
registerBuiltInBehaviors(behaviorRegistry);
|
||||
const calculator = new StaticTokenCalculator(1, behaviorRegistry);
|
||||
|
||||
this.env = new ContextEnvironmentImpl(
|
||||
() => mockLlmClient,
|
||||
'sim-prompt',
|
||||
@@ -66,6 +74,8 @@ export class SimulationHarness {
|
||||
this.tracer,
|
||||
1, // 1 char per token average for estimation (but estimator uses 0.33)
|
||||
this.eventBus,
|
||||
calculator,
|
||||
behaviorRegistry,
|
||||
);
|
||||
|
||||
this.orchestrator = new PipelineOrchestrator(
|
||||
@@ -81,6 +91,7 @@ export class SimulationHarness {
|
||||
this.tracer,
|
||||
this.orchestrator,
|
||||
this.chatHistory,
|
||||
calculator,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -111,11 +122,12 @@ export class SimulationHarness {
|
||||
}
|
||||
|
||||
async getGoldenState() {
|
||||
const { history: finalProjection } =
|
||||
const { history: finalProjection, baseUnits } =
|
||||
await this.contextManager.renderHistory();
|
||||
return {
|
||||
tokenTrajectory: this.tokenTrajectory,
|
||||
finalProjection,
|
||||
baseUnits,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,6 +30,9 @@ import type { ContextProfile } from '../config/profiles.js';
|
||||
import type { Mock } from 'vitest';
|
||||
import { ContextWorkingBufferImpl } from '../pipeline/contextWorkingBuffer.js';
|
||||
import { testTruncateProfile } from './testProfile.js';
|
||||
import { StaticTokenCalculator } from '../utils/contextTokenCalculator.js';
|
||||
import { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
|
||||
import { registerBuiltInBehaviors } from '../graph/builtinBehaviors.js';
|
||||
|
||||
/**
|
||||
* Creates a valid mock GenerateContentResponse with the provided text.
|
||||
@@ -134,17 +137,30 @@ export function createMockLlmClient(
|
||||
);
|
||||
});
|
||||
|
||||
const generateJsonMock = vi.fn().mockImplementation(async () => ({
|
||||
active_tasks: [],
|
||||
discovered_facts: [],
|
||||
constraints_and_preferences: [],
|
||||
recent_arc: [],
|
||||
}));
|
||||
const generateJsonMock = vi.fn().mockImplementation(async () => {
|
||||
let mockStr = '';
|
||||
if (responses && responses.length > 0) {
|
||||
const callCount = generateJsonMock.mock.calls.length - 1;
|
||||
const idx =
|
||||
callCount < responses.length ? callCount : responses.length - 1;
|
||||
const res = responses[idx];
|
||||
if (typeof res === 'string') {
|
||||
mockStr = res;
|
||||
}
|
||||
}
|
||||
return {
|
||||
active_tasks: [],
|
||||
discovered_facts: [],
|
||||
constraints_and_preferences: [],
|
||||
chronological_summary: mockStr,
|
||||
};
|
||||
});
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
return {
|
||||
generateContent: generateContentMock,
|
||||
generateJson: generateJsonMock,
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }),
|
||||
} as unknown as MockLlmClient;
|
||||
}
|
||||
|
||||
@@ -158,6 +174,9 @@ export function createMockEnvironment(
|
||||
sessionId: 'mock-session',
|
||||
});
|
||||
const eventBus = new ContextEventBus();
|
||||
const behaviorRegistry = new NodeBehaviorRegistry();
|
||||
registerBuiltInBehaviors(behaviorRegistry);
|
||||
const calculator = new StaticTokenCalculator(1, behaviorRegistry);
|
||||
|
||||
let env = new ContextEnvironmentImpl(
|
||||
() => llmClient as BaseLlmClient,
|
||||
@@ -168,6 +187,8 @@ export function createMockEnvironment(
|
||||
tracer,
|
||||
1,
|
||||
eventBus,
|
||||
calculator,
|
||||
behaviorRegistry,
|
||||
);
|
||||
|
||||
if (overrides) {
|
||||
@@ -181,6 +202,8 @@ export function createMockEnvironment(
|
||||
env.tracer,
|
||||
env.charsPerToken,
|
||||
env.eventBus,
|
||||
calculator,
|
||||
behaviorRegistry,
|
||||
);
|
||||
}
|
||||
const { llmClient: _llmClient, ...restOverrides } = overrides;
|
||||
@@ -273,6 +296,10 @@ export function setupContextComponentTest(
|
||||
sessionId: 'test-session',
|
||||
});
|
||||
const eventBus = new ContextEventBus();
|
||||
const behaviorRegistry = new NodeBehaviorRegistry();
|
||||
registerBuiltInBehaviors(behaviorRegistry);
|
||||
const calculator = new StaticTokenCalculator(1, behaviorRegistry);
|
||||
|
||||
const env = new ContextEnvironmentImpl(
|
||||
() => config.getBaseLlmClient(),
|
||||
'test prompt-id',
|
||||
@@ -282,6 +309,8 @@ export function setupContextComponentTest(
|
||||
tracer,
|
||||
1,
|
||||
eventBus,
|
||||
calculator,
|
||||
behaviorRegistry,
|
||||
);
|
||||
|
||||
const orchestrator = new PipelineOrchestrator(
|
||||
@@ -298,8 +327,8 @@ export function setupContextComponentTest(
|
||||
tracer,
|
||||
orchestrator,
|
||||
chatHistory,
|
||||
calculator,
|
||||
);
|
||||
|
||||
// The async async pipeline is now internally managed by ContextManager
|
||||
return { chatHistory, contextManager };
|
||||
}
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { AdaptiveTokenCalculator } from './adaptiveTokenCalculator.js';
|
||||
import { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
|
||||
import { registerBuiltInBehaviors } from '../graph/builtinBehaviors.js';
|
||||
import { ContextEventBus } from '../eventBus.js';
|
||||
import { createDummyNode } from '../testing/contextTestUtils.js';
|
||||
import { NodeType } from '../graph/types.js';
|
||||
|
||||
describe('AdaptiveTokenCalculator', () => {
|
||||
const registry = new NodeBehaviorRegistry();
|
||||
registerBuiltInBehaviors(registry);
|
||||
const charsPerToken = 1; // Simplifies math
|
||||
|
||||
it('should initialize with a learned weight of 1.0', () => {
|
||||
const eventBus = new ContextEventBus();
|
||||
const calculator = new AdaptiveTokenCalculator(
|
||||
charsPerToken,
|
||||
registry,
|
||||
eventBus,
|
||||
);
|
||||
expect(calculator.getLearnedWeight()).toBe(1.0);
|
||||
});
|
||||
|
||||
it('should dynamically update learned weight based on token ground truth events', () => {
|
||||
const eventBus = new ContextEventBus();
|
||||
const calculator = new AdaptiveTokenCalculator(
|
||||
charsPerToken,
|
||||
registry,
|
||||
eventBus,
|
||||
);
|
||||
|
||||
// Initial state: weight = 1.0
|
||||
|
||||
// Simulate an event where the API reported fewer tokens than our base units
|
||||
// targetWeight = 50 / 100 = 0.5
|
||||
// newWeight = 1.0 * 0.8 + 0.5 * 0.2 = 0.8 + 0.1 = 0.9
|
||||
eventBus.emitTokenGroundTruth({
|
||||
actualTokens: 50,
|
||||
promptBaseUnits: 100,
|
||||
});
|
||||
|
||||
// JavaScript floating point precision means we should use toBeCloseTo
|
||||
expect(calculator.getLearnedWeight()).toBeCloseTo(0.9, 5);
|
||||
|
||||
// Simulate another event
|
||||
// newWeight = 0.9 * 0.8 + (150 / 100) * 0.2 = 0.72 + 0.3 = 1.02
|
||||
eventBus.emitTokenGroundTruth({
|
||||
actualTokens: 150,
|
||||
promptBaseUnits: 100,
|
||||
});
|
||||
|
||||
expect(calculator.getLearnedWeight()).toBeCloseTo(1.02, 5);
|
||||
});
|
||||
|
||||
it('should clamp the learned weight between 0.5 and 2.0', () => {
|
||||
const eventBus = new ContextEventBus();
|
||||
const calculator = new AdaptiveTokenCalculator(
|
||||
charsPerToken,
|
||||
registry,
|
||||
eventBus,
|
||||
);
|
||||
|
||||
// Push weight up extremely high (API returns 10x tokens)
|
||||
for (let i = 0; i < 20; i++) {
|
||||
eventBus.emitTokenGroundTruth({
|
||||
actualTokens: 1000,
|
||||
promptBaseUnits: 100,
|
||||
});
|
||||
}
|
||||
expect(calculator.getLearnedWeight()).toBe(2.0);
|
||||
|
||||
// Push weight down extremely low (API returns 0 tokens)
|
||||
for (let i = 0; i < 20; i++) {
|
||||
eventBus.emitTokenGroundTruth({ actualTokens: 0, promptBaseUnits: 100 });
|
||||
}
|
||||
expect(calculator.getLearnedWeight()).toBe(0.5);
|
||||
});
|
||||
|
||||
it('should correctly apply the learned weight to node calculations while keeping raw base units stable', () => {
|
||||
const eventBus = new ContextEventBus();
|
||||
const calculator = new AdaptiveTokenCalculator(
|
||||
charsPerToken,
|
||||
registry,
|
||||
eventBus,
|
||||
);
|
||||
|
||||
// Decrease the weight to exactly 0.5
|
||||
for (let i = 0; i < 20; i++) {
|
||||
eventBus.emitTokenGroundTruth({ actualTokens: 0, promptBaseUnits: 100 });
|
||||
}
|
||||
|
||||
const turn1Id = 'turn-1';
|
||||
const node1 = createDummyNode(turn1Id, NodeType.USER_PROMPT);
|
||||
|
||||
// Get raw base units directly
|
||||
const rawTokens = calculator.calculateTokensAndBaseUnits([node1]).baseUnits;
|
||||
|
||||
// Get adjusted tokens
|
||||
const adjustedTokens = calculator.calculateConcreteListTokens([node1]);
|
||||
|
||||
expect(adjustedTokens).toBe(Math.round(rawTokens * 0.5));
|
||||
});
|
||||
|
||||
it('should ignore ground truth events with 0 promptBaseUnits to prevent division by zero', () => {
|
||||
const eventBus = new ContextEventBus();
|
||||
const calculator = new AdaptiveTokenCalculator(
|
||||
charsPerToken,
|
||||
registry,
|
||||
eventBus,
|
||||
);
|
||||
|
||||
eventBus.emitTokenGroundTruth({
|
||||
actualTokens: 100,
|
||||
promptBaseUnits: 0,
|
||||
});
|
||||
|
||||
expect(calculator.getLearnedWeight()).toBe(1.0);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,163 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Content, Part } from '@google/genai';
|
||||
import type { ConcreteNode } from '../graph/types.js';
|
||||
import {
|
||||
StaticTokenCalculator,
|
||||
type AdvancedTokenCalculator,
|
||||
} from './contextTokenCalculator.js';
|
||||
import type { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
|
||||
import type { ContextEventBus, TokenGroundTruthEvent } from '../eventBus.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
|
||||
/**
|
||||
* An Adaptive Token Calculator that dynamically learns the true token cost of the user's
|
||||
* conversation by applying an Exponential Moving Average (EMA) gradient descent to
|
||||
* real usage metadata returned from the Gemini API.
|
||||
*
|
||||
* It wraps the deterministic `StaticTokenCalculator` base heuristic to ensure
|
||||
* immutable node cost caching while still surfacing a self-corrected estimate
|
||||
* to the pipeline processors.
|
||||
*/
|
||||
export class AdaptiveTokenCalculator implements AdvancedTokenCalculator {
|
||||
private learnedWeight = 1.0;
|
||||
private readonly baseCalculator: StaticTokenCalculator;
|
||||
|
||||
constructor(
|
||||
charsPerToken: number,
|
||||
registry: NodeBehaviorRegistry,
|
||||
eventBus: ContextEventBus,
|
||||
) {
|
||||
this.baseCalculator = new StaticTokenCalculator(charsPerToken, registry);
|
||||
eventBus.onTokenGroundTruth((event: TokenGroundTruthEvent) => {
|
||||
this.handleGroundTruth(event.actualTokens, event.promptBaseUnits);
|
||||
});
|
||||
}
|
||||
|
||||
private handleGroundTruth(actualTokens: number, promptBaseUnits: number) {
|
||||
if (promptBaseUnits <= 0) return;
|
||||
|
||||
// Determine what ratio we should have used
|
||||
const targetWeight = actualTokens / promptBaseUnits;
|
||||
const oldWeight = this.learnedWeight;
|
||||
|
||||
// Apply Momentum (Learning Rate)
|
||||
const learningRate = 0.2;
|
||||
const newWeight =
|
||||
oldWeight * (1 - learningRate) + targetWeight * learningRate;
|
||||
|
||||
// Clamp to reasonable safety bounds to prevent rogue metadata poisoning the system
|
||||
this.learnedWeight = Math.max(0.5, Math.min(newWeight, 2.0));
|
||||
|
||||
debugLogger.log(
|
||||
`[AdaptiveTokenCalculator] Learned weight updated to ${this.learnedWeight.toFixed(3)} ` +
|
||||
`(API Tokens: ${actualTokens}, Base Units: ${promptBaseUnits}, Target Ratio: ${targetWeight.toFixed(3)})`,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the current learned weight multiplier.
|
||||
*/
|
||||
getLearnedWeight(): number {
|
||||
return this.learnedWeight;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the exact, unweighted Base Heuristic Units for the graph.
|
||||
* This is used exactly once per interaction to capture the baseline sent to the API.
|
||||
*/
|
||||
getRawBaseUnits(nodes: readonly ConcreteNode[]): number {
|
||||
return this.baseCalculator.calculateConcreteListTokens(nodes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the exact, unweighted Base Heuristic Units for a raw content chunk.
|
||||
*/
|
||||
getRawBaseUnitsForContent(content: Content): number {
|
||||
return this.baseCalculator.calculateContentTokens(content);
|
||||
}
|
||||
|
||||
calculateTokensAndBaseUnits(nodes: readonly ConcreteNode[]): {
|
||||
tokens: number;
|
||||
baseUnits: number;
|
||||
} {
|
||||
const baseUnits = this.baseCalculator.calculateConcreteListTokens(nodes);
|
||||
return {
|
||||
tokens: Math.round(baseUnits * this.learnedWeight),
|
||||
baseUnits,
|
||||
};
|
||||
}
|
||||
|
||||
calculateContentTokensAndBaseUnits(content: Content): {
|
||||
tokens: number;
|
||||
baseUnits: number;
|
||||
} {
|
||||
const baseUnits = this.baseCalculator.calculateContentTokens(content);
|
||||
return {
|
||||
tokens: Math.round(baseUnits * this.learnedWeight),
|
||||
baseUnits,
|
||||
};
|
||||
}
|
||||
|
||||
// --- Delegation and Weighting ---
|
||||
|
||||
garbageCollectCache(liveNodeIds: ReadonlySet<string>): void {
|
||||
this.baseCalculator.garbageCollectCache(liveNodeIds);
|
||||
}
|
||||
|
||||
cacheNodeTokens(node: ConcreteNode): number {
|
||||
return this.baseCalculator.cacheNodeTokens(node);
|
||||
}
|
||||
|
||||
calculateTokenBreakdown(nodes: readonly ConcreteNode[]): {
|
||||
text: number;
|
||||
media: number;
|
||||
tool: number;
|
||||
overhead: number;
|
||||
total: number;
|
||||
} {
|
||||
const raw = this.baseCalculator.calculateTokenBreakdown(nodes);
|
||||
return {
|
||||
text: Math.round(raw.text * this.learnedWeight),
|
||||
media: Math.round(raw.media * this.learnedWeight),
|
||||
tool: Math.round(raw.tool * this.learnedWeight),
|
||||
overhead: Math.round(raw.overhead * this.learnedWeight),
|
||||
total: Math.round(raw.total * this.learnedWeight),
|
||||
};
|
||||
}
|
||||
|
||||
estimateTokensForParts(parts: Part[]): number {
|
||||
const baseUnits = this.baseCalculator.estimateTokensForParts(parts);
|
||||
return Math.round(baseUnits * this.learnedWeight);
|
||||
}
|
||||
|
||||
getTokenCost(node: ConcreteNode): number {
|
||||
const baseUnits = this.baseCalculator.getTokenCost(node);
|
||||
return Math.round(baseUnits * this.learnedWeight);
|
||||
}
|
||||
|
||||
calculateConcreteListTokens(nodes: readonly ConcreteNode[]): number {
|
||||
const baseUnits = this.baseCalculator.calculateConcreteListTokens(nodes);
|
||||
return Math.round(baseUnits * this.learnedWeight);
|
||||
}
|
||||
|
||||
calculateContentTokens(content: Content): number {
|
||||
const baseUnits = this.baseCalculator.calculateContentTokens(content);
|
||||
return Math.round(baseUnits * this.learnedWeight);
|
||||
}
|
||||
|
||||
estimateTokensForString(text: string): number {
|
||||
const baseUnits = this.baseCalculator.estimateTokensForString(text);
|
||||
return Math.round(baseUnits * this.learnedWeight);
|
||||
}
|
||||
|
||||
tokensToChars(tokens: number): number {
|
||||
// If weight is > 1.0 (we are inflating tokens), a single returned token is worth fewer chars.
|
||||
// We reverse the math: convert requested tokens to target base units, then get chars.
|
||||
return this.baseCalculator.tokensToChars(tokens / this.learnedWeight);
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,7 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { ContextTokenCalculator } from './contextTokenCalculator.js';
|
||||
import { StaticTokenCalculator } from './contextTokenCalculator.js';
|
||||
import { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
|
||||
import { registerBuiltInBehaviors } from '../graph/builtinBehaviors.js';
|
||||
import { createDummyNode } from '../testing/contextTestUtils.js';
|
||||
@@ -16,7 +16,7 @@ describe('ContextTokenCalculator', () => {
|
||||
const registry = new NodeBehaviorRegistry();
|
||||
registerBuiltInBehaviors(registry);
|
||||
const charsPerToken = 1; // Simplifies math for text nodes in tests
|
||||
const calculator = new ContextTokenCalculator(charsPerToken, registry);
|
||||
const calculator = new StaticTokenCalculator(charsPerToken, registry);
|
||||
|
||||
it('should include structural overhead for each unique turn', () => {
|
||||
const turn1Id = 'turn-1';
|
||||
@@ -28,16 +28,16 @@ describe('ContextTokenCalculator', () => {
|
||||
|
||||
const nodes = [node1, node2, node3];
|
||||
|
||||
// Estimated tokens (using 0.33 per ASCII char heuristic):
|
||||
// node1: floor(17 chars * 0.33) = 5 tokens
|
||||
// node2: floor(17 chars * 0.33) = 5 tokens
|
||||
// node3: floor(19 chars * 0.33) = 6 tokens
|
||||
// Estimated tokens (using charsPerToken = 1):
|
||||
// node1: 17 chars / 1 = 17 tokens
|
||||
// node2: 17 chars / 1 = 17 tokens
|
||||
// node3: 19 chars / 1 = 19 tokens
|
||||
// Turn 1 overhead: 5 tokens
|
||||
// Turn 2 overhead: 5 tokens
|
||||
// Total: 5 + 5 + 6 + 5 + 5 = 26
|
||||
// Total: 17 + 17 + 19 + 5 + 5 = 63
|
||||
|
||||
const total = calculator.calculateConcreteListTokens(nodes);
|
||||
expect(total).toBe(26);
|
||||
expect(total).toBe(63);
|
||||
});
|
||||
|
||||
it('should handle categorical breakdown with overhead', () => {
|
||||
|
||||
@@ -17,7 +17,41 @@ import type { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
|
||||
* by the Gemini API. We use this as a baseline heuristic for inlineData/fileData.
|
||||
*/
|
||||
|
||||
export class ContextTokenCalculator {
|
||||
export interface ContextTokenCalculator {
|
||||
estimateTokensForString(text: string): number;
|
||||
tokensToChars(tokens: number): number;
|
||||
garbageCollectCache(liveNodeIds: ReadonlySet<string>): void;
|
||||
cacheNodeTokens(node: ConcreteNode): number;
|
||||
getTokenCost(node: ConcreteNode): number;
|
||||
calculateTokenBreakdown(nodes: readonly ConcreteNode[]): {
|
||||
text: number;
|
||||
media: number;
|
||||
tool: number;
|
||||
overhead: number;
|
||||
total: number;
|
||||
};
|
||||
calculateConcreteListTokens(nodes: readonly ConcreteNode[]): number;
|
||||
calculateContentTokens(content: Content): number;
|
||||
estimateTokensForParts(parts: Part[]): number;
|
||||
}
|
||||
|
||||
export interface AdvancedTokenCalculator extends ContextTokenCalculator {
|
||||
getRawBaseUnits(nodes: readonly ConcreteNode[]): number;
|
||||
getRawBaseUnitsForContent(content: Content): number;
|
||||
calculateTokensAndBaseUnits(nodes: readonly ConcreteNode[]): {
|
||||
tokens: number;
|
||||
baseUnits: number;
|
||||
};
|
||||
calculateContentTokensAndBaseUnits(content: Content): {
|
||||
tokens: number;
|
||||
baseUnits: number;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* A fast, deterministic token heuristic calculator.
|
||||
*/
|
||||
export class StaticTokenCalculator implements AdvancedTokenCalculator {
|
||||
private readonly tokenCache = new Map<string, number>();
|
||||
|
||||
constructor(
|
||||
@@ -143,6 +177,34 @@ export class ContextTokenCalculator {
|
||||
return breakdown;
|
||||
}
|
||||
|
||||
/**
|
||||
* For the static calculator, Raw Base Units are exactly the same as the final tokens,
|
||||
* because there is no dynamic learned weight (the multiplier is effectively 1.0).
|
||||
*/
|
||||
getRawBaseUnits(nodes: readonly ConcreteNode[]): number {
|
||||
return this.calculateConcreteListTokens(nodes);
|
||||
}
|
||||
|
||||
getRawBaseUnitsForContent(content: Content): number {
|
||||
return this.calculateContentTokens(content);
|
||||
}
|
||||
|
||||
calculateTokensAndBaseUnits(nodes: readonly ConcreteNode[]): {
|
||||
tokens: number;
|
||||
baseUnits: number;
|
||||
} {
|
||||
const baseUnits = this.calculateConcreteListTokens(nodes);
|
||||
return { tokens: baseUnits, baseUnits };
|
||||
}
|
||||
|
||||
calculateContentTokensAndBaseUnits(content: Content): {
|
||||
tokens: number;
|
||||
baseUnits: number;
|
||||
} {
|
||||
const baseUnits = this.calculateContentTokens(content);
|
||||
return { tokens: baseUnits, baseUnits };
|
||||
}
|
||||
|
||||
/**
|
||||
* Fast calculation for a flat array of ConcreteNodes (The Nodes).
|
||||
* It relies entirely on the O(1) sidecar token cache.
|
||||
|
||||
@@ -21,6 +21,9 @@ describe('SnapshotGenerator', () => {
|
||||
llmClient: {
|
||||
generateJson: mockGenerateJson,
|
||||
},
|
||||
advancedTokenCalculator: {
|
||||
getRawBaseUnits: vi.fn().mockReturnValue(100),
|
||||
},
|
||||
tokenCalculator: {
|
||||
estimateTokensForString: vi.fn().mockReturnValue(100),
|
||||
},
|
||||
|
||||
@@ -114,6 +114,7 @@ interface _CommonGenerateOptions {
|
||||
export interface CountTokenOptions {
|
||||
modelConfigKey?: ModelConfigKey;
|
||||
contents: Content[];
|
||||
abortSignal?: AbortSignal;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -240,6 +241,9 @@ export class BaseLlmClient {
|
||||
const result = await this.contentGenerator.countTokens({
|
||||
model,
|
||||
contents: options.contents,
|
||||
config: options.abortSignal
|
||||
? { abortSignal: options.abortSignal }
|
||||
: undefined,
|
||||
});
|
||||
return { totalTokens: result.totalTokens || 0 };
|
||||
}
|
||||
|
||||
@@ -1517,8 +1517,8 @@ ${JSON.stringify(
|
||||
// A string of length 404 is roughly 101 tokens.
|
||||
const longText = 'a'.repeat(404);
|
||||
const request: Part[] = [{ text: longText }];
|
||||
// estimateTextOnlyLength counts only text content (400 chars), not JSON structure
|
||||
const estimatedRequestTokenCount = Math.floor(longText.length * 0.33);
|
||||
// estimateTextOnlyLength counts only text content (404 chars), not JSON structure
|
||||
const estimatedRequestTokenCount = Math.floor(longText.length * 0.25);
|
||||
const remainingTokenCount = MOCKED_TOKEN_LIMIT - lastPromptTokenCount;
|
||||
|
||||
// Mock tryCompressChat to not compress
|
||||
@@ -1577,8 +1577,8 @@ ${JSON.stringify(
|
||||
// We need a request > 100 tokens.
|
||||
const longText = 'a'.repeat(404);
|
||||
const request: Part[] = [{ text: longText }];
|
||||
// estimateTextOnlyLength counts only text content (400 chars), not JSON structure
|
||||
const estimatedRequestTokenCount = Math.floor(longText.length * 0.33);
|
||||
// estimateTextOnlyLength counts only text content (404 chars), not JSON structure
|
||||
const estimatedRequestTokenCount = Math.floor(longText.length * 0.25);
|
||||
const remainingTokenCount = STICKY_MODEL_LIMIT - lastPromptTokenCount;
|
||||
|
||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValue({
|
||||
|
||||
@@ -637,11 +637,22 @@ export class GeminiClient {
|
||||
// Check for context window overflow
|
||||
const modelForLimitCheck = this._getActiveModelForCurrentTurn();
|
||||
|
||||
let currentBaseUnits = 0;
|
||||
|
||||
if (this.config.getContextManagementConfig().enabled) {
|
||||
if (this.contextManager) {
|
||||
const pendingRequest = createUserContent(request);
|
||||
const { history: newHistory, didApplyManagement } =
|
||||
await this.contextManager.renderHistory(pendingRequest);
|
||||
const {
|
||||
history: newHistory,
|
||||
didApplyManagement,
|
||||
baseUnits,
|
||||
} = await this.contextManager.renderHistory(
|
||||
pendingRequest,
|
||||
undefined,
|
||||
signal,
|
||||
);
|
||||
|
||||
currentBaseUnits = baseUnits;
|
||||
|
||||
if (didApplyManagement) {
|
||||
// If the manager pruned history, we update the chat before continuing.
|
||||
@@ -800,8 +811,16 @@ export class GeminiClient {
|
||||
}
|
||||
yield event;
|
||||
|
||||
if (event.type === GeminiEventType.Finished && this.contextManager) {
|
||||
const usageMetadata = event.value.usageMetadata;
|
||||
if (usageMetadata && usageMetadata.promptTokenCount !== undefined) {
|
||||
this.contextManager.getEnvironment().eventBus.emitTokenGroundTruth({
|
||||
actualTokens: usageMetadata.promptTokenCount,
|
||||
promptBaseUnits: currentBaseUnits,
|
||||
});
|
||||
}
|
||||
}
|
||||
this.updateTelemetryTokenCount();
|
||||
|
||||
if (event.type === GeminiEventType.Error) {
|
||||
isError = true;
|
||||
}
|
||||
|
||||
@@ -242,7 +242,7 @@ describe('GeminiChat', () => {
|
||||
// 'Hello': 5 chars * 0.25 = 1.25
|
||||
// 'Hi there': 8 chars * 0.25 = 2.0
|
||||
// Total: 3.25 -> floor(3.25) = 3
|
||||
expect(chatWithHistory.getLastPromptTokenCount()).toBe(4);
|
||||
expect(chatWithHistory.getLastPromptTokenCount()).toBe(3);
|
||||
});
|
||||
|
||||
it('should initialize lastPromptTokenCount for empty history', () => {
|
||||
|
||||
@@ -280,6 +280,26 @@ describe('tokenCalculation', () => {
|
||||
expect(tokens).toBeLessThan(30);
|
||||
});
|
||||
|
||||
it('should respect the user supplied charsPerToken argument', () => {
|
||||
const text = 'abcdefghijkl'; // 12 chars
|
||||
const parts: Part[] = [{ text }];
|
||||
|
||||
// Default (4 chars/token) -> 12 / 4 = 3 tokens
|
||||
expect(estimateTokenCountSync(parts)).toBe(3);
|
||||
|
||||
// Override to 3 chars/token -> 12 / 3 = 4 tokens
|
||||
expect(estimateTokenCountSync(parts, 0, 3)).toBe(4);
|
||||
|
||||
// Override to 2 chars/token -> 12 / 2 = 6 tokens
|
||||
expect(estimateTokenCountSync(parts, 0, 2)).toBe(6);
|
||||
|
||||
// Verify massive strings also respect the argument
|
||||
const massiveText = 'a'.repeat(120_000); // Exceeds 100k
|
||||
const massiveParts: Part[] = [{ text: massiveText }];
|
||||
expect(estimateTokenCountSync(massiveParts, 0, 4)).toBe(30_000);
|
||||
expect(estimateTokenCountSync(massiveParts, 0, 3)).toBe(40_000);
|
||||
});
|
||||
|
||||
it('should handle empty or nullish inputs gracefully', () => {
|
||||
expect(estimateTokenCountSync([])).toBe(0);
|
||||
expect(estimateTokenCountSync([{ text: '' }])).toBe(0);
|
||||
|
||||
@@ -43,10 +43,12 @@ function estimateTextTokens(text: string, charsPerToken: number): number {
|
||||
}
|
||||
|
||||
let tokens = 0;
|
||||
const asciiTokensPerChar = 1 / charsPerToken;
|
||||
|
||||
// Optimized loop: charCodeAt is faster than for...of on large strings
|
||||
for (let i = 0; i < text.length; i++) {
|
||||
if (text.charCodeAt(i) <= 127) {
|
||||
tokens += ASCII_TOKENS_PER_CHAR;
|
||||
tokens += asciiTokensPerChar;
|
||||
} else {
|
||||
tokens += NON_ASCII_TOKENS_PER_CHAR;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user