mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-06-13 21:07:00 -07:00
Vertex ai model mapping fix (#27749)
This commit is contained in:
@@ -15,6 +15,7 @@ import {
|
||||
} from './codeAssist.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { LoggingContentGenerator } from '../core/loggingContentGenerator.js';
|
||||
import { ModelMappingContentGenerator } from '../core/modelMappingContentGenerator.js';
|
||||
import { UserTierId } from './types.js';
|
||||
|
||||
// Mock dependencies
|
||||
@@ -22,11 +23,15 @@ vi.mock('./oauth2.js');
|
||||
vi.mock('./setup.js');
|
||||
vi.mock('./server.js');
|
||||
vi.mock('../core/loggingContentGenerator.js');
|
||||
vi.mock('../core/modelMappingContentGenerator.js');
|
||||
|
||||
const mockedGetOauthClient = vi.mocked(getOauthClient);
|
||||
const mockedSetupUser = vi.mocked(setupUser);
|
||||
const MockedCodeAssistServer = vi.mocked(CodeAssistServer);
|
||||
const MockedLoggingContentGenerator = vi.mocked(LoggingContentGenerator);
|
||||
const MockedModelMappingContentGenerator = vi.mocked(
|
||||
ModelMappingContentGenerator,
|
||||
);
|
||||
|
||||
describe('codeAssist', () => {
|
||||
beforeEach(() => {
|
||||
@@ -178,5 +183,47 @@ describe('codeAssist', () => {
|
||||
const server = getCodeAssistServer(mockConfig);
|
||||
expect(server).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should unwrap and return the server if it is wrapped in a ModelMappingContentGenerator', () => {
|
||||
const mockServer = new MockedCodeAssistServer({} as never, '', {});
|
||||
const mockMapper = new MockedModelMappingContentGenerator(
|
||||
{} as never,
|
||||
{},
|
||||
);
|
||||
vi.spyOn(mockMapper, 'getWrapped').mockReturnValue(mockServer);
|
||||
|
||||
const mockConfig = {
|
||||
getContentGenerator: () => mockMapper,
|
||||
} as unknown as Config;
|
||||
|
||||
const server = getCodeAssistServer(mockConfig);
|
||||
expect(server).toBe(mockServer);
|
||||
expect(mockMapper.getWrapped).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should recursively unwrap multiple layers of LoggingContentGenerator and ModelMappingContentGenerator', () => {
|
||||
const mockServer = new MockedCodeAssistServer({} as never, '', {});
|
||||
const mockLogger = new MockedLoggingContentGenerator(
|
||||
{} as never,
|
||||
{} as never,
|
||||
);
|
||||
const mockMapper = new MockedModelMappingContentGenerator(
|
||||
{} as never,
|
||||
{},
|
||||
);
|
||||
|
||||
// Mapper wraps Logger wraps Server
|
||||
vi.spyOn(mockMapper, 'getWrapped').mockReturnValue(mockLogger);
|
||||
vi.spyOn(mockLogger, 'getWrapped').mockReturnValue(mockServer);
|
||||
|
||||
const mockConfig = {
|
||||
getContentGenerator: () => mockMapper,
|
||||
} as unknown as Config;
|
||||
|
||||
const server = getCodeAssistServer(mockConfig);
|
||||
expect(server).toBe(mockServer);
|
||||
expect(mockMapper.getWrapped).toHaveBeenCalled();
|
||||
expect(mockLogger.getWrapped).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -10,6 +10,7 @@ import { setupUser } from './setup.js';
|
||||
import { CodeAssistServer, type HttpOptions } from './server.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { LoggingContentGenerator } from '../core/loggingContentGenerator.js';
|
||||
import { ModelMappingContentGenerator } from '../core/modelMappingContentGenerator.js';
|
||||
|
||||
export async function createCodeAssistContentGenerator(
|
||||
httpOptions: HttpOptions,
|
||||
@@ -43,9 +44,15 @@ export function getCodeAssistServer(
|
||||
): CodeAssistServer | undefined {
|
||||
let server = config.getContentGenerator();
|
||||
|
||||
// Unwrap LoggingContentGenerator if present
|
||||
if (server instanceof LoggingContentGenerator) {
|
||||
server = server.getWrapped();
|
||||
// Recursively unwrap LoggingContentGenerator and ModelMappingContentGenerator
|
||||
while (true) {
|
||||
if (server instanceof LoggingContentGenerator) {
|
||||
server = server.getWrapped();
|
||||
} else if (server instanceof ModelMappingContentGenerator) {
|
||||
server = server.getWrapped();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!(server instanceof CodeAssistServer)) {
|
||||
|
||||
@@ -4379,7 +4379,7 @@ describe('hasGemini35FlashGAAccess model setting', () => {
|
||||
expect(PREVIEW_GEMINI_FLASH_MODEL).toBe('gemini-3-flash-preview');
|
||||
});
|
||||
|
||||
it('should set DEFAULT_GEMINI_FLASH_MODEL and PREVIEW_GEMINI_FLASH_MODEL to gemini-3-flash if hasGemini35FlashGAAccess returns true and authType is not USE_GEMINI', () => {
|
||||
it('should set DEFAULT_GEMINI_FLASH_MODEL and PREVIEW_GEMINI_FLASH_MODEL to gemini-3.5-flash if hasGemini35FlashGAAccess returns true and authType is not USE_GEMINI', () => {
|
||||
const config = new Config(baseParams);
|
||||
config['contentGeneratorConfig'] = { authType: AuthType.LOGIN_WITH_GOOGLE };
|
||||
|
||||
@@ -4397,7 +4397,7 @@ describe('hasGemini35FlashGAAccess model setting', () => {
|
||||
const result = config.hasGemini35FlashGAAccess();
|
||||
expect(result).toBe(true);
|
||||
|
||||
expect(DEFAULT_GEMINI_FLASH_MODEL).toBe('gemini-3-flash');
|
||||
expect(PREVIEW_GEMINI_FLASH_MODEL).toBe('gemini-3-flash');
|
||||
expect(DEFAULT_GEMINI_FLASH_MODEL).toBe('gemini-3.5-flash');
|
||||
expect(PREVIEW_GEMINI_FLASH_MODEL).toBe('gemini-3.5-flash');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -3566,7 +3566,7 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
if (authType === AuthType.USE_GEMINI) {
|
||||
setFlashModels('gemini-3-flash-preview', 'gemini-3.5-flash');
|
||||
} else {
|
||||
setFlashModels('gemini-3-flash', 'gemini-3-flash');
|
||||
setFlashModels('gemini-3.5-flash', 'gemini-3.5-flash');
|
||||
}
|
||||
} else {
|
||||
setFlashModels('gemini-3-flash-preview', 'gemini-2.5-flash');
|
||||
|
||||
@@ -574,3 +574,7 @@ export function isActiveModel(
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export const CCPA_AI_MODEL_MAPPINGS: Record<string, string> = {
|
||||
[DEFAULT_GEMINI_3_5_FLASH_MODEL]: SECONDARY_GEMINI_3_5_FLASH_MODEL,
|
||||
};
|
||||
|
||||
@@ -18,10 +18,13 @@ import { HttpProxyAgent } from 'http-proxy-agent';
|
||||
import { HttpsProxyAgent } from 'https-proxy-agent';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { LoggingContentGenerator } from './loggingContentGenerator.js';
|
||||
import { ModelMappingContentGenerator } from './modelMappingContentGenerator.js';
|
||||
import { CCPA_AI_MODEL_MAPPINGS } from '../config/models.js';
|
||||
import { loadApiKey } from './apiKeyCredentialStorage.js';
|
||||
import { FakeContentGenerator } from './fakeContentGenerator.js';
|
||||
import { RecordingContentGenerator } from './recordingContentGenerator.js';
|
||||
import { resetVersionCache } from '../utils/version.js';
|
||||
import type { LlmRole } from '../telemetry/llmRole.js';
|
||||
|
||||
vi.mock('../code_assist/codeAssist.js');
|
||||
vi.mock('@google/genai');
|
||||
@@ -36,6 +39,14 @@ const mockConfig = {
|
||||
getProxy: vi.fn().mockReturnValue(undefined),
|
||||
getUsageStatisticsEnabled: vi.fn().mockReturnValue(true),
|
||||
getClientName: vi.fn().mockReturnValue(undefined),
|
||||
getTelemetryLogPromptsEnabled: vi.fn().mockReturnValue(true),
|
||||
getTelemetryTracesEnabled: vi.fn().mockReturnValue(true),
|
||||
getSessionId: vi.fn().mockReturnValue('test-session-id'),
|
||||
refreshUserQuotaIfStale: vi.fn().mockResolvedValue(undefined),
|
||||
setLatestApiRequest: vi.fn(),
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
|
||||
isInteractive: vi.fn().mockReturnValue(false),
|
||||
getExperiments: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as Config;
|
||||
|
||||
describe('getAuthTypeFromEnv', () => {
|
||||
@@ -142,7 +153,10 @@ describe('createContentGenerator', () => {
|
||||
);
|
||||
expect(createCodeAssistContentGenerator).toHaveBeenCalled();
|
||||
expect(generator).toEqual(
|
||||
new LoggingContentGenerator(mockGenerator, mockConfig),
|
||||
new LoggingContentGenerator(
|
||||
new ModelMappingContentGenerator(mockGenerator, CCPA_AI_MODEL_MAPPINGS),
|
||||
mockConfig,
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -159,7 +173,10 @@ describe('createContentGenerator', () => {
|
||||
);
|
||||
expect(createCodeAssistContentGenerator).toHaveBeenCalled();
|
||||
expect(generator).toEqual(
|
||||
new LoggingContentGenerator(mockGenerator, mockConfig),
|
||||
new LoggingContentGenerator(
|
||||
new ModelMappingContentGenerator(mockGenerator, CCPA_AI_MODEL_MAPPINGS),
|
||||
mockConfig,
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1095,6 +1112,178 @@ describe('createContentGenerator', () => {
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not apply model mapping for Vertex AI', async () => {
|
||||
const mockModels = {
|
||||
generateContent: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
const mockGenerator = {
|
||||
models: mockModels,
|
||||
} as unknown as GoogleGenAI;
|
||||
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never);
|
||||
|
||||
const generator = await createContentGenerator(
|
||||
{
|
||||
apiKey: 'test-api-key',
|
||||
authType: AuthType.USE_VERTEX_AI,
|
||||
vertexai: true,
|
||||
},
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
await generator.generateContent(
|
||||
{
|
||||
model: 'gemini-3-flash',
|
||||
contents: [],
|
||||
},
|
||||
'prompt-id',
|
||||
'user' as LlmRole,
|
||||
);
|
||||
|
||||
expect(mockModels.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'gemini-3-flash',
|
||||
}),
|
||||
'prompt-id',
|
||||
'user',
|
||||
);
|
||||
});
|
||||
|
||||
it('should not apply model mapping for Gemini API', async () => {
|
||||
const mockModels = {
|
||||
generateContent: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
const mockGenerator = {
|
||||
models: mockModels,
|
||||
} as unknown as GoogleGenAI;
|
||||
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never);
|
||||
|
||||
const generator = await createContentGenerator(
|
||||
{
|
||||
apiKey: 'test-api-key',
|
||||
authType: AuthType.USE_GEMINI,
|
||||
},
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
await generator.generateContent(
|
||||
{
|
||||
model: 'gemini-3-flash',
|
||||
contents: [],
|
||||
},
|
||||
'prompt-id',
|
||||
'user' as LlmRole,
|
||||
);
|
||||
|
||||
expect(mockModels.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'gemini-3-flash',
|
||||
}),
|
||||
'prompt-id',
|
||||
'user',
|
||||
);
|
||||
});
|
||||
|
||||
it('should not apply model mapping for GATEWAY', async () => {
|
||||
const mockModels = {
|
||||
generateContent: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
const mockGenerator = {
|
||||
models: mockModels,
|
||||
} as unknown as GoogleGenAI;
|
||||
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never);
|
||||
|
||||
const generator = await createContentGenerator(
|
||||
{
|
||||
apiKey: 'test-api-key',
|
||||
authType: AuthType.GATEWAY,
|
||||
},
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
await generator.generateContent(
|
||||
{
|
||||
model: 'gemini-3.5-flash',
|
||||
contents: [],
|
||||
},
|
||||
'prompt-id',
|
||||
'user' as LlmRole,
|
||||
);
|
||||
|
||||
expect(mockModels.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'gemini-3.5-flash',
|
||||
}),
|
||||
'prompt-id',
|
||||
'user',
|
||||
);
|
||||
});
|
||||
|
||||
it('should apply model mapping for LOGIN_WITH_GOOGLE', async () => {
|
||||
const mockInnerGenerator = {
|
||||
generateContent: vi.fn().mockResolvedValue({}),
|
||||
} as unknown as ContentGenerator;
|
||||
vi.mocked(createCodeAssistContentGenerator).mockResolvedValue(
|
||||
mockInnerGenerator as never,
|
||||
);
|
||||
|
||||
const generator = await createContentGenerator(
|
||||
{
|
||||
authType: AuthType.LOGIN_WITH_GOOGLE,
|
||||
},
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
await generator.generateContent(
|
||||
{
|
||||
model: 'gemini-3.5-flash',
|
||||
contents: [],
|
||||
},
|
||||
'prompt-id',
|
||||
'user' as LlmRole,
|
||||
);
|
||||
|
||||
expect(mockInnerGenerator.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'gemini-3-flash',
|
||||
}),
|
||||
'prompt-id',
|
||||
'user',
|
||||
);
|
||||
});
|
||||
|
||||
it('should apply model mapping for COMPUTE_ADC', async () => {
|
||||
const mockInnerGenerator = {
|
||||
generateContent: vi.fn().mockResolvedValue({}),
|
||||
} as unknown as ContentGenerator;
|
||||
vi.mocked(createCodeAssistContentGenerator).mockResolvedValue(
|
||||
mockInnerGenerator as never,
|
||||
);
|
||||
|
||||
const generator = await createContentGenerator(
|
||||
{
|
||||
authType: AuthType.COMPUTE_ADC,
|
||||
},
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
await generator.generateContent(
|
||||
{
|
||||
model: 'gemini-3.5-flash',
|
||||
contents: [],
|
||||
},
|
||||
'prompt-id',
|
||||
'user' as LlmRole,
|
||||
);
|
||||
|
||||
expect(mockInnerGenerator.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'gemini-3-flash',
|
||||
}),
|
||||
'prompt-id',
|
||||
'user',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('createContentGeneratorConfig', () => {
|
||||
|
||||
@@ -30,6 +30,8 @@ import { determineSurface } from '../utils/surface.js';
|
||||
import { RecordingContentGenerator } from './recordingContentGenerator.js';
|
||||
import { getVersion, resolveModel } from '../../index.js';
|
||||
import type { LlmRole } from '../telemetry/llmRole.js';
|
||||
import { ModelMappingContentGenerator } from './modelMappingContentGenerator.js';
|
||||
import { CCPA_AI_MODEL_MAPPINGS } from '../config/models.js';
|
||||
|
||||
/**
|
||||
* Interface abstracting the core functionalities for generating content and counting tokens.
|
||||
@@ -282,11 +284,14 @@ export async function createContentGenerator(
|
||||
) {
|
||||
const httpOptions = { headers: baseHeaders };
|
||||
return new LoggingContentGenerator(
|
||||
await createCodeAssistContentGenerator(
|
||||
httpOptions,
|
||||
config.authType,
|
||||
gcConfig,
|
||||
sessionId,
|
||||
new ModelMappingContentGenerator(
|
||||
await createCodeAssistContentGenerator(
|
||||
httpOptions,
|
||||
config.authType,
|
||||
gcConfig,
|
||||
sessionId,
|
||||
),
|
||||
CCPA_AI_MODEL_MAPPINGS,
|
||||
),
|
||||
gcConfig,
|
||||
);
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { ModelMappingContentGenerator } from './modelMappingContentGenerator.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import { LlmRole } from '../telemetry/llmRole.js';
|
||||
import type { GenerateContentParameters } from '@google/genai';
|
||||
|
||||
describe('ModelMappingContentGenerator', () => {
|
||||
const mockMappings = {
|
||||
'gemini-3.5-flash': 'gemini-3-flash',
|
||||
'gemini-pro': 'gemini-1.5-pro',
|
||||
};
|
||||
|
||||
it('delegates userTier, userTierName, and paidTier properties', () => {
|
||||
const mockWrapped = {
|
||||
userTier: 'free',
|
||||
userTierName: 'Free Tier',
|
||||
paidTier: { id: 'paid' },
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
const generator = new ModelMappingContentGenerator(
|
||||
mockWrapped,
|
||||
mockMappings,
|
||||
);
|
||||
|
||||
expect(generator.userTier).toBe('free');
|
||||
expect(generator.userTierName).toBe('Free Tier');
|
||||
expect(generator.paidTier).toEqual({ id: 'paid' });
|
||||
});
|
||||
|
||||
it('maps matching model without prefix', async () => {
|
||||
const mockWrapped = {
|
||||
generateContent: vi.fn().mockResolvedValue({}),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
const generator = new ModelMappingContentGenerator(
|
||||
mockWrapped,
|
||||
mockMappings,
|
||||
);
|
||||
const req = { model: 'gemini-3.5-flash', contents: [] };
|
||||
|
||||
await generator.generateContent(req, 'prompt-id', LlmRole.MAIN);
|
||||
|
||||
expect(mockWrapped.generateContent).toHaveBeenCalledWith(
|
||||
{ model: 'gemini-3-flash', contents: [] },
|
||||
'prompt-id',
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
});
|
||||
|
||||
it('maps matching model with models/ prefix', async () => {
|
||||
const mockWrapped = {
|
||||
generateContent: vi.fn().mockResolvedValue({}),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
const generator = new ModelMappingContentGenerator(
|
||||
mockWrapped,
|
||||
mockMappings,
|
||||
);
|
||||
const req = { model: 'models/gemini-3.5-flash', contents: [] };
|
||||
|
||||
await generator.generateContent(req, 'prompt-id', LlmRole.MAIN);
|
||||
|
||||
expect(mockWrapped.generateContent).toHaveBeenCalledWith(
|
||||
{ model: 'models/gemini-3-flash', contents: [] },
|
||||
'prompt-id',
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
});
|
||||
|
||||
it('leaves unmapped model unchanged', async () => {
|
||||
const mockWrapped = {
|
||||
generateContent: vi.fn().mockResolvedValue({}),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
const generator = new ModelMappingContentGenerator(
|
||||
mockWrapped,
|
||||
mockMappings,
|
||||
);
|
||||
const req = { model: 'unknown-model', contents: [] };
|
||||
|
||||
await generator.generateContent(req, 'prompt-id', LlmRole.MAIN);
|
||||
|
||||
expect(mockWrapped.generateContent).toHaveBeenCalledWith(
|
||||
{ model: 'unknown-model', contents: [] },
|
||||
'prompt-id',
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
});
|
||||
|
||||
it('leaves model with prefix unchanged if no match after normalization', async () => {
|
||||
const mockWrapped = {
|
||||
generateContent: vi.fn().mockResolvedValue({}),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
const generator = new ModelMappingContentGenerator(
|
||||
mockWrapped,
|
||||
mockMappings,
|
||||
);
|
||||
const req = { model: 'models/unknown-model', contents: [] };
|
||||
|
||||
await generator.generateContent(req, 'prompt-id', LlmRole.MAIN);
|
||||
|
||||
expect(mockWrapped.generateContent).toHaveBeenCalledWith(
|
||||
{ model: 'models/unknown-model', contents: [] },
|
||||
'prompt-id',
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
});
|
||||
|
||||
it('handles missing/undefined model property safely', async () => {
|
||||
const mockWrapped = {
|
||||
generateContent: vi.fn().mockResolvedValue({}),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
const generator = new ModelMappingContentGenerator(
|
||||
mockWrapped,
|
||||
mockMappings,
|
||||
);
|
||||
const req = { contents: [] } as unknown as GenerateContentParameters;
|
||||
|
||||
await generator.generateContent(req, 'prompt-id', LlmRole.MAIN);
|
||||
|
||||
expect(mockWrapped.generateContent).toHaveBeenCalledWith(
|
||||
{ contents: [] },
|
||||
'prompt-id',
|
||||
LlmRole.MAIN,
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,88 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
type CountTokensResponse,
|
||||
type GenerateContentResponse,
|
||||
type GenerateContentParameters,
|
||||
type CountTokensParameters,
|
||||
type EmbedContentResponse,
|
||||
type EmbedContentParameters,
|
||||
} from '@google/genai';
|
||||
import { type ContentGenerator } from './contentGenerator.js';
|
||||
import type { LlmRole } from '../telemetry/llmRole.js';
|
||||
import type { UserTierId, GeminiUserTier } from '../code_assist/types.js';
|
||||
import { normalizeModelId } from '../utils/modelUtils.js';
|
||||
|
||||
export class ModelMappingContentGenerator implements ContentGenerator {
|
||||
constructor(
|
||||
private readonly wrapped: ContentGenerator,
|
||||
private readonly mappings: Record<string, string>,
|
||||
) {}
|
||||
|
||||
getWrapped(): ContentGenerator {
|
||||
return this.wrapped;
|
||||
}
|
||||
|
||||
get userTier(): UserTierId | undefined {
|
||||
return this.wrapped.userTier;
|
||||
}
|
||||
|
||||
get userTierName(): string | undefined {
|
||||
return this.wrapped.userTierName;
|
||||
}
|
||||
|
||||
get paidTier(): GeminiUserTier | undefined {
|
||||
return this.wrapped.paidTier;
|
||||
}
|
||||
|
||||
private mapModel<T extends { model?: string }>(req: T): T {
|
||||
if (req.model) {
|
||||
const normalizedModel = normalizeModelId(req.model);
|
||||
if (this.mappings[normalizedModel]) {
|
||||
return {
|
||||
...req,
|
||||
model: req.model.startsWith('models/')
|
||||
? `models/${this.mappings[normalizedModel]}`
|
||||
: this.mappings[normalizedModel],
|
||||
};
|
||||
}
|
||||
}
|
||||
return req;
|
||||
}
|
||||
|
||||
generateContent(
|
||||
request: GenerateContentParameters,
|
||||
userPromptId: string,
|
||||
role: LlmRole,
|
||||
): Promise<GenerateContentResponse> {
|
||||
return this.wrapped.generateContent(
|
||||
this.mapModel(request),
|
||||
userPromptId,
|
||||
role,
|
||||
);
|
||||
}
|
||||
|
||||
generateContentStream(
|
||||
request: GenerateContentParameters,
|
||||
userPromptId: string,
|
||||
role: LlmRole,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
return this.wrapped.generateContentStream(
|
||||
this.mapModel(request),
|
||||
userPromptId,
|
||||
role,
|
||||
);
|
||||
}
|
||||
|
||||
countTokens(request: CountTokensParameters): Promise<CountTokensResponse> {
|
||||
return this.wrapped.countTokens(this.mapModel(request));
|
||||
}
|
||||
|
||||
embedContent(request: EmbedContentParameters): Promise<EmbedContentResponse> {
|
||||
return this.wrapped.embedContent(this.mapModel(request));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user