mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-23 19:44:30 -07:00
feat: implement adaptive thinking budget
This commit is contained in:
@@ -73,6 +73,7 @@ import type { ModelConfigServiceConfig } from '../services/modelConfigService.js
|
||||
import { ModelConfigService } from '../services/modelConfigService.js';
|
||||
import { DEFAULT_MODEL_CONFIGS } from './defaultModelConfigs.js';
|
||||
import { ContextManager } from '../services/contextManager.js';
|
||||
import { AdaptiveBudgetService } from '../services/adaptiveBudgetService.js';
|
||||
|
||||
// Re-export OAuth config type
|
||||
export type { MCPOAuthConfig, AnyToolInvocation };
|
||||
@@ -335,6 +336,10 @@ export interface ConfigParameters {
|
||||
disableModelRouterForAuth?: AuthType[];
|
||||
codebaseInvestigatorSettings?: CodebaseInvestigatorSettings;
|
||||
introspectionAgentSettings?: IntrospectionAgentSettings;
|
||||
adaptiveThinking?: {
|
||||
enabled?: boolean;
|
||||
classifierModel?: string;
|
||||
};
|
||||
continueOnFailedApiCall?: boolean;
|
||||
retryFetchErrors?: boolean;
|
||||
enableShellOutputEfficiency?: boolean;
|
||||
@@ -460,6 +465,10 @@ export class Config {
|
||||
private readonly outputSettings: OutputSettings;
|
||||
private readonly codebaseInvestigatorSettings: CodebaseInvestigatorSettings;
|
||||
private readonly introspectionAgentSettings: IntrospectionAgentSettings;
|
||||
private readonly adaptiveThinking: {
|
||||
enabled: boolean;
|
||||
classifierModel: string;
|
||||
};
|
||||
private readonly continueOnFailedApiCall: boolean;
|
||||
private readonly retryFetchErrors: boolean;
|
||||
private readonly enableShellOutputEfficiency: boolean;
|
||||
@@ -491,6 +500,7 @@ export class Config {
|
||||
private readonly experimentalJitContext: boolean;
|
||||
private contextManager?: ContextManager;
|
||||
private terminalBackground: string | undefined = undefined;
|
||||
private adaptiveBudgetService!: AdaptiveBudgetService;
|
||||
|
||||
constructor(params: ConfigParameters) {
|
||||
this.sessionId = params.sessionId;
|
||||
@@ -618,6 +628,10 @@ export class Config {
|
||||
this.introspectionAgentSettings = {
|
||||
enabled: params.introspectionAgentSettings?.enabled ?? false,
|
||||
};
|
||||
this.adaptiveThinking = {
|
||||
enabled: params.adaptiveThinking?.enabled ?? false,
|
||||
classifierModel: params.adaptiveThinking?.classifierModel ?? 'classifier',
|
||||
};
|
||||
this.continueOnFailedApiCall = params.continueOnFailedApiCall ?? true;
|
||||
this.enableShellOutputEfficiency =
|
||||
params.enableShellOutputEfficiency ?? true;
|
||||
@@ -763,6 +777,13 @@ export class Config {
|
||||
await this.contextManager.refresh();
|
||||
}
|
||||
|
||||
this.adaptiveBudgetService = new AdaptiveBudgetService(this);
|
||||
if (this.adaptiveThinking.enabled) {
|
||||
debugLogger.debug(
|
||||
`Adaptive Thinking Budget enabled (classifier: ${this.adaptiveThinking.classifierModel})`,
|
||||
);
|
||||
}
|
||||
|
||||
await this.geminiClient.initialize();
|
||||
}
|
||||
|
||||
@@ -770,6 +791,10 @@ export class Config {
|
||||
return this.contentGenerator;
|
||||
}
|
||||
|
||||
getAdaptiveBudgetService(): AdaptiveBudgetService {
|
||||
return this.adaptiveBudgetService;
|
||||
}
|
||||
|
||||
async refreshAuth(authMethod: AuthType) {
|
||||
// Reset availability service when switching auth
|
||||
this.modelAvailabilityService.reset();
|
||||
@@ -1664,6 +1689,10 @@ export class Config {
|
||||
return this.introspectionAgentSettings;
|
||||
}
|
||||
|
||||
getAdaptiveThinkingConfig(): { enabled: boolean; classifierModel: string } {
|
||||
return this.adaptiveThinking;
|
||||
}
|
||||
|
||||
async createToolRegistry(): Promise<ToolRegistry> {
|
||||
const registry = new ToolRegistry(this, this.messageBus);
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ import { GeminiChat } from './geminiChat.js';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { tokenLimit } from './tokenLimits.js';
|
||||
import { partListUnionToString } from './geminiRequest.js';
|
||||
import type {
|
||||
ChatRecordingService,
|
||||
ResumedSessionData,
|
||||
@@ -620,6 +621,25 @@ export class GeminiClient {
|
||||
|
||||
// availability logic
|
||||
const modelConfigKey: ModelConfigKey = { model: modelToUse };
|
||||
|
||||
// Adaptive Thinking Budget Integration
|
||||
if (
|
||||
!isInvalidStreamRetry &&
|
||||
this.config.getAdaptiveThinkingConfig().enabled
|
||||
) {
|
||||
const userMessage = partListUnionToString(request);
|
||||
if (userMessage) {
|
||||
const adaptiveConfig = await this.config
|
||||
.getAdaptiveBudgetService()
|
||||
.determineAdaptiveConfig(userMessage, modelToUse);
|
||||
|
||||
if (adaptiveConfig) {
|
||||
modelConfigKey.thinkingBudget = adaptiveConfig.thinkingBudget;
|
||||
modelConfigKey.thinkingLevel = adaptiveConfig.thinkingLevel;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const { model: finalModel } = applyModelSelection(
|
||||
this.config,
|
||||
modelConfigKey,
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import {
|
||||
AdaptiveBudgetService,
|
||||
ComplexityLevel,
|
||||
} from './adaptiveBudgetService.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ThinkingLevel } from '@google/genai';
|
||||
|
||||
describe('AdaptiveBudgetService', () => {
|
||||
it('should map complexity levels to correct V2 budgets', () => {
|
||||
const service = new AdaptiveBudgetService({} as Config);
|
||||
expect(service.getThinkingBudgetV2(ComplexityLevel.SIMPLE)).toBe(1024);
|
||||
expect(service.getThinkingBudgetV2(ComplexityLevel.MODERATE)).toBe(4096);
|
||||
expect(service.getThinkingBudgetV2(ComplexityLevel.HIGH)).toBe(16384);
|
||||
expect(service.getThinkingBudgetV2(ComplexityLevel.EXTREME)).toBe(32768);
|
||||
});
|
||||
|
||||
it('should map complexity levels to correct V3 levels', () => {
|
||||
const service = new AdaptiveBudgetService({} as Config);
|
||||
expect(service.getThinkingLevelV3(ComplexityLevel.SIMPLE)).toBe(
|
||||
ThinkingLevel.LOW,
|
||||
);
|
||||
expect(service.getThinkingLevelV3(ComplexityLevel.MODERATE)).toBe(
|
||||
ThinkingLevel.LOW,
|
||||
);
|
||||
expect(service.getThinkingLevelV3(ComplexityLevel.HIGH)).toBe(
|
||||
ThinkingLevel.HIGH,
|
||||
);
|
||||
expect(service.getThinkingLevelV3(ComplexityLevel.EXTREME)).toBe(
|
||||
ThinkingLevel.HIGH,
|
||||
);
|
||||
});
|
||||
|
||||
it('should determine adaptive config based on LLM response', async () => {
|
||||
const mockGenerateContent = vi.fn().mockResolvedValue({
|
||||
candidates: [{ content: { parts: [{ text: '3' }] } }],
|
||||
});
|
||||
|
||||
const mockConfig = {
|
||||
getBaseLlmClient: () => ({
|
||||
generateContent: mockGenerateContent,
|
||||
}),
|
||||
getAdaptiveThinkingConfig: () => ({
|
||||
enabled: true,
|
||||
classifierModel: 'gemini-2.0-flash',
|
||||
}),
|
||||
} as unknown as Config;
|
||||
|
||||
const service = new AdaptiveBudgetService(mockConfig);
|
||||
const result = await service.determineAdaptiveConfig(
|
||||
'Complex task',
|
||||
'gemini-2.5-pro',
|
||||
);
|
||||
|
||||
expect(result?.complexity).toBe(ComplexityLevel.HIGH);
|
||||
expect(result?.thinkingBudget).toBe(16384);
|
||||
expect(mockGenerateContent).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle Gemini 3 models with thinkingLevel', async () => {
|
||||
const mockConfig = {
|
||||
getBaseLlmClient: () => ({
|
||||
generateContent: vi.fn().mockResolvedValue({
|
||||
candidates: [{ content: { parts: [{ text: '1' }] } }],
|
||||
}),
|
||||
}),
|
||||
getAdaptiveThinkingConfig: () => ({
|
||||
enabled: true,
|
||||
classifierModel: 'gemini-2.0-flash',
|
||||
}),
|
||||
} as unknown as Config;
|
||||
|
||||
const service = new AdaptiveBudgetService(mockConfig);
|
||||
const result = await service.determineAdaptiveConfig(
|
||||
'Hi',
|
||||
'gemini-3-pro-preview',
|
||||
);
|
||||
|
||||
expect(result?.complexity).toBe(ComplexityLevel.SIMPLE);
|
||||
expect(result?.thinkingLevel).toBe(ThinkingLevel.LOW);
|
||||
expect(result?.thinkingBudget).toBeUndefined();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,132 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
import type { Config } from '../config/config.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { isGemini2Model, isPreviewModel } from '../config/models.js';
|
||||
import { ThinkingLevel } from '@google/genai';
|
||||
|
||||
export enum ComplexityLevel {
|
||||
SIMPLE = 1,
|
||||
MODERATE = 2,
|
||||
HIGH = 3,
|
||||
EXTREME = 4,
|
||||
}
|
||||
|
||||
export const BUDGET_MAPPING_V2: Record<ComplexityLevel, number> = {
|
||||
[ComplexityLevel.SIMPLE]: 1024,
|
||||
[ComplexityLevel.MODERATE]: 4096,
|
||||
[ComplexityLevel.HIGH]: 16384,
|
||||
[ComplexityLevel.EXTREME]: 32768,
|
||||
};
|
||||
|
||||
export const LEVEL_MAPPING_V3: Record<ComplexityLevel, ThinkingLevel> = {
|
||||
[ComplexityLevel.SIMPLE]: ThinkingLevel.LOW,
|
||||
[ComplexityLevel.MODERATE]: ThinkingLevel.LOW,
|
||||
[ComplexityLevel.HIGH]: ThinkingLevel.HIGH,
|
||||
[ComplexityLevel.EXTREME]: ThinkingLevel.HIGH,
|
||||
};
|
||||
|
||||
export interface AdaptiveBudgetResult {
|
||||
complexity: ComplexityLevel;
|
||||
thinkingBudget?: number;
|
||||
thinkingLevel?: ThinkingLevel;
|
||||
strategyNote?: string;
|
||||
}
|
||||
|
||||
export class AdaptiveBudgetService {
|
||||
constructor(private config: Config) {}
|
||||
|
||||
/**
|
||||
* Analyzes the user prompt and determines the optimal thinking configuration.
|
||||
*
|
||||
* Note on future scaling (per arXiv:2512.19585):
|
||||
* At Complexity 4 (Extreme), we should consider:
|
||||
* 1. Best-of-N: Generate multiple solutions.
|
||||
* 2. LLM-as-a-Judge: Use a strong model to evaluate candidates.
|
||||
* 3. Compiler Verification: Check code correctness via environment tools.
|
||||
*/
|
||||
async determineAdaptiveConfig(
|
||||
userPrompt: string,
|
||||
model: string,
|
||||
): Promise<AdaptiveBudgetResult | undefined> {
|
||||
const { classifierModel } = this.config.getAdaptiveThinkingConfig();
|
||||
|
||||
try {
|
||||
const llm = this.config.getBaseLlmClient();
|
||||
debugLogger.debug(
|
||||
`AdaptiveBudgetService: Classifying prompt complexity using ${classifierModel}...`,
|
||||
);
|
||||
const systemPrompt = `You are a complexity classifier for a coding assistant.
|
||||
Analyze the user's request and determine the complexity of the task.
|
||||
Output ONLY a single integer from 1 to 4 based on the following scale:
|
||||
|
||||
1 (Simple): Quick fixes, syntax questions, simple explanations, greetings.
|
||||
2 (Moderate): Function-level logic, writing small scripts, standard debugging.
|
||||
3 (High): Module-level refactoring, complex feature implementation, multi-file changes.
|
||||
4 (Extreme): Architecture design, deep root-cause analysis of obscure bugs, large-scale migrations.
|
||||
|
||||
Request: ${userPrompt}
|
||||
Complexity Level:`;
|
||||
|
||||
const response = await llm.generateContent({
|
||||
modelConfigKey: { model: classifierModel },
|
||||
contents: [{ role: 'user', parts: [{ text: systemPrompt }] }],
|
||||
promptId: 'adaptive-budget-classifier',
|
||||
abortSignal: new AbortController().signal,
|
||||
});
|
||||
|
||||
const text = response.candidates?.[0]?.content?.parts?.[0]?.text?.trim();
|
||||
if (!text) {
|
||||
debugLogger.debug(
|
||||
'AdaptiveBudgetService: No response from classifier.',
|
||||
);
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const level = parseInt(text, 10) as ComplexityLevel;
|
||||
if (isNaN(level) || level < 1 || level > 4) {
|
||||
debugLogger.debug(
|
||||
`AdaptiveBudgetService: Invalid complexity level returned: ${text}`,
|
||||
);
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const result: AdaptiveBudgetResult = { complexity: level };
|
||||
|
||||
// Determine mapping based on model version
|
||||
// Gemini 3 uses ThinkingLevel, Gemini 2.x uses thinkingBudget
|
||||
if (isPreviewModel(model)) {
|
||||
result.thinkingLevel = LEVEL_MAPPING_V3[level] ?? ThinkingLevel.HIGH;
|
||||
} else if (isGemini2Model(model)) {
|
||||
result.thinkingBudget = BUDGET_MAPPING_V2[level];
|
||||
}
|
||||
|
||||
if (level === ComplexityLevel.EXTREME) {
|
||||
result.strategyNote =
|
||||
'EXTREME complexity detected. Future implementations should use Best-of-N + Verification.';
|
||||
}
|
||||
|
||||
debugLogger.debug(
|
||||
`AdaptiveBudgetService: Complexity ${level} -> Thinking Param: ${result.thinkingLevel || result.thinkingBudget}`,
|
||||
);
|
||||
return result;
|
||||
} catch (error) {
|
||||
debugLogger.error(
|
||||
'AdaptiveBudgetService: Error classifying complexity',
|
||||
error,
|
||||
);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
getThinkingBudgetV2(level: ComplexityLevel): number {
|
||||
return BUDGET_MAPPING_V2[level];
|
||||
}
|
||||
|
||||
getThinkingLevelV3(level: ComplexityLevel): ThinkingLevel {
|
||||
return LEVEL_MAPPING_V3[level] ?? ThinkingLevel.HIGH;
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { GenerateContentConfig } from '@google/genai';
|
||||
import type { GenerateContentConfig, ThinkingLevel } from '@google/genai';
|
||||
|
||||
// The primary key for the ModelConfig is the model string. However, we also
|
||||
// support a secondary key to limit the override scope, typically an agent name.
|
||||
@@ -26,6 +26,10 @@ export interface ModelConfigKey {
|
||||
// This allows overrides to specify different settings (e.g., higher temperature)
|
||||
// specifically for retry scenarios.
|
||||
isRetry?: boolean;
|
||||
|
||||
// Dynamic thinking configuration determined at runtime (e.g. via complexity classification)
|
||||
thinkingBudget?: number;
|
||||
thinkingLevel?: ThinkingLevel;
|
||||
}
|
||||
|
||||
export interface ModelConfig {
|
||||
@@ -205,6 +209,22 @@ export class ModelConfigService {
|
||||
}
|
||||
}
|
||||
|
||||
// Apply dynamic thinking parameters from context if present
|
||||
if (
|
||||
context.thinkingBudget !== undefined ||
|
||||
context.thinkingLevel !== undefined
|
||||
) {
|
||||
resolvedConfig.thinkingConfig = {
|
||||
...(resolvedConfig.thinkingConfig as object),
|
||||
...(context.thinkingBudget !== undefined
|
||||
? { thinkingBudget: context.thinkingBudget }
|
||||
: {}),
|
||||
...(context.thinkingLevel !== undefined
|
||||
? { thinkingLevel: context.thinkingLevel }
|
||||
: {}),
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
model: baseModel,
|
||||
generateContentConfig: resolvedConfig,
|
||||
|
||||
Reference in New Issue
Block a user