mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-21 10:34:35 -07:00
feat(cli, core): Support hybrid evals.
This commit is contained in:
@@ -16,6 +16,7 @@ import {
|
||||
type ContentGenerator,
|
||||
type ContentGeneratorConfig,
|
||||
} from '../core/contentGenerator.js';
|
||||
import type { ScriptItem } from '../core/scriptUtils.js';
|
||||
import type { OverageStrategy } from '../billing/billing.js';
|
||||
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import { ResourceRegistry } from '../resources/resource-registry.js';
|
||||
@@ -513,6 +514,11 @@ export interface PolicyUpdateConfirmationRequest {
|
||||
newHash: string;
|
||||
}
|
||||
|
||||
export interface FakeModelConfig {
|
||||
responses: string | ScriptItem[];
|
||||
hybridHandoff?: boolean;
|
||||
}
|
||||
|
||||
export interface ConfigParameters {
|
||||
sessionId: string;
|
||||
clientName?: string;
|
||||
@@ -537,6 +543,7 @@ export interface ConfigParameters {
|
||||
mcpEnablementCallbacks?: McpEnablementCallbacks;
|
||||
userMemory?: string | HierarchicalMemory;
|
||||
geminiMdFileCount?: number;
|
||||
contentGenerator?: ContentGenerator;
|
||||
geminiMdFilePaths?: string[];
|
||||
approvalMode?: ApprovalMode;
|
||||
showMemoryUsage?: boolean;
|
||||
@@ -608,7 +615,8 @@ export interface ConfigParameters {
|
||||
maxAttempts?: number;
|
||||
enableShellOutputEfficiency?: boolean;
|
||||
shellToolInactivityTimeout?: number;
|
||||
fakeResponses?: string;
|
||||
fakeModelConfig?: FakeModelConfig;
|
||||
fakeResponses?: string | ScriptItem[];
|
||||
recordResponses?: string;
|
||||
ptyInfo?: string;
|
||||
disableYoloMode?: boolean;
|
||||
@@ -814,7 +822,8 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
private readonly maxAttempts: number;
|
||||
private readonly enableShellOutputEfficiency: boolean;
|
||||
private readonly shellToolInactivityTimeout: number;
|
||||
readonly fakeResponses?: string;
|
||||
readonly fakeModelConfig?: FakeModelConfig;
|
||||
private readonly hasCustomContentGenerator: boolean;
|
||||
readonly recordResponses?: string;
|
||||
private readonly disableYoloMode: boolean;
|
||||
private readonly disableAlwaysAllow: boolean;
|
||||
@@ -896,6 +905,10 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
this.pendingIncludeDirectories = params.includeDirectories ?? [];
|
||||
this.debugMode = params.debugMode;
|
||||
this.question = params.question;
|
||||
this.hasCustomContentGenerator = !!params.contentGenerator;
|
||||
if (params.contentGenerator) {
|
||||
this.contentGenerator = params.contentGenerator;
|
||||
}
|
||||
|
||||
this.coreTools = params.coreTools;
|
||||
this.mainAgentTools = params.mainAgentTools;
|
||||
@@ -1093,7 +1106,14 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
this.storage = new Storage(this.targetDir, this._sessionId);
|
||||
this.storage.setCustomPlansDir(params.planSettings?.directory);
|
||||
|
||||
this.fakeResponses = params.fakeResponses;
|
||||
if (params.fakeModelConfig) {
|
||||
this.fakeModelConfig = params.fakeModelConfig;
|
||||
} else if (params.fakeResponses) {
|
||||
this.fakeModelConfig = {
|
||||
responses: params.fakeResponses,
|
||||
};
|
||||
}
|
||||
|
||||
this.recordResponses = params.recordResponses;
|
||||
this.fileExclusions = new FileExclusions(this);
|
||||
this.eventEmitter = params.eventEmitter;
|
||||
@@ -1198,6 +1218,10 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
this.modelRouterService = new ModelRouterService(this);
|
||||
}
|
||||
|
||||
get fakeResponses(): string | ScriptItem[] | undefined {
|
||||
return this.fakeModelConfig?.responses;
|
||||
}
|
||||
|
||||
get config(): Config {
|
||||
return this;
|
||||
}
|
||||
@@ -1359,11 +1383,13 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
baseUrl,
|
||||
customHeaders,
|
||||
);
|
||||
this.contentGenerator = await createContentGenerator(
|
||||
newContentGeneratorConfig,
|
||||
this,
|
||||
this.getSessionId(),
|
||||
);
|
||||
if (!this.hasCustomContentGenerator) {
|
||||
this.contentGenerator = await createContentGenerator(
|
||||
newContentGeneratorConfig,
|
||||
this,
|
||||
this.getSessionId(),
|
||||
);
|
||||
}
|
||||
// Only assign to instance properties after successful initialization
|
||||
this.contentGeneratorConfig = newContentGeneratorConfig;
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ describe('createContentGenerator', () => {
|
||||
);
|
||||
const fakeResponsesFile = 'fake/responses.yaml';
|
||||
const mockConfigWithFake = {
|
||||
...mockConfig,
|
||||
fakeResponses: fakeResponsesFile,
|
||||
getClientName: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as Config;
|
||||
@@ -74,6 +75,7 @@ describe('createContentGenerator', () => {
|
||||
const fakeResponsesFile = 'fake/responses.yaml';
|
||||
const recordResponsesFile = 'record/responses.yaml';
|
||||
const mockConfigWithRecordResponses = {
|
||||
...mockConfig,
|
||||
fakeResponses: fakeResponsesFile,
|
||||
recordResponses: recordResponsesFile,
|
||||
getClientName: vi.fn().mockReturnValue(undefined),
|
||||
|
||||
@@ -22,9 +22,14 @@ import { LoggingContentGenerator } from './loggingContentGenerator.js';
|
||||
import { InstallationManager } from '../utils/installationManager.js';
|
||||
import { FakeContentGenerator } from './fakeContentGenerator.js';
|
||||
import { parseCustomHeaders } from '../utils/customHeaderUtils.js';
|
||||
import { extractFakeResponses } from './scriptUtils.js';
|
||||
import { determineSurface } from '../utils/surface.js';
|
||||
import { RecordingContentGenerator } from './recordingContentGenerator.js';
|
||||
import { getVersion, resolveModel } from '../../index.js';
|
||||
import {
|
||||
FallbackContentGenerator,
|
||||
getVersion,
|
||||
resolveModel,
|
||||
} from '../../index.js';
|
||||
import type { LlmRole } from '../telemetry/llmRole.js';
|
||||
|
||||
/**
|
||||
@@ -47,6 +52,8 @@ export interface ContentGenerator {
|
||||
|
||||
embedContent(request: EmbedContentParameters): Promise<EmbedContentResponse>;
|
||||
|
||||
getSentRequests?(): GenerateContentParameters[];
|
||||
|
||||
userTier?: UserTierId;
|
||||
|
||||
userTierName?: string;
|
||||
@@ -166,11 +173,15 @@ export async function createContentGenerator(
|
||||
sessionId?: string,
|
||||
): Promise<ContentGenerator> {
|
||||
const generator = await (async () => {
|
||||
let fakeGenerator: FakeContentGenerator | undefined;
|
||||
if (gcConfig.fakeResponses) {
|
||||
const fakeGenerator = await FakeContentGenerator.fromFile(
|
||||
gcConfig.fakeResponses,
|
||||
);
|
||||
return new LoggingContentGenerator(fakeGenerator, gcConfig);
|
||||
fakeGenerator = Array.isArray(gcConfig.fakeResponses)
|
||||
? new FakeContentGenerator(extractFakeResponses(gcConfig.fakeResponses))
|
||||
: await FakeContentGenerator.fromFile(gcConfig.fakeResponses);
|
||||
|
||||
if (!gcConfig.fakeModelConfig?.hybridHandoff) {
|
||||
return new LoggingContentGenerator(fakeGenerator, gcConfig);
|
||||
}
|
||||
}
|
||||
const version = await getVersion();
|
||||
const model = resolveModel(
|
||||
@@ -208,23 +219,21 @@ export async function createContentGenerator(
|
||||
) {
|
||||
baseHeaders['Authorization'] = `Bearer ${config.apiKey}`;
|
||||
}
|
||||
|
||||
let realGenerator: ContentGenerator;
|
||||
|
||||
if (
|
||||
config.authType === AuthType.LOGIN_WITH_GOOGLE ||
|
||||
config.authType === AuthType.COMPUTE_ADC
|
||||
) {
|
||||
const httpOptions = { headers: baseHeaders };
|
||||
return new LoggingContentGenerator(
|
||||
await createCodeAssistContentGenerator(
|
||||
httpOptions,
|
||||
config.authType,
|
||||
gcConfig,
|
||||
sessionId,
|
||||
),
|
||||
realGenerator = await createCodeAssistContentGenerator(
|
||||
httpOptions,
|
||||
config.authType,
|
||||
gcConfig,
|
||||
sessionId,
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
} else if (
|
||||
config.authType === AuthType.USE_GEMINI ||
|
||||
config.authType === AuthType.USE_VERTEX_AI ||
|
||||
config.authType === AuthType.GATEWAY
|
||||
@@ -268,11 +277,21 @@ export async function createContentGenerator(
|
||||
httpOptions,
|
||||
...(apiVersionEnv && { apiVersion: apiVersionEnv }),
|
||||
});
|
||||
return new LoggingContentGenerator(googleGenAI.models, gcConfig);
|
||||
realGenerator = googleGenAI.models;
|
||||
} else {
|
||||
throw new Error(
|
||||
`Error creating contentGenerator: Unsupported authType: ${config.authType}`,
|
||||
);
|
||||
}
|
||||
throw new Error(
|
||||
`Error creating contentGenerator: Unsupported authType: ${config.authType}`,
|
||||
);
|
||||
|
||||
if (fakeGenerator && gcConfig.fakeModelConfig?.hybridHandoff) {
|
||||
realGenerator = new FallbackContentGenerator(
|
||||
fakeGenerator,
|
||||
realGenerator,
|
||||
);
|
||||
}
|
||||
|
||||
return new LoggingContentGenerator(realGenerator, gcConfig);
|
||||
})();
|
||||
|
||||
if (gcConfig.recordResponses) {
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import {
|
||||
FakeContentGenerator,
|
||||
MockExhaustedError,
|
||||
type FakeResponse,
|
||||
} from './fakeContentGenerator.js';
|
||||
import { promises } from 'node:fs';
|
||||
@@ -142,7 +143,7 @@ describe('FakeContentGenerator', () => {
|
||||
}
|
||||
});
|
||||
|
||||
it('should throw error when no more responses', async () => {
|
||||
it('should throw MockExhaustedError when no more responses', async () => {
|
||||
const generator = new FakeContentGenerator([fakeGenerateContentResponse]);
|
||||
await generator.generateContent(
|
||||
{} as GenerateContentParameters,
|
||||
@@ -151,24 +152,45 @@ describe('FakeContentGenerator', () => {
|
||||
);
|
||||
await expect(
|
||||
generator.embedContent({} as EmbedContentParameters),
|
||||
).rejects.toThrowError('No more mock responses for embedContent');
|
||||
).rejects.toThrow(MockExhaustedError);
|
||||
await expect(
|
||||
generator.countTokens({} as CountTokensParameters),
|
||||
).rejects.toThrowError('No more mock responses for countTokens');
|
||||
).rejects.toThrow(MockExhaustedError);
|
||||
await expect(
|
||||
generator.generateContentStream(
|
||||
{} as GenerateContentParameters,
|
||||
'id',
|
||||
LlmRole.MAIN,
|
||||
),
|
||||
).rejects.toThrow('No more mock responses for generateContentStream');
|
||||
).rejects.toThrow(MockExhaustedError);
|
||||
await expect(
|
||||
generator.generateContent(
|
||||
{} as GenerateContentParameters,
|
||||
'id',
|
||||
LlmRole.MAIN,
|
||||
),
|
||||
).rejects.toThrowError('No more mock responses for generateContent');
|
||||
).rejects.toThrow(MockExhaustedError);
|
||||
});
|
||||
|
||||
it('should track sent requests via getSentRequests', async () => {
|
||||
const generator = new FakeContentGenerator([
|
||||
fakeGenerateContentResponse,
|
||||
fakeGenerateContentStreamResponse,
|
||||
]);
|
||||
const req1 = {
|
||||
contents: [{ role: 'user', parts: [{ text: 'a' }] }],
|
||||
} as GenerateContentParameters;
|
||||
const req2 = {
|
||||
contents: [{ role: 'user', parts: [{ text: 'b' }] }],
|
||||
} as GenerateContentParameters;
|
||||
|
||||
await generator.generateContent(req1, 'id1', LlmRole.MAIN);
|
||||
await generator.generateContentStream(req2, 'id2', LlmRole.MAIN);
|
||||
|
||||
const sent = generator.getSentRequests();
|
||||
expect(sent).toHaveLength(2);
|
||||
expect(sent[0]).toBe(req1);
|
||||
expect(sent[1]).toBe(req2);
|
||||
});
|
||||
|
||||
describe('fromFile', () => {
|
||||
|
||||
@@ -18,6 +18,16 @@ import type { UserTierId, GeminiUserTier } from '../code_assist/types.js';
|
||||
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
|
||||
import type { LlmRole } from '../telemetry/types.js';
|
||||
|
||||
export class MockExhaustedError extends Error {
|
||||
constructor(method: string, request?: unknown) {
|
||||
super(
|
||||
`No more mock responses for ${method}, got request:\n` +
|
||||
safeJsonStringify(request),
|
||||
);
|
||||
this.name = 'MockExhaustedError';
|
||||
}
|
||||
}
|
||||
|
||||
export type FakeResponse =
|
||||
| {
|
||||
method: 'generateContent';
|
||||
@@ -42,12 +52,17 @@ export type FakeResponse =
|
||||
// CLI argument.
|
||||
export class FakeContentGenerator implements ContentGenerator {
|
||||
private callCounter = 0;
|
||||
private sentRequests: GenerateContentParameters[] = [];
|
||||
userTier?: UserTierId;
|
||||
userTierName?: string;
|
||||
paidTier?: GeminiUserTier;
|
||||
|
||||
constructor(private readonly responses: FakeResponse[]) {}
|
||||
|
||||
getSentRequests(): GenerateContentParameters[] {
|
||||
return this.sentRequests;
|
||||
}
|
||||
|
||||
static async fromFile(filePath: string): Promise<FakeContentGenerator> {
|
||||
const fileContent = await promises.readFile(filePath, 'utf-8');
|
||||
const responses = fileContent
|
||||
@@ -62,13 +77,14 @@ export class FakeContentGenerator implements ContentGenerator {
|
||||
M extends FakeResponse['method'],
|
||||
R = Extract<FakeResponse, { method: M }>['response'],
|
||||
>(method: M, request: unknown): R {
|
||||
const response = this.responses[this.callCounter++];
|
||||
const response = this.responses[this.callCounter];
|
||||
if (!response) {
|
||||
throw new Error(
|
||||
`No more mock responses for ${method}, got request:\n` +
|
||||
safeJsonStringify(request),
|
||||
);
|
||||
throw new MockExhaustedError(method, request);
|
||||
}
|
||||
|
||||
// We only increment the counter if we actually consume a mock response
|
||||
this.callCounter++;
|
||||
|
||||
if (response.method !== method) {
|
||||
throw new Error(
|
||||
`Unexpected response type, next response was for ${response.method} but expected ${method}`,
|
||||
@@ -81,24 +97,27 @@ export class FakeContentGenerator implements ContentGenerator {
|
||||
async generateContent(
|
||||
request: GenerateContentParameters,
|
||||
_userPromptId: string,
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
role: LlmRole,
|
||||
_role: LlmRole,
|
||||
): Promise<GenerateContentResponse> {
|
||||
this.sentRequests.push(request);
|
||||
const next = this.getNextResponse('generateContent', request);
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
|
||||
return Object.setPrototypeOf(
|
||||
this.getNextResponse('generateContent', request),
|
||||
GenerateContentResponse.prototype,
|
||||
);
|
||||
return Object.setPrototypeOf(next, GenerateContentResponse.prototype);
|
||||
}
|
||||
|
||||
async generateContentStream(
|
||||
request: GenerateContentParameters,
|
||||
_userPromptId: string,
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
role: LlmRole,
|
||||
_role: LlmRole,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
this.sentRequests.push(request);
|
||||
const responses = this.getNextResponse('generateContentStream', request);
|
||||
|
||||
async function* stream() {
|
||||
// Add a tiny delay to ensure React has time to render the 'Responding'
|
||||
// state. If the mock stream finishes synchronously, AppRig's
|
||||
// awaitingResponse flag may never be cleared, causing the rig to hang.
|
||||
await new Promise((resolve) => setTimeout(resolve, 5));
|
||||
for (const response of responses) {
|
||||
yield Object.setPrototypeOf(
|
||||
response,
|
||||
@@ -112,16 +131,15 @@ export class FakeContentGenerator implements ContentGenerator {
|
||||
async countTokens(
|
||||
request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
return this.getNextResponse('countTokens', request);
|
||||
const next = this.getNextResponse('countTokens', request);
|
||||
return next;
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
request: EmbedContentParameters,
|
||||
): Promise<EmbedContentResponse> {
|
||||
const next = this.getNextResponse('embedContent', request);
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
|
||||
return Object.setPrototypeOf(
|
||||
this.getNextResponse('embedContent', request),
|
||||
EmbedContentResponse.prototype,
|
||||
);
|
||||
return Object.setPrototypeOf(next, EmbedContentResponse.prototype);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { FallbackContentGenerator } from './fallbackContentGenerator.js';
|
||||
import { MockExhaustedError } from './fakeContentGenerator.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import type { GenerateContentParameters } from '@google/genai';
|
||||
import { LlmRole } from '../telemetry/types.js';
|
||||
|
||||
describe('FallbackContentGenerator', () => {
|
||||
const dummyRequest: GenerateContentParameters = {
|
||||
model: 'gemini',
|
||||
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
||||
};
|
||||
|
||||
it('delegates to the primary generator if successful', async () => {
|
||||
const mockPrimary = {
|
||||
generateContent: vi.fn().mockResolvedValue({ text: 'primary response' }),
|
||||
} as unknown as ContentGenerator;
|
||||
const mockFallback = {
|
||||
generateContent: vi.fn(),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
const generator = new FallbackContentGenerator(mockPrimary, mockFallback);
|
||||
const result = await generator.generateContent(
|
||||
dummyRequest,
|
||||
'prompt-id',
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
|
||||
expect(result).toEqual({ text: 'primary response' });
|
||||
expect(mockPrimary.generateContent).toHaveBeenCalledWith(
|
||||
dummyRequest,
|
||||
'prompt-id',
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
expect(mockFallback.generateContent).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('bubbles up regular errors from the primary generator', async () => {
|
||||
const mockPrimary = {
|
||||
generateContent: vi.fn().mockRejectedValue(new Error('Network failure')),
|
||||
} as unknown as ContentGenerator;
|
||||
const mockFallback = {
|
||||
generateContent: vi.fn(),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
const generator = new FallbackContentGenerator(mockPrimary, mockFallback);
|
||||
await expect(
|
||||
generator.generateContent(dummyRequest, 'prompt-id', LlmRole.MAIN),
|
||||
).rejects.toThrow('Network failure');
|
||||
expect(mockFallback.generateContent).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('falls back to the secondary generator if primary throws MockExhaustedError', async () => {
|
||||
const mockPrimary = {
|
||||
generateContent: vi
|
||||
.fn()
|
||||
.mockRejectedValue(new MockExhaustedError('generateContent')),
|
||||
} as unknown as ContentGenerator;
|
||||
const mockFallback = {
|
||||
generateContent: vi.fn().mockResolvedValue({ text: 'fallback response' }),
|
||||
} as unknown as ContentGenerator;
|
||||
const onFallback = vi.fn();
|
||||
|
||||
const generator = new FallbackContentGenerator(
|
||||
mockPrimary,
|
||||
mockFallback,
|
||||
onFallback,
|
||||
);
|
||||
const result = await generator.generateContent(
|
||||
dummyRequest,
|
||||
'prompt-id',
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
|
||||
expect(result).toEqual({ text: 'fallback response' });
|
||||
expect(mockPrimary.generateContent).toHaveBeenCalled();
|
||||
expect(onFallback).toHaveBeenCalledWith('generateContent');
|
||||
expect(mockFallback.generateContent).toHaveBeenCalledWith(
|
||||
dummyRequest,
|
||||
'prompt-id',
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
});
|
||||
|
||||
it('bubbles up MockExhaustedError if the fallback generator also exhausts', async () => {
|
||||
const mockPrimary = {
|
||||
generateContent: vi
|
||||
.fn()
|
||||
.mockRejectedValue(new MockExhaustedError('generateContent')),
|
||||
} as unknown as ContentGenerator;
|
||||
const mockFallback = {
|
||||
generateContent: vi
|
||||
.fn()
|
||||
.mockRejectedValue(new MockExhaustedError('generateContent')),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
const generator = new FallbackContentGenerator(mockPrimary, mockFallback);
|
||||
await expect(
|
||||
generator.generateContent(dummyRequest, 'prompt-id', LlmRole.MAIN),
|
||||
).rejects.toThrow(MockExhaustedError);
|
||||
});
|
||||
|
||||
it('handles stream delegation and fallback', async () => {
|
||||
const asyncStream = async function* () {
|
||||
yield { text: 'stream chunk' };
|
||||
};
|
||||
const mockPrimary = {
|
||||
generateContentStream: vi
|
||||
.fn()
|
||||
.mockRejectedValue(new MockExhaustedError('generateContentStream')),
|
||||
} as unknown as ContentGenerator;
|
||||
const mockFallback = {
|
||||
generateContentStream: vi.fn().mockResolvedValue(asyncStream()),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
const generator = new FallbackContentGenerator(mockPrimary, mockFallback);
|
||||
const result = await generator.generateContentStream(
|
||||
dummyRequest,
|
||||
'prompt-id',
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const chunks: any[] = [];
|
||||
for await (const chunk of result) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
expect(chunks).toEqual([{ text: 'stream chunk' }]);
|
||||
expect(mockFallback.generateContentStream).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('handles optional methods like countTokens that are missing on primary', async () => {
|
||||
const mockPrimary = {} as unknown as ContentGenerator;
|
||||
const mockFallback = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 42 }),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
const generator = new FallbackContentGenerator(mockPrimary, mockFallback);
|
||||
const result = await generator.countTokens({
|
||||
model: 'gemini',
|
||||
contents: [],
|
||||
});
|
||||
|
||||
expect(result).toEqual({ totalTokens: 42 });
|
||||
expect(mockFallback.countTokens).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('handles optional methods like embedContent that are missing on primary', async () => {
|
||||
const mockPrimary = {} as unknown as ContentGenerator;
|
||||
const mockFallback = {
|
||||
embedContent: vi.fn().mockResolvedValue({ embedding: { values: [0.1] } }),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
const generator = new FallbackContentGenerator(mockPrimary, mockFallback);
|
||||
const result = await generator.embedContent({
|
||||
model: 'gemini',
|
||||
contents: { parts: [{ text: '' }] },
|
||||
});
|
||||
|
||||
expect(result).toEqual({ embedding: { values: [0.1] } });
|
||||
expect(mockFallback.embedContent).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('proxies tier properties from the primary', () => {
|
||||
const mockPrimary = {
|
||||
userTier: 'test-tier',
|
||||
userTierName: 'Test Tier',
|
||||
paidTier: true,
|
||||
} as unknown as ContentGenerator;
|
||||
const mockFallback = {} as unknown as ContentGenerator;
|
||||
|
||||
const generator = new FallbackContentGenerator(mockPrimary, mockFallback);
|
||||
expect(generator.userTier).toBe('test-tier');
|
||||
expect(generator.userTierName).toBe('Test Tier');
|
||||
expect(generator.paidTier).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,109 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import type {
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
CountTokensParameters,
|
||||
CountTokensResponse,
|
||||
EmbedContentParameters,
|
||||
EmbedContentResponse,
|
||||
} from '@google/genai';
|
||||
import type { LlmRole } from '../telemetry/types.js';
|
||||
import { MockExhaustedError } from './fakeContentGenerator.js';
|
||||
|
||||
/**
|
||||
* A ContentGenerator that attempts to use a primary generator,
|
||||
* and falls back to a secondary generator if the primary throws MockExhaustedError.
|
||||
*/
|
||||
export class FallbackContentGenerator implements ContentGenerator {
|
||||
get userTier() {
|
||||
return this.primary.userTier;
|
||||
}
|
||||
get userTierName() {
|
||||
return this.primary.userTierName;
|
||||
}
|
||||
get paidTier() {
|
||||
return this.primary.paidTier;
|
||||
}
|
||||
|
||||
constructor(
|
||||
private readonly primary: ContentGenerator,
|
||||
private readonly fallback: ContentGenerator,
|
||||
private readonly onFallback?: (method: string) => void,
|
||||
) {}
|
||||
|
||||
async generateContent(
|
||||
request: GenerateContentParameters,
|
||||
userPromptId: string,
|
||||
role: LlmRole,
|
||||
): Promise<GenerateContentResponse> {
|
||||
try {
|
||||
return await this.primary.generateContent(request, userPromptId, role);
|
||||
} catch (error) {
|
||||
if (error instanceof MockExhaustedError) {
|
||||
this.onFallback?.('generateContent');
|
||||
return this.fallback.generateContent(request, userPromptId, role);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async generateContentStream(
|
||||
request: GenerateContentParameters,
|
||||
userPromptId: string,
|
||||
role: LlmRole,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
try {
|
||||
return await this.primary.generateContentStream(
|
||||
request,
|
||||
userPromptId,
|
||||
role,
|
||||
);
|
||||
} catch (error) {
|
||||
if (error instanceof MockExhaustedError) {
|
||||
this.onFallback?.('generateContentStream');
|
||||
return this.fallback.generateContentStream(request, userPromptId, role);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async countTokens(
|
||||
request: CountTokensParameters,
|
||||
): Promise<CountTokensResponse> {
|
||||
try {
|
||||
if (!this.primary.countTokens) {
|
||||
throw new MockExhaustedError('countTokens');
|
||||
}
|
||||
return await this.primary.countTokens(request);
|
||||
} catch (error) {
|
||||
if (error instanceof MockExhaustedError && this.fallback.countTokens) {
|
||||
this.onFallback?.('countTokens');
|
||||
return this.fallback.countTokens(request);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
request: EmbedContentParameters,
|
||||
): Promise<EmbedContentResponse> {
|
||||
try {
|
||||
if (!this.primary.embedContent) {
|
||||
throw new MockExhaustedError('embedContent');
|
||||
}
|
||||
return await this.primary.embedContent(request);
|
||||
} catch (error) {
|
||||
if (error instanceof MockExhaustedError && this.fallback.embedContent) {
|
||||
this.onFallback?.('embedContent');
|
||||
return this.fallback.embedContent(request);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -168,6 +168,10 @@ export class LoggingContentGenerator implements ContentGenerator {
|
||||
return this.wrapped.paidTier;
|
||||
}
|
||||
|
||||
getSentRequests?(): GenerateContentParameters[] {
|
||||
return this.wrapped.getSentRequests?.() || [];
|
||||
}
|
||||
|
||||
private logApiRequest(
|
||||
contents: Content[],
|
||||
model: string,
|
||||
|
||||
@@ -39,6 +39,10 @@ export class RecordingContentGenerator implements ContentGenerator {
|
||||
return this.realGenerator.userTierName;
|
||||
}
|
||||
|
||||
getSentRequests?(): GenerateContentParameters[] {
|
||||
return this.realGenerator.getSentRequests?.() || [];
|
||||
}
|
||||
|
||||
async generateContent(
|
||||
request: GenerateContentParameters,
|
||||
userPromptId: string,
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import {
|
||||
mockGenerateContentStreamText,
|
||||
mockGenerateContentText,
|
||||
userText,
|
||||
isFakeResponse,
|
||||
isFakeRequest,
|
||||
extractUserPrompts,
|
||||
extractFakeResponses,
|
||||
type ScriptItem,
|
||||
} from './scriptUtils.js';
|
||||
|
||||
describe('scriptUtils', () => {
|
||||
describe('mockGenerateContentStreamText', () => {
|
||||
it('creates a valid FakeResponse for generateContentStream', () => {
|
||||
const result = mockGenerateContentStreamText('hello stream');
|
||||
expect(result.method).toBe('generateContentStream');
|
||||
expect(Array.isArray(result.response)).toBe(true);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const responseArray = result.response as any[];
|
||||
expect(responseArray[0].candidates[0].content.parts[0].text).toBe(
|
||||
'hello stream',
|
||||
);
|
||||
expect(responseArray[0].candidates[0].finishReason).toBe('STOP');
|
||||
});
|
||||
});
|
||||
|
||||
describe('mockGenerateContentText', () => {
|
||||
it('creates a valid FakeResponse for generateContent', () => {
|
||||
const result = mockGenerateContentText('hello block');
|
||||
expect(result.method).toBe('generateContent');
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const responseObj = result.response as any;
|
||||
expect(responseObj.candidates[0].content.parts[0].text).toBe(
|
||||
'hello block',
|
||||
);
|
||||
expect(responseObj.candidates[0].finishReason).toBe('STOP');
|
||||
});
|
||||
});
|
||||
|
||||
describe('userText', () => {
|
||||
it('creates a valid FakeRequest', () => {
|
||||
const result = userText('user input');
|
||||
expect(result.method).toBe('userText');
|
||||
expect(result.text).toBe('user input');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Type Guards', () => {
|
||||
it('correctly identifies FakeResponse vs FakeRequest', () => {
|
||||
const fakeRes = mockGenerateContentText('test');
|
||||
const fakeReq = userText('test');
|
||||
|
||||
expect(isFakeResponse(fakeRes)).toBe(true);
|
||||
expect(isFakeResponse(fakeReq)).toBe(false);
|
||||
|
||||
expect(isFakeRequest(fakeReq)).toBe(true);
|
||||
expect(isFakeRequest(fakeRes)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractUserPrompts and extractFakeResponses', () => {
|
||||
it('correctly partitions a mixed script array', () => {
|
||||
const script: ScriptItem[] = [
|
||||
userText('prompt 1'),
|
||||
mockGenerateContentText('response 1'),
|
||||
userText('prompt 2'),
|
||||
mockGenerateContentStreamText('response 2'),
|
||||
];
|
||||
|
||||
const prompts = extractUserPrompts(script);
|
||||
expect(prompts).toEqual(['prompt 1', 'prompt 2']);
|
||||
|
||||
const responses = extractFakeResponses(script);
|
||||
expect(responses).toHaveLength(2);
|
||||
expect(responses[0].method).toBe('generateContent');
|
||||
expect(responses[1].method).toBe('generateContentStream');
|
||||
});
|
||||
|
||||
it('handles empty scripts', () => {
|
||||
expect(extractUserPrompts([])).toEqual([]);
|
||||
expect(extractFakeResponses([])).toEqual([]);
|
||||
});
|
||||
|
||||
it('handles scripts with only one type', () => {
|
||||
const justPrompts = [userText('a'), userText('b')];
|
||||
expect(extractUserPrompts(justPrompts)).toEqual(['a', 'b']);
|
||||
expect(extractFakeResponses(justPrompts)).toEqual([]);
|
||||
|
||||
const justResponses = [
|
||||
mockGenerateContentText('a'),
|
||||
mockGenerateContentText('b'),
|
||||
];
|
||||
expect(extractUserPrompts(justResponses)).toEqual([]);
|
||||
expect(extractFakeResponses(justResponses)).toHaveLength(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,59 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { GenerateContentResponse } from '@google/genai';
|
||||
import type { FakeResponse } from './fakeContentGenerator.js';
|
||||
|
||||
export type FakeRequest = { method: 'userText'; text: string };
|
||||
export type ScriptItem = FakeResponse | FakeRequest;
|
||||
|
||||
export function mockGenerateContentStreamText(text: string): FakeResponse {
|
||||
return {
|
||||
method: 'generateContentStream',
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
response: [
|
||||
{
|
||||
candidates: [{ content: { parts: [{ text }] }, finishReason: 'STOP' }],
|
||||
},
|
||||
] as GenerateContentResponse[],
|
||||
};
|
||||
}
|
||||
|
||||
export function mockGenerateContentText(text: string): FakeResponse {
|
||||
return {
|
||||
method: 'generateContent',
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
response: {
|
||||
candidates: [{ content: { parts: [{ text }] }, finishReason: 'STOP' }],
|
||||
} as GenerateContentResponse,
|
||||
};
|
||||
}
|
||||
|
||||
export function userText(text: string): FakeRequest {
|
||||
return { method: 'userText', text };
|
||||
}
|
||||
|
||||
export function isFakeResponse(item: ScriptItem): item is FakeResponse {
|
||||
return item.method !== 'userText';
|
||||
}
|
||||
|
||||
export function isFakeRequest(item: ScriptItem): item is FakeRequest {
|
||||
return item.method === 'userText';
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts all FakeRequests from a script array and maps them to their string text.
|
||||
*/
|
||||
export function extractUserPrompts(script: ScriptItem[]): string[] {
|
||||
return script.filter(isFakeRequest).map((req) => req.text);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts all FakeResponses from a script array.
|
||||
*/
|
||||
export function extractFakeResponses(script: ScriptItem[]): FakeResponse[] {
|
||||
return script.filter(isFakeResponse);
|
||||
}
|
||||
@@ -24,6 +24,7 @@ export * from './config/extensions/integrityTypes.js';
|
||||
export * from './billing/index.js';
|
||||
export * from './confirmation-bus/types.js';
|
||||
export * from './confirmation-bus/message-bus.js';
|
||||
export * from './safety/conseca/conseca.js';
|
||||
|
||||
// Export Commands logic
|
||||
export * from './commands/extensions.js';
|
||||
@@ -36,6 +37,9 @@ export * from './commands/types.js';
|
||||
export * from './core/baseLlmClient.js';
|
||||
export * from './core/client.js';
|
||||
export * from './core/contentGenerator.js';
|
||||
export * from './core/fakeContentGenerator.js';
|
||||
export * from './core/fallbackContentGenerator.js';
|
||||
export * from './core/scriptUtils.js';
|
||||
export * from './core/loggingContentGenerator.js';
|
||||
export * from './core/geminiChat.js';
|
||||
export * from './core/logger.js';
|
||||
|
||||
Reference in New Issue
Block a user