mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-06-11 03:46:49 -07:00
[Gemma x Gemini CLI] Add an Experimental Gemma Router that uses a LiteRT-LM shim into the Composite Model Classifier Strategy (#17231)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Allen Hutchison <adh@google.com>
This commit is contained in:
@@ -140,6 +140,7 @@ they appear in the UI.
|
||||
| Plan | `experimental.plan` | Enable planning features (Plan Mode and tools). | `false` |
|
||||
| Model Steering | `experimental.modelSteering` | Enable model steering (user hints) to guide the model during tool execution. | `false` |
|
||||
| Direct Web Fetch | `experimental.directWebFetch` | Enable web fetch behavior that bypasses LLM summarization. | `false` |
|
||||
| Enable Gemma Model Router | `experimental.gemmaModelRouter.enabled` | Enable the Gemma Model Router. Requires a local endpoint serving Gemma via the Gemini API using LiteRT-LM shim. | `false` |
|
||||
|
||||
### Skills
|
||||
|
||||
|
||||
@@ -1014,6 +1014,23 @@ their corresponding top-level category object in your `settings.json` file.
|
||||
- **Default:** `false`
|
||||
- **Requires restart:** Yes
|
||||
|
||||
- **`experimental.gemmaModelRouter.enabled`** (boolean):
|
||||
- **Description:** Enable the Gemma Model Router. Requires a local endpoint
|
||||
serving Gemma via the Gemini API using LiteRT-LM shim.
|
||||
- **Default:** `false`
|
||||
- **Requires restart:** Yes
|
||||
|
||||
- **`experimental.gemmaModelRouter.classifier.host`** (string):
|
||||
- **Description:** The host of the classifier.
|
||||
- **Default:** `"http://localhost:9379"`
|
||||
- **Requires restart:** Yes
|
||||
|
||||
- **`experimental.gemmaModelRouter.classifier.model`** (string):
|
||||
- **Description:** The model to use for the classifier. Only tested on
|
||||
`gemma3-1b-gpu-custom`.
|
||||
- **Default:** `"gemma3-1b-gpu-custom"`
|
||||
- **Requires restart:** Yes
|
||||
|
||||
#### `skills`
|
||||
|
||||
- **`skills.enabled`** (boolean):
|
||||
|
||||
Generated
+1
-25
@@ -2292,7 +2292,6 @@
|
||||
"integrity": "sha512-t54CUOsFMappY1Jbzb7fetWeO0n6K0k/4+/ZpkS+3Joz8I4VcvY9OiEBFRYISqaI2fq5sCiPtAjRDOzVYG8m+Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@octokit/auth-token": "^6.0.0",
|
||||
"@octokit/graphql": "^9.0.2",
|
||||
@@ -2473,7 +2472,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@opentelemetry/api/-/api-1.9.0.tgz",
|
||||
"integrity": "sha512-3giAOQvZiH5F9bMlMiv8+GSPMeqg0dbaeo58/0SlA9sxSqZhnUtxzX9/2FzyhS9sWQf5S0GJE0AKBrFqjpeYcg==",
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=8.0.0"
|
||||
}
|
||||
@@ -2523,7 +2521,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@opentelemetry/core/-/core-2.5.0.tgz",
|
||||
"integrity": "sha512-ka4H8OM6+DlUhSAZpONu0cPBtPPTQKxbxVzC4CzVx5+K4JnroJVBtDzLAMx4/3CDTJXRvVFhpFjtl4SaiTNoyQ==",
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@opentelemetry/semantic-conventions": "^1.29.0"
|
||||
},
|
||||
@@ -2898,7 +2895,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@opentelemetry/resources/-/resources-2.5.0.tgz",
|
||||
"integrity": "sha512-F8W52ApePshpoSrfsSk1H2yJn9aKjCrbpQF1M9Qii0GHzbfVeFUB+rc3X4aggyZD8x9Gu3Slua+s6krmq6Dt8g==",
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@opentelemetry/core": "2.5.0",
|
||||
"@opentelemetry/semantic-conventions": "^1.29.0"
|
||||
@@ -2932,7 +2928,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@opentelemetry/sdk-metrics/-/sdk-metrics-2.5.0.tgz",
|
||||
"integrity": "sha512-BeJLtU+f5Gf905cJX9vXFQorAr6TAfK3SPvTFqP+scfIpDQEJfRaGJWta7sJgP+m4dNtBf9y3yvBKVAZZtJQVA==",
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@opentelemetry/core": "2.5.0",
|
||||
"@opentelemetry/resources": "2.5.0"
|
||||
@@ -2987,7 +2982,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@opentelemetry/sdk-trace-base/-/sdk-trace-base-2.5.0.tgz",
|
||||
"integrity": "sha512-VzRf8LzotASEyNDUxTdaJ9IRJ1/h692WyArDBInf5puLCjxbICD6XkHgpuudis56EndyS7LYFmtTMny6UABNdQ==",
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@opentelemetry/core": "2.5.0",
|
||||
"@opentelemetry/resources": "2.5.0",
|
||||
@@ -4184,7 +4178,6 @@
|
||||
"integrity": "sha512-6mDvHUFSjyT2B2yeNx2nUgMxh9LtOWvkhIU3uePn2I2oyNymUAX1NIsdgviM4CH+JSrp2D2hsMvJOkxY+0wNRA==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"csstype": "^3.0.2"
|
||||
}
|
||||
@@ -4458,7 +4451,6 @@
|
||||
"integrity": "sha512-klQbnPAAiGYFyI02+znpBRLyjL4/BrBd0nyWkdC0s/6xFLkXYQ8OoRrSkqacS1ddVxf/LDyODIKbQ5TgKAf/Fg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@typescript-eslint/scope-manager": "8.56.1",
|
||||
"@typescript-eslint/types": "8.56.1",
|
||||
@@ -5306,7 +5298,6 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -7860,7 +7851,6 @@
|
||||
"integrity": "sha512-VmQ+sifHUbI/IcSopBCF/HO3YiHQx/AVd3UVyYL6weuwW+HvON9VYn5l6Zl1WZzPWXPNZrSQpxwkkZ/VuvJZzg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@eslint-community/eslint-utils": "^4.8.0",
|
||||
"@eslint-community/regexpp": "^4.12.1",
|
||||
@@ -8493,7 +8483,6 @@
|
||||
"resolved": "https://registry.npmjs.org/express/-/express-5.2.1.tgz",
|
||||
"integrity": "sha512-hIS4idWWai69NezIdRt2xFVofaF4j+6INOpJlVOLDO8zXGpUVEVzIYk12UUi2JzjEzWL3IOAxcTubgz9Po0yXw==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"accepts": "^2.0.0",
|
||||
"body-parser": "^2.2.1",
|
||||
@@ -9788,7 +9777,6 @@
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.2.tgz",
|
||||
"integrity": "sha512-gJnaDHXKDayjt8ue0n8Gs0A007yKXj4Xzb8+cNjZeYsSzzwKc0Lr+OZgYwVfB0pHfUs17EPoLvrOsEaJ9mj+Tg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=16.9.0"
|
||||
}
|
||||
@@ -10068,7 +10056,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@jrichman/ink/-/ink-6.4.11.tgz",
|
||||
"integrity": "sha512-93LQlzT7vvZ1XJcmOMwN4s+6W334QegendeHOMnEJBlhnpIzr8bws6/aOEHG8ZCuVD/vNeeea5m1msHIdAY6ig==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@alcalzone/ansi-tokenize": "^0.2.1",
|
||||
"ansi-escapes": "^7.0.0",
|
||||
@@ -13718,7 +13705,6 @@
|
||||
"resolved": "https://registry.npmjs.org/react/-/react-19.2.4.tgz",
|
||||
"integrity": "sha512-9nfp2hYpCwOjAN+8TZFGhtWEwgvWHXqESH8qT89AT/lWklpLON22Lc8pEtnpsZz7VmawabSU0gCjnj8aC0euHQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
@@ -13729,7 +13715,6 @@
|
||||
"integrity": "sha512-ePrwPfxAnB+7hgnEr8vpKxL9cmnp7F322t8oqcPshbIQQhDKgFDW4tjhF2wjVbdXF9O/nyuy3sQWd9JGpiLPvA==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"shell-quote": "^1.6.1",
|
||||
"ws": "^7"
|
||||
@@ -15689,7 +15674,6 @@
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -15913,8 +15897,7 @@
|
||||
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz",
|
||||
"integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==",
|
||||
"dev": true,
|
||||
"license": "0BSD",
|
||||
"peer": true
|
||||
"license": "0BSD"
|
||||
},
|
||||
"node_modules/tsx": {
|
||||
"version": "4.20.3",
|
||||
@@ -15922,7 +15905,6 @@
|
||||
"integrity": "sha512-qjbnuR9Tr+FJOMBqJCW5ehvIo/buZq7vH7qD7JziU98h6l3qGy0a/yPFjwO+y0/T7GFpNgNAvEcPPVfyT8rrPQ==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "~0.25.0",
|
||||
"get-tsconfig": "^4.7.5"
|
||||
@@ -16082,7 +16064,6 @@
|
||||
"integrity": "sha512-p1diW6TqL9L07nNxvRMM7hMMw4c5XOo/1ibL4aAIGmSAt9slTE1Xgw5KWuof2uTOvCg9BY7ZRi+GaF+7sfgPeQ==",
|
||||
"devOptional": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -16291,7 +16272,6 @@
|
||||
"resolved": "https://registry.npmjs.org/vite/-/vite-7.2.2.tgz",
|
||||
"integrity": "sha512-BxAKBWmIbrDgrokdGZH1IgkIk/5mMHDreLDmCJ0qpyJaAteP8NvMhkwr/ZCQNqNH97bw/dANTE9PDzqwJghfMQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.5.0",
|
||||
@@ -16405,7 +16385,6 @@
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -16418,7 +16397,6 @@
|
||||
"resolved": "https://registry.npmjs.org/vitest/-/vitest-3.2.4.tgz",
|
||||
"integrity": "sha512-LUCP5ev3GURDysTWiP47wRRUpLKMOfPh+yKTx3kVIEiu5KOMeqzpnYNsKyOoVrULivR8tLcks4+lga33Whn90A==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@types/chai": "^5.2.2",
|
||||
"@vitest/expect": "3.2.4",
|
||||
@@ -17063,7 +17041,6 @@
|
||||
"resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz",
|
||||
"integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/colinhacks"
|
||||
}
|
||||
@@ -17463,7 +17440,6 @@
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
|
||||
@@ -2765,6 +2765,66 @@ describe('loadCliConfig approval mode', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('loadCliConfig gemmaModelRouter', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
vi.mocked(os.homedir).mockReturnValue('/mock/home/user');
|
||||
vi.stubEnv('GEMINI_API_KEY', 'test-api-key');
|
||||
vi.spyOn(ExtensionManager.prototype, 'getExtensions').mockReturnValue([]);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.unstubAllEnvs();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should have gemmaModelRouter disabled by default', async () => {
|
||||
process.argv = ['node', 'script.js'];
|
||||
const argv = await parseArguments(createTestMergedSettings());
|
||||
const settings = createTestMergedSettings();
|
||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||
expect(config.getGemmaModelRouterEnabled()).toBe(false);
|
||||
});
|
||||
|
||||
it('should load gemmaModelRouter settings from merged settings', async () => {
|
||||
process.argv = ['node', 'script.js'];
|
||||
const argv = await parseArguments(createTestMergedSettings());
|
||||
const settings = createTestMergedSettings({
|
||||
experimental: {
|
||||
gemmaModelRouter: {
|
||||
enabled: true,
|
||||
classifier: {
|
||||
host: 'http://custom:1234',
|
||||
model: 'custom-gemma',
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||
expect(config.getGemmaModelRouterEnabled()).toBe(true);
|
||||
const gemmaSettings = config.getGemmaModelRouterSettings();
|
||||
expect(gemmaSettings.classifier?.host).toBe('http://custom:1234');
|
||||
expect(gemmaSettings.classifier?.model).toBe('custom-gemma');
|
||||
});
|
||||
|
||||
it('should handle partial gemmaModelRouter settings', async () => {
|
||||
process.argv = ['node', 'script.js'];
|
||||
const argv = await parseArguments(createTestMergedSettings());
|
||||
const settings = createTestMergedSettings({
|
||||
experimental: {
|
||||
gemmaModelRouter: {
|
||||
enabled: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||
expect(config.getGemmaModelRouterEnabled()).toBe(true);
|
||||
const gemmaSettings = config.getGemmaModelRouterSettings();
|
||||
expect(gemmaSettings.classifier?.host).toBe('http://localhost:9379');
|
||||
expect(gemmaSettings.classifier?.model).toBe('gemma3-1b-gpu-custom');
|
||||
});
|
||||
});
|
||||
|
||||
describe('loadCliConfig fileFiltering', () => {
|
||||
const originalArgv = process.argv;
|
||||
|
||||
|
||||
@@ -856,6 +856,7 @@ export async function loadCliConfig(
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
format: (argv.outputFormat ?? settings.output?.format) as OutputFormat,
|
||||
},
|
||||
gemmaModelRouter: settings.experimental?.gemmaModelRouter,
|
||||
fakeResponses: argv.fakeResponses,
|
||||
recordResponses: argv.recordResponses,
|
||||
retryFetchErrors: settings.general?.retryFetchErrors,
|
||||
|
||||
@@ -444,6 +444,60 @@ describe('SettingsSchema', () => {
|
||||
expect(hookItemProperties.description).toBeDefined();
|
||||
expect(hookItemProperties.description.type).toBe('string');
|
||||
});
|
||||
|
||||
it('should have gemmaModelRouter setting in schema', () => {
|
||||
const gemmaModelRouter =
|
||||
getSettingsSchema().experimental.properties.gemmaModelRouter;
|
||||
expect(gemmaModelRouter).toBeDefined();
|
||||
expect(gemmaModelRouter.type).toBe('object');
|
||||
expect(gemmaModelRouter.category).toBe('Experimental');
|
||||
expect(gemmaModelRouter.default).toEqual({});
|
||||
expect(gemmaModelRouter.requiresRestart).toBe(true);
|
||||
expect(gemmaModelRouter.showInDialog).toBe(true);
|
||||
expect(gemmaModelRouter.description).toBe(
|
||||
'Enable Gemma model router (experimental).',
|
||||
);
|
||||
|
||||
const enabled = gemmaModelRouter.properties.enabled;
|
||||
expect(enabled).toBeDefined();
|
||||
expect(enabled.type).toBe('boolean');
|
||||
expect(enabled.category).toBe('Experimental');
|
||||
expect(enabled.default).toBe(false);
|
||||
expect(enabled.requiresRestart).toBe(true);
|
||||
expect(enabled.showInDialog).toBe(true);
|
||||
expect(enabled.description).toBe(
|
||||
'Enable the Gemma Model Router. Requires a local endpoint serving Gemma via the Gemini API using LiteRT-LM shim.',
|
||||
);
|
||||
|
||||
const classifier = gemmaModelRouter.properties.classifier;
|
||||
expect(classifier).toBeDefined();
|
||||
expect(classifier.type).toBe('object');
|
||||
expect(classifier.category).toBe('Experimental');
|
||||
expect(classifier.default).toEqual({});
|
||||
expect(classifier.requiresRestart).toBe(true);
|
||||
expect(classifier.showInDialog).toBe(false);
|
||||
expect(classifier.description).toBe('Classifier configuration.');
|
||||
|
||||
const host = classifier.properties.host;
|
||||
expect(host).toBeDefined();
|
||||
expect(host.type).toBe('string');
|
||||
expect(host.category).toBe('Experimental');
|
||||
expect(host.default).toBe('http://localhost:9379');
|
||||
expect(host.requiresRestart).toBe(true);
|
||||
expect(host.showInDialog).toBe(false);
|
||||
expect(host.description).toBe('The host of the classifier.');
|
||||
|
||||
const model = classifier.properties.model;
|
||||
expect(model).toBeDefined();
|
||||
expect(model.type).toBe('string');
|
||||
expect(model.category).toBe('Experimental');
|
||||
expect(model.default).toBe('gemma3-1b-gpu-custom');
|
||||
expect(model.requiresRestart).toBe(true);
|
||||
expect(model.showInDialog).toBe(false);
|
||||
expect(model.description).toBe(
|
||||
'The model to use for the classifier. Only tested on `gemma3-1b-gpu-custom`.',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('has JSON schema definitions for every referenced ref', () => {
|
||||
|
||||
@@ -1787,6 +1787,57 @@ const SETTINGS_SCHEMA = {
|
||||
'Enable web fetch behavior that bypasses LLM summarization.',
|
||||
showInDialog: true,
|
||||
},
|
||||
gemmaModelRouter: {
|
||||
type: 'object',
|
||||
label: 'Gemma Model Router',
|
||||
category: 'Experimental',
|
||||
requiresRestart: true,
|
||||
default: {},
|
||||
description: 'Enable Gemma model router (experimental).',
|
||||
showInDialog: true,
|
||||
properties: {
|
||||
enabled: {
|
||||
type: 'boolean',
|
||||
label: 'Enable Gemma Model Router',
|
||||
category: 'Experimental',
|
||||
requiresRestart: true,
|
||||
default: false,
|
||||
description:
|
||||
'Enable the Gemma Model Router. Requires a local endpoint serving Gemma via the Gemini API using LiteRT-LM shim.',
|
||||
showInDialog: true,
|
||||
},
|
||||
classifier: {
|
||||
type: 'object',
|
||||
label: 'Classifier',
|
||||
category: 'Experimental',
|
||||
requiresRestart: true,
|
||||
default: {},
|
||||
description: 'Classifier configuration.',
|
||||
showInDialog: false,
|
||||
properties: {
|
||||
host: {
|
||||
type: 'string',
|
||||
label: 'Host',
|
||||
category: 'Experimental',
|
||||
requiresRestart: true,
|
||||
default: 'http://localhost:9379',
|
||||
description: 'The host of the classifier.',
|
||||
showInDialog: false,
|
||||
},
|
||||
model: {
|
||||
type: 'string',
|
||||
label: 'Model',
|
||||
category: 'Experimental',
|
||||
requiresRestart: true,
|
||||
default: 'gemma3-1b-gpu-custom',
|
||||
description:
|
||||
'The model to use for the classifier. Only tested on `gemma3-1b-gpu-custom`.',
|
||||
showInDialog: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -2532,7 +2583,9 @@ type InferSettings<T extends SettingsSchema> = {
|
||||
: T[K]['default']
|
||||
: T[K]['default'] extends boolean
|
||||
? boolean
|
||||
: T[K]['default'];
|
||||
: T[K]['default'] extends string
|
||||
? string
|
||||
: T[K]['default'];
|
||||
};
|
||||
|
||||
type InferMergedSettings<T extends SettingsSchema> = {
|
||||
@@ -2544,7 +2597,9 @@ type InferMergedSettings<T extends SettingsSchema> = {
|
||||
: T[K]['default']
|
||||
: T[K]['default'] extends boolean
|
||||
? boolean
|
||||
: T[K]['default'];
|
||||
: T[K]['default'] extends string
|
||||
? string
|
||||
: T[K]['default'];
|
||||
};
|
||||
|
||||
export type Settings = InferSettings<SettingsSchemaType>;
|
||||
|
||||
@@ -225,8 +225,10 @@ import type {
|
||||
} from '../services/modelConfigService.js';
|
||||
import { ExitPlanModeTool } from '../tools/exit-plan-mode.js';
|
||||
import { EnterPlanModeTool } from '../tools/enter-plan-mode.js';
|
||||
import { LocalLiteRtLmClient } from '../core/localLiteRtLmClient.js';
|
||||
|
||||
vi.mock('../core/baseLlmClient.js');
|
||||
vi.mock('../core/localLiteRtLmClient.js');
|
||||
vi.mock('../core/tokenLimits.js', () => ({
|
||||
tokenLimit: vi.fn(),
|
||||
}));
|
||||
@@ -1418,6 +1420,79 @@ describe('Server Config (config.ts)', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('GemmaModelRouterSettings', () => {
|
||||
const MODEL = DEFAULT_GEMINI_MODEL;
|
||||
const SANDBOX: SandboxConfig = {
|
||||
command: 'docker',
|
||||
image: 'gemini-cli-sandbox',
|
||||
};
|
||||
const TARGET_DIR = '/path/to/target';
|
||||
const DEBUG_MODE = false;
|
||||
const QUESTION = 'test question';
|
||||
const USER_MEMORY = 'Test User Memory';
|
||||
const TELEMETRY_SETTINGS = { enabled: false };
|
||||
const EMBEDDING_MODEL = 'gemini-embedding';
|
||||
const SESSION_ID = 'test-session-id';
|
||||
const baseParams: ConfigParameters = {
|
||||
cwd: '/tmp',
|
||||
embeddingModel: EMBEDDING_MODEL,
|
||||
sandbox: SANDBOX,
|
||||
targetDir: TARGET_DIR,
|
||||
debugMode: DEBUG_MODE,
|
||||
question: QUESTION,
|
||||
userMemory: USER_MEMORY,
|
||||
telemetry: TELEMETRY_SETTINGS,
|
||||
sessionId: SESSION_ID,
|
||||
model: MODEL,
|
||||
usageStatisticsEnabled: false,
|
||||
};
|
||||
|
||||
it('should default gemmaModelRouter.enabled to false', () => {
|
||||
const config = new Config(baseParams);
|
||||
expect(config.getGemmaModelRouterEnabled()).toBe(false);
|
||||
});
|
||||
|
||||
it('should return default gemma model router settings when not provided', () => {
|
||||
const config = new Config(baseParams);
|
||||
const settings = config.getGemmaModelRouterSettings();
|
||||
expect(settings.enabled).toBe(false);
|
||||
expect(settings.classifier?.host).toBe('http://localhost:9379');
|
||||
expect(settings.classifier?.model).toBe('gemma3-1b-gpu-custom');
|
||||
});
|
||||
|
||||
it('should override default gemma model router settings when provided', () => {
|
||||
const params: ConfigParameters = {
|
||||
...baseParams,
|
||||
gemmaModelRouter: {
|
||||
enabled: true,
|
||||
classifier: {
|
||||
host: 'http://custom:1234',
|
||||
model: 'custom-gemma',
|
||||
},
|
||||
},
|
||||
};
|
||||
const config = new Config(params);
|
||||
const settings = config.getGemmaModelRouterSettings();
|
||||
expect(settings.enabled).toBe(true);
|
||||
expect(settings.classifier?.host).toBe('http://custom:1234');
|
||||
expect(settings.classifier?.model).toBe('custom-gemma');
|
||||
});
|
||||
|
||||
it('should merge partial gemma model router settings with defaults', () => {
|
||||
const params: ConfigParameters = {
|
||||
...baseParams,
|
||||
gemmaModelRouter: {
|
||||
enabled: true,
|
||||
},
|
||||
};
|
||||
const config = new Config(params);
|
||||
const settings = config.getGemmaModelRouterSettings();
|
||||
expect(settings.enabled).toBe(true);
|
||||
expect(settings.classifier?.host).toBe('http://localhost:9379');
|
||||
expect(settings.classifier?.model).toBe('gemma3-1b-gpu-custom');
|
||||
});
|
||||
});
|
||||
|
||||
describe('setApprovalMode with folder trust', () => {
|
||||
const baseParams: ConfigParameters = {
|
||||
sessionId: 'test',
|
||||
@@ -2069,6 +2144,71 @@ describe('Config getHooks', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('LocalLiteRtLmClient Lifecycle', () => {
|
||||
const MODEL = 'gemini-pro';
|
||||
const SANDBOX: SandboxConfig = {
|
||||
command: 'docker',
|
||||
image: 'gemini-cli-sandbox',
|
||||
};
|
||||
const TARGET_DIR = '/path/to/target';
|
||||
const DEBUG_MODE = false;
|
||||
const QUESTION = 'test question';
|
||||
const USER_MEMORY = 'Test User Memory';
|
||||
const TELEMETRY_SETTINGS = { enabled: false };
|
||||
const EMBEDDING_MODEL = 'gemini-embedding';
|
||||
const SESSION_ID = 'test-session-id';
|
||||
const baseParams: ConfigParameters = {
|
||||
cwd: '/tmp',
|
||||
embeddingModel: EMBEDDING_MODEL,
|
||||
sandbox: SANDBOX,
|
||||
targetDir: TARGET_DIR,
|
||||
debugMode: DEBUG_MODE,
|
||||
question: QUESTION,
|
||||
userMemory: USER_MEMORY,
|
||||
telemetry: TELEMETRY_SETTINGS,
|
||||
sessionId: SESSION_ID,
|
||||
model: MODEL,
|
||||
usageStatisticsEnabled: false,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.mocked(getExperiments).mockResolvedValue({
|
||||
experimentIds: [],
|
||||
flags: {},
|
||||
});
|
||||
});
|
||||
|
||||
it('should successfully initialize LocalLiteRtLmClient on first call and reuse it', () => {
|
||||
const config = new Config(baseParams);
|
||||
const client1 = config.getLocalLiteRtLmClient();
|
||||
const client2 = config.getLocalLiteRtLmClient();
|
||||
|
||||
expect(client1).toBeDefined();
|
||||
expect(client1).toBe(client2); // Should return the same instance
|
||||
});
|
||||
|
||||
it('should configure LocalLiteRtLmClient with settings from getGemmaModelRouterSettings', () => {
|
||||
const customHost = 'http://my-custom-host:9999';
|
||||
const customModel = 'my-custom-gemma-model';
|
||||
const params: ConfigParameters = {
|
||||
...baseParams,
|
||||
gemmaModelRouter: {
|
||||
enabled: true,
|
||||
classifier: {
|
||||
host: customHost,
|
||||
model: customModel,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const config = new Config(params);
|
||||
config.getLocalLiteRtLmClient();
|
||||
|
||||
expect(LocalLiteRtLmClient).toHaveBeenCalledWith(config);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Config getExperiments', () => {
|
||||
const baseParams: ConfigParameters = {
|
||||
cwd: '/tmp',
|
||||
|
||||
@@ -38,6 +38,7 @@ import { ExitPlanModeTool } from '../tools/exit-plan-mode.js';
|
||||
import { EnterPlanModeTool } from '../tools/enter-plan-mode.js';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
import { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import { LocalLiteRtLmClient } from '../core/localLiteRtLmClient.js';
|
||||
import type { HookDefinition, HookEventName } from '../hooks/types.js';
|
||||
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
|
||||
import { GitService } from '../services/gitService.js';
|
||||
@@ -178,6 +179,14 @@ export interface ToolOutputMaskingConfig {
|
||||
protectLatestTurn: boolean;
|
||||
}
|
||||
|
||||
export interface GemmaModelRouterSettings {
|
||||
enabled?: boolean;
|
||||
classifier?: {
|
||||
host?: string;
|
||||
model?: string;
|
||||
};
|
||||
}
|
||||
|
||||
export interface ExtensionSetting {
|
||||
name: string;
|
||||
description: string;
|
||||
@@ -509,6 +518,7 @@ export interface ConfigParameters {
|
||||
directWebFetch?: boolean;
|
||||
policyUpdateConfirmationRequest?: PolicyUpdateConfirmationRequest;
|
||||
output?: OutputSettings;
|
||||
gemmaModelRouter?: GemmaModelRouterSettings;
|
||||
disableModelRouterForAuth?: AuthType[];
|
||||
continueOnFailedApiCall?: boolean;
|
||||
retryFetchErrors?: boolean;
|
||||
@@ -599,6 +609,7 @@ export class Config {
|
||||
private readonly usageStatisticsEnabled: boolean;
|
||||
private geminiClient!: GeminiClient;
|
||||
private baseLlmClient!: BaseLlmClient;
|
||||
private localLiteRtLmClient?: LocalLiteRtLmClient;
|
||||
private modelRouterService: ModelRouterService;
|
||||
private readonly modelAvailabilityService: ModelAvailabilityService;
|
||||
private readonly fileFiltering: {
|
||||
@@ -694,6 +705,9 @@ export class Config {
|
||||
| PolicyUpdateConfirmationRequest
|
||||
| undefined;
|
||||
private readonly outputSettings: OutputSettings;
|
||||
|
||||
private readonly gemmaModelRouter: GemmaModelRouterSettings;
|
||||
|
||||
private readonly continueOnFailedApiCall: boolean;
|
||||
private readonly retryFetchErrors: boolean;
|
||||
private readonly maxAttempts: number;
|
||||
@@ -942,6 +956,15 @@ export class Config {
|
||||
this.outputSettings = {
|
||||
format: params.output?.format ?? OutputFormat.TEXT,
|
||||
};
|
||||
this.gemmaModelRouter = {
|
||||
enabled: params.gemmaModelRouter?.enabled ?? false,
|
||||
classifier: {
|
||||
host:
|
||||
params.gemmaModelRouter?.classifier?.host ?? 'http://localhost:9379',
|
||||
model:
|
||||
params.gemmaModelRouter?.classifier?.model ?? 'gemma3-1b-gpu-custom',
|
||||
},
|
||||
};
|
||||
this.retryFetchErrors = params.retryFetchErrors ?? false;
|
||||
this.maxAttempts = Math.min(
|
||||
params.maxAttempts ?? DEFAULT_MAX_ATTEMPTS,
|
||||
@@ -1245,6 +1268,13 @@ export class Config {
|
||||
return this.baseLlmClient;
|
||||
}
|
||||
|
||||
getLocalLiteRtLmClient(): LocalLiteRtLmClient {
|
||||
if (!this.localLiteRtLmClient) {
|
||||
this.localLiteRtLmClient = new LocalLiteRtLmClient(this);
|
||||
}
|
||||
return this.localLiteRtLmClient;
|
||||
}
|
||||
|
||||
getSessionId(): string {
|
||||
return this.sessionId;
|
||||
}
|
||||
@@ -2578,6 +2608,14 @@ export class Config {
|
||||
return this.enableHooksUI;
|
||||
}
|
||||
|
||||
getGemmaModelRouterEnabled(): boolean {
|
||||
return this.gemmaModelRouter.enabled ?? false;
|
||||
}
|
||||
|
||||
getGemmaModelRouterSettings(): GemmaModelRouterSettings {
|
||||
return this.gemmaModelRouter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get override settings for a specific agent.
|
||||
* Reads from agents.overrides.<agentName>.
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { LocalLiteRtLmClient } from './localLiteRtLmClient.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
const mockGenerateContent = vi.fn();
|
||||
|
||||
vi.mock('@google/genai', () => {
|
||||
const GoogleGenAI = vi.fn().mockImplementation(() => ({
|
||||
models: {
|
||||
generateContent: mockGenerateContent,
|
||||
},
|
||||
}));
|
||||
return { GoogleGenAI };
|
||||
});
|
||||
|
||||
describe('LocalLiteRtLmClient', () => {
|
||||
let mockConfig: Config;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockGenerateContent.mockClear();
|
||||
|
||||
mockConfig = {
|
||||
getGemmaModelRouterSettings: vi.fn().mockReturnValue({
|
||||
classifier: {
|
||||
host: 'http://test-host:1234',
|
||||
model: 'gemma:latest',
|
||||
},
|
||||
}),
|
||||
} as unknown as Config;
|
||||
});
|
||||
|
||||
it('should successfully call generateJson and return parsed JSON', async () => {
|
||||
mockGenerateContent.mockResolvedValue({
|
||||
text: '{"key": "value"}',
|
||||
});
|
||||
|
||||
const client = new LocalLiteRtLmClient(mockConfig);
|
||||
const result = await client.generateJson([], 'test-instruction');
|
||||
|
||||
expect(result).toEqual({ key: 'value' });
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'gemma:latest',
|
||||
config: expect.objectContaining({
|
||||
responseMimeType: 'application/json',
|
||||
temperature: 0,
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if the API response has no text', async () => {
|
||||
mockGenerateContent.mockResolvedValue({
|
||||
text: null,
|
||||
});
|
||||
|
||||
const client = new LocalLiteRtLmClient(mockConfig);
|
||||
await expect(client.generateJson([], 'test-instruction')).rejects.toThrow(
|
||||
'Invalid response from Local Gemini API: No text found',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw if the JSON is malformed', async () => {
|
||||
mockGenerateContent.mockResolvedValue({
|
||||
text: `{
|
||||
“key”: ‘value’,
|
||||
}`, // Smart quotes, trailing comma
|
||||
});
|
||||
|
||||
const client = new LocalLiteRtLmClient(mockConfig);
|
||||
await expect(client.generateJson([], 'test-instruction')).rejects.toThrow(
|
||||
SyntaxError,
|
||||
);
|
||||
});
|
||||
|
||||
it('should add reminder to the last user message', async () => {
|
||||
mockGenerateContent.mockResolvedValue({
|
||||
text: '{"key": "value"}',
|
||||
});
|
||||
|
||||
const client = new LocalLiteRtLmClient(mockConfig);
|
||||
await client.generateJson(
|
||||
[{ role: 'user', parts: [{ text: 'initial prompt' }] }],
|
||||
'test-instruction',
|
||||
'test-reminder',
|
||||
);
|
||||
|
||||
const calledContents =
|
||||
vi.mocked(mockGenerateContent).mock.calls[0][0].contents;
|
||||
expect(calledContents.at(-1)?.parts[0].text).toBe(
|
||||
`initial prompt
|
||||
|
||||
test-reminder`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass abortSignal to generateContent', async () => {
|
||||
mockGenerateContent.mockResolvedValue({
|
||||
text: '{"key": "value"}',
|
||||
});
|
||||
|
||||
const client = new LocalLiteRtLmClient(mockConfig);
|
||||
const controller = new AbortController();
|
||||
await client.generateJson(
|
||||
[],
|
||||
'test-instruction',
|
||||
undefined,
|
||||
controller.signal,
|
||||
);
|
||||
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
config: expect.objectContaining({
|
||||
abortSignal: controller.signal,
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,96 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { GoogleGenAI } from '@google/genai';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { Content } from '@google/genai';
|
||||
|
||||
/**
|
||||
* A client for making single, non-streaming calls to a local Gemini-compatible API
|
||||
* and expecting a JSON response.
|
||||
*/
|
||||
export class LocalLiteRtLmClient {
|
||||
private readonly host: string;
|
||||
private readonly model: string;
|
||||
private readonly client: GoogleGenAI;
|
||||
|
||||
constructor(config: Config) {
|
||||
const gemmaModelRouterSettings = config.getGemmaModelRouterSettings();
|
||||
this.host = gemmaModelRouterSettings.classifier!.host!;
|
||||
this.model = gemmaModelRouterSettings.classifier!.model!;
|
||||
|
||||
this.client = new GoogleGenAI({
|
||||
// The LiteRT-LM server does not require an API key, but the SDK requires one to be set even for local endpoints. This is a dummy value and is not used for authentication.
|
||||
apiKey: 'no-api-key-needed',
|
||||
httpOptions: {
|
||||
baseUrl: this.host,
|
||||
// If the LiteRT-LM server is started but the wrong port is set, there will be a lengthy TCP timeout (here fixed to be 10 seconds).
|
||||
// If the LiteRT-LM server is not started, there will be an immediate connection refusal.
|
||||
// If the LiteRT-LM server is started and the model is unsupported or not downloaded, the server will return an error immediately.
|
||||
// If the model's context window is exceeded, the server will return an error immediately.
|
||||
timeout: 10000,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a prompt to the local Gemini model and expects a JSON object in response.
|
||||
* @param contents The history and current prompt.
|
||||
* @param systemInstruction The system prompt.
|
||||
* @returns A promise that resolves to the parsed JSON object.
|
||||
*/
|
||||
async generateJson(
|
||||
contents: Content[],
|
||||
systemInstruction: string,
|
||||
reminder?: string,
|
||||
abortSignal?: AbortSignal,
|
||||
): Promise<object> {
|
||||
const geminiContents = contents.map((c) => ({
|
||||
role: c.role,
|
||||
parts: c.parts ? c.parts.map((p) => ({ text: p.text })) : [],
|
||||
}));
|
||||
|
||||
if (reminder) {
|
||||
const lastContent = geminiContents.at(-1);
|
||||
if (lastContent?.role === 'user' && lastContent.parts?.[0]?.text) {
|
||||
lastContent.parts[0].text += `\n\n${reminder}`;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await this.client.models.generateContent({
|
||||
model: this.model,
|
||||
contents: geminiContents,
|
||||
config: {
|
||||
responseMimeType: 'application/json',
|
||||
systemInstruction: systemInstruction
|
||||
? { parts: [{ text: systemInstruction }] }
|
||||
: undefined,
|
||||
temperature: 0,
|
||||
maxOutputTokens: 256,
|
||||
abortSignal,
|
||||
},
|
||||
});
|
||||
|
||||
const text = result.text;
|
||||
if (!text) {
|
||||
throw new Error(
|
||||
'Invalid response from Local Gemini API: No text found',
|
||||
);
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
|
||||
return JSON.parse(result.text);
|
||||
} catch (error) {
|
||||
debugLogger.error(
|
||||
`[LocalLiteRtLmClient] Failed to generate content:`,
|
||||
error,
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import { ModelRouterService } from './modelRouterService.js';
|
||||
import { Config } from '../config/config.js';
|
||||
|
||||
import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import type { LocalLiteRtLmClient } from '../core/localLiteRtLmClient.js';
|
||||
import type { RoutingContext, RoutingDecision } from './routingStrategy.js';
|
||||
import { DefaultStrategy } from './strategies/defaultStrategy.js';
|
||||
import { CompositeStrategy } from './strategies/compositeStrategy.js';
|
||||
@@ -19,6 +20,7 @@ import { ClassifierStrategy } from './strategies/classifierStrategy.js';
|
||||
import { NumericalClassifierStrategy } from './strategies/numericalClassifierStrategy.js';
|
||||
import { logModelRouting } from '../telemetry/loggers.js';
|
||||
import { ModelRoutingEvent } from '../telemetry/types.js';
|
||||
import { GemmaClassifierStrategy } from './strategies/gemmaClassifierStrategy.js';
|
||||
import { ApprovalMode } from '../policy/types.js';
|
||||
|
||||
vi.mock('../config/config.js');
|
||||
@@ -30,6 +32,7 @@ vi.mock('./strategies/overrideStrategy.js');
|
||||
vi.mock('./strategies/approvalModeStrategy.js');
|
||||
vi.mock('./strategies/classifierStrategy.js');
|
||||
vi.mock('./strategies/numericalClassifierStrategy.js');
|
||||
vi.mock('./strategies/gemmaClassifierStrategy.js');
|
||||
vi.mock('../telemetry/loggers.js');
|
||||
vi.mock('../telemetry/types.js');
|
||||
|
||||
@@ -37,6 +40,7 @@ describe('ModelRouterService', () => {
|
||||
let service: ModelRouterService;
|
||||
let mockConfig: Config;
|
||||
let mockBaseLlmClient: BaseLlmClient;
|
||||
let mockLocalLiteRtLmClient: LocalLiteRtLmClient;
|
||||
let mockContext: RoutingContext;
|
||||
let mockCompositeStrategy: CompositeStrategy;
|
||||
|
||||
@@ -45,9 +49,20 @@ describe('ModelRouterService', () => {
|
||||
|
||||
mockConfig = new Config({} as never);
|
||||
mockBaseLlmClient = {} as BaseLlmClient;
|
||||
mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient;
|
||||
vi.spyOn(mockConfig, 'getBaseLlmClient').mockReturnValue(mockBaseLlmClient);
|
||||
vi.spyOn(mockConfig, 'getLocalLiteRtLmClient').mockReturnValue(
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(false);
|
||||
vi.spyOn(mockConfig, 'getClassifierThreshold').mockResolvedValue(undefined);
|
||||
vi.spyOn(mockConfig, 'getGemmaModelRouterSettings').mockReturnValue({
|
||||
enabled: false,
|
||||
classifier: {
|
||||
host: 'http://localhost:1234',
|
||||
model: 'gemma3-1b-gpu-custom',
|
||||
},
|
||||
});
|
||||
vi.spyOn(mockConfig, 'getApprovalMode').mockReturnValue(
|
||||
ApprovalMode.DEFAULT,
|
||||
);
|
||||
@@ -96,6 +111,36 @@ describe('ModelRouterService', () => {
|
||||
expect(compositeStrategyArgs[1]).toBe('agent-router');
|
||||
});
|
||||
|
||||
it('should include GemmaClassifierStrategy when enabled', () => {
|
||||
// Override the default mock for this specific test
|
||||
vi.spyOn(mockConfig, 'getGemmaModelRouterSettings').mockReturnValue({
|
||||
enabled: true,
|
||||
classifier: {
|
||||
host: 'http://localhost:1234',
|
||||
model: 'gemma3-1b-gpu-custom',
|
||||
},
|
||||
});
|
||||
|
||||
// Clear previous mock calls from beforeEach
|
||||
vi.mocked(CompositeStrategy).mockClear();
|
||||
|
||||
// Re-initialize the service to pick up the new config
|
||||
service = new ModelRouterService(mockConfig);
|
||||
|
||||
const compositeStrategyArgs = vi.mocked(CompositeStrategy).mock.calls[0];
|
||||
const childStrategies = compositeStrategyArgs[0];
|
||||
|
||||
expect(childStrategies.length).toBe(7);
|
||||
expect(childStrategies[0]).toBeInstanceOf(FallbackStrategy);
|
||||
expect(childStrategies[1]).toBeInstanceOf(OverrideStrategy);
|
||||
expect(childStrategies[2]).toBeInstanceOf(ApprovalModeStrategy);
|
||||
expect(childStrategies[3]).toBeInstanceOf(GemmaClassifierStrategy);
|
||||
expect(childStrategies[4]).toBeInstanceOf(ClassifierStrategy);
|
||||
expect(childStrategies[5]).toBeInstanceOf(NumericalClassifierStrategy);
|
||||
expect(childStrategies[6]).toBeInstanceOf(DefaultStrategy);
|
||||
expect(compositeStrategyArgs[1]).toBe('agent-router');
|
||||
});
|
||||
|
||||
describe('route()', () => {
|
||||
const strategyDecision: RoutingDecision = {
|
||||
model: 'strategy-chosen-model',
|
||||
@@ -117,6 +162,7 @@ describe('ModelRouterService', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
expect(decision).toEqual(strategyDecision);
|
||||
});
|
||||
|
||||
@@ -4,10 +4,12 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { GemmaClassifierStrategy } from './strategies/gemmaClassifierStrategy.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type {
|
||||
RoutingContext,
|
||||
RoutingDecision,
|
||||
RoutingStrategy,
|
||||
TerminalStrategy,
|
||||
} from './routingStrategy.js';
|
||||
import { DefaultStrategy } from './strategies/defaultStrategy.js';
|
||||
@@ -35,17 +37,31 @@ export class ModelRouterService {
|
||||
}
|
||||
|
||||
private initializeDefaultStrategy(): TerminalStrategy {
|
||||
// Initialize the composite strategy with the desired priority order.
|
||||
// The strategies are ordered in order of highest priority.
|
||||
const strategies: RoutingStrategy[] = [];
|
||||
|
||||
// Order matters here. Fallback and override are checked first.
|
||||
strategies.push(new FallbackStrategy());
|
||||
strategies.push(new OverrideStrategy());
|
||||
|
||||
// Approval mode is next.
|
||||
strategies.push(new ApprovalModeStrategy());
|
||||
|
||||
// Then, if enabled, the Gemma classifier is used.
|
||||
if (this.config.getGemmaModelRouterSettings()?.enabled) {
|
||||
strategies.push(new GemmaClassifierStrategy());
|
||||
}
|
||||
|
||||
// The generic classifier is next.
|
||||
strategies.push(new ClassifierStrategy());
|
||||
|
||||
// The numerical classifier is next.
|
||||
strategies.push(new NumericalClassifierStrategy());
|
||||
|
||||
// The default strategy is the terminal strategy.
|
||||
const terminalStrategy = new DefaultStrategy();
|
||||
|
||||
return new CompositeStrategy(
|
||||
[
|
||||
new FallbackStrategy(),
|
||||
new OverrideStrategy(),
|
||||
new ApprovalModeStrategy(),
|
||||
new ClassifierStrategy(),
|
||||
new NumericalClassifierStrategy(),
|
||||
new DefaultStrategy(),
|
||||
],
|
||||
[...strategies, terminalStrategy],
|
||||
'agent-router',
|
||||
);
|
||||
}
|
||||
@@ -75,6 +91,7 @@ export class ModelRouterService {
|
||||
context,
|
||||
this.config,
|
||||
this.config.getBaseLlmClient(),
|
||||
this.config.getLocalLiteRtLmClient(),
|
||||
);
|
||||
|
||||
debugLogger.debug(
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
import type { Content, PartListUnion } from '@google/genai';
|
||||
import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { LocalLiteRtLmClient } from '../core/localLiteRtLmClient.js';
|
||||
|
||||
/**
|
||||
* The output of a routing decision. It specifies which model to use and why.
|
||||
@@ -58,6 +59,7 @@ export interface RoutingStrategy {
|
||||
context: RoutingContext,
|
||||
config: Config,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
localLiteRtLmClient: LocalLiteRtLmClient,
|
||||
): Promise<RoutingDecision | null>;
|
||||
}
|
||||
|
||||
@@ -74,5 +76,6 @@ export interface TerminalStrategy extends RoutingStrategy {
|
||||
context: RoutingContext,
|
||||
config: Config,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
localLiteRtLmClient: LocalLiteRtLmClient,
|
||||
): Promise<RoutingDecision>;
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import { ClassifierStrategy } from './classifierStrategy.js';
|
||||
import type { RoutingContext } from '../routingStrategy.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
|
||||
import {
|
||||
isFunctionCall,
|
||||
isFunctionResponse,
|
||||
@@ -34,6 +35,7 @@ describe('ClassifierStrategy', () => {
|
||||
let mockContext: RoutingContext;
|
||||
let mockConfig: Config;
|
||||
let mockBaseLlmClient: BaseLlmClient;
|
||||
let mockLocalLiteRtLmClient: LocalLiteRtLmClient;
|
||||
let mockResolvedConfig: ResolvedModelConfig;
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -64,6 +66,7 @@ describe('ClassifierStrategy', () => {
|
||||
mockBaseLlmClient = {
|
||||
generateJson: vi.fn(),
|
||||
} as unknown as BaseLlmClient;
|
||||
mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient;
|
||||
|
||||
vi.spyOn(promptIdContext, 'getStore').mockReturnValue('test-prompt-id');
|
||||
});
|
||||
@@ -76,6 +79,7 @@ describe('ClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
@@ -94,6 +98,7 @@ describe('ClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).not.toBeNull();
|
||||
@@ -109,7 +114,12 @@ describe('ClassifierStrategy', () => {
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
||||
await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
@@ -132,6 +142,7 @@ describe('ClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledOnce();
|
||||
@@ -159,6 +170,7 @@ describe('ClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledOnce();
|
||||
@@ -183,6 +195,7 @@ describe('ClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
@@ -206,6 +219,7 @@ describe('ClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
@@ -233,7 +247,12 @@ describe('ClassifierStrategy', () => {
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
||||
await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||
.calls[0][0];
|
||||
@@ -269,7 +288,12 @@ describe('ClassifierStrategy', () => {
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
||||
await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||
.calls[0][0];
|
||||
@@ -305,7 +329,12 @@ describe('ClassifierStrategy', () => {
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
||||
await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||
.calls[0][0];
|
||||
@@ -340,6 +369,7 @@ describe('ClassifierStrategy', () => {
|
||||
contextWithRequestedModel,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).not.toBeNull();
|
||||
@@ -363,6 +393,7 @@ describe('ClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_MODEL);
|
||||
@@ -386,6 +417,7 @@ describe('ClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL);
|
||||
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
isFunctionResponse,
|
||||
} from '../../utils/messageInspectors.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
|
||||
import { LlmRole } from '../../telemetry/types.js';
|
||||
import { AuthType } from '../../core/contentGenerator.js';
|
||||
|
||||
@@ -132,6 +133,7 @@ export class ClassifierStrategy implements RoutingStrategy {
|
||||
context: RoutingContext,
|
||||
config: Config,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
_localLiteRtLmClient: LocalLiteRtLmClient,
|
||||
): Promise<RoutingDecision | null> {
|
||||
const startTime = Date.now();
|
||||
try {
|
||||
|
||||
@@ -16,6 +16,7 @@ import type { Config } from '../../config/config.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
import { coreEvents } from '../../utils/events.js';
|
||||
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
|
||||
|
||||
vi.mock('../../utils/debugLogger.js', () => ({
|
||||
debugLogger: {
|
||||
@@ -27,6 +28,7 @@ describe('CompositeStrategy', () => {
|
||||
let mockContext: RoutingContext;
|
||||
let mockConfig: Config;
|
||||
let mockBaseLlmClient: BaseLlmClient;
|
||||
let mockLocalLiteRtLmClient: LocalLiteRtLmClient;
|
||||
let mockStrategy1: RoutingStrategy;
|
||||
let mockStrategy2: RoutingStrategy;
|
||||
let mockTerminalStrategy: TerminalStrategy;
|
||||
@@ -38,6 +40,7 @@ describe('CompositeStrategy', () => {
|
||||
mockContext = {} as RoutingContext;
|
||||
mockConfig = {} as Config;
|
||||
mockBaseLlmClient = {} as BaseLlmClient;
|
||||
mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient;
|
||||
|
||||
emitFeedbackSpy = vi.spyOn(coreEvents, 'emitFeedback');
|
||||
|
||||
@@ -84,17 +87,20 @@ describe('CompositeStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(mockStrategy1.route).toHaveBeenCalledWith(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
expect(mockStrategy2.route).toHaveBeenCalledWith(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
expect(mockTerminalStrategy.route).not.toHaveBeenCalled();
|
||||
|
||||
@@ -112,6 +118,7 @@ describe('CompositeStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(mockStrategy1.route).toHaveBeenCalledTimes(1);
|
||||
@@ -136,6 +143,7 @@ describe('CompositeStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(debugLogger.warn).toHaveBeenCalledWith(
|
||||
@@ -152,7 +160,12 @@ describe('CompositeStrategy', () => {
|
||||
const composite = new CompositeStrategy([mockTerminalStrategy]);
|
||||
|
||||
await expect(
|
||||
composite.route(mockContext, mockConfig, mockBaseLlmClient),
|
||||
composite.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
),
|
||||
).rejects.toThrow(terminalError);
|
||||
|
||||
expect(emitFeedbackSpy).toHaveBeenCalledWith(
|
||||
@@ -182,6 +195,7 @@ describe('CompositeStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(result.model).toBe('some-model');
|
||||
@@ -212,6 +226,7 @@ describe('CompositeStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(result.metadata.latencyMs).toBeGreaterThanOrEqual(0);
|
||||
|
||||
@@ -14,6 +14,7 @@ import type {
|
||||
RoutingStrategy,
|
||||
TerminalStrategy,
|
||||
} from '../routingStrategy.js';
|
||||
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
|
||||
|
||||
/**
|
||||
* A strategy that attempts a list of child strategies in order (Chain of Responsibility).
|
||||
@@ -40,6 +41,7 @@ export class CompositeStrategy implements TerminalStrategy {
|
||||
context: RoutingContext,
|
||||
config: Config,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
localLiteRtLmClient: LocalLiteRtLmClient,
|
||||
): Promise<RoutingDecision> {
|
||||
const startTime = performance.now();
|
||||
|
||||
@@ -57,7 +59,12 @@ export class CompositeStrategy implements TerminalStrategy {
|
||||
// Try non-terminal strategies, allowing them to fail gracefully.
|
||||
for (const strategy of nonTerminalStrategies) {
|
||||
try {
|
||||
const decision = await strategy.route(context, config, baseLlmClient);
|
||||
const decision = await strategy.route(
|
||||
context,
|
||||
config,
|
||||
baseLlmClient,
|
||||
localLiteRtLmClient,
|
||||
);
|
||||
if (decision) {
|
||||
return this.finalizeDecision(decision, startTime);
|
||||
}
|
||||
@@ -75,6 +82,7 @@ export class CompositeStrategy implements TerminalStrategy {
|
||||
context,
|
||||
config,
|
||||
baseLlmClient,
|
||||
localLiteRtLmClient,
|
||||
);
|
||||
|
||||
return this.finalizeDecision(decision, startTime);
|
||||
|
||||
@@ -8,6 +8,7 @@ import { describe, it, expect, vi } from 'vitest';
|
||||
import { DefaultStrategy } from './defaultStrategy.js';
|
||||
import type { RoutingContext } from '../routingStrategy.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
@@ -26,8 +27,14 @@ describe('DefaultStrategy', () => {
|
||||
getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO),
|
||||
} as unknown as Config;
|
||||
const mockClient = {} as BaseLlmClient;
|
||||
const mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient;
|
||||
|
||||
const decision = await strategy.route(mockContext, mockConfig, mockClient);
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
@@ -46,8 +53,14 @@ describe('DefaultStrategy', () => {
|
||||
getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO),
|
||||
} as unknown as Config;
|
||||
const mockClient = {} as BaseLlmClient;
|
||||
const mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient;
|
||||
|
||||
const decision = await strategy.route(mockContext, mockConfig, mockClient);
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_MODEL,
|
||||
@@ -66,8 +79,14 @@ describe('DefaultStrategy', () => {
|
||||
getModel: vi.fn().mockReturnValue(GEMINI_MODEL_ALIAS_AUTO),
|
||||
} as unknown as Config;
|
||||
const mockClient = {} as BaseLlmClient;
|
||||
const mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient;
|
||||
|
||||
const decision = await strategy.route(mockContext, mockConfig, mockClient);
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_MODEL,
|
||||
@@ -87,8 +106,14 @@ describe('DefaultStrategy', () => {
|
||||
getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_FLASH_MODEL),
|
||||
} as unknown as Config;
|
||||
const mockClient = {} as BaseLlmClient;
|
||||
const mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient;
|
||||
|
||||
const decision = await strategy.route(mockContext, mockConfig, mockClient);
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL,
|
||||
|
||||
@@ -12,6 +12,7 @@ import type {
|
||||
TerminalStrategy,
|
||||
} from '../routingStrategy.js';
|
||||
import { resolveModel } from '../../config/models.js';
|
||||
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
|
||||
|
||||
export class DefaultStrategy implements TerminalStrategy {
|
||||
readonly name = 'default';
|
||||
@@ -20,6 +21,7 @@ export class DefaultStrategy implements TerminalStrategy {
|
||||
_context: RoutingContext,
|
||||
config: Config,
|
||||
_baseLlmClient: BaseLlmClient,
|
||||
_localLiteRtLmClient: LocalLiteRtLmClient,
|
||||
): Promise<RoutingDecision> {
|
||||
const defaultModel = resolveModel(
|
||||
config.getModel(),
|
||||
|
||||
@@ -10,6 +10,7 @@ import type { RoutingContext } from '../routingStrategy.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import type { ModelAvailabilityService } from '../../availability/modelAvailabilityService.js';
|
||||
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
@@ -32,6 +33,7 @@ describe('FallbackStrategy', () => {
|
||||
const strategy = new FallbackStrategy();
|
||||
const mockContext = {} as RoutingContext;
|
||||
const mockClient = {} as BaseLlmClient;
|
||||
const mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient;
|
||||
let mockService: ModelAvailabilityService;
|
||||
let mockConfig: Config;
|
||||
|
||||
@@ -51,7 +53,12 @@ describe('FallbackStrategy', () => {
|
||||
// Mock snapshot to return available
|
||||
vi.mocked(mockService.snapshot).mockReturnValue({ available: true });
|
||||
|
||||
const decision = await strategy.route(mockContext, mockConfig, mockClient);
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
expect(decision).toBeNull();
|
||||
// Should check availability of the resolved model (DEFAULT_GEMINI_MODEL)
|
||||
expect(mockService.snapshot).toHaveBeenCalledWith(DEFAULT_GEMINI_MODEL);
|
||||
@@ -69,7 +76,12 @@ describe('FallbackStrategy', () => {
|
||||
skipped: [],
|
||||
});
|
||||
|
||||
const decision = await strategy.route(mockContext, mockConfig, mockClient);
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
expect(decision).toBeNull();
|
||||
});
|
||||
|
||||
@@ -86,7 +98,12 @@ describe('FallbackStrategy', () => {
|
||||
skipped: [{ model: DEFAULT_GEMINI_MODEL, reason: 'quota' }],
|
||||
});
|
||||
|
||||
const decision = await strategy.route(mockContext, mockConfig, mockClient);
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).not.toBeNull();
|
||||
expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
@@ -101,7 +118,12 @@ describe('FallbackStrategy', () => {
|
||||
vi.mocked(mockService.snapshot).mockReturnValue({ available: true });
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
|
||||
|
||||
const decision = await strategy.route(mockContext, mockConfig, mockClient);
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
// Important: check that it queried snapshot with the RESOLVED model, not 'auto'
|
||||
@@ -122,6 +144,7 @@ describe('FallbackStrategy', () => {
|
||||
contextWithRequestedModel,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
|
||||
@@ -13,6 +13,7 @@ import type {
|
||||
RoutingDecision,
|
||||
RoutingStrategy,
|
||||
} from '../routingStrategy.js';
|
||||
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
|
||||
|
||||
export class FallbackStrategy implements RoutingStrategy {
|
||||
readonly name = 'fallback';
|
||||
@@ -21,6 +22,7 @@ export class FallbackStrategy implements RoutingStrategy {
|
||||
context: RoutingContext,
|
||||
config: Config,
|
||||
_baseLlmClient: BaseLlmClient,
|
||||
_localLiteRtLmClient: LocalLiteRtLmClient,
|
||||
): Promise<RoutingDecision | null> {
|
||||
const requestedModel = context.requestedModel ?? config.getModel();
|
||||
const resolvedModel = resolveModel(
|
||||
|
||||
@@ -0,0 +1,324 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Mock } from 'vitest';
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { GemmaClassifierStrategy } from './gemmaClassifierStrategy.js';
|
||||
import type { RoutingContext } from '../routingStrategy.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
} from '../../config/models.js';
|
||||
import type { Content } from '@google/genai';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
|
||||
|
||||
vi.mock('../../core/localLiteRtLmClient.js');
|
||||
|
||||
describe('GemmaClassifierStrategy', () => {
|
||||
let strategy: GemmaClassifierStrategy;
|
||||
let mockContext: RoutingContext;
|
||||
let mockConfig: Config;
|
||||
let mockBaseLlmClient: BaseLlmClient;
|
||||
let mockLocalLiteRtLmClient: LocalLiteRtLmClient;
|
||||
let mockGenerateJson: Mock;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockGenerateJson = vi.fn();
|
||||
|
||||
mockConfig = {
|
||||
getGemmaModelRouterSettings: vi.fn().mockReturnValue({
|
||||
enabled: true,
|
||||
classifier: { model: 'gemma3-1b-gpu-custom' },
|
||||
}),
|
||||
getModel: () => DEFAULT_GEMINI_MODEL,
|
||||
getPreviewFeatures: () => false,
|
||||
} as unknown as Config;
|
||||
|
||||
strategy = new GemmaClassifierStrategy();
|
||||
mockContext = {
|
||||
history: [],
|
||||
request: 'simple task',
|
||||
signal: new AbortController().signal,
|
||||
};
|
||||
|
||||
mockBaseLlmClient = {} as BaseLlmClient;
|
||||
mockLocalLiteRtLmClient = {
|
||||
generateJson: mockGenerateJson,
|
||||
} as unknown as LocalLiteRtLmClient;
|
||||
});
|
||||
|
||||
it('should return null if gemma model router is disabled', async () => {
|
||||
vi.mocked(mockConfig.getGemmaModelRouterSettings).mockReturnValue({
|
||||
enabled: false,
|
||||
});
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
expect(decision).toBeNull();
|
||||
});
|
||||
|
||||
it('should throw an error if the model is not gemma3-1b-gpu-custom', async () => {
|
||||
vi.mocked(mockConfig.getGemmaModelRouterSettings).mockReturnValue({
|
||||
enabled: true,
|
||||
classifier: { model: 'other-model' },
|
||||
});
|
||||
|
||||
await expect(
|
||||
strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
),
|
||||
).rejects.toThrow('Only gemma3-1b-gpu-custom has been tested');
|
||||
});
|
||||
|
||||
it('should call generateJson with the correct parameters', async () => {
|
||||
const mockApiResponse = {
|
||||
reasoning: 'Simple task',
|
||||
model_choice: 'flash',
|
||||
};
|
||||
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||
|
||||
await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(mockGenerateJson).toHaveBeenCalledWith(
|
||||
expect.any(Array),
|
||||
expect.any(String),
|
||||
expect.any(String),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
});
|
||||
|
||||
it('should route to FLASH model for a simple task', async () => {
|
||||
const mockApiResponse = {
|
||||
reasoning: 'This is a simple task.',
|
||||
model_choice: 'flash',
|
||||
};
|
||||
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(mockGenerateJson).toHaveBeenCalledOnce();
|
||||
expect(decision).toEqual({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
metadata: {
|
||||
source: 'GemmaClassifier',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: mockApiResponse.reasoning,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should route to PRO model for a complex task', async () => {
|
||||
const mockApiResponse = {
|
||||
reasoning: 'This is a complex task.',
|
||||
model_choice: 'pro',
|
||||
};
|
||||
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||
mockContext.request = 'how do I build a spaceship?';
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(mockGenerateJson).toHaveBeenCalledOnce();
|
||||
expect(decision).toEqual({
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
metadata: {
|
||||
source: 'GemmaClassifier',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: mockApiResponse.reasoning,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should return null if the classifier API call fails', async () => {
|
||||
const consoleWarnSpy = vi
|
||||
.spyOn(debugLogger, 'warn')
|
||||
.mockImplementation(() => {});
|
||||
const testError = new Error('API Failure');
|
||||
mockGenerateJson.mockRejectedValue(testError);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
expect(consoleWarnSpy).toHaveBeenCalled();
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should return null if the classifier returns a malformed JSON object', async () => {
|
||||
const consoleWarnSpy = vi
|
||||
.spyOn(debugLogger, 'warn')
|
||||
.mockImplementation(() => {});
|
||||
const malformedApiResponse = {
|
||||
reasoning: 'This is a simple task.',
|
||||
// model_choice is missing, which will cause a Zod parsing error.
|
||||
};
|
||||
mockGenerateJson.mockResolvedValue(malformedApiResponse);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
expect(consoleWarnSpy).toHaveBeenCalled();
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should filter out tool-related history before sending to classifier', async () => {
|
||||
mockContext.history = [
|
||||
{ role: 'user', parts: [{ text: 'call a tool' }] },
|
||||
{
|
||||
role: 'model',
|
||||
parts: [{ functionCall: { name: 'test_tool', args: {} } }],
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{ functionResponse: { name: 'test_tool', response: { ok: true } } },
|
||||
],
|
||||
},
|
||||
{ role: 'user', parts: [{ text: 'another user turn' }] },
|
||||
];
|
||||
const mockApiResponse = {
|
||||
reasoning: 'Simple.',
|
||||
model_choice: 'flash',
|
||||
};
|
||||
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||
|
||||
await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
// Define a type for the arguments passed to the mock `generateJson`
|
||||
type GenerateJsonCall = [Content[], string, string | undefined];
|
||||
const calls = mockGenerateJson.mock.calls as GenerateJsonCall[];
|
||||
const contents = calls[0][0];
|
||||
const lastTurn = contents.at(-1);
|
||||
expect(lastTurn).toBeDefined();
|
||||
if (!lastTurn?.parts) {
|
||||
// Fail test if parts is not defined.
|
||||
expect(lastTurn?.parts).toBeDefined();
|
||||
return;
|
||||
}
|
||||
const expectedLastTurn = `You are provided with a **Chat History** and the user's **Current Request** below.
|
||||
|
||||
#### Chat History:
|
||||
call a tool
|
||||
|
||||
another user turn
|
||||
|
||||
#### Current Request:
|
||||
"simple task"
|
||||
`;
|
||||
expect(lastTurn.parts.at(0)?.text).toEqual(expectedLastTurn);
|
||||
});
|
||||
|
||||
it('should respect HISTORY_SEARCH_WINDOW and HISTORY_TURNS_FOR_CONTEXT', async () => {
|
||||
const longHistory: Content[] = [];
|
||||
for (let i = 0; i < 30; i++) {
|
||||
longHistory.push({ role: 'user', parts: [{ text: `Message ${i}` }] });
|
||||
// Add noise that should be filtered
|
||||
if (i % 2 === 0) {
|
||||
longHistory.push({
|
||||
role: 'model',
|
||||
parts: [{ functionCall: { name: 'noise', args: {} } }],
|
||||
});
|
||||
}
|
||||
}
|
||||
mockContext.history = longHistory;
|
||||
const mockApiResponse = {
|
||||
reasoning: 'Simple.',
|
||||
model_choice: 'flash',
|
||||
};
|
||||
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||
|
||||
await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
const generateJsonCall = mockGenerateJson.mock.calls[0][0];
|
||||
|
||||
// There should be 1 item which is the flattened history.
|
||||
expect(generateJsonCall).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('should filter out non-text parts from history', async () => {
|
||||
mockContext.history = [
|
||||
{ role: 'user', parts: [{ text: 'first message' }] },
|
||||
// This part has no `text` property and should be filtered out.
|
||||
{ role: 'user', parts: [{}] } as Content,
|
||||
{ role: 'user', parts: [{ text: 'second message' }] },
|
||||
];
|
||||
const mockApiResponse = {
|
||||
reasoning: 'Simple.',
|
||||
model_choice: 'flash',
|
||||
};
|
||||
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||
|
||||
await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
type GenerateJsonCall = [Content[], string, string | undefined];
|
||||
const calls = mockGenerateJson.mock.calls as GenerateJsonCall[];
|
||||
const contents = calls[0][0];
|
||||
const lastTurn = contents.at(-1);
|
||||
expect(lastTurn).toBeDefined();
|
||||
|
||||
const expectedLastTurn = `You are provided with a **Chat History** and the user's **Current Request** below.
|
||||
|
||||
#### Chat History:
|
||||
first message
|
||||
|
||||
second message
|
||||
|
||||
#### Current Request:
|
||||
"simple task"
|
||||
`;
|
||||
|
||||
expect(lastTurn!.parts!.at(0)!.text).toEqual(expectedLastTurn);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,232 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { z } from 'zod';
|
||||
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import type {
|
||||
RoutingContext,
|
||||
RoutingDecision,
|
||||
RoutingStrategy,
|
||||
} from '../routingStrategy.js';
|
||||
import { resolveClassifierModel } from '../../config/models.js';
|
||||
import { createUserContent, type Content, type Part } from '@google/genai';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import {
|
||||
isFunctionCall,
|
||||
isFunctionResponse,
|
||||
} from '../../utils/messageInspectors.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
|
||||
|
||||
// The number of recent history turns to provide to the router for context.
|
||||
const HISTORY_TURNS_FOR_CONTEXT = 4;
|
||||
const HISTORY_SEARCH_WINDOW = 20;
|
||||
|
||||
const FLASH_MODEL = 'flash';
|
||||
const PRO_MODEL = 'pro';
|
||||
|
||||
const COMPLEXITY_RUBRIC = `### Complexity Rubric
|
||||
A task is COMPLEX (Choose \`${PRO_MODEL}\`) if it meets ONE OR MORE of the following criteria:
|
||||
1. **High Operational Complexity (Est. 4+ Steps/Tool Calls):** Requires dependent actions, significant planning, or multiple coordinated changes.
|
||||
2. **Strategic Planning & Conceptual Design:** Asking "how" or "why." Requires advice, architecture, or high-level strategy.
|
||||
3. **High Ambiguity or Large Scope (Extensive Investigation):** Broadly defined requests requiring extensive investigation.
|
||||
4. **Deep Debugging & Root Cause Analysis:** Diagnosing unknown or complex problems from symptoms.
|
||||
A task is SIMPLE (Choose \`${FLASH_MODEL}\`) if it is highly specific, bounded, and has Low Operational Complexity (Est. 1-3 tool calls). Operational simplicity overrides strategic phrasing.`;
|
||||
|
||||
const OUTPUT_FORMAT = `### Output Format
|
||||
Respond *only* in JSON format like this:
|
||||
{
|
||||
"reasoning": Your reasoning...
|
||||
"model_choice": Either ${FLASH_MODEL} or ${PRO_MODEL}
|
||||
}
|
||||
And you must follow the following JSON schema:
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": "A brief summary of the user objective, followed by a step-by-step explanation for the model choice, referencing the rubric."
|
||||
},
|
||||
"model_choice": {
|
||||
"type": "string",
|
||||
"enum": ["${FLASH_MODEL}", "${PRO_MODEL}"]
|
||||
}
|
||||
},
|
||||
"required": ["reasoning", "model_choice"]
|
||||
}
|
||||
You must ensure that your reasoning is no more than 2 sentences long and directly references the rubric criteria.
|
||||
When making your decision, the user's request should be weighted much more heavily than the surrounding context when making your determination.`;
|
||||
|
||||
const LITERT_GEMMA_CLASSIFIER_SYSTEM_PROMPT = `### Role
|
||||
You are the **Lead Orchestrator** for an AI system. You do not talk to users. Your sole responsibility is to analyze the **Chat History** and delegate the **Current Request** to the most appropriate **Model** based on the request's complexity.
|
||||
|
||||
### Models
|
||||
Choose between \`${FLASH_MODEL}\` (SIMPLE) or \`${PRO_MODEL}\` (COMPLEX).
|
||||
1. \`${FLASH_MODEL}\`: A fast, efficient model for simple, well-defined tasks.
|
||||
2. \`${PRO_MODEL}\`: A powerful, advanced model for complex, open-ended, or multi-step tasks.
|
||||
|
||||
${COMPLEXITY_RUBRIC}
|
||||
|
||||
${OUTPUT_FORMAT}
|
||||
|
||||
### Examples
|
||||
**Example 1 (Strategic Planning):**
|
||||
*User Prompt:* "How should I architect the data pipeline for this new analytics service?"
|
||||
*Your JSON Output:*
|
||||
{
|
||||
"reasoning": "The user is asking for high-level architectural design and strategy. This falls under 'Strategic Planning & Conceptual Design'.",
|
||||
"model_choice": "${PRO_MODEL}"
|
||||
}
|
||||
**Example 2 (Simple Tool Use):**
|
||||
*User Prompt:* "list the files in the current directory"
|
||||
*Your JSON Output:*
|
||||
{
|
||||
"reasoning": "This is a direct command requiring a single tool call (ls). It has Low Operational Complexity (1 step).",
|
||||
"model_choice": "${FLASH_MODEL}"
|
||||
}
|
||||
**Example 3 (High Operational Complexity):**
|
||||
*User Prompt:* "I need to add a new 'email' field to the User schema in 'src/models/user.ts', migrate the database, and update the registration endpoint."
|
||||
*Your JSON Output:*
|
||||
{
|
||||
"reasoning": "This request involves multiple coordinated steps across different files and systems. This meets the criteria for High Operational Complexity (4+ steps).",
|
||||
"model_choice": "${PRO_MODEL}"
|
||||
}
|
||||
**Example 4 (Simple Read):**
|
||||
*User Prompt:* "Read the contents of 'package.json'."
|
||||
*Your JSON Output:*
|
||||
{
|
||||
"reasoning": "This is a direct command requiring a single read. It has Low Operational Complexity (1 step).",
|
||||
"model_choice": "${FLASH_MODEL}"
|
||||
}
|
||||
**Example 5 (Deep Debugging):**
|
||||
*User Prompt:* "I'm getting an error 'Cannot read property 'map' of undefined' when I click the save button. Can you fix it?"
|
||||
*Your JSON Output:*
|
||||
{
|
||||
"reasoning": "The user is reporting an error symptom without a known cause. This requires investigation and falls under 'Deep Debugging'.",
|
||||
"model_choice": "${PRO_MODEL}"
|
||||
}
|
||||
**Example 6 (Simple Edit despite Phrasing):**
|
||||
*User Prompt:* "What is the best way to rename the variable 'data' to 'userData' in 'src/utils.js'?"
|
||||
*Your JSON Output:*
|
||||
{
|
||||
"reasoning": "Although the user uses strategic language ('best way'), the underlying task is a localized edit. The operational complexity is low (1-2 steps).",
|
||||
"model_choice": "${FLASH_MODEL}"
|
||||
}
|
||||
`;
|
||||
|
||||
const LITERT_GEMMA_CLASSIFIER_REMINDER = `### Reminder
|
||||
You are a Task Routing AI. Your sole task is to analyze the preceding **Chat History** and **Current Request** and classify its complexity.
|
||||
|
||||
${COMPLEXITY_RUBRIC}
|
||||
|
||||
${OUTPUT_FORMAT}
|
||||
`;
|
||||
|
||||
const ClassifierResponseSchema = z.object({
|
||||
reasoning: z.string(),
|
||||
model_choice: z.enum([FLASH_MODEL, PRO_MODEL]),
|
||||
});
|
||||
|
||||
export class GemmaClassifierStrategy implements RoutingStrategy {
|
||||
readonly name = 'gemma-classifier';
|
||||
|
||||
private flattenChatHistory(turns: Content[]): Content[] {
|
||||
const formattedHistory = turns
|
||||
.slice(0, -1)
|
||||
.map((turn) =>
|
||||
turn.parts
|
||||
? turn.parts
|
||||
.map((part) => part.text)
|
||||
.filter(Boolean)
|
||||
.join('\n')
|
||||
: '',
|
||||
)
|
||||
.filter(Boolean)
|
||||
.join('\n\n');
|
||||
|
||||
const lastTurn = turns.at(-1);
|
||||
const userRequest =
|
||||
lastTurn?.parts
|
||||
?.map((part: Part) => part.text)
|
||||
.filter(Boolean)
|
||||
.join('\n\n') ?? '';
|
||||
|
||||
const finalPrompt = `You are provided with a **Chat History** and the user's **Current Request** below.
|
||||
|
||||
#### Chat History:
|
||||
${formattedHistory}
|
||||
|
||||
#### Current Request:
|
||||
"${userRequest}"
|
||||
`;
|
||||
return [createUserContent(finalPrompt)];
|
||||
}
|
||||
|
||||
async route(
|
||||
context: RoutingContext,
|
||||
config: Config,
|
||||
_baseLlmClient: BaseLlmClient,
|
||||
client: LocalLiteRtLmClient,
|
||||
): Promise<RoutingDecision | null> {
|
||||
const startTime = Date.now();
|
||||
const gemmaRouterSettings = config.getGemmaModelRouterSettings();
|
||||
if (!gemmaRouterSettings?.enabled) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Only the gemma3-1b-gpu-custom model has been tested and verified.
|
||||
if (gemmaRouterSettings.classifier?.model !== 'gemma3-1b-gpu-custom') {
|
||||
throw new Error('Only gemma3-1b-gpu-custom has been tested');
|
||||
}
|
||||
|
||||
try {
|
||||
const historySlice = context.history.slice(-HISTORY_SEARCH_WINDOW);
|
||||
|
||||
// Filter out tool-related turns.
|
||||
// TODO - Consider using function req/res if they help accuracy.
|
||||
const cleanHistory = historySlice.filter(
|
||||
(content) => !isFunctionCall(content) && !isFunctionResponse(content),
|
||||
);
|
||||
|
||||
// Take the last N turns from the *cleaned* history.
|
||||
const finalHistory = cleanHistory.slice(-HISTORY_TURNS_FOR_CONTEXT);
|
||||
|
||||
const history = [...finalHistory, createUserContent(context.request)];
|
||||
const singleMessageHistory = this.flattenChatHistory(history);
|
||||
|
||||
const jsonResponse = await client.generateJson(
|
||||
singleMessageHistory,
|
||||
LITERT_GEMMA_CLASSIFIER_SYSTEM_PROMPT,
|
||||
LITERT_GEMMA_CLASSIFIER_REMINDER,
|
||||
context.signal,
|
||||
);
|
||||
|
||||
const routerResponse = ClassifierResponseSchema.parse(jsonResponse);
|
||||
|
||||
const reasoning = routerResponse.reasoning;
|
||||
const latencyMs = Date.now() - startTime;
|
||||
const selectedModel = resolveClassifierModel(
|
||||
context.requestedModel ?? config.getModel(),
|
||||
routerResponse.model_choice,
|
||||
);
|
||||
|
||||
return {
|
||||
model: selectedModel,
|
||||
metadata: {
|
||||
source: 'GemmaClassifier',
|
||||
latencyMs,
|
||||
reasoning,
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
// If the classifier fails for any reason (API error, parsing error, etc.),
|
||||
// we log it and return null to allow the composite strategy to proceed.
|
||||
debugLogger.warn(`[Routing] GemmaClassifierStrategy failed:`, error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import { promptIdContext } from '../../utils/promptIdContext.js';
|
||||
import type { Content } from '@google/genai';
|
||||
import type { ResolvedModelConfig } from '../../services/modelConfigService.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
|
||||
import { AuthType } from '../../core/contentGenerator.js';
|
||||
|
||||
vi.mock('../../core/baseLlmClient.js');
|
||||
@@ -31,6 +32,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
let mockContext: RoutingContext;
|
||||
let mockConfig: Config;
|
||||
let mockBaseLlmClient: BaseLlmClient;
|
||||
let mockLocalLiteRtLmClient: LocalLiteRtLmClient;
|
||||
let mockResolvedConfig: ResolvedModelConfig;
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -63,6 +65,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockBaseLlmClient = {
|
||||
generateJson: vi.fn(),
|
||||
} as unknown as BaseLlmClient;
|
||||
mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient;
|
||||
|
||||
vi.spyOn(promptIdContext, 'getStore').mockReturnValue('test-prompt-id');
|
||||
});
|
||||
@@ -78,6 +81,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
@@ -91,6 +95,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
@@ -104,6 +109,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
@@ -119,7 +125,12 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
||||
await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||
.calls[0][0];
|
||||
@@ -151,6 +162,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
@@ -177,6 +189,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
@@ -203,6 +216,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
@@ -229,6 +243,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
@@ -257,6 +272,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
@@ -283,6 +299,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
@@ -309,6 +326,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
@@ -337,6 +355,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
@@ -364,6 +383,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
@@ -391,6 +411,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
@@ -415,6 +436,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
@@ -437,6 +459,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
@@ -463,7 +486,12 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
||||
await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||
.calls[0][0];
|
||||
@@ -495,7 +523,12 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
||||
await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||
.calls[0][0];
|
||||
@@ -528,7 +561,12 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
||||
await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||
.calls[0][0];
|
||||
@@ -558,6 +596,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_MODEL);
|
||||
@@ -579,6 +618,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL);
|
||||
@@ -601,6 +641,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_MODEL);
|
||||
|
||||
@@ -16,6 +16,7 @@ import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
|
||||
import { createUserContent, Type } from '@google/genai';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
|
||||
import { LlmRole } from '../../telemetry/types.js';
|
||||
import { AuthType } from '../../core/contentGenerator.js';
|
||||
|
||||
@@ -133,6 +134,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
|
||||
context: RoutingContext,
|
||||
config: Config,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
_localLiteRtLmClient: LocalLiteRtLmClient,
|
||||
): Promise<RoutingDecision | null> {
|
||||
const startTime = Date.now();
|
||||
try {
|
||||
|
||||
@@ -10,18 +10,25 @@ import type { RoutingContext } from '../routingStrategy.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { DEFAULT_GEMINI_MODEL_AUTO } from '../../config/models.js';
|
||||
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
|
||||
|
||||
describe('OverrideStrategy', () => {
|
||||
const strategy = new OverrideStrategy();
|
||||
const mockContext = {} as RoutingContext;
|
||||
const mockClient = {} as BaseLlmClient;
|
||||
const mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient;
|
||||
|
||||
it('should return null when the override model is auto', async () => {
|
||||
const mockConfig = {
|
||||
getModel: () => DEFAULT_GEMINI_MODEL_AUTO,
|
||||
} as Config;
|
||||
|
||||
const decision = await strategy.route(mockContext, mockConfig, mockClient);
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
expect(decision).toBeNull();
|
||||
});
|
||||
|
||||
@@ -31,7 +38,12 @@ describe('OverrideStrategy', () => {
|
||||
getModel: () => overrideModel,
|
||||
} as Config;
|
||||
|
||||
const decision = await strategy.route(mockContext, mockConfig, mockClient);
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).not.toBeNull();
|
||||
expect(decision?.model).toBe(overrideModel);
|
||||
@@ -48,7 +60,12 @@ describe('OverrideStrategy', () => {
|
||||
getModel: () => overrideModel,
|
||||
} as Config;
|
||||
|
||||
const decision = await strategy.route(mockContext, mockConfig, mockClient);
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).not.toBeNull();
|
||||
expect(decision?.model).toBe(overrideModel);
|
||||
@@ -68,6 +85,7 @@ describe('OverrideStrategy', () => {
|
||||
contextWithRequestedModel,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).not.toBeNull();
|
||||
|
||||
@@ -12,6 +12,7 @@ import type {
|
||||
RoutingDecision,
|
||||
RoutingStrategy,
|
||||
} from '../routingStrategy.js';
|
||||
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
|
||||
|
||||
/**
|
||||
* Handles cases where the user explicitly specifies a model (override).
|
||||
@@ -23,6 +24,7 @@ export class OverrideStrategy implements RoutingStrategy {
|
||||
context: RoutingContext,
|
||||
config: Config,
|
||||
_baseLlmClient: BaseLlmClient,
|
||||
_localLiteRtLmClient: LocalLiteRtLmClient,
|
||||
): Promise<RoutingDecision | null> {
|
||||
const overrideModel = context.requestedModel ?? config.getModel();
|
||||
|
||||
|
||||
@@ -1694,6 +1694,47 @@
|
||||
"markdownDescription": "Enable web fetch behavior that bypasses LLM summarization.\n\n- Category: `Experimental`\n- Requires restart: `yes`\n- Default: `false`",
|
||||
"default": false,
|
||||
"type": "boolean"
|
||||
},
|
||||
"gemmaModelRouter": {
|
||||
"title": "Gemma Model Router",
|
||||
"description": "Enable Gemma model router (experimental).",
|
||||
"markdownDescription": "Enable Gemma model router (experimental).\n\n- Category: `Experimental`\n- Requires restart: `yes`\n- Default: `{}`",
|
||||
"default": {},
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {
|
||||
"title": "Enable Gemma Model Router",
|
||||
"description": "Enable the Gemma Model Router. Requires a local endpoint serving Gemma via the Gemini API using LiteRT-LM shim.",
|
||||
"markdownDescription": "Enable the Gemma Model Router. Requires a local endpoint serving Gemma via the Gemini API using LiteRT-LM shim.\n\n- Category: `Experimental`\n- Requires restart: `yes`\n- Default: `false`",
|
||||
"default": false,
|
||||
"type": "boolean"
|
||||
},
|
||||
"classifier": {
|
||||
"title": "Classifier",
|
||||
"description": "Classifier configuration.",
|
||||
"markdownDescription": "Classifier configuration.\n\n- Category: `Experimental`\n- Requires restart: `yes`\n- Default: `{}`",
|
||||
"default": {},
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"host": {
|
||||
"title": "Host",
|
||||
"description": "The host of the classifier.",
|
||||
"markdownDescription": "The host of the classifier.\n\n- Category: `Experimental`\n- Requires restart: `yes`\n- Default: `http://localhost:9379`",
|
||||
"default": "http://localhost:9379",
|
||||
"type": "string"
|
||||
},
|
||||
"model": {
|
||||
"title": "Model",
|
||||
"description": "The model to use for the classifier. Only tested on `gemma3-1b-gpu-custom`.",
|
||||
"markdownDescription": "The model to use for the classifier. Only tested on `gemma3-1b-gpu-custom`.\n\n- Category: `Experimental`\n- Requires restart: `yes`\n- Default: `gemma3-1b-gpu-custom`",
|
||||
"default": "gemma3-1b-gpu-custom",
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
|
||||
Reference in New Issue
Block a user