feat(vertex): add settings for Vertex AI request routing (#25513)

This commit is contained in:
Gordon Hui
2026-04-22 01:48:30 +08:00
committed by GitHub
parent aee2cde1a3
commit 27344833cb
9 changed files with 210 additions and 0 deletions
+25
View File
@@ -836,12 +836,37 @@ describe('Server Config (config.ts)', () => {
undefined,
undefined,
undefined,
undefined,
);
// Verify that contentGeneratorConfig is updated
expect(config.getContentGeneratorConfig()).toEqual(mockContentConfig);
expect(GeminiClient).toHaveBeenCalledWith(config);
});
it('should pass Vertex AI routing settings when refreshing auth', async () => {
const vertexAiRouting = {
requestType: 'shared' as const,
sharedRequestType: 'priority' as const,
};
const config = new Config({
...baseParams,
vertexAiRouting,
});
vi.mocked(createContentGeneratorConfig).mockResolvedValue({});
await config.refreshAuth(AuthType.USE_VERTEX_AI);
expect(createContentGeneratorConfig).toHaveBeenCalledWith(
config,
AuthType.USE_VERTEX_AI,
undefined,
undefined,
undefined,
vertexAiRouting,
);
});
it('should reset model availability status', async () => {
const config = new Config(baseParams);
const service = config.getModelAvailabilityService();
+5
View File
@@ -23,6 +23,7 @@ import {
createContentGeneratorConfig,
type ContentGenerator,
type ContentGeneratorConfig,
type VertexAiRoutingConfig,
} from '../core/contentGenerator.js';
import type { OverageStrategy } from '../billing/billing.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
@@ -731,6 +732,7 @@ export interface ConfigParameters {
billing?: {
overageStrategy?: OverageStrategy;
};
vertexAiRouting?: VertexAiRoutingConfig;
}
export class Config implements McpContext, AgentLoopContext {
@@ -936,6 +938,7 @@ export class Config implements McpContext, AgentLoopContext {
private readonly billing: {
overageStrategy: OverageStrategy;
};
private readonly vertexAiRouting: VertexAiRoutingConfig | undefined;
private readonly enableAgents: boolean;
private agents: AgentSettings;
@@ -1362,6 +1365,7 @@ export class Config implements McpContext, AgentLoopContext {
this.billing = {
overageStrategy: params.billing?.overageStrategy ?? 'ask',
};
this.vertexAiRouting = params.vertexAiRouting;
if (params.contextFileName) {
setGeminiMdFilename(params.contextFileName);
@@ -1549,6 +1553,7 @@ export class Config implements McpContext, AgentLoopContext {
apiKey,
baseUrl,
customHeaders,
this.vertexAiRouting,
);
this.contentGenerator = await createContentGenerator(
newContentGeneratorConfig,
@@ -385,6 +385,44 @@ describe('createContentGenerator', () => {
);
});
it('should include Vertex AI routing headers for Vertex AI requests', async () => {
const mockConfig = {
getModel: vi.fn().mockReturnValue('gemini-pro'),
getProxy: vi.fn().mockReturnValue(undefined),
getUsageStatisticsEnabled: () => false,
getClientName: vi.fn().mockReturnValue(undefined),
} as unknown as Config;
const mockGenerator = {
models: {},
} as unknown as GoogleGenAI;
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never);
await createContentGenerator(
{
apiKey: 'test-api-key',
vertexai: true,
authType: AuthType.USE_VERTEX_AI,
vertexAiRouting: {
requestType: 'shared',
sharedRequestType: 'priority',
},
},
mockConfig,
);
expect(GoogleGenAI).toHaveBeenCalledWith(
expect.objectContaining({
httpOptions: expect.objectContaining({
headers: expect.objectContaining({
'X-Vertex-AI-LLM-Request-Type': 'shared',
'X-Vertex-AI-LLM-Shared-Request-Type': 'priority',
}),
}),
}),
);
});
it('should pass api key as Authorization Header when GEMINI_API_KEY_AUTH_MECHANISM is set to bearer', async () => {
const mockConfig = {
getModel: vi.fn().mockReturnValue('gemini-pro'),
@@ -887,6 +925,25 @@ describe('createContentGeneratorConfig', () => {
expect(config.vertexai).toBe(true);
});
it('should include Vertex AI routing settings in content generator config', async () => {
vi.stubEnv('GOOGLE_API_KEY', 'env-google-key');
const vertexAiRouting = {
requestType: 'shared' as const,
sharedRequestType: 'priority' as const,
};
const config = await createContentGeneratorConfig(
mockConfig,
AuthType.USE_VERTEX_AI,
undefined,
undefined,
undefined,
vertexAiRouting,
);
expect(config.vertexAiRouting).toEqual(vertexAiRouting);
});
it('should configure for Vertex AI using GCP project and location when set', async () => {
vi.stubEnv('GOOGLE_API_KEY', undefined);
vi.stubEnv('GOOGLE_CLOUD_PROJECT', 'env-gcp-project');
@@ -99,9 +99,21 @@ export type ContentGeneratorConfig = {
proxy?: string;
baseUrl?: string;
customHeaders?: Record<string, string>;
vertexAiRouting?: VertexAiRoutingConfig;
};
export type VertexAiRequestType = 'dedicated' | 'shared';
export type VertexAiSharedRequestType = 'priority' | 'flex';
export interface VertexAiRoutingConfig {
requestType?: VertexAiRequestType;
sharedRequestType?: VertexAiSharedRequestType;
}
const LOCAL_HOSTNAMES = ['localhost', '127.0.0.1', '[::1]'];
const VERTEX_AI_REQUEST_TYPE_HEADER = 'X-Vertex-AI-LLM-Request-Type';
const VERTEX_AI_SHARED_REQUEST_TYPE_HEADER =
'X-Vertex-AI-LLM-Shared-Request-Type';
function validateBaseUrl(baseUrl: string): void {
let url: URL;
@@ -122,6 +134,7 @@ export async function createContentGeneratorConfig(
apiKey?: string,
baseUrl?: string,
customHeaders?: Record<string, string>,
vertexAiRouting?: VertexAiRoutingConfig,
): Promise<ContentGeneratorConfig> {
const geminiApiKey =
apiKey ||
@@ -140,6 +153,7 @@ export async function createContentGeneratorConfig(
proxy: config?.getProxy(),
baseUrl,
customHeaders,
vertexAiRouting,
};
// If we are using Google auth or we are in Cloud Shell, there is nothing else to validate for now
@@ -280,6 +294,21 @@ export async function createContentGenerator(
if (config.customHeaders) {
headers = { ...headers, ...config.customHeaders };
}
if (
config.authType === AuthType.USE_VERTEX_AI &&
config.vertexAiRouting
) {
const { requestType, sharedRequestType } = config.vertexAiRouting;
headers = {
...headers,
...(requestType
? { [VERTEX_AI_REQUEST_TYPE_HEADER]: requestType }
: {}),
...(sharedRequestType
? { [VERTEX_AI_SHARED_REQUEST_TYPE_HEADER]: sharedRequestType }
: {}),
};
}
if (gcConfig?.getUsageStatisticsEnabled()) {
const installationManager = new InstallationManager();
const installationId = installationManager.getInstallationId();