diff --git a/docs/cli/settings.md b/docs/cli/settings.md index 8adccba6ae..b0c12116d6 100644 --- a/docs/cli/settings.md +++ b/docs/cli/settings.md @@ -140,6 +140,7 @@ they appear in the UI. | Plan | `experimental.plan` | Enable planning features (Plan Mode and tools). | `false` | | Model Steering | `experimental.modelSteering` | Enable model steering (user hints) to guide the model during tool execution. | `false` | | Direct Web Fetch | `experimental.directWebFetch` | Enable web fetch behavior that bypasses LLM summarization. | `false` | +| Enable Gemma Model Router | `experimental.gemmaModelRouter.enabled` | Enable the Gemma Model Router. Requires a local endpoint serving Gemma via the Gemini API using LiteRT-LM shim. | `false` | ### Skills diff --git a/docs/reference/configuration.md b/docs/reference/configuration.md index 5337d973b8..c1c67803b0 100644 --- a/docs/reference/configuration.md +++ b/docs/reference/configuration.md @@ -1014,6 +1014,23 @@ their corresponding top-level category object in your `settings.json` file. - **Default:** `false` - **Requires restart:** Yes +- **`experimental.gemmaModelRouter.enabled`** (boolean): + - **Description:** Enable the Gemma Model Router. Requires a local endpoint + serving Gemma via the Gemini API using LiteRT-LM shim. + - **Default:** `false` + - **Requires restart:** Yes + +- **`experimental.gemmaModelRouter.classifier.host`** (string): + - **Description:** The host of the classifier. + - **Default:** `"http://localhost:9379"` + - **Requires restart:** Yes + +- **`experimental.gemmaModelRouter.classifier.model`** (string): + - **Description:** The model to use for the classifier. Only tested on + `gemma3-1b-gpu-custom`. + - **Default:** `"gemma3-1b-gpu-custom"` + - **Requires restart:** Yes + #### `skills` - **`skills.enabled`** (boolean): diff --git a/package-lock.json b/package-lock.json index 5f0c5f058d..82bf1c2221 100644 --- a/package-lock.json +++ b/package-lock.json @@ -2292,7 +2292,6 @@ "integrity": "sha512-t54CUOsFMappY1Jbzb7fetWeO0n6K0k/4+/ZpkS+3Joz8I4VcvY9OiEBFRYISqaI2fq5sCiPtAjRDOzVYG8m+Q==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@octokit/auth-token": "^6.0.0", "@octokit/graphql": "^9.0.2", @@ -2473,7 +2472,6 @@ "resolved": "https://registry.npmjs.org/@opentelemetry/api/-/api-1.9.0.tgz", "integrity": "sha512-3giAOQvZiH5F9bMlMiv8+GSPMeqg0dbaeo58/0SlA9sxSqZhnUtxzX9/2FzyhS9sWQf5S0GJE0AKBrFqjpeYcg==", "license": "Apache-2.0", - "peer": true, "engines": { "node": ">=8.0.0" } @@ -2523,7 +2521,6 @@ "resolved": "https://registry.npmjs.org/@opentelemetry/core/-/core-2.5.0.tgz", "integrity": "sha512-ka4H8OM6+DlUhSAZpONu0cPBtPPTQKxbxVzC4CzVx5+K4JnroJVBtDzLAMx4/3CDTJXRvVFhpFjtl4SaiTNoyQ==", "license": "Apache-2.0", - "peer": true, "dependencies": { "@opentelemetry/semantic-conventions": "^1.29.0" }, @@ -2898,7 +2895,6 @@ "resolved": "https://registry.npmjs.org/@opentelemetry/resources/-/resources-2.5.0.tgz", "integrity": "sha512-F8W52ApePshpoSrfsSk1H2yJn9aKjCrbpQF1M9Qii0GHzbfVeFUB+rc3X4aggyZD8x9Gu3Slua+s6krmq6Dt8g==", "license": "Apache-2.0", - "peer": true, "dependencies": { "@opentelemetry/core": "2.5.0", "@opentelemetry/semantic-conventions": "^1.29.0" @@ -2932,7 +2928,6 @@ "resolved": "https://registry.npmjs.org/@opentelemetry/sdk-metrics/-/sdk-metrics-2.5.0.tgz", "integrity": "sha512-BeJLtU+f5Gf905cJX9vXFQorAr6TAfK3SPvTFqP+scfIpDQEJfRaGJWta7sJgP+m4dNtBf9y3yvBKVAZZtJQVA==", "license": "Apache-2.0", - "peer": true, "dependencies": { "@opentelemetry/core": "2.5.0", "@opentelemetry/resources": "2.5.0" @@ -2987,7 +2982,6 @@ "resolved": "https://registry.npmjs.org/@opentelemetry/sdk-trace-base/-/sdk-trace-base-2.5.0.tgz", "integrity": "sha512-VzRf8LzotASEyNDUxTdaJ9IRJ1/h692WyArDBInf5puLCjxbICD6XkHgpuudis56EndyS7LYFmtTMny6UABNdQ==", "license": "Apache-2.0", - "peer": true, "dependencies": { "@opentelemetry/core": "2.5.0", "@opentelemetry/resources": "2.5.0", @@ -4184,7 +4178,6 @@ "integrity": "sha512-6mDvHUFSjyT2B2yeNx2nUgMxh9LtOWvkhIU3uePn2I2oyNymUAX1NIsdgviM4CH+JSrp2D2hsMvJOkxY+0wNRA==", "devOptional": true, "license": "MIT", - "peer": true, "dependencies": { "csstype": "^3.0.2" } @@ -4458,7 +4451,6 @@ "integrity": "sha512-klQbnPAAiGYFyI02+znpBRLyjL4/BrBd0nyWkdC0s/6xFLkXYQ8OoRrSkqacS1ddVxf/LDyODIKbQ5TgKAf/Fg==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@typescript-eslint/scope-manager": "8.56.1", "@typescript-eslint/types": "8.56.1", @@ -5306,7 +5298,6 @@ "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "license": "MIT", - "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -7860,7 +7851,6 @@ "integrity": "sha512-VmQ+sifHUbI/IcSopBCF/HO3YiHQx/AVd3UVyYL6weuwW+HvON9VYn5l6Zl1WZzPWXPNZrSQpxwkkZ/VuvJZzg==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.8.0", "@eslint-community/regexpp": "^4.12.1", @@ -8493,7 +8483,6 @@ "resolved": "https://registry.npmjs.org/express/-/express-5.2.1.tgz", "integrity": "sha512-hIS4idWWai69NezIdRt2xFVofaF4j+6INOpJlVOLDO8zXGpUVEVzIYk12UUi2JzjEzWL3IOAxcTubgz9Po0yXw==", "license": "MIT", - "peer": true, "dependencies": { "accepts": "^2.0.0", "body-parser": "^2.2.1", @@ -9788,7 +9777,6 @@ "resolved": "https://registry.npmjs.org/hono/-/hono-4.12.2.tgz", "integrity": "sha512-gJnaDHXKDayjt8ue0n8Gs0A007yKXj4Xzb8+cNjZeYsSzzwKc0Lr+OZgYwVfB0pHfUs17EPoLvrOsEaJ9mj+Tg==", "license": "MIT", - "peer": true, "engines": { "node": ">=16.9.0" } @@ -10068,7 +10056,6 @@ "resolved": "https://registry.npmjs.org/@jrichman/ink/-/ink-6.4.11.tgz", "integrity": "sha512-93LQlzT7vvZ1XJcmOMwN4s+6W334QegendeHOMnEJBlhnpIzr8bws6/aOEHG8ZCuVD/vNeeea5m1msHIdAY6ig==", "license": "MIT", - "peer": true, "dependencies": { "@alcalzone/ansi-tokenize": "^0.2.1", "ansi-escapes": "^7.0.0", @@ -13718,7 +13705,6 @@ "resolved": "https://registry.npmjs.org/react/-/react-19.2.4.tgz", "integrity": "sha512-9nfp2hYpCwOjAN+8TZFGhtWEwgvWHXqESH8qT89AT/lWklpLON22Lc8pEtnpsZz7VmawabSU0gCjnj8aC0euHQ==", "license": "MIT", - "peer": true, "engines": { "node": ">=0.10.0" } @@ -13729,7 +13715,6 @@ "integrity": "sha512-ePrwPfxAnB+7hgnEr8vpKxL9cmnp7F322t8oqcPshbIQQhDKgFDW4tjhF2wjVbdXF9O/nyuy3sQWd9JGpiLPvA==", "devOptional": true, "license": "MIT", - "peer": true, "dependencies": { "shell-quote": "^1.6.1", "ws": "^7" @@ -15689,7 +15674,6 @@ "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "license": "MIT", - "peer": true, "engines": { "node": ">=12" }, @@ -15913,8 +15897,7 @@ "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", "dev": true, - "license": "0BSD", - "peer": true + "license": "0BSD" }, "node_modules/tsx": { "version": "4.20.3", @@ -15922,7 +15905,6 @@ "integrity": "sha512-qjbnuR9Tr+FJOMBqJCW5ehvIo/buZq7vH7qD7JziU98h6l3qGy0a/yPFjwO+y0/T7GFpNgNAvEcPPVfyT8rrPQ==", "devOptional": true, "license": "MIT", - "peer": true, "dependencies": { "esbuild": "~0.25.0", "get-tsconfig": "^4.7.5" @@ -16082,7 +16064,6 @@ "integrity": "sha512-p1diW6TqL9L07nNxvRMM7hMMw4c5XOo/1ibL4aAIGmSAt9slTE1Xgw5KWuof2uTOvCg9BY7ZRi+GaF+7sfgPeQ==", "devOptional": true, "license": "Apache-2.0", - "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -16291,7 +16272,6 @@ "resolved": "https://registry.npmjs.org/vite/-/vite-7.2.2.tgz", "integrity": "sha512-BxAKBWmIbrDgrokdGZH1IgkIk/5mMHDreLDmCJ0qpyJaAteP8NvMhkwr/ZCQNqNH97bw/dANTE9PDzqwJghfMQ==", "license": "MIT", - "peer": true, "dependencies": { "esbuild": "^0.25.0", "fdir": "^6.5.0", @@ -16405,7 +16385,6 @@ "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "license": "MIT", - "peer": true, "engines": { "node": ">=12" }, @@ -16418,7 +16397,6 @@ "resolved": "https://registry.npmjs.org/vitest/-/vitest-3.2.4.tgz", "integrity": "sha512-LUCP5ev3GURDysTWiP47wRRUpLKMOfPh+yKTx3kVIEiu5KOMeqzpnYNsKyOoVrULivR8tLcks4+lga33Whn90A==", "license": "MIT", - "peer": true, "dependencies": { "@types/chai": "^5.2.2", "@vitest/expect": "3.2.4", @@ -17063,7 +17041,6 @@ "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz", "integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==", "license": "MIT", - "peer": true, "funding": { "url": "https://github.com/sponsors/colinhacks" } @@ -17463,7 +17440,6 @@ "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "license": "MIT", - "peer": true, "engines": { "node": ">=12" }, diff --git a/packages/cli/src/config/config.test.ts b/packages/cli/src/config/config.test.ts index 75812e4442..919ad86c51 100644 --- a/packages/cli/src/config/config.test.ts +++ b/packages/cli/src/config/config.test.ts @@ -2765,6 +2765,66 @@ describe('loadCliConfig approval mode', () => { }); }); +describe('loadCliConfig gemmaModelRouter', () => { + beforeEach(() => { + vi.resetAllMocks(); + vi.mocked(os.homedir).mockReturnValue('/mock/home/user'); + vi.stubEnv('GEMINI_API_KEY', 'test-api-key'); + vi.spyOn(ExtensionManager.prototype, 'getExtensions').mockReturnValue([]); + }); + + afterEach(() => { + vi.unstubAllEnvs(); + vi.restoreAllMocks(); + }); + + it('should have gemmaModelRouter disabled by default', async () => { + process.argv = ['node', 'script.js']; + const argv = await parseArguments(createTestMergedSettings()); + const settings = createTestMergedSettings(); + const config = await loadCliConfig(settings, 'test-session', argv); + expect(config.getGemmaModelRouterEnabled()).toBe(false); + }); + + it('should load gemmaModelRouter settings from merged settings', async () => { + process.argv = ['node', 'script.js']; + const argv = await parseArguments(createTestMergedSettings()); + const settings = createTestMergedSettings({ + experimental: { + gemmaModelRouter: { + enabled: true, + classifier: { + host: 'http://custom:1234', + model: 'custom-gemma', + }, + }, + }, + }); + const config = await loadCliConfig(settings, 'test-session', argv); + expect(config.getGemmaModelRouterEnabled()).toBe(true); + const gemmaSettings = config.getGemmaModelRouterSettings(); + expect(gemmaSettings.classifier?.host).toBe('http://custom:1234'); + expect(gemmaSettings.classifier?.model).toBe('custom-gemma'); + }); + + it('should handle partial gemmaModelRouter settings', async () => { + process.argv = ['node', 'script.js']; + const argv = await parseArguments(createTestMergedSettings()); + const settings = createTestMergedSettings({ + experimental: { + gemmaModelRouter: { + enabled: true, + }, + }, + }); + const config = await loadCliConfig(settings, 'test-session', argv); + expect(config.getGemmaModelRouterEnabled()).toBe(true); + const gemmaSettings = config.getGemmaModelRouterSettings(); + expect(gemmaSettings.classifier?.host).toBe('http://localhost:9379'); + expect(gemmaSettings.classifier?.model).toBe('gemma3-1b-gpu-custom'); + }); +}); + describe('loadCliConfig fileFiltering', () => { const originalArgv = process.argv; diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index 6a4bd09470..f2870a5f57 100755 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -856,6 +856,7 @@ export async function loadCliConfig( // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion format: (argv.outputFormat ?? settings.output?.format) as OutputFormat, }, + gemmaModelRouter: settings.experimental?.gemmaModelRouter, fakeResponses: argv.fakeResponses, recordResponses: argv.recordResponses, retryFetchErrors: settings.general?.retryFetchErrors, diff --git a/packages/cli/src/config/settingsSchema.test.ts b/packages/cli/src/config/settingsSchema.test.ts index ffe1dd2ac5..cf9dfc992f 100644 --- a/packages/cli/src/config/settingsSchema.test.ts +++ b/packages/cli/src/config/settingsSchema.test.ts @@ -444,6 +444,60 @@ describe('SettingsSchema', () => { expect(hookItemProperties.description).toBeDefined(); expect(hookItemProperties.description.type).toBe('string'); }); + + it('should have gemmaModelRouter setting in schema', () => { + const gemmaModelRouter = + getSettingsSchema().experimental.properties.gemmaModelRouter; + expect(gemmaModelRouter).toBeDefined(); + expect(gemmaModelRouter.type).toBe('object'); + expect(gemmaModelRouter.category).toBe('Experimental'); + expect(gemmaModelRouter.default).toEqual({}); + expect(gemmaModelRouter.requiresRestart).toBe(true); + expect(gemmaModelRouter.showInDialog).toBe(true); + expect(gemmaModelRouter.description).toBe( + 'Enable Gemma model router (experimental).', + ); + + const enabled = gemmaModelRouter.properties.enabled; + expect(enabled).toBeDefined(); + expect(enabled.type).toBe('boolean'); + expect(enabled.category).toBe('Experimental'); + expect(enabled.default).toBe(false); + expect(enabled.requiresRestart).toBe(true); + expect(enabled.showInDialog).toBe(true); + expect(enabled.description).toBe( + 'Enable the Gemma Model Router. Requires a local endpoint serving Gemma via the Gemini API using LiteRT-LM shim.', + ); + + const classifier = gemmaModelRouter.properties.classifier; + expect(classifier).toBeDefined(); + expect(classifier.type).toBe('object'); + expect(classifier.category).toBe('Experimental'); + expect(classifier.default).toEqual({}); + expect(classifier.requiresRestart).toBe(true); + expect(classifier.showInDialog).toBe(false); + expect(classifier.description).toBe('Classifier configuration.'); + + const host = classifier.properties.host; + expect(host).toBeDefined(); + expect(host.type).toBe('string'); + expect(host.category).toBe('Experimental'); + expect(host.default).toBe('http://localhost:9379'); + expect(host.requiresRestart).toBe(true); + expect(host.showInDialog).toBe(false); + expect(host.description).toBe('The host of the classifier.'); + + const model = classifier.properties.model; + expect(model).toBeDefined(); + expect(model.type).toBe('string'); + expect(model.category).toBe('Experimental'); + expect(model.default).toBe('gemma3-1b-gpu-custom'); + expect(model.requiresRestart).toBe(true); + expect(model.showInDialog).toBe(false); + expect(model.description).toBe( + 'The model to use for the classifier. Only tested on `gemma3-1b-gpu-custom`.', + ); + }); }); it('has JSON schema definitions for every referenced ref', () => { diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index 26faaafda7..48a7641766 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -1787,6 +1787,57 @@ const SETTINGS_SCHEMA = { 'Enable web fetch behavior that bypasses LLM summarization.', showInDialog: true, }, + gemmaModelRouter: { + type: 'object', + label: 'Gemma Model Router', + category: 'Experimental', + requiresRestart: true, + default: {}, + description: 'Enable Gemma model router (experimental).', + showInDialog: true, + properties: { + enabled: { + type: 'boolean', + label: 'Enable Gemma Model Router', + category: 'Experimental', + requiresRestart: true, + default: false, + description: + 'Enable the Gemma Model Router. Requires a local endpoint serving Gemma via the Gemini API using LiteRT-LM shim.', + showInDialog: true, + }, + classifier: { + type: 'object', + label: 'Classifier', + category: 'Experimental', + requiresRestart: true, + default: {}, + description: 'Classifier configuration.', + showInDialog: false, + properties: { + host: { + type: 'string', + label: 'Host', + category: 'Experimental', + requiresRestart: true, + default: 'http://localhost:9379', + description: 'The host of the classifier.', + showInDialog: false, + }, + model: { + type: 'string', + label: 'Model', + category: 'Experimental', + requiresRestart: true, + default: 'gemma3-1b-gpu-custom', + description: + 'The model to use for the classifier. Only tested on `gemma3-1b-gpu-custom`.', + showInDialog: false, + }, + }, + }, + }, + }, }, }, @@ -2532,7 +2583,9 @@ type InferSettings = { : T[K]['default'] : T[K]['default'] extends boolean ? boolean - : T[K]['default']; + : T[K]['default'] extends string + ? string + : T[K]['default']; }; type InferMergedSettings = { @@ -2544,7 +2597,9 @@ type InferMergedSettings = { : T[K]['default'] : T[K]['default'] extends boolean ? boolean - : T[K]['default']; + : T[K]['default'] extends string + ? string + : T[K]['default']; }; export type Settings = InferSettings; diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index e92f464fa2..1034246e9c 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -225,8 +225,10 @@ import type { } from '../services/modelConfigService.js'; import { ExitPlanModeTool } from '../tools/exit-plan-mode.js'; import { EnterPlanModeTool } from '../tools/enter-plan-mode.js'; +import { LocalLiteRtLmClient } from '../core/localLiteRtLmClient.js'; vi.mock('../core/baseLlmClient.js'); +vi.mock('../core/localLiteRtLmClient.js'); vi.mock('../core/tokenLimits.js', () => ({ tokenLimit: vi.fn(), })); @@ -1418,6 +1420,79 @@ describe('Server Config (config.ts)', () => { }); }); +describe('GemmaModelRouterSettings', () => { + const MODEL = DEFAULT_GEMINI_MODEL; + const SANDBOX: SandboxConfig = { + command: 'docker', + image: 'gemini-cli-sandbox', + }; + const TARGET_DIR = '/path/to/target'; + const DEBUG_MODE = false; + const QUESTION = 'test question'; + const USER_MEMORY = 'Test User Memory'; + const TELEMETRY_SETTINGS = { enabled: false }; + const EMBEDDING_MODEL = 'gemini-embedding'; + const SESSION_ID = 'test-session-id'; + const baseParams: ConfigParameters = { + cwd: '/tmp', + embeddingModel: EMBEDDING_MODEL, + sandbox: SANDBOX, + targetDir: TARGET_DIR, + debugMode: DEBUG_MODE, + question: QUESTION, + userMemory: USER_MEMORY, + telemetry: TELEMETRY_SETTINGS, + sessionId: SESSION_ID, + model: MODEL, + usageStatisticsEnabled: false, + }; + + it('should default gemmaModelRouter.enabled to false', () => { + const config = new Config(baseParams); + expect(config.getGemmaModelRouterEnabled()).toBe(false); + }); + + it('should return default gemma model router settings when not provided', () => { + const config = new Config(baseParams); + const settings = config.getGemmaModelRouterSettings(); + expect(settings.enabled).toBe(false); + expect(settings.classifier?.host).toBe('http://localhost:9379'); + expect(settings.classifier?.model).toBe('gemma3-1b-gpu-custom'); + }); + + it('should override default gemma model router settings when provided', () => { + const params: ConfigParameters = { + ...baseParams, + gemmaModelRouter: { + enabled: true, + classifier: { + host: 'http://custom:1234', + model: 'custom-gemma', + }, + }, + }; + const config = new Config(params); + const settings = config.getGemmaModelRouterSettings(); + expect(settings.enabled).toBe(true); + expect(settings.classifier?.host).toBe('http://custom:1234'); + expect(settings.classifier?.model).toBe('custom-gemma'); + }); + + it('should merge partial gemma model router settings with defaults', () => { + const params: ConfigParameters = { + ...baseParams, + gemmaModelRouter: { + enabled: true, + }, + }; + const config = new Config(params); + const settings = config.getGemmaModelRouterSettings(); + expect(settings.enabled).toBe(true); + expect(settings.classifier?.host).toBe('http://localhost:9379'); + expect(settings.classifier?.model).toBe('gemma3-1b-gpu-custom'); + }); +}); + describe('setApprovalMode with folder trust', () => { const baseParams: ConfigParameters = { sessionId: 'test', @@ -2069,6 +2144,71 @@ describe('Config getHooks', () => { }); }); +describe('LocalLiteRtLmClient Lifecycle', () => { + const MODEL = 'gemini-pro'; + const SANDBOX: SandboxConfig = { + command: 'docker', + image: 'gemini-cli-sandbox', + }; + const TARGET_DIR = '/path/to/target'; + const DEBUG_MODE = false; + const QUESTION = 'test question'; + const USER_MEMORY = 'Test User Memory'; + const TELEMETRY_SETTINGS = { enabled: false }; + const EMBEDDING_MODEL = 'gemini-embedding'; + const SESSION_ID = 'test-session-id'; + const baseParams: ConfigParameters = { + cwd: '/tmp', + embeddingModel: EMBEDDING_MODEL, + sandbox: SANDBOX, + targetDir: TARGET_DIR, + debugMode: DEBUG_MODE, + question: QUESTION, + userMemory: USER_MEMORY, + telemetry: TELEMETRY_SETTINGS, + sessionId: SESSION_ID, + model: MODEL, + usageStatisticsEnabled: false, + }; + + beforeEach(() => { + vi.clearAllMocks(); + vi.mocked(getExperiments).mockResolvedValue({ + experimentIds: [], + flags: {}, + }); + }); + + it('should successfully initialize LocalLiteRtLmClient on first call and reuse it', () => { + const config = new Config(baseParams); + const client1 = config.getLocalLiteRtLmClient(); + const client2 = config.getLocalLiteRtLmClient(); + + expect(client1).toBeDefined(); + expect(client1).toBe(client2); // Should return the same instance + }); + + it('should configure LocalLiteRtLmClient with settings from getGemmaModelRouterSettings', () => { + const customHost = 'http://my-custom-host:9999'; + const customModel = 'my-custom-gemma-model'; + const params: ConfigParameters = { + ...baseParams, + gemmaModelRouter: { + enabled: true, + classifier: { + host: customHost, + model: customModel, + }, + }, + }; + + const config = new Config(params); + config.getLocalLiteRtLmClient(); + + expect(LocalLiteRtLmClient).toHaveBeenCalledWith(config); + }); +}); + describe('Config getExperiments', () => { const baseParams: ConfigParameters = { cwd: '/tmp', diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 7297693b8e..2f5d452446 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -38,6 +38,7 @@ import { ExitPlanModeTool } from '../tools/exit-plan-mode.js'; import { EnterPlanModeTool } from '../tools/enter-plan-mode.js'; import { GeminiClient } from '../core/client.js'; import { BaseLlmClient } from '../core/baseLlmClient.js'; +import { LocalLiteRtLmClient } from '../core/localLiteRtLmClient.js'; import type { HookDefinition, HookEventName } from '../hooks/types.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; import { GitService } from '../services/gitService.js'; @@ -178,6 +179,14 @@ export interface ToolOutputMaskingConfig { protectLatestTurn: boolean; } +export interface GemmaModelRouterSettings { + enabled?: boolean; + classifier?: { + host?: string; + model?: string; + }; +} + export interface ExtensionSetting { name: string; description: string; @@ -509,6 +518,7 @@ export interface ConfigParameters { directWebFetch?: boolean; policyUpdateConfirmationRequest?: PolicyUpdateConfirmationRequest; output?: OutputSettings; + gemmaModelRouter?: GemmaModelRouterSettings; disableModelRouterForAuth?: AuthType[]; continueOnFailedApiCall?: boolean; retryFetchErrors?: boolean; @@ -599,6 +609,7 @@ export class Config { private readonly usageStatisticsEnabled: boolean; private geminiClient!: GeminiClient; private baseLlmClient!: BaseLlmClient; + private localLiteRtLmClient?: LocalLiteRtLmClient; private modelRouterService: ModelRouterService; private readonly modelAvailabilityService: ModelAvailabilityService; private readonly fileFiltering: { @@ -694,6 +705,9 @@ export class Config { | PolicyUpdateConfirmationRequest | undefined; private readonly outputSettings: OutputSettings; + + private readonly gemmaModelRouter: GemmaModelRouterSettings; + private readonly continueOnFailedApiCall: boolean; private readonly retryFetchErrors: boolean; private readonly maxAttempts: number; @@ -942,6 +956,15 @@ export class Config { this.outputSettings = { format: params.output?.format ?? OutputFormat.TEXT, }; + this.gemmaModelRouter = { + enabled: params.gemmaModelRouter?.enabled ?? false, + classifier: { + host: + params.gemmaModelRouter?.classifier?.host ?? 'http://localhost:9379', + model: + params.gemmaModelRouter?.classifier?.model ?? 'gemma3-1b-gpu-custom', + }, + }; this.retryFetchErrors = params.retryFetchErrors ?? false; this.maxAttempts = Math.min( params.maxAttempts ?? DEFAULT_MAX_ATTEMPTS, @@ -1245,6 +1268,13 @@ export class Config { return this.baseLlmClient; } + getLocalLiteRtLmClient(): LocalLiteRtLmClient { + if (!this.localLiteRtLmClient) { + this.localLiteRtLmClient = new LocalLiteRtLmClient(this); + } + return this.localLiteRtLmClient; + } + getSessionId(): string { return this.sessionId; } @@ -2578,6 +2608,14 @@ export class Config { return this.enableHooksUI; } + getGemmaModelRouterEnabled(): boolean { + return this.gemmaModelRouter.enabled ?? false; + } + + getGemmaModelRouterSettings(): GemmaModelRouterSettings { + return this.gemmaModelRouter; + } + /** * Get override settings for a specific agent. * Reads from agents.overrides.. diff --git a/packages/core/src/core/localLiteRtLmClient.test.ts b/packages/core/src/core/localLiteRtLmClient.test.ts new file mode 100644 index 0000000000..c4398b5b9c --- /dev/null +++ b/packages/core/src/core/localLiteRtLmClient.test.ts @@ -0,0 +1,125 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { LocalLiteRtLmClient } from './localLiteRtLmClient.js'; +import type { Config } from '../config/config.js'; +const mockGenerateContent = vi.fn(); + +vi.mock('@google/genai', () => { + const GoogleGenAI = vi.fn().mockImplementation(() => ({ + models: { + generateContent: mockGenerateContent, + }, + })); + return { GoogleGenAI }; +}); + +describe('LocalLiteRtLmClient', () => { + let mockConfig: Config; + + beforeEach(() => { + vi.clearAllMocks(); + mockGenerateContent.mockClear(); + + mockConfig = { + getGemmaModelRouterSettings: vi.fn().mockReturnValue({ + classifier: { + host: 'http://test-host:1234', + model: 'gemma:latest', + }, + }), + } as unknown as Config; + }); + + it('should successfully call generateJson and return parsed JSON', async () => { + mockGenerateContent.mockResolvedValue({ + text: '{"key": "value"}', + }); + + const client = new LocalLiteRtLmClient(mockConfig); + const result = await client.generateJson([], 'test-instruction'); + + expect(result).toEqual({ key: 'value' }); + expect(mockGenerateContent).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'gemma:latest', + config: expect.objectContaining({ + responseMimeType: 'application/json', + temperature: 0, + }), + }), + ); + }); + + it('should throw an error if the API response has no text', async () => { + mockGenerateContent.mockResolvedValue({ + text: null, + }); + + const client = new LocalLiteRtLmClient(mockConfig); + await expect(client.generateJson([], 'test-instruction')).rejects.toThrow( + 'Invalid response from Local Gemini API: No text found', + ); + }); + + it('should throw if the JSON is malformed', async () => { + mockGenerateContent.mockResolvedValue({ + text: `{ + “key”: ‘value’, +}`, // Smart quotes, trailing comma + }); + + const client = new LocalLiteRtLmClient(mockConfig); + await expect(client.generateJson([], 'test-instruction')).rejects.toThrow( + SyntaxError, + ); + }); + + it('should add reminder to the last user message', async () => { + mockGenerateContent.mockResolvedValue({ + text: '{"key": "value"}', + }); + + const client = new LocalLiteRtLmClient(mockConfig); + await client.generateJson( + [{ role: 'user', parts: [{ text: 'initial prompt' }] }], + 'test-instruction', + 'test-reminder', + ); + + const calledContents = + vi.mocked(mockGenerateContent).mock.calls[0][0].contents; + expect(calledContents.at(-1)?.parts[0].text).toBe( + `initial prompt + +test-reminder`, + ); + }); + + it('should pass abortSignal to generateContent', async () => { + mockGenerateContent.mockResolvedValue({ + text: '{"key": "value"}', + }); + + const client = new LocalLiteRtLmClient(mockConfig); + const controller = new AbortController(); + await client.generateJson( + [], + 'test-instruction', + undefined, + controller.signal, + ); + + expect(mockGenerateContent).toHaveBeenCalledWith( + expect.objectContaining({ + config: expect.objectContaining({ + abortSignal: controller.signal, + }), + }), + ); + }); +}); diff --git a/packages/core/src/core/localLiteRtLmClient.ts b/packages/core/src/core/localLiteRtLmClient.ts new file mode 100644 index 0000000000..8f4a020a50 --- /dev/null +++ b/packages/core/src/core/localLiteRtLmClient.ts @@ -0,0 +1,96 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { GoogleGenAI } from '@google/genai'; +import type { Config } from '../config/config.js'; +import { debugLogger } from '../utils/debugLogger.js'; +import type { Content } from '@google/genai'; + +/** + * A client for making single, non-streaming calls to a local Gemini-compatible API + * and expecting a JSON response. + */ +export class LocalLiteRtLmClient { + private readonly host: string; + private readonly model: string; + private readonly client: GoogleGenAI; + + constructor(config: Config) { + const gemmaModelRouterSettings = config.getGemmaModelRouterSettings(); + this.host = gemmaModelRouterSettings.classifier!.host!; + this.model = gemmaModelRouterSettings.classifier!.model!; + + this.client = new GoogleGenAI({ + // The LiteRT-LM server does not require an API key, but the SDK requires one to be set even for local endpoints. This is a dummy value and is not used for authentication. + apiKey: 'no-api-key-needed', + httpOptions: { + baseUrl: this.host, + // If the LiteRT-LM server is started but the wrong port is set, there will be a lengthy TCP timeout (here fixed to be 10 seconds). + // If the LiteRT-LM server is not started, there will be an immediate connection refusal. + // If the LiteRT-LM server is started and the model is unsupported or not downloaded, the server will return an error immediately. + // If the model's context window is exceeded, the server will return an error immediately. + timeout: 10000, + }, + }); + } + + /** + * Sends a prompt to the local Gemini model and expects a JSON object in response. + * @param contents The history and current prompt. + * @param systemInstruction The system prompt. + * @returns A promise that resolves to the parsed JSON object. + */ + async generateJson( + contents: Content[], + systemInstruction: string, + reminder?: string, + abortSignal?: AbortSignal, + ): Promise { + const geminiContents = contents.map((c) => ({ + role: c.role, + parts: c.parts ? c.parts.map((p) => ({ text: p.text })) : [], + })); + + if (reminder) { + const lastContent = geminiContents.at(-1); + if (lastContent?.role === 'user' && lastContent.parts?.[0]?.text) { + lastContent.parts[0].text += `\n\n${reminder}`; + } + } + + try { + const result = await this.client.models.generateContent({ + model: this.model, + contents: geminiContents, + config: { + responseMimeType: 'application/json', + systemInstruction: systemInstruction + ? { parts: [{ text: systemInstruction }] } + : undefined, + temperature: 0, + maxOutputTokens: 256, + abortSignal, + }, + }); + + const text = result.text; + if (!text) { + throw new Error( + 'Invalid response from Local Gemini API: No text found', + ); + } + + // eslint-disable-next-line @typescript-eslint/no-unsafe-return + return JSON.parse(result.text); + } catch (error) { + debugLogger.error( + `[LocalLiteRtLmClient] Failed to generate content:`, + error, + ); + throw error; + } + } +} diff --git a/packages/core/src/routing/modelRouterService.test.ts b/packages/core/src/routing/modelRouterService.test.ts index 144d8d3232..ad0e3c890e 100644 --- a/packages/core/src/routing/modelRouterService.test.ts +++ b/packages/core/src/routing/modelRouterService.test.ts @@ -9,6 +9,7 @@ import { ModelRouterService } from './modelRouterService.js'; import { Config } from '../config/config.js'; import type { BaseLlmClient } from '../core/baseLlmClient.js'; +import type { LocalLiteRtLmClient } from '../core/localLiteRtLmClient.js'; import type { RoutingContext, RoutingDecision } from './routingStrategy.js'; import { DefaultStrategy } from './strategies/defaultStrategy.js'; import { CompositeStrategy } from './strategies/compositeStrategy.js'; @@ -19,6 +20,7 @@ import { ClassifierStrategy } from './strategies/classifierStrategy.js'; import { NumericalClassifierStrategy } from './strategies/numericalClassifierStrategy.js'; import { logModelRouting } from '../telemetry/loggers.js'; import { ModelRoutingEvent } from '../telemetry/types.js'; +import { GemmaClassifierStrategy } from './strategies/gemmaClassifierStrategy.js'; import { ApprovalMode } from '../policy/types.js'; vi.mock('../config/config.js'); @@ -30,6 +32,7 @@ vi.mock('./strategies/overrideStrategy.js'); vi.mock('./strategies/approvalModeStrategy.js'); vi.mock('./strategies/classifierStrategy.js'); vi.mock('./strategies/numericalClassifierStrategy.js'); +vi.mock('./strategies/gemmaClassifierStrategy.js'); vi.mock('../telemetry/loggers.js'); vi.mock('../telemetry/types.js'); @@ -37,6 +40,7 @@ describe('ModelRouterService', () => { let service: ModelRouterService; let mockConfig: Config; let mockBaseLlmClient: BaseLlmClient; + let mockLocalLiteRtLmClient: LocalLiteRtLmClient; let mockContext: RoutingContext; let mockCompositeStrategy: CompositeStrategy; @@ -45,9 +49,20 @@ describe('ModelRouterService', () => { mockConfig = new Config({} as never); mockBaseLlmClient = {} as BaseLlmClient; + mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient; vi.spyOn(mockConfig, 'getBaseLlmClient').mockReturnValue(mockBaseLlmClient); + vi.spyOn(mockConfig, 'getLocalLiteRtLmClient').mockReturnValue( + mockLocalLiteRtLmClient, + ); vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(false); vi.spyOn(mockConfig, 'getClassifierThreshold').mockResolvedValue(undefined); + vi.spyOn(mockConfig, 'getGemmaModelRouterSettings').mockReturnValue({ + enabled: false, + classifier: { + host: 'http://localhost:1234', + model: 'gemma3-1b-gpu-custom', + }, + }); vi.spyOn(mockConfig, 'getApprovalMode').mockReturnValue( ApprovalMode.DEFAULT, ); @@ -96,6 +111,36 @@ describe('ModelRouterService', () => { expect(compositeStrategyArgs[1]).toBe('agent-router'); }); + it('should include GemmaClassifierStrategy when enabled', () => { + // Override the default mock for this specific test + vi.spyOn(mockConfig, 'getGemmaModelRouterSettings').mockReturnValue({ + enabled: true, + classifier: { + host: 'http://localhost:1234', + model: 'gemma3-1b-gpu-custom', + }, + }); + + // Clear previous mock calls from beforeEach + vi.mocked(CompositeStrategy).mockClear(); + + // Re-initialize the service to pick up the new config + service = new ModelRouterService(mockConfig); + + const compositeStrategyArgs = vi.mocked(CompositeStrategy).mock.calls[0]; + const childStrategies = compositeStrategyArgs[0]; + + expect(childStrategies.length).toBe(7); + expect(childStrategies[0]).toBeInstanceOf(FallbackStrategy); + expect(childStrategies[1]).toBeInstanceOf(OverrideStrategy); + expect(childStrategies[2]).toBeInstanceOf(ApprovalModeStrategy); + expect(childStrategies[3]).toBeInstanceOf(GemmaClassifierStrategy); + expect(childStrategies[4]).toBeInstanceOf(ClassifierStrategy); + expect(childStrategies[5]).toBeInstanceOf(NumericalClassifierStrategy); + expect(childStrategies[6]).toBeInstanceOf(DefaultStrategy); + expect(compositeStrategyArgs[1]).toBe('agent-router'); + }); + describe('route()', () => { const strategyDecision: RoutingDecision = { model: 'strategy-chosen-model', @@ -117,6 +162,7 @@ describe('ModelRouterService', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toEqual(strategyDecision); }); diff --git a/packages/core/src/routing/modelRouterService.ts b/packages/core/src/routing/modelRouterService.ts index 54cfa72259..1bd19f3622 100644 --- a/packages/core/src/routing/modelRouterService.ts +++ b/packages/core/src/routing/modelRouterService.ts @@ -4,10 +4,12 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { GemmaClassifierStrategy } from './strategies/gemmaClassifierStrategy.js'; import type { Config } from '../config/config.js'; import type { RoutingContext, RoutingDecision, + RoutingStrategy, TerminalStrategy, } from './routingStrategy.js'; import { DefaultStrategy } from './strategies/defaultStrategy.js'; @@ -35,17 +37,31 @@ export class ModelRouterService { } private initializeDefaultStrategy(): TerminalStrategy { - // Initialize the composite strategy with the desired priority order. - // The strategies are ordered in order of highest priority. + const strategies: RoutingStrategy[] = []; + + // Order matters here. Fallback and override are checked first. + strategies.push(new FallbackStrategy()); + strategies.push(new OverrideStrategy()); + + // Approval mode is next. + strategies.push(new ApprovalModeStrategy()); + + // Then, if enabled, the Gemma classifier is used. + if (this.config.getGemmaModelRouterSettings()?.enabled) { + strategies.push(new GemmaClassifierStrategy()); + } + + // The generic classifier is next. + strategies.push(new ClassifierStrategy()); + + // The numerical classifier is next. + strategies.push(new NumericalClassifierStrategy()); + + // The default strategy is the terminal strategy. + const terminalStrategy = new DefaultStrategy(); + return new CompositeStrategy( - [ - new FallbackStrategy(), - new OverrideStrategy(), - new ApprovalModeStrategy(), - new ClassifierStrategy(), - new NumericalClassifierStrategy(), - new DefaultStrategy(), - ], + [...strategies, terminalStrategy], 'agent-router', ); } @@ -75,6 +91,7 @@ export class ModelRouterService { context, this.config, this.config.getBaseLlmClient(), + this.config.getLocalLiteRtLmClient(), ); debugLogger.debug( diff --git a/packages/core/src/routing/routingStrategy.ts b/packages/core/src/routing/routingStrategy.ts index de8bcf04f1..a2f9448989 100644 --- a/packages/core/src/routing/routingStrategy.ts +++ b/packages/core/src/routing/routingStrategy.ts @@ -7,6 +7,7 @@ import type { Content, PartListUnion } from '@google/genai'; import type { BaseLlmClient } from '../core/baseLlmClient.js'; import type { Config } from '../config/config.js'; +import type { LocalLiteRtLmClient } from '../core/localLiteRtLmClient.js'; /** * The output of a routing decision. It specifies which model to use and why. @@ -58,6 +59,7 @@ export interface RoutingStrategy { context: RoutingContext, config: Config, baseLlmClient: BaseLlmClient, + localLiteRtLmClient: LocalLiteRtLmClient, ): Promise; } @@ -74,5 +76,6 @@ export interface TerminalStrategy extends RoutingStrategy { context: RoutingContext, config: Config, baseLlmClient: BaseLlmClient, + localLiteRtLmClient: LocalLiteRtLmClient, ): Promise; } diff --git a/packages/core/src/routing/strategies/classifierStrategy.test.ts b/packages/core/src/routing/strategies/classifierStrategy.test.ts index 7e024b790a..701e7de932 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.test.ts @@ -9,6 +9,7 @@ import { ClassifierStrategy } from './classifierStrategy.js'; import type { RoutingContext } from '../routingStrategy.js'; import type { Config } from '../../config/config.js'; import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; import { isFunctionCall, isFunctionResponse, @@ -34,6 +35,7 @@ describe('ClassifierStrategy', () => { let mockContext: RoutingContext; let mockConfig: Config; let mockBaseLlmClient: BaseLlmClient; + let mockLocalLiteRtLmClient: LocalLiteRtLmClient; let mockResolvedConfig: ResolvedModelConfig; beforeEach(() => { @@ -64,6 +66,7 @@ describe('ClassifierStrategy', () => { mockBaseLlmClient = { generateJson: vi.fn(), } as unknown as BaseLlmClient; + mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient; vi.spyOn(promptIdContext, 'getStore').mockReturnValue('test-prompt-id'); }); @@ -76,6 +79,7 @@ describe('ClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toBeNull(); @@ -94,6 +98,7 @@ describe('ClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).not.toBeNull(); @@ -109,7 +114,12 @@ describe('ClassifierStrategy', () => { mockApiResponse, ); - await strategy.route(mockContext, mockConfig, mockBaseLlmClient); + await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith( expect.objectContaining({ @@ -132,6 +142,7 @@ describe('ClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(mockBaseLlmClient.generateJson).toHaveBeenCalledOnce(); @@ -159,6 +170,7 @@ describe('ClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(mockBaseLlmClient.generateJson).toHaveBeenCalledOnce(); @@ -183,6 +195,7 @@ describe('ClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toBeNull(); @@ -206,6 +219,7 @@ describe('ClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toBeNull(); @@ -233,7 +247,12 @@ describe('ClassifierStrategy', () => { mockApiResponse, ); - await strategy.route(mockContext, mockConfig, mockBaseLlmClient); + await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock .calls[0][0]; @@ -269,7 +288,12 @@ describe('ClassifierStrategy', () => { mockApiResponse, ); - await strategy.route(mockContext, mockConfig, mockBaseLlmClient); + await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock .calls[0][0]; @@ -305,7 +329,12 @@ describe('ClassifierStrategy', () => { mockApiResponse, ); - await strategy.route(mockContext, mockConfig, mockBaseLlmClient); + await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock .calls[0][0]; @@ -340,6 +369,7 @@ describe('ClassifierStrategy', () => { contextWithRequestedModel, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).not.toBeNull(); @@ -363,6 +393,7 @@ describe('ClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_MODEL); @@ -386,6 +417,7 @@ describe('ClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL); diff --git a/packages/core/src/routing/strategies/classifierStrategy.ts b/packages/core/src/routing/strategies/classifierStrategy.ts index 7e54d161de..5fd6208b15 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.ts @@ -20,6 +20,7 @@ import { isFunctionResponse, } from '../../utils/messageInspectors.js'; import { debugLogger } from '../../utils/debugLogger.js'; +import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; import { LlmRole } from '../../telemetry/types.js'; import { AuthType } from '../../core/contentGenerator.js'; @@ -132,6 +133,7 @@ export class ClassifierStrategy implements RoutingStrategy { context: RoutingContext, config: Config, baseLlmClient: BaseLlmClient, + _localLiteRtLmClient: LocalLiteRtLmClient, ): Promise { const startTime = Date.now(); try { diff --git a/packages/core/src/routing/strategies/compositeStrategy.test.ts b/packages/core/src/routing/strategies/compositeStrategy.test.ts index 1be0b8a8e3..5b627a1692 100644 --- a/packages/core/src/routing/strategies/compositeStrategy.test.ts +++ b/packages/core/src/routing/strategies/compositeStrategy.test.ts @@ -16,6 +16,7 @@ import type { Config } from '../../config/config.js'; import type { BaseLlmClient } from '../../core/baseLlmClient.js'; import { debugLogger } from '../../utils/debugLogger.js'; import { coreEvents } from '../../utils/events.js'; +import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; vi.mock('../../utils/debugLogger.js', () => ({ debugLogger: { @@ -27,6 +28,7 @@ describe('CompositeStrategy', () => { let mockContext: RoutingContext; let mockConfig: Config; let mockBaseLlmClient: BaseLlmClient; + let mockLocalLiteRtLmClient: LocalLiteRtLmClient; let mockStrategy1: RoutingStrategy; let mockStrategy2: RoutingStrategy; let mockTerminalStrategy: TerminalStrategy; @@ -38,6 +40,7 @@ describe('CompositeStrategy', () => { mockContext = {} as RoutingContext; mockConfig = {} as Config; mockBaseLlmClient = {} as BaseLlmClient; + mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient; emitFeedbackSpy = vi.spyOn(coreEvents, 'emitFeedback'); @@ -84,17 +87,20 @@ describe('CompositeStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(mockStrategy1.route).toHaveBeenCalledWith( mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(mockStrategy2.route).toHaveBeenCalledWith( mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(mockTerminalStrategy.route).not.toHaveBeenCalled(); @@ -112,6 +118,7 @@ describe('CompositeStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(mockStrategy1.route).toHaveBeenCalledTimes(1); @@ -136,6 +143,7 @@ describe('CompositeStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(debugLogger.warn).toHaveBeenCalledWith( @@ -152,7 +160,12 @@ describe('CompositeStrategy', () => { const composite = new CompositeStrategy([mockTerminalStrategy]); await expect( - composite.route(mockContext, mockConfig, mockBaseLlmClient), + composite.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ), ).rejects.toThrow(terminalError); expect(emitFeedbackSpy).toHaveBeenCalledWith( @@ -182,6 +195,7 @@ describe('CompositeStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(result.model).toBe('some-model'); @@ -212,6 +226,7 @@ describe('CompositeStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(result.metadata.latencyMs).toBeGreaterThanOrEqual(0); diff --git a/packages/core/src/routing/strategies/compositeStrategy.ts b/packages/core/src/routing/strategies/compositeStrategy.ts index 29e6b96355..1706282864 100644 --- a/packages/core/src/routing/strategies/compositeStrategy.ts +++ b/packages/core/src/routing/strategies/compositeStrategy.ts @@ -14,6 +14,7 @@ import type { RoutingStrategy, TerminalStrategy, } from '../routingStrategy.js'; +import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; /** * A strategy that attempts a list of child strategies in order (Chain of Responsibility). @@ -40,6 +41,7 @@ export class CompositeStrategy implements TerminalStrategy { context: RoutingContext, config: Config, baseLlmClient: BaseLlmClient, + localLiteRtLmClient: LocalLiteRtLmClient, ): Promise { const startTime = performance.now(); @@ -57,7 +59,12 @@ export class CompositeStrategy implements TerminalStrategy { // Try non-terminal strategies, allowing them to fail gracefully. for (const strategy of nonTerminalStrategies) { try { - const decision = await strategy.route(context, config, baseLlmClient); + const decision = await strategy.route( + context, + config, + baseLlmClient, + localLiteRtLmClient, + ); if (decision) { return this.finalizeDecision(decision, startTime); } @@ -75,6 +82,7 @@ export class CompositeStrategy implements TerminalStrategy { context, config, baseLlmClient, + localLiteRtLmClient, ); return this.finalizeDecision(decision, startTime); diff --git a/packages/core/src/routing/strategies/defaultStrategy.test.ts b/packages/core/src/routing/strategies/defaultStrategy.test.ts index ceec72d171..de27a84e19 100644 --- a/packages/core/src/routing/strategies/defaultStrategy.test.ts +++ b/packages/core/src/routing/strategies/defaultStrategy.test.ts @@ -8,6 +8,7 @@ import { describe, it, expect, vi } from 'vitest'; import { DefaultStrategy } from './defaultStrategy.js'; import type { RoutingContext } from '../routingStrategy.js'; import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; import { DEFAULT_GEMINI_MODEL, PREVIEW_GEMINI_MODEL, @@ -26,8 +27,14 @@ describe('DefaultStrategy', () => { getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO), } as unknown as Config; const mockClient = {} as BaseLlmClient; + const mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient; - const decision = await strategy.route(mockContext, mockConfig, mockClient); + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + mockLocalLiteRtLmClient, + ); expect(decision).toEqual({ model: DEFAULT_GEMINI_MODEL, @@ -46,8 +53,14 @@ describe('DefaultStrategy', () => { getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO), } as unknown as Config; const mockClient = {} as BaseLlmClient; + const mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient; - const decision = await strategy.route(mockContext, mockConfig, mockClient); + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + mockLocalLiteRtLmClient, + ); expect(decision).toEqual({ model: PREVIEW_GEMINI_MODEL, @@ -66,8 +79,14 @@ describe('DefaultStrategy', () => { getModel: vi.fn().mockReturnValue(GEMINI_MODEL_ALIAS_AUTO), } as unknown as Config; const mockClient = {} as BaseLlmClient; + const mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient; - const decision = await strategy.route(mockContext, mockConfig, mockClient); + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + mockLocalLiteRtLmClient, + ); expect(decision).toEqual({ model: PREVIEW_GEMINI_MODEL, @@ -87,8 +106,14 @@ describe('DefaultStrategy', () => { getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_FLASH_MODEL), } as unknown as Config; const mockClient = {} as BaseLlmClient; + const mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient; - const decision = await strategy.route(mockContext, mockConfig, mockClient); + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + mockLocalLiteRtLmClient, + ); expect(decision).toEqual({ model: PREVIEW_GEMINI_FLASH_MODEL, diff --git a/packages/core/src/routing/strategies/defaultStrategy.ts b/packages/core/src/routing/strategies/defaultStrategy.ts index 1f5b7e54c2..d380ba7ad2 100644 --- a/packages/core/src/routing/strategies/defaultStrategy.ts +++ b/packages/core/src/routing/strategies/defaultStrategy.ts @@ -12,6 +12,7 @@ import type { TerminalStrategy, } from '../routingStrategy.js'; import { resolveModel } from '../../config/models.js'; +import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; export class DefaultStrategy implements TerminalStrategy { readonly name = 'default'; @@ -20,6 +21,7 @@ export class DefaultStrategy implements TerminalStrategy { _context: RoutingContext, config: Config, _baseLlmClient: BaseLlmClient, + _localLiteRtLmClient: LocalLiteRtLmClient, ): Promise { const defaultModel = resolveModel( config.getModel(), diff --git a/packages/core/src/routing/strategies/fallbackStrategy.test.ts b/packages/core/src/routing/strategies/fallbackStrategy.test.ts index d0be7938c4..ffe2ed6446 100644 --- a/packages/core/src/routing/strategies/fallbackStrategy.test.ts +++ b/packages/core/src/routing/strategies/fallbackStrategy.test.ts @@ -10,6 +10,7 @@ import type { RoutingContext } from '../routingStrategy.js'; import type { BaseLlmClient } from '../../core/baseLlmClient.js'; import type { Config } from '../../config/config.js'; import type { ModelAvailabilityService } from '../../availability/modelAvailabilityService.js'; +import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; import { DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_FLASH_MODEL, @@ -32,6 +33,7 @@ describe('FallbackStrategy', () => { const strategy = new FallbackStrategy(); const mockContext = {} as RoutingContext; const mockClient = {} as BaseLlmClient; + const mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient; let mockService: ModelAvailabilityService; let mockConfig: Config; @@ -51,7 +53,12 @@ describe('FallbackStrategy', () => { // Mock snapshot to return available vi.mocked(mockService.snapshot).mockReturnValue({ available: true }); - const decision = await strategy.route(mockContext, mockConfig, mockClient); + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + mockLocalLiteRtLmClient, + ); expect(decision).toBeNull(); // Should check availability of the resolved model (DEFAULT_GEMINI_MODEL) expect(mockService.snapshot).toHaveBeenCalledWith(DEFAULT_GEMINI_MODEL); @@ -69,7 +76,12 @@ describe('FallbackStrategy', () => { skipped: [], }); - const decision = await strategy.route(mockContext, mockConfig, mockClient); + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + mockLocalLiteRtLmClient, + ); expect(decision).toBeNull(); }); @@ -86,7 +98,12 @@ describe('FallbackStrategy', () => { skipped: [{ model: DEFAULT_GEMINI_MODEL, reason: 'quota' }], }); - const decision = await strategy.route(mockContext, mockConfig, mockClient); + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + mockLocalLiteRtLmClient, + ); expect(decision).not.toBeNull(); expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL); @@ -101,7 +118,12 @@ describe('FallbackStrategy', () => { vi.mocked(mockService.snapshot).mockReturnValue({ available: true }); vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO); - const decision = await strategy.route(mockContext, mockConfig, mockClient); + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + mockLocalLiteRtLmClient, + ); expect(decision).toBeNull(); // Important: check that it queried snapshot with the RESOLVED model, not 'auto' @@ -122,6 +144,7 @@ describe('FallbackStrategy', () => { contextWithRequestedModel, mockConfig, mockClient, + mockLocalLiteRtLmClient, ); expect(decision).toBeNull(); diff --git a/packages/core/src/routing/strategies/fallbackStrategy.ts b/packages/core/src/routing/strategies/fallbackStrategy.ts index a18e4fc4dd..21a080e9da 100644 --- a/packages/core/src/routing/strategies/fallbackStrategy.ts +++ b/packages/core/src/routing/strategies/fallbackStrategy.ts @@ -13,6 +13,7 @@ import type { RoutingDecision, RoutingStrategy, } from '../routingStrategy.js'; +import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; export class FallbackStrategy implements RoutingStrategy { readonly name = 'fallback'; @@ -21,6 +22,7 @@ export class FallbackStrategy implements RoutingStrategy { context: RoutingContext, config: Config, _baseLlmClient: BaseLlmClient, + _localLiteRtLmClient: LocalLiteRtLmClient, ): Promise { const requestedModel = context.requestedModel ?? config.getModel(); const resolvedModel = resolveModel( diff --git a/packages/core/src/routing/strategies/gemmaClassifierStrategy.test.ts b/packages/core/src/routing/strategies/gemmaClassifierStrategy.test.ts new file mode 100644 index 0000000000..9425208fd7 --- /dev/null +++ b/packages/core/src/routing/strategies/gemmaClassifierStrategy.test.ts @@ -0,0 +1,324 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Mock } from 'vitest'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { GemmaClassifierStrategy } from './gemmaClassifierStrategy.js'; +import type { RoutingContext } from '../routingStrategy.js'; +import type { Config } from '../../config/config.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import { + DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_MODEL, +} from '../../config/models.js'; +import type { Content } from '@google/genai'; +import { debugLogger } from '../../utils/debugLogger.js'; +import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; + +vi.mock('../../core/localLiteRtLmClient.js'); + +describe('GemmaClassifierStrategy', () => { + let strategy: GemmaClassifierStrategy; + let mockContext: RoutingContext; + let mockConfig: Config; + let mockBaseLlmClient: BaseLlmClient; + let mockLocalLiteRtLmClient: LocalLiteRtLmClient; + let mockGenerateJson: Mock; + + beforeEach(() => { + vi.clearAllMocks(); + mockGenerateJson = vi.fn(); + + mockConfig = { + getGemmaModelRouterSettings: vi.fn().mockReturnValue({ + enabled: true, + classifier: { model: 'gemma3-1b-gpu-custom' }, + }), + getModel: () => DEFAULT_GEMINI_MODEL, + getPreviewFeatures: () => false, + } as unknown as Config; + + strategy = new GemmaClassifierStrategy(); + mockContext = { + history: [], + request: 'simple task', + signal: new AbortController().signal, + }; + + mockBaseLlmClient = {} as BaseLlmClient; + mockLocalLiteRtLmClient = { + generateJson: mockGenerateJson, + } as unknown as LocalLiteRtLmClient; + }); + + it('should return null if gemma model router is disabled', async () => { + vi.mocked(mockConfig.getGemmaModelRouterSettings).mockReturnValue({ + enabled: false, + }); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); + expect(decision).toBeNull(); + }); + + it('should throw an error if the model is not gemma3-1b-gpu-custom', async () => { + vi.mocked(mockConfig.getGemmaModelRouterSettings).mockReturnValue({ + enabled: true, + classifier: { model: 'other-model' }, + }); + + await expect( + strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ), + ).rejects.toThrow('Only gemma3-1b-gpu-custom has been tested'); + }); + + it('should call generateJson with the correct parameters', async () => { + const mockApiResponse = { + reasoning: 'Simple task', + model_choice: 'flash', + }; + mockGenerateJson.mockResolvedValue(mockApiResponse); + + await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); + + expect(mockGenerateJson).toHaveBeenCalledWith( + expect.any(Array), + expect.any(String), + expect.any(String), + expect.any(AbortSignal), + ); + }); + + it('should route to FLASH model for a simple task', async () => { + const mockApiResponse = { + reasoning: 'This is a simple task.', + model_choice: 'flash', + }; + mockGenerateJson.mockResolvedValue(mockApiResponse); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); + + expect(mockGenerateJson).toHaveBeenCalledOnce(); + expect(decision).toEqual({ + model: DEFAULT_GEMINI_FLASH_MODEL, + metadata: { + source: 'GemmaClassifier', + latencyMs: expect.any(Number), + reasoning: mockApiResponse.reasoning, + }, + }); + }); + + it('should route to PRO model for a complex task', async () => { + const mockApiResponse = { + reasoning: 'This is a complex task.', + model_choice: 'pro', + }; + mockGenerateJson.mockResolvedValue(mockApiResponse); + mockContext.request = 'how do I build a spaceship?'; + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); + + expect(mockGenerateJson).toHaveBeenCalledOnce(); + expect(decision).toEqual({ + model: DEFAULT_GEMINI_MODEL, + metadata: { + source: 'GemmaClassifier', + latencyMs: expect.any(Number), + reasoning: mockApiResponse.reasoning, + }, + }); + }); + + it('should return null if the classifier API call fails', async () => { + const consoleWarnSpy = vi + .spyOn(debugLogger, 'warn') + .mockImplementation(() => {}); + const testError = new Error('API Failure'); + mockGenerateJson.mockRejectedValue(testError); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); + + expect(decision).toBeNull(); + expect(consoleWarnSpy).toHaveBeenCalled(); + consoleWarnSpy.mockRestore(); + }); + + it('should return null if the classifier returns a malformed JSON object', async () => { + const consoleWarnSpy = vi + .spyOn(debugLogger, 'warn') + .mockImplementation(() => {}); + const malformedApiResponse = { + reasoning: 'This is a simple task.', + // model_choice is missing, which will cause a Zod parsing error. + }; + mockGenerateJson.mockResolvedValue(malformedApiResponse); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); + + expect(decision).toBeNull(); + expect(consoleWarnSpy).toHaveBeenCalled(); + consoleWarnSpy.mockRestore(); + }); + + it('should filter out tool-related history before sending to classifier', async () => { + mockContext.history = [ + { role: 'user', parts: [{ text: 'call a tool' }] }, + { + role: 'model', + parts: [{ functionCall: { name: 'test_tool', args: {} } }], + }, + { + role: 'user', + parts: [ + { functionResponse: { name: 'test_tool', response: { ok: true } } }, + ], + }, + { role: 'user', parts: [{ text: 'another user turn' }] }, + ]; + const mockApiResponse = { + reasoning: 'Simple.', + model_choice: 'flash', + }; + mockGenerateJson.mockResolvedValue(mockApiResponse); + + await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); + + // Define a type for the arguments passed to the mock `generateJson` + type GenerateJsonCall = [Content[], string, string | undefined]; + const calls = mockGenerateJson.mock.calls as GenerateJsonCall[]; + const contents = calls[0][0]; + const lastTurn = contents.at(-1); + expect(lastTurn).toBeDefined(); + if (!lastTurn?.parts) { + // Fail test if parts is not defined. + expect(lastTurn?.parts).toBeDefined(); + return; + } + const expectedLastTurn = `You are provided with a **Chat History** and the user's **Current Request** below. + +#### Chat History: +call a tool + +another user turn + +#### Current Request: +"simple task" +`; + expect(lastTurn.parts.at(0)?.text).toEqual(expectedLastTurn); + }); + + it('should respect HISTORY_SEARCH_WINDOW and HISTORY_TURNS_FOR_CONTEXT', async () => { + const longHistory: Content[] = []; + for (let i = 0; i < 30; i++) { + longHistory.push({ role: 'user', parts: [{ text: `Message ${i}` }] }); + // Add noise that should be filtered + if (i % 2 === 0) { + longHistory.push({ + role: 'model', + parts: [{ functionCall: { name: 'noise', args: {} } }], + }); + } + } + mockContext.history = longHistory; + const mockApiResponse = { + reasoning: 'Simple.', + model_choice: 'flash', + }; + mockGenerateJson.mockResolvedValue(mockApiResponse); + + await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); + + const generateJsonCall = mockGenerateJson.mock.calls[0][0]; + + // There should be 1 item which is the flattened history. + expect(generateJsonCall).toHaveLength(1); + }); + + it('should filter out non-text parts from history', async () => { + mockContext.history = [ + { role: 'user', parts: [{ text: 'first message' }] }, + // This part has no `text` property and should be filtered out. + { role: 'user', parts: [{}] } as Content, + { role: 'user', parts: [{ text: 'second message' }] }, + ]; + const mockApiResponse = { + reasoning: 'Simple.', + model_choice: 'flash', + }; + mockGenerateJson.mockResolvedValue(mockApiResponse); + + await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); + + type GenerateJsonCall = [Content[], string, string | undefined]; + const calls = mockGenerateJson.mock.calls as GenerateJsonCall[]; + const contents = calls[0][0]; + const lastTurn = contents.at(-1); + expect(lastTurn).toBeDefined(); + + const expectedLastTurn = `You are provided with a **Chat History** and the user's **Current Request** below. + +#### Chat History: +first message + +second message + +#### Current Request: +"simple task" +`; + + expect(lastTurn!.parts!.at(0)!.text).toEqual(expectedLastTurn); + }); +}); diff --git a/packages/core/src/routing/strategies/gemmaClassifierStrategy.ts b/packages/core/src/routing/strategies/gemmaClassifierStrategy.ts new file mode 100644 index 0000000000..f1175cc101 --- /dev/null +++ b/packages/core/src/routing/strategies/gemmaClassifierStrategy.ts @@ -0,0 +1,232 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { z } from 'zod'; + +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import type { + RoutingContext, + RoutingDecision, + RoutingStrategy, +} from '../routingStrategy.js'; +import { resolveClassifierModel } from '../../config/models.js'; +import { createUserContent, type Content, type Part } from '@google/genai'; +import type { Config } from '../../config/config.js'; +import { + isFunctionCall, + isFunctionResponse, +} from '../../utils/messageInspectors.js'; +import { debugLogger } from '../../utils/debugLogger.js'; +import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; + +// The number of recent history turns to provide to the router for context. +const HISTORY_TURNS_FOR_CONTEXT = 4; +const HISTORY_SEARCH_WINDOW = 20; + +const FLASH_MODEL = 'flash'; +const PRO_MODEL = 'pro'; + +const COMPLEXITY_RUBRIC = `### Complexity Rubric +A task is COMPLEX (Choose \`${PRO_MODEL}\`) if it meets ONE OR MORE of the following criteria: +1. **High Operational Complexity (Est. 4+ Steps/Tool Calls):** Requires dependent actions, significant planning, or multiple coordinated changes. +2. **Strategic Planning & Conceptual Design:** Asking "how" or "why." Requires advice, architecture, or high-level strategy. +3. **High Ambiguity or Large Scope (Extensive Investigation):** Broadly defined requests requiring extensive investigation. +4. **Deep Debugging & Root Cause Analysis:** Diagnosing unknown or complex problems from symptoms. +A task is SIMPLE (Choose \`${FLASH_MODEL}\`) if it is highly specific, bounded, and has Low Operational Complexity (Est. 1-3 tool calls). Operational simplicity overrides strategic phrasing.`; + +const OUTPUT_FORMAT = `### Output Format +Respond *only* in JSON format like this: +{ + "reasoning": Your reasoning... + "model_choice": Either ${FLASH_MODEL} or ${PRO_MODEL} +} +And you must follow the following JSON schema: +{ + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": "A brief summary of the user objective, followed by a step-by-step explanation for the model choice, referencing the rubric." + }, + "model_choice": { + "type": "string", + "enum": ["${FLASH_MODEL}", "${PRO_MODEL}"] + } + }, + "required": ["reasoning", "model_choice"] +} +You must ensure that your reasoning is no more than 2 sentences long and directly references the rubric criteria. +When making your decision, the user's request should be weighted much more heavily than the surrounding context when making your determination.`; + +const LITERT_GEMMA_CLASSIFIER_SYSTEM_PROMPT = `### Role +You are the **Lead Orchestrator** for an AI system. You do not talk to users. Your sole responsibility is to analyze the **Chat History** and delegate the **Current Request** to the most appropriate **Model** based on the request's complexity. + +### Models +Choose between \`${FLASH_MODEL}\` (SIMPLE) or \`${PRO_MODEL}\` (COMPLEX). +1. \`${FLASH_MODEL}\`: A fast, efficient model for simple, well-defined tasks. +2. \`${PRO_MODEL}\`: A powerful, advanced model for complex, open-ended, or multi-step tasks. + +${COMPLEXITY_RUBRIC} + +${OUTPUT_FORMAT} + +### Examples +**Example 1 (Strategic Planning):** +*User Prompt:* "How should I architect the data pipeline for this new analytics service?" +*Your JSON Output:* +{ + "reasoning": "The user is asking for high-level architectural design and strategy. This falls under 'Strategic Planning & Conceptual Design'.", + "model_choice": "${PRO_MODEL}" +} +**Example 2 (Simple Tool Use):** +*User Prompt:* "list the files in the current directory" +*Your JSON Output:* +{ + "reasoning": "This is a direct command requiring a single tool call (ls). It has Low Operational Complexity (1 step).", + "model_choice": "${FLASH_MODEL}" +} +**Example 3 (High Operational Complexity):** +*User Prompt:* "I need to add a new 'email' field to the User schema in 'src/models/user.ts', migrate the database, and update the registration endpoint." +*Your JSON Output:* +{ + "reasoning": "This request involves multiple coordinated steps across different files and systems. This meets the criteria for High Operational Complexity (4+ steps).", + "model_choice": "${PRO_MODEL}" +} +**Example 4 (Simple Read):** +*User Prompt:* "Read the contents of 'package.json'." +*Your JSON Output:* +{ + "reasoning": "This is a direct command requiring a single read. It has Low Operational Complexity (1 step).", + "model_choice": "${FLASH_MODEL}" +} +**Example 5 (Deep Debugging):** +*User Prompt:* "I'm getting an error 'Cannot read property 'map' of undefined' when I click the save button. Can you fix it?" +*Your JSON Output:* +{ + "reasoning": "The user is reporting an error symptom without a known cause. This requires investigation and falls under 'Deep Debugging'.", + "model_choice": "${PRO_MODEL}" +} +**Example 6 (Simple Edit despite Phrasing):** +*User Prompt:* "What is the best way to rename the variable 'data' to 'userData' in 'src/utils.js'?" +*Your JSON Output:* +{ + "reasoning": "Although the user uses strategic language ('best way'), the underlying task is a localized edit. The operational complexity is low (1-2 steps).", + "model_choice": "${FLASH_MODEL}" +} +`; + +const LITERT_GEMMA_CLASSIFIER_REMINDER = `### Reminder +You are a Task Routing AI. Your sole task is to analyze the preceding **Chat History** and **Current Request** and classify its complexity. + +${COMPLEXITY_RUBRIC} + +${OUTPUT_FORMAT} +`; + +const ClassifierResponseSchema = z.object({ + reasoning: z.string(), + model_choice: z.enum([FLASH_MODEL, PRO_MODEL]), +}); + +export class GemmaClassifierStrategy implements RoutingStrategy { + readonly name = 'gemma-classifier'; + + private flattenChatHistory(turns: Content[]): Content[] { + const formattedHistory = turns + .slice(0, -1) + .map((turn) => + turn.parts + ? turn.parts + .map((part) => part.text) + .filter(Boolean) + .join('\n') + : '', + ) + .filter(Boolean) + .join('\n\n'); + + const lastTurn = turns.at(-1); + const userRequest = + lastTurn?.parts + ?.map((part: Part) => part.text) + .filter(Boolean) + .join('\n\n') ?? ''; + + const finalPrompt = `You are provided with a **Chat History** and the user's **Current Request** below. + +#### Chat History: +${formattedHistory} + +#### Current Request: +"${userRequest}" +`; + return [createUserContent(finalPrompt)]; + } + + async route( + context: RoutingContext, + config: Config, + _baseLlmClient: BaseLlmClient, + client: LocalLiteRtLmClient, + ): Promise { + const startTime = Date.now(); + const gemmaRouterSettings = config.getGemmaModelRouterSettings(); + if (!gemmaRouterSettings?.enabled) { + return null; + } + + // Only the gemma3-1b-gpu-custom model has been tested and verified. + if (gemmaRouterSettings.classifier?.model !== 'gemma3-1b-gpu-custom') { + throw new Error('Only gemma3-1b-gpu-custom has been tested'); + } + + try { + const historySlice = context.history.slice(-HISTORY_SEARCH_WINDOW); + + // Filter out tool-related turns. + // TODO - Consider using function req/res if they help accuracy. + const cleanHistory = historySlice.filter( + (content) => !isFunctionCall(content) && !isFunctionResponse(content), + ); + + // Take the last N turns from the *cleaned* history. + const finalHistory = cleanHistory.slice(-HISTORY_TURNS_FOR_CONTEXT); + + const history = [...finalHistory, createUserContent(context.request)]; + const singleMessageHistory = this.flattenChatHistory(history); + + const jsonResponse = await client.generateJson( + singleMessageHistory, + LITERT_GEMMA_CLASSIFIER_SYSTEM_PROMPT, + LITERT_GEMMA_CLASSIFIER_REMINDER, + context.signal, + ); + + const routerResponse = ClassifierResponseSchema.parse(jsonResponse); + + const reasoning = routerResponse.reasoning; + const latencyMs = Date.now() - startTime; + const selectedModel = resolveClassifierModel( + context.requestedModel ?? config.getModel(), + routerResponse.model_choice, + ); + + return { + model: selectedModel, + metadata: { + source: 'GemmaClassifier', + latencyMs, + reasoning, + }, + }; + } catch (error) { + // If the classifier fails for any reason (API error, parsing error, etc.), + // we log it and return null to allow the composite strategy to proceed. + debugLogger.warn(`[Routing] GemmaClassifierStrategy failed:`, error); + return null; + } + } +} diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts index b8f6c50282..77fc69a218 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts @@ -22,6 +22,7 @@ import { promptIdContext } from '../../utils/promptIdContext.js'; import type { Content } from '@google/genai'; import type { ResolvedModelConfig } from '../../services/modelConfigService.js'; import { debugLogger } from '../../utils/debugLogger.js'; +import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; import { AuthType } from '../../core/contentGenerator.js'; vi.mock('../../core/baseLlmClient.js'); @@ -31,6 +32,7 @@ describe('NumericalClassifierStrategy', () => { let mockContext: RoutingContext; let mockConfig: Config; let mockBaseLlmClient: BaseLlmClient; + let mockLocalLiteRtLmClient: LocalLiteRtLmClient; let mockResolvedConfig: ResolvedModelConfig; beforeEach(() => { @@ -63,6 +65,7 @@ describe('NumericalClassifierStrategy', () => { mockBaseLlmClient = { generateJson: vi.fn(), } as unknown as BaseLlmClient; + mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient; vi.spyOn(promptIdContext, 'getStore').mockReturnValue('test-prompt-id'); }); @@ -78,6 +81,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toBeNull(); @@ -91,6 +95,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toBeNull(); @@ -104,6 +109,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toBeNull(); @@ -119,7 +125,12 @@ describe('NumericalClassifierStrategy', () => { mockApiResponse, ); - await strategy.route(mockContext, mockConfig, mockBaseLlmClient); + await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock .calls[0][0]; @@ -151,6 +162,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toEqual({ @@ -177,6 +189,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toEqual({ @@ -203,6 +216,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toEqual({ @@ -229,6 +243,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toEqual({ @@ -257,6 +272,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toEqual({ @@ -283,6 +299,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toEqual({ @@ -309,6 +326,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toEqual({ @@ -337,6 +355,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toEqual({ @@ -364,6 +383,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toEqual({ @@ -391,6 +411,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toEqual({ @@ -415,6 +436,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toBeNull(); @@ -437,6 +459,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision).toBeNull(); @@ -463,7 +486,12 @@ describe('NumericalClassifierStrategy', () => { mockApiResponse, ); - await strategy.route(mockContext, mockConfig, mockBaseLlmClient); + await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock .calls[0][0]; @@ -495,7 +523,12 @@ describe('NumericalClassifierStrategy', () => { mockApiResponse, ); - await strategy.route(mockContext, mockConfig, mockBaseLlmClient); + await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock .calls[0][0]; @@ -528,7 +561,12 @@ describe('NumericalClassifierStrategy', () => { mockApiResponse, ); - await strategy.route(mockContext, mockConfig, mockBaseLlmClient); + await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock .calls[0][0]; @@ -558,6 +596,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_MODEL); @@ -579,6 +618,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL); @@ -601,6 +641,7 @@ describe('NumericalClassifierStrategy', () => { mockContext, mockConfig, mockBaseLlmClient, + mockLocalLiteRtLmClient, ); expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_MODEL); diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts index 32cc6ccbb7..39805fb43c 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts @@ -16,6 +16,7 @@ import { resolveClassifierModel, isGemini3Model } from '../../config/models.js'; import { createUserContent, Type } from '@google/genai'; import type { Config } from '../../config/config.js'; import { debugLogger } from '../../utils/debugLogger.js'; +import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; import { LlmRole } from '../../telemetry/types.js'; import { AuthType } from '../../core/contentGenerator.js'; @@ -133,6 +134,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy { context: RoutingContext, config: Config, baseLlmClient: BaseLlmClient, + _localLiteRtLmClient: LocalLiteRtLmClient, ): Promise { const startTime = Date.now(); try { diff --git a/packages/core/src/routing/strategies/overrideStrategy.test.ts b/packages/core/src/routing/strategies/overrideStrategy.test.ts index 73c1aeec62..804ee8f962 100644 --- a/packages/core/src/routing/strategies/overrideStrategy.test.ts +++ b/packages/core/src/routing/strategies/overrideStrategy.test.ts @@ -10,18 +10,25 @@ import type { RoutingContext } from '../routingStrategy.js'; import type { BaseLlmClient } from '../../core/baseLlmClient.js'; import type { Config } from '../../config/config.js'; import { DEFAULT_GEMINI_MODEL_AUTO } from '../../config/models.js'; +import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; describe('OverrideStrategy', () => { const strategy = new OverrideStrategy(); const mockContext = {} as RoutingContext; const mockClient = {} as BaseLlmClient; + const mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient; it('should return null when the override model is auto', async () => { const mockConfig = { getModel: () => DEFAULT_GEMINI_MODEL_AUTO, } as Config; - const decision = await strategy.route(mockContext, mockConfig, mockClient); + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + mockLocalLiteRtLmClient, + ); expect(decision).toBeNull(); }); @@ -31,7 +38,12 @@ describe('OverrideStrategy', () => { getModel: () => overrideModel, } as Config; - const decision = await strategy.route(mockContext, mockConfig, mockClient); + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + mockLocalLiteRtLmClient, + ); expect(decision).not.toBeNull(); expect(decision?.model).toBe(overrideModel); @@ -48,7 +60,12 @@ describe('OverrideStrategy', () => { getModel: () => overrideModel, } as Config; - const decision = await strategy.route(mockContext, mockConfig, mockClient); + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + mockLocalLiteRtLmClient, + ); expect(decision).not.toBeNull(); expect(decision?.model).toBe(overrideModel); @@ -68,6 +85,7 @@ describe('OverrideStrategy', () => { contextWithRequestedModel, mockConfig, mockClient, + mockLocalLiteRtLmClient, ); expect(decision).not.toBeNull(); diff --git a/packages/core/src/routing/strategies/overrideStrategy.ts b/packages/core/src/routing/strategies/overrideStrategy.ts index 5101ba9fe7..9a89d2af70 100644 --- a/packages/core/src/routing/strategies/overrideStrategy.ts +++ b/packages/core/src/routing/strategies/overrideStrategy.ts @@ -12,6 +12,7 @@ import type { RoutingDecision, RoutingStrategy, } from '../routingStrategy.js'; +import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; /** * Handles cases where the user explicitly specifies a model (override). @@ -23,6 +24,7 @@ export class OverrideStrategy implements RoutingStrategy { context: RoutingContext, config: Config, _baseLlmClient: BaseLlmClient, + _localLiteRtLmClient: LocalLiteRtLmClient, ): Promise { const overrideModel = context.requestedModel ?? config.getModel(); diff --git a/schemas/settings.schema.json b/schemas/settings.schema.json index 059584a73f..51bf9c84e2 100644 --- a/schemas/settings.schema.json +++ b/schemas/settings.schema.json @@ -1694,6 +1694,47 @@ "markdownDescription": "Enable web fetch behavior that bypasses LLM summarization.\n\n- Category: `Experimental`\n- Requires restart: `yes`\n- Default: `false`", "default": false, "type": "boolean" + }, + "gemmaModelRouter": { + "title": "Gemma Model Router", + "description": "Enable Gemma model router (experimental).", + "markdownDescription": "Enable Gemma model router (experimental).\n\n- Category: `Experimental`\n- Requires restart: `yes`\n- Default: `{}`", + "default": {}, + "type": "object", + "properties": { + "enabled": { + "title": "Enable Gemma Model Router", + "description": "Enable the Gemma Model Router. Requires a local endpoint serving Gemma via the Gemini API using LiteRT-LM shim.", + "markdownDescription": "Enable the Gemma Model Router. Requires a local endpoint serving Gemma via the Gemini API using LiteRT-LM shim.\n\n- Category: `Experimental`\n- Requires restart: `yes`\n- Default: `false`", + "default": false, + "type": "boolean" + }, + "classifier": { + "title": "Classifier", + "description": "Classifier configuration.", + "markdownDescription": "Classifier configuration.\n\n- Category: `Experimental`\n- Requires restart: `yes`\n- Default: `{}`", + "default": {}, + "type": "object", + "properties": { + "host": { + "title": "Host", + "description": "The host of the classifier.", + "markdownDescription": "The host of the classifier.\n\n- Category: `Experimental`\n- Requires restart: `yes`\n- Default: `http://localhost:9379`", + "default": "http://localhost:9379", + "type": "string" + }, + "model": { + "title": "Model", + "description": "The model to use for the classifier. Only tested on `gemma3-1b-gpu-custom`.", + "markdownDescription": "The model to use for the classifier. Only tested on `gemma3-1b-gpu-custom`.\n\n- Category: `Experimental`\n- Requires restart: `yes`\n- Default: `gemma3-1b-gpu-custom`", + "default": "gemma3-1b-gpu-custom", + "type": "string" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false } }, "additionalProperties": false