mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-06-25 02:37:53 -07:00
Fix bulk of remaining issues with generalist profile (#26073)
This commit is contained in:
@@ -0,0 +1,69 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { ContextTokenCalculator } from './contextTokenCalculator.js';
|
||||
import { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
|
||||
import { registerBuiltInBehaviors } from '../graph/builtinBehaviors.js';
|
||||
import { createDummyNode } from '../testing/contextTestUtils.js';
|
||||
import { MSG_OVERHEAD_TOKENS } from '../../utils/tokenCalculation.js';
|
||||
import { NodeType } from '../graph/types.js';
|
||||
|
||||
describe('ContextTokenCalculator', () => {
|
||||
const registry = new NodeBehaviorRegistry();
|
||||
registerBuiltInBehaviors(registry);
|
||||
const charsPerToken = 1; // Simplifies math for text nodes in tests
|
||||
const calculator = new ContextTokenCalculator(charsPerToken, registry);
|
||||
|
||||
it('should include structural overhead for each unique turn', () => {
|
||||
const turn1Id = 'turn-1';
|
||||
const turn2Id = 'turn-2';
|
||||
|
||||
const node1 = createDummyNode(turn1Id, NodeType.USER_PROMPT);
|
||||
const node2 = createDummyNode(turn1Id, NodeType.USER_PROMPT); // Same turn
|
||||
const node3 = createDummyNode(turn2Id, NodeType.AGENT_THOUGHT); // Different turn
|
||||
|
||||
const nodes = [node1, node2, node3];
|
||||
|
||||
// Estimated tokens (using 0.33 per ASCII char heuristic):
|
||||
// node1: floor(17 chars * 0.33) = 5 tokens
|
||||
// node2: floor(17 chars * 0.33) = 5 tokens
|
||||
// node3: floor(19 chars * 0.33) = 6 tokens
|
||||
// Turn 1 overhead: 5 tokens
|
||||
// Turn 2 overhead: 5 tokens
|
||||
// Total: 5 + 5 + 6 + 5 + 5 = 26
|
||||
|
||||
const total = calculator.calculateConcreteListTokens(nodes);
|
||||
expect(total).toBe(26);
|
||||
});
|
||||
|
||||
it('should handle categorical breakdown with overhead', () => {
|
||||
const turn1Id = 'turn-1';
|
||||
const node = createDummyNode(turn1Id, NodeType.USER_PROMPT);
|
||||
|
||||
const breakdown = calculator.calculateTokenBreakdown([node]);
|
||||
|
||||
expect(breakdown.overhead).toBe(MSG_OVERHEAD_TOKENS);
|
||||
expect(breakdown.total).toBe(
|
||||
calculator.getTokenCost(node) + MSG_OVERHEAD_TOKENS,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not double-count overhead for duplicate turn IDs in separate nodes', () => {
|
||||
const turn1Id = 'turn-1';
|
||||
const node1 = createDummyNode(turn1Id, NodeType.USER_PROMPT);
|
||||
const node2 = createDummyNode(turn1Id, NodeType.USER_PROMPT);
|
||||
|
||||
const total = calculator.calculateConcreteListTokens([node1, node2]);
|
||||
|
||||
// cost(node1) + cost(node2) + 1 * overhead
|
||||
const expected =
|
||||
calculator.getTokenCost(node1) +
|
||||
calculator.getTokenCost(node2) +
|
||||
MSG_OVERHEAD_TOKENS;
|
||||
expect(total).toBe(expected);
|
||||
});
|
||||
});
|
||||
@@ -4,8 +4,11 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Part } from '@google/genai';
|
||||
import { estimateTokenCountSync } from '../../utils/tokenCalculation.js';
|
||||
import type { Part, Content } from '@google/genai';
|
||||
import {
|
||||
estimateTokenCountSync,
|
||||
MSG_OVERHEAD_TOKENS,
|
||||
} from '../../utils/tokenCalculation.js';
|
||||
import type { ConcreteNode } from '../graph/types.js';
|
||||
import type { NodeBehaviorRegistry } from '../graph/behaviorRegistry.js';
|
||||
|
||||
@@ -73,18 +76,107 @@ export class ContextTokenCalculator {
|
||||
return this.cacheNodeTokens(node);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates a detailed breakdown of tokens by category for a list of nodes.
|
||||
* Useful for calibration tracing and debugging overestimation.
|
||||
*/
|
||||
calculateTokenBreakdown(nodes: readonly ConcreteNode[]): {
|
||||
total: number;
|
||||
text: number;
|
||||
media: number;
|
||||
tool: number;
|
||||
overhead: number;
|
||||
} {
|
||||
const breakdown = { total: 0, text: 0, media: 0, tool: 0, overhead: 0 };
|
||||
const seenIds = new Set<string>();
|
||||
const seenTurnIds = new Set<string>();
|
||||
|
||||
for (const node of nodes) {
|
||||
if (seenIds.has(node.id)) continue;
|
||||
seenIds.add(node.id);
|
||||
|
||||
if (node.turnId) {
|
||||
if (!seenTurnIds.has(node.turnId)) {
|
||||
seenTurnIds.add(node.turnId);
|
||||
breakdown.overhead += MSG_OVERHEAD_TOKENS;
|
||||
breakdown.total += MSG_OVERHEAD_TOKENS;
|
||||
}
|
||||
}
|
||||
|
||||
const cost = this.getTokenCost(node);
|
||||
breakdown.total += cost;
|
||||
|
||||
const behavior = this.registry.get(node.type);
|
||||
const parts = behavior.getEstimatableParts(node);
|
||||
|
||||
for (const part of parts) {
|
||||
if (typeof part.text === 'string') {
|
||||
breakdown.text += estimateTokenCountSync(
|
||||
[part],
|
||||
0,
|
||||
this.charsPerToken,
|
||||
);
|
||||
} else if (
|
||||
part.inlineData?.mimeType?.startsWith('image/') ||
|
||||
part.fileData?.mimeType?.startsWith('image/')
|
||||
) {
|
||||
breakdown.media += estimateTokenCountSync(
|
||||
[part],
|
||||
0,
|
||||
this.charsPerToken,
|
||||
);
|
||||
} else if (part.functionCall || part.functionResponse) {
|
||||
breakdown.tool += estimateTokenCountSync(
|
||||
[part],
|
||||
0,
|
||||
this.charsPerToken,
|
||||
);
|
||||
} else {
|
||||
breakdown.overhead += estimateTokenCountSync(
|
||||
[part],
|
||||
0,
|
||||
this.charsPerToken,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
return breakdown;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fast calculation for a flat array of ConcreteNodes (The Nodes).
|
||||
* It relies entirely on the O(1) sidecar token cache.
|
||||
*/
|
||||
calculateConcreteListTokens(nodes: readonly ConcreteNode[]): number {
|
||||
let tokens = 0;
|
||||
const seenIds = new Set<string>();
|
||||
const seenTurnIds = new Set<string>();
|
||||
|
||||
for (const node of nodes) {
|
||||
tokens += this.getTokenCost(node);
|
||||
if (!seenIds.has(node.id)) {
|
||||
seenIds.add(node.id);
|
||||
tokens += this.getTokenCost(node);
|
||||
|
||||
if (node.turnId) {
|
||||
if (!seenTurnIds.has(node.turnId)) {
|
||||
seenTurnIds.add(node.turnId);
|
||||
tokens += MSG_OVERHEAD_TOKENS;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the token cost for a single Gemini Content object.
|
||||
*/
|
||||
calculateContentTokens(content: Content): number {
|
||||
return (
|
||||
this.estimateTokensForParts(content.parts || []) + MSG_OVERHEAD_TOKENS
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Slower, precise estimation for a Gemini Content/Part graph.
|
||||
* Deeply inspects the nested structure and uses the base tokenization math.
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { ConcreteNode } from '../graph/types.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
|
||||
/**
|
||||
* Validates structural and logical invariants of the Episodic Context Graph.
|
||||
* Primarily used in debug mode to identify "smelly" states before they reach the LLM.
|
||||
*/
|
||||
export function checkContextInvariants(
|
||||
nodes: readonly ConcreteNode[],
|
||||
context: string,
|
||||
): void {
|
||||
const seenIds = new Set<string>();
|
||||
const duplicates = new Set<string>();
|
||||
|
||||
for (const node of nodes) {
|
||||
if (seenIds.has(node.id)) {
|
||||
duplicates.add(node.id);
|
||||
}
|
||||
seenIds.add(node.id);
|
||||
}
|
||||
|
||||
if (duplicates.size > 0) {
|
||||
debugLogger.warn(
|
||||
`[InvariantCheck][${context}] Detected ${duplicates.size} duplicate nodes by ID: ${Array.from(duplicates).join(', ')}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Check for orphan logic (nodes without turn association)
|
||||
const orphans = nodes.filter((n) => !n.turnId);
|
||||
if (orphans.length > 0) {
|
||||
debugLogger.warn(
|
||||
`[InvariantCheck][${context}] Detected ${orphans.length} nodes without turnId.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Check for timestamp linearity
|
||||
for (let i = 1; i < nodes.length; i++) {
|
||||
if (nodes[i].timestamp < nodes[i - 1].timestamp) {
|
||||
debugLogger.warn(
|
||||
`[InvariantCheck][${context}] Non-linear timestamps detected at index ${i}.`,
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -23,16 +23,14 @@ Output ONLY the raw factual snapshot, formatted compactly. Do not include markdo
|
||||
|
||||
let userPromptText = 'TRANSCRIPT TO SNAPSHOT:\n\n';
|
||||
for (const node of nodes) {
|
||||
const payload = node.payload;
|
||||
let nodeContent = '';
|
||||
if ('text' in node && typeof node.text === 'string') {
|
||||
nodeContent = node.text;
|
||||
} else if ('semanticParts' in node) {
|
||||
nodeContent = JSON.stringify(node.semanticParts);
|
||||
} else if ('observation' in node) {
|
||||
nodeContent =
|
||||
typeof node.observation === 'string'
|
||||
? node.observation
|
||||
: JSON.stringify(node.observation);
|
||||
if (payload.text) {
|
||||
nodeContent = payload.text;
|
||||
} else if (payload.functionCall) {
|
||||
nodeContent = `CALL: ${payload.functionCall.name}(${JSON.stringify(payload.functionCall.args)})`;
|
||||
} else if (payload.functionResponse) {
|
||||
nodeContent = `RESPONSE: ${JSON.stringify(payload.functionResponse.response)}`;
|
||||
}
|
||||
|
||||
userPromptText += `[${node.type}]: ${nodeContent}\n`;
|
||||
|
||||
Reference in New Issue
Block a user