feat(optimization): finalized iterative-surgical optimization suite (checkpoint)

This commit is contained in:
Abhijit Balaji
2026-03-24 14:29:05 -07:00
parent 419d674b70
commit e06a562176
7 changed files with 294 additions and 55 deletions
@@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { debugLogger } from '../../../../../packages/core/src/utils/debugLogger.js';
import { debugLogger } from '../../../../packages/core/src/utils/debugLogger.js';
import { DEFAULT_EVAL_CONFIG } from '../config.js';
import { MetricObjective } from '../types.js';
import type { MetricResult } from '../types.js';
@@ -7,77 +7,62 @@
import { describe, it, expect } from 'vitest';
import { evaluateToolAlignment } from './toolAlignment.js';
import { MetricObjective } from '../types.js';
import type { Scenario } from '../schema.js';
describe('evaluateToolAlignment', () => {
const mockScenario: Scenario = {
const mockScenario = {
id: 'test-scenario',
metadata: { tags: ['test'], created_at: '2026-03-02' },
input: { user_query: 'test query' },
expected: {
tool_calls: [{ name: 'read_file', arguments: { file_path: 'test.ts' } }],
rationale: 'Testing alignment',
tool_calls: [{ name: 'test_tool', arguments: { arg: 1 } }],
},
negatives: [
{
tool_calls: [
{ name: 'run_shell_command', arguments: { command: 'cat test.ts' } },
],
reason: 'Avoid shell',
tool_calls: [{ name: 'shell', arguments: { cmd: 'rm -rf' } }],
reason: 'Matched negative shell pattern',
severity: 'high',
},
}
],
};
} as any;
it('should return 1.0 for a perfect match', () => {
it('should return 1.0 for a perfect functional match', () => {
const prediction = {
tool_calls: [{ name: 'read_file', arguments: { file_path: 'test.ts' } }],
tool_calls: [{ name: 'test_tool', arguments: { arg: 1 } }],
};
const result = evaluateToolAlignment(prediction, mockScenario);
expect(result.score).toBe(1.0);
expect(result.objective).toBe(MetricObjective.ALIGNMENT);
expect(result.reason).toContain('Functional Success');
});
it('should return 0.0 for a hard failure (negative match)', () => {
const prediction = {
tool_calls: [
{ name: 'run_shell_command', arguments: { command: 'cat test.ts' } },
],
tool_calls: [{ name: 'shell', arguments: { cmd: 'rm -rf' } }],
};
const result = evaluateToolAlignment(prediction, mockScenario);
expect(result.score).toBe(0.0);
expect(result.reason).toContain('Hard Failure');
expect(result.metadata?.['matchedNegativeReason']).toBe('Avoid shell');
expect(result.reason).toContain('Matched negative shell pattern');
});
it('should return 0.1 for an incorrect tool selection', () => {
const prediction = {
tool_calls: [
{
name: 'write_file',
arguments: { file_path: 'test.ts', content: 'test' },
},
],
tool_calls: [{ name: 'wrong_tool', arguments: { arg: 1 } }],
};
const result = evaluateToolAlignment(prediction, mockScenario);
expect(result.score).toBe(0.1);
expect(result.reason).toContain('wrong tool');
});
it('should return 0.4 for correct tool but wrong arguments', () => {
const prediction = {
tool_calls: [{ name: 'read_file', arguments: { file_path: 'wrong.ts' } }],
tool_calls: [{ name: 'test_tool', arguments: { arg: 999 } }],
};
const result = evaluateToolAlignment(prediction, mockScenario);
expect(result.score).toBe(0.4);
expect(result.reason).toContain('arguments are incorrect');
});
it('should return 0.1 for an empty tool call list', () => {
const prediction = { tool_calls: [] };
const prediction = {
tool_calls: [],
};
const result = evaluateToolAlignment(prediction, mockScenario);
expect(result.score).toBe(0.1);
expect(result.reason).toContain('failed to produce any tool calls');
});
});
@@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { debugLogger } from '../../../../../packages/core/src/utils/debugLogger.js';
import { debugLogger } from '../../../../packages/core/src/utils/debugLogger.js';
import type { Scenario, ToolCall } from '../schema.js';
import { DEFAULT_EVAL_CONFIG } from '../config.js';
import { MetricObjective } from '../types.js';
@@ -12,7 +12,7 @@ import type { MetricResult } from '../types.js';
/**
* Evaluates the alignment of a model's predicted tool calls against a golden scenario.
* Focuses on accuracy and shell avoidance.
* Focuses strictly on functional correctness (tool selection and argument precision).
*/
export function evaluateToolAlignment(
prediction: { tool_calls: ToolCall[] },
@@ -25,6 +25,7 @@ export function evaluateToolAlignment(
debugLogger.debug(`[Eval:${scenarioId}] Evaluating tool alignment...`);
// 1. Check for Hard Failures (Explicit Negatives)
// These are for specific "Forbidden" tool uses (e.g., using shell instead of read_file)
for (const negative of negatives) {
const isNegativeMatch = negative.tool_calls.every((negCall: ToolCall) =>
predictedCalls.some(
@@ -35,26 +36,17 @@ export function evaluateToolAlignment(
);
if (isNegativeMatch && negative.tool_calls.length > 0) {
debugLogger.debug(
`[Eval:${scenarioId}] Hard Failure: Matched negative pattern.`,
);
return {
score: config.hardFailureScore,
objective: MetricObjective.ALIGNMENT,
reason: `Hard Failure: ${negative.reason}`,
metadata: {
matchedNegativeReason: negative.reason,
severity: negative.severity,
},
metadata: { matchedNegativeReason: negative.reason },
};
}
}
// 2. Structural Check
if (predictedCalls.length === 0) {
debugLogger.debug(
`[Eval:${scenarioId}] Invalid Response: No tool calls found.`,
);
return {
score: config.invalidResponseScore,
objective: MetricObjective.ALIGNMENT,
@@ -71,9 +63,6 @@ export function evaluateToolAlignment(
);
if (!namesMatch) {
debugLogger.debug(
`[Eval:${scenarioId}] Failure: Incorrect tool selection.`,
);
return {
score: config.invalidResponseScore,
objective: MetricObjective.ALIGNMENT,
@@ -91,9 +80,6 @@ export function evaluateToolAlignment(
);
if (!argsMatch) {
debugLogger.debug(
`[Eval:${scenarioId}] Partial Success: Right tool, wrong arguments.`,
);
return {
score: config.toolNameMatchOnlyScore,
objective: MetricObjective.ALIGNMENT,
@@ -102,14 +88,10 @@ export function evaluateToolAlignment(
}
// 4. Perfect Success
debugLogger.debug(
`[Eval:${scenarioId}] Perfect Functional Alignment achieved.`,
);
return {
score: config.functionalSuccessScore,
objective: MetricObjective.ALIGNMENT,
reason:
'Functional Success: Tool and arguments align perfectly with golden scenario.',
reason: 'Functional Success: Tool and arguments align perfectly.',
};
}
@@ -0,0 +1,22 @@
{
"models": {
"student": {
"provider": "google-gemini",
"modelId": "gemini-3.1-pro-preview"
},
"teacher": {
"provider": "google-gemini",
"modelId": "gemini-3.1-pro-preview"
}
},
"gepa": {
"numTrials": 5,
"minibatch": true,
"maxMetricCalls": 10
},
"paths": {
"scenarios": "data/tool_alignment.jsonl",
"targets": "data/optimization/targets.json",
"outputDir": "data/optimization"
}
}
+37
View File
@@ -0,0 +1,37 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi } from 'vitest';
import { maskVariables, unmaskVariables } from './masking.js';
vi.mock('node:fs');
describe('Optimization Pipeline Infrastructure', () => {
// --- Masking Tests ---
describe('Masking Utility', () => {
it('should mask unique template variables and restore them', () => {
const input =
'Use ${TOOL_A} to read ${FILE_PATH}. ${TOOL_A} is efficient.';
const { maskedText, maskMap } = maskVariables(input);
expect(maskedText).toContain('[[GCLI_VAR_0]]');
expect(maskedText).not.toContain('${TOOL_A}');
const restored = unmaskVariables(maskedText, maskMap);
expect(restored).toBe(input);
});
it('should handle text with no variables', () => {
const input = 'Static text.';
const { maskedText, maskMap } = maskVariables(input);
expect(maskedText).toBe(input);
expect(Object.keys(maskMap).length).toBe(0);
});
});
// Note: Extraction tests remain in extract.test.ts
// Optimization logic is verified via dry runs and Pareto frontier outputs.
});
+213
View File
@@ -0,0 +1,213 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as fs from 'node:fs';
import * as path from 'node:path';
import { fileURLToPath } from 'node:url';
import { ai, AxGEPA, ax } from '@ax-llm/ax';
import { evaluateToolAlignment } from './evals/metrics/toolAlignment.js';
import { evaluateBrevity } from './evals/metrics/brevityMetric.js';
import type { Scenario, ToolCall } from './evals/schema.js';
interface OptimizationConfig {
models: {
student: { provider: string; modelId: string };
teacher: { provider: string; modelId: string };
};
gepa: {
numTrials?: number;
minibatch?: boolean;
maxMetricCalls?: number;
};
paths?: {
scenarios?: string;
targets?: string;
outputDir?: string;
};
}
interface AxPrediction {
tool_calls?: ToolCall[];
output_text?: string;
}
interface MetricArgs {
prediction: AxPrediction;
example: Scenario;
}
let currentCallCount = 0;
let maxCallsExpected = 0;
/**
* multiObjectiveMetric: Evaluates model performance with structured logging.
*/
function multiObjectiveMetric({ prediction, example }: MetricArgs): Record<string, number> {
currentCallCount++;
const modelOutput = {
tool_calls: prediction.tool_calls || [],
output_text: prediction.output_text || '',
};
const alignment = evaluateToolAlignment(modelOutput, example);
const brevity = evaluateBrevity(modelOutput);
// eslint-disable-next-line no-console
console.log(`\n[ EVAL: ${currentCallCount}/${maxCallsExpected} | ${example.id} ]`);
// eslint-disable-next-line no-console
console.log(`Scores: Acc=${alignment.score.toFixed(2)} | Brev=${brevity.score.toFixed(2)}`);
return {
accuracy: alignment.score,
brevity: brevity.score,
};
}
/**
* Evolve a specific target snippet using GEPA.
*/
async function evolveTarget(
id: string,
allTargets: any[],
scenarios: any[],
config: OptimizationConfig,
apiKey: string,
outputDir: string
) {
// eslint-disable-next-line no-console
console.log(`\n🎯 TARGETED EVOLUTION: ${id}`);
const target = allTargets.find((t) => t.id === id);
const backgroundContext = allTargets
.filter((t) => t.id !== id)
.map((t) => `\n### ${t.id}\n${t.maskedText}`)
.join('\n');
const student = ai({
name: config.models.student.provider,
apiKey,
config: { model: config.models.student.modelId },
});
const teacher = ai({
name: config.models.teacher.provider,
apiKey,
config: { model: config.models.teacher.modelId },
});
// Standard field names to avoid signature validation errors.
const gcliProgram = ax(
'user_query:string, platform:string, tags:string[], background_context:string -> tool_calls:json, output_text:string',
{ instructions: target.maskedText }
);
const dataset = scenarios.map(s => ({ ...s, background_context: backgroundContext }));
const optimizer = new AxGEPA({
studentAI: student,
teacherAI: teacher,
numTrials: config.gepa.numTrials || 16,
minibatch: config.gepa.minibatch !== false,
verbose: true,
});
currentCallCount = 0;
maxCallsExpected = config.gepa.maxMetricCalls || 100;
const result = (await optimizer.compile(
gcliProgram,
dataset,
multiObjectiveMetric,
{
maxMetricCalls: maxCallsExpected,
}
)) as any;
// Save to consolidated registry
const resultsPath = path.join(outputDir, 'results.json');
let registry: Record<string, any> = {};
if (fs.existsSync(resultsPath)) {
registry = JSON.parse(fs.readFileSync(resultsPath, 'utf8'));
}
// The 'instruction' field in result.optimizedProgram contains the winner for the mutable part
const optimizedText = result.optimizedProgram?.instruction || "ERROR_EXTRACTING";
registry[id] = {
timestamp: new Date().toISOString(),
bestScore: result.optimizedProgram?.bestScore,
optimizedText,
stats: result.stats,
report: result.report,
paretoFront: result.paretoFront?.map((entry: any) => ({
scores: entry.scores,
isBest: entry.isBest,
text: entry.instruction || entry.program?.instruction
}))
};
fs.writeFileSync(resultsPath, JSON.stringify(registry, null, 2));
// eslint-disable-next-line no-console
console.log(`✅ Evolution complete for ${id}.`);
}
/**
* Main Optimization Runner.
*/
export async function runOptimization(configPath: string) {
const apiKey = process.env.GEMINI_API_KEY;
if (!apiKey) throw new Error('GEMINI_API_KEY is not set.');
if (!fs.existsSync(configPath)) throw new Error(`Config not found: ${configPath}`);
const config = JSON.parse(fs.readFileSync(configPath, 'utf8')) as OptimizationConfig;
const scenariosPath = config.paths?.scenarios || 'data/tool_alignment.jsonl';
const targetsPath = config.paths?.targets || 'data/optimization/targets.json';
const outputDir = config.paths?.outputDir || 'data/optimization';
if (!fs.existsSync(targetsPath)) throw new Error(`Targets file not found: ${targetsPath}`);
const allTargets = JSON.parse(fs.readFileSync(targetsPath, 'utf8'));
const scenarios = fs
.readFileSync(scenariosPath, 'utf8')
.split('\n')
.filter(Boolean)
.map((line) => {
const data = JSON.parse(line);
return {
id: data.id,
user_query: data.input.user_query,
platform: data.metadata?.platform || 'unknown',
tags: data.metadata?.tags || [],
expected: data.expected,
negatives: data.negatives || [],
};
});
// Iterative mode over ALL targets
const targetsToOptimize = allTargets;
if (config.gepa.numTrials === 1) {
// eslint-disable-next-line no-console
console.log('🧪 Micro-Trial detected: Optimizing first target.');
await evolveTarget(targetsToOptimize[0].id, allTargets, scenarios.slice(0, 2), config, apiKey, outputDir);
} else {
for (const t of targetsToOptimize) {
await evolveTarget(t.id, allTargets, scenarios, config, apiKey, outputDir);
}
}
}
// CLI Entrypoint
const currentFilePath = fileURLToPath(import.meta.url);
const isMain = process.argv[1] && fs.realpathSync(currentFilePath) === fs.realpathSync(process.argv[1]);
if (isMain) {
const configPath = path.join(path.dirname(currentFilePath), 'optimization.config.json');
runOptimization(configPath).catch(console.error);
}