diff --git a/packages/cli/src/zed-integration/zedIntegration.test.ts b/packages/cli/src/zed-integration/zedIntegration.test.ts index 45e115ca9e..a86db9b272 100644 --- a/packages/cli/src/zed-integration/zedIntegration.test.ts +++ b/packages/cli/src/zed-integration/zedIntegration.test.ts @@ -207,7 +207,7 @@ describe('GeminiAgent', () => { }); expect(response.protocolVersion).toBe(acp.PROTOCOL_VERSION); - expect(response.authMethods).toHaveLength(3); + expect(response.authMethods).toHaveLength(4); const geminiAuth = response.authMethods?.find( (m) => m.id === AuthType.USE_GEMINI, ); @@ -227,6 +227,8 @@ describe('GeminiAgent', () => { expect(mockConfig.refreshAuth).toHaveBeenCalledWith( AuthType.LOGIN_WITH_GOOGLE, undefined, + undefined, + undefined, ); expect(mockSettings.setValue).toHaveBeenCalledWith( SettingScope.User, @@ -246,6 +248,8 @@ describe('GeminiAgent', () => { expect(mockConfig.refreshAuth).toHaveBeenCalledWith( AuthType.USE_GEMINI, 'test-api-key', + undefined, + undefined, ); expect(mockSettings.setValue).toHaveBeenCalledWith( SettingScope.User, @@ -267,6 +271,37 @@ describe('GeminiAgent', () => { AuthType.USE_GEMINI, 'test-api-key', 'https://custom.api.endpoint', + undefined, + ); + expect(mockSettings.setValue).toHaveBeenCalledWith( + SettingScope.User, + 'security.auth.selectedType', + AuthType.USE_GEMINI, + ); + }); + + it('should authenticate correctly with gateway method', async () => { + await agent.authenticate({ + methodId: 'gateway', + _meta: { + gateway: { + baseUrl: 'https://gateway.example.com', + headers: { + Authorization: 'Bearer test-token', + 'X-Custom-Header': 'custom-value', + }, + }, + }, + } as unknown as acp.AuthenticateRequest); + + expect(mockConfig.refreshAuth).toHaveBeenCalledWith( + AuthType.USE_GEMINI, + undefined, + 'https://gateway.example.com', + { + Authorization: 'Bearer test-token', + 'X-Custom-Header': 'custom-value', + }, ); expect(mockSettings.setValue).toHaveBeenCalledWith( SettingScope.User, diff --git a/packages/cli/src/zed-integration/zedIntegration.ts b/packages/cli/src/zed-integration/zedIntegration.ts index 8cf630dc3b..2ccc8a57c6 100644 --- a/packages/cli/src/zed-integration/zedIntegration.ts +++ b/packages/cli/src/zed-integration/zedIntegration.ts @@ -117,14 +117,14 @@ export class GeminiAgent { }, }, { - id: 'gemini-custom-url', - name: 'Gemini API (Custom URL)', - description: 'Use an API key and custom base URL', + id: 'gateway', + name: 'AI API Gateway', + description: 'Use a custom AI API Gateway', _meta: { - 'api-key': { - provider: 'google', + gateway: { + protocol: 'google', + restartRequired: 'false', }, - 'base-url': {}, }, }, { @@ -177,8 +177,28 @@ export class GeminiAgent { const meta = hasMeta(req) ? req._meta : undefined; const apiKey = typeof meta?.['api-key'] === 'string' ? meta['api-key'] : undefined; - const baseUrl = + let baseUrl = typeof meta?.['base-url'] === 'string' ? meta['base-url'] : undefined; + let headers: Record | undefined; + + if (methodId === 'gateway') { + method = AuthType.USE_GEMINI; + // Gateway specific handling + const gatewaySchema = z + .object({ + baseUrl: z.string().optional(), + headers: z.record(z.string()).optional(), + }) + .optional(); + const gatewayParams = gatewaySchema.parse(meta?.['gateway']); + + if (gatewayParams?.baseUrl) { + baseUrl = gatewayParams.baseUrl; + } + if (gatewayParams?.headers) { + headers = gatewayParams.headers; + } + } // Refresh auth with the requested method // This will reuse existing credentials if they're valid, @@ -194,6 +214,7 @@ export class GeminiAgent { method, apiKey ?? this.apiKey, baseUrl ?? this.baseUrl, + headers, ); } catch (e) { throw new acp.RequestError(-32000, getAcpErrorMessage(e)); diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index e92f464fa2..c08d88ae20 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -500,6 +500,8 @@ describe('Server Config (config.ts)', () => { config, authType, undefined, + undefined, + undefined, ); // Verify that contentGeneratorConfig is updated expect(config.getContentGeneratorConfig()).toEqual(mockContentConfig); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 7751daa269..17fd97678e 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -1126,7 +1126,12 @@ export class Config { return this.contentGenerator; } - async refreshAuth(authMethod: AuthType, apiKey?: string, baseUrl?: string) { + async refreshAuth( + authMethod: AuthType, + apiKey?: string, + baseUrl?: string, + headers?: Record, + ) { // Reset availability service when switching auth this.modelAvailabilityService.reset(); @@ -1154,6 +1159,7 @@ export class Config { authMethod, apiKey, baseUrl, + headers, ); this.contentGenerator = await createContentGenerator( newContentGeneratorConfig, diff --git a/packages/core/src/core/contentGenerator.test.ts b/packages/core/src/core/contentGenerator.test.ts index 9b7c3ac802..2b19781e49 100644 --- a/packages/core/src/core/contentGenerator.test.ts +++ b/packages/core/src/core/contentGenerator.test.ts @@ -174,7 +174,6 @@ describe('createContentGenerator', () => { headers: expect.objectContaining({ 'User-Agent': expect.any(String), 'X-Test-Header': 'test-value', - 'Another-Header': 'another value', }), }, AuthType.LOGIN_WITH_GOOGLE, @@ -475,6 +474,64 @@ describe('createContentGenerator', () => { apiVersion: 'v1alpha', }); }); + + it('should include config.headers for Gemini auth but not for others', async () => { + const mockConfig = { + getModel: vi.fn().mockReturnValue('gemini-pro'), + getProxy: vi.fn().mockReturnValue(undefined), + getUsageStatisticsEnabled: () => false, + } as unknown as Config; + + const mockGenerator = { + models: {}, + } as unknown as GoogleGenAI; + vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never); + + // Test USE_GEMINI gets headers + await createContentGenerator( + { + apiKey: 'test-api-key', + authType: AuthType.USE_GEMINI, + headers: { 'X-Custom-Config-Header': 'custom-value' }, + }, + mockConfig, + ); + + expect(GoogleGenAI).toHaveBeenCalledWith( + expect.objectContaining({ + httpOptions: expect.objectContaining({ + headers: expect.objectContaining({ + 'X-Custom-Config-Header': 'custom-value', + }), + }), + }), + ); + + // Test LOGIN_WITH_GOOGLE does NOT get headers (unless they were in baseHeaders, which they are not) + const mockCodeAssistGenerator = {} as unknown as ContentGenerator; + vi.mocked(createCodeAssistContentGenerator).mockResolvedValue( + mockCodeAssistGenerator as never, + ); + + await createContentGenerator( + { + authType: AuthType.LOGIN_WITH_GOOGLE, + headers: { 'X-Should-Not-Be-Here': 'nope' }, + }, + mockConfig, + ); + + expect(createCodeAssistContentGenerator).toHaveBeenCalledWith( + expect.objectContaining({ + headers: expect.not.objectContaining({ + 'X-Should-Not-Be-Here': 'nope', + }), + }), + AuthType.LOGIN_WITH_GOOGLE, + mockConfig, + undefined, + ); + }); }); describe('createContentGeneratorConfig', () => { diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index b96eb4f8fe..2ea6096c9e 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -86,6 +86,7 @@ export type ContentGeneratorConfig = { vertexai?: boolean; authType?: AuthType; proxy?: string; + headers?: Record; }; export async function createContentGeneratorConfig( @@ -93,6 +94,7 @@ export async function createContentGeneratorConfig( authType: AuthType | undefined, apiKey?: string, baseUrl?: string, + headers?: Record, ): Promise { const geminiApiKey = apiKey || @@ -109,6 +111,7 @@ export async function createContentGeneratorConfig( const contentGeneratorConfig: ContentGeneratorConfig = { authType, proxy: config?.getProxy(), + headers, }; // If we are using Google auth or we are in Cloud Shell, there is nothing else to validate for now @@ -200,7 +203,10 @@ export async function createContentGenerator( config.authType === AuthType.USE_GEMINI || config.authType === AuthType.USE_VERTEX_AI ) { - let headers: Record = { ...baseHeaders }; + let headers: Record = { + ...baseHeaders, + ...config.headers, + }; if (gcConfig?.getUsageStatisticsEnabled()) { const installationManager = new InstallationManager(); const installationId = installationManager.getInstallationId();