diff --git a/packages/core/src/code_assist/codeAssist.test.ts b/packages/core/src/code_assist/codeAssist.test.ts index 1a4ba66f27..0a20be1d43 100644 --- a/packages/core/src/code_assist/codeAssist.test.ts +++ b/packages/core/src/code_assist/codeAssist.test.ts @@ -15,6 +15,7 @@ import { } from './codeAssist.js'; import type { Config } from '../config/config.js'; import { LoggingContentGenerator } from '../core/loggingContentGenerator.js'; +import { ModelMappingContentGenerator } from '../core/modelMappingContentGenerator.js'; import { UserTierId } from './types.js'; // Mock dependencies @@ -22,11 +23,15 @@ vi.mock('./oauth2.js'); vi.mock('./setup.js'); vi.mock('./server.js'); vi.mock('../core/loggingContentGenerator.js'); +vi.mock('../core/modelMappingContentGenerator.js'); const mockedGetOauthClient = vi.mocked(getOauthClient); const mockedSetupUser = vi.mocked(setupUser); const MockedCodeAssistServer = vi.mocked(CodeAssistServer); const MockedLoggingContentGenerator = vi.mocked(LoggingContentGenerator); +const MockedModelMappingContentGenerator = vi.mocked( + ModelMappingContentGenerator, +); describe('codeAssist', () => { beforeEach(() => { @@ -178,5 +183,47 @@ describe('codeAssist', () => { const server = getCodeAssistServer(mockConfig); expect(server).toBeUndefined(); }); + + it('should unwrap and return the server if it is wrapped in a ModelMappingContentGenerator', () => { + const mockServer = new MockedCodeAssistServer({} as never, '', {}); + const mockMapper = new MockedModelMappingContentGenerator( + {} as never, + {}, + ); + vi.spyOn(mockMapper, 'getWrapped').mockReturnValue(mockServer); + + const mockConfig = { + getContentGenerator: () => mockMapper, + } as unknown as Config; + + const server = getCodeAssistServer(mockConfig); + expect(server).toBe(mockServer); + expect(mockMapper.getWrapped).toHaveBeenCalled(); + }); + + it('should recursively unwrap multiple layers of LoggingContentGenerator and ModelMappingContentGenerator', () => { + const mockServer = new MockedCodeAssistServer({} as never, '', {}); + const mockLogger = new MockedLoggingContentGenerator( + {} as never, + {} as never, + ); + const mockMapper = new MockedModelMappingContentGenerator( + {} as never, + {}, + ); + + // Mapper wraps Logger wraps Server + vi.spyOn(mockMapper, 'getWrapped').mockReturnValue(mockLogger); + vi.spyOn(mockLogger, 'getWrapped').mockReturnValue(mockServer); + + const mockConfig = { + getContentGenerator: () => mockMapper, + } as unknown as Config; + + const server = getCodeAssistServer(mockConfig); + expect(server).toBe(mockServer); + expect(mockMapper.getWrapped).toHaveBeenCalled(); + expect(mockLogger.getWrapped).toHaveBeenCalled(); + }); }); }); diff --git a/packages/core/src/code_assist/codeAssist.ts b/packages/core/src/code_assist/codeAssist.ts index 4fcbea7853..b6c28c44a7 100644 --- a/packages/core/src/code_assist/codeAssist.ts +++ b/packages/core/src/code_assist/codeAssist.ts @@ -10,6 +10,7 @@ import { setupUser } from './setup.js'; import { CodeAssistServer, type HttpOptions } from './server.js'; import type { Config } from '../config/config.js'; import { LoggingContentGenerator } from '../core/loggingContentGenerator.js'; +import { ModelMappingContentGenerator } from '../core/modelMappingContentGenerator.js'; export async function createCodeAssistContentGenerator( httpOptions: HttpOptions, @@ -43,9 +44,15 @@ export function getCodeAssistServer( ): CodeAssistServer | undefined { let server = config.getContentGenerator(); - // Unwrap LoggingContentGenerator if present - if (server instanceof LoggingContentGenerator) { - server = server.getWrapped(); + // Recursively unwrap LoggingContentGenerator and ModelMappingContentGenerator + while (true) { + if (server instanceof LoggingContentGenerator) { + server = server.getWrapped(); + } else if (server instanceof ModelMappingContentGenerator) { + server = server.getWrapped(); + } else { + break; + } } if (!(server instanceof CodeAssistServer)) { diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 48b15253bc..69a1dc5c29 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -4379,7 +4379,7 @@ describe('hasGemini35FlashGAAccess model setting', () => { expect(PREVIEW_GEMINI_FLASH_MODEL).toBe('gemini-3-flash-preview'); }); - it('should set DEFAULT_GEMINI_FLASH_MODEL and PREVIEW_GEMINI_FLASH_MODEL to gemini-3-flash if hasGemini35FlashGAAccess returns true and authType is not USE_GEMINI', () => { + it('should set DEFAULT_GEMINI_FLASH_MODEL and PREVIEW_GEMINI_FLASH_MODEL to gemini-3.5-flash if hasGemini35FlashGAAccess returns true and authType is not USE_GEMINI', () => { const config = new Config(baseParams); config['contentGeneratorConfig'] = { authType: AuthType.LOGIN_WITH_GOOGLE }; @@ -4397,7 +4397,7 @@ describe('hasGemini35FlashGAAccess model setting', () => { const result = config.hasGemini35FlashGAAccess(); expect(result).toBe(true); - expect(DEFAULT_GEMINI_FLASH_MODEL).toBe('gemini-3-flash'); - expect(PREVIEW_GEMINI_FLASH_MODEL).toBe('gemini-3-flash'); + expect(DEFAULT_GEMINI_FLASH_MODEL).toBe('gemini-3.5-flash'); + expect(PREVIEW_GEMINI_FLASH_MODEL).toBe('gemini-3.5-flash'); }); }); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index f7895931d5..5196e5cd63 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -3566,7 +3566,7 @@ export class Config implements McpContext, AgentLoopContext { if (authType === AuthType.USE_GEMINI) { setFlashModels('gemini-3-flash-preview', 'gemini-3.5-flash'); } else { - setFlashModels('gemini-3-flash', 'gemini-3-flash'); + setFlashModels('gemini-3.5-flash', 'gemini-3.5-flash'); } } else { setFlashModels('gemini-3-flash-preview', 'gemini-2.5-flash'); diff --git a/packages/core/src/config/models.ts b/packages/core/src/config/models.ts index 1fd8047f02..dd6506b837 100644 --- a/packages/core/src/config/models.ts +++ b/packages/core/src/config/models.ts @@ -574,3 +574,7 @@ export function isActiveModel( ); } } + +export const CCPA_AI_MODEL_MAPPINGS: Record = { + [DEFAULT_GEMINI_3_5_FLASH_MODEL]: SECONDARY_GEMINI_3_5_FLASH_MODEL, +}; diff --git a/packages/core/src/core/contentGenerator.test.ts b/packages/core/src/core/contentGenerator.test.ts index 72e9c5b514..60cb4c563e 100644 --- a/packages/core/src/core/contentGenerator.test.ts +++ b/packages/core/src/core/contentGenerator.test.ts @@ -18,10 +18,13 @@ import { HttpProxyAgent } from 'http-proxy-agent'; import { HttpsProxyAgent } from 'https-proxy-agent'; import type { Config } from '../config/config.js'; import { LoggingContentGenerator } from './loggingContentGenerator.js'; +import { ModelMappingContentGenerator } from './modelMappingContentGenerator.js'; +import { CCPA_AI_MODEL_MAPPINGS } from '../config/models.js'; import { loadApiKey } from './apiKeyCredentialStorage.js'; import { FakeContentGenerator } from './fakeContentGenerator.js'; import { RecordingContentGenerator } from './recordingContentGenerator.js'; import { resetVersionCache } from '../utils/version.js'; +import type { LlmRole } from '../telemetry/llmRole.js'; vi.mock('../code_assist/codeAssist.js'); vi.mock('@google/genai'); @@ -36,6 +39,14 @@ const mockConfig = { getProxy: vi.fn().mockReturnValue(undefined), getUsageStatisticsEnabled: vi.fn().mockReturnValue(true), getClientName: vi.fn().mockReturnValue(undefined), + getTelemetryLogPromptsEnabled: vi.fn().mockReturnValue(true), + getTelemetryTracesEnabled: vi.fn().mockReturnValue(true), + getSessionId: vi.fn().mockReturnValue('test-session-id'), + refreshUserQuotaIfStale: vi.fn().mockResolvedValue(undefined), + setLatestApiRequest: vi.fn(), + getContentGeneratorConfig: vi.fn().mockReturnValue({}), + isInteractive: vi.fn().mockReturnValue(false), + getExperiments: vi.fn().mockReturnValue(undefined), } as unknown as Config; describe('getAuthTypeFromEnv', () => { @@ -142,7 +153,10 @@ describe('createContentGenerator', () => { ); expect(createCodeAssistContentGenerator).toHaveBeenCalled(); expect(generator).toEqual( - new LoggingContentGenerator(mockGenerator, mockConfig), + new LoggingContentGenerator( + new ModelMappingContentGenerator(mockGenerator, CCPA_AI_MODEL_MAPPINGS), + mockConfig, + ), ); }); @@ -159,7 +173,10 @@ describe('createContentGenerator', () => { ); expect(createCodeAssistContentGenerator).toHaveBeenCalled(); expect(generator).toEqual( - new LoggingContentGenerator(mockGenerator, mockConfig), + new LoggingContentGenerator( + new ModelMappingContentGenerator(mockGenerator, CCPA_AI_MODEL_MAPPINGS), + mockConfig, + ), ); }); @@ -1095,6 +1112,178 @@ describe('createContentGenerator', () => { }), ); }); + + it('should not apply model mapping for Vertex AI', async () => { + const mockModels = { + generateContent: vi.fn().mockResolvedValue({}), + }; + const mockGenerator = { + models: mockModels, + } as unknown as GoogleGenAI; + vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never); + + const generator = await createContentGenerator( + { + apiKey: 'test-api-key', + authType: AuthType.USE_VERTEX_AI, + vertexai: true, + }, + mockConfig, + ); + + await generator.generateContent( + { + model: 'gemini-3-flash', + contents: [], + }, + 'prompt-id', + 'user' as LlmRole, + ); + + expect(mockModels.generateContent).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'gemini-3-flash', + }), + 'prompt-id', + 'user', + ); + }); + + it('should not apply model mapping for Gemini API', async () => { + const mockModels = { + generateContent: vi.fn().mockResolvedValue({}), + }; + const mockGenerator = { + models: mockModels, + } as unknown as GoogleGenAI; + vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never); + + const generator = await createContentGenerator( + { + apiKey: 'test-api-key', + authType: AuthType.USE_GEMINI, + }, + mockConfig, + ); + + await generator.generateContent( + { + model: 'gemini-3-flash', + contents: [], + }, + 'prompt-id', + 'user' as LlmRole, + ); + + expect(mockModels.generateContent).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'gemini-3-flash', + }), + 'prompt-id', + 'user', + ); + }); + + it('should not apply model mapping for GATEWAY', async () => { + const mockModels = { + generateContent: vi.fn().mockResolvedValue({}), + }; + const mockGenerator = { + models: mockModels, + } as unknown as GoogleGenAI; + vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never); + + const generator = await createContentGenerator( + { + apiKey: 'test-api-key', + authType: AuthType.GATEWAY, + }, + mockConfig, + ); + + await generator.generateContent( + { + model: 'gemini-3.5-flash', + contents: [], + }, + 'prompt-id', + 'user' as LlmRole, + ); + + expect(mockModels.generateContent).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'gemini-3.5-flash', + }), + 'prompt-id', + 'user', + ); + }); + + it('should apply model mapping for LOGIN_WITH_GOOGLE', async () => { + const mockInnerGenerator = { + generateContent: vi.fn().mockResolvedValue({}), + } as unknown as ContentGenerator; + vi.mocked(createCodeAssistContentGenerator).mockResolvedValue( + mockInnerGenerator as never, + ); + + const generator = await createContentGenerator( + { + authType: AuthType.LOGIN_WITH_GOOGLE, + }, + mockConfig, + ); + + await generator.generateContent( + { + model: 'gemini-3.5-flash', + contents: [], + }, + 'prompt-id', + 'user' as LlmRole, + ); + + expect(mockInnerGenerator.generateContent).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'gemini-3-flash', + }), + 'prompt-id', + 'user', + ); + }); + + it('should apply model mapping for COMPUTE_ADC', async () => { + const mockInnerGenerator = { + generateContent: vi.fn().mockResolvedValue({}), + } as unknown as ContentGenerator; + vi.mocked(createCodeAssistContentGenerator).mockResolvedValue( + mockInnerGenerator as never, + ); + + const generator = await createContentGenerator( + { + authType: AuthType.COMPUTE_ADC, + }, + mockConfig, + ); + + await generator.generateContent( + { + model: 'gemini-3.5-flash', + contents: [], + }, + 'prompt-id', + 'user' as LlmRole, + ); + + expect(mockInnerGenerator.generateContent).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'gemini-3-flash', + }), + 'prompt-id', + 'user', + ); + }); }); describe('createContentGeneratorConfig', () => { diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index 04493c6d73..c893860d4c 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -30,6 +30,8 @@ import { determineSurface } from '../utils/surface.js'; import { RecordingContentGenerator } from './recordingContentGenerator.js'; import { getVersion, resolveModel } from '../../index.js'; import type { LlmRole } from '../telemetry/llmRole.js'; +import { ModelMappingContentGenerator } from './modelMappingContentGenerator.js'; +import { CCPA_AI_MODEL_MAPPINGS } from '../config/models.js'; /** * Interface abstracting the core functionalities for generating content and counting tokens. @@ -282,11 +284,14 @@ export async function createContentGenerator( ) { const httpOptions = { headers: baseHeaders }; return new LoggingContentGenerator( - await createCodeAssistContentGenerator( - httpOptions, - config.authType, - gcConfig, - sessionId, + new ModelMappingContentGenerator( + await createCodeAssistContentGenerator( + httpOptions, + config.authType, + gcConfig, + sessionId, + ), + CCPA_AI_MODEL_MAPPINGS, ), gcConfig, ); diff --git a/packages/core/src/core/modelMappingContentGenerator.test.ts b/packages/core/src/core/modelMappingContentGenerator.test.ts new file mode 100644 index 0000000000..926f55fb36 --- /dev/null +++ b/packages/core/src/core/modelMappingContentGenerator.test.ts @@ -0,0 +1,135 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi } from 'vitest'; +import { ModelMappingContentGenerator } from './modelMappingContentGenerator.js'; +import type { ContentGenerator } from './contentGenerator.js'; +import { LlmRole } from '../telemetry/llmRole.js'; +import type { GenerateContentParameters } from '@google/genai'; + +describe('ModelMappingContentGenerator', () => { + const mockMappings = { + 'gemini-3.5-flash': 'gemini-3-flash', + 'gemini-pro': 'gemini-1.5-pro', + }; + + it('delegates userTier, userTierName, and paidTier properties', () => { + const mockWrapped = { + userTier: 'free', + userTierName: 'Free Tier', + paidTier: { id: 'paid' }, + } as unknown as ContentGenerator; + + const generator = new ModelMappingContentGenerator( + mockWrapped, + mockMappings, + ); + + expect(generator.userTier).toBe('free'); + expect(generator.userTierName).toBe('Free Tier'); + expect(generator.paidTier).toEqual({ id: 'paid' }); + }); + + it('maps matching model without prefix', async () => { + const mockWrapped = { + generateContent: vi.fn().mockResolvedValue({}), + } as unknown as ContentGenerator; + + const generator = new ModelMappingContentGenerator( + mockWrapped, + mockMappings, + ); + const req = { model: 'gemini-3.5-flash', contents: [] }; + + await generator.generateContent(req, 'prompt-id', LlmRole.MAIN); + + expect(mockWrapped.generateContent).toHaveBeenCalledWith( + { model: 'gemini-3-flash', contents: [] }, + 'prompt-id', + LlmRole.MAIN, + ); + }); + + it('maps matching model with models/ prefix', async () => { + const mockWrapped = { + generateContent: vi.fn().mockResolvedValue({}), + } as unknown as ContentGenerator; + + const generator = new ModelMappingContentGenerator( + mockWrapped, + mockMappings, + ); + const req = { model: 'models/gemini-3.5-flash', contents: [] }; + + await generator.generateContent(req, 'prompt-id', LlmRole.MAIN); + + expect(mockWrapped.generateContent).toHaveBeenCalledWith( + { model: 'models/gemini-3-flash', contents: [] }, + 'prompt-id', + LlmRole.MAIN, + ); + }); + + it('leaves unmapped model unchanged', async () => { + const mockWrapped = { + generateContent: vi.fn().mockResolvedValue({}), + } as unknown as ContentGenerator; + + const generator = new ModelMappingContentGenerator( + mockWrapped, + mockMappings, + ); + const req = { model: 'unknown-model', contents: [] }; + + await generator.generateContent(req, 'prompt-id', LlmRole.MAIN); + + expect(mockWrapped.generateContent).toHaveBeenCalledWith( + { model: 'unknown-model', contents: [] }, + 'prompt-id', + LlmRole.MAIN, + ); + }); + + it('leaves model with prefix unchanged if no match after normalization', async () => { + const mockWrapped = { + generateContent: vi.fn().mockResolvedValue({}), + } as unknown as ContentGenerator; + + const generator = new ModelMappingContentGenerator( + mockWrapped, + mockMappings, + ); + const req = { model: 'models/unknown-model', contents: [] }; + + await generator.generateContent(req, 'prompt-id', LlmRole.MAIN); + + expect(mockWrapped.generateContent).toHaveBeenCalledWith( + { model: 'models/unknown-model', contents: [] }, + 'prompt-id', + LlmRole.MAIN, + ); + }); + + it('handles missing/undefined model property safely', async () => { + const mockWrapped = { + generateContent: vi.fn().mockResolvedValue({}), + } as unknown as ContentGenerator; + + const generator = new ModelMappingContentGenerator( + mockWrapped, + mockMappings, + ); + const req = { contents: [] } as unknown as GenerateContentParameters; + + await generator.generateContent(req, 'prompt-id', LlmRole.MAIN); + + expect(mockWrapped.generateContent).toHaveBeenCalledWith( + { contents: [] }, + 'prompt-id', + LlmRole.MAIN, + ); + }); +}); diff --git a/packages/core/src/core/modelMappingContentGenerator.ts b/packages/core/src/core/modelMappingContentGenerator.ts new file mode 100644 index 0000000000..ef07f614ae --- /dev/null +++ b/packages/core/src/core/modelMappingContentGenerator.ts @@ -0,0 +1,88 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + type CountTokensResponse, + type GenerateContentResponse, + type GenerateContentParameters, + type CountTokensParameters, + type EmbedContentResponse, + type EmbedContentParameters, +} from '@google/genai'; +import { type ContentGenerator } from './contentGenerator.js'; +import type { LlmRole } from '../telemetry/llmRole.js'; +import type { UserTierId, GeminiUserTier } from '../code_assist/types.js'; +import { normalizeModelId } from '../utils/modelUtils.js'; + +export class ModelMappingContentGenerator implements ContentGenerator { + constructor( + private readonly wrapped: ContentGenerator, + private readonly mappings: Record, + ) {} + + getWrapped(): ContentGenerator { + return this.wrapped; + } + + get userTier(): UserTierId | undefined { + return this.wrapped.userTier; + } + + get userTierName(): string | undefined { + return this.wrapped.userTierName; + } + + get paidTier(): GeminiUserTier | undefined { + return this.wrapped.paidTier; + } + + private mapModel(req: T): T { + if (req.model) { + const normalizedModel = normalizeModelId(req.model); + if (this.mappings[normalizedModel]) { + return { + ...req, + model: req.model.startsWith('models/') + ? `models/${this.mappings[normalizedModel]}` + : this.mappings[normalizedModel], + }; + } + } + return req; + } + + generateContent( + request: GenerateContentParameters, + userPromptId: string, + role: LlmRole, + ): Promise { + return this.wrapped.generateContent( + this.mapModel(request), + userPromptId, + role, + ); + } + + generateContentStream( + request: GenerateContentParameters, + userPromptId: string, + role: LlmRole, + ): Promise> { + return this.wrapped.generateContentStream( + this.mapModel(request), + userPromptId, + role, + ); + } + + countTokens(request: CountTokensParameters): Promise { + return this.wrapped.countTokens(this.mapModel(request)); + } + + embedContent(request: EmbedContentParameters): Promise { + return this.wrapped.embedContent(this.mapModel(request)); + } +}