feat: add role-specific statistics to telemetry and UI (cont. #15234) (#18824)

Co-authored-by: Yuna Seol <yunaseol@google.com>
This commit is contained in:
Yuna Seol
2026-02-17 12:32:30 -05:00
committed by GitHub
parent 14aabbbe8b
commit 8aca3068cf
51 changed files with 826 additions and 20 deletions
@@ -11,7 +11,7 @@ import * as SessionContext from '../contexts/SessionContext.js';
import * as SettingsContext from '../contexts/SettingsContext.js'; import * as SettingsContext from '../contexts/SettingsContext.js';
import type { LoadedSettings } from '../../config/settings.js'; import type { LoadedSettings } from '../../config/settings.js';
import type { SessionMetrics } from '../contexts/SessionContext.js'; import type { SessionMetrics } from '../contexts/SessionContext.js';
import { ToolCallDecision } from '@google/gemini-cli-core'; import { ToolCallDecision, LlmRole } from '@google/gemini-cli-core';
// Mock the context to provide controlled data for testing // Mock the context to provide controlled data for testing
vi.mock('../contexts/SessionContext.js', async (importOriginal) => { vi.mock('../contexts/SessionContext.js', async (importOriginal) => {
@@ -118,6 +118,7 @@ describe('<ModelStatsDisplay />', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}, },
}, },
tools: { tools: {
@@ -160,6 +161,7 @@ describe('<ModelStatsDisplay />', () => {
thoughts: 2, thoughts: 2,
tool: 0, tool: 0,
}, },
roles: {},
}, },
'gemini-2.5-flash': { 'gemini-2.5-flash': {
api: { totalRequests: 1, totalErrors: 0, totalLatencyMs: 50 }, api: { totalRequests: 1, totalErrors: 0, totalLatencyMs: 50 },
@@ -172,6 +174,7 @@ describe('<ModelStatsDisplay />', () => {
thoughts: 0, thoughts: 0,
tool: 3, tool: 3,
}, },
roles: {},
}, },
}, },
tools: { tools: {
@@ -214,6 +217,7 @@ describe('<ModelStatsDisplay />', () => {
thoughts: 10, thoughts: 10,
tool: 5, tool: 5,
}, },
roles: {},
}, },
'gemini-2.5-flash': { 'gemini-2.5-flash': {
api: { totalRequests: 20, totalErrors: 2, totalLatencyMs: 500 }, api: { totalRequests: 20, totalErrors: 2, totalLatencyMs: 500 },
@@ -226,6 +230,7 @@ describe('<ModelStatsDisplay />', () => {
thoughts: 20, thoughts: 20,
tool: 10, tool: 10,
}, },
roles: {},
}, },
}, },
tools: { tools: {
@@ -271,6 +276,7 @@ describe('<ModelStatsDisplay />', () => {
thoughts: 111111111, thoughts: 111111111,
tool: 222222222, tool: 222222222,
}, },
roles: {},
}, },
}, },
tools: { tools: {
@@ -309,6 +315,7 @@ describe('<ModelStatsDisplay />', () => {
thoughts: 2, thoughts: 2,
tool: 1, tool: 1,
}, },
roles: {},
}, },
}, },
tools: { tools: {
@@ -351,6 +358,7 @@ describe('<ModelStatsDisplay />', () => {
thoughts: 100, thoughts: 100,
tool: 50, tool: 50,
}, },
roles: {},
}, },
'gemini-3-flash-preview': { 'gemini-3-flash-preview': {
api: { totalRequests: 20, totalErrors: 0, totalLatencyMs: 1000 }, api: { totalRequests: 20, totalErrors: 0, totalLatencyMs: 1000 },
@@ -363,6 +371,7 @@ describe('<ModelStatsDisplay />', () => {
thoughts: 200, thoughts: 200,
tool: 100, tool: 100,
}, },
roles: {},
}, },
}, },
tools: { tools: {
@@ -390,6 +399,64 @@ describe('<ModelStatsDisplay />', () => {
const output = lastFrame(); const output = lastFrame();
expect(output).toContain('gemini-3-pro-'); expect(output).toContain('gemini-3-pro-');
expect(output).toContain('gemini-3-flash-'); expect(output).toContain('gemini-3-flash-');
});
it('should display role breakdown correctly', () => {
const { lastFrame } = renderWithMockedStats({
models: {
'gemini-2.5-pro': {
api: { totalRequests: 2, totalErrors: 0, totalLatencyMs: 200 },
tokens: {
input: 20,
prompt: 30,
candidates: 40,
total: 70,
cached: 10,
thoughts: 0,
tool: 0,
},
roles: {
[LlmRole.MAIN]: {
totalRequests: 1,
totalErrors: 0,
totalLatencyMs: 100,
tokens: {
input: 10,
prompt: 15,
candidates: 20,
total: 35,
cached: 5,
thoughts: 0,
tool: 0,
},
},
},
},
},
tools: {
totalCalls: 0,
totalSuccess: 0,
totalFail: 0,
totalDurationMs: 0,
totalDecisions: {
accept: 0,
reject: 0,
modify: 0,
[ToolCallDecision.AUTO_ACCEPT]: 0,
},
byName: {},
},
files: {
totalLinesAdded: 0,
totalLinesRemoved: 0,
},
});
const output = lastFrame();
expect(output).toContain('main');
expect(output).toContain('Input');
expect(output).toContain('Output');
expect(output).toContain('Cache Reads');
expect(output).toMatchSnapshot(); expect(output).toMatchSnapshot();
}); });
@@ -427,6 +494,7 @@ describe('<ModelStatsDisplay />', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}, },
}, },
tools: { tools: {
@@ -462,4 +530,121 @@ describe('<ModelStatsDisplay />', () => {
expect(output).toContain('Tier:'); expect(output).toContain('Tier:');
expect(output).toContain('Pro'); expect(output).toContain('Pro');
}); });
it('should handle long role name layout', () => {
// Use the longest valid role name to test layout
const longRoleName = LlmRole.UTILITY_LOOP_DETECTOR;
const { lastFrame } = renderWithMockedStats({
models: {
'gemini-2.5-pro': {
api: { totalRequests: 1, totalErrors: 0, totalLatencyMs: 100 },
tokens: {
input: 10,
prompt: 10,
candidates: 20,
total: 30,
cached: 0,
thoughts: 0,
tool: 0,
},
roles: {
[longRoleName]: {
totalRequests: 1,
totalErrors: 0,
totalLatencyMs: 100,
tokens: {
input: 10,
prompt: 10,
candidates: 20,
total: 30,
cached: 0,
thoughts: 0,
tool: 0,
},
},
},
},
},
tools: {
totalCalls: 0,
totalSuccess: 0,
totalFail: 0,
totalDurationMs: 0,
totalDecisions: {
accept: 0,
reject: 0,
modify: 0,
[ToolCallDecision.AUTO_ACCEPT]: 0,
},
byName: {},
},
files: {
totalLinesAdded: 0,
totalLinesRemoved: 0,
},
});
const output = lastFrame();
expect(output).toContain(longRoleName);
expect(output).toMatchSnapshot();
});
it('should filter out invalid role names', () => {
const invalidRoleName =
'this_is_a_very_long_role_name_that_should_be_wrapped' as LlmRole;
const { lastFrame } = renderWithMockedStats({
models: {
'gemini-2.5-pro': {
api: { totalRequests: 1, totalErrors: 0, totalLatencyMs: 100 },
tokens: {
input: 10,
prompt: 10,
candidates: 20,
total: 30,
cached: 0,
thoughts: 0,
tool: 0,
},
roles: {
[invalidRoleName]: {
totalRequests: 1,
totalErrors: 0,
totalLatencyMs: 100,
tokens: {
input: 10,
prompt: 10,
candidates: 20,
total: 30,
cached: 0,
thoughts: 0,
tool: 0,
},
},
},
},
},
tools: {
totalCalls: 0,
totalSuccess: 0,
totalFail: 0,
totalDurationMs: 0,
totalDecisions: {
accept: 0,
reject: 0,
modify: 0,
[ToolCallDecision.AUTO_ACCEPT]: 0,
},
byName: {},
},
files: {
totalLinesAdded: 0,
totalLinesRemoved: 0,
},
});
const output = lastFrame();
expect(output).not.toContain(invalidRoleName);
expect(output).toMatchSnapshot();
});
}); });
@@ -13,10 +13,17 @@ import {
calculateCacheHitRate, calculateCacheHitRate,
calculateErrorRate, calculateErrorRate,
} from '../utils/computeStats.js'; } from '../utils/computeStats.js';
import { useSessionStats } from '../contexts/SessionContext.js'; import {
useSessionStats,
type ModelMetrics,
} from '../contexts/SessionContext.js';
import { Table, type Column } from './Table.js'; import { Table, type Column } from './Table.js';
import { useSettings } from '../contexts/SettingsContext.js'; import { useSettings } from '../contexts/SettingsContext.js';
import { getDisplayString, isAutoModel } from '@google/gemini-cli-core'; import {
getDisplayString,
isAutoModel,
LlmRole,
} from '@google/gemini-cli-core';
import type { QuotaStats } from '../types.js'; import type { QuotaStats } from '../types.js';
import { QuotaStatsInfo } from './QuotaStatsInfo.js'; import { QuotaStatsInfo } from './QuotaStatsInfo.js';
@@ -25,9 +32,11 @@ interface StatRowData {
isSection?: boolean; isSection?: boolean;
isSubtle?: boolean; isSubtle?: boolean;
// Dynamic keys for model values // Dynamic keys for model values
[key: string]: string | React.ReactNode | boolean | undefined; [key: string]: string | React.ReactNode | boolean | undefined | number;
} }
type RoleMetrics = NonNullable<NonNullable<ModelMetrics['roles']>[LlmRole]>;
interface ModelStatsDisplayProps { interface ModelStatsDisplayProps {
selectedAuthType?: string; selectedAuthType?: string;
userEmail?: string; userEmail?: string;
@@ -81,6 +90,22 @@ export const ModelStatsDisplay: React.FC<ModelStatsDisplayProps> = ({
([, metrics]) => metrics.tokens.cached > 0, ([, metrics]) => metrics.tokens.cached > 0,
); );
const allRoles = [
...new Set(
activeModels.flatMap(([, metrics]) => Object.keys(metrics.roles ?? {})),
),
]
.filter((role): role is LlmRole => {
const validRoles: string[] = Object.values(LlmRole);
return validRoles.includes(role);
})
.sort((a, b) => {
if (a === b) return 0;
if (a === LlmRole.MAIN) return -1;
if (b === LlmRole.MAIN) return 1;
return a.localeCompare(b);
});
// Helper to create a row with values for each model // Helper to create a row with values for each model
const createRow = ( const createRow = (
metric: string, metric: string,
@@ -204,6 +229,60 @@ export const ModelStatsDisplay: React.FC<ModelStatsDisplayProps> = ({
), ),
); );
// Roles Section
if (allRoles.length > 0) {
// Spacer
rows.push({ metric: '' });
rows.push({ metric: 'Roles', isSection: true });
allRoles.forEach((role) => {
// Role Header Row
const roleHeaderRow: StatRowData = {
metric: role,
isSection: true,
color: theme.text.primary,
};
// We don't populate model values for the role header row
rows.push(roleHeaderRow);
const addRoleMetric = (
metric: string,
getValue: (r: RoleMetrics) => string | React.ReactNode,
) => {
const row: StatRowData = {
metric,
isSubtle: true,
};
activeModels.forEach(([name, metrics]) => {
const roleMetrics = metrics.roles?.[role];
if (roleMetrics) {
row[name] = getValue(roleMetrics);
} else {
row[name] = <Text color={theme.text.secondary}>-</Text>;
}
});
rows.push(row);
};
addRoleMetric('Requests', (r) => r.totalRequests.toLocaleString());
addRoleMetric('Input', (r) => (
<Text color={theme.text.primary}>
{r.tokens.input.toLocaleString()}
</Text>
));
addRoleMetric('Output', (r) => (
<Text color={theme.text.primary}>
{r.tokens.candidates.toLocaleString()}
</Text>
));
addRoleMetric('Cache Reads', (r) => (
<Text color={theme.text.secondary}>
{r.tokens.cached.toLocaleString()}
</Text>
));
});
}
const columns: Array<Column<StatRowData>> = [ const columns: Array<Column<StatRowData>> = [
{ {
key: 'metric', key: 'metric',
@@ -55,6 +55,7 @@ describe('<SessionSummaryDisplay />', () => {
thoughts: 300, thoughts: 300,
tool: 200, tool: 200,
}, },
roles: {},
}, },
}, },
tools: { tools: {
@@ -93,6 +93,7 @@ describe('<StatsDisplay />', () => {
thoughts: 100, thoughts: 100,
tool: 50, tool: 50,
}, },
roles: {},
}, },
'gemini-2.5-flash': { 'gemini-2.5-flash': {
api: { totalRequests: 5, totalErrors: 1, totalLatencyMs: 4500 }, api: { totalRequests: 5, totalErrors: 1, totalLatencyMs: 4500 },
@@ -105,6 +106,7 @@ describe('<StatsDisplay />', () => {
thoughts: 2000, thoughts: 2000,
tool: 1000, tool: 1000,
}, },
roles: {},
}, },
}, },
}); });
@@ -133,6 +135,7 @@ describe('<StatsDisplay />', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}, },
}, },
tools: { tools: {
@@ -227,6 +230,7 @@ describe('<StatsDisplay />', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}, },
}, },
}); });
@@ -411,6 +415,7 @@ describe('<StatsDisplay />', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}, },
}, },
}); });
@@ -44,6 +44,32 @@ exports[`<ModelStatsDisplay /> > should display conditional rows if at least one
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯" ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯"
`; `;
exports[`<ModelStatsDisplay /> > should display role breakdown correctly 1`] = `
"╭──────────────────────────────────────────────────────────────────────────────────────────────────╮
│ │
│ Model Stats For Nerds │
│ │
│ │
│ Metric gemini-2.5-pro │
│ ────────────────────────────────────────────────────────────────────────────────────────────── │
│ API │
│ Requests 2 │
│ Errors 0 (0.0%) │
│ Avg Latency 100ms │
│ Tokens │
│ Total 70 │
│ ↳ Input 20 │
│ ↳ Cache Reads 10 (33.3%) │
│ ↳ Output 40 │
│ Roles │
│ main │
│ ↳ Requests 1 │
│ ↳ Input 10 │
│ ↳ Output 20 │
│ ↳ Cache Reads 5 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯"
`;
exports[`<ModelStatsDisplay /> > should display stats for multiple models correctly 1`] = ` exports[`<ModelStatsDisplay /> > should display stats for multiple models correctly 1`] = `
"╭──────────────────────────────────────────────────────────────────────────────────────────────────╮ "╭──────────────────────────────────────────────────────────────────────────────────────────────────╮
│ │ │ │
@@ -66,6 +92,25 @@ exports[`<ModelStatsDisplay /> > should display stats for multiple models correc
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯" ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯"
`; `;
exports[`<ModelStatsDisplay /> > should filter out invalid role names 1`] = `
"╭──────────────────────────────────────────────────────────────────────────────────────────────────╮
│ │
│ Model Stats For Nerds │
│ │
│ │
│ Metric gemini-2.5-pro │
│ ────────────────────────────────────────────────────────────────────────────────────────────── │
│ API │
│ Requests 1 │
│ Errors 0 (0.0%) │
│ Avg Latency 100ms │
│ Tokens │
│ Total 30 │
│ ↳ Input 10 │
│ ↳ Output 20 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯"
`;
exports[`<ModelStatsDisplay /> > should handle large values without wrapping or overlapping 1`] = ` exports[`<ModelStatsDisplay /> > should handle large values without wrapping or overlapping 1`] = `
"╭──────────────────────────────────────────────────────────────────────────────────────────────────╮ "╭──────────────────────────────────────────────────────────────────────────────────────────────────╮
│ │ │ │
@@ -88,6 +133,31 @@ exports[`<ModelStatsDisplay /> > should handle large values without wrapping or
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯" ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯"
`; `;
exports[`<ModelStatsDisplay /> > should handle long role name layout 1`] = `
"╭──────────────────────────────────────────────────────────────────────────────────────────────────╮
│ │
│ Model Stats For Nerds │
│ │
│ │
│ Metric gemini-2.5-pro │
│ ────────────────────────────────────────────────────────────────────────────────────────────── │
│ API │
│ Requests 1 │
│ Errors 0 (0.0%) │
│ Avg Latency 100ms │
│ Tokens │
│ Total 30 │
│ ↳ Input 10 │
│ ↳ Output 20 │
│ Roles │
│ utility_loop_detector │
│ ↳ Requests 1 │
│ ↳ Input 10 │
│ ↳ Output 20 │
│ ↳ Cache Reads 0 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯"
`;
exports[`<ModelStatsDisplay /> > should handle models with long names (gemini-3-*-preview) without layout breaking 1`] = ` exports[`<ModelStatsDisplay /> > should handle models with long names (gemini-3-*-preview) without layout breaking 1`] = `
"╭──────────────────────────────────────────────────────────────────────────────╮ "╭──────────────────────────────────────────────────────────────────────────────╮
│ │ │ │
@@ -8,7 +8,7 @@ import { useState, useEffect, useCallback } from 'react';
import { Box, Text } from 'ink'; import { Box, Text } from 'ink';
import Spinner from 'ink-spinner'; import Spinner from 'ink-spinner';
import type { Config } from '@google/gemini-cli-core'; import type { Config } from '@google/gemini-cli-core';
import { debugLogger, spawnAsync } from '@google/gemini-cli-core'; import { debugLogger, spawnAsync, LlmRole } from '@google/gemini-cli-core';
import { useKeypress } from '../../hooks/useKeypress.js'; import { useKeypress } from '../../hooks/useKeypress.js';
import { keyMatchers, Command } from '../../keyMatchers.js'; import { keyMatchers, Command } from '../../keyMatchers.js';
@@ -279,6 +279,7 @@ Return a JSON object with:
}, },
abortSignal: new AbortController().signal, abortSignal: new AbortController().signal,
promptId: 'triage-duplicates', promptId: 'triage-duplicates',
role: LlmRole.UTILITY_TOOL,
}); });
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
@@ -8,7 +8,7 @@ import { useState, useEffect, useCallback, useRef } from 'react';
import { Box, Text } from 'ink'; import { Box, Text } from 'ink';
import Spinner from 'ink-spinner'; import Spinner from 'ink-spinner';
import type { Config } from '@google/gemini-cli-core'; import type { Config } from '@google/gemini-cli-core';
import { debugLogger, spawnAsync } from '@google/gemini-cli-core'; import { debugLogger, spawnAsync, LlmRole } from '@google/gemini-cli-core';
import { useKeypress } from '../../hooks/useKeypress.js'; import { useKeypress } from '../../hooks/useKeypress.js';
import { keyMatchers, Command } from '../../keyMatchers.js'; import { keyMatchers, Command } from '../../keyMatchers.js';
import { TextInput } from '../shared/TextInput.js'; import { TextInput } from '../shared/TextInput.js';
@@ -223,6 +223,7 @@ Return a JSON object with:
}, },
abortSignal: abortControllerRef.current.signal, abortSignal: abortControllerRef.current.signal,
promptId: 'triage-issues', promptId: 'triage-issues',
role: LlmRole.UTILITY_TOOL,
}); });
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
@@ -100,6 +100,7 @@ describe('SessionStatsContext', () => {
thoughts: 20, thoughts: 20,
tool: 10, tool: 10,
}, },
roles: {},
}, },
}, },
tools: { tools: {
@@ -180,6 +181,7 @@ describe('SessionStatsContext', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}, },
}, },
tools: { tools: {
@@ -6,7 +6,7 @@
import { useState, useCallback, useRef, useEffect, useMemo } from 'react'; import { useState, useCallback, useRef, useEffect, useMemo } from 'react';
import type { Config } from '@google/gemini-cli-core'; import type { Config } from '@google/gemini-cli-core';
import { debugLogger, getResponseText } from '@google/gemini-cli-core'; import { debugLogger, getResponseText, LlmRole } from '@google/gemini-cli-core';
import type { Content } from '@google/genai'; import type { Content } from '@google/genai';
import type { TextBuffer } from '../components/shared/text-buffer.js'; import type { TextBuffer } from '../components/shared/text-buffer.js';
import { isSlashCommand } from '../utils/commandUtils.js'; import { isSlashCommand } from '../utils/commandUtils.js';
@@ -110,6 +110,7 @@ export function usePromptCompletion({
{ model: 'prompt-completion' }, { model: 'prompt-completion' },
contents, contents,
signal, signal,
LlmRole.UTILITY_AUTOCOMPLETE,
); );
if (signal.aborted) { if (signal.aborted) {
@@ -29,6 +29,7 @@ describe('calculateErrorRate', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}; };
expect(calculateErrorRate(metrics)).toBe(0); expect(calculateErrorRate(metrics)).toBe(0);
}); });
@@ -45,6 +46,7 @@ describe('calculateErrorRate', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}; };
expect(calculateErrorRate(metrics)).toBe(20); expect(calculateErrorRate(metrics)).toBe(20);
}); });
@@ -63,6 +65,7 @@ describe('calculateAverageLatency', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}; };
expect(calculateAverageLatency(metrics)).toBe(0); expect(calculateAverageLatency(metrics)).toBe(0);
}); });
@@ -79,6 +82,7 @@ describe('calculateAverageLatency', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}; };
expect(calculateAverageLatency(metrics)).toBe(150); expect(calculateAverageLatency(metrics)).toBe(150);
}); });
@@ -97,6 +101,7 @@ describe('calculateCacheHitRate', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}; };
expect(calculateCacheHitRate(metrics)).toBe(0); expect(calculateCacheHitRate(metrics)).toBe(0);
}); });
@@ -113,6 +118,7 @@ describe('calculateCacheHitRate', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}; };
expect(calculateCacheHitRate(metrics)).toBe(25); expect(calculateCacheHitRate(metrics)).toBe(25);
}); });
@@ -170,6 +176,7 @@ describe('computeSessionStats', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}, },
}, },
tools: { tools: {
@@ -209,6 +216,7 @@ describe('computeSessionStats', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}, },
}, },
tools: { tools: {
@@ -25,6 +25,7 @@ import {
type GeminiChat, type GeminiChat,
type Config, type Config,
type MessageBus, type MessageBus,
LlmRole,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import { import {
SettingScope, SettingScope,
@@ -588,7 +589,8 @@ describe('Session', () => {
}), }),
]), ]),
expect.anything(), expect.anything(),
expect.anything(), expect.any(AbortSignal),
LlmRole.MAIN,
); );
}); });
@@ -35,6 +35,7 @@ import {
startupProfiler, startupProfiler,
Kind, Kind,
partListUnionToString, partListUnionToString,
LlmRole,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import * as acp from '@agentclientprotocol/sdk'; import * as acp from '@agentclientprotocol/sdk';
import { AcpFileSystemService } from './fileSystemService.js'; import { AcpFileSystemService } from './fileSystemService.js';
@@ -493,6 +494,7 @@ export class Session {
nextMessage?.parts ?? [], nextMessage?.parts ?? [],
promptId, promptId,
pendingSend.signal, pendingSend.signal,
LlmRole.MAIN,
); );
nextMessage = null; nextMessage = null;
@@ -47,6 +47,7 @@ import {
logAgentFinish, logAgentFinish,
logRecoveryAttempt, logRecoveryAttempt,
} from '../telemetry/loggers.js'; } from '../telemetry/loggers.js';
import { LlmRole } from '../telemetry/types.js';
import { import {
AgentStartEvent, AgentStartEvent,
AgentFinishEvent, AgentFinishEvent,
@@ -1407,6 +1408,7 @@ describe('LocalAgentExecutor', () => {
expect.any(Array), expect.any(Array),
expect.any(String), expect.any(String),
expect.any(AbortSignal), expect.any(AbortSignal),
LlmRole.SUBAGENT,
); );
}); });
@@ -1452,6 +1454,7 @@ describe('LocalAgentExecutor', () => {
expect.any(Array), expect.any(Array),
expect.any(String), expect.any(String),
expect.any(AbortSignal), expect.any(AbortSignal),
LlmRole.SUBAGENT,
); );
}); });
}); });
@@ -59,6 +59,7 @@ import { getVersion } from '../utils/version.js';
import { getToolCallContext } from '../utils/toolCallContext.js'; import { getToolCallContext } from '../utils/toolCallContext.js';
import { scheduleAgentTools } from './agent-scheduler.js'; import { scheduleAgentTools } from './agent-scheduler.js';
import { DeadlineTimer } from '../utils/deadlineTimer.js'; import { DeadlineTimer } from '../utils/deadlineTimer.js';
import { LlmRole } from '../telemetry/types.js';
/** A callback function to report on agent activity. */ /** A callback function to report on agent activity. */
export type ActivityCallback = (activity: SubagentActivityEvent) => void; export type ActivityCallback = (activity: SubagentActivityEvent) => void;
@@ -699,6 +700,8 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
modelToUse = requestedModel; modelToUse = requestedModel;
} }
const role = LlmRole.SUBAGENT;
const responseStream = await chat.sendMessageStream( const responseStream = await chat.sendMessageStream(
{ {
model: modelToUse, model: modelToUse,
@@ -707,6 +710,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
message.parts || [], message.parts || [],
promptId, promptId,
signal, signal,
role,
); );
const functionCalls: FunctionCall[] = []; const functionCalls: FunctionCall[] = [];
@@ -9,6 +9,7 @@ import { CodeAssistServer } from './server.js';
import { OAuth2Client } from 'google-auth-library'; import { OAuth2Client } from 'google-auth-library';
import { UserTierId, ActionStatus } from './types.js'; import { UserTierId, ActionStatus } from './types.js';
import { FinishReason } from '@google/genai'; import { FinishReason } from '@google/genai';
import { LlmRole } from '../telemetry/types.js';
vi.mock('google-auth-library'); vi.mock('google-auth-library');
@@ -69,6 +70,7 @@ describe('CodeAssistServer', () => {
contents: [{ role: 'user', parts: [{ text: 'request' }] }], contents: [{ role: 'user', parts: [{ text: 'request' }] }],
}, },
'user-prompt-id', 'user-prompt-id',
LlmRole.MAIN,
); );
expect(mockRequest).toHaveBeenCalledWith({ expect(mockRequest).toHaveBeenCalledWith({
@@ -126,6 +128,7 @@ describe('CodeAssistServer', () => {
contents: [{ role: 'user', parts: [{ text: 'request' }] }], contents: [{ role: 'user', parts: [{ text: 'request' }] }],
}, },
'user-prompt-id', 'user-prompt-id',
LlmRole.MAIN,
); );
expect(recordConversationOfferedSpy).toHaveBeenCalledWith( expect(recordConversationOfferedSpy).toHaveBeenCalledWith(
@@ -170,6 +173,7 @@ describe('CodeAssistServer', () => {
contents: [{ role: 'user', parts: [{ text: 'request' }] }], contents: [{ role: 'user', parts: [{ text: 'request' }] }],
}, },
'user-prompt-id', 'user-prompt-id',
LlmRole.MAIN,
); );
expect(server.recordCodeAssistMetrics).toHaveBeenCalledWith( expect(server.recordCodeAssistMetrics).toHaveBeenCalledWith(
@@ -208,6 +212,7 @@ describe('CodeAssistServer', () => {
contents: [{ role: 'user', parts: [{ text: 'request' }] }], contents: [{ role: 'user', parts: [{ text: 'request' }] }],
}, },
'user-prompt-id', 'user-prompt-id',
LlmRole.MAIN,
); );
const mockResponseData = { const mockResponseData = {
@@ -369,6 +374,7 @@ describe('CodeAssistServer', () => {
contents: [{ role: 'user', parts: [{ text: 'request' }] }], contents: [{ role: 'user', parts: [{ text: 'request' }] }],
}, },
'user-prompt-id', 'user-prompt-id',
LlmRole.MAIN,
); );
// Push SSE data to the stream // Push SSE data to the stream
+5
View File
@@ -53,6 +53,7 @@ import {
recordConversationOffered, recordConversationOffered,
} from './telemetry.js'; } from './telemetry.js';
import { getClientMetadata } from './experiments/client_metadata.js'; import { getClientMetadata } from './experiments/client_metadata.js';
import type { LlmRole } from '../telemetry/types.js';
/** HTTP options to be used in each of the requests. */ /** HTTP options to be used in each of the requests. */
export interface HttpOptions { export interface HttpOptions {
/** Additional HTTP headers to be sent with the request. */ /** Additional HTTP headers to be sent with the request. */
@@ -75,6 +76,8 @@ export class CodeAssistServer implements ContentGenerator {
async generateContentStream( async generateContentStream(
req: GenerateContentParameters, req: GenerateContentParameters,
userPromptId: string, userPromptId: string,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
role: LlmRole,
): Promise<AsyncGenerator<GenerateContentResponse>> { ): Promise<AsyncGenerator<GenerateContentResponse>> {
const responses = const responses =
await this.requestStreamingPost<CaGenerateContentResponse>( await this.requestStreamingPost<CaGenerateContentResponse>(
@@ -125,6 +128,8 @@ export class CodeAssistServer implements ContentGenerator {
async generateContent( async generateContent(
req: GenerateContentParameters, req: GenerateContentParameters,
userPromptId: string, userPromptId: string,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
role: LlmRole,
): Promise<GenerateContentResponse> { ): Promise<GenerateContentResponse> {
const start = Date.now(); const start = Date.now();
const response = await this.requestPost<CaGenerateContentResponse>( const response = await this.requestPost<CaGenerateContentResponse>(
@@ -30,6 +30,7 @@ import { MalformedJsonResponseEvent } from '../telemetry/types.js';
import { getErrorMessage } from '../utils/errors.js'; import { getErrorMessage } from '../utils/errors.js';
import type { ModelConfigService } from '../services/modelConfigService.js'; import type { ModelConfigService } from '../services/modelConfigService.js';
import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js'; import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js';
import { LlmRole } from '../telemetry/types.js';
vi.mock('../utils/errorReporting.js'); vi.mock('../utils/errorReporting.js');
vi.mock('../telemetry/loggers.js'); vi.mock('../telemetry/loggers.js');
@@ -128,6 +129,7 @@ describe('BaseLlmClient', () => {
schema: { type: 'object', properties: { color: { type: 'string' } } }, schema: { type: 'object', properties: { color: { type: 'string' } } },
abortSignal: abortController.signal, abortSignal: abortController.signal,
promptId: 'test-prompt-id', promptId: 'test-prompt-id',
role: LlmRole.UTILITY_TOOL,
}; };
}); });
@@ -169,6 +171,7 @@ describe('BaseLlmClient', () => {
}, },
}, },
'test-prompt-id', 'test-prompt-id',
LlmRole.UTILITY_TOOL,
); );
}); });
@@ -191,6 +194,7 @@ describe('BaseLlmClient', () => {
}), }),
}), }),
expect.any(String), expect.any(String),
LlmRole.UTILITY_TOOL,
); );
}); });
@@ -209,6 +213,7 @@ describe('BaseLlmClient', () => {
expect(mockGenerateContent).toHaveBeenCalledWith( expect(mockGenerateContent).toHaveBeenCalledWith(
expect.any(Object), expect.any(Object),
customPromptId, customPromptId,
LlmRole.UTILITY_TOOL,
); );
}); });
@@ -528,6 +533,7 @@ describe('BaseLlmClient', () => {
contents: [{ role: 'user', parts: [{ text: 'Give me content.' }] }], contents: [{ role: 'user', parts: [{ text: 'Give me content.' }] }],
abortSignal: abortController.signal, abortSignal: abortController.signal,
promptId: 'content-prompt-id', promptId: 'content-prompt-id',
role: LlmRole.UTILITY_TOOL,
}; };
const result = await client.generateContent(options); const result = await client.generateContent(options);
@@ -556,6 +562,7 @@ describe('BaseLlmClient', () => {
}, },
}, },
'content-prompt-id', 'content-prompt-id',
LlmRole.UTILITY_TOOL,
); );
}); });
@@ -568,6 +575,7 @@ describe('BaseLlmClient', () => {
contents: [{ role: 'user', parts: [{ text: 'Give me content.' }] }], contents: [{ role: 'user', parts: [{ text: 'Give me content.' }] }],
abortSignal: abortController.signal, abortSignal: abortController.signal,
promptId: 'content-prompt-id', promptId: 'content-prompt-id',
role: LlmRole.UTILITY_TOOL,
}; };
await client.generateContent(options); await client.generateContent(options);
@@ -590,6 +598,7 @@ describe('BaseLlmClient', () => {
contents: [{ role: 'user', parts: [{ text: 'Give me content.' }] }], contents: [{ role: 'user', parts: [{ text: 'Give me content.' }] }],
abortSignal: abortController.signal, abortSignal: abortController.signal,
promptId: 'content-prompt-id', promptId: 'content-prompt-id',
role: LlmRole.UTILITY_TOOL,
}; };
await expect(client.generateContent(options)).rejects.toThrow( await expect(client.generateContent(options)).rejects.toThrow(
@@ -634,6 +643,7 @@ describe('BaseLlmClient', () => {
contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }], contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }],
abortSignal: abortController.signal, abortSignal: abortController.signal,
promptId: 'content-prompt-id', promptId: 'content-prompt-id',
role: LlmRole.UTILITY_TOOL,
}; };
jsonOptions = { jsonOptions = {
@@ -655,6 +665,7 @@ describe('BaseLlmClient', () => {
await client.generateContent({ await client.generateContent({
...contentOptions, ...contentOptions,
modelConfigKey: { model: successfulModel }, modelConfigKey: { model: successfulModel },
role: LlmRole.UTILITY_TOOL,
}); });
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith( expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
@@ -680,6 +691,7 @@ describe('BaseLlmClient', () => {
...contentOptions, ...contentOptions,
modelConfigKey: { model: firstModel }, modelConfigKey: { model: firstModel },
maxAttempts: 2, maxAttempts: 2,
role: LlmRole.UTILITY_TOOL,
}); });
await vi.runAllTimersAsync(); await vi.runAllTimersAsync();
@@ -689,6 +701,7 @@ describe('BaseLlmClient', () => {
...contentOptions, ...contentOptions,
modelConfigKey: { model: firstModel }, modelConfigKey: { model: firstModel },
maxAttempts: 2, maxAttempts: 2,
role: LlmRole.UTILITY_TOOL,
}); });
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(firstModel); expect(mockConfig.setActiveModel).toHaveBeenCalledWith(firstModel);
@@ -699,6 +712,7 @@ describe('BaseLlmClient', () => {
expect(mockGenerateContent).toHaveBeenLastCalledWith( expect(mockGenerateContent).toHaveBeenLastCalledWith(
expect.objectContaining({ model: fallbackModel }), expect.objectContaining({ model: fallbackModel }),
expect.any(String), expect.any(String),
LlmRole.UTILITY_TOOL,
); );
}); });
@@ -724,6 +738,7 @@ describe('BaseLlmClient', () => {
await client.generateContent({ await client.generateContent({
...contentOptions, ...contentOptions,
modelConfigKey: { model: stickyModel }, modelConfigKey: { model: stickyModel },
role: LlmRole.UTILITY_TOOL,
}); });
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith( expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
@@ -763,6 +778,7 @@ describe('BaseLlmClient', () => {
expect(mockGenerateContent).toHaveBeenLastCalledWith( expect(mockGenerateContent).toHaveBeenLastCalledWith(
expect.objectContaining({ model: availableModel }), expect.objectContaining({ model: availableModel }),
jsonOptions.promptId, jsonOptions.promptId,
LlmRole.UTILITY_TOOL,
); );
}); });
@@ -814,6 +830,7 @@ describe('BaseLlmClient', () => {
...contentOptions, ...contentOptions,
modelConfigKey: { model: firstModel }, modelConfigKey: { model: firstModel },
maxAttempts: 2, maxAttempts: 2,
role: LlmRole.UTILITY_TOOL,
}); });
expect(mockGenerateContent).toHaveBeenCalledTimes(2); expect(mockGenerateContent).toHaveBeenCalledTimes(2);
+19 -1
View File
@@ -27,6 +27,7 @@ import {
applyModelSelection, applyModelSelection,
createAvailabilityContextProvider, createAvailabilityContextProvider,
} from '../availability/policyHelpers.js'; } from '../availability/policyHelpers.js';
import { LlmRole } from '../telemetry/types.js';
const DEFAULT_MAX_ATTEMPTS = 5; const DEFAULT_MAX_ATTEMPTS = 5;
@@ -51,6 +52,10 @@ export interface GenerateJsonOptions {
* A unique ID for the prompt, used for logging/telemetry correlation. * A unique ID for the prompt, used for logging/telemetry correlation.
*/ */
promptId: string; promptId: string;
/**
* The role of the LLM call.
*/
role: LlmRole;
/** /**
* The maximum number of attempts for the request. * The maximum number of attempts for the request.
*/ */
@@ -76,6 +81,10 @@ export interface GenerateContentOptions {
* A unique ID for the prompt, used for logging/telemetry correlation. * A unique ID for the prompt, used for logging/telemetry correlation.
*/ */
promptId: string; promptId: string;
/**
* The role of the LLM call.
*/
role: LlmRole;
/** /**
* The maximum number of attempts for the request. * The maximum number of attempts for the request.
*/ */
@@ -115,6 +124,7 @@ export class BaseLlmClient {
systemInstruction, systemInstruction,
abortSignal, abortSignal,
promptId, promptId,
role,
maxAttempts, maxAttempts,
} = options; } = options;
@@ -150,6 +160,7 @@ export class BaseLlmClient {
}, },
shouldRetryOnContent, shouldRetryOnContent,
'generateJson', 'generateJson',
role,
); );
// If we are here, the content is valid (not empty and parsable). // If we are here, the content is valid (not empty and parsable).
@@ -215,6 +226,7 @@ export class BaseLlmClient {
systemInstruction, systemInstruction,
abortSignal, abortSignal,
promptId, promptId,
role,
maxAttempts, maxAttempts,
} = options; } = options;
@@ -234,6 +246,7 @@ export class BaseLlmClient {
}, },
shouldRetryOnContent, shouldRetryOnContent,
'generateContent', 'generateContent',
role,
); );
} }
@@ -241,6 +254,7 @@ export class BaseLlmClient {
options: _CommonGenerateOptions, options: _CommonGenerateOptions,
shouldRetryOnContent: (response: GenerateContentResponse) => boolean, shouldRetryOnContent: (response: GenerateContentResponse) => boolean,
errorContext: 'generateJson' | 'generateContent', errorContext: 'generateJson' | 'generateContent',
role: LlmRole = LlmRole.UTILITY_TOOL,
): Promise<GenerateContentResponse> { ): Promise<GenerateContentResponse> {
const { const {
modelConfigKey, modelConfigKey,
@@ -293,7 +307,11 @@ export class BaseLlmClient {
config: finalConfig, config: finalConfig,
contents, contents,
}; };
return this.contentGenerator.generateContent(requestParams, promptId); return this.contentGenerator.generateContent(
requestParams,
promptId,
role,
);
}; };
return await retryWithBackoff(apiCall, { return await retryWithBackoff(apiCall, {
+5
View File
@@ -47,6 +47,7 @@ import type {
} from '../services/modelConfigService.js'; } from '../services/modelConfigService.js';
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js'; import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
import * as policyCatalog from '../availability/policyCatalog.js'; import * as policyCatalog from '../availability/policyCatalog.js';
import { LlmRole } from '../telemetry/types.js';
import { partToString } from '../utils/partUtils.js'; import { partToString } from '../utils/partUtils.js';
import { coreEvents } from '../utils/events.js'; import { coreEvents } from '../utils/events.js';
@@ -2913,6 +2914,7 @@ ${JSON.stringify(
{ model: 'test-model' }, { model: 'test-model' },
contents, contents,
abortSignal, abortSignal,
LlmRole.MAIN,
); );
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
@@ -2927,6 +2929,7 @@ ${JSON.stringify(
contents, contents,
}, },
'test-session-id', 'test-session-id',
LlmRole.MAIN,
); );
}); });
@@ -2938,6 +2941,7 @@ ${JSON.stringify(
{ model: initialModel }, { model: initialModel },
contents, contents,
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
@@ -2945,6 +2949,7 @@ ${JSON.stringify(
model: initialModel, model: initialModel,
}), }),
'test-session-id', 'test-session-id',
LlmRole.MAIN,
); );
}); });
+3
View File
@@ -64,6 +64,7 @@ import { resolveModel } from '../config/models.js';
import type { RetryAvailabilityContext } from '../utils/retry.js'; import type { RetryAvailabilityContext } from '../utils/retry.js';
import { partToString } from '../utils/partUtils.js'; import { partToString } from '../utils/partUtils.js';
import { coreEvents, CoreEvent } from '../utils/events.js'; import { coreEvents, CoreEvent } from '../utils/events.js';
import type { LlmRole } from '../telemetry/types.js';
const MAX_TURNS = 100; const MAX_TURNS = 100;
@@ -925,6 +926,7 @@ export class GeminiClient {
modelConfigKey: ModelConfigKey, modelConfigKey: ModelConfigKey,
contents: Content[], contents: Content[],
abortSignal: AbortSignal, abortSignal: AbortSignal,
role: LlmRole,
): Promise<GenerateContentResponse> { ): Promise<GenerateContentResponse> {
const desiredModelConfig = const desiredModelConfig =
this.config.modelConfigService.getResolvedConfig(modelConfigKey); this.config.modelConfigService.getResolvedConfig(modelConfigKey);
@@ -979,6 +981,7 @@ export class GeminiClient {
contents, contents,
}, },
this.lastPromptId, this.lastPromptId,
role,
); );
}; };
const onPersistent429Callback = async ( const onPersistent429Callback = async (
@@ -24,6 +24,7 @@ import { FakeContentGenerator } from './fakeContentGenerator.js';
import { parseCustomHeaders } from '../utils/customHeaderUtils.js'; import { parseCustomHeaders } from '../utils/customHeaderUtils.js';
import { RecordingContentGenerator } from './recordingContentGenerator.js'; import { RecordingContentGenerator } from './recordingContentGenerator.js';
import { getVersion, resolveModel } from '../../index.js'; import { getVersion, resolveModel } from '../../index.js';
import type { LlmRole } from '../telemetry/llmRole.js';
/** /**
* Interface abstracting the core functionalities for generating content and counting tokens. * Interface abstracting the core functionalities for generating content and counting tokens.
@@ -32,11 +33,13 @@ export interface ContentGenerator {
generateContent( generateContent(
request: GenerateContentParameters, request: GenerateContentParameters,
userPromptId: string, userPromptId: string,
role: LlmRole,
): Promise<GenerateContentResponse>; ): Promise<GenerateContentResponse>;
generateContentStream( generateContentStream(
request: GenerateContentParameters, request: GenerateContentParameters,
userPromptId: string, userPromptId: string,
role: LlmRole,
): Promise<AsyncGenerator<GenerateContentResponse>>; ): Promise<AsyncGenerator<GenerateContentResponse>>;
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>; countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
@@ -18,6 +18,7 @@ import {
type CountTokensParameters, type CountTokensParameters,
type EmbedContentParameters, type EmbedContentParameters,
} from '@google/genai'; } from '@google/genai';
import { LlmRole } from '../telemetry/types.js';
vi.mock('node:fs', async (importOriginal) => { vi.mock('node:fs', async (importOriginal) => {
const actual = await importOriginal<typeof import('node:fs')>(); const actual = await importOriginal<typeof import('node:fs')>();
@@ -79,6 +80,7 @@ describe('FakeContentGenerator', () => {
const response = await generator.generateContent( const response = await generator.generateContent(
{} as GenerateContentParameters, {} as GenerateContentParameters,
'id', 'id',
LlmRole.MAIN,
); );
expect(response).instanceOf(GenerateContentResponse); expect(response).instanceOf(GenerateContentResponse);
expect(response).toEqual(fakeGenerateContentResponse.response); expect(response).toEqual(fakeGenerateContentResponse.response);
@@ -91,6 +93,7 @@ describe('FakeContentGenerator', () => {
const stream = await generator.generateContentStream( const stream = await generator.generateContentStream(
{} as GenerateContentParameters, {} as GenerateContentParameters,
'id', 'id',
LlmRole.MAIN,
); );
const responses = []; const responses = [];
for await (const response of stream) { for await (const response of stream) {
@@ -121,7 +124,11 @@ describe('FakeContentGenerator', () => {
]; ];
const generator = new FakeContentGenerator(fakeResponses); const generator = new FakeContentGenerator(fakeResponses);
for (const fakeResponse of fakeResponses) { for (const fakeResponse of fakeResponses) {
const response = await generator[fakeResponse.method]({} as never, ''); const response = await generator[fakeResponse.method](
{} as never,
'',
LlmRole.MAIN,
);
if (fakeResponse.method === 'generateContentStream') { if (fakeResponse.method === 'generateContentStream') {
const responses = []; const responses = [];
for await (const item of response as AsyncGenerator<GenerateContentResponse>) { for await (const item of response as AsyncGenerator<GenerateContentResponse>) {
@@ -137,7 +144,11 @@ describe('FakeContentGenerator', () => {
it('should throw error when no more responses', async () => { it('should throw error when no more responses', async () => {
const generator = new FakeContentGenerator([fakeGenerateContentResponse]); const generator = new FakeContentGenerator([fakeGenerateContentResponse]);
await generator.generateContent({} as GenerateContentParameters, 'id'); await generator.generateContent(
{} as GenerateContentParameters,
'id',
LlmRole.MAIN,
);
await expect( await expect(
generator.embedContent({} as EmbedContentParameters), generator.embedContent({} as EmbedContentParameters),
).rejects.toThrowError('No more mock responses for embedContent'); ).rejects.toThrowError('No more mock responses for embedContent');
@@ -145,10 +156,18 @@ describe('FakeContentGenerator', () => {
generator.countTokens({} as CountTokensParameters), generator.countTokens({} as CountTokensParameters),
).rejects.toThrowError('No more mock responses for countTokens'); ).rejects.toThrowError('No more mock responses for countTokens');
await expect( await expect(
generator.generateContentStream({} as GenerateContentParameters, 'id'), generator.generateContentStream(
{} as GenerateContentParameters,
'id',
LlmRole.MAIN,
),
).rejects.toThrow('No more mock responses for generateContentStream'); ).rejects.toThrow('No more mock responses for generateContentStream');
await expect( await expect(
generator.generateContent({} as GenerateContentParameters, 'id'), generator.generateContent(
{} as GenerateContentParameters,
'id',
LlmRole.MAIN,
),
).rejects.toThrowError('No more mock responses for generateContent'); ).rejects.toThrowError('No more mock responses for generateContent');
}); });
@@ -161,6 +180,7 @@ describe('FakeContentGenerator', () => {
const response = await generator.generateContent( const response = await generator.generateContent(
{} as GenerateContentParameters, {} as GenerateContentParameters,
'id', 'id',
LlmRole.MAIN,
); );
expect(response).toEqual(fakeGenerateContentResponse.response); expect(response).toEqual(fakeGenerateContentResponse.response);
}); });
@@ -16,6 +16,7 @@ import { promises } from 'node:fs';
import type { ContentGenerator } from './contentGenerator.js'; import type { ContentGenerator } from './contentGenerator.js';
import type { UserTierId } from '../code_assist/types.js'; import type { UserTierId } from '../code_assist/types.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js'; import { safeJsonStringify } from '../utils/safeJsonStringify.js';
import type { LlmRole } from '../telemetry/types.js';
export type FakeResponse = export type FakeResponse =
| { | {
@@ -79,6 +80,8 @@ export class FakeContentGenerator implements ContentGenerator {
async generateContent( async generateContent(
request: GenerateContentParameters, request: GenerateContentParameters,
_userPromptId: string, _userPromptId: string,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
role: LlmRole,
): Promise<GenerateContentResponse> { ): Promise<GenerateContentResponse> {
return Object.setPrototypeOf( return Object.setPrototypeOf(
this.getNextResponse('generateContent', request), this.getNextResponse('generateContent', request),
@@ -89,6 +92,8 @@ export class FakeContentGenerator implements ContentGenerator {
async generateContentStream( async generateContentStream(
request: GenerateContentParameters, request: GenerateContentParameters,
_userPromptId: string, _userPromptId: string,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
role: LlmRole,
): Promise<AsyncGenerator<GenerateContentResponse>> { ): Promise<AsyncGenerator<GenerateContentResponse>> {
const responses = this.getNextResponse('generateContentStream', request); const responses = this.getNextResponse('generateContentStream', request);
async function* stream() { async function* stream() {
+57 -3
View File
@@ -28,6 +28,7 @@ import type { ModelAvailabilityService } from '../availability/modelAvailability
import * as policyHelpers from '../availability/policyHelpers.js'; import * as policyHelpers from '../availability/policyHelpers.js';
import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js'; import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js';
import type { HookSystem } from '../hooks/hookSystem.js'; import type { HookSystem } from '../hooks/hookSystem.js';
import { LlmRole } from '../telemetry/types.js';
// Mock fs module to prevent actual file system operations during tests // Mock fs module to prevent actual file system operations during tests
const mockFileSystem = new Map<string, string>(); const mockFileSystem = new Map<string, string>();
@@ -287,6 +288,7 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-tool-call-empty-end', 'prompt-id-tool-call-empty-end',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
await expect( await expect(
(async () => { (async () => {
@@ -340,6 +342,7 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-no-finish-empty-end', 'prompt-id-no-finish-empty-end',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
await expect( await expect(
(async () => { (async () => {
@@ -387,6 +390,7 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-valid-then-invalid-end', 'prompt-id-valid-then-invalid-end',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
await expect( await expect(
(async () => { (async () => {
@@ -435,6 +439,7 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-empty-chunk-consolidation', 'prompt-id-empty-chunk-consolidation',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
for await (const _ of stream) { for await (const _ of stream) {
// Consume the stream // Consume the stream
@@ -494,6 +499,7 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-multi-chunk', 'prompt-id-multi-chunk',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
for await (const _ of stream) { for await (const _ of stream) {
// Consume the stream to trigger history recording. // Consume the stream to trigger history recording.
@@ -543,6 +549,7 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-mixed-chunk', 'prompt-id-mixed-chunk',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
for await (const _ of stream) { for await (const _ of stream) {
// This loop consumes the stream. // This loop consumes the stream.
@@ -612,6 +619,7 @@ describe('GeminiChat', () => {
}, },
'prompt-id-stream-1', 'prompt-id-stream-1',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
// 4. Assert: The stream processing should throw an InvalidStreamError. // 4. Assert: The stream processing should throw an InvalidStreamError.
@@ -656,6 +664,7 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-1', 'prompt-id-1',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
// Should not throw an error // Should not throw an error
@@ -693,6 +702,7 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-1', 'prompt-id-1',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
await expect( await expect(
@@ -729,6 +739,7 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-1', 'prompt-id-1',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
await expect( await expect(
@@ -765,6 +776,7 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-1', 'prompt-id-1',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
// Should not throw an error // Should not throw an error
@@ -802,6 +814,7 @@ describe('GeminiChat', () => {
'test', 'test',
'prompt-id-malformed', 'prompt-id-malformed',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
// Should throw an error // Should throw an error
@@ -849,6 +862,7 @@ describe('GeminiChat', () => {
'test retry', 'test retry',
'prompt-id-retry-malformed', 'prompt-id-retry-malformed',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const events: StreamEvent[] = []; const events: StreamEvent[] = [];
for await (const event of stream) { for await (const event of stream) {
@@ -906,6 +920,7 @@ describe('GeminiChat', () => {
'hello', 'hello',
'prompt-id-1', 'prompt-id-1',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
for await (const _ of stream) { for await (const _ of stream) {
// consume stream // consume stream
@@ -931,6 +946,7 @@ describe('GeminiChat', () => {
}, },
}, },
'prompt-id-1', 'prompt-id-1',
LlmRole.MAIN,
); );
}); });
@@ -954,6 +970,7 @@ describe('GeminiChat', () => {
'hello', 'hello',
'prompt-id-thinking-level', 'prompt-id-thinking-level',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
for await (const _ of stream) { for await (const _ of stream) {
// consume stream // consume stream
@@ -970,6 +987,7 @@ describe('GeminiChat', () => {
}), }),
}), }),
'prompt-id-thinking-level', 'prompt-id-thinking-level',
LlmRole.MAIN,
); );
}); });
@@ -993,6 +1011,7 @@ describe('GeminiChat', () => {
'hello', 'hello',
'prompt-id-thinking-budget', 'prompt-id-thinking-budget',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
for await (const _ of stream) { for await (const _ of stream) {
// consume stream // consume stream
@@ -1003,12 +1022,13 @@ describe('GeminiChat', () => {
model: 'gemini-2.0-flash', model: 'gemini-2.0-flash',
config: expect.objectContaining({ config: expect.objectContaining({
thinkingConfig: { thinkingConfig: {
thinkingBudget: DEFAULT_THINKING_MODE, thinkingBudget: 8192,
thinkingLevel: undefined, thinkingLevel: undefined,
}, },
}), }),
}), }),
'prompt-id-thinking-budget', 'prompt-id-thinking-budget',
LlmRole.MAIN,
); );
}); });
}); });
@@ -1060,6 +1080,7 @@ describe('GeminiChat', () => {
'test', 'test',
'prompt-id-no-retry', 'prompt-id-no-retry',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
await expect( await expect(
@@ -1108,6 +1129,7 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-yield-retry', 'prompt-id-yield-retry',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const events: StreamEvent[] = []; const events: StreamEvent[] = [];
for await (const event of stream) { for await (const event of stream) {
@@ -1150,6 +1172,7 @@ describe('GeminiChat', () => {
'test', 'test',
'prompt-id-retry-success', 'prompt-id-retry-success',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const chunks: StreamEvent[] = []; const chunks: StreamEvent[] = [];
for await (const chunk of stream) { for await (const chunk of stream) {
@@ -1222,6 +1245,7 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-retry-temperature', 'prompt-id-retry-temperature',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
for await (const _ of stream) { for await (const _ of stream) {
@@ -1243,6 +1267,7 @@ describe('GeminiChat', () => {
}), }),
}), }),
'prompt-id-retry-temperature', 'prompt-id-retry-temperature',
LlmRole.MAIN,
); );
// Second call (retry) should have temperature 1 // Second call (retry) should have temperature 1
@@ -1256,6 +1281,7 @@ describe('GeminiChat', () => {
}), }),
}), }),
'prompt-id-retry-temperature', 'prompt-id-retry-temperature',
LlmRole.MAIN,
); );
}); });
@@ -1281,6 +1307,7 @@ describe('GeminiChat', () => {
'test', 'test',
'prompt-id-retry-fail', 'prompt-id-retry-fail',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
await expect(async () => { await expect(async () => {
for await (const _ of stream) { for await (const _ of stream) {
@@ -1347,6 +1374,7 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-400', 'prompt-id-400',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
await expect( await expect(
@@ -1386,9 +1414,11 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-429-retry', 'prompt-id-429-retry',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const events: StreamEvent[] = []; const events: StreamEvent[] = [];
for await (const event of stream) { for await (const event of stream) {
events.push(event); events.push(event);
} }
@@ -1435,9 +1465,11 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-500-retry', 'prompt-id-500-retry',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const events: StreamEvent[] = []; const events: StreamEvent[] = [];
for await (const event of stream) { for await (const event of stream) {
events.push(event); events.push(event);
} }
@@ -1492,9 +1524,11 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-fetch-error-retry', 'prompt-id-fetch-error-retry',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const events: StreamEvent[] = []; const events: StreamEvent[] = [];
for await (const event of stream) { for await (const event of stream) {
events.push(event); events.push(event);
} }
@@ -1556,6 +1590,7 @@ describe('GeminiChat', () => {
'Second question', 'Second question',
'prompt-id-retry-existing', 'prompt-id-retry-existing',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
for await (const _ of stream) { for await (const _ of stream) {
// consume stream // consume stream
@@ -1628,6 +1663,7 @@ describe('GeminiChat', () => {
'test empty stream', 'test empty stream',
'prompt-id-empty-stream', 'prompt-id-empty-stream',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const chunks: StreamEvent[] = []; const chunks: StreamEvent[] = [];
for await (const chunk of stream) { for await (const chunk of stream) {
@@ -1709,6 +1745,7 @@ describe('GeminiChat', () => {
'first', 'first',
'prompt-1', 'prompt-1',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const firstStreamIterator = firstStream[Symbol.asyncIterator](); const firstStreamIterator = firstStream[Symbol.asyncIterator]();
await firstStreamIterator.next(); await firstStreamIterator.next();
@@ -1719,6 +1756,7 @@ describe('GeminiChat', () => {
'second', 'second',
'prompt-2', 'prompt-2',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
// 5. Assert that only one API call has been made so far. // 5. Assert that only one API call has been made so far.
@@ -1824,6 +1862,7 @@ describe('GeminiChat', () => {
'trigger 429', 'trigger 429',
'prompt-id-fb1', 'prompt-id-fb1',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
// Consume stream to trigger logic // Consume stream to trigger logic
@@ -1890,6 +1929,7 @@ describe('GeminiChat', () => {
'test message', 'test message',
'prompt-id-discard-test', 'prompt-id-discard-test',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const events: StreamEvent[] = []; const events: StreamEvent[] = [];
for await (const event of stream) { for await (const event of stream) {
@@ -2106,6 +2146,7 @@ describe('GeminiChat', () => {
'test', 'test',
'prompt-healthy', 'prompt-healthy',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
for await (const _ of stream) { for await (const _ of stream) {
// consume // consume
@@ -2141,6 +2182,7 @@ describe('GeminiChat', () => {
'test', 'test',
'prompt-sticky-once', 'prompt-sticky-once',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
for await (const _ of stream) { for await (const _ of stream) {
// consume // consume
@@ -2191,6 +2233,7 @@ describe('GeminiChat', () => {
'test', 'test',
'prompt-fallback-arg', 'prompt-fallback-arg',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
for await (const _ of stream) { for await (const _ of stream) {
// consume // consume
@@ -2269,6 +2312,7 @@ describe('GeminiChat', () => {
'test', 'test',
'prompt-config-refresh', 'prompt-config-refresh',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
// Consume to drive both attempts // Consume to drive both attempts
for await (const _ of stream) { for await (const _ of stream) {
@@ -2281,9 +2325,12 @@ describe('GeminiChat', () => {
1, 1,
expect.objectContaining({ expect.objectContaining({
model: 'model-a', model: 'model-a',
config: expect.objectContaining({ temperature: 0.1 }), config: expect.objectContaining({
temperature: 0.1,
}),
}), }),
expect.any(String), expect.any(String),
LlmRole.MAIN,
); );
expect( expect(
mockContentGenerator.generateContentStream, mockContentGenerator.generateContentStream,
@@ -2291,9 +2338,12 @@ describe('GeminiChat', () => {
2, 2,
expect.objectContaining({ expect.objectContaining({
model: 'model-b', model: 'model-b',
config: expect.objectContaining({ temperature: 0.9 }), config: expect.objectContaining({
temperature: 0.9,
}),
}), }),
expect.any(String), expect.any(String),
LlmRole.MAIN,
); );
}); });
}); });
@@ -2323,6 +2373,7 @@ describe('GeminiChat', () => {
'test', 'test',
'prompt-id', 'prompt-id',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const events: StreamEvent[] = []; const events: StreamEvent[] = [];
@@ -2353,6 +2404,7 @@ describe('GeminiChat', () => {
'test', 'test',
'prompt-id', 'prompt-id',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const events: StreamEvent[] = []; const events: StreamEvent[] = [];
@@ -2392,6 +2444,7 @@ describe('GeminiChat', () => {
'test', 'test',
'prompt-id', 'prompt-id',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const events: StreamEvent[] = []; const events: StreamEvent[] = [];
@@ -2428,6 +2481,7 @@ describe('GeminiChat', () => {
'test', 'test',
'prompt-id', 'prompt-id',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const events: StreamEvent[] = []; const events: StreamEvent[] = [];
+5
View File
@@ -55,6 +55,7 @@ import {
createAvailabilityContextProvider, createAvailabilityContextProvider,
} from '../availability/policyHelpers.js'; } from '../availability/policyHelpers.js';
import { coreEvents } from '../utils/events.js'; import { coreEvents } from '../utils/events.js';
import type { LlmRole } from '../telemetry/types.js';
export enum StreamEventType { export enum StreamEventType {
/** A regular content chunk from the API. */ /** A regular content chunk from the API. */
@@ -292,6 +293,7 @@ export class GeminiChat {
message: PartListUnion, message: PartListUnion,
prompt_id: string, prompt_id: string,
signal: AbortSignal, signal: AbortSignal,
role: LlmRole,
displayContent?: PartListUnion, displayContent?: PartListUnion,
): Promise<AsyncGenerator<StreamEvent>> { ): Promise<AsyncGenerator<StreamEvent>> {
await this.sendPromise; await this.sendPromise;
@@ -362,6 +364,7 @@ export class GeminiChat {
requestContents, requestContents,
prompt_id, prompt_id,
signal, signal,
role,
); );
isConnectionPhase = false; isConnectionPhase = false;
for await (const chunk of stream) { for await (const chunk of stream) {
@@ -467,6 +470,7 @@ export class GeminiChat {
requestContents: Content[], requestContents: Content[],
prompt_id: string, prompt_id: string,
abortSignal: AbortSignal, abortSignal: AbortSignal,
role: LlmRole,
): Promise<AsyncGenerator<GenerateContentResponse>> { ): Promise<AsyncGenerator<GenerateContentResponse>> {
const contentsForPreviewModel = const contentsForPreviewModel =
this.ensureActiveLoopHasThoughtSignatures(requestContents); this.ensureActiveLoopHasThoughtSignatures(requestContents);
@@ -599,6 +603,7 @@ export class GeminiChat {
config, config,
}, },
prompt_id, prompt_id,
role,
); );
}; };
@@ -14,6 +14,7 @@ import { setSimulate429 } from '../utils/testUtils.js';
import { HookSystem } from '../hooks/hookSystem.js'; import { HookSystem } from '../hooks/hookSystem.js';
import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
import { createAvailabilityServiceMock } from '../availability/testUtils.js'; import { createAvailabilityServiceMock } from '../availability/testUtils.js';
import { LlmRole } from '../telemetry/types.js';
// Mock fs module // Mock fs module
vi.mock('node:fs', async (importOriginal) => { vi.mock('node:fs', async (importOriginal) => {
@@ -154,6 +155,7 @@ describe('GeminiChat Network Retries', () => {
'test message', 'test message',
'prompt-id-retry-network', 'prompt-id-retry-network',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const events: StreamEvent[] = []; const events: StreamEvent[] = [];
@@ -223,6 +225,7 @@ describe('GeminiChat Network Retries', () => {
'test message', 'test message',
'prompt-id-retry-fetch', 'prompt-id-retry-fetch',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const events: StreamEvent[] = []; const events: StreamEvent[] = [];
@@ -263,6 +266,7 @@ describe('GeminiChat Network Retries', () => {
'test message', 'test message',
'prompt-id-no-retry', 'prompt-id-no-retry',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
await expect(async () => { await expect(async () => {
@@ -304,6 +308,7 @@ describe('GeminiChat Network Retries', () => {
'test message', 'test message',
'prompt-id-ssl-retry', 'prompt-id-ssl-retry',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const events: StreamEvent[] = []; const events: StreamEvent[] = [];
@@ -353,6 +358,7 @@ describe('GeminiChat Network Retries', () => {
'test message', 'test message',
'prompt-id-connection-retry', 'prompt-id-connection-retry',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const events: StreamEvent[] = []; const events: StreamEvent[] = [];
@@ -384,6 +390,7 @@ describe('GeminiChat Network Retries', () => {
'test message', 'test message',
'prompt-id-no-connection-retry', 'prompt-id-no-connection-retry',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
await expect(async () => { await expect(async () => {
@@ -438,6 +445,7 @@ describe('GeminiChat Network Retries', () => {
'test message', 'test message',
'prompt-id-ssl-mid-stream', 'prompt-id-ssl-mid-stream',
new AbortController().signal, new AbortController().signal,
LlmRole.MAIN,
); );
const events: StreamEvent[] = []; const events: StreamEvent[] = [];
@@ -30,8 +30,8 @@ import type {
import type { ContentGenerator } from './contentGenerator.js'; import type { ContentGenerator } from './contentGenerator.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js'; import { LoggingContentGenerator } from './loggingContentGenerator.js';
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
import { ApiRequestEvent } from '../telemetry/types.js';
import { UserTierId } from '../code_assist/types.js'; import { UserTierId } from '../code_assist/types.js';
import { ApiRequestEvent, LlmRole } from '../telemetry/types.js';
describe('LoggingContentGenerator', () => { describe('LoggingContentGenerator', () => {
let wrapped: ContentGenerator; let wrapped: ContentGenerator;
@@ -89,13 +89,18 @@ describe('LoggingContentGenerator', () => {
const promise = loggingContentGenerator.generateContent( const promise = loggingContentGenerator.generateContent(
req, req,
userPromptId, userPromptId,
LlmRole.MAIN,
); );
vi.advanceTimersByTime(1000); vi.advanceTimersByTime(1000);
await promise; await promise;
expect(wrapped.generateContent).toHaveBeenCalledWith(req, userPromptId); expect(wrapped.generateContent).toHaveBeenCalledWith(
req,
userPromptId,
LlmRole.MAIN,
);
expect(logApiRequest).toHaveBeenCalledWith( expect(logApiRequest).toHaveBeenCalledWith(
config, config,
expect.any(ApiRequestEvent), expect.any(ApiRequestEvent),
@@ -118,6 +123,7 @@ describe('LoggingContentGenerator', () => {
const promise = loggingContentGenerator.generateContent( const promise = loggingContentGenerator.generateContent(
req, req,
userPromptId, userPromptId,
LlmRole.MAIN,
); );
vi.advanceTimersByTime(1000); vi.advanceTimersByTime(1000);
@@ -156,12 +162,17 @@ describe('LoggingContentGenerator', () => {
vi.mocked(wrapped.generateContentStream).mockResolvedValue( vi.mocked(wrapped.generateContentStream).mockResolvedValue(
createAsyncGenerator(), createAsyncGenerator(),
); );
const startTime = new Date('2025-01-01T00:00:00.000Z'); const startTime = new Date('2025-01-01T00:00:00.000Z');
vi.setSystemTime(startTime); vi.setSystemTime(startTime);
const stream = await loggingContentGenerator.generateContentStream( const stream = await loggingContentGenerator.generateContentStream(
req, req,
userPromptId, userPromptId,
LlmRole.MAIN,
); );
vi.advanceTimersByTime(1000); vi.advanceTimersByTime(1000);
@@ -173,6 +184,7 @@ describe('LoggingContentGenerator', () => {
expect(wrapped.generateContentStream).toHaveBeenCalledWith( expect(wrapped.generateContentStream).toHaveBeenCalledWith(
req, req,
userPromptId, userPromptId,
LlmRole.MAIN,
); );
expect(logApiRequest).toHaveBeenCalledWith( expect(logApiRequest).toHaveBeenCalledWith(
config, config,
@@ -203,6 +215,7 @@ describe('LoggingContentGenerator', () => {
const stream = await loggingContentGenerator.generateContentStream( const stream = await loggingContentGenerator.generateContentStream(
req, req,
userPromptId, userPromptId,
LlmRole.MAIN,
); );
vi.advanceTimersByTime(1000); vi.advanceTimersByTime(1000);
@@ -240,6 +253,7 @@ describe('LoggingContentGenerator', () => {
await loggingContentGenerator.generateContentStream( await loggingContentGenerator.generateContentStream(
req, req,
mainAgentPromptId, mainAgentPromptId,
LlmRole.MAIN,
); );
expect(config.setLatestApiRequest).toHaveBeenCalledWith(req); expect(config.setLatestApiRequest).toHaveBeenCalledWith(req);
@@ -264,6 +278,7 @@ describe('LoggingContentGenerator', () => {
await loggingContentGenerator.generateContentStream( await loggingContentGenerator.generateContentStream(
req, req,
subAgentPromptId, subAgentPromptId,
LlmRole.SUBAGENT,
); );
expect(config.setLatestApiRequest).not.toHaveBeenCalled(); expect(config.setLatestApiRequest).not.toHaveBeenCalled();
@@ -22,6 +22,7 @@ import {
ApiResponseEvent, ApiResponseEvent,
ApiErrorEvent, ApiErrorEvent,
} from '../telemetry/types.js'; } from '../telemetry/types.js';
import type { LlmRole } from '../telemetry/llmRole.js';
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
import type { UserTierId } from '../code_assist/types.js'; import type { UserTierId } from '../code_assist/types.js';
import { import {
@@ -65,6 +66,7 @@ export class LoggingContentGenerator implements ContentGenerator {
contents: Content[], contents: Content[],
model: string, model: string,
promptId: string, promptId: string,
role: LlmRole,
generationConfig?: GenerateContentConfig, generationConfig?: GenerateContentConfig,
serverDetails?: ServerDetails, serverDetails?: ServerDetails,
): void { ): void {
@@ -80,6 +82,7 @@ export class LoggingContentGenerator implements ContentGenerator {
server: serverDetails, server: serverDetails,
}, },
requestText, requestText,
role,
), ),
); );
} }
@@ -122,6 +125,7 @@ export class LoggingContentGenerator implements ContentGenerator {
durationMs: number, durationMs: number,
model: string, model: string,
prompt_id: string, prompt_id: string,
role: LlmRole,
responseId: string | undefined, responseId: string | undefined,
responseCandidates?: Candidate[], responseCandidates?: Candidate[],
usageMetadata?: GenerateContentResponseUsageMetadata, usageMetadata?: GenerateContentResponseUsageMetadata,
@@ -147,6 +151,7 @@ export class LoggingContentGenerator implements ContentGenerator {
this.config.getContentGeneratorConfig()?.authType, this.config.getContentGeneratorConfig()?.authType,
usageMetadata, usageMetadata,
responseText, responseText,
role,
), ),
); );
} }
@@ -157,6 +162,7 @@ export class LoggingContentGenerator implements ContentGenerator {
model: string, model: string,
prompt_id: string, prompt_id: string,
requestContents: Content[], requestContents: Content[],
role: LlmRole,
generationConfig?: GenerateContentConfig, generationConfig?: GenerateContentConfig,
serverDetails?: ServerDetails, serverDetails?: ServerDetails,
): void { ): void {
@@ -181,6 +187,7 @@ export class LoggingContentGenerator implements ContentGenerator {
? // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion ? // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
(error as StructuredError).status (error as StructuredError).status
: undefined, : undefined,
role,
), ),
); );
} }
@@ -188,6 +195,7 @@ export class LoggingContentGenerator implements ContentGenerator {
async generateContent( async generateContent(
req: GenerateContentParameters, req: GenerateContentParameters,
userPromptId: string, userPromptId: string,
role: LlmRole,
): Promise<GenerateContentResponse> { ): Promise<GenerateContentResponse> {
return runInDevTraceSpan( return runInDevTraceSpan(
{ {
@@ -203,6 +211,7 @@ export class LoggingContentGenerator implements ContentGenerator {
contents, contents,
req.model, req.model,
userPromptId, userPromptId,
role,
req.config, req.config,
serverDetails, serverDetails,
); );
@@ -211,6 +220,7 @@ export class LoggingContentGenerator implements ContentGenerator {
const response = await this.wrapped.generateContent( const response = await this.wrapped.generateContent(
req, req,
userPromptId, userPromptId,
role,
); );
spanMetadata.output = { spanMetadata.output = {
response, response,
@@ -222,6 +232,7 @@ export class LoggingContentGenerator implements ContentGenerator {
durationMs, durationMs,
response.modelVersion || req.model, response.modelVersion || req.model,
userPromptId, userPromptId,
role,
response.responseId, response.responseId,
response.candidates, response.candidates,
response.usageMetadata, response.usageMetadata,
@@ -247,6 +258,7 @@ export class LoggingContentGenerator implements ContentGenerator {
req.model, req.model,
userPromptId, userPromptId,
contents, contents,
role,
req.config, req.config,
serverDetails, serverDetails,
); );
@@ -259,6 +271,7 @@ export class LoggingContentGenerator implements ContentGenerator {
async generateContentStream( async generateContentStream(
req: GenerateContentParameters, req: GenerateContentParameters,
userPromptId: string, userPromptId: string,
role: LlmRole,
): Promise<AsyncGenerator<GenerateContentResponse>> { ): Promise<AsyncGenerator<GenerateContentResponse>> {
return runInDevTraceSpan( return runInDevTraceSpan(
{ {
@@ -283,13 +296,18 @@ export class LoggingContentGenerator implements ContentGenerator {
toContents(req.contents), toContents(req.contents),
req.model, req.model,
userPromptId, userPromptId,
role,
req.config, req.config,
serverDetails, serverDetails,
); );
let stream: AsyncGenerator<GenerateContentResponse>; let stream: AsyncGenerator<GenerateContentResponse>;
try { try {
stream = await this.wrapped.generateContentStream(req, userPromptId); stream = await this.wrapped.generateContentStream(
req,
userPromptId,
role,
);
} catch (error) { } catch (error) {
const durationMs = Date.now() - startTime; const durationMs = Date.now() - startTime;
this._logApiError( this._logApiError(
@@ -298,6 +316,7 @@ export class LoggingContentGenerator implements ContentGenerator {
req.model, req.model,
userPromptId, userPromptId,
toContents(req.contents), toContents(req.contents),
role,
req.config, req.config,
serverDetails, serverDetails,
); );
@@ -309,6 +328,7 @@ export class LoggingContentGenerator implements ContentGenerator {
stream, stream,
startTime, startTime,
userPromptId, userPromptId,
role,
spanMetadata, spanMetadata,
endSpan, endSpan,
); );
@@ -321,6 +341,7 @@ export class LoggingContentGenerator implements ContentGenerator {
stream: AsyncGenerator<GenerateContentResponse>, stream: AsyncGenerator<GenerateContentResponse>,
startTime: number, startTime: number,
userPromptId: string, userPromptId: string,
role: LlmRole,
spanMetadata: SpanMetadata, spanMetadata: SpanMetadata,
endSpan: () => void, endSpan: () => void,
): AsyncGenerator<GenerateContentResponse> { ): AsyncGenerator<GenerateContentResponse> {
@@ -344,6 +365,7 @@ export class LoggingContentGenerator implements ContentGenerator {
durationMs, durationMs,
responses[0]?.modelVersion || req.model, responses[0]?.modelVersion || req.model,
userPromptId, userPromptId,
role,
responses[0]?.responseId, responses[0]?.responseId,
responses.flatMap((response) => response.candidates || []), responses.flatMap((response) => response.candidates || []),
lastUsageMetadata, lastUsageMetadata,
@@ -378,6 +400,7 @@ export class LoggingContentGenerator implements ContentGenerator {
responses[0]?.modelVersion || req.model, responses[0]?.modelVersion || req.model,
userPromptId, userPromptId,
requestContents, requestContents,
role,
req.config, req.config,
serverDetails, serverDetails,
); );
@@ -18,6 +18,7 @@ import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
import { safeJsonStringify } from '../utils/safeJsonStringify.js'; import { safeJsonStringify } from '../utils/safeJsonStringify.js';
import type { ContentGenerator } from './contentGenerator.js'; import type { ContentGenerator } from './contentGenerator.js';
import { RecordingContentGenerator } from './recordingContentGenerator.js'; import { RecordingContentGenerator } from './recordingContentGenerator.js';
import { LlmRole } from '../telemetry/types.js';
vi.mock('node:fs', () => ({ vi.mock('node:fs', () => ({
appendFileSync: vi.fn(), appendFileSync: vi.fn(),
@@ -51,9 +52,14 @@ describe('RecordingContentGenerator', () => {
const response = await recorder.generateContent( const response = await recorder.generateContent(
{} as GenerateContentParameters, {} as GenerateContentParameters,
'id1', 'id1',
LlmRole.MAIN,
); );
expect(response).toEqual(mockResponse); expect(response).toEqual(mockResponse);
expect(mockRealGenerator.generateContent).toHaveBeenCalledWith({}, 'id1'); expect(mockRealGenerator.generateContent).toHaveBeenCalledWith(
{},
'id1',
LlmRole.MAIN,
);
expect(appendFileSync).toHaveBeenCalledWith( expect(appendFileSync).toHaveBeenCalledWith(
filePath, filePath,
@@ -90,6 +96,7 @@ describe('RecordingContentGenerator', () => {
const stream = await recorder.generateContentStream( const stream = await recorder.generateContentStream(
{} as GenerateContentParameters, {} as GenerateContentParameters,
'id1', 'id1',
LlmRole.MAIN,
); );
const responses = []; const responses = [];
for await (const response of stream) { for await (const response of stream) {
@@ -100,6 +107,7 @@ describe('RecordingContentGenerator', () => {
expect(mockRealGenerator.generateContentStream).toHaveBeenCalledWith( expect(mockRealGenerator.generateContentStream).toHaveBeenCalledWith(
{}, {},
'id1', 'id1',
LlmRole.MAIN,
); );
expect(appendFileSync).toHaveBeenCalledWith( expect(appendFileSync).toHaveBeenCalledWith(
@@ -17,6 +17,7 @@ import type { ContentGenerator } from './contentGenerator.js';
import type { FakeResponse } from './fakeContentGenerator.js'; import type { FakeResponse } from './fakeContentGenerator.js';
import type { UserTierId } from '../code_assist/types.js'; import type { UserTierId } from '../code_assist/types.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js'; import { safeJsonStringify } from '../utils/safeJsonStringify.js';
import type { LlmRole } from '../telemetry/types.js';
// A ContentGenerator that wraps another content generator and records all the // A ContentGenerator that wraps another content generator and records all the
// responses, with the ability to write them out to a file. These files are // responses, with the ability to write them out to a file. These files are
@@ -41,10 +42,12 @@ export class RecordingContentGenerator implements ContentGenerator {
async generateContent( async generateContent(
request: GenerateContentParameters, request: GenerateContentParameters,
userPromptId: string, userPromptId: string,
role: LlmRole,
): Promise<GenerateContentResponse> { ): Promise<GenerateContentResponse> {
const response = await this.realGenerator.generateContent( const response = await this.realGenerator.generateContent(
request, request,
userPromptId, userPromptId,
role,
); );
const recordedResponse: FakeResponse = { const recordedResponse: FakeResponse = {
method: 'generateContent', method: 'generateContent',
@@ -61,6 +64,7 @@ export class RecordingContentGenerator implements ContentGenerator {
async generateContentStream( async generateContentStream(
request: GenerateContentParameters, request: GenerateContentParameters,
userPromptId: string, userPromptId: string,
role: LlmRole,
): Promise<AsyncGenerator<GenerateContentResponse>> { ): Promise<AsyncGenerator<GenerateContentResponse>> {
const recordedResponse: FakeResponse = { const recordedResponse: FakeResponse = {
method: 'generateContentStream', method: 'generateContentStream',
@@ -70,6 +74,7 @@ export class RecordingContentGenerator implements ContentGenerator {
const realResponses = await this.realGenerator.generateContentStream( const realResponses = await this.realGenerator.generateContentStream(
request, request,
userPromptId, userPromptId,
role,
); );
async function* stream(filePath: string) { async function* stream(filePath: string) {
+2
View File
@@ -14,6 +14,7 @@ import type { GenerateContentResponse, Part, Content } from '@google/genai';
import { reportError } from '../utils/errorReporting.js'; import { reportError } from '../utils/errorReporting.js';
import type { GeminiChat } from './geminiChat.js'; import type { GeminiChat } from './geminiChat.js';
import { InvalidStreamError, StreamEventType } from './geminiChat.js'; import { InvalidStreamError, StreamEventType } from './geminiChat.js';
import { LlmRole } from '../telemetry/types.js';
const mockSendMessageStream = vi.fn(); const mockSendMessageStream = vi.fn();
const mockGetHistory = vi.fn(); const mockGetHistory = vi.fn();
@@ -102,6 +103,7 @@ describe('Turn', () => {
reqParts, reqParts,
'prompt-id-1', 'prompt-id-1',
expect.any(AbortSignal), expect.any(AbortSignal),
LlmRole.MAIN,
undefined, undefined,
); );
+3
View File
@@ -29,6 +29,7 @@ import { parseThought, type ThoughtSummary } from '../utils/thoughtUtils.js';
import { createUserContent } from '@google/genai'; import { createUserContent } from '@google/genai';
import type { ModelConfigKey } from '../services/modelConfigService.js'; import type { ModelConfigKey } from '../services/modelConfigService.js';
import { getCitations } from '../utils/generateContentResponseUtilities.js'; import { getCitations } from '../utils/generateContentResponseUtilities.js';
import { LlmRole } from '../telemetry/types.js';
import { import {
type ToolCallRequestInfo, type ToolCallRequestInfo,
@@ -251,6 +252,7 @@ export class Turn {
req: PartListUnion, req: PartListUnion,
signal: AbortSignal, signal: AbortSignal,
displayContent?: PartListUnion, displayContent?: PartListUnion,
role: LlmRole = LlmRole.MAIN,
): AsyncGenerator<ServerGeminiStreamEvent> { ): AsyncGenerator<ServerGeminiStreamEvent> {
try { try {
// Note: This assumes `sendMessageStream` yields events like // Note: This assumes `sendMessageStream` yields events like
@@ -260,6 +262,7 @@ export class Turn {
req, req,
this.prompt_id, this.prompt_id,
signal, signal,
role,
displayContent, displayContent,
); );
@@ -79,6 +79,7 @@ describe('JsonFormatter', () => {
thoughts: 103, thoughts: 103,
tool: 0, tool: 0,
}, },
roles: {},
}, },
'gemini-2.5-flash': { 'gemini-2.5-flash': {
api: { api: {
@@ -95,6 +96,7 @@ describe('JsonFormatter', () => {
thoughts: 138, thoughts: 138,
tool: 0, tool: 0,
}, },
roles: {},
}, },
}, },
tools: { tools: {
@@ -289,6 +289,7 @@ describe('StreamJsonFormatter', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}; };
metrics.tools.totalCalls = 2; metrics.tools.totalCalls = 2;
metrics.tools.totalDecisions[ToolCallDecision.AUTO_ACCEPT] = 2; metrics.tools.totalDecisions[ToolCallDecision.AUTO_ACCEPT] = 2;
@@ -319,6 +320,7 @@ describe('StreamJsonFormatter', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}; };
metrics.models['gemini-ultra'] = { metrics.models['gemini-ultra'] = {
api: { totalRequests: 1, totalErrors: 0, totalLatencyMs: 2000 }, api: { totalRequests: 1, totalErrors: 0, totalLatencyMs: 2000 },
@@ -331,6 +333,7 @@ describe('StreamJsonFormatter', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}; };
metrics.tools.totalCalls = 5; metrics.tools.totalCalls = 5;
@@ -360,6 +363,7 @@ describe('StreamJsonFormatter', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}; };
const result = formatter.convertToStreamStats(metrics, 1200); const result = formatter.convertToStreamStats(metrics, 1200);
@@ -20,6 +20,7 @@ import {
isFunctionResponse, isFunctionResponse,
} from '../../utils/messageInspectors.js'; } from '../../utils/messageInspectors.js';
import { debugLogger } from '../../utils/debugLogger.js'; import { debugLogger } from '../../utils/debugLogger.js';
import { LlmRole } from '../../telemetry/types.js';
// The number of recent history turns to provide to the router for context. // The number of recent history turns to provide to the router for context.
const HISTORY_TURNS_FOR_CONTEXT = 4; const HISTORY_TURNS_FOR_CONTEXT = 4;
@@ -161,6 +162,7 @@ export class ClassifierStrategy implements RoutingStrategy {
systemInstruction: CLASSIFIER_SYSTEM_PROMPT, systemInstruction: CLASSIFIER_SYSTEM_PROMPT,
abortSignal: context.signal, abortSignal: context.signal,
promptId, promptId,
role: LlmRole.UTILITY_ROUTER,
}); });
const routerResponse = ClassifierResponseSchema.parse(jsonResponse); const routerResponse = ClassifierResponseSchema.parse(jsonResponse);
@@ -16,6 +16,7 @@ import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
import { createUserContent, Type } from '@google/genai'; import { createUserContent, Type } from '@google/genai';
import type { Config } from '../../config/config.js'; import type { Config } from '../../config/config.js';
import { debugLogger } from '../../utils/debugLogger.js'; import { debugLogger } from '../../utils/debugLogger.js';
import { LlmRole } from '../../telemetry/types.js';
// The number of recent history turns to provide to the router for context. // The number of recent history turns to provide to the router for context.
const HISTORY_TURNS_FOR_CONTEXT = 8; const HISTORY_TURNS_FOR_CONTEXT = 8;
@@ -169,6 +170,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
systemInstruction: CLASSIFIER_SYSTEM_PROMPT, systemInstruction: CLASSIFIER_SYSTEM_PROMPT,
abortSignal: context.signal, abortSignal: context.signal,
promptId, promptId,
role: LlmRole.UTILITY_ROUTER,
}); });
const routerResponse = ClassifierResponseSchema.parse(jsonResponse); const routerResponse = ClassifierResponseSchema.parse(jsonResponse);
@@ -31,6 +31,7 @@ import {
PREVIEW_GEMINI_FLASH_MODEL, PREVIEW_GEMINI_FLASH_MODEL,
} from '../config/models.js'; } from '../config/models.js';
import { PreCompressTrigger } from '../hooks/types.js'; import { PreCompressTrigger } from '../hooks/types.js';
import { LlmRole } from '../telemetry/types.js';
/** /**
* Default threshold for compression token count as a fraction of the model's * Default threshold for compression token count as a fraction of the model's
@@ -339,6 +340,7 @@ export class ChatCompressionService {
promptId, promptId,
// TODO(joshualitt): wire up a sensible abort signal, // TODO(joshualitt): wire up a sensible abort signal,
abortSignal: abortSignal ?? new AbortController().signal, abortSignal: abortSignal ?? new AbortController().signal,
role: LlmRole.UTILITY_COMPRESSOR,
}); });
const summary = getResponseText(summaryResponse) ?? ''; const summary = getResponseText(summaryResponse) ?? '';
@@ -365,6 +367,7 @@ export class ChatCompressionService {
], ],
systemInstruction: { text: getCompressionPrompt(config) }, systemInstruction: { text: getCompressionPrompt(config) },
promptId: `${promptId}-verify`, promptId: `${promptId}-verify`,
role: LlmRole.UTILITY_COMPRESSOR,
abortSignal: abortSignal ?? new AbortController().signal, abortSignal: abortSignal ?? new AbortController().signal,
}); });
@@ -25,6 +25,7 @@ import {
isFunctionResponse, isFunctionResponse,
} from '../utils/messageInspectors.js'; } from '../utils/messageInspectors.js';
import { debugLogger } from '../utils/debugLogger.js'; import { debugLogger } from '../utils/debugLogger.js';
import { LlmRole } from '../telemetry/types.js';
const TOOL_CALL_LOOP_THRESHOLD = 5; const TOOL_CALL_LOOP_THRESHOLD = 5;
const CONTENT_LOOP_THRESHOLD = 10; const CONTENT_LOOP_THRESHOLD = 10;
@@ -554,6 +555,7 @@ export class LoopDetectionService {
abortSignal: signal, abortSignal: signal,
promptId: this.promptId, promptId: this.promptId,
maxAttempts: 2, maxAttempts: 2,
role: LlmRole.UTILITY_LOOP_DETECTOR,
}); });
if ( if (
@@ -10,6 +10,7 @@ import { partListUnionToString } from '../core/geminiRequest.js';
import { debugLogger } from '../utils/debugLogger.js'; import { debugLogger } from '../utils/debugLogger.js';
import type { Content } from '@google/genai'; import type { Content } from '@google/genai';
import { getResponseText } from '../utils/partUtils.js'; import { getResponseText } from '../utils/partUtils.js';
import { LlmRole } from '../telemetry/types.js';
const DEFAULT_MAX_MESSAGES = 20; const DEFAULT_MAX_MESSAGES = 20;
const DEFAULT_TIMEOUT_MS = 5000; const DEFAULT_TIMEOUT_MS = 5000;
@@ -124,6 +125,7 @@ export class SessionSummaryService {
contents, contents,
abortSignal: abortController.signal, abortSignal: abortController.signal,
promptId: 'session-summary-generation', promptId: 'session-summary-generation',
role: LlmRole.UTILITY_SUMMARIZER,
}); });
const summary = getResponseText(response); const summary = getResponseText(response);
+1
View File
@@ -65,6 +65,7 @@ export {
ToolCallDecision, ToolCallDecision,
RewindEvent, RewindEvent,
} from './types.js'; } from './types.js';
export { LlmRole } from './llmRole.js';
export { makeSlashCommandEvent, makeChatCompressionEvent } from './types.js'; export { makeSlashCommandEvent, makeChatCompressionEvent } from './types.js';
export type { TelemetryEvent } from './types.js'; export type { TelemetryEvent } from './types.js';
export { SpanStatusCode, ValueType } from '@opentelemetry/api'; export { SpanStatusCode, ValueType } from '@opentelemetry/api';
+18
View File
@@ -0,0 +1,18 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
export enum LlmRole {
MAIN = 'main',
SUBAGENT = 'subagent',
UTILITY_TOOL = 'utility_tool',
UTILITY_COMPRESSOR = 'utility_compressor',
UTILITY_SUMMARIZER = 'utility_summarizer',
UTILITY_ROUTER = 'utility_router',
UTILITY_LOOP_DETECTOR = 'utility_loop_detector',
UTILITY_NEXT_SPEAKER = 'utility_next_speaker',
UTILITY_EDIT_CORRECTOR = 'utility_edit_corrector',
UTILITY_AUTOCOMPLETE = 'utility_autocomplete',
}
@@ -93,6 +93,7 @@ import {
EVENT_EXTENSION_UPDATE, EVENT_EXTENSION_UPDATE,
HookCallEvent, HookCallEvent,
EVENT_HOOK_CALL, EVENT_HOOK_CALL,
LlmRole,
} from './types.js'; } from './types.js';
import * as metrics from './metrics.js'; import * as metrics from './metrics.js';
import { FileOperation } from './metrics.js'; import { FileOperation } from './metrics.js';
@@ -520,6 +521,30 @@ describe('loggers', () => {
'event.timestamp': '2025-01-01T00:00:00.000Z', 'event.timestamp': '2025-01-01T00:00:00.000Z',
}); });
}); });
it('should log an API response with a role', () => {
const event = new ApiResponseEvent(
'test-model',
100,
{ prompt_id: 'prompt-id-role', contents: [] },
{ candidates: [] },
AuthType.LOGIN_WITH_GOOGLE,
{},
'test-response',
LlmRole.SUBAGENT,
);
logApiResponse(mockConfig, event);
expect(mockLogger.emit).toHaveBeenCalledWith({
body: 'API response from test-model. Status: 200. Duration: 100ms.',
attributes: expect.objectContaining({
'event.name': EVENT_API_RESPONSE,
prompt_id: 'prompt-id-role',
role: 'subagent',
}),
});
});
}); });
describe('logApiError', () => { describe('logApiError', () => {
@@ -654,6 +679,30 @@ describe('loggers', () => {
'event.timestamp': '2025-01-01T00:00:00.000Z', 'event.timestamp': '2025-01-01T00:00:00.000Z',
}); });
}); });
it('should log an API error with a role', () => {
const event = new ApiErrorEvent(
'test-model',
'error',
100,
{ prompt_id: 'prompt-id-role', contents: [] },
AuthType.LOGIN_WITH_GOOGLE,
'ApiError',
503,
LlmRole.SUBAGENT,
);
logApiError(mockConfig, event);
expect(mockLogger.emit).toHaveBeenCalledWith({
body: 'API error for test-model. Error: error. Duration: 100ms.',
attributes: expect.objectContaining({
'event.name': EVENT_API_ERROR,
prompt_id: 'prompt-id-role',
role: 'subagent',
}),
});
});
}); });
describe('logApiRequest', () => { describe('logApiRequest', () => {
@@ -917,6 +966,26 @@ describe('loggers', () => {
}), }),
}); });
}); });
it('should log an API request with a role', () => {
const event = new ApiRequestEvent(
'test-model',
{ prompt_id: 'prompt-id-role', contents: [] },
'request text',
LlmRole.SUBAGENT,
);
logApiRequest(mockConfig, event);
expect(mockLogger.emit).toHaveBeenCalledWith({
body: 'API request to test-model.',
attributes: expect.objectContaining({
'event.name': EVENT_API_REQUEST,
prompt_id: 'prompt-id-role',
role: 'subagent',
}),
});
});
}); });
describe('logFlashFallback', () => { describe('logFlashFallback', () => {
+21
View File
@@ -41,6 +41,8 @@ import {
} from './semantic.js'; } from './semantic.js';
import { sanitizeHookName } from './sanitize.js'; import { sanitizeHookName } from './sanitize.js';
import { getFileDiffFromResultDisplay } from '../utils/fileDiffUtils.js'; import { getFileDiffFromResultDisplay } from '../utils/fileDiffUtils.js';
import { LlmRole } from './llmRole.js';
export { LlmRole };
export interface BaseTelemetryEvent { export interface BaseTelemetryEvent {
'event.name': string; 'event.name': string;
@@ -375,17 +377,20 @@ export class ApiRequestEvent implements BaseTelemetryEvent {
model: string; model: string;
prompt: GenAIPromptDetails; prompt: GenAIPromptDetails;
request_text?: string; request_text?: string;
role?: LlmRole;
constructor( constructor(
model: string, model: string,
prompt_details: GenAIPromptDetails, prompt_details: GenAIPromptDetails,
request_text?: string, request_text?: string,
role?: LlmRole,
) { ) {
this['event.name'] = 'api_request'; this['event.name'] = 'api_request';
this['event.timestamp'] = new Date().toISOString(); this['event.timestamp'] = new Date().toISOString();
this.model = model; this.model = model;
this.prompt = prompt_details; this.prompt = prompt_details;
this.request_text = request_text; this.request_text = request_text;
this.role = role;
} }
toLogRecord(config: Config): LogRecord { toLogRecord(config: Config): LogRecord {
@@ -397,6 +402,9 @@ export class ApiRequestEvent implements BaseTelemetryEvent {
prompt_id: this.prompt.prompt_id, prompt_id: this.prompt.prompt_id,
request_text: this.request_text, request_text: this.request_text,
}; };
if (this.role) {
attributes['role'] = this.role;
}
return { body: `API request to ${this.model}.`, attributes }; return { body: `API request to ${this.model}.`, attributes };
} }
@@ -445,6 +453,7 @@ export class ApiErrorEvent implements BaseTelemetryEvent {
status_code?: number | string; status_code?: number | string;
duration_ms: number; duration_ms: number;
auth_type?: string; auth_type?: string;
role?: LlmRole;
constructor( constructor(
model: string, model: string,
@@ -454,6 +463,7 @@ export class ApiErrorEvent implements BaseTelemetryEvent {
auth_type?: string, auth_type?: string,
error_type?: string, error_type?: string,
status_code?: number | string, status_code?: number | string,
role?: LlmRole,
) { ) {
this['event.name'] = 'api_error'; this['event.name'] = 'api_error';
this['event.timestamp'] = new Date().toISOString(); this['event.timestamp'] = new Date().toISOString();
@@ -464,6 +474,7 @@ export class ApiErrorEvent implements BaseTelemetryEvent {
this.duration_ms = duration_ms; this.duration_ms = duration_ms;
this.prompt = prompt_details; this.prompt = prompt_details;
this.auth_type = auth_type; this.auth_type = auth_type;
this.role = role;
} }
toLogRecord(config: Config): LogRecord { toLogRecord(config: Config): LogRecord {
@@ -482,6 +493,10 @@ export class ApiErrorEvent implements BaseTelemetryEvent {
auth_type: this.auth_type, auth_type: this.auth_type,
}; };
if (this.role) {
attributes['role'] = this.role;
}
if (this.error_type) { if (this.error_type) {
attributes['error.type'] = this.error_type; attributes['error.type'] = this.error_type;
} }
@@ -590,6 +605,7 @@ export class ApiResponseEvent implements BaseTelemetryEvent {
response: GenAIResponseDetails; response: GenAIResponseDetails;
usage: GenAIUsageDetails; usage: GenAIUsageDetails;
finish_reasons: OTelFinishReason[]; finish_reasons: OTelFinishReason[];
role?: LlmRole;
constructor( constructor(
model: string, model: string,
@@ -599,6 +615,7 @@ export class ApiResponseEvent implements BaseTelemetryEvent {
auth_type?: string, auth_type?: string,
usage_data?: GenerateContentResponseUsageMetadata, usage_data?: GenerateContentResponseUsageMetadata,
response_text?: string, response_text?: string,
role?: LlmRole,
) { ) {
this['event.name'] = 'api_response'; this['event.name'] = 'api_response';
this['event.timestamp'] = new Date().toISOString(); this['event.timestamp'] = new Date().toISOString();
@@ -619,6 +636,7 @@ export class ApiResponseEvent implements BaseTelemetryEvent {
total_token_count: usage_data?.totalTokenCount ?? 0, total_token_count: usage_data?.totalTokenCount ?? 0,
}; };
this.finish_reasons = toFinishReasons(this.response.candidates); this.finish_reasons = toFinishReasons(this.response.candidates);
this.role = role;
} }
toLogRecord(config: Config): LogRecord { toLogRecord(config: Config): LogRecord {
@@ -639,6 +657,9 @@ export class ApiResponseEvent implements BaseTelemetryEvent {
status_code: this.status_code, status_code: this.status_code,
finish_reasons: this.finish_reasons, finish_reasons: this.finish_reasons,
}; };
if (this.role) {
attributes['role'] = this.role;
}
if (this.response_text) { if (this.response_text) {
attributes['response_text'] = this.response_text; attributes['response_text'] = this.response_text;
} }
@@ -181,6 +181,7 @@ describe('UiTelemetryService', () => {
thoughts: 2, thoughts: 2,
tool: 3, tool: 3,
}, },
roles: {},
}); });
expect(service.getLastPromptTokenCount()).toBe(0); expect(service.getLastPromptTokenCount()).toBe(0);
}); });
@@ -236,6 +237,7 @@ describe('UiTelemetryService', () => {
thoughts: 6, thoughts: 6,
tool: 9, tool: 9,
}, },
roles: {},
}); });
expect(service.getLastPromptTokenCount()).toBe(0); expect(service.getLastPromptTokenCount()).toBe(0);
}); });
@@ -311,6 +313,7 @@ describe('UiTelemetryService', () => {
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}); });
}); });
@@ -356,6 +359,35 @@ describe('UiTelemetryService', () => {
thoughts: 2, thoughts: 2,
tool: 3, tool: 3,
}, },
roles: {},
});
});
it('should update role metrics when processing an ApiErrorEvent with a role', () => {
const event = {
'event.name': EVENT_API_ERROR,
model: 'gemini-2.5-pro',
duration_ms: 300,
error: 'Something went wrong',
role: 'utility_tool',
} as unknown as ApiErrorEvent & { 'event.name': typeof EVENT_API_ERROR };
service.addEvent(event);
const metrics = service.getMetrics();
expect(metrics.models['gemini-2.5-pro'].roles['utility_tool']).toEqual({
totalRequests: 1,
totalErrors: 1,
totalLatencyMs: 300,
tokens: {
input: 0,
prompt: 0,
candidates: 0,
total: 0,
cached: 0,
thoughts: 0,
tool: 0,
},
}); });
}); });
}); });
@@ -18,6 +18,8 @@ import type {
ToolCallEvent, ToolCallEvent,
} from './types.js'; } from './types.js';
import type { LlmRole } from './types.js';
export type UiEvent = export type UiEvent =
| (ApiResponseEvent & { 'event.name': typeof EVENT_API_RESPONSE }) | (ApiResponseEvent & { 'event.name': typeof EVENT_API_RESPONSE })
| (ApiErrorEvent & { 'event.name': typeof EVENT_API_ERROR }) | (ApiErrorEvent & { 'event.name': typeof EVENT_API_ERROR })
@@ -36,6 +38,21 @@ export interface ToolCallStats {
}; };
} }
export interface RoleMetrics {
totalRequests: number;
totalErrors: number;
totalLatencyMs: number;
tokens: {
input: number;
prompt: number;
candidates: number;
total: number;
cached: number;
thoughts: number;
tool: number;
};
}
export interface ModelMetrics { export interface ModelMetrics {
api: { api: {
totalRequests: number; totalRequests: number;
@@ -51,6 +68,7 @@ export interface ModelMetrics {
thoughts: number; thoughts: number;
tool: number; tool: number;
}; };
roles: Partial<Record<LlmRole, RoleMetrics>>;
} }
export interface SessionMetrics { export interface SessionMetrics {
@@ -74,6 +92,21 @@ export interface SessionMetrics {
}; };
} }
const createInitialRoleMetrics = (): RoleMetrics => ({
totalRequests: 0,
totalErrors: 0,
totalLatencyMs: 0,
tokens: {
input: 0,
prompt: 0,
candidates: 0,
total: 0,
cached: 0,
thoughts: 0,
tool: 0,
},
});
const createInitialModelMetrics = (): ModelMetrics => ({ const createInitialModelMetrics = (): ModelMetrics => ({
api: { api: {
totalRequests: 0, totalRequests: 0,
@@ -89,6 +122,7 @@ const createInitialModelMetrics = (): ModelMetrics => ({
thoughts: 0, thoughts: 0,
tool: 0, tool: 0,
}, },
roles: {},
}); });
const createInitialMetrics = (): SessionMetrics => ({ const createInitialMetrics = (): SessionMetrics => ({
@@ -177,6 +211,25 @@ export class UiTelemetryService extends EventEmitter {
0, 0,
modelMetrics.tokens.prompt - modelMetrics.tokens.cached, modelMetrics.tokens.prompt - modelMetrics.tokens.cached,
); );
if (event.role) {
if (!modelMetrics.roles[event.role]) {
modelMetrics.roles[event.role] = createInitialRoleMetrics();
}
const roleMetrics = modelMetrics.roles[event.role]!;
roleMetrics.totalRequests++;
roleMetrics.totalLatencyMs += event.duration_ms;
roleMetrics.tokens.prompt += event.usage.input_token_count;
roleMetrics.tokens.candidates += event.usage.output_token_count;
roleMetrics.tokens.total += event.usage.total_token_count;
roleMetrics.tokens.cached += event.usage.cached_content_token_count;
roleMetrics.tokens.thoughts += event.usage.thoughts_token_count;
roleMetrics.tokens.tool += event.usage.tool_token_count;
roleMetrics.tokens.input = Math.max(
0,
roleMetrics.tokens.prompt - roleMetrics.tokens.cached,
);
}
} }
private processApiError(event: ApiErrorEvent) { private processApiError(event: ApiErrorEvent) {
@@ -184,6 +237,16 @@ export class UiTelemetryService extends EventEmitter {
modelMetrics.api.totalRequests++; modelMetrics.api.totalRequests++;
modelMetrics.api.totalErrors++; modelMetrics.api.totalErrors++;
modelMetrics.api.totalLatencyMs += event.duration_ms; modelMetrics.api.totalLatencyMs += event.duration_ms;
if (event.role) {
if (!modelMetrics.roles[event.role]) {
modelMetrics.roles[event.role] = createInitialRoleMetrics();
}
const roleMetrics = modelMetrics.roles[event.role]!;
roleMetrics.totalRequests++;
roleMetrics.totalErrors++;
roleMetrics.totalLatencyMs += event.duration_ms;
}
} }
private processToolCall(event: ToolCallEvent) { private processToolCall(event: ToolCallEvent) {
+3
View File
@@ -27,6 +27,7 @@ import {
logWebFetchFallbackAttempt, logWebFetchFallbackAttempt,
WebFetchFallbackAttemptEvent, WebFetchFallbackAttemptEvent,
} from '../telemetry/index.js'; } from '../telemetry/index.js';
import { LlmRole } from '../telemetry/llmRole.js';
import { WEB_FETCH_TOOL_NAME } from './tool-names.js'; import { WEB_FETCH_TOOL_NAME } from './tool-names.js';
import { debugLogger } from '../utils/debugLogger.js'; import { debugLogger } from '../utils/debugLogger.js';
import { retryWithBackoff } from '../utils/retry.js'; import { retryWithBackoff } from '../utils/retry.js';
@@ -189,6 +190,7 @@ ${textContent}
{ model: 'web-fetch-fallback' }, { model: 'web-fetch-fallback' },
[{ role: 'user', parts: [{ text: fallbackPrompt }] }], [{ role: 'user', parts: [{ text: fallbackPrompt }] }],
signal, signal,
LlmRole.UTILITY_TOOL,
); );
const resultText = getResponseText(result) || ''; const resultText = getResponseText(result) || '';
return { return {
@@ -278,6 +280,7 @@ ${textContent}
{ model: 'web-fetch' }, { model: 'web-fetch' },
[{ role: 'user', parts: [{ text: userPrompt }] }], [{ role: 'user', parts: [{ text: userPrompt }] }],
signal, // Pass signal signal, // Pass signal
LlmRole.UTILITY_TOOL,
); );
debugLogger.debug( debugLogger.debug(
+2
View File
@@ -17,6 +17,7 @@ import { getResponseText } from '../utils/partUtils.js';
import { debugLogger } from '../utils/debugLogger.js'; import { debugLogger } from '../utils/debugLogger.js';
import { WEB_SEARCH_DEFINITION } from './definitions/coreTools.js'; import { WEB_SEARCH_DEFINITION } from './definitions/coreTools.js';
import { resolveToolDeclaration } from './definitions/resolver.js'; import { resolveToolDeclaration } from './definitions/resolver.js';
import { LlmRole } from '../telemetry/llmRole.js';
interface GroundingChunkWeb { interface GroundingChunkWeb {
uri?: string; uri?: string;
@@ -86,6 +87,7 @@ class WebSearchToolInvocation extends BaseToolInvocation<
{ model: 'web-search' }, { model: 'web-search' },
[{ role: 'user', parts: [{ text: this.params.query }] }], [{ role: 'user', parts: [{ text: this.params.query }] }],
signal, signal,
LlmRole.UTILITY_TOOL,
); );
const responseText = getResponseText(response); const responseText = getResponseText(response);
+5
View File
@@ -23,6 +23,7 @@ import * as fs from 'node:fs';
import { promptIdContext } from './promptIdContext.js'; import { promptIdContext } from './promptIdContext.js';
import { debugLogger } from './debugLogger.js'; import { debugLogger } from './debugLogger.js';
import { LRUCache } from 'mnemonist'; import { LRUCache } from 'mnemonist';
import { LlmRole } from '../telemetry/types.js';
const CODE_CORRECTION_SYSTEM_PROMPT = ` const CODE_CORRECTION_SYSTEM_PROMPT = `
You are an expert code-editing assistant. Your task is to analyze a failed edit attempt and provide a corrected version of the text snippets. You are an expert code-editing assistant. Your task is to analyze a failed edit attempt and provide a corrected version of the text snippets.
@@ -439,6 +440,7 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
abortSignal, abortSignal,
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT, systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
promptId: getPromptId(), promptId: getPromptId(),
role: LlmRole.UTILITY_EDIT_CORRECTOR,
}); });
if ( if (
@@ -528,6 +530,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
abortSignal, abortSignal,
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT, systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
promptId: getPromptId(), promptId: getPromptId(),
role: LlmRole.UTILITY_EDIT_CORRECTOR,
}); });
if ( if (
@@ -598,6 +601,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
abortSignal, abortSignal,
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT, systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
promptId: getPromptId(), promptId: getPromptId(),
role: LlmRole.UTILITY_EDIT_CORRECTOR,
}); });
if ( if (
@@ -665,6 +669,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
abortSignal, abortSignal,
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT, systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
promptId: getPromptId(), promptId: getPromptId(),
role: LlmRole.UTILITY_EDIT_CORRECTOR,
}); });
if ( if (
@@ -10,6 +10,7 @@ import { type BaseLlmClient } from '../core/baseLlmClient.js';
import { LRUCache } from 'mnemonist'; import { LRUCache } from 'mnemonist';
import { getPromptIdWithFallback } from './promptIdContext.js'; import { getPromptIdWithFallback } from './promptIdContext.js';
import { debugLogger } from './debugLogger.js'; import { debugLogger } from './debugLogger.js';
import { LlmRole } from '../telemetry/types.js';
const MAX_CACHE_SIZE = 50; const MAX_CACHE_SIZE = 50;
const GENERATE_JSON_TIMEOUT_MS = 40000; // 40 seconds const GENERATE_JSON_TIMEOUT_MS = 40000; // 40 seconds
@@ -181,6 +182,7 @@ export async function FixLLMEditWithInstruction(
systemInstruction: EDIT_SYS_PROMPT, systemInstruction: EDIT_SYS_PROMPT,
promptId, promptId,
maxAttempts: 1, maxAttempts: 1,
role: LlmRole.UTILITY_EDIT_CORRECTOR,
}, },
GENERATE_JSON_TIMEOUT_MS, GENERATE_JSON_TIMEOUT_MS,
); );
@@ -9,6 +9,7 @@ import type { BaseLlmClient } from '../core/baseLlmClient.js';
import type { GeminiChat } from '../core/geminiChat.js'; import type { GeminiChat } from '../core/geminiChat.js';
import { isFunctionResponse } from './messageInspectors.js'; import { isFunctionResponse } from './messageInspectors.js';
import { debugLogger } from './debugLogger.js'; import { debugLogger } from './debugLogger.js';
import { LlmRole } from '../telemetry/types.js';
const CHECK_PROMPT = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you). const CHECK_PROMPT = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you).
**Decision Rules (apply in order):** **Decision Rules (apply in order):**
@@ -116,6 +117,7 @@ export async function checkNextSpeaker(
schema: RESPONSE_SCHEMA, schema: RESPONSE_SCHEMA,
abortSignal, abortSignal,
promptId, promptId,
role: LlmRole.UTILITY_NEXT_SPEAKER,
})) as unknown as NextSpeakerResponse; })) as unknown as NextSpeakerResponse;
if ( if (
+2
View File
@@ -11,6 +11,7 @@ import { getResponseText, partToString } from './partUtils.js';
import { debugLogger } from './debugLogger.js'; import { debugLogger } from './debugLogger.js';
import type { ModelConfigKey } from '../services/modelConfigService.js'; import type { ModelConfigKey } from '../services/modelConfigService.js';
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
import { LlmRole } from '../telemetry/llmRole.js';
/** /**
* A function that summarizes the result of a tool execution. * A function that summarizes the result of a tool execution.
@@ -94,6 +95,7 @@ export async function summarizeToolOutput(
modelConfigKey, modelConfigKey,
contents, contents,
abortSignal, abortSignal,
LlmRole.UTILITY_SUMMARIZER,
); );
return getResponseText(parsedResponse) || textToSummarize; return getResponseText(parsedResponse) || textToSummarize;
} catch (error) { } catch (error) {