Vertex ai model mapping fix (#27749)

This commit is contained in:
David Pierce
2026-06-09 20:02:50 +00:00
committed by GitHub
parent 8e99c26dd8
commit f08b4af654
9 changed files with 489 additions and 14 deletions
@@ -15,6 +15,7 @@ import {
} from './codeAssist.js';
import type { Config } from '../config/config.js';
import { LoggingContentGenerator } from '../core/loggingContentGenerator.js';
import { ModelMappingContentGenerator } from '../core/modelMappingContentGenerator.js';
import { UserTierId } from './types.js';
// Mock dependencies
@@ -22,11 +23,15 @@ vi.mock('./oauth2.js');
vi.mock('./setup.js');
vi.mock('./server.js');
vi.mock('../core/loggingContentGenerator.js');
vi.mock('../core/modelMappingContentGenerator.js');
const mockedGetOauthClient = vi.mocked(getOauthClient);
const mockedSetupUser = vi.mocked(setupUser);
const MockedCodeAssistServer = vi.mocked(CodeAssistServer);
const MockedLoggingContentGenerator = vi.mocked(LoggingContentGenerator);
const MockedModelMappingContentGenerator = vi.mocked(
ModelMappingContentGenerator,
);
describe('codeAssist', () => {
beforeEach(() => {
@@ -178,5 +183,47 @@ describe('codeAssist', () => {
const server = getCodeAssistServer(mockConfig);
expect(server).toBeUndefined();
});
it('should unwrap and return the server if it is wrapped in a ModelMappingContentGenerator', () => {
const mockServer = new MockedCodeAssistServer({} as never, '', {});
const mockMapper = new MockedModelMappingContentGenerator(
{} as never,
{},
);
vi.spyOn(mockMapper, 'getWrapped').mockReturnValue(mockServer);
const mockConfig = {
getContentGenerator: () => mockMapper,
} as unknown as Config;
const server = getCodeAssistServer(mockConfig);
expect(server).toBe(mockServer);
expect(mockMapper.getWrapped).toHaveBeenCalled();
});
it('should recursively unwrap multiple layers of LoggingContentGenerator and ModelMappingContentGenerator', () => {
const mockServer = new MockedCodeAssistServer({} as never, '', {});
const mockLogger = new MockedLoggingContentGenerator(
{} as never,
{} as never,
);
const mockMapper = new MockedModelMappingContentGenerator(
{} as never,
{},
);
// Mapper wraps Logger wraps Server
vi.spyOn(mockMapper, 'getWrapped').mockReturnValue(mockLogger);
vi.spyOn(mockLogger, 'getWrapped').mockReturnValue(mockServer);
const mockConfig = {
getContentGenerator: () => mockMapper,
} as unknown as Config;
const server = getCodeAssistServer(mockConfig);
expect(server).toBe(mockServer);
expect(mockMapper.getWrapped).toHaveBeenCalled();
expect(mockLogger.getWrapped).toHaveBeenCalled();
});
});
});
+10 -3
View File
@@ -10,6 +10,7 @@ import { setupUser } from './setup.js';
import { CodeAssistServer, type HttpOptions } from './server.js';
import type { Config } from '../config/config.js';
import { LoggingContentGenerator } from '../core/loggingContentGenerator.js';
import { ModelMappingContentGenerator } from '../core/modelMappingContentGenerator.js';
export async function createCodeAssistContentGenerator(
httpOptions: HttpOptions,
@@ -43,9 +44,15 @@ export function getCodeAssistServer(
): CodeAssistServer | undefined {
let server = config.getContentGenerator();
// Unwrap LoggingContentGenerator if present
if (server instanceof LoggingContentGenerator) {
server = server.getWrapped();
// Recursively unwrap LoggingContentGenerator and ModelMappingContentGenerator
while (true) {
if (server instanceof LoggingContentGenerator) {
server = server.getWrapped();
} else if (server instanceof ModelMappingContentGenerator) {
server = server.getWrapped();
} else {
break;
}
}
if (!(server instanceof CodeAssistServer)) {
+3 -3
View File
@@ -4379,7 +4379,7 @@ describe('hasGemini35FlashGAAccess model setting', () => {
expect(PREVIEW_GEMINI_FLASH_MODEL).toBe('gemini-3-flash-preview');
});
it('should set DEFAULT_GEMINI_FLASH_MODEL and PREVIEW_GEMINI_FLASH_MODEL to gemini-3-flash if hasGemini35FlashGAAccess returns true and authType is not USE_GEMINI', () => {
it('should set DEFAULT_GEMINI_FLASH_MODEL and PREVIEW_GEMINI_FLASH_MODEL to gemini-3.5-flash if hasGemini35FlashGAAccess returns true and authType is not USE_GEMINI', () => {
const config = new Config(baseParams);
config['contentGeneratorConfig'] = { authType: AuthType.LOGIN_WITH_GOOGLE };
@@ -4397,7 +4397,7 @@ describe('hasGemini35FlashGAAccess model setting', () => {
const result = config.hasGemini35FlashGAAccess();
expect(result).toBe(true);
expect(DEFAULT_GEMINI_FLASH_MODEL).toBe('gemini-3-flash');
expect(PREVIEW_GEMINI_FLASH_MODEL).toBe('gemini-3-flash');
expect(DEFAULT_GEMINI_FLASH_MODEL).toBe('gemini-3.5-flash');
expect(PREVIEW_GEMINI_FLASH_MODEL).toBe('gemini-3.5-flash');
});
});
+1 -1
View File
@@ -3566,7 +3566,7 @@ export class Config implements McpContext, AgentLoopContext {
if (authType === AuthType.USE_GEMINI) {
setFlashModels('gemini-3-flash-preview', 'gemini-3.5-flash');
} else {
setFlashModels('gemini-3-flash', 'gemini-3-flash');
setFlashModels('gemini-3.5-flash', 'gemini-3.5-flash');
}
} else {
setFlashModels('gemini-3-flash-preview', 'gemini-2.5-flash');
+4
View File
@@ -574,3 +574,7 @@ export function isActiveModel(
);
}
}
export const CCPA_AI_MODEL_MAPPINGS: Record<string, string> = {
[DEFAULT_GEMINI_3_5_FLASH_MODEL]: SECONDARY_GEMINI_3_5_FLASH_MODEL,
};
+191 -2
View File
@@ -18,10 +18,13 @@ import { HttpProxyAgent } from 'http-proxy-agent';
import { HttpsProxyAgent } from 'https-proxy-agent';
import type { Config } from '../config/config.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
import { ModelMappingContentGenerator } from './modelMappingContentGenerator.js';
import { CCPA_AI_MODEL_MAPPINGS } from '../config/models.js';
import { loadApiKey } from './apiKeyCredentialStorage.js';
import { FakeContentGenerator } from './fakeContentGenerator.js';
import { RecordingContentGenerator } from './recordingContentGenerator.js';
import { resetVersionCache } from '../utils/version.js';
import type { LlmRole } from '../telemetry/llmRole.js';
vi.mock('../code_assist/codeAssist.js');
vi.mock('@google/genai');
@@ -36,6 +39,14 @@ const mockConfig = {
getProxy: vi.fn().mockReturnValue(undefined),
getUsageStatisticsEnabled: vi.fn().mockReturnValue(true),
getClientName: vi.fn().mockReturnValue(undefined),
getTelemetryLogPromptsEnabled: vi.fn().mockReturnValue(true),
getTelemetryTracesEnabled: vi.fn().mockReturnValue(true),
getSessionId: vi.fn().mockReturnValue('test-session-id'),
refreshUserQuotaIfStale: vi.fn().mockResolvedValue(undefined),
setLatestApiRequest: vi.fn(),
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
isInteractive: vi.fn().mockReturnValue(false),
getExperiments: vi.fn().mockReturnValue(undefined),
} as unknown as Config;
describe('getAuthTypeFromEnv', () => {
@@ -142,7 +153,10 @@ describe('createContentGenerator', () => {
);
expect(createCodeAssistContentGenerator).toHaveBeenCalled();
expect(generator).toEqual(
new LoggingContentGenerator(mockGenerator, mockConfig),
new LoggingContentGenerator(
new ModelMappingContentGenerator(mockGenerator, CCPA_AI_MODEL_MAPPINGS),
mockConfig,
),
);
});
@@ -159,7 +173,10 @@ describe('createContentGenerator', () => {
);
expect(createCodeAssistContentGenerator).toHaveBeenCalled();
expect(generator).toEqual(
new LoggingContentGenerator(mockGenerator, mockConfig),
new LoggingContentGenerator(
new ModelMappingContentGenerator(mockGenerator, CCPA_AI_MODEL_MAPPINGS),
mockConfig,
),
);
});
@@ -1095,6 +1112,178 @@ describe('createContentGenerator', () => {
}),
);
});
it('should not apply model mapping for Vertex AI', async () => {
const mockModels = {
generateContent: vi.fn().mockResolvedValue({}),
};
const mockGenerator = {
models: mockModels,
} as unknown as GoogleGenAI;
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never);
const generator = await createContentGenerator(
{
apiKey: 'test-api-key',
authType: AuthType.USE_VERTEX_AI,
vertexai: true,
},
mockConfig,
);
await generator.generateContent(
{
model: 'gemini-3-flash',
contents: [],
},
'prompt-id',
'user' as LlmRole,
);
expect(mockModels.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
model: 'gemini-3-flash',
}),
'prompt-id',
'user',
);
});
it('should not apply model mapping for Gemini API', async () => {
const mockModels = {
generateContent: vi.fn().mockResolvedValue({}),
};
const mockGenerator = {
models: mockModels,
} as unknown as GoogleGenAI;
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never);
const generator = await createContentGenerator(
{
apiKey: 'test-api-key',
authType: AuthType.USE_GEMINI,
},
mockConfig,
);
await generator.generateContent(
{
model: 'gemini-3-flash',
contents: [],
},
'prompt-id',
'user' as LlmRole,
);
expect(mockModels.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
model: 'gemini-3-flash',
}),
'prompt-id',
'user',
);
});
it('should not apply model mapping for GATEWAY', async () => {
const mockModels = {
generateContent: vi.fn().mockResolvedValue({}),
};
const mockGenerator = {
models: mockModels,
} as unknown as GoogleGenAI;
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never);
const generator = await createContentGenerator(
{
apiKey: 'test-api-key',
authType: AuthType.GATEWAY,
},
mockConfig,
);
await generator.generateContent(
{
model: 'gemini-3.5-flash',
contents: [],
},
'prompt-id',
'user' as LlmRole,
);
expect(mockModels.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
model: 'gemini-3.5-flash',
}),
'prompt-id',
'user',
);
});
it('should apply model mapping for LOGIN_WITH_GOOGLE', async () => {
const mockInnerGenerator = {
generateContent: vi.fn().mockResolvedValue({}),
} as unknown as ContentGenerator;
vi.mocked(createCodeAssistContentGenerator).mockResolvedValue(
mockInnerGenerator as never,
);
const generator = await createContentGenerator(
{
authType: AuthType.LOGIN_WITH_GOOGLE,
},
mockConfig,
);
await generator.generateContent(
{
model: 'gemini-3.5-flash',
contents: [],
},
'prompt-id',
'user' as LlmRole,
);
expect(mockInnerGenerator.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
model: 'gemini-3-flash',
}),
'prompt-id',
'user',
);
});
it('should apply model mapping for COMPUTE_ADC', async () => {
const mockInnerGenerator = {
generateContent: vi.fn().mockResolvedValue({}),
} as unknown as ContentGenerator;
vi.mocked(createCodeAssistContentGenerator).mockResolvedValue(
mockInnerGenerator as never,
);
const generator = await createContentGenerator(
{
authType: AuthType.COMPUTE_ADC,
},
mockConfig,
);
await generator.generateContent(
{
model: 'gemini-3.5-flash',
contents: [],
},
'prompt-id',
'user' as LlmRole,
);
expect(mockInnerGenerator.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
model: 'gemini-3-flash',
}),
'prompt-id',
'user',
);
});
});
describe('createContentGeneratorConfig', () => {
+10 -5
View File
@@ -30,6 +30,8 @@ import { determineSurface } from '../utils/surface.js';
import { RecordingContentGenerator } from './recordingContentGenerator.js';
import { getVersion, resolveModel } from '../../index.js';
import type { LlmRole } from '../telemetry/llmRole.js';
import { ModelMappingContentGenerator } from './modelMappingContentGenerator.js';
import { CCPA_AI_MODEL_MAPPINGS } from '../config/models.js';
/**
* Interface abstracting the core functionalities for generating content and counting tokens.
@@ -282,11 +284,14 @@ export async function createContentGenerator(
) {
const httpOptions = { headers: baseHeaders };
return new LoggingContentGenerator(
await createCodeAssistContentGenerator(
httpOptions,
config.authType,
gcConfig,
sessionId,
new ModelMappingContentGenerator(
await createCodeAssistContentGenerator(
httpOptions,
config.authType,
gcConfig,
sessionId,
),
CCPA_AI_MODEL_MAPPINGS,
),
gcConfig,
);
@@ -0,0 +1,135 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi } from 'vitest';
import { ModelMappingContentGenerator } from './modelMappingContentGenerator.js';
import type { ContentGenerator } from './contentGenerator.js';
import { LlmRole } from '../telemetry/llmRole.js';
import type { GenerateContentParameters } from '@google/genai';
describe('ModelMappingContentGenerator', () => {
const mockMappings = {
'gemini-3.5-flash': 'gemini-3-flash',
'gemini-pro': 'gemini-1.5-pro',
};
it('delegates userTier, userTierName, and paidTier properties', () => {
const mockWrapped = {
userTier: 'free',
userTierName: 'Free Tier',
paidTier: { id: 'paid' },
} as unknown as ContentGenerator;
const generator = new ModelMappingContentGenerator(
mockWrapped,
mockMappings,
);
expect(generator.userTier).toBe('free');
expect(generator.userTierName).toBe('Free Tier');
expect(generator.paidTier).toEqual({ id: 'paid' });
});
it('maps matching model without prefix', async () => {
const mockWrapped = {
generateContent: vi.fn().mockResolvedValue({}),
} as unknown as ContentGenerator;
const generator = new ModelMappingContentGenerator(
mockWrapped,
mockMappings,
);
const req = { model: 'gemini-3.5-flash', contents: [] };
await generator.generateContent(req, 'prompt-id', LlmRole.MAIN);
expect(mockWrapped.generateContent).toHaveBeenCalledWith(
{ model: 'gemini-3-flash', contents: [] },
'prompt-id',
LlmRole.MAIN,
);
});
it('maps matching model with models/ prefix', async () => {
const mockWrapped = {
generateContent: vi.fn().mockResolvedValue({}),
} as unknown as ContentGenerator;
const generator = new ModelMappingContentGenerator(
mockWrapped,
mockMappings,
);
const req = { model: 'models/gemini-3.5-flash', contents: [] };
await generator.generateContent(req, 'prompt-id', LlmRole.MAIN);
expect(mockWrapped.generateContent).toHaveBeenCalledWith(
{ model: 'models/gemini-3-flash', contents: [] },
'prompt-id',
LlmRole.MAIN,
);
});
it('leaves unmapped model unchanged', async () => {
const mockWrapped = {
generateContent: vi.fn().mockResolvedValue({}),
} as unknown as ContentGenerator;
const generator = new ModelMappingContentGenerator(
mockWrapped,
mockMappings,
);
const req = { model: 'unknown-model', contents: [] };
await generator.generateContent(req, 'prompt-id', LlmRole.MAIN);
expect(mockWrapped.generateContent).toHaveBeenCalledWith(
{ model: 'unknown-model', contents: [] },
'prompt-id',
LlmRole.MAIN,
);
});
it('leaves model with prefix unchanged if no match after normalization', async () => {
const mockWrapped = {
generateContent: vi.fn().mockResolvedValue({}),
} as unknown as ContentGenerator;
const generator = new ModelMappingContentGenerator(
mockWrapped,
mockMappings,
);
const req = { model: 'models/unknown-model', contents: [] };
await generator.generateContent(req, 'prompt-id', LlmRole.MAIN);
expect(mockWrapped.generateContent).toHaveBeenCalledWith(
{ model: 'models/unknown-model', contents: [] },
'prompt-id',
LlmRole.MAIN,
);
});
it('handles missing/undefined model property safely', async () => {
const mockWrapped = {
generateContent: vi.fn().mockResolvedValue({}),
} as unknown as ContentGenerator;
const generator = new ModelMappingContentGenerator(
mockWrapped,
mockMappings,
);
const req = { contents: [] } as unknown as GenerateContentParameters;
await generator.generateContent(req, 'prompt-id', LlmRole.MAIN);
expect(mockWrapped.generateContent).toHaveBeenCalledWith(
{ contents: [] },
'prompt-id',
LlmRole.MAIN,
);
});
});
@@ -0,0 +1,88 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
type CountTokensResponse,
type GenerateContentResponse,
type GenerateContentParameters,
type CountTokensParameters,
type EmbedContentResponse,
type EmbedContentParameters,
} from '@google/genai';
import { type ContentGenerator } from './contentGenerator.js';
import type { LlmRole } from '../telemetry/llmRole.js';
import type { UserTierId, GeminiUserTier } from '../code_assist/types.js';
import { normalizeModelId } from '../utils/modelUtils.js';
export class ModelMappingContentGenerator implements ContentGenerator {
constructor(
private readonly wrapped: ContentGenerator,
private readonly mappings: Record<string, string>,
) {}
getWrapped(): ContentGenerator {
return this.wrapped;
}
get userTier(): UserTierId | undefined {
return this.wrapped.userTier;
}
get userTierName(): string | undefined {
return this.wrapped.userTierName;
}
get paidTier(): GeminiUserTier | undefined {
return this.wrapped.paidTier;
}
private mapModel<T extends { model?: string }>(req: T): T {
if (req.model) {
const normalizedModel = normalizeModelId(req.model);
if (this.mappings[normalizedModel]) {
return {
...req,
model: req.model.startsWith('models/')
? `models/${this.mappings[normalizedModel]}`
: this.mappings[normalizedModel],
};
}
}
return req;
}
generateContent(
request: GenerateContentParameters,
userPromptId: string,
role: LlmRole,
): Promise<GenerateContentResponse> {
return this.wrapped.generateContent(
this.mapModel(request),
userPromptId,
role,
);
}
generateContentStream(
request: GenerateContentParameters,
userPromptId: string,
role: LlmRole,
): Promise<AsyncGenerator<GenerateContentResponse>> {
return this.wrapped.generateContentStream(
this.mapModel(request),
userPromptId,
role,
);
}
countTokens(request: CountTokensParameters): Promise<CountTokensResponse> {
return this.wrapped.countTokens(this.mapModel(request));
}
embedContent(request: EmbedContentParameters): Promise<EmbedContentResponse> {
return this.wrapped.embedContent(this.mapModel(request));
}
}