mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 13:22:35 -07:00
fix(context): implement loose boundary policy for gc backstop. (#26594)
This commit is contained in:
@@ -60,9 +60,7 @@ describe('ContextManager Sync Pressure Barrier Tests', () => {
|
||||
|
||||
// Verify Episode 0 (System) was pruned, so we now start with a sentinel due to role alternation
|
||||
expect(projection[0].role).toBe('user');
|
||||
expect(projection[0].parts![0].text).toBe(
|
||||
'[Continuing from previous AI thoughts...]',
|
||||
);
|
||||
expect(projection[0].parts![0].text).toContain('User turn 17');
|
||||
|
||||
// Filter out synthetic Yield nodes (they are model responses without actual tool/text bodies)
|
||||
const contentNodes = projection.filter(
|
||||
|
||||
@@ -72,7 +72,11 @@ export class ContextManager {
|
||||
event.targets,
|
||||
event.returnedNodes,
|
||||
);
|
||||
this.evaluateTriggers(new Set());
|
||||
// We explicitly DO NOT call evaluateTriggers here.
|
||||
// The Context Manager is a one-way assembly line. It only evaluates triggers
|
||||
// when fundamentally new organic context is added via PristineHistoryUpdated.
|
||||
// Re-evaluating after a processor finishes creates infinite feedback loops if
|
||||
// the processor fails to reduce the token count below the threshold.
|
||||
});
|
||||
|
||||
this.historyObserver.start();
|
||||
@@ -126,10 +130,15 @@ export class ContextManager {
|
||||
// Walk backwards finding nodes that fall out of the retained budget
|
||||
for (let i = this.buffer.nodes.length - 1; i >= 0; i--) {
|
||||
const node = this.buffer.nodes[i];
|
||||
const priorTokens = rollingTokens;
|
||||
rollingTokens += this.env.tokenCalculator.calculateConcreteListTokens([
|
||||
node,
|
||||
]);
|
||||
if (rollingTokens > this.sidecar.config.budget.retainedTokens) {
|
||||
|
||||
// Loose Boundary Policy: If this node is the one that pushes us over the retained limit,
|
||||
// we KEEP it to prevent aggressive undershooting. We only age out nodes that are
|
||||
// strictly *older* than the boundary node.
|
||||
if (priorTokens > this.sidecar.config.budget.retainedTokens) {
|
||||
// Only age out if not protected
|
||||
if (!protectedIds.has(node.id)) {
|
||||
agedOutNodes.add(node.id);
|
||||
|
||||
@@ -61,4 +61,169 @@ describe('render', () => {
|
||||
|
||||
expect(result.history).toEqual([{ text: '1' }, { text: '2' }]);
|
||||
});
|
||||
|
||||
it('simulates the boundary knapsack problem (loose boundary policy)', async () => {
|
||||
// 10k, 20k, 40k, 5k
|
||||
const mockNodes: ConcreteNode[] = [
|
||||
{
|
||||
id: 'D',
|
||||
type: NodeType.USER_PROMPT,
|
||||
payload: {} as Part,
|
||||
} as unknown as ConcreteNode,
|
||||
{
|
||||
id: 'C',
|
||||
type: NodeType.AGENT_THOUGHT,
|
||||
payload: {} as Part,
|
||||
} as unknown as ConcreteNode,
|
||||
{
|
||||
id: 'B',
|
||||
type: NodeType.USER_PROMPT,
|
||||
payload: {} as Part,
|
||||
} as unknown as ConcreteNode,
|
||||
{
|
||||
id: 'A',
|
||||
type: NodeType.AGENT_THOUGHT,
|
||||
payload: {} as Part,
|
||||
} as unknown as ConcreteNode,
|
||||
];
|
||||
|
||||
const tokenMap: Record<string, number> = {
|
||||
D: 5000,
|
||||
C: 40000,
|
||||
B: 20000,
|
||||
A: 10000,
|
||||
};
|
||||
|
||||
const orchestrator = {
|
||||
executeTriggerSync: vi.fn(async (trigger, nodes, agedOutNodes) =>
|
||||
nodes.filter((n: ConcreteNode) => !agedOutNodes.has(n.id)),
|
||||
),
|
||||
} as unknown as PipelineOrchestrator;
|
||||
|
||||
const sidecar = {
|
||||
config: {
|
||||
budget: { maxTokens: 150000, retainedTokens: 65000 },
|
||||
},
|
||||
} as unknown as ContextProfile;
|
||||
|
||||
const currentTokens = 160000;
|
||||
|
||||
const env = {
|
||||
llmClient: {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1000 }),
|
||||
},
|
||||
tokenCalculator: {
|
||||
calculateConcreteListTokens: vi.fn((nodes) => {
|
||||
if (nodes.length === 1) return tokenMap[nodes[0].id];
|
||||
return currentTokens;
|
||||
}),
|
||||
calculateTokenBreakdown: vi.fn(() => ({})),
|
||||
},
|
||||
graphMapper: {
|
||||
fromGraph: vi.fn((nodes: readonly ConcreteNode[]) =>
|
||||
nodes.map((n) => ({ text: n.id })),
|
||||
),
|
||||
},
|
||||
} as unknown as ContextEnvironment;
|
||||
|
||||
const tracer = {
|
||||
logEvent: vi.fn(),
|
||||
} as unknown as ContextTracer;
|
||||
|
||||
const result = await render(
|
||||
mockNodes,
|
||||
orchestrator,
|
||||
sidecar,
|
||||
tracer,
|
||||
env,
|
||||
new Map(),
|
||||
0,
|
||||
new Set(),
|
||||
);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const surviving = result.history.map((c: any) => c.text);
|
||||
// Loose Boundary: A (10k), B (20k), C (40k). Total = 70k.
|
||||
// Adding C pushes rolling total (70k) above retainedTokens (65k).
|
||||
// Under loose policy, C survives. D is strictly older and drops.
|
||||
expect(surviving).toEqual(['C', 'B', 'A']); // D is dropped
|
||||
});
|
||||
|
||||
it('drops nodes that are STRICTLY older than the boundary node', async () => {
|
||||
const mockNodes: ConcreteNode[] = [
|
||||
{
|
||||
id: 'A',
|
||||
type: NodeType.USER_PROMPT,
|
||||
payload: {} as Part,
|
||||
} as unknown as ConcreteNode,
|
||||
{
|
||||
id: 'B',
|
||||
type: NodeType.AGENT_THOUGHT,
|
||||
payload: {} as Part,
|
||||
} as unknown as ConcreteNode,
|
||||
{
|
||||
id: 'C',
|
||||
type: NodeType.USER_PROMPT,
|
||||
payload: {} as Part,
|
||||
} as unknown as ConcreteNode,
|
||||
];
|
||||
|
||||
const tokenMap: Record<string, number> = {
|
||||
C: 40000,
|
||||
B: 40000,
|
||||
A: 10000,
|
||||
};
|
||||
|
||||
const orchestrator = {
|
||||
executeTriggerSync: vi.fn(async (trigger, nodes, agedOutNodes) =>
|
||||
nodes.filter((n: ConcreteNode) => !agedOutNodes.has(n.id)),
|
||||
),
|
||||
} as unknown as PipelineOrchestrator;
|
||||
|
||||
const sidecar = {
|
||||
config: {
|
||||
budget: { maxTokens: 150000, retainedTokens: 65000 },
|
||||
},
|
||||
} as unknown as ContextProfile;
|
||||
|
||||
const currentTokens = 160000;
|
||||
|
||||
const env = {
|
||||
llmClient: {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1000 }),
|
||||
},
|
||||
tokenCalculator: {
|
||||
calculateConcreteListTokens: vi.fn((nodes) => {
|
||||
if (nodes.length === 1) return tokenMap[nodes[0].id];
|
||||
return currentTokens;
|
||||
}),
|
||||
calculateTokenBreakdown: vi.fn(() => ({})),
|
||||
},
|
||||
graphMapper: {
|
||||
fromGraph: vi.fn((nodes: readonly ConcreteNode[]) =>
|
||||
nodes.map((n) => ({ text: n.id })),
|
||||
),
|
||||
},
|
||||
} as unknown as ContextEnvironment;
|
||||
|
||||
const tracer = {
|
||||
logEvent: vi.fn(),
|
||||
} as unknown as ContextTracer;
|
||||
|
||||
const result = await render(
|
||||
mockNodes,
|
||||
orchestrator,
|
||||
sidecar,
|
||||
tracer,
|
||||
env,
|
||||
new Map(),
|
||||
0,
|
||||
new Set(),
|
||||
);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const surviving = result.history.map((c: any) => c.text);
|
||||
// C(40k), B(40k). Adding B pushes total to 80k. B is the boundary node and survives. A drops.
|
||||
expect(surviving).toEqual(['B', 'C']); // A is dropped
|
||||
});
|
||||
});
|
||||
|
||||
@@ -10,6 +10,7 @@ import type { ContextTracer } from '../tracer.js';
|
||||
import type { ContextProfile } from '../config/profiles.js';
|
||||
import type { PipelineOrchestrator } from '../pipeline/orchestrator.js';
|
||||
import type { ContextEnvironment } from '../pipeline/environment.js';
|
||||
import { performCalibration } from '../utils/tokenCalibration.js';
|
||||
|
||||
/**
|
||||
* Maps the Episodic Context Graph back into a raw Gemini Content[] array for transmission.
|
||||
@@ -68,6 +69,7 @@ export async function render(
|
||||
tracer.logEvent('Render', 'Render Context for LLM', {
|
||||
renderedContext: contents,
|
||||
});
|
||||
performCalibration(env, visibleNodes, contents);
|
||||
return { history: contents, didApplyManagement: false };
|
||||
}
|
||||
const targetDelta = currentTokens - sidecar.config.budget.retainedTokens;
|
||||
@@ -83,9 +85,12 @@ export async function render(
|
||||
// Start from newest and count backwards
|
||||
for (let i = nodes.length - 1; i >= 0; i--) {
|
||||
const node = nodes[i];
|
||||
const priorTokens = rollingTokens;
|
||||
const nodeTokens = env.tokenCalculator.calculateConcreteListTokens([node]);
|
||||
rollingTokens += nodeTokens;
|
||||
if (rollingTokens > sidecar.config.budget.retainedTokens) {
|
||||
|
||||
// Loose Boundary Policy: Keep the node that crosses the boundary
|
||||
if (priorTokens > sidecar.config.budget.retainedTokens) {
|
||||
agedOutNodes.add(node.id);
|
||||
}
|
||||
}
|
||||
@@ -113,5 +118,6 @@ export async function render(
|
||||
tracer.logEvent('Render', 'Render Sanitized Context for LLM', {
|
||||
renderedContextSanitized: contents,
|
||||
});
|
||||
performCalibration(env, visibleNodes, contents);
|
||||
return { history: contents, didApplyManagement: true };
|
||||
}
|
||||
|
||||
@@ -94,6 +94,10 @@ export async function initializeContextManager(
|
||||
tracer,
|
||||
4,
|
||||
eventBus,
|
||||
{
|
||||
calibrateTokenCalculation:
|
||||
!!process.env['GEMINI_CONTEXT_CALIBRATE_TOKEN_CALCULATIONS'],
|
||||
},
|
||||
);
|
||||
|
||||
const orchestrator = new PipelineOrchestrator(
|
||||
|
||||
@@ -13,6 +13,10 @@ import type { ContextGraphMapper } from '../graph/mapper.js';
|
||||
|
||||
export type { ContextTracer, ContextEventBus };
|
||||
|
||||
export interface RenderOptions {
|
||||
calibrateTokenCalculation?: boolean;
|
||||
}
|
||||
|
||||
export interface ContextEnvironment {
|
||||
readonly llmClient: BaseLlmClient;
|
||||
readonly promptId: string;
|
||||
@@ -26,4 +30,5 @@ export interface ContextEnvironment {
|
||||
readonly inbox: LiveInbox;
|
||||
readonly behaviorRegistry: NodeBehaviorRegistry;
|
||||
readonly graphMapper: ContextGraphMapper;
|
||||
readonly renderOptions?: RenderOptions;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import type { ContextTracer } from '../tracer.js';
|
||||
import type { ContextEnvironment } from './environment.js';
|
||||
import type { ContextEnvironment, RenderOptions } from './environment.js';
|
||||
import type { ContextEventBus } from '../eventBus.js';
|
||||
import { ContextTokenCalculator } from '../utils/contextTokenCalculator.js';
|
||||
import { LiveInbox } from './inbox.js';
|
||||
@@ -29,6 +29,7 @@ export class ContextEnvironmentImpl implements ContextEnvironment {
|
||||
readonly tracer: ContextTracer,
|
||||
readonly charsPerToken: number,
|
||||
readonly eventBus: ContextEventBus,
|
||||
readonly renderOptions?: RenderOptions,
|
||||
) {
|
||||
this.behaviorRegistry = new NodeBehaviorRegistry();
|
||||
registerBuiltInBehaviors(this.behaviorRegistry);
|
||||
|
||||
+193
-23
File diff suppressed because one or more lines are too long
@@ -9,11 +9,7 @@ import fs from 'node:fs';
|
||||
import { SimulationHarness } from './simulationHarness.js';
|
||||
import { createMockLlmClient } from '../testing/contextTestUtils.js';
|
||||
import type { ContextProfile } from '../config/profiles.js';
|
||||
import { createToolMaskingProcessor } from '../processors/toolMaskingProcessor.js';
|
||||
import { createBlobDegradationProcessor } from '../processors/blobDegradationProcessor.js';
|
||||
import { createStateSnapshotProcessor } from '../processors/stateSnapshotProcessor.js';
|
||||
import { createHistoryTruncationProcessor } from '../processors/historyTruncationProcessor.js';
|
||||
import { createStateSnapshotAsyncProcessor } from '../processors/stateSnapshotAsyncProcessor.js';
|
||||
import { stressTestProfile } from '../config/profiles.js';
|
||||
|
||||
expect.addSnapshotSerializer({
|
||||
test: (val) =>
|
||||
@@ -52,57 +48,22 @@ describe('System Lifecycle Golden Tests', () => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
const getAggressiveConfig = (): ContextProfile => ({
|
||||
name: 'Aggressive Test',
|
||||
config: {
|
||||
budget: { maxTokens: 1000, retainedTokens: 500 }, // Extremely tight limits
|
||||
},
|
||||
buildPipelines: (env) => [
|
||||
{
|
||||
name: 'Pressure Relief', // Emits from eventBus 'retained_exceeded'
|
||||
triggers: ['retained_exceeded'],
|
||||
processors: [
|
||||
createBlobDegradationProcessor('BlobDegradationProcessor', env),
|
||||
createToolMaskingProcessor('ToolMaskingProcessor', env, {
|
||||
stringLengthThresholdTokens: 50,
|
||||
}),
|
||||
createStateSnapshotProcessor('StateSnapshotProcessor', env, {}),
|
||||
],
|
||||
},
|
||||
{
|
||||
name: 'Immediate Sanitization', // The magic string the projector is hardcoded to use
|
||||
triggers: ['retained_exceeded'],
|
||||
processors: [
|
||||
createHistoryTruncationProcessor(
|
||||
'HistoryTruncationProcessor',
|
||||
env,
|
||||
{},
|
||||
),
|
||||
],
|
||||
},
|
||||
],
|
||||
buildAsyncPipelines: (env) => [
|
||||
{
|
||||
name: 'Async',
|
||||
triggers: ['nodes_aged_out'],
|
||||
processors: [
|
||||
createStateSnapshotAsyncProcessor(
|
||||
'StateSnapshotAsyncProcessor',
|
||||
env,
|
||||
{},
|
||||
),
|
||||
],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const mockLlmClient = createMockLlmClient([
|
||||
'<MOCKED_STATE_SNAPSHOT_SUMMARY>',
|
||||
]);
|
||||
// Uses dynamic role-based mocking to differentiate Snapshot vs Distillation output automatically.
|
||||
const mockLlmClient = createMockLlmClient();
|
||||
|
||||
it('Scenario 1: Organic Growth with Huge Tool Output & Images', async () => {
|
||||
// Override stressTestProfile limits slightly to ensure immediate overflow
|
||||
// without having to push 50,000 characters to cross the generalist boundaries.
|
||||
const customProfile: ContextProfile = {
|
||||
...stressTestProfile,
|
||||
config: {
|
||||
...stressTestProfile.config,
|
||||
budget: { maxTokens: 1000, retainedTokens: 500 },
|
||||
},
|
||||
};
|
||||
|
||||
const harness = await SimulationHarness.create(
|
||||
getAggressiveConfig(),
|
||||
customProfile,
|
||||
mockLlmClient,
|
||||
);
|
||||
|
||||
@@ -169,6 +130,9 @@ describe('System Lifecycle Golden Tests', () => {
|
||||
{ role: 'model', parts: [{ text: 'Yes we can.' }] },
|
||||
]);
|
||||
|
||||
// Give the background tasks a moment to inject the snapshot into the graph
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
// Get final state
|
||||
const goldenState = await harness.getGoldenState();
|
||||
|
||||
@@ -212,54 +176,117 @@ describe('System Lifecycle Golden Tests', () => {
|
||||
expect(goldenState).toMatchSnapshot();
|
||||
});
|
||||
|
||||
it('Scenario 3: Async-Driven Background GC', async () => {
|
||||
const gcConfig: ContextProfile = {
|
||||
name: 'GC Test Config',
|
||||
it('Scenario 3: Node Distillation of Large Historical Messages', async () => {
|
||||
// 1 Turn = ~2520 tokens.
|
||||
// retainedTokens = 4000 ensures Turn 0 is kept intact until Turn 1 pushes the total to ~5040.
|
||||
const customProfile: ContextProfile = {
|
||||
...stressTestProfile,
|
||||
config: {
|
||||
budget: { maxTokens: 200, retainedTokens: 100 },
|
||||
},
|
||||
buildPipelines: () => [],
|
||||
buildAsyncPipelines: (env) => [
|
||||
{
|
||||
name: 'Async',
|
||||
triggers: ['nodes_aged_out'],
|
||||
processors: [
|
||||
createStateSnapshotAsyncProcessor(
|
||||
'StateSnapshotAsyncProcessor',
|
||||
env,
|
||||
{},
|
||||
),
|
||||
],
|
||||
...stressTestProfile.config,
|
||||
budget: { maxTokens: 10000, retainedTokens: 4000 },
|
||||
processorOptions: {
|
||||
...stressTestProfile.config?.processorOptions,
|
||||
NodeDistillation: {
|
||||
type: 'NodeDistillationProcessor',
|
||||
options: {
|
||||
nodeThresholdTokens: 1000, // 1250 > 1000, so older messages will be distilled
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
// Disable async pipelines (StateSnapshots) so they don't compete with the Normalization pipeline
|
||||
buildAsyncPipelines: () => [],
|
||||
};
|
||||
|
||||
const harness = await SimulationHarness.create(gcConfig, mockLlmClient);
|
||||
const harness = await SimulationHarness.create(
|
||||
customProfile,
|
||||
mockLlmClient,
|
||||
);
|
||||
|
||||
// Turn 0
|
||||
await harness.simulateTurn([
|
||||
{ role: 'user', parts: [{ text: 'A'.repeat(50) }] },
|
||||
{ role: 'model', parts: [{ text: 'B'.repeat(50) }] },
|
||||
{ role: 'user', parts: [{ text: 'A'.repeat(5000) }] },
|
||||
{ role: 'model', parts: [{ text: 'B'.repeat(5000) }] },
|
||||
]);
|
||||
|
||||
// Turn 1 (Should trigger StateSnapshotasync pipeline because we exceed 100 retainedTokens)
|
||||
// Turn 1
|
||||
await harness.simulateTurn([
|
||||
{ role: 'user', parts: [{ text: 'C'.repeat(50) }] },
|
||||
{ role: 'model', parts: [{ text: 'D'.repeat(50) }] },
|
||||
{ role: 'user', parts: [{ text: 'C'.repeat(5000) }] },
|
||||
{ role: 'model', parts: [{ text: 'D'.repeat(5000) }] },
|
||||
]);
|
||||
|
||||
// Give the async background pipeline an extra beat to complete its async execution and emit variants
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
// Turn 2
|
||||
await harness.simulateTurn([
|
||||
{ role: 'user', parts: [{ text: 'E'.repeat(50) }] },
|
||||
{ role: 'model', parts: [{ text: 'F'.repeat(50) }] },
|
||||
{ role: 'user', parts: [{ text: 'E'.repeat(5000) }] },
|
||||
{ role: 'model', parts: [{ text: 'F'.repeat(5000) }] },
|
||||
]);
|
||||
|
||||
const goldenState = await harness.getGoldenState();
|
||||
|
||||
// We should see ROLLING_SUMMARY nodes injected into the graph, proving the async pipeline ran in the background
|
||||
// We should see MOCKED_DISTILLED_NODE replacing older bloated messages, while recent messages are untouched.
|
||||
expect(goldenState).toMatchSnapshot();
|
||||
});
|
||||
|
||||
it('Scenario 4: Async-Driven Background GC via State Snapshots', async () => {
|
||||
// Mathematical Token Budgeting:
|
||||
// 200 chars ≈ 50 tokens.
|
||||
// 1 Turn (User + Model + Overhead) ≈ 50 + 50 + 20 = 120 Tokens.
|
||||
const customProfile: ContextProfile = {
|
||||
...stressTestProfile,
|
||||
config: {
|
||||
...stressTestProfile.config,
|
||||
// Retain 3 Turns (~360 tokens). Max 5 Turns (~600 tokens).
|
||||
budget: { maxTokens: 600, retainedTokens: 360 },
|
||||
},
|
||||
};
|
||||
|
||||
const harness = await SimulationHarness.create(
|
||||
customProfile,
|
||||
mockLlmClient,
|
||||
);
|
||||
|
||||
const createMessage = (index: number) =>
|
||||
`Msg ${index} `.repeat(25).padEnd(200, '.');
|
||||
|
||||
// Turn 0 (~120 tokens) Total: 120
|
||||
await harness.simulateTurn([
|
||||
{ role: 'user', parts: [{ text: createMessage(0) }] },
|
||||
{ role: 'model', parts: [{ text: createMessage(1) }] },
|
||||
]);
|
||||
|
||||
// Turn 1 (~120 tokens) Total: 240
|
||||
await harness.simulateTurn([
|
||||
{ role: 'user', parts: [{ text: createMessage(2) }] },
|
||||
{ role: 'model', parts: [{ text: createMessage(3) }] },
|
||||
]);
|
||||
|
||||
// Turn 2 (~120 tokens) Total: 360 (At retainedTokens boundary)
|
||||
await harness.simulateTurn([
|
||||
{ role: 'user', parts: [{ text: createMessage(4) }] },
|
||||
{ role: 'model', parts: [{ text: createMessage(5) }] },
|
||||
]);
|
||||
|
||||
// Turn 3 (~120 tokens) Total: 480 (Exceeds retainedTokens! Triggers GC on Turn 0 & 1)
|
||||
await harness.simulateTurn([
|
||||
{ role: 'user', parts: [{ text: createMessage(6) }] },
|
||||
{ role: 'model', parts: [{ text: createMessage(7) }] },
|
||||
]);
|
||||
|
||||
// Give the async background snapshot pipeline time to complete
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
// Turn 4 (~120 tokens).
|
||||
// If GC succeeded, Turn 0 and 1 are now a ~10 token snapshot.
|
||||
// Total should be: 10 (Snapshot) + 120 (Turn 2) + 120 (Turn 3) + 120 (Turn 4) = ~370 tokens.
|
||||
await harness.simulateTurn([
|
||||
{ role: 'user', parts: [{ text: createMessage(8) }] },
|
||||
{ role: 'model', parts: [{ text: createMessage(9) }] },
|
||||
]);
|
||||
|
||||
const goldenState = await harness.getGoldenState();
|
||||
|
||||
// We should see a MOCKED_STATE_SNAPSHOT_SUMMARY rolling up Turns 0 and 1,
|
||||
// while Turns 2, 3, and 4 remain fully intact.
|
||||
expect(goldenState).toMatchSnapshot();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -19,7 +19,10 @@ import {
|
||||
} from '../graph/types.js';
|
||||
import type { ContextEnvironment } from '../pipeline/environment.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import type {
|
||||
BaseLlmClient,
|
||||
GenerateContentOptions,
|
||||
} from '../../core/baseLlmClient.js';
|
||||
import type { Content, GenerateContentResponse } from '@google/genai';
|
||||
import { InboxSnapshotImpl } from '../pipeline/inbox.js';
|
||||
import type { InboxMessage, ProcessArgs } from '../pipeline.js';
|
||||
@@ -98,38 +101,38 @@ export function createDummyToolNode(
|
||||
|
||||
export interface MockLlmClient extends BaseLlmClient {
|
||||
generateContent: Mock;
|
||||
countTokens: Mock;
|
||||
}
|
||||
|
||||
export function createMockLlmClient(
|
||||
responses?: Array<string | GenerateContentResponse>,
|
||||
): MockLlmClient {
|
||||
const generateContentMock = vi.fn();
|
||||
|
||||
if (responses && responses.length > 0) {
|
||||
for (const response of responses) {
|
||||
if (typeof response === 'string') {
|
||||
generateContentMock.mockResolvedValueOnce(
|
||||
createMockGenerateContentResponse(response),
|
||||
const generateContentMock = vi
|
||||
.fn()
|
||||
.mockImplementation((options: GenerateContentOptions) => {
|
||||
// Array-based logic for backwards compatibility, if provided
|
||||
if (responses && responses.length > 0) {
|
||||
const callCount = generateContentMock.mock.calls.length - 1;
|
||||
const idx =
|
||||
callCount < responses.length ? callCount : responses.length - 1;
|
||||
const res = responses[idx];
|
||||
return Promise.resolve(
|
||||
typeof res === 'string'
|
||||
? createMockGenerateContentResponse(res)
|
||||
: res,
|
||||
);
|
||||
} else {
|
||||
generateContentMock.mockResolvedValueOnce(response);
|
||||
}
|
||||
}
|
||||
// Fallback to the last response for any subsequent calls
|
||||
const lastResponse = responses[responses.length - 1];
|
||||
if (typeof lastResponse === 'string') {
|
||||
generateContentMock.mockResolvedValue(
|
||||
createMockGenerateContentResponse(lastResponse),
|
||||
|
||||
const lastContent = options.contents[options.contents.length - 1];
|
||||
const lastPart = lastContent?.parts?.[lastContent.parts.length - 1];
|
||||
const lastPartString = JSON.stringify(lastPart ?? {});
|
||||
const contentSample = `${lastPartString.slice(0, 10)}...${lastPartString.slice(-10)}`;
|
||||
return Promise.resolve(
|
||||
createMockGenerateContentResponse(
|
||||
`Mock response from: ${options.role}, for: ${contentSample}`,
|
||||
),
|
||||
);
|
||||
} else {
|
||||
generateContentMock.mockResolvedValue(lastResponse);
|
||||
}
|
||||
} else {
|
||||
// Default fallback
|
||||
generateContentMock.mockResolvedValue(
|
||||
createMockGenerateContentResponse('Mock LLM response'),
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
return {
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
import type { Content } from '@google/genai';
|
||||
import type { ContextEnvironment } from '../pipeline/environment.js';
|
||||
import type { ConcreteNode } from '../graph/types.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
|
||||
export function performCalibration(
|
||||
env: ContextEnvironment,
|
||||
finalNodes: readonly ConcreteNode[],
|
||||
finalContents: Content[],
|
||||
) {
|
||||
if (!env.renderOptions?.calibrateTokenCalculation) {
|
||||
return;
|
||||
}
|
||||
|
||||
void (async () => {
|
||||
try {
|
||||
const exactResp = await env.llmClient.countTokens({
|
||||
contents: finalContents,
|
||||
});
|
||||
const exactTokens =
|
||||
typeof exactResp.totalTokens === 'number' ? exactResp.totalTokens : 0;
|
||||
const estimatedTokens =
|
||||
env.tokenCalculator.calculateConcreteListTokens(finalNodes);
|
||||
|
||||
const delta = Math.abs(exactTokens - estimatedTokens);
|
||||
const tolerance = Math.max(exactTokens, estimatedTokens) * 0.2; // 20% tolerance
|
||||
|
||||
env.tracer.logEvent('Render', 'Token Calibration Measurement', {
|
||||
exactTokens,
|
||||
estimatedTokens,
|
||||
delta,
|
||||
isWithinTolerance: delta <= tolerance,
|
||||
});
|
||||
|
||||
if (delta > tolerance) {
|
||||
debugLogger.error(
|
||||
`[Token Calibration] Large deviation detected: exact ${exactTokens} vs estimated ${estimatedTokens} (delta: ${delta})`,
|
||||
);
|
||||
}
|
||||
} catch {
|
||||
// Ignore API failures during background calibration
|
||||
}
|
||||
})();
|
||||
}
|
||||
@@ -111,6 +111,11 @@ interface _CommonGenerateOptions {
|
||||
};
|
||||
}
|
||||
|
||||
export interface CountTokenOptions {
|
||||
modelConfigKey?: ModelConfigKey;
|
||||
contents: Content[];
|
||||
}
|
||||
|
||||
/**
|
||||
* A client dedicated to stateless, utility-focused LLM calls.
|
||||
*/
|
||||
@@ -225,6 +230,20 @@ export class BaseLlmClient {
|
||||
return text;
|
||||
}
|
||||
|
||||
async countTokens(
|
||||
options: CountTokenOptions,
|
||||
): Promise<{ totalTokens: number }> {
|
||||
const model = options.modelConfigKey
|
||||
? this.config.modelConfigService.getResolvedConfig(options.modelConfigKey)
|
||||
.model
|
||||
: this.config.getActiveModel();
|
||||
const result = await this.contentGenerator.countTokens({
|
||||
model,
|
||||
contents: options.contents,
|
||||
});
|
||||
return { totalTokens: result.totalTokens || 0 };
|
||||
}
|
||||
|
||||
async generateContent(
|
||||
options: GenerateContentOptions,
|
||||
): Promise<GenerateContentResponse> {
|
||||
|
||||
Reference in New Issue
Block a user