mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-18 07:43:00 -07:00
refactor(config): revert experiment getter names and async signatures
This commit is contained in:
@@ -120,7 +120,7 @@ export const createMockConfig = (overrides: Partial<Config> = {}): Config =>
|
||||
isTrustedFolder: vi.fn().mockReturnValue(true),
|
||||
getCompressionThreshold: vi.fn().mockResolvedValue(undefined),
|
||||
getUserCaching: vi.fn().mockResolvedValue(false),
|
||||
isNumericalRoutingEnabled: vi.fn().mockReturnValue(false),
|
||||
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false),
|
||||
isAwesomeEnabled: vi.fn().mockReturnValue(false),
|
||||
getClassifierThreshold: vi.fn().mockResolvedValue(undefined),
|
||||
getBannerTextNoCapacityIssues: vi.fn().mockResolvedValue(''),
|
||||
|
||||
@@ -14,7 +14,7 @@ describe('Config CLI Override', () => {
|
||||
const cwd = process.cwd();
|
||||
const model = 'gemini-pro';
|
||||
|
||||
it('should prioritize CLI argument over local setting', () => {
|
||||
it('should prioritize CLI argument over local setting', async () => {
|
||||
const config = new Config({
|
||||
sessionId,
|
||||
targetDir,
|
||||
@@ -25,10 +25,10 @@ describe('Config CLI Override', () => {
|
||||
experimentalSettings: { 'enable-numerical-routing': false },
|
||||
});
|
||||
|
||||
expect(config.isNumericalRoutingEnabled()).toBe(true);
|
||||
expect(await config.getNumericalRoutingEnabled()).toBe(true);
|
||||
});
|
||||
|
||||
it('should prioritize CLI argument over remote experiment', () => {
|
||||
it('should prioritize CLI argument over remote experiment', async () => {
|
||||
const config = new Config({
|
||||
sessionId,
|
||||
targetDir,
|
||||
@@ -44,6 +44,6 @@ describe('Config CLI Override', () => {
|
||||
},
|
||||
});
|
||||
|
||||
expect(config.isNumericalRoutingEnabled()).toBe(true);
|
||||
expect(await config.getNumericalRoutingEnabled()).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -3020,15 +3020,15 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
return remoteThreshold;
|
||||
}
|
||||
|
||||
getUserCaching(): boolean | undefined {
|
||||
async getUserCaching(): Promise<boolean | undefined> {
|
||||
return this.getExperimentValue<boolean>(ExperimentFlags.USER_CACHING);
|
||||
}
|
||||
|
||||
getPlanModeRoutingEnabled(): boolean {
|
||||
async getPlanModeRoutingEnabled(): Promise<boolean> {
|
||||
return this.planModeRoutingEnabled;
|
||||
}
|
||||
|
||||
isNumericalRoutingEnabled(): boolean {
|
||||
async getNumericalRoutingEnabled(): Promise<boolean> {
|
||||
return (
|
||||
this.getExperimentValue<boolean>(
|
||||
ExperimentFlags.ENABLE_NUMERICAL_ROUTING,
|
||||
@@ -3041,8 +3041,8 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
* If a remote threshold is provided and within range (0-100), it is returned.
|
||||
* Otherwise, the default threshold (90) is returned.
|
||||
*/
|
||||
getResolvedClassifierThreshold(): number {
|
||||
const remoteValue = this.getClassifierThreshold();
|
||||
async getResolvedClassifierThreshold(): Promise<number> {
|
||||
const remoteValue = await this.getClassifierThreshold();
|
||||
const defaultValue = 90;
|
||||
|
||||
if (
|
||||
@@ -3057,13 +3057,13 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
getClassifierThreshold(): number | undefined {
|
||||
async getClassifierThreshold(): Promise<number | undefined> {
|
||||
return this.getExperimentValue<number>(
|
||||
ExperimentFlags.CLASSIFIER_THRESHOLD,
|
||||
);
|
||||
}
|
||||
|
||||
getBannerTextNoCapacityIssues(): string {
|
||||
async getBannerTextNoCapacityIssues(): Promise<string> {
|
||||
return (
|
||||
this.getExperimentValue<string>(
|
||||
ExperimentFlags.BANNER_TEXT_NO_CAPACITY_ISSUES,
|
||||
@@ -3071,7 +3071,7 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
);
|
||||
}
|
||||
|
||||
getBannerTextCapacityIssues(): string {
|
||||
async getBannerTextCapacityIssues(): Promise<string> {
|
||||
return (
|
||||
this.getExperimentValue<string>(
|
||||
ExperimentFlags.BANNER_TEXT_CAPACITY_ISSUES,
|
||||
@@ -3079,7 +3079,6 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
);
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
/**
|
||||
* Returns whether the user has access to Pro models.
|
||||
* This is determined by the PRO_MODEL_NO_ACCESS experiment flag.
|
||||
@@ -3198,11 +3197,12 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
return (
|
||||
this.experiments?.flags[ExperimentFlags.GEMINI_3_1_FLASH_LITE_LAUNCHED]
|
||||
?.boolValue ?? false
|
||||
=======
|
||||
);
|
||||
}
|
||||
|
||||
isAwesomeEnabled(): boolean {
|
||||
return (
|
||||
this.getExperimentValue<boolean>(ExperimentFlags.ENABLE_AWESOME) ?? false
|
||||
>>>>>>> d2ce1460f (feat(config): add enable-awesome experiment to show custom ASCII art)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -51,12 +51,15 @@ describe('ModelRouterService', () => {
|
||||
mockBaseLlmClient = {} as BaseLlmClient;
|
||||
mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient;
|
||||
vi.spyOn(mockConfig, 'getBaseLlmClient').mockReturnValue(mockBaseLlmClient);
|
||||
vi.spyOn(mockConfig, 'getLocalLiteRtLmClient').mockReturnValue(
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
vi.spyOn(mockConfig, 'isNumericalRoutingEnabled').mockReturnValue(true);
|
||||
vi.spyOn(mockConfig, 'getResolvedClassifierThreshold').mockReturnValue(90);
|
||||
vi.spyOn(mockConfig, 'getClassifierThreshold').mockReturnValue(undefined); vi.spyOn(mockConfig, 'getGemmaModelRouterSettings').mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getLocalLiteRtLmClient').mockReturnValue(
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(true);
|
||||
vi.spyOn(mockConfig, 'getResolvedClassifierThreshold').mockResolvedValue(
|
||||
90,
|
||||
);
|
||||
vi.spyOn(mockConfig, 'getClassifierThreshold').mockResolvedValue(undefined);
|
||||
vi.spyOn(mockConfig, 'getGemmaModelRouterSettings').mockReturnValue({
|
||||
enabled: false,
|
||||
classifier: {
|
||||
host: 'http://localhost:1234',
|
||||
|
||||
@@ -76,8 +76,10 @@ export class ModelRouterService {
|
||||
const startTime = Date.now();
|
||||
let decision: RoutingDecision;
|
||||
|
||||
const enableNumericalRouting = this.config.isNumericalRoutingEnabled();
|
||||
const thresholdValue = this.config.getResolvedClassifierThreshold();
|
||||
const [enableNumericalRouting, thresholdValue] = await Promise.all([
|
||||
this.config.getNumericalRoutingEnabled(),
|
||||
this.config.getResolvedClassifierThreshold(),
|
||||
]);
|
||||
const classifierThreshold = String(thresholdValue);
|
||||
|
||||
let failed = false;
|
||||
|
||||
@@ -57,7 +57,7 @@ describe('ClassifierStrategy', () => {
|
||||
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
|
||||
},
|
||||
getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO),
|
||||
isNumericalRoutingEnabled: vi.fn().mockReturnValue(false),
|
||||
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false),
|
||||
getGemini31Launched: vi.fn().mockResolvedValue(false),
|
||||
getGemini31FlashLiteLaunched: vi.fn().mockResolvedValue(false),
|
||||
getUseCustomToolModel: vi.fn().mockImplementation(async () => {
|
||||
@@ -78,7 +78,7 @@ describe('ClassifierStrategy', () => {
|
||||
});
|
||||
|
||||
it('should return null if numerical routing is enabled and model is Gemini 3', async () => {
|
||||
vi.mocked(mockConfig.isNumericalRoutingEnabled).mockReturnValue(true);
|
||||
vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true);
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO);
|
||||
|
||||
const decision = await strategy.route(
|
||||
@@ -93,7 +93,7 @@ describe('ClassifierStrategy', () => {
|
||||
});
|
||||
|
||||
it('should NOT return null if numerical routing is enabled but model is NOT Gemini 3', async () => {
|
||||
vi.mocked(mockConfig.isNumericalRoutingEnabled).mockReturnValue(true);
|
||||
vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true);
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue({
|
||||
reasoning: 'test',
|
||||
|
||||
@@ -137,7 +137,10 @@ export class ClassifierStrategy implements RoutingStrategy {
|
||||
const startTime = Date.now();
|
||||
try {
|
||||
const model = context.requestedModel ?? config.getModel();
|
||||
if (config.isNumericalRoutingEnabled() && isGemini3Model(model, config)) {
|
||||
if (
|
||||
(await config.getNumericalRoutingEnabled()) &&
|
||||
isGemini3Model(model, config)
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -55,9 +55,9 @@ describe('NumericalClassifierStrategy', () => {
|
||||
},
|
||||
getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO),
|
||||
getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50)
|
||||
isNumericalRoutingEnabled: vi.fn().mockReturnValue(true),
|
||||
getResolvedClassifierThreshold: vi.fn().mockReturnValue(90),
|
||||
getClassifierThreshold: vi.fn().mockReturnValue(undefined),
|
||||
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
|
||||
getResolvedClassifierThreshold: vi.fn().mockResolvedValue(90),
|
||||
getClassifierThreshold: vi.fn().mockResolvedValue(undefined),
|
||||
getGemini31Launched: vi.fn().mockResolvedValue(false),
|
||||
getGemini31FlashLiteLaunched: vi.fn().mockResolvedValue(false),
|
||||
getUseCustomToolModel: vi.fn().mockImplementation(async () => {
|
||||
@@ -82,7 +82,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
});
|
||||
|
||||
it('should return null if numerical routing is disabled', async () => {
|
||||
vi.mocked(mockConfig.isNumericalRoutingEnabled).mockReturnValue(false);
|
||||
vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(false);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
@@ -211,9 +211,12 @@ describe('NumericalClassifierStrategy', () => {
|
||||
describe('Remote Threshold Logic', () => {
|
||||
it('should use the remote CLASSIFIER_THRESHOLD if provided (int value)', async () => {
|
||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(70);
|
||||
<<<<<<< HEAD
|
||||
vi.mocked(mockConfig.getResolvedClassifierThreshold).mockResolvedValue(
|
||||
70,
|
||||
);
|
||||
=======
|
||||
>>>>>>> b9035a18d (refactor(config): revert experiment getter names and async signatures)
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Test task',
|
||||
complexity_score: 60,
|
||||
@@ -241,9 +244,12 @@ describe('NumericalClassifierStrategy', () => {
|
||||
|
||||
it('should use the remote CLASSIFIER_THRESHOLD if provided (float value)', async () => {
|
||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(45.5);
|
||||
<<<<<<< HEAD
|
||||
vi.mocked(mockConfig.getResolvedClassifierThreshold).mockResolvedValue(
|
||||
45.5,
|
||||
);
|
||||
=======
|
||||
>>>>>>> b9035a18d (refactor(config): revert experiment getter names and async signatures)
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Test task',
|
||||
complexity_score: 40,
|
||||
@@ -271,9 +277,12 @@ describe('NumericalClassifierStrategy', () => {
|
||||
|
||||
it('should use PRO model if score >= remote CLASSIFIER_THRESHOLD', async () => {
|
||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(30);
|
||||
<<<<<<< HEAD
|
||||
vi.mocked(mockConfig.getResolvedClassifierThreshold).mockResolvedValue(
|
||||
30,
|
||||
);
|
||||
=======
|
||||
>>>>>>> b9035a18d (refactor(config): revert experiment getter names and async signatures)
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Test task',
|
||||
complexity_score: 35,
|
||||
|
||||
@@ -105,7 +105,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
|
||||
const startTime = Date.now();
|
||||
try {
|
||||
const model = context.requestedModel ?? config.getModel();
|
||||
if (!config.isNumericalRoutingEnabled()) {
|
||||
if (!(await config.getNumericalRoutingEnabled())) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -187,8 +187,8 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
|
||||
groupLabel: string;
|
||||
modelAlias: typeof FLASH_MODEL | typeof PRO_MODEL;
|
||||
}> {
|
||||
const threshold = config.getResolvedClassifierThreshold();
|
||||
const remoteThresholdValue = config.getClassifierThreshold();
|
||||
const threshold = await config.getResolvedClassifierThreshold();
|
||||
const remoteThresholdValue = await config.getClassifierThreshold();
|
||||
|
||||
let groupLabel: string;
|
||||
if (threshold === remoteThresholdValue) {
|
||||
|
||||
Reference in New Issue
Block a user