diff --git a/.prettierignore b/.prettierignore index f4330b7e68..7b8a75a110 100644 --- a/.prettierignore +++ b/.prettierignore @@ -18,3 +18,4 @@ eslint.config.js gha-creds-*.json junit.xml Thumbs.db +.pytest_cache diff --git a/packages/core/src/mcp/auth-provider.ts b/packages/core/src/mcp/auth-provider.ts new file mode 100644 index 0000000000..6706a2130b --- /dev/null +++ b/packages/core/src/mcp/auth-provider.ts @@ -0,0 +1,18 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js'; + +/** + * Extension of OAuthClientProvider that allows providers to inject custom headers + * into the transport request. + */ +export interface McpAuthProvider extends OAuthClientProvider { + /** + * Returns custom headers to be added to the request. + */ + getRequestHeaders?(): Promise>; +} diff --git a/packages/core/src/mcp/google-auth-provider.test.ts b/packages/core/src/mcp/google-auth-provider.test.ts index efe959ff3c..8a25f15ad7 100644 --- a/packages/core/src/mcp/google-auth-provider.test.ts +++ b/packages/core/src/mcp/google-auth-provider.test.ts @@ -86,6 +86,7 @@ describe('GoogleCredentialProvider', () => { let mockClient: { getAccessToken: Mock; credentials?: { expiry_date: number | null }; + quotaProjectId?: string; }; beforeEach(() => { @@ -153,5 +154,58 @@ describe('GoogleCredentialProvider', () => { vi.useRealTimers(); }); + + it('should return quota project ID', async () => { + mockClient['quotaProjectId'] = 'test-project-id'; + const quotaProjectId = await provider.getQuotaProjectId(); + expect(quotaProjectId).toBe('test-project-id'); + }); + + it('should return request headers with quota project ID', async () => { + mockClient['quotaProjectId'] = 'test-project-id'; + const headers = await provider.getRequestHeaders(); + expect(headers).toEqual({ + 'X-Goog-User-Project': 'test-project-id', + }); + }); + + it('should return empty request headers if quota project ID is missing', async () => { + mockClient['quotaProjectId'] = undefined; + const headers = await provider.getRequestHeaders(); + expect(headers).toEqual({}); + }); + + it('should prioritize config headers over quota project ID', async () => { + mockClient['quotaProjectId'] = 'quota-project-id'; + const configWithHeaders = { + ...validConfig, + headers: { + 'X-Goog-User-Project': 'config-project-id', + }, + }; + const providerWithHeaders = new GoogleCredentialProvider( + configWithHeaders, + ); + const headers = await providerWithHeaders.getRequestHeaders(); + expect(headers).toEqual({ + 'X-Goog-User-Project': 'config-project-id', + }); + }); + it('should prioritize config headers over quota project ID (case-insensitive)', async () => { + mockClient['quotaProjectId'] = 'quota-project-id'; + const configWithHeaders = { + ...validConfig, + headers: { + 'x-goog-user-project': 'config-project-id', + }, + }; + const providerWithHeaders = new GoogleCredentialProvider( + configWithHeaders, + ); + const headers = await providerWithHeaders.getRequestHeaders(); + expect(headers).toEqual({ + 'x-goog-user-project': 'config-project-id', + }); + }); }); }); diff --git a/packages/core/src/mcp/google-auth-provider.ts b/packages/core/src/mcp/google-auth-provider.ts index 6196914cb4..a4f61c3139 100644 --- a/packages/core/src/mcp/google-auth-provider.ts +++ b/packages/core/src/mcp/google-auth-provider.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js'; +import type { McpAuthProvider } from './auth-provider.js'; import type { OAuthClientInformation, OAuthClientInformationFull, @@ -18,7 +18,7 @@ import { coreEvents } from '../utils/events.js'; const ALLOWED_HOSTS = [/^.+\.googleapis\.com$/, /^(.*\.)?luci\.app$/]; -export class GoogleCredentialProvider implements OAuthClientProvider { +export class GoogleCredentialProvider implements McpAuthProvider { private readonly auth: GoogleAuth; private cachedToken?: OAuthTokens; private tokenExpiryTime?: number; @@ -123,4 +123,35 @@ export class GoogleCredentialProvider implements OAuthClientProvider { // No-op return ''; } + /** + * Returns the project ID used for quota. + */ + async getQuotaProjectId(): Promise { + const client = await this.auth.getClient(); + return client.quotaProjectId; + } + + /** + * Returns custom headers to be added to the request. + */ + async getRequestHeaders(): Promise> { + const headers: Record = {}; + const configHeaders = this.config?.headers ?? {}; + const userProjectHeaderKey = Object.keys(configHeaders).find( + (key) => key.toLowerCase() === 'x-goog-user-project', + ); + + // If the header is present in the config (case-insensitive check), use the + // config's key and value. This prevents duplicate headers (e.g. + // 'x-goog-user-project' and 'X-Goog-User-Project') which can cause errors. + if (userProjectHeaderKey) { + headers[userProjectHeaderKey] = configHeaders[userProjectHeaderKey]; + } else { + const quotaProjectId = await this.getQuotaProjectId(); + if (quotaProjectId) { + headers['X-Goog-User-Project'] = quotaProjectId; + } + } + return headers; + } } diff --git a/packages/core/src/mcp/sa-impersonation-provider.ts b/packages/core/src/mcp/sa-impersonation-provider.ts index 2b9516d0d4..837601c0db 100644 --- a/packages/core/src/mcp/sa-impersonation-provider.ts +++ b/packages/core/src/mcp/sa-impersonation-provider.ts @@ -13,7 +13,7 @@ import type { import { GoogleAuth } from 'google-auth-library'; import { OAuthUtils, FIVE_MIN_BUFFER_MS } from './oauth-utils.js'; import type { MCPServerConfig } from '../config/config.js'; -import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js'; +import type { McpAuthProvider } from './auth-provider.js'; import { coreEvents } from '../utils/events.js'; function createIamApiUrl(targetSA: string): string { @@ -22,9 +22,7 @@ function createIamApiUrl(targetSA: string): string { )}:generateIdToken`; } -export class ServiceAccountImpersonationProvider - implements OAuthClientProvider -{ +export class ServiceAccountImpersonationProvider implements McpAuthProvider { private readonly targetServiceAccount: string; private readonly targetAudience: string; // OAuth Client Id private readonly auth: GoogleAuth; diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 249ebc4c33..15e9fd8cb3 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -36,6 +36,8 @@ vi.mock('@google/genai'); vi.mock('../mcp/oauth-provider.js'); vi.mock('../mcp/oauth-token-storage.js'); vi.mock('../mcp/oauth-utils.js'); +vi.mock('google-auth-library'); +import { GoogleAuth } from 'google-auth-library'; vi.mock('../utils/events.js', () => ({ coreEvents: { @@ -578,6 +580,16 @@ describe('mcp-client', () => { }); describe('useGoogleCredentialProvider', () => { + beforeEach(() => { + // Mock GoogleAuth client + const mockClient = { + getAccessToken: vi.fn().mockResolvedValue({ token: 'test-token' }), + quotaProjectId: 'myproject', + }; + + GoogleAuth.prototype.getClient = vi.fn().mockResolvedValue(mockClient); + }); + it('should use GoogleCredentialProvider when specified', async () => { const transport = await createTransport( 'test-server', @@ -605,6 +617,64 @@ describe('mcp-client', () => { expect(googUserProject).toBe('myproject'); }); + it('should use headers from GoogleCredentialProvider', async () => { + const mockGetRequestHeaders = vi.fn().mockResolvedValue({ + 'X-Goog-User-Project': 'provider-project', + }); + vi.spyOn( + GoogleCredentialProvider.prototype, + 'getRequestHeaders', + ).mockImplementation(mockGetRequestHeaders); + + const transport = await createTransport( + 'test-server', + { + httpUrl: 'http://test.googleapis.com', + authProviderType: AuthProviderType.GOOGLE_CREDENTIALS, + oauth: { + scopes: ['scope1'], + }, + }, + false, + ); + + expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); + expect(mockGetRequestHeaders).toHaveBeenCalled(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const headers = (transport as any)._requestInit?.headers; + expect(headers['X-Goog-User-Project']).toBe('provider-project'); + }); + + it('should prioritize provider headers over config headers', async () => { + const mockGetRequestHeaders = vi.fn().mockResolvedValue({ + 'X-Goog-User-Project': 'provider-project', + }); + vi.spyOn( + GoogleCredentialProvider.prototype, + 'getRequestHeaders', + ).mockImplementation(mockGetRequestHeaders); + + const transport = await createTransport( + 'test-server', + { + httpUrl: 'http://test.googleapis.com', + authProviderType: AuthProviderType.GOOGLE_CREDENTIALS, + oauth: { + scopes: ['scope1'], + }, + headers: { + 'X-Goog-User-Project': 'config-project', + }, + }, + false, + ); + + expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const headers = (transport as any)._requestInit?.headers; + expect(headers['X-Goog-User-Project']).toBe('provider-project'); + }); + it('should use GoogleCredentialProvider with SSE transport', async () => { const transport = await createTransport( 'test-server', diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 0968dc8702..c4f9cc390e 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -35,6 +35,7 @@ import { DiscoveredMCPTool } from './mcp-tool.js'; import type { CallableTool, FunctionCall, Part, Tool } from '@google/genai'; import { basename } from 'node:path'; import { pathToFileURL } from 'node:url'; +import type { McpAuthProvider } from '../mcp/auth-provider.js'; import { MCPOAuthProvider } from '../mcp/oauth-provider.js'; import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js'; import { OAuthUtils } from '../mcp/oauth-utils.js'; @@ -425,7 +426,9 @@ function createTransportRequestInit( * * @param mcpServerConfig The MCP server configuration */ -function createAuthProvider(mcpServerConfig: MCPServerConfig) { +function createAuthProvider( + mcpServerConfig: MCPServerConfig, +): McpAuthProvider | undefined { if ( mcpServerConfig.authProviderType === AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION @@ -1333,8 +1336,9 @@ export async function createTransport( if (mcpServerConfig.httpUrl || mcpServerConfig.url) { const authProvider = createAuthProvider(mcpServerConfig); + const headers: Record = + (await authProvider?.getRequestHeaders?.()) ?? {}; - const headers: Record = {}; if (authProvider === undefined) { // Check if we have OAuth configuration or stored tokens let accessToken: string | null = null;