feat(context): Introduce adaptive token calculator to more accurately calculate content sizes. (#26888)

This commit is contained in:
joshualitt
2026-05-12 08:51:20 -07:00
committed by GitHub
parent 7a9ed4c20a
commit 07792f98cd
26 changed files with 856 additions and 164 deletions
@@ -37,8 +37,8 @@ describe('ContextManager Sync Pressure Barrier Tests', () => {
]);
// 3. Add massive history that blows past the 150k maxTokens limit
// 20 turns * 10,000 tokens/turn = ~200,000 tokens
const massiveHistory = createSyntheticHistory(20, 35000);
// 20 turns * ~20,000 tokens/turn (10k user + 10k model) = ~400,000 tokens
const massiveHistory = createSyntheticHistory(20, 10000);
chatHistory.set([...chatHistory.get(), ...massiveHistory]);
// 4. Add the Latest Turn (Protected)
@@ -60,8 +60,8 @@ describe('ContextManager Sync Pressure Barrier Tests', () => {
// Verify Episode 0 (System) was pruned, so we now start with a sentinel due to role alternation
expect(projection[0].role).toBe('user');
expect(projection[0].parts![0].text).toContain('User turn 17');
const projectionString = JSON.stringify(projection);
expect(projectionString).toContain('User turn 17');
// Filter out synthetic Yield nodes (they are model responses without actual tool/text bodies)
const contentNodes = projection.filter(
(p) =>
@@ -0,0 +1,92 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi } from 'vitest';
import {
createMockContextConfig,
setupContextComponentTest,
createMockLlmClient,
} from './testing/contextTestUtils.js';
import { stressTestProfile } from './config/profiles.js';
describe('ContextManager - Hot Start Calibration', () => {
it('should not perform calibration if the buffer is empty', async () => {
const mockLlm = createMockLlmClient();
const config = createMockContextConfig(undefined, mockLlm);
const { contextManager } = setupContextComponentTest(
config,
stressTestProfile,
);
// We can spy on the underlying mock LLM client countTokens
const countTokensSpy = vi.spyOn(mockLlm, 'countTokens');
// Render an empty graph
await contextManager.renderHistory();
expect(countTokensSpy).not.toHaveBeenCalled();
});
it('should perform calibration exactly once when rendering with existing nodes', async () => {
const mockLlm = createMockLlmClient();
const countTokensSpy = vi
.spyOn(mockLlm, 'countTokens')
.mockResolvedValue({ totalTokens: 42 });
const config = createMockContextConfig(undefined, mockLlm);
const { contextManager, chatHistory } = setupContextComponentTest(
config,
stressTestProfile,
);
// We need to access the env's eventBus inside the contextManager
const env = Reflect.get(contextManager, 'env');
const emitGroundTruthSpy = vi.spyOn(env.eventBus, 'emitTokenGroundTruth');
// Add a node to make the buffer non-empty
chatHistory.set([{ role: 'user', parts: [{ text: 'Hello' }] }]);
// First render should trigger calibration
await contextManager.renderHistory();
expect(countTokensSpy).toHaveBeenCalledTimes(1);
expect(emitGroundTruthSpy).toHaveBeenCalledTimes(1);
expect(emitGroundTruthSpy).toHaveBeenCalledWith(
expect.objectContaining({
actualTokens: 42,
promptBaseUnits: 10,
}),
);
// Second render should skip calibration
await contextManager.renderHistory();
expect(countTokensSpy).toHaveBeenCalledTimes(1);
// emit hasn't been called again
expect(emitGroundTruthSpy).toHaveBeenCalledTimes(1);
});
it('should silently swallow errors if countTokens API fails', async () => {
const mockLlm = createMockLlmClient();
const countTokensSpy = vi
.spyOn(mockLlm, 'countTokens')
.mockRejectedValue(new Error('API failure'));
const config = createMockContextConfig(undefined, mockLlm);
const { contextManager, chatHistory } = setupContextComponentTest(
config,
stressTestProfile,
);
// Add a node
chatHistory.set([{ role: 'user', parts: [{ text: 'Hello' }] }]);
// Render should succeed without throwing
const result = await contextManager.renderHistory();
expect(result.history).toBeDefined();
expect(countTokensSpy).toHaveBeenCalledTimes(1);
});
});
+99 -15
View File
@@ -18,6 +18,7 @@ import { ContextWorkingBufferImpl } from './pipeline/contextWorkingBuffer.js';
import { debugLogger } from '../utils/debugLogger.js';
import { hardenHistory } from '../utils/historyHardening.js';
import { checkContextInvariants } from './utils/invariantChecker.js';
import type { AdvancedTokenCalculator } from './utils/contextTokenCalculator.js';
export class ContextManager {
// The master state containing the pristine graph and current active graph.
@@ -36,15 +37,22 @@ export class ContextManager {
// Cache for Anomaly 3 (Redundant Renders)
private lastRenderCache?: {
nodesHash: string;
result: { history: Content[]; didApplyManagement: boolean };
result: {
history: Content[];
didApplyManagement: boolean;
baseUnits: number;
};
};
private hasPerformedHotStart = false;
constructor(
private readonly sidecar: ContextProfile,
private readonly env: ContextEnvironment,
private readonly tracer: ContextTracer,
orchestrator: PipelineOrchestrator,
chatHistory: AgentChatHistory,
private readonly advancedTokenCalculator: AdvancedTokenCalculator,
private readonly headerProvider?: () => Promise<Content | undefined>,
) {
this.eventBus = env.eventBus;
@@ -260,6 +268,10 @@ export class ContextManager {
return [...this.buffer.nodes];
}
getEnvironment(): ContextEnvironment {
return this.env;
}
/**
* Executes the final 'gc_backstop' pipeline if necessary, enforcing the token budget,
* and maps the Episodic Context Graph back into a raw Gemini Content[] array for transmission.
@@ -268,22 +280,44 @@ export class ContextManager {
async renderHistory(
pendingRequest?: Content,
activeTaskIds: Set<string> = new Set(),
): Promise<{ history: Content[]; didApplyManagement: boolean }> {
abortSignal?: AbortSignal,
): Promise<{
history: Content[];
didApplyManagement: boolean;
baseUnits: number;
}> {
this.tracer.logEvent('ContextManager', 'Starting rendering of LLM context');
let previewNodes: ConcreteNode[] = [];
if (pendingRequest) {
previewNodes = this.env.graphMapper.applyEvent({
type: 'PUSH',
payload: [pendingRequest],
});
}
// --- Hot Start Calibration ---
// If we are resuming a session with history, we don't want the adaptive token calculator
// to fly blind on its first GC pass. We do a one-time API calibration.
const hotStartPromise = (async () => {
if (!this.hasPerformedHotStart) {
this.hasPerformedHotStart = true;
if (this.buffer.nodes.length > 0) {
const nodesForHotStart = [...this.buffer.nodes, ...previewNodes];
await this.performHotStartCalibration(nodesForHotStart, abortSignal);
}
}
})();
// 1. Synchronous Pressure Barrier: Wait for background management pipelines to finish.
// This ensures that the render sees the results of recent pushes (Anomaly 2).
await this.orchestrator.waitForPipelines();
// We run hot start calibration in parallel to hide the network latency.
await Promise.all([this.orchestrator.waitForPipelines(), hotStartPromise]);
let nodes = this.buffer.nodes;
const previewNodeIds = new Set<string>();
// If we have a pending request, we need to build a 'preview' graph for this render.
if (pendingRequest) {
const previewNodes = this.env.graphMapper.applyEvent({
type: 'PUSH',
payload: [pendingRequest],
});
// Apply the preview nodes to the final graph
if (previewNodes.length > 0) {
for (const n of previewNodes) {
previewNodeIds.add(n.id);
}
@@ -294,9 +328,6 @@ export class ContextManager {
const header = this.headerProvider
? await this.headerProvider()
: undefined;
const headerTokens = header
? this.env.tokenCalculator.calculateContentTokens(header)
: 0;
// 3. Cache Check (Anomaly 3): If nodes haven't changed, return previous result.
// We combine the graph hash with a hash of the header to ensure total freshness.
@@ -314,14 +345,19 @@ export class ContextManager {
const protectionReasons = this.getProtectedNodeIds(nodes, activeTaskIds);
// Apply final GC Backstop pressure barrier synchronously before mapping
const { history: renderedHistory, didApplyManagement } = await render(
const {
history: renderedHistory,
didApplyManagement,
baseUnits,
} = await render(
nodes,
this.orchestrator,
this.sidecar,
this.tracer,
this.env,
this.advancedTokenCalculator,
protectionReasons,
headerTokens,
header,
previewNodeIds,
);
@@ -339,10 +375,58 @@ export class ContextManager {
sentinels: this.sidecar.sentinels,
}),
didApplyManagement,
baseUnits,
};
// Update cache
this.lastRenderCache = { nodesHash: totalHash, result };
return result;
}
private async performHotStartCalibration(
nodes: readonly ConcreteNode[],
abortSignal?: AbortSignal,
) {
try {
this.tracer.logEvent(
'ContextManager',
'Performing Hot Start Token Calibration',
);
const contents = this.env.graphMapper.fromGraph(nodes);
const header = this.headerProvider
? await this.headerProvider()
: undefined;
const combinedHistory = header ? [header, ...contents] : contents;
const baseUnits =
this.advancedTokenCalculator.getRawBaseUnits(nodes) +
(header
? this.advancedTokenCalculator.getRawBaseUnitsForContent(header)
: 0);
// We only make the network call if we have actual contents to send,
// avoiding 400 Bad Request errors from the API.
if (combinedHistory.length > 0) {
const result = await this.env.llmClient.countTokens({
contents: combinedHistory,
abortSignal,
});
if (result.totalTokens > 0) {
this.env.eventBus.emitTokenGroundTruth({
actualTokens: result.totalTokens,
promptBaseUnits: baseUnits,
});
}
}
} catch (error) {
// Hot start calibration is purely an optimization. If the network fails or auth is weird,
// we silently swallow and fallback to the un-calibrated 1.0 ratio heuristic.
this.tracer.logEvent(
'ContextManager',
'Hot Start Token Calibration Failed (Ignored)',
{ error },
);
}
}
}
+13
View File
@@ -29,7 +29,20 @@ export interface ChunkReceivedEvent {
targetNodeIds: Set<string>;
}
export interface TokenGroundTruthEvent {
actualTokens: number;
promptBaseUnits: number;
}
export class ContextEventBus extends EventEmitter {
emitTokenGroundTruth(event: TokenGroundTruthEvent) {
this.emit('TOKEN_GROUND_TRUTH', event);
}
onTokenGroundTruth(listener: (event: TokenGroundTruthEvent) => void) {
this.on('TOKEN_GROUND_TRUTH', listener);
}
emitPristineHistoryUpdated(event: PristineHistoryUpdatedEvent) {
this.emit('PRISTINE_HISTORY_UPDATED', event);
}
+49 -5
View File
@@ -8,6 +8,7 @@ import { describe, it, expect, vi } from 'vitest';
import { render } from './render.js';
import type { ConcreteNode } from './types.js';
import { NodeType } from './types.js';
import type { AdvancedTokenCalculator } from '../utils/contextTokenCalculator.js';
import type { ContextEnvironment } from '../pipeline/environment.js';
import type { ContextTracer } from '../tracer.js';
import type { ContextProfile } from '../config/profiles.js';
@@ -37,7 +38,20 @@ describe('render', () => {
const orchestrator = {} as PipelineOrchestrator;
const sidecar = { config: {} } as ContextProfile; // No budget
const mockAdvancedTokenCalculator = {
calculateTokensAndBaseUnits: vi.fn().mockReturnValue({
tokens: 100,
baseUnits: 100,
}),
getRawBaseUnits: vi.fn().mockReturnValue(100),
getRawBaseUnitsForContent: vi.fn().mockReturnValue(0),
};
const env = {
tokenCalculator: {
calculateConcreteListTokens: vi.fn().mockReturnValue(100),
calculateTokenBreakdown: vi.fn().mockReturnValue({}),
},
graphMapper: {
fromGraph: vi.fn((nodes: readonly ConcreteNode[]) =>
nodes.map((n) => ({ text: n.id })),
@@ -54,12 +68,14 @@ describe('render', () => {
sidecar,
tracer,
env,
mockAdvancedTokenCalculator as unknown as AdvancedTokenCalculator,
new Map(),
0,
undefined,
previewNodeIds,
);
expect(result.history).toEqual([{ text: '1' }, { text: '2' }]);
expect(result.baseUnits).toBe(100);
});
it('simulates the boundary knapsack problem (loose boundary policy)', async () => {
@@ -108,12 +124,24 @@ describe('render', () => {
const currentTokens = 160000;
const mockAdvancedTokenCalculator = {
calculateTokensAndBaseUnits: vi.fn((nodes: readonly ConcreteNode[]) => {
const tokens =
nodes.length === 1 ? tokenMap[nodes[0].id] : currentTokens;
return { tokens, baseUnits: tokens };
}),
getRawBaseUnits: vi.fn((nodes: readonly ConcreteNode[]) => {
if (nodes.length === 1) return tokenMap[nodes[0].id];
return currentTokens;
}),
};
const env = {
llmClient: {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1000 }),
},
tokenCalculator: {
calculateConcreteListTokens: vi.fn((nodes) => {
calculateConcreteListTokens: vi.fn((nodes: readonly ConcreteNode[]) => {
if (nodes.length === 1) return tokenMap[nodes[0].id];
return currentTokens;
}),
@@ -136,8 +164,9 @@ describe('render', () => {
sidecar,
tracer,
env,
mockAdvancedTokenCalculator as unknown as AdvancedTokenCalculator,
new Map(),
0,
undefined,
new Set(),
);
@@ -147,6 +176,7 @@ describe('render', () => {
// Adding C pushes rolling total (70k) above retainedTokens (65k).
// Under loose policy, C survives. D is strictly older and drops.
expect(surviving).toEqual(['C', 'B', 'A']); // D is dropped
expect(result.baseUnits).toBe(160000);
});
it('drops nodes that are STRICTLY older than the boundary node', async () => {
@@ -188,12 +218,24 @@ describe('render', () => {
const currentTokens = 160000;
const mockAdvancedTokenCalculator = {
calculateTokensAndBaseUnits: vi.fn((nodes: readonly ConcreteNode[]) => {
const tokens =
nodes.length === 1 ? tokenMap[nodes[0].id] : currentTokens;
return { tokens, baseUnits: tokens };
}),
getRawBaseUnits: vi.fn((nodes: readonly ConcreteNode[]) => {
if (nodes.length === 1) return tokenMap[nodes[0].id];
return currentTokens;
}),
};
const env = {
llmClient: {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1000 }),
},
tokenCalculator: {
calculateConcreteListTokens: vi.fn((nodes) => {
calculateConcreteListTokens: vi.fn((nodes: readonly ConcreteNode[]) => {
if (nodes.length === 1) return tokenMap[nodes[0].id];
return currentTokens;
}),
@@ -216,8 +258,9 @@ describe('render', () => {
sidecar,
tracer,
env,
mockAdvancedTokenCalculator as unknown as AdvancedTokenCalculator,
new Map(),
0,
undefined,
new Set(),
);
@@ -225,5 +268,6 @@ describe('render', () => {
const surviving = result.history.map((c: any) => c.text);
// C(40k), B(40k). Adding B pushes total to 80k. B is the boundary node and survives. A drops.
expect(surviving).toEqual(['B', 'C']); // A is dropped
expect(result.baseUnits).toBe(160000);
});
});
+38 -6
View File
@@ -11,6 +11,7 @@ import type { ContextProfile } from '../config/profiles.js';
import type { PipelineOrchestrator } from '../pipeline/orchestrator.js';
import type { ContextEnvironment } from '../pipeline/environment.js';
import { performCalibration } from '../utils/tokenCalibration.js';
import type { AdvancedTokenCalculator } from '../utils/contextTokenCalculator.js';
/**
* Maps the Episodic Context Graph back into a raw Gemini Content[] array for transmission.
@@ -22,21 +23,43 @@ export async function render(
sidecar: ContextProfile,
tracer: ContextTracer,
env: ContextEnvironment,
advancedTokenCalculator: AdvancedTokenCalculator,
protectionReasons: Map<string, string> = new Map(),
headerTokens: number = 0,
header?: Content,
previewNodeIds: ReadonlySet<string> = new Set(),
): Promise<{ history: Content[]; didApplyManagement: boolean }> {
): Promise<{
history: Content[];
didApplyManagement: boolean;
baseUnits: number;
}> {
let headerTokens = 0;
let headerBaseUnits = 0;
if (header) {
const costs =
advancedTokenCalculator.calculateContentTokensAndBaseUnits(header);
headerTokens = costs.tokens;
headerBaseUnits = costs.baseUnits;
}
if (!sidecar.config.budget) {
const visibleNodes = nodes.filter((n) => !previewNodeIds.has(n.id));
const contents = env.graphMapper.fromGraph(visibleNodes);
tracer.logEvent('Render', 'Render Context to LLM (No Budget)', {
renderedContext: contents,
});
return { history: contents, didApplyManagement: false };
// In all cases, retrieve raw base units from the token calculator interface
const baseUnits =
advancedTokenCalculator.getRawBaseUnits(nodes) + headerBaseUnits;
return { history: contents, didApplyManagement: false, baseUnits };
}
const maxTokens = sidecar.config.budget.maxTokens;
const graphTokens = env.tokenCalculator.calculateConcreteListTokens(nodes);
const { tokens: graphTokens, baseUnits: graphBaseUnits } =
advancedTokenCalculator.calculateTokensAndBaseUnits(nodes);
const currentTokens = graphTokens + headerTokens;
const protectedIds = new Set(protectionReasons.keys());
@@ -70,7 +93,11 @@ export async function render(
renderedContext: contents,
});
performCalibration(env, visibleNodes, contents);
return { history: contents, didApplyManagement: false };
return {
history: contents,
didApplyManagement: false,
baseUnits: graphBaseUnits + headerBaseUnits,
};
}
const targetDelta = currentTokens - sidecar.config.budget.retainedTokens;
tracer.logEvent(
@@ -119,5 +146,10 @@ export async function render(
renderedContextSanitized: contents,
});
performCalibration(env, visibleNodes, contents);
return { history: contents, didApplyManagement: true };
return {
history: contents,
didApplyManagement: true,
baseUnits:
advancedTokenCalculator.getRawBaseUnits(visibleNodes) + headerBaseUnits,
};
}
+17 -1
View File
@@ -23,6 +23,9 @@ import { StateSnapshotProcessorOptionsSchema } from './processors/stateSnapshotP
import { StateSnapshotAsyncProcessorOptionsSchema } from './processors/stateSnapshotAsyncProcessor.js';
import { RollingSummaryProcessorOptionsSchema } from './processors/rollingSummaryProcessor.js';
import { getEnvironmentContext } from '../utils/environmentContext.js';
import { AdaptiveTokenCalculator } from './utils/adaptiveTokenCalculator.js';
import { NodeBehaviorRegistry } from './graph/behaviorRegistry.js';
import { registerBuiltInBehaviors } from './graph/builtinBehaviors.js';
export async function initializeContextManager(
config: Config,
@@ -85,6 +88,16 @@ export async function initializeContextManager(
const eventBus = new ContextEventBus();
const charsPerToken = 3;
const behaviorRegistry = new NodeBehaviorRegistry();
registerBuiltInBehaviors(behaviorRegistry);
const calculator = new AdaptiveTokenCalculator(
charsPerToken,
behaviorRegistry,
eventBus,
);
const env = new ContextEnvironmentImpl(
() => config.getBaseLlmClient(),
config.getSessionId(),
@@ -92,8 +105,10 @@ export async function initializeContextManager(
logDir,
projectTempDir,
tracer,
4,
charsPerToken,
eventBus,
calculator,
behaviorRegistry,
{
calibrateTokenCalculation:
!!process.env['GEMINI_CONTEXT_CALIBRATE_TOKEN_CALCULATIONS'],
@@ -114,6 +129,7 @@ export async function initializeContextManager(
tracer,
orchestrator,
chat.agentHistory,
calculator,
async () => {
const parts = await getEnvironmentContext(config);
return { role: 'user', parts };
@@ -8,12 +8,16 @@ import { ContextEnvironmentImpl } from './environmentImpl.js';
import { ContextTracer } from '../tracer.js';
import { ContextEventBus } from '../eventBus.js';
import { createMockLlmClient } from '../testing/contextTestUtils.js';
import { StaticTokenCalculator } from '../utils/contextTokenCalculator.js';
import { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
describe('ContextEnvironmentImpl', () => {
it('should initialize with defaults correctly', () => {
const tracer = new ContextTracer({ targetDir: '/tmp', sessionId: 'mock' });
const eventBus = new ContextEventBus();
const mockLlmClient = createMockLlmClient();
const behaviorRegistry = new NodeBehaviorRegistry();
const calculator = new StaticTokenCalculator(4, behaviorRegistry);
const env = new ContextEnvironmentImpl(
() => mockLlmClient,
@@ -24,6 +28,8 @@ describe('ContextEnvironmentImpl', () => {
tracer,
4,
eventBus,
calculator,
behaviorRegistry,
);
expect(env.llmClient).toBe(mockLlmClient);
@@ -8,16 +8,13 @@ import type { BaseLlmClient } from '../../core/baseLlmClient.js';
import type { ContextTracer } from '../tracer.js';
import type { ContextEnvironment, RenderOptions } from './environment.js';
import type { ContextEventBus } from '../eventBus.js';
import { ContextTokenCalculator } from '../utils/contextTokenCalculator.js';
import type { ContextTokenCalculator } from '../utils/contextTokenCalculator.js';
import { LiveInbox } from './inbox.js';
import { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
import { registerBuiltInBehaviors } from '../graph/builtinBehaviors.js';
import type { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
import { ContextGraphMapper } from '../graph/mapper.js';
export class ContextEnvironmentImpl implements ContextEnvironment {
readonly tokenCalculator: ContextTokenCalculator;
readonly inbox: LiveInbox;
readonly behaviorRegistry: NodeBehaviorRegistry;
readonly graphMapper: ContextGraphMapper;
constructor(
@@ -29,14 +26,10 @@ export class ContextEnvironmentImpl implements ContextEnvironment {
readonly tracer: ContextTracer,
readonly charsPerToken: number,
readonly eventBus: ContextEventBus,
readonly tokenCalculator: ContextTokenCalculator,
readonly behaviorRegistry: NodeBehaviorRegistry,
readonly renderOptions?: RenderOptions,
) {
this.behaviorRegistry = new NodeBehaviorRegistry();
registerBuiltInBehaviors(this.behaviorRegistry);
this.tokenCalculator = new ContextTokenCalculator(
this.charsPerToken,
this.behaviorRegistry,
);
this.inbox = new LiveInbox();
this.graphMapper = new ContextGraphMapper();
}
@@ -66,12 +66,12 @@ describe('BlobDegradationProcessor', () => {
const node1 = createDummyNode('ep1', NodeType.USER_PROMPT, 100, {
payload: {
fileData: { mimeType: 'video/mp4', fileUri: 'gs://test1' },
fileData: { mimeType: 'image/png', fileUri: 'gs://test1' },
},
});
const node2 = createDummyNode('ep1', NodeType.USER_PROMPT, 100, {
payload: {
fileData: { mimeType: 'video/mp4', fileUri: 'gs://test2' },
fileData: { mimeType: 'image/png', fileUri: 'gs://test2' },
},
});
@@ -52,7 +52,7 @@ describe('StateSnapshotAsyncProcessor', () => {
'PROPOSED_SNAPSHOT',
expect.objectContaining({
newText:
'{"active_tasks":[],"discovered_facts":[],"constraints_and_preferences":[],"recent_arc":[]}',
'{"active_tasks":[],"discovered_facts":[],"constraints_and_preferences":[],"recent_arc":["Mock LLM summary response"]}',
consumedIds: ['node-A', 'node-B'],
type: 'point-in-time',
}),
@@ -107,7 +107,7 @@ describe('StateSnapshotAsyncProcessor', () => {
'PROPOSED_SNAPSHOT',
expect.objectContaining({
newText:
'{"active_tasks":[],"discovered_facts":[],"constraints_and_preferences":[],"recent_arc":[]}',
'{"active_tasks":[],"discovered_facts":[],"constraints_and_preferences":[],"recent_arc":["Mock LLM summary response"]}',
consumedIds: ['node-A', 'node-B', 'node-C'], // Aggregated!
type: 'accumulate',
}),
File diff suppressed because one or more lines are too long
@@ -9,6 +9,7 @@ import { SimulationHarness } from './simulationHarness.js';
import { createMockLlmClient } from '../testing/contextTestUtils.js';
import type { ContextProfile } from '../config/profiles.js';
import { generalistProfile } from '../config/profiles.js';
import type { Content } from '@google/genai';
describe('Context Manager Hysteresis Tests', () => {
const mockLlmClient = createMockLlmClient(['<SNAPSHOT>']);
@@ -25,6 +26,12 @@ describe('Context Manager Hysteresis Tests', () => {
},
});
const getProjectionTokens = (proj: Content[], harness: SimulationHarness) =>
proj.reduce(
(sum, c) => sum + harness.env.tokenCalculator.calculateContentTokens(c),
0,
);
it('should block consolidation when deficit is below coalescing threshold', async () => {
const threshold = 1500;
const harness = await SimulationHarness.create(
@@ -35,14 +42,14 @@ describe('Context Manager Hysteresis Tests', () => {
// Turn 0: INIT
await harness.simulateTurn([{ role: 'user', parts: [{ text: 'INIT' }] }]);
// Turn 1: Add 1500 chars (~500 tokens). Total ~500. Under retained (1000).
// Turn 1: Add ~500 tokens
await harness.simulateTurn([
{ role: 'user', parts: [{ text: 'A'.repeat(1500) }] },
{ role: 'user', parts: [{ text: 'A'.repeat(500) }] },
]);
// Turn 2: Add 3000 chars (~1000 tokens). Total ~1500. Deficit ~500 < 1500.
// Turn 2: Add ~1000 tokens. Total ~1500. Deficit ~500 < 1500.
await harness.simulateTurn([
{ role: 'user', parts: [{ text: 'B'.repeat(3000) }] },
{ role: 'user', parts: [{ text: 'B'.repeat(1000) }] },
]);
await new Promise((resolve) => setTimeout(resolve, 100));
@@ -54,19 +61,19 @@ describe('Context Manager Hysteresis Tests', () => {
),
).toBe(false);
// Turn 3: Add 9000 chars (~3000 tokens). Total ~4500.
// Turn 3: Add ~3000 tokens. Total ~4500.
// Deficit ~3500 > 1500. TRIGGER!
await harness.simulateTurn([
{ role: 'user', parts: [{ text: 'C'.repeat(9000) }] },
{ role: 'user', parts: [{ text: 'C'.repeat(3000) }] },
]);
// Give it a moment for the async task to finish
await new Promise((resolve) => setTimeout(resolve, 500));
// Exceed maxTokens to force a render that shows the snapshot
// Add 3000 more tokens (9000 chars). Total ~7500 > 5000.
// Add ~3000 tokens. Total ~7500 > 5000.
await harness.simulateTurn([
{ role: 'user', parts: [{ text: 'D'.repeat(9000) }] },
{ role: 'user', parts: [{ text: 'D'.repeat(3000) }] },
]);
state = await harness.getGoldenState();
@@ -85,57 +92,51 @@ describe('Context Manager Hysteresis Tests', () => {
);
// 1. Trigger first consolidation
// Add ~9000 chars (~3000 tokens). Total ~3000. Deficit ~2000 > 1000.
// Add ~3000 tokens. Total ~3000. Deficit ~2000 > 1000.
await harness.simulateTurn([
{ role: 'user', parts: [{ text: 'A'.repeat(9000) }] },
{ role: 'user', parts: [{ text: 'A'.repeat(3000) }] },
]);
await harness.simulateTurn([{ role: 'user', parts: [{ text: 'B' }] }]); // Make eligible
await new Promise((resolve) => setTimeout(resolve, 500));
// Exceed maxTokens (5000) to see it
await harness.simulateTurn([
{ role: 'user', parts: [{ text: 'X'.repeat(9000) }] },
{ role: 'user', parts: [{ text: 'X'.repeat(3000) }] },
]);
const state = await harness.getGoldenState();
// Get baseline tokens
let state = await harness.getGoldenState();
expect(
state.finalProjection.some((c) =>
c.parts?.some((p) => p.text?.includes('<SNAPSHOT>')),
),
).toBe(true);
// Get baseline tokens
const baselineTokens =
harness.env.tokenCalculator.calculateConcreteListTokens(
harness.contextManager.getNodes(),
);
const baselineTokens = getProjectionTokens(state.finalProjection, harness);
// 2. Add nodes again, staying below threshold growth
// Add 1500 chars (~500 tokens). Growth ~500 < 1000.
// Add ~500 tokens. Growth ~500 < 1000.
await harness.simulateTurn([
{ role: 'user', parts: [{ text: 'C'.repeat(1500) }] },
{ role: 'user', parts: [{ text: 'C'.repeat(500) }] },
]);
await harness.simulateTurn([{ role: 'user', parts: [{ text: 'D' }] }]); // Make eligible
await new Promise((resolve) => setTimeout(resolve, 200));
const currentTokens =
harness.env.tokenCalculator.calculateConcreteListTokens(
harness.contextManager.getNodes(),
);
state = await harness.getGoldenState();
const currentTokens = getProjectionTokens(state.finalProjection, harness);
// Should not have shrunk further (except for D's small addition)
expect(currentTokens).toBeGreaterThanOrEqual(baselineTokens);
// 3. Exceed threshold growth
// Add 6000 chars (~2000 tokens). Growth = ~500 + ~2000 = ~2500 > 1000.
// Add ~2000 tokens. Growth = ~500 + ~2000 = ~2500 > 1000.
await harness.simulateTurn([
{ role: 'user', parts: [{ text: 'E'.repeat(6000) }] },
{ role: 'user', parts: [{ text: 'E'.repeat(2000) }] },
]);
await harness.simulateTurn([{ role: 'user', parts: [{ text: 'F' }] }]); // Make eligible
await new Promise((resolve) => setTimeout(resolve, 500));
const finalTokens = harness.env.tokenCalculator.calculateConcreteListTokens(
harness.contextManager.getNodes(),
);
state = await harness.getGoldenState();
const finalTokens = getProjectionTokens(state.finalProjection, harness);
// Now it should have consolidated again (E should be replaced by a snapshot eventually)
expect(finalTokens).toBeLessThan(currentTokens + 2000);
});
@@ -13,6 +13,9 @@ import { ContextTracer } from '../tracer.js';
import { ContextEventBus } from '../eventBus.js';
import { PipelineOrchestrator } from '../pipeline/orchestrator.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
import { StaticTokenCalculator } from '../utils/contextTokenCalculator.js';
import { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
import { registerBuiltInBehaviors } from '../graph/builtinBehaviors.js';
export interface TurnSummary {
turnIndex: number;
@@ -57,6 +60,11 @@ export class SimulationHarness {
targetDir: mockTempDir,
sessionId: 'sim-session',
});
const behaviorRegistry = new NodeBehaviorRegistry();
registerBuiltInBehaviors(behaviorRegistry);
const calculator = new StaticTokenCalculator(1, behaviorRegistry);
this.env = new ContextEnvironmentImpl(
() => mockLlmClient,
'sim-prompt',
@@ -66,6 +74,8 @@ export class SimulationHarness {
this.tracer,
1, // 1 char per token average for estimation (but estimator uses 0.33)
this.eventBus,
calculator,
behaviorRegistry,
);
this.orchestrator = new PipelineOrchestrator(
@@ -81,6 +91,7 @@ export class SimulationHarness {
this.tracer,
this.orchestrator,
this.chatHistory,
calculator,
);
}
@@ -111,11 +122,12 @@ export class SimulationHarness {
}
async getGoldenState() {
const { history: finalProjection } =
const { history: finalProjection, baseUnits } =
await this.contextManager.renderHistory();
return {
tokenTrajectory: this.tokenTrajectory,
finalProjection,
baseUnits,
};
}
}
@@ -30,6 +30,9 @@ import type { ContextProfile } from '../config/profiles.js';
import type { Mock } from 'vitest';
import { ContextWorkingBufferImpl } from '../pipeline/contextWorkingBuffer.js';
import { testTruncateProfile } from './testProfile.js';
import { StaticTokenCalculator } from '../utils/contextTokenCalculator.js';
import { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
import { registerBuiltInBehaviors } from '../graph/builtinBehaviors.js';
/**
* Creates a valid mock GenerateContentResponse with the provided text.
@@ -134,17 +137,30 @@ export function createMockLlmClient(
);
});
const generateJsonMock = vi.fn().mockImplementation(async () => ({
active_tasks: [],
discovered_facts: [],
constraints_and_preferences: [],
recent_arc: [],
}));
const generateJsonMock = vi.fn().mockImplementation(async () => {
let mockStr = '';
if (responses && responses.length > 0) {
const callCount = generateJsonMock.mock.calls.length - 1;
const idx =
callCount < responses.length ? callCount : responses.length - 1;
const res = responses[idx];
if (typeof res === 'string') {
mockStr = res;
}
}
return {
active_tasks: [],
discovered_facts: [],
constraints_and_preferences: [],
chronological_summary: mockStr,
};
});
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
return {
generateContent: generateContentMock,
generateJson: generateJsonMock,
countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }),
} as unknown as MockLlmClient;
}
@@ -158,6 +174,9 @@ export function createMockEnvironment(
sessionId: 'mock-session',
});
const eventBus = new ContextEventBus();
const behaviorRegistry = new NodeBehaviorRegistry();
registerBuiltInBehaviors(behaviorRegistry);
const calculator = new StaticTokenCalculator(1, behaviorRegistry);
let env = new ContextEnvironmentImpl(
() => llmClient as BaseLlmClient,
@@ -168,6 +187,8 @@ export function createMockEnvironment(
tracer,
1,
eventBus,
calculator,
behaviorRegistry,
);
if (overrides) {
@@ -181,6 +202,8 @@ export function createMockEnvironment(
env.tracer,
env.charsPerToken,
env.eventBus,
calculator,
behaviorRegistry,
);
}
const { llmClient: _llmClient, ...restOverrides } = overrides;
@@ -273,6 +296,10 @@ export function setupContextComponentTest(
sessionId: 'test-session',
});
const eventBus = new ContextEventBus();
const behaviorRegistry = new NodeBehaviorRegistry();
registerBuiltInBehaviors(behaviorRegistry);
const calculator = new StaticTokenCalculator(1, behaviorRegistry);
const env = new ContextEnvironmentImpl(
() => config.getBaseLlmClient(),
'test prompt-id',
@@ -282,6 +309,8 @@ export function setupContextComponentTest(
tracer,
1,
eventBus,
calculator,
behaviorRegistry,
);
const orchestrator = new PipelineOrchestrator(
@@ -298,8 +327,8 @@ export function setupContextComponentTest(
tracer,
orchestrator,
chatHistory,
calculator,
);
// The async async pipeline is now internally managed by ContextManager
return { chatHistory, contextManager };
}
@@ -0,0 +1,125 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { AdaptiveTokenCalculator } from './adaptiveTokenCalculator.js';
import { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
import { registerBuiltInBehaviors } from '../graph/builtinBehaviors.js';
import { ContextEventBus } from '../eventBus.js';
import { createDummyNode } from '../testing/contextTestUtils.js';
import { NodeType } from '../graph/types.js';
describe('AdaptiveTokenCalculator', () => {
const registry = new NodeBehaviorRegistry();
registerBuiltInBehaviors(registry);
const charsPerToken = 1; // Simplifies math
it('should initialize with a learned weight of 1.0', () => {
const eventBus = new ContextEventBus();
const calculator = new AdaptiveTokenCalculator(
charsPerToken,
registry,
eventBus,
);
expect(calculator.getLearnedWeight()).toBe(1.0);
});
it('should dynamically update learned weight based on token ground truth events', () => {
const eventBus = new ContextEventBus();
const calculator = new AdaptiveTokenCalculator(
charsPerToken,
registry,
eventBus,
);
// Initial state: weight = 1.0
// Simulate an event where the API reported fewer tokens than our base units
// targetWeight = 50 / 100 = 0.5
// newWeight = 1.0 * 0.8 + 0.5 * 0.2 = 0.8 + 0.1 = 0.9
eventBus.emitTokenGroundTruth({
actualTokens: 50,
promptBaseUnits: 100,
});
// JavaScript floating point precision means we should use toBeCloseTo
expect(calculator.getLearnedWeight()).toBeCloseTo(0.9, 5);
// Simulate another event
// newWeight = 0.9 * 0.8 + (150 / 100) * 0.2 = 0.72 + 0.3 = 1.02
eventBus.emitTokenGroundTruth({
actualTokens: 150,
promptBaseUnits: 100,
});
expect(calculator.getLearnedWeight()).toBeCloseTo(1.02, 5);
});
it('should clamp the learned weight between 0.5 and 2.0', () => {
const eventBus = new ContextEventBus();
const calculator = new AdaptiveTokenCalculator(
charsPerToken,
registry,
eventBus,
);
// Push weight up extremely high (API returns 10x tokens)
for (let i = 0; i < 20; i++) {
eventBus.emitTokenGroundTruth({
actualTokens: 1000,
promptBaseUnits: 100,
});
}
expect(calculator.getLearnedWeight()).toBe(2.0);
// Push weight down extremely low (API returns 0 tokens)
for (let i = 0; i < 20; i++) {
eventBus.emitTokenGroundTruth({ actualTokens: 0, promptBaseUnits: 100 });
}
expect(calculator.getLearnedWeight()).toBe(0.5);
});
it('should correctly apply the learned weight to node calculations while keeping raw base units stable', () => {
const eventBus = new ContextEventBus();
const calculator = new AdaptiveTokenCalculator(
charsPerToken,
registry,
eventBus,
);
// Decrease the weight to exactly 0.5
for (let i = 0; i < 20; i++) {
eventBus.emitTokenGroundTruth({ actualTokens: 0, promptBaseUnits: 100 });
}
const turn1Id = 'turn-1';
const node1 = createDummyNode(turn1Id, NodeType.USER_PROMPT);
// Get raw base units directly
const rawTokens = calculator.calculateTokensAndBaseUnits([node1]).baseUnits;
// Get adjusted tokens
const adjustedTokens = calculator.calculateConcreteListTokens([node1]);
expect(adjustedTokens).toBe(Math.round(rawTokens * 0.5));
});
it('should ignore ground truth events with 0 promptBaseUnits to prevent division by zero', () => {
const eventBus = new ContextEventBus();
const calculator = new AdaptiveTokenCalculator(
charsPerToken,
registry,
eventBus,
);
eventBus.emitTokenGroundTruth({
actualTokens: 100,
promptBaseUnits: 0,
});
expect(calculator.getLearnedWeight()).toBe(1.0);
});
});
@@ -0,0 +1,163 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Content, Part } from '@google/genai';
import type { ConcreteNode } from '../graph/types.js';
import {
StaticTokenCalculator,
type AdvancedTokenCalculator,
} from './contextTokenCalculator.js';
import type { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
import type { ContextEventBus, TokenGroundTruthEvent } from '../eventBus.js';
import { debugLogger } from '../../utils/debugLogger.js';
/**
* An Adaptive Token Calculator that dynamically learns the true token cost of the user's
* conversation by applying an Exponential Moving Average (EMA) gradient descent to
* real usage metadata returned from the Gemini API.
*
* It wraps the deterministic `StaticTokenCalculator` base heuristic to ensure
* immutable node cost caching while still surfacing a self-corrected estimate
* to the pipeline processors.
*/
export class AdaptiveTokenCalculator implements AdvancedTokenCalculator {
private learnedWeight = 1.0;
private readonly baseCalculator: StaticTokenCalculator;
constructor(
charsPerToken: number,
registry: NodeBehaviorRegistry,
eventBus: ContextEventBus,
) {
this.baseCalculator = new StaticTokenCalculator(charsPerToken, registry);
eventBus.onTokenGroundTruth((event: TokenGroundTruthEvent) => {
this.handleGroundTruth(event.actualTokens, event.promptBaseUnits);
});
}
private handleGroundTruth(actualTokens: number, promptBaseUnits: number) {
if (promptBaseUnits <= 0) return;
// Determine what ratio we should have used
const targetWeight = actualTokens / promptBaseUnits;
const oldWeight = this.learnedWeight;
// Apply Momentum (Learning Rate)
const learningRate = 0.2;
const newWeight =
oldWeight * (1 - learningRate) + targetWeight * learningRate;
// Clamp to reasonable safety bounds to prevent rogue metadata poisoning the system
this.learnedWeight = Math.max(0.5, Math.min(newWeight, 2.0));
debugLogger.log(
`[AdaptiveTokenCalculator] Learned weight updated to ${this.learnedWeight.toFixed(3)} ` +
`(API Tokens: ${actualTokens}, Base Units: ${promptBaseUnits}, Target Ratio: ${targetWeight.toFixed(3)})`,
);
}
/**
* Retrieves the current learned weight multiplier.
*/
getLearnedWeight(): number {
return this.learnedWeight;
}
/**
* Returns the exact, unweighted Base Heuristic Units for the graph.
* This is used exactly once per interaction to capture the baseline sent to the API.
*/
getRawBaseUnits(nodes: readonly ConcreteNode[]): number {
return this.baseCalculator.calculateConcreteListTokens(nodes);
}
/**
* Returns the exact, unweighted Base Heuristic Units for a raw content chunk.
*/
getRawBaseUnitsForContent(content: Content): number {
return this.baseCalculator.calculateContentTokens(content);
}
calculateTokensAndBaseUnits(nodes: readonly ConcreteNode[]): {
tokens: number;
baseUnits: number;
} {
const baseUnits = this.baseCalculator.calculateConcreteListTokens(nodes);
return {
tokens: Math.round(baseUnits * this.learnedWeight),
baseUnits,
};
}
calculateContentTokensAndBaseUnits(content: Content): {
tokens: number;
baseUnits: number;
} {
const baseUnits = this.baseCalculator.calculateContentTokens(content);
return {
tokens: Math.round(baseUnits * this.learnedWeight),
baseUnits,
};
}
// --- Delegation and Weighting ---
garbageCollectCache(liveNodeIds: ReadonlySet<string>): void {
this.baseCalculator.garbageCollectCache(liveNodeIds);
}
cacheNodeTokens(node: ConcreteNode): number {
return this.baseCalculator.cacheNodeTokens(node);
}
calculateTokenBreakdown(nodes: readonly ConcreteNode[]): {
text: number;
media: number;
tool: number;
overhead: number;
total: number;
} {
const raw = this.baseCalculator.calculateTokenBreakdown(nodes);
return {
text: Math.round(raw.text * this.learnedWeight),
media: Math.round(raw.media * this.learnedWeight),
tool: Math.round(raw.tool * this.learnedWeight),
overhead: Math.round(raw.overhead * this.learnedWeight),
total: Math.round(raw.total * this.learnedWeight),
};
}
estimateTokensForParts(parts: Part[]): number {
const baseUnits = this.baseCalculator.estimateTokensForParts(parts);
return Math.round(baseUnits * this.learnedWeight);
}
getTokenCost(node: ConcreteNode): number {
const baseUnits = this.baseCalculator.getTokenCost(node);
return Math.round(baseUnits * this.learnedWeight);
}
calculateConcreteListTokens(nodes: readonly ConcreteNode[]): number {
const baseUnits = this.baseCalculator.calculateConcreteListTokens(nodes);
return Math.round(baseUnits * this.learnedWeight);
}
calculateContentTokens(content: Content): number {
const baseUnits = this.baseCalculator.calculateContentTokens(content);
return Math.round(baseUnits * this.learnedWeight);
}
estimateTokensForString(text: string): number {
const baseUnits = this.baseCalculator.estimateTokensForString(text);
return Math.round(baseUnits * this.learnedWeight);
}
tokensToChars(tokens: number): number {
// If weight is > 1.0 (we are inflating tokens), a single returned token is worth fewer chars.
// We reverse the math: convert requested tokens to target base units, then get chars.
return this.baseCalculator.tokensToChars(tokens / this.learnedWeight);
}
}
@@ -5,7 +5,7 @@
*/
import { describe, it, expect } from 'vitest';
import { ContextTokenCalculator } from './contextTokenCalculator.js';
import { StaticTokenCalculator } from './contextTokenCalculator.js';
import { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
import { registerBuiltInBehaviors } from '../graph/builtinBehaviors.js';
import { createDummyNode } from '../testing/contextTestUtils.js';
@@ -16,7 +16,7 @@ describe('ContextTokenCalculator', () => {
const registry = new NodeBehaviorRegistry();
registerBuiltInBehaviors(registry);
const charsPerToken = 1; // Simplifies math for text nodes in tests
const calculator = new ContextTokenCalculator(charsPerToken, registry);
const calculator = new StaticTokenCalculator(charsPerToken, registry);
it('should include structural overhead for each unique turn', () => {
const turn1Id = 'turn-1';
@@ -28,16 +28,16 @@ describe('ContextTokenCalculator', () => {
const nodes = [node1, node2, node3];
// Estimated tokens (using 0.33 per ASCII char heuristic):
// node1: floor(17 chars * 0.33) = 5 tokens
// node2: floor(17 chars * 0.33) = 5 tokens
// node3: floor(19 chars * 0.33) = 6 tokens
// Estimated tokens (using charsPerToken = 1):
// node1: 17 chars / 1 = 17 tokens
// node2: 17 chars / 1 = 17 tokens
// node3: 19 chars / 1 = 19 tokens
// Turn 1 overhead: 5 tokens
// Turn 2 overhead: 5 tokens
// Total: 5 + 5 + 6 + 5 + 5 = 26
// Total: 17 + 17 + 19 + 5 + 5 = 63
const total = calculator.calculateConcreteListTokens(nodes);
expect(total).toBe(26);
expect(total).toBe(63);
});
it('should handle categorical breakdown with overhead', () => {
@@ -17,7 +17,41 @@ import type { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
* by the Gemini API. We use this as a baseline heuristic for inlineData/fileData.
*/
export class ContextTokenCalculator {
export interface ContextTokenCalculator {
estimateTokensForString(text: string): number;
tokensToChars(tokens: number): number;
garbageCollectCache(liveNodeIds: ReadonlySet<string>): void;
cacheNodeTokens(node: ConcreteNode): number;
getTokenCost(node: ConcreteNode): number;
calculateTokenBreakdown(nodes: readonly ConcreteNode[]): {
text: number;
media: number;
tool: number;
overhead: number;
total: number;
};
calculateConcreteListTokens(nodes: readonly ConcreteNode[]): number;
calculateContentTokens(content: Content): number;
estimateTokensForParts(parts: Part[]): number;
}
export interface AdvancedTokenCalculator extends ContextTokenCalculator {
getRawBaseUnits(nodes: readonly ConcreteNode[]): number;
getRawBaseUnitsForContent(content: Content): number;
calculateTokensAndBaseUnits(nodes: readonly ConcreteNode[]): {
tokens: number;
baseUnits: number;
};
calculateContentTokensAndBaseUnits(content: Content): {
tokens: number;
baseUnits: number;
};
}
/**
* A fast, deterministic token heuristic calculator.
*/
export class StaticTokenCalculator implements AdvancedTokenCalculator {
private readonly tokenCache = new Map<string, number>();
constructor(
@@ -143,6 +177,34 @@ export class ContextTokenCalculator {
return breakdown;
}
/**
* For the static calculator, Raw Base Units are exactly the same as the final tokens,
* because there is no dynamic learned weight (the multiplier is effectively 1.0).
*/
getRawBaseUnits(nodes: readonly ConcreteNode[]): number {
return this.calculateConcreteListTokens(nodes);
}
getRawBaseUnitsForContent(content: Content): number {
return this.calculateContentTokens(content);
}
calculateTokensAndBaseUnits(nodes: readonly ConcreteNode[]): {
tokens: number;
baseUnits: number;
} {
const baseUnits = this.calculateConcreteListTokens(nodes);
return { tokens: baseUnits, baseUnits };
}
calculateContentTokensAndBaseUnits(content: Content): {
tokens: number;
baseUnits: number;
} {
const baseUnits = this.calculateContentTokens(content);
return { tokens: baseUnits, baseUnits };
}
/**
* Fast calculation for a flat array of ConcreteNodes (The Nodes).
* It relies entirely on the O(1) sidecar token cache.
@@ -21,6 +21,9 @@ describe('SnapshotGenerator', () => {
llmClient: {
generateJson: mockGenerateJson,
},
advancedTokenCalculator: {
getRawBaseUnits: vi.fn().mockReturnValue(100),
},
tokenCalculator: {
estimateTokensForString: vi.fn().mockReturnValue(100),
},
+4
View File
@@ -114,6 +114,7 @@ interface _CommonGenerateOptions {
export interface CountTokenOptions {
modelConfigKey?: ModelConfigKey;
contents: Content[];
abortSignal?: AbortSignal;
}
/**
@@ -240,6 +241,9 @@ export class BaseLlmClient {
const result = await this.contentGenerator.countTokens({
model,
contents: options.contents,
config: options.abortSignal
? { abortSignal: options.abortSignal }
: undefined,
});
return { totalTokens: result.totalTokens || 0 };
}
+4 -4
View File
@@ -1517,8 +1517,8 @@ ${JSON.stringify(
// A string of length 404 is roughly 101 tokens.
const longText = 'a'.repeat(404);
const request: Part[] = [{ text: longText }];
// estimateTextOnlyLength counts only text content (400 chars), not JSON structure
const estimatedRequestTokenCount = Math.floor(longText.length * 0.33);
// estimateTextOnlyLength counts only text content (404 chars), not JSON structure
const estimatedRequestTokenCount = Math.floor(longText.length * 0.25);
const remainingTokenCount = MOCKED_TOKEN_LIMIT - lastPromptTokenCount;
// Mock tryCompressChat to not compress
@@ -1577,8 +1577,8 @@ ${JSON.stringify(
// We need a request > 100 tokens.
const longText = 'a'.repeat(404);
const request: Part[] = [{ text: longText }];
// estimateTextOnlyLength counts only text content (400 chars), not JSON structure
const estimatedRequestTokenCount = Math.floor(longText.length * 0.33);
// estimateTextOnlyLength counts only text content (404 chars), not JSON structure
const estimatedRequestTokenCount = Math.floor(longText.length * 0.25);
const remainingTokenCount = STICKY_MODEL_LIMIT - lastPromptTokenCount;
vi.spyOn(client, 'tryCompressChat').mockResolvedValue({
+22 -3
View File
@@ -637,11 +637,22 @@ export class GeminiClient {
// Check for context window overflow
const modelForLimitCheck = this._getActiveModelForCurrentTurn();
let currentBaseUnits = 0;
if (this.config.getContextManagementConfig().enabled) {
if (this.contextManager) {
const pendingRequest = createUserContent(request);
const { history: newHistory, didApplyManagement } =
await this.contextManager.renderHistory(pendingRequest);
const {
history: newHistory,
didApplyManagement,
baseUnits,
} = await this.contextManager.renderHistory(
pendingRequest,
undefined,
signal,
);
currentBaseUnits = baseUnits;
if (didApplyManagement) {
// If the manager pruned history, we update the chat before continuing.
@@ -800,8 +811,16 @@ export class GeminiClient {
}
yield event;
if (event.type === GeminiEventType.Finished && this.contextManager) {
const usageMetadata = event.value.usageMetadata;
if (usageMetadata && usageMetadata.promptTokenCount !== undefined) {
this.contextManager.getEnvironment().eventBus.emitTokenGroundTruth({
actualTokens: usageMetadata.promptTokenCount,
promptBaseUnits: currentBaseUnits,
});
}
}
this.updateTelemetryTokenCount();
if (event.type === GeminiEventType.Error) {
isError = true;
}
+1 -1
View File
@@ -242,7 +242,7 @@ describe('GeminiChat', () => {
// 'Hello': 5 chars * 0.25 = 1.25
// 'Hi there': 8 chars * 0.25 = 2.0
// Total: 3.25 -> floor(3.25) = 3
expect(chatWithHistory.getLastPromptTokenCount()).toBe(4);
expect(chatWithHistory.getLastPromptTokenCount()).toBe(3);
});
it('should initialize lastPromptTokenCount for empty history', () => {
@@ -280,6 +280,26 @@ describe('tokenCalculation', () => {
expect(tokens).toBeLessThan(30);
});
it('should respect the user supplied charsPerToken argument', () => {
const text = 'abcdefghijkl'; // 12 chars
const parts: Part[] = [{ text }];
// Default (4 chars/token) -> 12 / 4 = 3 tokens
expect(estimateTokenCountSync(parts)).toBe(3);
// Override to 3 chars/token -> 12 / 3 = 4 tokens
expect(estimateTokenCountSync(parts, 0, 3)).toBe(4);
// Override to 2 chars/token -> 12 / 2 = 6 tokens
expect(estimateTokenCountSync(parts, 0, 2)).toBe(6);
// Verify massive strings also respect the argument
const massiveText = 'a'.repeat(120_000); // Exceeds 100k
const massiveParts: Part[] = [{ text: massiveText }];
expect(estimateTokenCountSync(massiveParts, 0, 4)).toBe(30_000);
expect(estimateTokenCountSync(massiveParts, 0, 3)).toBe(40_000);
});
it('should handle empty or nullish inputs gracefully', () => {
expect(estimateTokenCountSync([])).toBe(0);
expect(estimateTokenCountSync([{ text: '' }])).toBe(0);
+3 -1
View File
@@ -43,10 +43,12 @@ function estimateTextTokens(text: string, charsPerToken: number): number {
}
let tokens = 0;
const asciiTokensPerChar = 1 / charsPerToken;
// Optimized loop: charCodeAt is faster than for...of on large strings
for (let i = 0; i < text.length; i++) {
if (text.charCodeAt(i) <= 127) {
tokens += ASCII_TOKENS_PER_CHAR;
tokens += asciiTokensPerChar;
} else {
tokens += NON_ASCII_TOKENS_PER_CHAR;
}