mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-16 08:10:46 -07:00
Record model responses with --record-responses (for use in testing) (#11894)
This commit is contained in:
@@ -74,6 +74,7 @@ export interface CliArgs {
|
||||
useWriteTodos: boolean | undefined;
|
||||
outputFormat: string | undefined;
|
||||
fakeResponses: string | undefined;
|
||||
recordResponses: string | undefined;
|
||||
}
|
||||
|
||||
export async function parseArguments(settings: Settings): Promise<CliArgs> {
|
||||
@@ -202,6 +203,12 @@ export async function parseArguments(settings: Settings): Promise<CliArgs> {
|
||||
.option('fake-responses', {
|
||||
type: 'string',
|
||||
description: 'Path to a file with fake model responses for testing.',
|
||||
hidden: true,
|
||||
})
|
||||
.option('record-responses', {
|
||||
type: 'string',
|
||||
description: 'Path to a file to record model responses for testing.',
|
||||
hidden: true,
|
||||
})
|
||||
.deprecateOption(
|
||||
'prompt',
|
||||
@@ -700,6 +707,7 @@ export async function loadCliConfig(
|
||||
codebaseInvestigatorSettings:
|
||||
settings.experimental?.codebaseInvestigatorSettings,
|
||||
fakeResponses: argv.fakeResponses,
|
||||
recordResponses: argv.recordResponses,
|
||||
retryFetchErrors: settings.general?.retryFetchErrors ?? false,
|
||||
ptyInfo: ptyInfo?.name,
|
||||
});
|
||||
|
||||
@@ -364,6 +364,7 @@ describe('gemini.tsx main function kitty protocol', () => {
|
||||
useWriteTodos: undefined,
|
||||
outputFormat: undefined,
|
||||
fakeResponses: undefined,
|
||||
recordResponses: undefined,
|
||||
});
|
||||
|
||||
await main();
|
||||
|
||||
@@ -284,6 +284,7 @@ export interface ConfigParameters {
|
||||
retryFetchErrors?: boolean;
|
||||
enableShellOutputEfficiency?: boolean;
|
||||
fakeResponses?: string;
|
||||
recordResponses?: string;
|
||||
ptyInfo?: string;
|
||||
disableYoloMode?: boolean;
|
||||
}
|
||||
@@ -383,6 +384,7 @@ export class Config {
|
||||
private readonly retryFetchErrors: boolean;
|
||||
private readonly enableShellOutputEfficiency: boolean;
|
||||
readonly fakeResponses?: string;
|
||||
readonly recordResponses?: string;
|
||||
private readonly disableYoloMode: boolean;
|
||||
|
||||
constructor(params: ConfigParameters) {
|
||||
@@ -493,6 +495,7 @@ export class Config {
|
||||
this.extensionManagement = params.extensionManagement ?? true;
|
||||
this.storage = new Storage(this.targetDir);
|
||||
this.fakeResponses = params.fakeResponses;
|
||||
this.recordResponses = params.recordResponses;
|
||||
this.enablePromptCompletion = params.enablePromptCompletion ?? false;
|
||||
this.fileExclusions = new FileExclusions(this);
|
||||
this.eventEmitter = params.eventEmitter;
|
||||
|
||||
@@ -16,6 +16,7 @@ import { GoogleGenAI } from '@google/genai';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { LoggingContentGenerator } from './loggingContentGenerator.js';
|
||||
import { FakeContentGenerator } from './fakeContentGenerator.js';
|
||||
import { RecordingContentGenerator } from './recordingContentGenerator.js';
|
||||
|
||||
vi.mock('../code_assist/codeAssist.js');
|
||||
vi.mock('@google/genai');
|
||||
@@ -45,6 +46,22 @@ describe('createContentGenerator', () => {
|
||||
expect(generator).toEqual(mockGenerator);
|
||||
});
|
||||
|
||||
it('should create a RecordingContentGenerator', async () => {
|
||||
const fakeResponsesFile = 'fake/responses.yaml';
|
||||
const recordResponsesFile = 'record/responses.yaml';
|
||||
const mockConfigWithRecordResponses = {
|
||||
fakeResponses: fakeResponsesFile,
|
||||
recordResponses: recordResponsesFile,
|
||||
} as unknown as Config;
|
||||
const generator = await createContentGenerator(
|
||||
{
|
||||
authType: AuthType.USE_GEMINI,
|
||||
},
|
||||
mockConfigWithRecordResponses,
|
||||
);
|
||||
expect(generator).toBeInstanceOf(RecordingContentGenerator);
|
||||
});
|
||||
|
||||
it('should create a CodeAssistContentGenerator', async () => {
|
||||
const mockGenerator = {} as unknown as ContentGenerator;
|
||||
vi.mocked(createCodeAssistContentGenerator).mockResolvedValue(
|
||||
|
||||
@@ -20,6 +20,7 @@ import type { UserTierId } from '../code_assist/types.js';
|
||||
import { LoggingContentGenerator } from './loggingContentGenerator.js';
|
||||
import { InstallationManager } from '../utils/installationManager.js';
|
||||
import { FakeContentGenerator } from './fakeContentGenerator.js';
|
||||
import { RecordingContentGenerator } from './recordingContentGenerator.js';
|
||||
|
||||
/**
|
||||
* Interface abstracting the core functionalities for generating content and counting tokens.
|
||||
@@ -106,55 +107,61 @@ export async function createContentGenerator(
|
||||
gcConfig: Config,
|
||||
sessionId?: string,
|
||||
): Promise<ContentGenerator> {
|
||||
if (gcConfig.fakeResponses) {
|
||||
return FakeContentGenerator.fromFile(gcConfig.fakeResponses);
|
||||
}
|
||||
|
||||
const version = process.env['CLI_VERSION'] || process.version;
|
||||
const userAgent = `GeminiCLI/${version} (${process.platform}; ${process.arch})`;
|
||||
const baseHeaders: Record<string, string> = {
|
||||
'User-Agent': userAgent,
|
||||
};
|
||||
|
||||
if (
|
||||
config.authType === AuthType.LOGIN_WITH_GOOGLE ||
|
||||
config.authType === AuthType.CLOUD_SHELL
|
||||
) {
|
||||
const httpOptions = { headers: baseHeaders };
|
||||
return new LoggingContentGenerator(
|
||||
await createCodeAssistContentGenerator(
|
||||
httpOptions,
|
||||
config.authType,
|
||||
gcConfig,
|
||||
sessionId,
|
||||
),
|
||||
gcConfig,
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
config.authType === AuthType.USE_GEMINI ||
|
||||
config.authType === AuthType.USE_VERTEX_AI
|
||||
) {
|
||||
let headers: Record<string, string> = { ...baseHeaders };
|
||||
if (gcConfig?.getUsageStatisticsEnabled()) {
|
||||
const installationManager = new InstallationManager();
|
||||
const installationId = installationManager.getInstallationId();
|
||||
headers = {
|
||||
...headers,
|
||||
'x-gemini-api-privileged-user-id': `${installationId}`,
|
||||
};
|
||||
const generator = await (async () => {
|
||||
if (gcConfig.fakeResponses) {
|
||||
return FakeContentGenerator.fromFile(gcConfig.fakeResponses);
|
||||
}
|
||||
const version = process.env['CLI_VERSION'] || process.version;
|
||||
const userAgent = `GeminiCLI/${version} (${process.platform}; ${process.arch})`;
|
||||
const baseHeaders: Record<string, string> = {
|
||||
'User-Agent': userAgent,
|
||||
};
|
||||
if (
|
||||
config.authType === AuthType.LOGIN_WITH_GOOGLE ||
|
||||
config.authType === AuthType.CLOUD_SHELL
|
||||
) {
|
||||
const httpOptions = { headers: baseHeaders };
|
||||
return new LoggingContentGenerator(
|
||||
await createCodeAssistContentGenerator(
|
||||
httpOptions,
|
||||
config.authType,
|
||||
gcConfig,
|
||||
sessionId,
|
||||
),
|
||||
gcConfig,
|
||||
);
|
||||
}
|
||||
const httpOptions = { headers };
|
||||
|
||||
const googleGenAI = new GoogleGenAI({
|
||||
apiKey: config.apiKey === '' ? undefined : config.apiKey,
|
||||
vertexai: config.vertexai,
|
||||
httpOptions,
|
||||
});
|
||||
return new LoggingContentGenerator(googleGenAI.models, gcConfig);
|
||||
if (
|
||||
config.authType === AuthType.USE_GEMINI ||
|
||||
config.authType === AuthType.USE_VERTEX_AI
|
||||
) {
|
||||
let headers: Record<string, string> = { ...baseHeaders };
|
||||
if (gcConfig?.getUsageStatisticsEnabled()) {
|
||||
const installationManager = new InstallationManager();
|
||||
const installationId = installationManager.getInstallationId();
|
||||
headers = {
|
||||
...headers,
|
||||
'x-gemini-api-privileged-user-id': `${installationId}`,
|
||||
};
|
||||
}
|
||||
const httpOptions = { headers };
|
||||
|
||||
const googleGenAI = new GoogleGenAI({
|
||||
apiKey: config.apiKey === '' ? undefined : config.apiKey,
|
||||
vertexai: config.vertexai,
|
||||
httpOptions,
|
||||
});
|
||||
return new LoggingContentGenerator(googleGenAI.models, gcConfig);
|
||||
}
|
||||
throw new Error(
|
||||
`Error creating contentGenerator: Unsupported authType: ${config.authType}`,
|
||||
);
|
||||
})();
|
||||
|
||||
if (gcConfig.recordResponses) {
|
||||
return new RecordingContentGenerator(generator, gcConfig.recordResponses);
|
||||
}
|
||||
throw new Error(
|
||||
`Error creating contentGenerator: Unsupported authType: ${config.authType}`,
|
||||
);
|
||||
|
||||
return generator;
|
||||
}
|
||||
|
||||
@@ -5,16 +5,18 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { FakeContentGenerator } from './fakeContentGenerator.js';
|
||||
import {
|
||||
FakeContentGenerator,
|
||||
type FakeResponse,
|
||||
} from './fakeContentGenerator.js';
|
||||
import { promises } from 'node:fs';
|
||||
import type { FakeResponses } from './fakeContentGenerator.js';
|
||||
import type {
|
||||
import {
|
||||
GenerateContentResponse,
|
||||
CountTokensResponse,
|
||||
EmbedContentResponse,
|
||||
GenerateContentParameters,
|
||||
CountTokensParameters,
|
||||
EmbedContentParameters,
|
||||
type CountTokensResponse,
|
||||
type EmbedContentResponse,
|
||||
type GenerateContentParameters,
|
||||
type CountTokensParameters,
|
||||
type EmbedContentParameters,
|
||||
} from '@google/genai';
|
||||
|
||||
vi.mock('node:fs', async (importOriginal) => {
|
||||
@@ -31,32 +33,41 @@ vi.mock('node:fs', async (importOriginal) => {
|
||||
const mockReadFile = vi.mocked(promises.readFile);
|
||||
|
||||
describe('FakeContentGenerator', () => {
|
||||
const fakeResponses: FakeResponses = {
|
||||
generateContent: [
|
||||
const fakeGenerateContentResponse: FakeResponse = {
|
||||
method: 'generateContent',
|
||||
response: {
|
||||
candidates: [
|
||||
{ content: { parts: [{ text: 'response1' }], role: 'model' } },
|
||||
],
|
||||
} as GenerateContentResponse,
|
||||
};
|
||||
|
||||
const fakeGenerateContentStreamResponse: FakeResponse = {
|
||||
method: 'generateContentStream',
|
||||
response: [
|
||||
{
|
||||
candidates: [
|
||||
{ content: { parts: [{ text: 'response1' }], role: 'model' } },
|
||||
{ content: { parts: [{ text: 'chunk1' }], role: 'model' } },
|
||||
],
|
||||
},
|
||||
{
|
||||
candidates: [
|
||||
{ content: { parts: [{ text: 'chunk2' }], role: 'model' } },
|
||||
],
|
||||
},
|
||||
] as GenerateContentResponse[],
|
||||
generateContentStream: [
|
||||
[
|
||||
{
|
||||
candidates: [
|
||||
{ content: { parts: [{ text: 'chunk1' }], role: 'model' } },
|
||||
],
|
||||
},
|
||||
{
|
||||
candidates: [
|
||||
{ content: { parts: [{ text: 'chunk2' }], role: 'model' } },
|
||||
],
|
||||
},
|
||||
],
|
||||
] as GenerateContentResponse[][],
|
||||
countTokens: [{ totalTokens: 10 }] as CountTokensResponse[],
|
||||
embedContent: [
|
||||
{ embeddings: [{ values: [1, 2, 3] }] },
|
||||
] as EmbedContentResponse[],
|
||||
};
|
||||
|
||||
const fakeCountTokensResponse: FakeResponse = {
|
||||
method: 'countTokens',
|
||||
response: { totalTokens: 10 } as CountTokensResponse,
|
||||
};
|
||||
|
||||
const fakeEmbedContentResponse: FakeResponse = {
|
||||
method: 'embedContent',
|
||||
response: {
|
||||
embeddings: [{ values: [1, 2, 3] }],
|
||||
} as EmbedContentResponse,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -64,90 +75,86 @@ describe('FakeContentGenerator', () => {
|
||||
});
|
||||
|
||||
it('should return responses for generateContent', async () => {
|
||||
const generator = new FakeContentGenerator(fakeResponses);
|
||||
const generator = new FakeContentGenerator([fakeGenerateContentResponse]);
|
||||
const response = await generator.generateContent(
|
||||
{} as GenerateContentParameters,
|
||||
'id',
|
||||
);
|
||||
expect(response).toEqual(fakeResponses.generateContent[0]);
|
||||
});
|
||||
|
||||
it('should throw error when no more generateContent responses', async () => {
|
||||
const generator = new FakeContentGenerator({
|
||||
...fakeResponses,
|
||||
generateContent: [],
|
||||
});
|
||||
await expect(
|
||||
generator.generateContent({} as GenerateContentParameters, 'id'),
|
||||
).rejects.toThrowError('No more mock responses for generateContent');
|
||||
expect(response).instanceOf(GenerateContentResponse);
|
||||
expect(response).toEqual(fakeGenerateContentResponse.response);
|
||||
});
|
||||
|
||||
it('should return responses for generateContentStream', async () => {
|
||||
const generator = new FakeContentGenerator(fakeResponses);
|
||||
const generator = new FakeContentGenerator([
|
||||
fakeGenerateContentStreamResponse,
|
||||
]);
|
||||
const stream = await generator.generateContentStream(
|
||||
{} as GenerateContentParameters,
|
||||
'id',
|
||||
);
|
||||
const responses = [];
|
||||
for await (const response of stream) {
|
||||
expect(response).instanceOf(GenerateContentResponse);
|
||||
responses.push(response);
|
||||
}
|
||||
expect(responses).toEqual(fakeResponses.generateContentStream[0]);
|
||||
});
|
||||
|
||||
it('should throw error when no more generateContentStream responses', async () => {
|
||||
const generator = new FakeContentGenerator({
|
||||
...fakeResponses,
|
||||
generateContentStream: [],
|
||||
});
|
||||
await expect(
|
||||
generator.generateContentStream({} as GenerateContentParameters, 'id'),
|
||||
).rejects.toThrow('No more mock responses for generateContentStream');
|
||||
expect(responses).toEqual(fakeGenerateContentStreamResponse.response);
|
||||
});
|
||||
|
||||
it('should return responses for countTokens', async () => {
|
||||
const generator = new FakeContentGenerator(fakeResponses);
|
||||
const generator = new FakeContentGenerator([fakeCountTokensResponse]);
|
||||
const response = await generator.countTokens({} as CountTokensParameters);
|
||||
expect(response).toEqual(fakeResponses.countTokens[0]);
|
||||
});
|
||||
|
||||
it('should throw error when no more countTokens responses', async () => {
|
||||
const generator = new FakeContentGenerator({
|
||||
...fakeResponses,
|
||||
countTokens: [],
|
||||
});
|
||||
await expect(
|
||||
generator.countTokens({} as CountTokensParameters),
|
||||
).rejects.toThrowError('No more mock responses for countTokens');
|
||||
expect(response).toEqual(fakeCountTokensResponse.response);
|
||||
});
|
||||
|
||||
it('should return responses for embedContent', async () => {
|
||||
const generator = new FakeContentGenerator(fakeResponses);
|
||||
const generator = new FakeContentGenerator([fakeEmbedContentResponse]);
|
||||
const response = await generator.embedContent({} as EmbedContentParameters);
|
||||
expect(response).toEqual(fakeResponses.embedContent[0]);
|
||||
expect(response).toEqual(fakeEmbedContentResponse.response);
|
||||
});
|
||||
|
||||
it('should throw error when no more embedContent responses', async () => {
|
||||
const generator = new FakeContentGenerator({
|
||||
...fakeResponses,
|
||||
embedContent: [],
|
||||
});
|
||||
it('should handle a mixture of calls', async () => {
|
||||
const fakeResponses = [
|
||||
fakeGenerateContentResponse,
|
||||
fakeGenerateContentStreamResponse,
|
||||
fakeCountTokensResponse,
|
||||
fakeEmbedContentResponse,
|
||||
];
|
||||
const generator = new FakeContentGenerator(fakeResponses);
|
||||
for (const fakeResponse of fakeResponses) {
|
||||
const response = await generator[fakeResponse.method]({} as never, '');
|
||||
if (fakeResponse.method === 'generateContentStream') {
|
||||
const responses = [];
|
||||
for await (const item of response as AsyncGenerator<GenerateContentResponse>) {
|
||||
expect(item).instanceOf(GenerateContentResponse);
|
||||
responses.push(item);
|
||||
}
|
||||
expect(responses).toEqual(fakeResponse.response);
|
||||
} else {
|
||||
expect(response).toEqual(fakeResponse.response);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
it('should throw error when no more responses', async () => {
|
||||
const generator = new FakeContentGenerator([fakeGenerateContentResponse]);
|
||||
await generator.generateContent({} as GenerateContentParameters, 'id');
|
||||
await expect(
|
||||
generator.embedContent({} as EmbedContentParameters),
|
||||
).rejects.toThrowError('No more mock responses for embedContent');
|
||||
});
|
||||
|
||||
it('should handle multiple calls and exhaust responses', async () => {
|
||||
const generator = new FakeContentGenerator(fakeResponses);
|
||||
await generator.generateContent({} as GenerateContentParameters, 'id');
|
||||
await expect(
|
||||
generator.countTokens({} as CountTokensParameters),
|
||||
).rejects.toThrowError('No more mock responses for countTokens');
|
||||
await expect(
|
||||
generator.generateContentStream({} as GenerateContentParameters, 'id'),
|
||||
).rejects.toThrow('No more mock responses for generateContentStream');
|
||||
await expect(
|
||||
generator.generateContent({} as GenerateContentParameters, 'id'),
|
||||
).rejects.toThrow();
|
||||
).rejects.toThrowError('No more mock responses for generateContent');
|
||||
});
|
||||
|
||||
describe('fromFile', () => {
|
||||
it('should create a generator from a file', async () => {
|
||||
const fileContent = JSON.stringify(fakeResponses);
|
||||
const fileContent = JSON.stringify(fakeGenerateContentResponse) + '\n';
|
||||
mockReadFile.mockResolvedValue(fileContent);
|
||||
|
||||
const generator = await FakeContentGenerator.fromFile('fake-path.json');
|
||||
@@ -155,51 +162,7 @@ describe('FakeContentGenerator', () => {
|
||||
{} as GenerateContentParameters,
|
||||
'id',
|
||||
);
|
||||
expect(response).toEqual(fakeResponses.generateContent[0]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('constructor with partial responses', () => {
|
||||
it('should handle missing generateContent', async () => {
|
||||
const responses = { ...fakeResponses, generateContent: undefined };
|
||||
const generator = new FakeContentGenerator(
|
||||
responses as unknown as FakeResponses,
|
||||
);
|
||||
await expect(
|
||||
generator.generateContent({} as GenerateContentParameters, 'id'),
|
||||
).rejects.toThrowError('No more mock responses for generateContent');
|
||||
});
|
||||
|
||||
it('should handle missing generateContentStream', async () => {
|
||||
const responses = { ...fakeResponses, generateContentStream: undefined };
|
||||
const generator = new FakeContentGenerator(
|
||||
responses as unknown as FakeResponses,
|
||||
);
|
||||
await expect(
|
||||
generator.generateContentStream({} as GenerateContentParameters, 'id'),
|
||||
).rejects.toThrowError(
|
||||
'No more mock responses for generateContentStream',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle missing countTokens', async () => {
|
||||
const responses = { ...fakeResponses, countTokens: undefined };
|
||||
const generator = new FakeContentGenerator(
|
||||
responses as unknown as FakeResponses,
|
||||
);
|
||||
await expect(
|
||||
generator.countTokens({} as CountTokensParameters),
|
||||
).rejects.toThrowError('No more mock responses for countTokens');
|
||||
});
|
||||
|
||||
it('should handle missing embedContent', async () => {
|
||||
const responses = { ...fakeResponses, embedContent: undefined };
|
||||
const generator = new FakeContentGenerator(
|
||||
responses as unknown as FakeResponses,
|
||||
);
|
||||
await expect(
|
||||
generator.embedContent({} as EmbedContentParameters),
|
||||
).rejects.toThrowError('No more mock responses for embedContent');
|
||||
expect(response).toEqual(fakeGenerateContentResponse.response);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,98 +4,113 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
CountTokensResponse,
|
||||
import {
|
||||
GenerateContentResponse,
|
||||
GenerateContentParameters,
|
||||
CountTokensParameters,
|
||||
type CountTokensResponse,
|
||||
type GenerateContentParameters,
|
||||
type CountTokensParameters,
|
||||
EmbedContentResponse,
|
||||
EmbedContentParameters,
|
||||
type EmbedContentParameters,
|
||||
} from '@google/genai';
|
||||
import { promises } from 'node:fs';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import type { UserTierId } from '../code_assist/types.js';
|
||||
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
|
||||
|
||||
export type FakeResponses = {
|
||||
generateContent: GenerateContentResponse[];
|
||||
generateContentStream: GenerateContentResponse[][];
|
||||
countTokens: CountTokensResponse[];
|
||||
embedContent: EmbedContentResponse[];
|
||||
};
|
||||
export type FakeResponse =
|
||||
| {
|
||||
method: 'generateContent';
|
||||
response: GenerateContentResponse;
|
||||
}
|
||||
| {
|
||||
method: 'generateContentStream';
|
||||
response: GenerateContentResponse[];
|
||||
}
|
||||
| {
|
||||
method: 'countTokens';
|
||||
response: CountTokensResponse;
|
||||
}
|
||||
| {
|
||||
method: 'embedContent';
|
||||
response: EmbedContentResponse;
|
||||
};
|
||||
|
||||
// A ContentGenerator that responds with canned responses.
|
||||
//
|
||||
// Typically these would come from a file, provided by the `--fake-responses`
|
||||
// CLI argument.
|
||||
export class FakeContentGenerator implements ContentGenerator {
|
||||
private responses: FakeResponses;
|
||||
private callCounters = {
|
||||
generateContent: 0,
|
||||
generateContentStream: 0,
|
||||
countTokens: 0,
|
||||
embedContent: 0,
|
||||
};
|
||||
private callCounter = 0;
|
||||
userTier?: UserTierId;
|
||||
|
||||
constructor(responses: FakeResponses) {
|
||||
this.responses = {
|
||||
generateContent: responses.generateContent ?? [],
|
||||
generateContentStream: responses.generateContentStream ?? [],
|
||||
countTokens: responses.countTokens ?? [],
|
||||
embedContent: responses.embedContent ?? [],
|
||||
};
|
||||
}
|
||||
constructor(private readonly responses: FakeResponse[]) {}
|
||||
|
||||
static async fromFile(filePath: string): Promise<FakeContentGenerator> {
|
||||
const fileContent = await promises.readFile(filePath, 'utf-8');
|
||||
const responses = JSON.parse(fileContent) as FakeResponses;
|
||||
const responses = fileContent
|
||||
.split('\n')
|
||||
.filter((line) => line.trim() !== '')
|
||||
.map((line) => JSON.parse(line) as FakeResponse);
|
||||
return new FakeContentGenerator(responses);
|
||||
}
|
||||
|
||||
private getNextResponse<K extends keyof FakeResponses>(
|
||||
method: K,
|
||||
request: unknown,
|
||||
): FakeResponses[K][number] {
|
||||
const response = this.responses[method][this.callCounters[method]++];
|
||||
private getNextResponse<
|
||||
M extends FakeResponse['method'],
|
||||
R = Extract<FakeResponse, { method: M }>['response'],
|
||||
>(method: M, request: unknown): R {
|
||||
const response = this.responses[this.callCounter++];
|
||||
if (!response) {
|
||||
throw new Error(
|
||||
`No more mock responses for ${method}, got request:\n` +
|
||||
safeJsonStringify(request),
|
||||
);
|
||||
}
|
||||
return response;
|
||||
if (response.method !== method) {
|
||||
throw new Error(
|
||||
`Unexpected response type, next response was for ${response.method} but expected ${method}`,
|
||||
);
|
||||
}
|
||||
return response.response as R;
|
||||
}
|
||||
|
||||
async generateContent(
|
||||
_request: GenerateContentParameters,
|
||||
request: GenerateContentParameters,
|
||||
_userPromptId: string,
|
||||
): Promise<GenerateContentResponse> {
|
||||
return this.getNextResponse('generateContent', _request);
|
||||
return Object.setPrototypeOf(
|
||||
this.getNextResponse('generateContent', request),
|
||||
GenerateContentResponse.prototype,
|
||||
);
|
||||
}
|
||||
|
||||
async generateContentStream(
|
||||
_request: GenerateContentParameters,
|
||||
request: GenerateContentParameters,
|
||||
_userPromptId: string,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
const responses = this.getNextResponse('generateContentStream', _request);
|
||||
const responses = this.getNextResponse('generateContentStream', request);
|
||||
async function* stream() {
|
||||
for (const response of responses) {
|
||||
yield response;
|
||||
yield Object.setPrototypeOf(
|
||||
response,
|
||||
GenerateContentResponse.prototype,
|
||||
);
|
||||
}
|
||||
}
|
||||
return stream();
|
||||
}
|
||||
|
||||
async countTokens(
|
||||
_request: CountTokensParameters,
|
||||
request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
return this.getNextResponse('countTokens', _request);
|
||||
return this.getNextResponse('countTokens', request);
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
_request: EmbedContentParameters,
|
||||
request: EmbedContentParameters,
|
||||
): Promise<EmbedContentResponse> {
|
||||
return this.getNextResponse('embedContent', _request);
|
||||
return Object.setPrototypeOf(
|
||||
this.getNextResponse('embedContent', request),
|
||||
EmbedContentResponse.prototype,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
151
packages/core/src/core/recordingContentGenerator.test.ts
Normal file
151
packages/core/src/core/recordingContentGenerator.test.ts
Normal file
@@ -0,0 +1,151 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
GenerateContentResponse,
|
||||
CountTokensResponse,
|
||||
EmbedContentResponse,
|
||||
GenerateContentParameters,
|
||||
CountTokensParameters,
|
||||
EmbedContentParameters,
|
||||
ContentEmbedding,
|
||||
} from '@google/genai';
|
||||
import { appendFileSync } from 'node:fs';
|
||||
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
|
||||
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import { RecordingContentGenerator } from './recordingContentGenerator.js';
|
||||
|
||||
vi.mock('node:fs', () => ({
|
||||
appendFileSync: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('RecordingContentGenerator', () => {
|
||||
let mockRealGenerator: ContentGenerator;
|
||||
let recorder: RecordingContentGenerator;
|
||||
const filePath = '/test/file/responses.json';
|
||||
|
||||
beforeEach(() => {
|
||||
mockRealGenerator = {
|
||||
generateContent: vi.fn(),
|
||||
generateContentStream: vi.fn(),
|
||||
countTokens: vi.fn(),
|
||||
embedContent: vi.fn(),
|
||||
};
|
||||
recorder = new RecordingContentGenerator(mockRealGenerator, filePath);
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should record generateContent responses', async () => {
|
||||
const mockResponse = {
|
||||
candidates: [
|
||||
{ content: { parts: [{ text: 'response' }], role: 'model' } },
|
||||
],
|
||||
usageMetadata: { totalTokenCount: 10 },
|
||||
} as GenerateContentResponse;
|
||||
(mockRealGenerator.generateContent as Mock).mockResolvedValue(mockResponse);
|
||||
|
||||
const response = await recorder.generateContent(
|
||||
{} as GenerateContentParameters,
|
||||
'id1',
|
||||
);
|
||||
expect(response).toEqual(mockResponse);
|
||||
expect(mockRealGenerator.generateContent).toHaveBeenCalledWith({}, 'id1');
|
||||
|
||||
expect(appendFileSync).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
safeJsonStringify({
|
||||
method: 'generateContent',
|
||||
response: mockResponse,
|
||||
}) + '\n',
|
||||
);
|
||||
});
|
||||
|
||||
it('should record generateContentStream responses', async () => {
|
||||
const mockResponse1 = {
|
||||
candidates: [
|
||||
{ content: { parts: [{ text: 'response1' }], role: 'model' } },
|
||||
],
|
||||
usageMetadata: { totalTokenCount: 10 },
|
||||
} as GenerateContentResponse;
|
||||
const mockResponse2 = {
|
||||
candidates: [
|
||||
{ content: { parts: [{ text: 'response2' }], role: 'model' } },
|
||||
],
|
||||
usageMetadata: { totalTokenCount: 20 },
|
||||
} as GenerateContentResponse;
|
||||
|
||||
async function* mockStream() {
|
||||
yield mockResponse1;
|
||||
yield mockResponse2;
|
||||
}
|
||||
|
||||
(mockRealGenerator.generateContentStream as Mock).mockResolvedValue(
|
||||
mockStream(),
|
||||
);
|
||||
|
||||
const stream = await recorder.generateContentStream(
|
||||
{} as GenerateContentParameters,
|
||||
'id1',
|
||||
);
|
||||
const responses = [];
|
||||
for await (const response of stream) {
|
||||
responses.push(response);
|
||||
}
|
||||
|
||||
expect(responses).toEqual([mockResponse1, mockResponse2]);
|
||||
expect(mockRealGenerator.generateContentStream).toHaveBeenCalledWith(
|
||||
{},
|
||||
'id1',
|
||||
);
|
||||
|
||||
expect(appendFileSync).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
safeJsonStringify({
|
||||
method: 'generateContentStream',
|
||||
response: responses,
|
||||
}) + '\n',
|
||||
);
|
||||
});
|
||||
|
||||
it('should record countTokens responses', async () => {
|
||||
const mockResponse = {
|
||||
totalTokens: 100,
|
||||
cachedContentTokenCount: 10,
|
||||
} as CountTokensResponse;
|
||||
(mockRealGenerator.countTokens as Mock).mockResolvedValue(mockResponse);
|
||||
|
||||
const response = await recorder.countTokens({} as CountTokensParameters);
|
||||
expect(response).toEqual(mockResponse);
|
||||
expect(mockRealGenerator.countTokens).toHaveBeenCalledWith({});
|
||||
|
||||
expect(appendFileSync).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
safeJsonStringify({
|
||||
method: 'countTokens',
|
||||
response: mockResponse,
|
||||
}) + '\n',
|
||||
);
|
||||
});
|
||||
|
||||
it('should record embedContent responses', async () => {
|
||||
const mockResponse = {
|
||||
embeddings: [{ values: [1, 2, 3] } as ContentEmbedding],
|
||||
} as EmbedContentResponse;
|
||||
(mockRealGenerator.embedContent as Mock).mockResolvedValue(mockResponse);
|
||||
|
||||
const response = await recorder.embedContent({} as EmbedContentParameters);
|
||||
expect(response).toEqual(mockResponse);
|
||||
expect(mockRealGenerator.embedContent).toHaveBeenCalledWith({});
|
||||
expect(appendFileSync).toHaveBeenCalledWith(
|
||||
filePath,
|
||||
safeJsonStringify({
|
||||
method: 'embedContent',
|
||||
response: mockResponse,
|
||||
}) + '\n',
|
||||
);
|
||||
});
|
||||
});
|
||||
112
packages/core/src/core/recordingContentGenerator.ts
Normal file
112
packages/core/src/core/recordingContentGenerator.ts
Normal file
@@ -0,0 +1,112 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
CountTokensResponse,
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
CountTokensParameters,
|
||||
EmbedContentResponse,
|
||||
EmbedContentParameters,
|
||||
} from '@google/genai';
|
||||
import { appendFileSync } from 'node:fs';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import type { FakeResponse } from './fakeContentGenerator.js';
|
||||
import type { UserTierId } from '../code_assist/types.js';
|
||||
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
|
||||
|
||||
// A ContentGenerator that wraps another content generator and records all the
|
||||
// responses, with the ability to write them out to a file. These files are
|
||||
// intended to be consumed later on by a FakeContentGenerator, given the
|
||||
// `--fake-responses` CLI argument.
|
||||
//
|
||||
// Note that only the "interesting" bits of the responses are actually kept.
|
||||
export class RecordingContentGenerator implements ContentGenerator {
|
||||
userTier?: UserTierId;
|
||||
|
||||
constructor(
|
||||
private readonly realGenerator: ContentGenerator,
|
||||
private readonly filePath: string,
|
||||
) {}
|
||||
|
||||
async generateContent(
|
||||
request: GenerateContentParameters,
|
||||
userPromptId: string,
|
||||
): Promise<GenerateContentResponse> {
|
||||
const response = await this.realGenerator.generateContent(
|
||||
request,
|
||||
userPromptId,
|
||||
);
|
||||
const recordedResponse: FakeResponse = {
|
||||
method: 'generateContent',
|
||||
response: {
|
||||
candidates: response.candidates,
|
||||
usageMetadata: response.usageMetadata,
|
||||
} as GenerateContentResponse,
|
||||
};
|
||||
appendFileSync(this.filePath, `${safeJsonStringify(recordedResponse)}\n`);
|
||||
return response;
|
||||
}
|
||||
|
||||
async generateContentStream(
|
||||
request: GenerateContentParameters,
|
||||
userPromptId: string,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
const recordedResponse: FakeResponse = {
|
||||
method: 'generateContentStream',
|
||||
response: [],
|
||||
};
|
||||
|
||||
const realResponses = await this.realGenerator.generateContentStream(
|
||||
request,
|
||||
userPromptId,
|
||||
);
|
||||
|
||||
async function* stream(filePath: string) {
|
||||
for await (const response of realResponses) {
|
||||
(recordedResponse.response as GenerateContentResponse[]).push({
|
||||
candidates: response.candidates,
|
||||
usageMetadata: response.usageMetadata,
|
||||
} as GenerateContentResponse);
|
||||
yield response;
|
||||
}
|
||||
appendFileSync(filePath, `${safeJsonStringify(recordedResponse)}\n`);
|
||||
}
|
||||
|
||||
return Promise.resolve(stream(this.filePath));
|
||||
}
|
||||
|
||||
async countTokens(
|
||||
request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
const response = await this.realGenerator.countTokens(request);
|
||||
const recordedResponse: FakeResponse = {
|
||||
method: 'countTokens',
|
||||
response: {
|
||||
totalTokens: response.totalTokens,
|
||||
cachedContentTokenCount: response.cachedContentTokenCount,
|
||||
},
|
||||
};
|
||||
appendFileSync(this.filePath, `${safeJsonStringify(recordedResponse)}\n`);
|
||||
return response;
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
request: EmbedContentParameters,
|
||||
): Promise<EmbedContentResponse> {
|
||||
const response = await this.realGenerator.embedContent(request);
|
||||
|
||||
const recordedResponse: FakeResponse = {
|
||||
method: 'embedContent',
|
||||
response: {
|
||||
embeddings: response.embeddings,
|
||||
metadata: response.metadata,
|
||||
},
|
||||
};
|
||||
appendFileSync(this.filePath, `${safeJsonStringify(recordedResponse)}\n`);
|
||||
return response;
|
||||
}
|
||||
}
|
||||
@@ -29,6 +29,7 @@ export * from './core/turn.js';
|
||||
export * from './core/geminiRequest.js';
|
||||
export * from './core/coreToolScheduler.js';
|
||||
export * from './core/nonInteractiveToolExecutor.js';
|
||||
export * from './core/recordingContentGenerator.js';
|
||||
|
||||
export * from './fallback/types.js';
|
||||
|
||||
|
||||
Reference in New Issue
Block a user