Fix bulk of remaining issues with generalist profile (#26073)

This commit is contained in:
joshualitt
2026-05-01 15:04:39 -07:00
committed by GitHub
parent 408afd3c5a
commit de8fdcfa16
52 changed files with 2133 additions and 1364 deletions
@@ -0,0 +1,69 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { ContextTokenCalculator } from './contextTokenCalculator.js';
import { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
import { registerBuiltInBehaviors } from '../graph/builtinBehaviors.js';
import { createDummyNode } from '../testing/contextTestUtils.js';
import { MSG_OVERHEAD_TOKENS } from '../../utils/tokenCalculation.js';
import { NodeType } from '../graph/types.js';
describe('ContextTokenCalculator', () => {
const registry = new NodeBehaviorRegistry();
registerBuiltInBehaviors(registry);
const charsPerToken = 1; // Simplifies math for text nodes in tests
const calculator = new ContextTokenCalculator(charsPerToken, registry);
it('should include structural overhead for each unique turn', () => {
const turn1Id = 'turn-1';
const turn2Id = 'turn-2';
const node1 = createDummyNode(turn1Id, NodeType.USER_PROMPT);
const node2 = createDummyNode(turn1Id, NodeType.USER_PROMPT); // Same turn
const node3 = createDummyNode(turn2Id, NodeType.AGENT_THOUGHT); // Different turn
const nodes = [node1, node2, node3];
// Estimated tokens (using 0.33 per ASCII char heuristic):
// node1: floor(17 chars * 0.33) = 5 tokens
// node2: floor(17 chars * 0.33) = 5 tokens
// node3: floor(19 chars * 0.33) = 6 tokens
// Turn 1 overhead: 5 tokens
// Turn 2 overhead: 5 tokens
// Total: 5 + 5 + 6 + 5 + 5 = 26
const total = calculator.calculateConcreteListTokens(nodes);
expect(total).toBe(26);
});
it('should handle categorical breakdown with overhead', () => {
const turn1Id = 'turn-1';
const node = createDummyNode(turn1Id, NodeType.USER_PROMPT);
const breakdown = calculator.calculateTokenBreakdown([node]);
expect(breakdown.overhead).toBe(MSG_OVERHEAD_TOKENS);
expect(breakdown.total).toBe(
calculator.getTokenCost(node) + MSG_OVERHEAD_TOKENS,
);
});
it('should not double-count overhead for duplicate turn IDs in separate nodes', () => {
const turn1Id = 'turn-1';
const node1 = createDummyNode(turn1Id, NodeType.USER_PROMPT);
const node2 = createDummyNode(turn1Id, NodeType.USER_PROMPT);
const total = calculator.calculateConcreteListTokens([node1, node2]);
// cost(node1) + cost(node2) + 1 * overhead
const expected =
calculator.getTokenCost(node1) +
calculator.getTokenCost(node2) +
MSG_OVERHEAD_TOKENS;
expect(total).toBe(expected);
});
});
@@ -4,8 +4,11 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { Part } from '@google/genai';
import { estimateTokenCountSync } from '../../utils/tokenCalculation.js';
import type { Part, Content } from '@google/genai';
import {
estimateTokenCountSync,
MSG_OVERHEAD_TOKENS,
} from '../../utils/tokenCalculation.js';
import type { ConcreteNode } from '../graph/types.js';
import type { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
@@ -73,18 +76,107 @@ export class ContextTokenCalculator {
return this.cacheNodeTokens(node);
}
/**
* Calculates a detailed breakdown of tokens by category for a list of nodes.
* Useful for calibration tracing and debugging overestimation.
*/
calculateTokenBreakdown(nodes: readonly ConcreteNode[]): {
total: number;
text: number;
media: number;
tool: number;
overhead: number;
} {
const breakdown = { total: 0, text: 0, media: 0, tool: 0, overhead: 0 };
const seenIds = new Set<string>();
const seenTurnIds = new Set<string>();
for (const node of nodes) {
if (seenIds.has(node.id)) continue;
seenIds.add(node.id);
if (node.turnId) {
if (!seenTurnIds.has(node.turnId)) {
seenTurnIds.add(node.turnId);
breakdown.overhead += MSG_OVERHEAD_TOKENS;
breakdown.total += MSG_OVERHEAD_TOKENS;
}
}
const cost = this.getTokenCost(node);
breakdown.total += cost;
const behavior = this.registry.get(node.type);
const parts = behavior.getEstimatableParts(node);
for (const part of parts) {
if (typeof part.text === 'string') {
breakdown.text += estimateTokenCountSync(
[part],
0,
this.charsPerToken,
);
} else if (
part.inlineData?.mimeType?.startsWith('image/') ||
part.fileData?.mimeType?.startsWith('image/')
) {
breakdown.media += estimateTokenCountSync(
[part],
0,
this.charsPerToken,
);
} else if (part.functionCall || part.functionResponse) {
breakdown.tool += estimateTokenCountSync(
[part],
0,
this.charsPerToken,
);
} else {
breakdown.overhead += estimateTokenCountSync(
[part],
0,
this.charsPerToken,
);
}
}
}
return breakdown;
}
/**
* Fast calculation for a flat array of ConcreteNodes (The Nodes).
* It relies entirely on the O(1) sidecar token cache.
*/
calculateConcreteListTokens(nodes: readonly ConcreteNode[]): number {
let tokens = 0;
const seenIds = new Set<string>();
const seenTurnIds = new Set<string>();
for (const node of nodes) {
tokens += this.getTokenCost(node);
if (!seenIds.has(node.id)) {
seenIds.add(node.id);
tokens += this.getTokenCost(node);
if (node.turnId) {
if (!seenTurnIds.has(node.turnId)) {
seenTurnIds.add(node.turnId);
tokens += MSG_OVERHEAD_TOKENS;
}
}
}
}
return tokens;
}
/**
* Calculates the token cost for a single Gemini Content object.
*/
calculateContentTokens(content: Content): number {
return (
this.estimateTokensForParts(content.parts || []) + MSG_OVERHEAD_TOKENS
);
}
/**
* Slower, precise estimation for a Gemini Content/Part graph.
* Deeply inspects the nested structure and uses the base tokenization math.
@@ -0,0 +1,51 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { ConcreteNode } from '../graph/types.js';
import { debugLogger } from '../../utils/debugLogger.js';
/**
* Validates structural and logical invariants of the Episodic Context Graph.
* Primarily used in debug mode to identify "smelly" states before they reach the LLM.
*/
export function checkContextInvariants(
nodes: readonly ConcreteNode[],
context: string,
): void {
const seenIds = new Set<string>();
const duplicates = new Set<string>();
for (const node of nodes) {
if (seenIds.has(node.id)) {
duplicates.add(node.id);
}
seenIds.add(node.id);
}
if (duplicates.size > 0) {
debugLogger.warn(
`[InvariantCheck][${context}] Detected ${duplicates.size} duplicate nodes by ID: ${Array.from(duplicates).join(', ')}`,
);
}
// Check for orphan logic (nodes without turn association)
const orphans = nodes.filter((n) => !n.turnId);
if (orphans.length > 0) {
debugLogger.warn(
`[InvariantCheck][${context}] Detected ${orphans.length} nodes without turnId.`,
);
}
// Check for timestamp linearity
for (let i = 1; i < nodes.length; i++) {
if (nodes[i].timestamp < nodes[i - 1].timestamp) {
debugLogger.warn(
`[InvariantCheck][${context}] Non-linear timestamps detected at index ${i}.`,
);
break;
}
}
}
@@ -23,16 +23,14 @@ Output ONLY the raw factual snapshot, formatted compactly. Do not include markdo
let userPromptText = 'TRANSCRIPT TO SNAPSHOT:\n\n';
for (const node of nodes) {
const payload = node.payload;
let nodeContent = '';
if ('text' in node && typeof node.text === 'string') {
nodeContent = node.text;
} else if ('semanticParts' in node) {
nodeContent = JSON.stringify(node.semanticParts);
} else if ('observation' in node) {
nodeContent =
typeof node.observation === 'string'
? node.observation
: JSON.stringify(node.observation);
if (payload.text) {
nodeContent = payload.text;
} else if (payload.functionCall) {
nodeContent = `CALL: ${payload.functionCall.name}(${JSON.stringify(payload.functionCall.args)})`;
} else if (payload.functionResponse) {
nodeContent = `RESPONSE: ${JSON.stringify(payload.functionResponse.response)}`;
}
userPromptText += `[${node.type}]: ${nodeContent}\n`;