feat(mcp): Inject GoogleCredentialProvider headers in McpClient (#13783)

This commit is contained in:
sai-sunder-s
2025-11-26 12:08:19 -08:00
committed by GitHub
parent 3406dc5b2e
commit 0f12d6c426
7 changed files with 184 additions and 8 deletions

View File

@@ -18,3 +18,4 @@ eslint.config.js
gha-creds-*.json
junit.xml
Thumbs.db
.pytest_cache

View File

@@ -0,0 +1,18 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
/**
* Extension of OAuthClientProvider that allows providers to inject custom headers
* into the transport request.
*/
export interface McpAuthProvider extends OAuthClientProvider {
/**
* Returns custom headers to be added to the request.
*/
getRequestHeaders?(): Promise<Record<string, string>>;
}

View File

@@ -86,6 +86,7 @@ describe('GoogleCredentialProvider', () => {
let mockClient: {
getAccessToken: Mock;
credentials?: { expiry_date: number | null };
quotaProjectId?: string;
};
beforeEach(() => {
@@ -153,5 +154,58 @@ describe('GoogleCredentialProvider', () => {
vi.useRealTimers();
});
it('should return quota project ID', async () => {
mockClient['quotaProjectId'] = 'test-project-id';
const quotaProjectId = await provider.getQuotaProjectId();
expect(quotaProjectId).toBe('test-project-id');
});
it('should return request headers with quota project ID', async () => {
mockClient['quotaProjectId'] = 'test-project-id';
const headers = await provider.getRequestHeaders();
expect(headers).toEqual({
'X-Goog-User-Project': 'test-project-id',
});
});
it('should return empty request headers if quota project ID is missing', async () => {
mockClient['quotaProjectId'] = undefined;
const headers = await provider.getRequestHeaders();
expect(headers).toEqual({});
});
it('should prioritize config headers over quota project ID', async () => {
mockClient['quotaProjectId'] = 'quota-project-id';
const configWithHeaders = {
...validConfig,
headers: {
'X-Goog-User-Project': 'config-project-id',
},
};
const providerWithHeaders = new GoogleCredentialProvider(
configWithHeaders,
);
const headers = await providerWithHeaders.getRequestHeaders();
expect(headers).toEqual({
'X-Goog-User-Project': 'config-project-id',
});
});
it('should prioritize config headers over quota project ID (case-insensitive)', async () => {
mockClient['quotaProjectId'] = 'quota-project-id';
const configWithHeaders = {
...validConfig,
headers: {
'x-goog-user-project': 'config-project-id',
},
};
const providerWithHeaders = new GoogleCredentialProvider(
configWithHeaders,
);
const headers = await providerWithHeaders.getRequestHeaders();
expect(headers).toEqual({
'x-goog-user-project': 'config-project-id',
});
});
});
});

View File

@@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
import type { McpAuthProvider } from './auth-provider.js';
import type {
OAuthClientInformation,
OAuthClientInformationFull,
@@ -18,7 +18,7 @@ import { coreEvents } from '../utils/events.js';
const ALLOWED_HOSTS = [/^.+\.googleapis\.com$/, /^(.*\.)?luci\.app$/];
export class GoogleCredentialProvider implements OAuthClientProvider {
export class GoogleCredentialProvider implements McpAuthProvider {
private readonly auth: GoogleAuth;
private cachedToken?: OAuthTokens;
private tokenExpiryTime?: number;
@@ -123,4 +123,35 @@ export class GoogleCredentialProvider implements OAuthClientProvider {
// No-op
return '';
}
/**
* Returns the project ID used for quota.
*/
async getQuotaProjectId(): Promise<string | undefined> {
const client = await this.auth.getClient();
return client.quotaProjectId;
}
/**
* Returns custom headers to be added to the request.
*/
async getRequestHeaders(): Promise<Record<string, string>> {
const headers: Record<string, string> = {};
const configHeaders = this.config?.headers ?? {};
const userProjectHeaderKey = Object.keys(configHeaders).find(
(key) => key.toLowerCase() === 'x-goog-user-project',
);
// If the header is present in the config (case-insensitive check), use the
// config's key and value. This prevents duplicate headers (e.g.
// 'x-goog-user-project' and 'X-Goog-User-Project') which can cause errors.
if (userProjectHeaderKey) {
headers[userProjectHeaderKey] = configHeaders[userProjectHeaderKey];
} else {
const quotaProjectId = await this.getQuotaProjectId();
if (quotaProjectId) {
headers['X-Goog-User-Project'] = quotaProjectId;
}
}
return headers;
}
}

View File

@@ -13,7 +13,7 @@ import type {
import { GoogleAuth } from 'google-auth-library';
import { OAuthUtils, FIVE_MIN_BUFFER_MS } from './oauth-utils.js';
import type { MCPServerConfig } from '../config/config.js';
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
import type { McpAuthProvider } from './auth-provider.js';
import { coreEvents } from '../utils/events.js';
function createIamApiUrl(targetSA: string): string {
@@ -22,9 +22,7 @@ function createIamApiUrl(targetSA: string): string {
)}:generateIdToken`;
}
export class ServiceAccountImpersonationProvider
implements OAuthClientProvider
{
export class ServiceAccountImpersonationProvider implements McpAuthProvider {
private readonly targetServiceAccount: string;
private readonly targetAudience: string; // OAuth Client Id
private readonly auth: GoogleAuth;

View File

@@ -36,6 +36,8 @@ vi.mock('@google/genai');
vi.mock('../mcp/oauth-provider.js');
vi.mock('../mcp/oauth-token-storage.js');
vi.mock('../mcp/oauth-utils.js');
vi.mock('google-auth-library');
import { GoogleAuth } from 'google-auth-library';
vi.mock('../utils/events.js', () => ({
coreEvents: {
@@ -578,6 +580,16 @@ describe('mcp-client', () => {
});
describe('useGoogleCredentialProvider', () => {
beforeEach(() => {
// Mock GoogleAuth client
const mockClient = {
getAccessToken: vi.fn().mockResolvedValue({ token: 'test-token' }),
quotaProjectId: 'myproject',
};
GoogleAuth.prototype.getClient = vi.fn().mockResolvedValue(mockClient);
});
it('should use GoogleCredentialProvider when specified', async () => {
const transport = await createTransport(
'test-server',
@@ -605,6 +617,64 @@ describe('mcp-client', () => {
expect(googUserProject).toBe('myproject');
});
it('should use headers from GoogleCredentialProvider', async () => {
const mockGetRequestHeaders = vi.fn().mockResolvedValue({
'X-Goog-User-Project': 'provider-project',
});
vi.spyOn(
GoogleCredentialProvider.prototype,
'getRequestHeaders',
).mockImplementation(mockGetRequestHeaders);
const transport = await createTransport(
'test-server',
{
httpUrl: 'http://test.googleapis.com',
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
oauth: {
scopes: ['scope1'],
},
},
false,
);
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
expect(mockGetRequestHeaders).toHaveBeenCalled();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const headers = (transport as any)._requestInit?.headers;
expect(headers['X-Goog-User-Project']).toBe('provider-project');
});
it('should prioritize provider headers over config headers', async () => {
const mockGetRequestHeaders = vi.fn().mockResolvedValue({
'X-Goog-User-Project': 'provider-project',
});
vi.spyOn(
GoogleCredentialProvider.prototype,
'getRequestHeaders',
).mockImplementation(mockGetRequestHeaders);
const transport = await createTransport(
'test-server',
{
httpUrl: 'http://test.googleapis.com',
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
oauth: {
scopes: ['scope1'],
},
headers: {
'X-Goog-User-Project': 'config-project',
},
},
false,
);
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const headers = (transport as any)._requestInit?.headers;
expect(headers['X-Goog-User-Project']).toBe('provider-project');
});
it('should use GoogleCredentialProvider with SSE transport', async () => {
const transport = await createTransport(
'test-server',

View File

@@ -35,6 +35,7 @@ import { DiscoveredMCPTool } from './mcp-tool.js';
import type { CallableTool, FunctionCall, Part, Tool } from '@google/genai';
import { basename } from 'node:path';
import { pathToFileURL } from 'node:url';
import type { McpAuthProvider } from '../mcp/auth-provider.js';
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
import { OAuthUtils } from '../mcp/oauth-utils.js';
@@ -425,7 +426,9 @@ function createTransportRequestInit(
*
* @param mcpServerConfig The MCP server configuration
*/
function createAuthProvider(mcpServerConfig: MCPServerConfig) {
function createAuthProvider(
mcpServerConfig: MCPServerConfig,
): McpAuthProvider | undefined {
if (
mcpServerConfig.authProviderType ===
AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION
@@ -1333,8 +1336,9 @@ export async function createTransport(
if (mcpServerConfig.httpUrl || mcpServerConfig.url) {
const authProvider = createAuthProvider(mcpServerConfig);
const headers: Record<string, string> =
(await authProvider?.getRequestHeaders?.()) ?? {};
const headers: Record<string, string> = {};
if (authProvider === undefined) {
// Check if we have OAuth configuration or stored tokens
let accessToken: string | null = null;