diff --git a/packages/core/src/context/contextManager.ts b/packages/core/src/context/contextManager.ts index 25cec4d93c..ca3115d4ec 100644 --- a/packages/core/src/context/contextManager.ts +++ b/packages/core/src/context/contextManager.ts @@ -139,7 +139,11 @@ export class ContextManager { this.hasPerformedHotStart = true; if (this.buffer.nodes.length > 0) { const nodesForHotStart = [...this.buffer.nodes, ...previewNodes]; - await this.performHotStartCalibration(nodesForHotStart, abortSignal); + await this.performHotStartCalibration( + nodesForHotStart, + this.env.model, + abortSignal, + ); } } })(); @@ -468,6 +472,7 @@ export class ContextManager { private async performHotStartCalibration( nodes: readonly ConcreteNode[], + model: string, abortSignal?: AbortSignal, ) { const history = this.env.graphMapper.fromGraph(nodes); @@ -475,7 +480,7 @@ export class ContextManager { try { const { totalTokens } = await this.env.llmClient.countTokens({ - modelConfigKey: { model: 'context-calibrator' }, + modelConfigKey: { model }, contents, abortSignal, }); diff --git a/packages/core/src/context/initializer.ts b/packages/core/src/context/initializer.ts index dc19c127e1..0804f06509 100644 --- a/packages/core/src/context/initializer.ts +++ b/packages/core/src/context/initializer.ts @@ -116,6 +116,7 @@ export async function initializeContextManager( const env = new ContextEnvironmentImpl( () => config.getBaseLlmClient(), + () => config.getActiveModel(), config.getSessionId(), lastPromptId, logDir, diff --git a/packages/core/src/context/pipeline/environment.ts b/packages/core/src/context/pipeline/environment.ts index b57466638a..df4f2b6887 100644 --- a/packages/core/src/context/pipeline/environment.ts +++ b/packages/core/src/context/pipeline/environment.ts @@ -19,6 +19,7 @@ export interface RenderOptions { export interface ContextEnvironment { readonly llmClient: BaseLlmClient; + readonly model: string; readonly promptId: string; readonly sessionId: string; readonly traceDir: string; diff --git a/packages/core/src/context/pipeline/environmentImpl.test.ts b/packages/core/src/context/pipeline/environmentImpl.test.ts index d6f2835b89..7f52f1217b 100644 --- a/packages/core/src/context/pipeline/environmentImpl.test.ts +++ b/packages/core/src/context/pipeline/environmentImpl.test.ts @@ -21,6 +21,7 @@ describe('ContextEnvironmentImpl', () => { const env = new ContextEnvironmentImpl( () => mockLlmClient, + () => 'mock-model', 'mock-session', 'mock-prompt', '/tmp/trace', diff --git a/packages/core/src/context/pipeline/environmentImpl.ts b/packages/core/src/context/pipeline/environmentImpl.ts index 78f1b3dda5..405448807b 100644 --- a/packages/core/src/context/pipeline/environmentImpl.ts +++ b/packages/core/src/context/pipeline/environmentImpl.ts @@ -19,6 +19,7 @@ export class ContextEnvironmentImpl implements ContextEnvironment { constructor( private readonly llmClientProvider: () => BaseLlmClient, + private readonly activeModel: () => string, readonly sessionId: string, readonly promptId: string, readonly traceDir: string, @@ -37,4 +38,8 @@ export class ContextEnvironmentImpl implements ContextEnvironment { get llmClient(): BaseLlmClient { return this.llmClientProvider(); } + + get model(): string { + return this.activeModel(); + } } diff --git a/packages/core/src/context/system-tests/simulationHarness.ts b/packages/core/src/context/system-tests/simulationHarness.ts index f177b6ec7b..8f0d1290ea 100644 --- a/packages/core/src/context/system-tests/simulationHarness.ts +++ b/packages/core/src/context/system-tests/simulationHarness.ts @@ -68,6 +68,7 @@ export class SimulationHarness { this.env = new ContextEnvironmentImpl( () => mockLlmClient, + () => 'mock-model', 'sim-prompt', 'sim-session', mockTempDir, diff --git a/packages/core/src/context/testing/contextTestUtils.ts b/packages/core/src/context/testing/contextTestUtils.ts index fed08e2a82..389e33df52 100644 --- a/packages/core/src/context/testing/contextTestUtils.ts +++ b/packages/core/src/context/testing/contextTestUtils.ts @@ -181,7 +181,8 @@ export function createMockEnvironment( let env = new ContextEnvironmentImpl( () => llmClient as BaseLlmClient, - 'mock-session', + () => 'mock-session', + 'mock-model', 'mock-prompt-id', '/tmp/.gemini/trace', '/tmp/.gemini/tool-outputs', @@ -196,6 +197,7 @@ export function createMockEnvironment( if (overrides.llmClient) { env = new ContextEnvironmentImpl( () => overrides.llmClient!, + () => overrides.model!, env.sessionId, env.promptId, env.traceDir, @@ -276,6 +278,7 @@ export function createMockContextConfig( getTargetDir: vi.fn().mockReturnValue('/tmp'), getSessionId: vi.fn().mockReturnValue('test-session'), getExperimentalContextManagementConfig: vi.fn().mockReturnValue(undefined), + getActiveModel: vi.fn().mockReturnValue('mock-model'), }; // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion @@ -303,6 +306,7 @@ export function setupContextComponentTest( const env = new ContextEnvironmentImpl( () => config.getBaseLlmClient(), + () => config.getActiveModel(), 'test prompt-id', 'test-session', '/tmp',