Record model responses with --record-responses (for use in testing) (#11894)

This commit is contained in:
Jacob MacDonald
2025-10-28 12:13:45 -07:00
committed by GitHub
parent ab1f195508
commit 44bdd3ad11
19 changed files with 549 additions and 326 deletions

View File

@@ -74,6 +74,7 @@ export interface CliArgs {
useWriteTodos: boolean | undefined;
outputFormat: string | undefined;
fakeResponses: string | undefined;
recordResponses: string | undefined;
}
export async function parseArguments(settings: Settings): Promise<CliArgs> {
@@ -202,6 +203,12 @@ export async function parseArguments(settings: Settings): Promise<CliArgs> {
.option('fake-responses', {
type: 'string',
description: 'Path to a file with fake model responses for testing.',
hidden: true,
})
.option('record-responses', {
type: 'string',
description: 'Path to a file to record model responses for testing.',
hidden: true,
})
.deprecateOption(
'prompt',
@@ -700,6 +707,7 @@ export async function loadCliConfig(
codebaseInvestigatorSettings:
settings.experimental?.codebaseInvestigatorSettings,
fakeResponses: argv.fakeResponses,
recordResponses: argv.recordResponses,
retryFetchErrors: settings.general?.retryFetchErrors ?? false,
ptyInfo: ptyInfo?.name,
});

View File

@@ -364,6 +364,7 @@ describe('gemini.tsx main function kitty protocol', () => {
useWriteTodos: undefined,
outputFormat: undefined,
fakeResponses: undefined,
recordResponses: undefined,
});
await main();

View File

@@ -284,6 +284,7 @@ export interface ConfigParameters {
retryFetchErrors?: boolean;
enableShellOutputEfficiency?: boolean;
fakeResponses?: string;
recordResponses?: string;
ptyInfo?: string;
disableYoloMode?: boolean;
}
@@ -383,6 +384,7 @@ export class Config {
private readonly retryFetchErrors: boolean;
private readonly enableShellOutputEfficiency: boolean;
readonly fakeResponses?: string;
readonly recordResponses?: string;
private readonly disableYoloMode: boolean;
constructor(params: ConfigParameters) {
@@ -493,6 +495,7 @@ export class Config {
this.extensionManagement = params.extensionManagement ?? true;
this.storage = new Storage(this.targetDir);
this.fakeResponses = params.fakeResponses;
this.recordResponses = params.recordResponses;
this.enablePromptCompletion = params.enablePromptCompletion ?? false;
this.fileExclusions = new FileExclusions(this);
this.eventEmitter = params.eventEmitter;

View File

@@ -16,6 +16,7 @@ import { GoogleGenAI } from '@google/genai';
import type { Config } from '../config/config.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
import { FakeContentGenerator } from './fakeContentGenerator.js';
import { RecordingContentGenerator } from './recordingContentGenerator.js';
vi.mock('../code_assist/codeAssist.js');
vi.mock('@google/genai');
@@ -45,6 +46,22 @@ describe('createContentGenerator', () => {
expect(generator).toEqual(mockGenerator);
});
it('should create a RecordingContentGenerator', async () => {
const fakeResponsesFile = 'fake/responses.yaml';
const recordResponsesFile = 'record/responses.yaml';
const mockConfigWithRecordResponses = {
fakeResponses: fakeResponsesFile,
recordResponses: recordResponsesFile,
} as unknown as Config;
const generator = await createContentGenerator(
{
authType: AuthType.USE_GEMINI,
},
mockConfigWithRecordResponses,
);
expect(generator).toBeInstanceOf(RecordingContentGenerator);
});
it('should create a CodeAssistContentGenerator', async () => {
const mockGenerator = {} as unknown as ContentGenerator;
vi.mocked(createCodeAssistContentGenerator).mockResolvedValue(

View File

@@ -20,6 +20,7 @@ import type { UserTierId } from '../code_assist/types.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
import { InstallationManager } from '../utils/installationManager.js';
import { FakeContentGenerator } from './fakeContentGenerator.js';
import { RecordingContentGenerator } from './recordingContentGenerator.js';
/**
* Interface abstracting the core functionalities for generating content and counting tokens.
@@ -106,55 +107,61 @@ export async function createContentGenerator(
gcConfig: Config,
sessionId?: string,
): Promise<ContentGenerator> {
if (gcConfig.fakeResponses) {
return FakeContentGenerator.fromFile(gcConfig.fakeResponses);
}
const version = process.env['CLI_VERSION'] || process.version;
const userAgent = `GeminiCLI/${version} (${process.platform}; ${process.arch})`;
const baseHeaders: Record<string, string> = {
'User-Agent': userAgent,
};
if (
config.authType === AuthType.LOGIN_WITH_GOOGLE ||
config.authType === AuthType.CLOUD_SHELL
) {
const httpOptions = { headers: baseHeaders };
return new LoggingContentGenerator(
await createCodeAssistContentGenerator(
httpOptions,
config.authType,
gcConfig,
sessionId,
),
gcConfig,
);
}
if (
config.authType === AuthType.USE_GEMINI ||
config.authType === AuthType.USE_VERTEX_AI
) {
let headers: Record<string, string> = { ...baseHeaders };
if (gcConfig?.getUsageStatisticsEnabled()) {
const installationManager = new InstallationManager();
const installationId = installationManager.getInstallationId();
headers = {
...headers,
'x-gemini-api-privileged-user-id': `${installationId}`,
};
const generator = await (async () => {
if (gcConfig.fakeResponses) {
return FakeContentGenerator.fromFile(gcConfig.fakeResponses);
}
const version = process.env['CLI_VERSION'] || process.version;
const userAgent = `GeminiCLI/${version} (${process.platform}; ${process.arch})`;
const baseHeaders: Record<string, string> = {
'User-Agent': userAgent,
};
if (
config.authType === AuthType.LOGIN_WITH_GOOGLE ||
config.authType === AuthType.CLOUD_SHELL
) {
const httpOptions = { headers: baseHeaders };
return new LoggingContentGenerator(
await createCodeAssistContentGenerator(
httpOptions,
config.authType,
gcConfig,
sessionId,
),
gcConfig,
);
}
const httpOptions = { headers };
const googleGenAI = new GoogleGenAI({
apiKey: config.apiKey === '' ? undefined : config.apiKey,
vertexai: config.vertexai,
httpOptions,
});
return new LoggingContentGenerator(googleGenAI.models, gcConfig);
if (
config.authType === AuthType.USE_GEMINI ||
config.authType === AuthType.USE_VERTEX_AI
) {
let headers: Record<string, string> = { ...baseHeaders };
if (gcConfig?.getUsageStatisticsEnabled()) {
const installationManager = new InstallationManager();
const installationId = installationManager.getInstallationId();
headers = {
...headers,
'x-gemini-api-privileged-user-id': `${installationId}`,
};
}
const httpOptions = { headers };
const googleGenAI = new GoogleGenAI({
apiKey: config.apiKey === '' ? undefined : config.apiKey,
vertexai: config.vertexai,
httpOptions,
});
return new LoggingContentGenerator(googleGenAI.models, gcConfig);
}
throw new Error(
`Error creating contentGenerator: Unsupported authType: ${config.authType}`,
);
})();
if (gcConfig.recordResponses) {
return new RecordingContentGenerator(generator, gcConfig.recordResponses);
}
throw new Error(
`Error creating contentGenerator: Unsupported authType: ${config.authType}`,
);
return generator;
}

View File

@@ -5,16 +5,18 @@
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { FakeContentGenerator } from './fakeContentGenerator.js';
import {
FakeContentGenerator,
type FakeResponse,
} from './fakeContentGenerator.js';
import { promises } from 'node:fs';
import type { FakeResponses } from './fakeContentGenerator.js';
import type {
import {
GenerateContentResponse,
CountTokensResponse,
EmbedContentResponse,
GenerateContentParameters,
CountTokensParameters,
EmbedContentParameters,
type CountTokensResponse,
type EmbedContentResponse,
type GenerateContentParameters,
type CountTokensParameters,
type EmbedContentParameters,
} from '@google/genai';
vi.mock('node:fs', async (importOriginal) => {
@@ -31,32 +33,41 @@ vi.mock('node:fs', async (importOriginal) => {
const mockReadFile = vi.mocked(promises.readFile);
describe('FakeContentGenerator', () => {
const fakeResponses: FakeResponses = {
generateContent: [
const fakeGenerateContentResponse: FakeResponse = {
method: 'generateContent',
response: {
candidates: [
{ content: { parts: [{ text: 'response1' }], role: 'model' } },
],
} as GenerateContentResponse,
};
const fakeGenerateContentStreamResponse: FakeResponse = {
method: 'generateContentStream',
response: [
{
candidates: [
{ content: { parts: [{ text: 'response1' }], role: 'model' } },
{ content: { parts: [{ text: 'chunk1' }], role: 'model' } },
],
},
{
candidates: [
{ content: { parts: [{ text: 'chunk2' }], role: 'model' } },
],
},
] as GenerateContentResponse[],
generateContentStream: [
[
{
candidates: [
{ content: { parts: [{ text: 'chunk1' }], role: 'model' } },
],
},
{
candidates: [
{ content: { parts: [{ text: 'chunk2' }], role: 'model' } },
],
},
],
] as GenerateContentResponse[][],
countTokens: [{ totalTokens: 10 }] as CountTokensResponse[],
embedContent: [
{ embeddings: [{ values: [1, 2, 3] }] },
] as EmbedContentResponse[],
};
const fakeCountTokensResponse: FakeResponse = {
method: 'countTokens',
response: { totalTokens: 10 } as CountTokensResponse,
};
const fakeEmbedContentResponse: FakeResponse = {
method: 'embedContent',
response: {
embeddings: [{ values: [1, 2, 3] }],
} as EmbedContentResponse,
};
beforeEach(() => {
@@ -64,90 +75,86 @@ describe('FakeContentGenerator', () => {
});
it('should return responses for generateContent', async () => {
const generator = new FakeContentGenerator(fakeResponses);
const generator = new FakeContentGenerator([fakeGenerateContentResponse]);
const response = await generator.generateContent(
{} as GenerateContentParameters,
'id',
);
expect(response).toEqual(fakeResponses.generateContent[0]);
});
it('should throw error when no more generateContent responses', async () => {
const generator = new FakeContentGenerator({
...fakeResponses,
generateContent: [],
});
await expect(
generator.generateContent({} as GenerateContentParameters, 'id'),
).rejects.toThrowError('No more mock responses for generateContent');
expect(response).instanceOf(GenerateContentResponse);
expect(response).toEqual(fakeGenerateContentResponse.response);
});
it('should return responses for generateContentStream', async () => {
const generator = new FakeContentGenerator(fakeResponses);
const generator = new FakeContentGenerator([
fakeGenerateContentStreamResponse,
]);
const stream = await generator.generateContentStream(
{} as GenerateContentParameters,
'id',
);
const responses = [];
for await (const response of stream) {
expect(response).instanceOf(GenerateContentResponse);
responses.push(response);
}
expect(responses).toEqual(fakeResponses.generateContentStream[0]);
});
it('should throw error when no more generateContentStream responses', async () => {
const generator = new FakeContentGenerator({
...fakeResponses,
generateContentStream: [],
});
await expect(
generator.generateContentStream({} as GenerateContentParameters, 'id'),
).rejects.toThrow('No more mock responses for generateContentStream');
expect(responses).toEqual(fakeGenerateContentStreamResponse.response);
});
it('should return responses for countTokens', async () => {
const generator = new FakeContentGenerator(fakeResponses);
const generator = new FakeContentGenerator([fakeCountTokensResponse]);
const response = await generator.countTokens({} as CountTokensParameters);
expect(response).toEqual(fakeResponses.countTokens[0]);
});
it('should throw error when no more countTokens responses', async () => {
const generator = new FakeContentGenerator({
...fakeResponses,
countTokens: [],
});
await expect(
generator.countTokens({} as CountTokensParameters),
).rejects.toThrowError('No more mock responses for countTokens');
expect(response).toEqual(fakeCountTokensResponse.response);
});
it('should return responses for embedContent', async () => {
const generator = new FakeContentGenerator(fakeResponses);
const generator = new FakeContentGenerator([fakeEmbedContentResponse]);
const response = await generator.embedContent({} as EmbedContentParameters);
expect(response).toEqual(fakeResponses.embedContent[0]);
expect(response).toEqual(fakeEmbedContentResponse.response);
});
it('should throw error when no more embedContent responses', async () => {
const generator = new FakeContentGenerator({
...fakeResponses,
embedContent: [],
});
it('should handle a mixture of calls', async () => {
const fakeResponses = [
fakeGenerateContentResponse,
fakeGenerateContentStreamResponse,
fakeCountTokensResponse,
fakeEmbedContentResponse,
];
const generator = new FakeContentGenerator(fakeResponses);
for (const fakeResponse of fakeResponses) {
const response = await generator[fakeResponse.method]({} as never, '');
if (fakeResponse.method === 'generateContentStream') {
const responses = [];
for await (const item of response as AsyncGenerator<GenerateContentResponse>) {
expect(item).instanceOf(GenerateContentResponse);
responses.push(item);
}
expect(responses).toEqual(fakeResponse.response);
} else {
expect(response).toEqual(fakeResponse.response);
}
}
});
it('should throw error when no more responses', async () => {
const generator = new FakeContentGenerator([fakeGenerateContentResponse]);
await generator.generateContent({} as GenerateContentParameters, 'id');
await expect(
generator.embedContent({} as EmbedContentParameters),
).rejects.toThrowError('No more mock responses for embedContent');
});
it('should handle multiple calls and exhaust responses', async () => {
const generator = new FakeContentGenerator(fakeResponses);
await generator.generateContent({} as GenerateContentParameters, 'id');
await expect(
generator.countTokens({} as CountTokensParameters),
).rejects.toThrowError('No more mock responses for countTokens');
await expect(
generator.generateContentStream({} as GenerateContentParameters, 'id'),
).rejects.toThrow('No more mock responses for generateContentStream');
await expect(
generator.generateContent({} as GenerateContentParameters, 'id'),
).rejects.toThrow();
).rejects.toThrowError('No more mock responses for generateContent');
});
describe('fromFile', () => {
it('should create a generator from a file', async () => {
const fileContent = JSON.stringify(fakeResponses);
const fileContent = JSON.stringify(fakeGenerateContentResponse) + '\n';
mockReadFile.mockResolvedValue(fileContent);
const generator = await FakeContentGenerator.fromFile('fake-path.json');
@@ -155,51 +162,7 @@ describe('FakeContentGenerator', () => {
{} as GenerateContentParameters,
'id',
);
expect(response).toEqual(fakeResponses.generateContent[0]);
});
});
describe('constructor with partial responses', () => {
it('should handle missing generateContent', async () => {
const responses = { ...fakeResponses, generateContent: undefined };
const generator = new FakeContentGenerator(
responses as unknown as FakeResponses,
);
await expect(
generator.generateContent({} as GenerateContentParameters, 'id'),
).rejects.toThrowError('No more mock responses for generateContent');
});
it('should handle missing generateContentStream', async () => {
const responses = { ...fakeResponses, generateContentStream: undefined };
const generator = new FakeContentGenerator(
responses as unknown as FakeResponses,
);
await expect(
generator.generateContentStream({} as GenerateContentParameters, 'id'),
).rejects.toThrowError(
'No more mock responses for generateContentStream',
);
});
it('should handle missing countTokens', async () => {
const responses = { ...fakeResponses, countTokens: undefined };
const generator = new FakeContentGenerator(
responses as unknown as FakeResponses,
);
await expect(
generator.countTokens({} as CountTokensParameters),
).rejects.toThrowError('No more mock responses for countTokens');
});
it('should handle missing embedContent', async () => {
const responses = { ...fakeResponses, embedContent: undefined };
const generator = new FakeContentGenerator(
responses as unknown as FakeResponses,
);
await expect(
generator.embedContent({} as EmbedContentParameters),
).rejects.toThrowError('No more mock responses for embedContent');
expect(response).toEqual(fakeGenerateContentResponse.response);
});
});
});

View File

@@ -4,98 +4,113 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type {
CountTokensResponse,
import {
GenerateContentResponse,
GenerateContentParameters,
CountTokensParameters,
type CountTokensResponse,
type GenerateContentParameters,
type CountTokensParameters,
EmbedContentResponse,
EmbedContentParameters,
type EmbedContentParameters,
} from '@google/genai';
import { promises } from 'node:fs';
import type { ContentGenerator } from './contentGenerator.js';
import type { UserTierId } from '../code_assist/types.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
export type FakeResponses = {
generateContent: GenerateContentResponse[];
generateContentStream: GenerateContentResponse[][];
countTokens: CountTokensResponse[];
embedContent: EmbedContentResponse[];
};
export type FakeResponse =
| {
method: 'generateContent';
response: GenerateContentResponse;
}
| {
method: 'generateContentStream';
response: GenerateContentResponse[];
}
| {
method: 'countTokens';
response: CountTokensResponse;
}
| {
method: 'embedContent';
response: EmbedContentResponse;
};
// A ContentGenerator that responds with canned responses.
//
// Typically these would come from a file, provided by the `--fake-responses`
// CLI argument.
export class FakeContentGenerator implements ContentGenerator {
private responses: FakeResponses;
private callCounters = {
generateContent: 0,
generateContentStream: 0,
countTokens: 0,
embedContent: 0,
};
private callCounter = 0;
userTier?: UserTierId;
constructor(responses: FakeResponses) {
this.responses = {
generateContent: responses.generateContent ?? [],
generateContentStream: responses.generateContentStream ?? [],
countTokens: responses.countTokens ?? [],
embedContent: responses.embedContent ?? [],
};
}
constructor(private readonly responses: FakeResponse[]) {}
static async fromFile(filePath: string): Promise<FakeContentGenerator> {
const fileContent = await promises.readFile(filePath, 'utf-8');
const responses = JSON.parse(fileContent) as FakeResponses;
const responses = fileContent
.split('\n')
.filter((line) => line.trim() !== '')
.map((line) => JSON.parse(line) as FakeResponse);
return new FakeContentGenerator(responses);
}
private getNextResponse<K extends keyof FakeResponses>(
method: K,
request: unknown,
): FakeResponses[K][number] {
const response = this.responses[method][this.callCounters[method]++];
private getNextResponse<
M extends FakeResponse['method'],
R = Extract<FakeResponse, { method: M }>['response'],
>(method: M, request: unknown): R {
const response = this.responses[this.callCounter++];
if (!response) {
throw new Error(
`No more mock responses for ${method}, got request:\n` +
safeJsonStringify(request),
);
}
return response;
if (response.method !== method) {
throw new Error(
`Unexpected response type, next response was for ${response.method} but expected ${method}`,
);
}
return response.response as R;
}
async generateContent(
_request: GenerateContentParameters,
request: GenerateContentParameters,
_userPromptId: string,
): Promise<GenerateContentResponse> {
return this.getNextResponse('generateContent', _request);
return Object.setPrototypeOf(
this.getNextResponse('generateContent', request),
GenerateContentResponse.prototype,
);
}
async generateContentStream(
_request: GenerateContentParameters,
request: GenerateContentParameters,
_userPromptId: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const responses = this.getNextResponse('generateContentStream', _request);
const responses = this.getNextResponse('generateContentStream', request);
async function* stream() {
for (const response of responses) {
yield response;
yield Object.setPrototypeOf(
response,
GenerateContentResponse.prototype,
);
}
}
return stream();
}
async countTokens(
_request: CountTokensParameters,
request: CountTokensParameters,
): Promise<CountTokensResponse> {
return this.getNextResponse('countTokens', _request);
return this.getNextResponse('countTokens', request);
}
async embedContent(
_request: EmbedContentParameters,
request: EmbedContentParameters,
): Promise<EmbedContentResponse> {
return this.getNextResponse('embedContent', _request);
return Object.setPrototypeOf(
this.getNextResponse('embedContent', request),
EmbedContentResponse.prototype,
);
}
}

View File

@@ -0,0 +1,151 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type {
GenerateContentResponse,
CountTokensResponse,
EmbedContentResponse,
GenerateContentParameters,
CountTokensParameters,
EmbedContentParameters,
ContentEmbedding,
} from '@google/genai';
import { appendFileSync } from 'node:fs';
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
import type { ContentGenerator } from './contentGenerator.js';
import { RecordingContentGenerator } from './recordingContentGenerator.js';
vi.mock('node:fs', () => ({
appendFileSync: vi.fn(),
}));
describe('RecordingContentGenerator', () => {
let mockRealGenerator: ContentGenerator;
let recorder: RecordingContentGenerator;
const filePath = '/test/file/responses.json';
beforeEach(() => {
mockRealGenerator = {
generateContent: vi.fn(),
generateContentStream: vi.fn(),
countTokens: vi.fn(),
embedContent: vi.fn(),
};
recorder = new RecordingContentGenerator(mockRealGenerator, filePath);
vi.clearAllMocks();
});
it('should record generateContent responses', async () => {
const mockResponse = {
candidates: [
{ content: { parts: [{ text: 'response' }], role: 'model' } },
],
usageMetadata: { totalTokenCount: 10 },
} as GenerateContentResponse;
(mockRealGenerator.generateContent as Mock).mockResolvedValue(mockResponse);
const response = await recorder.generateContent(
{} as GenerateContentParameters,
'id1',
);
expect(response).toEqual(mockResponse);
expect(mockRealGenerator.generateContent).toHaveBeenCalledWith({}, 'id1');
expect(appendFileSync).toHaveBeenCalledWith(
filePath,
safeJsonStringify({
method: 'generateContent',
response: mockResponse,
}) + '\n',
);
});
it('should record generateContentStream responses', async () => {
const mockResponse1 = {
candidates: [
{ content: { parts: [{ text: 'response1' }], role: 'model' } },
],
usageMetadata: { totalTokenCount: 10 },
} as GenerateContentResponse;
const mockResponse2 = {
candidates: [
{ content: { parts: [{ text: 'response2' }], role: 'model' } },
],
usageMetadata: { totalTokenCount: 20 },
} as GenerateContentResponse;
async function* mockStream() {
yield mockResponse1;
yield mockResponse2;
}
(mockRealGenerator.generateContentStream as Mock).mockResolvedValue(
mockStream(),
);
const stream = await recorder.generateContentStream(
{} as GenerateContentParameters,
'id1',
);
const responses = [];
for await (const response of stream) {
responses.push(response);
}
expect(responses).toEqual([mockResponse1, mockResponse2]);
expect(mockRealGenerator.generateContentStream).toHaveBeenCalledWith(
{},
'id1',
);
expect(appendFileSync).toHaveBeenCalledWith(
filePath,
safeJsonStringify({
method: 'generateContentStream',
response: responses,
}) + '\n',
);
});
it('should record countTokens responses', async () => {
const mockResponse = {
totalTokens: 100,
cachedContentTokenCount: 10,
} as CountTokensResponse;
(mockRealGenerator.countTokens as Mock).mockResolvedValue(mockResponse);
const response = await recorder.countTokens({} as CountTokensParameters);
expect(response).toEqual(mockResponse);
expect(mockRealGenerator.countTokens).toHaveBeenCalledWith({});
expect(appendFileSync).toHaveBeenCalledWith(
filePath,
safeJsonStringify({
method: 'countTokens',
response: mockResponse,
}) + '\n',
);
});
it('should record embedContent responses', async () => {
const mockResponse = {
embeddings: [{ values: [1, 2, 3] } as ContentEmbedding],
} as EmbedContentResponse;
(mockRealGenerator.embedContent as Mock).mockResolvedValue(mockResponse);
const response = await recorder.embedContent({} as EmbedContentParameters);
expect(response).toEqual(mockResponse);
expect(mockRealGenerator.embedContent).toHaveBeenCalledWith({});
expect(appendFileSync).toHaveBeenCalledWith(
filePath,
safeJsonStringify({
method: 'embedContent',
response: mockResponse,
}) + '\n',
);
});
});

View File

@@ -0,0 +1,112 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type {
CountTokensResponse,
GenerateContentParameters,
GenerateContentResponse,
CountTokensParameters,
EmbedContentResponse,
EmbedContentParameters,
} from '@google/genai';
import { appendFileSync } from 'node:fs';
import type { ContentGenerator } from './contentGenerator.js';
import type { FakeResponse } from './fakeContentGenerator.js';
import type { UserTierId } from '../code_assist/types.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
// A ContentGenerator that wraps another content generator and records all the
// responses, with the ability to write them out to a file. These files are
// intended to be consumed later on by a FakeContentGenerator, given the
// `--fake-responses` CLI argument.
//
// Note that only the "interesting" bits of the responses are actually kept.
export class RecordingContentGenerator implements ContentGenerator {
userTier?: UserTierId;
constructor(
private readonly realGenerator: ContentGenerator,
private readonly filePath: string,
) {}
async generateContent(
request: GenerateContentParameters,
userPromptId: string,
): Promise<GenerateContentResponse> {
const response = await this.realGenerator.generateContent(
request,
userPromptId,
);
const recordedResponse: FakeResponse = {
method: 'generateContent',
response: {
candidates: response.candidates,
usageMetadata: response.usageMetadata,
} as GenerateContentResponse,
};
appendFileSync(this.filePath, `${safeJsonStringify(recordedResponse)}\n`);
return response;
}
async generateContentStream(
request: GenerateContentParameters,
userPromptId: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const recordedResponse: FakeResponse = {
method: 'generateContentStream',
response: [],
};
const realResponses = await this.realGenerator.generateContentStream(
request,
userPromptId,
);
async function* stream(filePath: string) {
for await (const response of realResponses) {
(recordedResponse.response as GenerateContentResponse[]).push({
candidates: response.candidates,
usageMetadata: response.usageMetadata,
} as GenerateContentResponse);
yield response;
}
appendFileSync(filePath, `${safeJsonStringify(recordedResponse)}\n`);
}
return Promise.resolve(stream(this.filePath));
}
async countTokens(
request: CountTokensParameters,
): Promise<CountTokensResponse> {
const response = await this.realGenerator.countTokens(request);
const recordedResponse: FakeResponse = {
method: 'countTokens',
response: {
totalTokens: response.totalTokens,
cachedContentTokenCount: response.cachedContentTokenCount,
},
};
appendFileSync(this.filePath, `${safeJsonStringify(recordedResponse)}\n`);
return response;
}
async embedContent(
request: EmbedContentParameters,
): Promise<EmbedContentResponse> {
const response = await this.realGenerator.embedContent(request);
const recordedResponse: FakeResponse = {
method: 'embedContent',
response: {
embeddings: response.embeddings,
metadata: response.metadata,
},
};
appendFileSync(this.filePath, `${safeJsonStringify(recordedResponse)}\n`);
return response;
}
}

View File

@@ -29,6 +29,7 @@ export * from './core/turn.js';
export * from './core/geminiRequest.js';
export * from './core/coreToolScheduler.js';
export * from './core/nonInteractiveToolExecutor.js';
export * from './core/recordingContentGenerator.js';
export * from './fallback/types.js';