mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 14:10:37 -07:00
feat(mcp): Inject GoogleCredentialProvider headers in McpClient (#13783)
This commit is contained in:
@@ -18,3 +18,4 @@ eslint.config.js
|
||||
gha-creds-*.json
|
||||
junit.xml
|
||||
Thumbs.db
|
||||
.pytest_cache
|
||||
|
||||
18
packages/core/src/mcp/auth-provider.ts
Normal file
18
packages/core/src/mcp/auth-provider.ts
Normal file
@@ -0,0 +1,18 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
|
||||
|
||||
/**
|
||||
* Extension of OAuthClientProvider that allows providers to inject custom headers
|
||||
* into the transport request.
|
||||
*/
|
||||
export interface McpAuthProvider extends OAuthClientProvider {
|
||||
/**
|
||||
* Returns custom headers to be added to the request.
|
||||
*/
|
||||
getRequestHeaders?(): Promise<Record<string, string>>;
|
||||
}
|
||||
@@ -86,6 +86,7 @@ describe('GoogleCredentialProvider', () => {
|
||||
let mockClient: {
|
||||
getAccessToken: Mock;
|
||||
credentials?: { expiry_date: number | null };
|
||||
quotaProjectId?: string;
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -153,5 +154,58 @@ describe('GoogleCredentialProvider', () => {
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should return quota project ID', async () => {
|
||||
mockClient['quotaProjectId'] = 'test-project-id';
|
||||
const quotaProjectId = await provider.getQuotaProjectId();
|
||||
expect(quotaProjectId).toBe('test-project-id');
|
||||
});
|
||||
|
||||
it('should return request headers with quota project ID', async () => {
|
||||
mockClient['quotaProjectId'] = 'test-project-id';
|
||||
const headers = await provider.getRequestHeaders();
|
||||
expect(headers).toEqual({
|
||||
'X-Goog-User-Project': 'test-project-id',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return empty request headers if quota project ID is missing', async () => {
|
||||
mockClient['quotaProjectId'] = undefined;
|
||||
const headers = await provider.getRequestHeaders();
|
||||
expect(headers).toEqual({});
|
||||
});
|
||||
|
||||
it('should prioritize config headers over quota project ID', async () => {
|
||||
mockClient['quotaProjectId'] = 'quota-project-id';
|
||||
const configWithHeaders = {
|
||||
...validConfig,
|
||||
headers: {
|
||||
'X-Goog-User-Project': 'config-project-id',
|
||||
},
|
||||
};
|
||||
const providerWithHeaders = new GoogleCredentialProvider(
|
||||
configWithHeaders,
|
||||
);
|
||||
const headers = await providerWithHeaders.getRequestHeaders();
|
||||
expect(headers).toEqual({
|
||||
'X-Goog-User-Project': 'config-project-id',
|
||||
});
|
||||
});
|
||||
it('should prioritize config headers over quota project ID (case-insensitive)', async () => {
|
||||
mockClient['quotaProjectId'] = 'quota-project-id';
|
||||
const configWithHeaders = {
|
||||
...validConfig,
|
||||
headers: {
|
||||
'x-goog-user-project': 'config-project-id',
|
||||
},
|
||||
};
|
||||
const providerWithHeaders = new GoogleCredentialProvider(
|
||||
configWithHeaders,
|
||||
);
|
||||
const headers = await providerWithHeaders.getRequestHeaders();
|
||||
expect(headers).toEqual({
|
||||
'x-goog-user-project': 'config-project-id',
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
|
||||
import type { McpAuthProvider } from './auth-provider.js';
|
||||
import type {
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
@@ -18,7 +18,7 @@ import { coreEvents } from '../utils/events.js';
|
||||
|
||||
const ALLOWED_HOSTS = [/^.+\.googleapis\.com$/, /^(.*\.)?luci\.app$/];
|
||||
|
||||
export class GoogleCredentialProvider implements OAuthClientProvider {
|
||||
export class GoogleCredentialProvider implements McpAuthProvider {
|
||||
private readonly auth: GoogleAuth;
|
||||
private cachedToken?: OAuthTokens;
|
||||
private tokenExpiryTime?: number;
|
||||
@@ -123,4 +123,35 @@ export class GoogleCredentialProvider implements OAuthClientProvider {
|
||||
// No-op
|
||||
return '';
|
||||
}
|
||||
/**
|
||||
* Returns the project ID used for quota.
|
||||
*/
|
||||
async getQuotaProjectId(): Promise<string | undefined> {
|
||||
const client = await this.auth.getClient();
|
||||
return client.quotaProjectId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns custom headers to be added to the request.
|
||||
*/
|
||||
async getRequestHeaders(): Promise<Record<string, string>> {
|
||||
const headers: Record<string, string> = {};
|
||||
const configHeaders = this.config?.headers ?? {};
|
||||
const userProjectHeaderKey = Object.keys(configHeaders).find(
|
||||
(key) => key.toLowerCase() === 'x-goog-user-project',
|
||||
);
|
||||
|
||||
// If the header is present in the config (case-insensitive check), use the
|
||||
// config's key and value. This prevents duplicate headers (e.g.
|
||||
// 'x-goog-user-project' and 'X-Goog-User-Project') which can cause errors.
|
||||
if (userProjectHeaderKey) {
|
||||
headers[userProjectHeaderKey] = configHeaders[userProjectHeaderKey];
|
||||
} else {
|
||||
const quotaProjectId = await this.getQuotaProjectId();
|
||||
if (quotaProjectId) {
|
||||
headers['X-Goog-User-Project'] = quotaProjectId;
|
||||
}
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import type {
|
||||
import { GoogleAuth } from 'google-auth-library';
|
||||
import { OAuthUtils, FIVE_MIN_BUFFER_MS } from './oauth-utils.js';
|
||||
import type { MCPServerConfig } from '../config/config.js';
|
||||
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
|
||||
import type { McpAuthProvider } from './auth-provider.js';
|
||||
import { coreEvents } from '../utils/events.js';
|
||||
|
||||
function createIamApiUrl(targetSA: string): string {
|
||||
@@ -22,9 +22,7 @@ function createIamApiUrl(targetSA: string): string {
|
||||
)}:generateIdToken`;
|
||||
}
|
||||
|
||||
export class ServiceAccountImpersonationProvider
|
||||
implements OAuthClientProvider
|
||||
{
|
||||
export class ServiceAccountImpersonationProvider implements McpAuthProvider {
|
||||
private readonly targetServiceAccount: string;
|
||||
private readonly targetAudience: string; // OAuth Client Id
|
||||
private readonly auth: GoogleAuth;
|
||||
|
||||
@@ -36,6 +36,8 @@ vi.mock('@google/genai');
|
||||
vi.mock('../mcp/oauth-provider.js');
|
||||
vi.mock('../mcp/oauth-token-storage.js');
|
||||
vi.mock('../mcp/oauth-utils.js');
|
||||
vi.mock('google-auth-library');
|
||||
import { GoogleAuth } from 'google-auth-library';
|
||||
|
||||
vi.mock('../utils/events.js', () => ({
|
||||
coreEvents: {
|
||||
@@ -578,6 +580,16 @@ describe('mcp-client', () => {
|
||||
});
|
||||
|
||||
describe('useGoogleCredentialProvider', () => {
|
||||
beforeEach(() => {
|
||||
// Mock GoogleAuth client
|
||||
const mockClient = {
|
||||
getAccessToken: vi.fn().mockResolvedValue({ token: 'test-token' }),
|
||||
quotaProjectId: 'myproject',
|
||||
};
|
||||
|
||||
GoogleAuth.prototype.getClient = vi.fn().mockResolvedValue(mockClient);
|
||||
});
|
||||
|
||||
it('should use GoogleCredentialProvider when specified', async () => {
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
@@ -605,6 +617,64 @@ describe('mcp-client', () => {
|
||||
expect(googUserProject).toBe('myproject');
|
||||
});
|
||||
|
||||
it('should use headers from GoogleCredentialProvider', async () => {
|
||||
const mockGetRequestHeaders = vi.fn().mockResolvedValue({
|
||||
'X-Goog-User-Project': 'provider-project',
|
||||
});
|
||||
vi.spyOn(
|
||||
GoogleCredentialProvider.prototype,
|
||||
'getRequestHeaders',
|
||||
).mockImplementation(mockGetRequestHeaders);
|
||||
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
httpUrl: 'http://test.googleapis.com',
|
||||
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
|
||||
oauth: {
|
||||
scopes: ['scope1'],
|
||||
},
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
expect(mockGetRequestHeaders).toHaveBeenCalled();
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const headers = (transport as any)._requestInit?.headers;
|
||||
expect(headers['X-Goog-User-Project']).toBe('provider-project');
|
||||
});
|
||||
|
||||
it('should prioritize provider headers over config headers', async () => {
|
||||
const mockGetRequestHeaders = vi.fn().mockResolvedValue({
|
||||
'X-Goog-User-Project': 'provider-project',
|
||||
});
|
||||
vi.spyOn(
|
||||
GoogleCredentialProvider.prototype,
|
||||
'getRequestHeaders',
|
||||
).mockImplementation(mockGetRequestHeaders);
|
||||
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
httpUrl: 'http://test.googleapis.com',
|
||||
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
|
||||
oauth: {
|
||||
scopes: ['scope1'],
|
||||
},
|
||||
headers: {
|
||||
'X-Goog-User-Project': 'config-project',
|
||||
},
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const headers = (transport as any)._requestInit?.headers;
|
||||
expect(headers['X-Goog-User-Project']).toBe('provider-project');
|
||||
});
|
||||
|
||||
it('should use GoogleCredentialProvider with SSE transport', async () => {
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
|
||||
@@ -35,6 +35,7 @@ import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import type { CallableTool, FunctionCall, Part, Tool } from '@google/genai';
|
||||
import { basename } from 'node:path';
|
||||
import { pathToFileURL } from 'node:url';
|
||||
import type { McpAuthProvider } from '../mcp/auth-provider.js';
|
||||
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
||||
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
||||
@@ -425,7 +426,9 @@ function createTransportRequestInit(
|
||||
*
|
||||
* @param mcpServerConfig The MCP server configuration
|
||||
*/
|
||||
function createAuthProvider(mcpServerConfig: MCPServerConfig) {
|
||||
function createAuthProvider(
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
): McpAuthProvider | undefined {
|
||||
if (
|
||||
mcpServerConfig.authProviderType ===
|
||||
AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION
|
||||
@@ -1333,8 +1336,9 @@ export async function createTransport(
|
||||
|
||||
if (mcpServerConfig.httpUrl || mcpServerConfig.url) {
|
||||
const authProvider = createAuthProvider(mcpServerConfig);
|
||||
const headers: Record<string, string> =
|
||||
(await authProvider?.getRequestHeaders?.()) ?? {};
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
if (authProvider === undefined) {
|
||||
// Check if we have OAuth configuration or stored tokens
|
||||
let accessToken: string | null = null;
|
||||
|
||||
Reference in New Issue
Block a user