mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-06-13 12:57:12 -07:00
feat(optimization): finalized iterative-surgical optimization suite (checkpoint)
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { debugLogger } from '../../../../../packages/core/src/utils/debugLogger.js';
|
||||
import { debugLogger } from '../../../../packages/core/src/utils/debugLogger.js';
|
||||
import { DEFAULT_EVAL_CONFIG } from '../config.js';
|
||||
import { MetricObjective } from '../types.js';
|
||||
import type { MetricResult } from '../types.js';
|
||||
|
||||
@@ -7,77 +7,62 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { evaluateToolAlignment } from './toolAlignment.js';
|
||||
import { MetricObjective } from '../types.js';
|
||||
import type { Scenario } from '../schema.js';
|
||||
|
||||
describe('evaluateToolAlignment', () => {
|
||||
const mockScenario: Scenario = {
|
||||
const mockScenario = {
|
||||
id: 'test-scenario',
|
||||
metadata: { tags: ['test'], created_at: '2026-03-02' },
|
||||
input: { user_query: 'test query' },
|
||||
expected: {
|
||||
tool_calls: [{ name: 'read_file', arguments: { file_path: 'test.ts' } }],
|
||||
rationale: 'Testing alignment',
|
||||
tool_calls: [{ name: 'test_tool', arguments: { arg: 1 } }],
|
||||
},
|
||||
negatives: [
|
||||
{
|
||||
tool_calls: [
|
||||
{ name: 'run_shell_command', arguments: { command: 'cat test.ts' } },
|
||||
],
|
||||
reason: 'Avoid shell',
|
||||
tool_calls: [{ name: 'shell', arguments: { cmd: 'rm -rf' } }],
|
||||
reason: 'Matched negative shell pattern',
|
||||
severity: 'high',
|
||||
},
|
||||
}
|
||||
],
|
||||
};
|
||||
} as any;
|
||||
|
||||
it('should return 1.0 for a perfect match', () => {
|
||||
it('should return 1.0 for a perfect functional match', () => {
|
||||
const prediction = {
|
||||
tool_calls: [{ name: 'read_file', arguments: { file_path: 'test.ts' } }],
|
||||
tool_calls: [{ name: 'test_tool', arguments: { arg: 1 } }],
|
||||
};
|
||||
const result = evaluateToolAlignment(prediction, mockScenario);
|
||||
expect(result.score).toBe(1.0);
|
||||
expect(result.objective).toBe(MetricObjective.ALIGNMENT);
|
||||
expect(result.reason).toContain('Functional Success');
|
||||
});
|
||||
|
||||
it('should return 0.0 for a hard failure (negative match)', () => {
|
||||
const prediction = {
|
||||
tool_calls: [
|
||||
{ name: 'run_shell_command', arguments: { command: 'cat test.ts' } },
|
||||
],
|
||||
tool_calls: [{ name: 'shell', arguments: { cmd: 'rm -rf' } }],
|
||||
};
|
||||
const result = evaluateToolAlignment(prediction, mockScenario);
|
||||
expect(result.score).toBe(0.0);
|
||||
expect(result.reason).toContain('Hard Failure');
|
||||
expect(result.metadata?.['matchedNegativeReason']).toBe('Avoid shell');
|
||||
expect(result.reason).toContain('Matched negative shell pattern');
|
||||
});
|
||||
|
||||
it('should return 0.1 for an incorrect tool selection', () => {
|
||||
const prediction = {
|
||||
tool_calls: [
|
||||
{
|
||||
name: 'write_file',
|
||||
arguments: { file_path: 'test.ts', content: 'test' },
|
||||
},
|
||||
],
|
||||
tool_calls: [{ name: 'wrong_tool', arguments: { arg: 1 } }],
|
||||
};
|
||||
const result = evaluateToolAlignment(prediction, mockScenario);
|
||||
expect(result.score).toBe(0.1);
|
||||
expect(result.reason).toContain('wrong tool');
|
||||
});
|
||||
|
||||
it('should return 0.4 for correct tool but wrong arguments', () => {
|
||||
const prediction = {
|
||||
tool_calls: [{ name: 'read_file', arguments: { file_path: 'wrong.ts' } }],
|
||||
tool_calls: [{ name: 'test_tool', arguments: { arg: 999 } }],
|
||||
};
|
||||
const result = evaluateToolAlignment(prediction, mockScenario);
|
||||
expect(result.score).toBe(0.4);
|
||||
expect(result.reason).toContain('arguments are incorrect');
|
||||
});
|
||||
|
||||
it('should return 0.1 for an empty tool call list', () => {
|
||||
const prediction = { tool_calls: [] };
|
||||
const prediction = {
|
||||
tool_calls: [],
|
||||
};
|
||||
const result = evaluateToolAlignment(prediction, mockScenario);
|
||||
expect(result.score).toBe(0.1);
|
||||
expect(result.reason).toContain('failed to produce any tool calls');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { debugLogger } from '../../../../../packages/core/src/utils/debugLogger.js';
|
||||
import { debugLogger } from '../../../../packages/core/src/utils/debugLogger.js';
|
||||
import type { Scenario, ToolCall } from '../schema.js';
|
||||
import { DEFAULT_EVAL_CONFIG } from '../config.js';
|
||||
import { MetricObjective } from '../types.js';
|
||||
@@ -12,7 +12,7 @@ import type { MetricResult } from '../types.js';
|
||||
|
||||
/**
|
||||
* Evaluates the alignment of a model's predicted tool calls against a golden scenario.
|
||||
* Focuses on accuracy and shell avoidance.
|
||||
* Focuses strictly on functional correctness (tool selection and argument precision).
|
||||
*/
|
||||
export function evaluateToolAlignment(
|
||||
prediction: { tool_calls: ToolCall[] },
|
||||
@@ -25,6 +25,7 @@ export function evaluateToolAlignment(
|
||||
debugLogger.debug(`[Eval:${scenarioId}] Evaluating tool alignment...`);
|
||||
|
||||
// 1. Check for Hard Failures (Explicit Negatives)
|
||||
// These are for specific "Forbidden" tool uses (e.g., using shell instead of read_file)
|
||||
for (const negative of negatives) {
|
||||
const isNegativeMatch = negative.tool_calls.every((negCall: ToolCall) =>
|
||||
predictedCalls.some(
|
||||
@@ -35,26 +36,17 @@ export function evaluateToolAlignment(
|
||||
);
|
||||
|
||||
if (isNegativeMatch && negative.tool_calls.length > 0) {
|
||||
debugLogger.debug(
|
||||
`[Eval:${scenarioId}] Hard Failure: Matched negative pattern.`,
|
||||
);
|
||||
return {
|
||||
score: config.hardFailureScore,
|
||||
objective: MetricObjective.ALIGNMENT,
|
||||
reason: `Hard Failure: ${negative.reason}`,
|
||||
metadata: {
|
||||
matchedNegativeReason: negative.reason,
|
||||
severity: negative.severity,
|
||||
},
|
||||
metadata: { matchedNegativeReason: negative.reason },
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Structural Check
|
||||
if (predictedCalls.length === 0) {
|
||||
debugLogger.debug(
|
||||
`[Eval:${scenarioId}] Invalid Response: No tool calls found.`,
|
||||
);
|
||||
return {
|
||||
score: config.invalidResponseScore,
|
||||
objective: MetricObjective.ALIGNMENT,
|
||||
@@ -71,9 +63,6 @@ export function evaluateToolAlignment(
|
||||
);
|
||||
|
||||
if (!namesMatch) {
|
||||
debugLogger.debug(
|
||||
`[Eval:${scenarioId}] Failure: Incorrect tool selection.`,
|
||||
);
|
||||
return {
|
||||
score: config.invalidResponseScore,
|
||||
objective: MetricObjective.ALIGNMENT,
|
||||
@@ -91,9 +80,6 @@ export function evaluateToolAlignment(
|
||||
);
|
||||
|
||||
if (!argsMatch) {
|
||||
debugLogger.debug(
|
||||
`[Eval:${scenarioId}] Partial Success: Right tool, wrong arguments.`,
|
||||
);
|
||||
return {
|
||||
score: config.toolNameMatchOnlyScore,
|
||||
objective: MetricObjective.ALIGNMENT,
|
||||
@@ -102,14 +88,10 @@ export function evaluateToolAlignment(
|
||||
}
|
||||
|
||||
// 4. Perfect Success
|
||||
debugLogger.debug(
|
||||
`[Eval:${scenarioId}] Perfect Functional Alignment achieved.`,
|
||||
);
|
||||
return {
|
||||
score: config.functionalSuccessScore,
|
||||
objective: MetricObjective.ALIGNMENT,
|
||||
reason:
|
||||
'Functional Success: Tool and arguments align perfectly with golden scenario.',
|
||||
reason: 'Functional Success: Tool and arguments align perfectly.',
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"models": {
|
||||
"student": {
|
||||
"provider": "google-gemini",
|
||||
"modelId": "gemini-3.1-pro-preview"
|
||||
},
|
||||
"teacher": {
|
||||
"provider": "google-gemini",
|
||||
"modelId": "gemini-3.1-pro-preview"
|
||||
}
|
||||
},
|
||||
"gepa": {
|
||||
"numTrials": 5,
|
||||
"minibatch": true,
|
||||
"maxMetricCalls": 10
|
||||
},
|
||||
"paths": {
|
||||
"scenarios": "data/tool_alignment.jsonl",
|
||||
"targets": "data/optimization/targets.json",
|
||||
"outputDir": "data/optimization"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { maskVariables, unmaskVariables } from './masking.js';
|
||||
|
||||
vi.mock('node:fs');
|
||||
|
||||
describe('Optimization Pipeline Infrastructure', () => {
|
||||
// --- Masking Tests ---
|
||||
describe('Masking Utility', () => {
|
||||
it('should mask unique template variables and restore them', () => {
|
||||
const input =
|
||||
'Use ${TOOL_A} to read ${FILE_PATH}. ${TOOL_A} is efficient.';
|
||||
const { maskedText, maskMap } = maskVariables(input);
|
||||
|
||||
expect(maskedText).toContain('[[GCLI_VAR_0]]');
|
||||
expect(maskedText).not.toContain('${TOOL_A}');
|
||||
|
||||
const restored = unmaskVariables(maskedText, maskMap);
|
||||
expect(restored).toBe(input);
|
||||
});
|
||||
|
||||
it('should handle text with no variables', () => {
|
||||
const input = 'Static text.';
|
||||
const { maskedText, maskMap } = maskVariables(input);
|
||||
expect(maskedText).toBe(input);
|
||||
expect(Object.keys(maskMap).length).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
// Note: Extraction tests remain in extract.test.ts
|
||||
// Optimization logic is verified via dry runs and Pareto frontier outputs.
|
||||
});
|
||||
@@ -0,0 +1,213 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as fs from 'node:fs';
|
||||
import * as path from 'node:path';
|
||||
import { fileURLToPath } from 'node:url';
|
||||
import { ai, AxGEPA, ax } from '@ax-llm/ax';
|
||||
import { evaluateToolAlignment } from './evals/metrics/toolAlignment.js';
|
||||
import { evaluateBrevity } from './evals/metrics/brevityMetric.js';
|
||||
import type { Scenario, ToolCall } from './evals/schema.js';
|
||||
|
||||
interface OptimizationConfig {
|
||||
models: {
|
||||
student: { provider: string; modelId: string };
|
||||
teacher: { provider: string; modelId: string };
|
||||
};
|
||||
gepa: {
|
||||
numTrials?: number;
|
||||
minibatch?: boolean;
|
||||
maxMetricCalls?: number;
|
||||
};
|
||||
paths?: {
|
||||
scenarios?: string;
|
||||
targets?: string;
|
||||
outputDir?: string;
|
||||
};
|
||||
}
|
||||
|
||||
interface AxPrediction {
|
||||
tool_calls?: ToolCall[];
|
||||
output_text?: string;
|
||||
}
|
||||
|
||||
interface MetricArgs {
|
||||
prediction: AxPrediction;
|
||||
example: Scenario;
|
||||
}
|
||||
|
||||
let currentCallCount = 0;
|
||||
let maxCallsExpected = 0;
|
||||
|
||||
/**
|
||||
* multiObjectiveMetric: Evaluates model performance with structured logging.
|
||||
*/
|
||||
function multiObjectiveMetric({ prediction, example }: MetricArgs): Record<string, number> {
|
||||
currentCallCount++;
|
||||
|
||||
const modelOutput = {
|
||||
tool_calls: prediction.tool_calls || [],
|
||||
output_text: prediction.output_text || '',
|
||||
};
|
||||
|
||||
const alignment = evaluateToolAlignment(modelOutput, example);
|
||||
const brevity = evaluateBrevity(modelOutput);
|
||||
|
||||
// eslint-disable-next-line no-console
|
||||
console.log(`\n[ EVAL: ${currentCallCount}/${maxCallsExpected} | ${example.id} ]`);
|
||||
// eslint-disable-next-line no-console
|
||||
console.log(`Scores: Acc=${alignment.score.toFixed(2)} | Brev=${brevity.score.toFixed(2)}`);
|
||||
|
||||
return {
|
||||
accuracy: alignment.score,
|
||||
brevity: brevity.score,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Evolve a specific target snippet using GEPA.
|
||||
*/
|
||||
async function evolveTarget(
|
||||
id: string,
|
||||
allTargets: any[],
|
||||
scenarios: any[],
|
||||
config: OptimizationConfig,
|
||||
apiKey: string,
|
||||
outputDir: string
|
||||
) {
|
||||
// eslint-disable-next-line no-console
|
||||
console.log(`\n🎯 TARGETED EVOLUTION: ${id}`);
|
||||
|
||||
const target = allTargets.find((t) => t.id === id);
|
||||
const backgroundContext = allTargets
|
||||
.filter((t) => t.id !== id)
|
||||
.map((t) => `\n### ${t.id}\n${t.maskedText}`)
|
||||
.join('\n');
|
||||
|
||||
const student = ai({
|
||||
name: config.models.student.provider,
|
||||
apiKey,
|
||||
config: { model: config.models.student.modelId },
|
||||
});
|
||||
|
||||
const teacher = ai({
|
||||
name: config.models.teacher.provider,
|
||||
apiKey,
|
||||
config: { model: config.models.teacher.modelId },
|
||||
});
|
||||
|
||||
// Standard field names to avoid signature validation errors.
|
||||
const gcliProgram = ax(
|
||||
'user_query:string, platform:string, tags:string[], background_context:string -> tool_calls:json, output_text:string',
|
||||
{ instructions: target.maskedText }
|
||||
);
|
||||
|
||||
const dataset = scenarios.map(s => ({ ...s, background_context: backgroundContext }));
|
||||
|
||||
const optimizer = new AxGEPA({
|
||||
studentAI: student,
|
||||
teacherAI: teacher,
|
||||
numTrials: config.gepa.numTrials || 16,
|
||||
minibatch: config.gepa.minibatch !== false,
|
||||
verbose: true,
|
||||
});
|
||||
|
||||
currentCallCount = 0;
|
||||
maxCallsExpected = config.gepa.maxMetricCalls || 100;
|
||||
|
||||
const result = (await optimizer.compile(
|
||||
gcliProgram,
|
||||
dataset,
|
||||
multiObjectiveMetric,
|
||||
{
|
||||
maxMetricCalls: maxCallsExpected,
|
||||
}
|
||||
)) as any;
|
||||
|
||||
// Save to consolidated registry
|
||||
const resultsPath = path.join(outputDir, 'results.json');
|
||||
let registry: Record<string, any> = {};
|
||||
if (fs.existsSync(resultsPath)) {
|
||||
registry = JSON.parse(fs.readFileSync(resultsPath, 'utf8'));
|
||||
}
|
||||
|
||||
// The 'instruction' field in result.optimizedProgram contains the winner for the mutable part
|
||||
const optimizedText = result.optimizedProgram?.instruction || "ERROR_EXTRACTING";
|
||||
|
||||
registry[id] = {
|
||||
timestamp: new Date().toISOString(),
|
||||
bestScore: result.optimizedProgram?.bestScore,
|
||||
optimizedText,
|
||||
stats: result.stats,
|
||||
report: result.report,
|
||||
paretoFront: result.paretoFront?.map((entry: any) => ({
|
||||
scores: entry.scores,
|
||||
isBest: entry.isBest,
|
||||
text: entry.instruction || entry.program?.instruction
|
||||
}))
|
||||
};
|
||||
|
||||
fs.writeFileSync(resultsPath, JSON.stringify(registry, null, 2));
|
||||
|
||||
// eslint-disable-next-line no-console
|
||||
console.log(`✅ Evolution complete for ${id}.`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Main Optimization Runner.
|
||||
*/
|
||||
export async function runOptimization(configPath: string) {
|
||||
const apiKey = process.env.GEMINI_API_KEY;
|
||||
if (!apiKey) throw new Error('GEMINI_API_KEY is not set.');
|
||||
|
||||
if (!fs.existsSync(configPath)) throw new Error(`Config not found: ${configPath}`);
|
||||
|
||||
const config = JSON.parse(fs.readFileSync(configPath, 'utf8')) as OptimizationConfig;
|
||||
const scenariosPath = config.paths?.scenarios || 'data/tool_alignment.jsonl';
|
||||
const targetsPath = config.paths?.targets || 'data/optimization/targets.json';
|
||||
const outputDir = config.paths?.outputDir || 'data/optimization';
|
||||
|
||||
if (!fs.existsSync(targetsPath)) throw new Error(`Targets file not found: ${targetsPath}`);
|
||||
|
||||
const allTargets = JSON.parse(fs.readFileSync(targetsPath, 'utf8'));
|
||||
const scenarios = fs
|
||||
.readFileSync(scenariosPath, 'utf8')
|
||||
.split('\n')
|
||||
.filter(Boolean)
|
||||
.map((line) => {
|
||||
const data = JSON.parse(line);
|
||||
return {
|
||||
id: data.id,
|
||||
user_query: data.input.user_query,
|
||||
platform: data.metadata?.platform || 'unknown',
|
||||
tags: data.metadata?.tags || [],
|
||||
expected: data.expected,
|
||||
negatives: data.negatives || [],
|
||||
};
|
||||
});
|
||||
|
||||
// Iterative mode over ALL targets
|
||||
const targetsToOptimize = allTargets;
|
||||
|
||||
if (config.gepa.numTrials === 1) {
|
||||
// eslint-disable-next-line no-console
|
||||
console.log('🧪 Micro-Trial detected: Optimizing first target.');
|
||||
await evolveTarget(targetsToOptimize[0].id, allTargets, scenarios.slice(0, 2), config, apiKey, outputDir);
|
||||
} else {
|
||||
for (const t of targetsToOptimize) {
|
||||
await evolveTarget(t.id, allTargets, scenarios, config, apiKey, outputDir);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CLI Entrypoint
|
||||
const currentFilePath = fileURLToPath(import.meta.url);
|
||||
const isMain = process.argv[1] && fs.realpathSync(currentFilePath) === fs.realpathSync(process.argv[1]);
|
||||
|
||||
if (isMain) {
|
||||
const configPath = path.join(path.dirname(currentFilePath), 'optimization.config.json');
|
||||
runOptimization(configPath).catch(console.error);
|
||||
}
|
||||
Reference in New Issue
Block a user