diff --git a/packages/core/src/context/contextManager.ts b/packages/core/src/context/contextManager.ts index 41d93aa47f..82bc22aaa6 100644 --- a/packages/core/src/context/contextManager.ts +++ b/packages/core/src/context/contextManager.ts @@ -9,7 +9,7 @@ import type { AgentChatHistory, HistoryTurn, } from '../core/agentChatHistory.js'; -import { isToolExecution, type ConcreteNode } from './graph/types.js'; +import type { ConcreteNode } from './graph/types.js'; import type { ContextEventBus } from './eventBus.js'; import type { ContextTracer } from './tracer.js'; import type { ContextEnvironment } from './pipeline/environment.js'; @@ -235,24 +235,7 @@ export class ContextManager { } } - // 2. Identify active tool calls that must NEVER be truncated - const calls = nodes.filter((n) => isToolExecution(n) && n.role === 'model'); - const responses = new Set( - nodes - .filter((n) => isToolExecution(n) && n.role === 'user') - .map((n) => n.payload.functionResponse?.id) - .filter((id): id is string => !!id), - ); - - for (const call of calls) { - const id = call.payload.functionCall?.id; - // If we have a call but no response in the current graph, it's 'in flight' - if (id && !responses.has(id)) { - protectionMap.set(call.id, 'in_flight_tool_call'); - } - } - - // 3. Any externally requested protections + // 2. Any externally requested protections for (const id of extraProtectedIds) { protectionMap.set(id, 'external_active_task'); } diff --git a/packages/core/src/context/graph/toGraph.ts b/packages/core/src/context/graph/toGraph.ts index 11b76e8d5b..c5dfc0738c 100644 --- a/packages/core/src/context/graph/toGraph.ts +++ b/packages/core/src/context/graph/toGraph.ts @@ -10,6 +10,7 @@ import { createHash } from 'node:crypto'; import { debugLogger } from '../../utils/debugLogger.js'; import type { NodeIdService } from './nodeIdService.js'; import type { HistoryTurn } from '../../core/agentChatHistory.js'; +import { isSnapshotState } from '../utils/snapshotGenerator.js'; // Global WeakMap to cache hashes for Part objects. // This optimizes getStableId by avoiding redundant stringify/hash operations @@ -215,12 +216,16 @@ export class ContextGraphBuilder { ? `${apiId}_${turnSalt}_${partIdx}` : `${turnSalt}_${partIdx}`; + const isSnapshot = isTextPart(part) && isSnapshotState(part.text); + const node: ConcreteNode = { id, timestamp: Date.now(), type: isFunctionResponsePart(part) ? NodeType.TOOL_EXECUTION - : NodeType.USER_PROMPT, + : isSnapshot + ? NodeType.SNAPSHOT + : NodeType.USER_PROMPT, role: 'user', payload: part, turnId, diff --git a/packages/core/src/context/initializer.ts b/packages/core/src/context/initializer.ts index 3916210bea..ac6208a78e 100644 --- a/packages/core/src/context/initializer.ts +++ b/packages/core/src/context/initializer.ts @@ -24,6 +24,7 @@ import { StateSnapshotAsyncProcessorOptionsSchema } from './processors/stateSnap import { RollingSummaryProcessorOptionsSchema } from './processors/rollingSummaryProcessor.js'; import { getEnvironmentContext } from '../utils/environmentContext.js'; import { AdaptiveTokenCalculator } from './utils/adaptiveTokenCalculator.js'; +import { estimateContextBreakdown } from '../core/loggingContentGenerator.js'; import { NodeBehaviorRegistry } from './graph/behaviorRegistry.js'; import { registerBuiltInBehaviors } from './graph/builtinBehaviors.js'; @@ -92,10 +93,26 @@ export async function initializeContextManager( const behaviorRegistry = new NodeBehaviorRegistry(); registerBuiltInBehaviors(behaviorRegistry); + const getOverheadTokens = () => { + const breakdown = estimateContextBreakdown([], { + systemInstruction: { + role: 'system', + parts: [{ text: chat.getSystemInstruction() }], + }, + tools: chat.getTools(), + }); + return ( + breakdown.system_instructions + + breakdown.tool_definitions + + breakdown.mcp_servers + ); + }; + const calculator = new AdaptiveTokenCalculator( charsPerToken, behaviorRegistry, eventBus, + getOverheadTokens, ); const env = new ContextEnvironmentImpl( diff --git a/packages/core/src/context/utils/adaptiveTokenCalculator.test.ts b/packages/core/src/context/utils/adaptiveTokenCalculator.test.ts index 6e89d1baca..f396f87d94 100644 --- a/packages/core/src/context/utils/adaptiveTokenCalculator.test.ts +++ b/packages/core/src/context/utils/adaptiveTokenCalculator.test.ts @@ -122,4 +122,29 @@ describe('AdaptiveTokenCalculator', () => { expect(calculator.getLearnedWeight()).toBe(1.0); }); + + it('should subtract overhead tokens from actual tokens when determining target weight', () => { + const eventBus = new ContextEventBus(); + const getOverheadTokens = () => 40; + const calculator = new AdaptiveTokenCalculator( + charsPerToken, + registry, + eventBus, + getOverheadTokens, + ); + + // Initial state: weight = 1.0 + + // Simulate an event where the API reported 100 tokens, and our base units were 100 + // But overhead is 40. + // actualGraphTokens = 100 - 40 = 60 + // targetWeight = 60 / 100 = 0.6 + // newWeight = 1.0 * 0.8 + 0.6 * 0.2 = 0.8 + 0.12 = 0.92 + eventBus.emitTokenGroundTruth({ + actualTokens: 100, + promptBaseUnits: 100, + }); + + expect(calculator.getLearnedWeight()).toBeCloseTo(0.92, 5); + }); }); diff --git a/packages/core/src/context/utils/adaptiveTokenCalculator.ts b/packages/core/src/context/utils/adaptiveTokenCalculator.ts index 2ac3825ef5..69be8a202f 100644 --- a/packages/core/src/context/utils/adaptiveTokenCalculator.ts +++ b/packages/core/src/context/utils/adaptiveTokenCalculator.ts @@ -31,6 +31,7 @@ export class AdaptiveTokenCalculator implements AdvancedTokenCalculator { charsPerToken: number, registry: NodeBehaviorRegistry, eventBus: ContextEventBus, + private readonly getOverheadTokens?: () => number, ) { this.baseCalculator = new StaticTokenCalculator(charsPerToken, registry); eventBus.onTokenGroundTruth((event: TokenGroundTruthEvent) => { @@ -41,8 +42,16 @@ export class AdaptiveTokenCalculator implements AdvancedTokenCalculator { private handleGroundTruth(actualTokens: number, promptBaseUnits: number) { if (promptBaseUnits <= 0) return; + const overheadTokens = this.getOverheadTokens ? this.getOverheadTokens() : 0; + + // The Gemini API token count includes the static overhead (system instruction + tools) + // and the dynamic chat history (which we measure as promptBaseUnits). + // We subtract the overhead so the adaptive calculator is comparing "apples to apples" + // when learning the weight multiplier for the graph nodes. + const actualGraphTokens = Math.max(0, actualTokens - overheadTokens); + // Determine what ratio we should have used - const targetWeight = actualTokens / promptBaseUnits; + const targetWeight = actualGraphTokens / promptBaseUnits; const oldWeight = this.learnedWeight; // Apply Momentum (Learning Rate) diff --git a/packages/core/src/context/utils/snapshotGenerator.ts b/packages/core/src/context/utils/snapshotGenerator.ts index e715d7a336..3156f516c6 100644 --- a/packages/core/src/context/utils/snapshotGenerator.ts +++ b/packages/core/src/context/utils/snapshotGenerator.ts @@ -48,6 +48,25 @@ export interface SnapshotState { recent_arc: string[]; } +export function isSnapshotState(text: string): boolean { + const trimmed = text.trim(); + if (!trimmed.startsWith('{') || !trimmed.endsWith('}')) { + return false; + } + try { + const parsed: unknown = JSON.parse(trimmed); + if (!isRecord(parsed)) return false; + return ( + Array.isArray(parsed['active_tasks']) && + Array.isArray(parsed['discovered_facts']) && + Array.isArray(parsed['constraints_and_preferences']) && + Array.isArray(parsed['recent_arc']) + ); + } catch { + return false; + } +} + export interface BaselineSnapshotInfo { text: string; abstractsIds: string[]; diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 67783a62d7..2aa813142b 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -329,6 +329,10 @@ export class GeminiChat { this.systemInstruction = sysInstr; } + getSystemInstruction(): string { + return this.systemInstruction; + } + /** * Sends a message to the model and returns the response in chunks. * @@ -1019,6 +1023,10 @@ export class GeminiChat { this.tools = tools; } + getTools(): Tool[] { + return this.tools; + } + async maybeIncludeSchemaDepthContext(error: StructuredError): Promise { // Check for potentially problematic cyclic tools with cyclic schemas // and include a recommendation to remove potentially problematic tools.