diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 530d4caf0f..427187540b 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -169,12 +169,18 @@ export class MCPServerConfig { // OAuth configuration readonly oauth?: MCPOAuthConfig, readonly authProviderType?: AuthProviderType, + // Service Account Configuration + /* targetAudience format: CLIENT_ID.apps.googleusercontent.com */ + readonly targetAudience?: string, + /* targetServiceAccount format: @.iam.gserviceaccount.com */ + readonly targetServiceAccount?: string, ) {} } export enum AuthProviderType { DYNAMIC_DISCOVERY = 'dynamic_discovery', GOOGLE_CREDENTIALS = 'google_credentials', + SERVICE_ACCOUNT_IMPERSONATION = 'service_account_impersonation', } export interface SandboxConfig { diff --git a/packages/core/src/mcp/sa-impersonation-provider.test.ts b/packages/core/src/mcp/sa-impersonation-provider.test.ts new file mode 100644 index 0000000000..c86da645cb --- /dev/null +++ b/packages/core/src/mcp/sa-impersonation-provider.test.ts @@ -0,0 +1,153 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { ServiceAccountImpersonationProvider } from './sa-impersonation-provider.js'; +import type { MCPServerConfig } from '../config/config.js'; + +const mockRequest = vi.fn(); +const mockGetClient = vi.fn(() => ({ + request: mockRequest, +})); + +// Mock the google-auth-library to use a shared mock function +vi.mock('google-auth-library', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + GoogleAuth: vi.fn().mockImplementation(() => ({ + getClient: mockGetClient, + })), + }; +}); + +const defaultSAConfig: MCPServerConfig = { + url: 'https://my-iap-service.run.app', + targetAudience: 'my-audience', + targetServiceAccount: 'my-sa', +}; + +describe('ServiceAccountImpersonationProvider', () => { + beforeEach(() => { + // Reset mocks before each test + vi.clearAllMocks(); + }); + + it('should throw an error if no URL is provided', () => { + const config: MCPServerConfig = {}; + expect(() => new ServiceAccountImpersonationProvider(config)).toThrow( + 'A url or httpUrl must be provided for the Service Account Impersonation provider', + ); + }); + + it('should throw an error if no targetAudience is provided', () => { + const config: MCPServerConfig = { + url: 'https://my-iap-service.run.app', + }; + expect(() => new ServiceAccountImpersonationProvider(config)).toThrow( + 'targetAudience must be provided for the Service Account Impersonation provider', + ); + }); + + it('should throw an error if no targetSA is provided', () => { + const config: MCPServerConfig = { + url: 'https://my-iap-service.run.app', + targetAudience: 'my-audience', + }; + expect(() => new ServiceAccountImpersonationProvider(config)).toThrow( + 'targetServiceAccount must be provided for the Service Account Impersonation provider', + ); + }); + + it('should correctly get tokens for a valid config', async () => { + const mockToken = 'mock-id-token-123'; + mockRequest.mockResolvedValue({ data: { token: mockToken } }); + + const provider = new ServiceAccountImpersonationProvider(defaultSAConfig); + const tokens = await provider.tokens(); + + expect(tokens).toBeDefined(); + expect(tokens?.access_token).toBe(mockToken); + expect(tokens?.token_type).toBe('Bearer'); + }); + + it('should return undefined if token acquisition fails', async () => { + mockRequest.mockResolvedValue({ data: { token: null } }); + + const provider = new ServiceAccountImpersonationProvider(defaultSAConfig); + const tokens = await provider.tokens(); + + expect(tokens).toBeUndefined(); + }); + + it('should make a request with the correct parameters', async () => { + mockRequest.mockResolvedValue({ data: { token: 'test-token' } }); + + const provider = new ServiceAccountImpersonationProvider(defaultSAConfig); + await provider.tokens(); + + expect(mockRequest).toHaveBeenCalledWith({ + url: 'https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/my-sa:generateIdToken', + method: 'POST', + data: { + audience: 'my-audience', + includeEmail: true, + }, + }); + }); + + it('should return a cached token if it is not expired', async () => { + const provider = new ServiceAccountImpersonationProvider(defaultSAConfig); + vi.useFakeTimers(); + + // jwt payload with exp set to 1 hour from now + const payload = { exp: Math.floor(Date.now() / 1000) + 3600 }; + const jwt = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`; + mockRequest.mockResolvedValue({ data: { token: jwt } }); + + const firstTokens = await provider.tokens(); + expect(firstTokens?.access_token).toBe(jwt); + expect(mockRequest).toHaveBeenCalledTimes(1); + + // Advance time by 30 minutes + vi.advanceTimersByTime(1800 * 1000); + + // Seturn cached token + const secondTokens = await provider.tokens(); + expect(secondTokens).toBe(firstTokens); + expect(mockRequest).toHaveBeenCalledTimes(1); + + vi.useRealTimers(); + }); + + it('should fetch a new token if the cached token is expired (using fake timers)', async () => { + const provider = new ServiceAccountImpersonationProvider(defaultSAConfig); + vi.useFakeTimers(); + + // Get and cache a token that expires in 1 second + const expiredPayload = { exp: Math.floor(Date.now() / 1000) + 1 }; + const expiredJwt = `header.${Buffer.from(JSON.stringify(expiredPayload)).toString('base64')}.signature`; + + mockRequest.mockResolvedValue({ data: { token: expiredJwt } }); + const firstTokens = await provider.tokens(); + expect(firstTokens?.access_token).toBe(expiredJwt); + expect(mockRequest).toHaveBeenCalledTimes(1); + + // Prepare the mock for the *next* call + const newPayload = { exp: Math.floor(Date.now() / 1000) + 3600 }; + const newJwt = `header.${Buffer.from(JSON.stringify(newPayload)).toString('base64')}.signature`; + mockRequest.mockResolvedValue({ data: { token: newJwt } }); + + vi.advanceTimersByTime(1001); + + const newTokens = await provider.tokens(); + expect(newTokens?.access_token).toBe(newJwt); + expect(newTokens?.access_token).not.toBe(expiredJwt); + expect(mockRequest).toHaveBeenCalledTimes(2); // Confirms a new fetch + + vi.useRealTimers(); + }); +}); diff --git a/packages/core/src/mcp/sa-impersonation-provider.ts b/packages/core/src/mcp/sa-impersonation-provider.ts new file mode 100644 index 0000000000..e3336693d2 --- /dev/null +++ b/packages/core/src/mcp/sa-impersonation-provider.ts @@ -0,0 +1,171 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { + OAuthClientInformation, + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthTokens, +} from '@modelcontextprotocol/sdk/shared/auth.js'; +import { GoogleAuth } from 'google-auth-library'; +import type { MCPServerConfig } from '../config/config.js'; +import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js'; + +const fiveMinBufferMs = 5 * 60 * 1000; + +function createIamApiUrl(targetSA: string): string { + return `https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/${encodeURIComponent(targetSA)}:generateIdToken`; +} + +export class ServiceAccountImpersonationProvider + implements OAuthClientProvider +{ + private readonly targetServiceAccount: string; + private readonly targetAudience: string; // OAuth Client Id + private readonly auth: GoogleAuth; + private cachedToken?: OAuthTokens; + private tokenExpiryTime?: number; + + // Properties required by OAuthClientProvider, with no-op values + readonly redirectUrl = ''; + readonly clientMetadata: OAuthClientMetadata = { + client_name: 'Gemini CLI (Service Account Impersonation)', + redirect_uris: [], + grant_types: [], + response_types: [], + token_endpoint_auth_method: 'none', + }; + private _clientInformation?: OAuthClientInformationFull; + + constructor(private readonly config: MCPServerConfig) { + // This check is done in mcp-client.ts. This is just an additional check. + if (!this.config.httpUrl && !this.config.url) { + throw new Error( + 'A url or httpUrl must be provided for the Service Account Impersonation provider', + ); + } + + if (!config.targetAudience) { + throw new Error( + 'targetAudience must be provided for the Service Account Impersonation provider', + ); + } + this.targetAudience = config.targetAudience; + + if (!config.targetServiceAccount) { + throw new Error( + 'targetServiceAccount must be provided for the Service Account Impersonation provider', + ); + } + this.targetServiceAccount = config.targetServiceAccount; + + this.auth = new GoogleAuth(); + } + + clientInformation(): OAuthClientInformation | undefined { + return this._clientInformation; + } + + saveClientInformation(clientInformation: OAuthClientInformationFull): void { + this._clientInformation = clientInformation; + } + + async tokens(): Promise { + // 1. Check if we have a valid, non-expired cached token. + if ( + this.cachedToken && + this.tokenExpiryTime && + Date.now() < this.tokenExpiryTime - fiveMinBufferMs + ) { + return this.cachedToken; + } + + // 2. Clear any invalid/expired cache. + this.cachedToken = undefined; + this.tokenExpiryTime = undefined; + + // 3. Fetch a new ID token. + const client = await this.auth.getClient(); + const url = createIamApiUrl(this.targetServiceAccount); + + let idToken: string; + try { + const res = await client.request<{ token: string }>({ + url, + method: 'POST', + data: { + audience: this.targetAudience, + includeEmail: true, + }, + }); + idToken = res.data.token; + + if (!idToken || idToken.length === 0) { + console.error('Failed to get ID token from Google'); + return undefined; + } + } catch (e) { + console.error('Failed to fetch ID token from Google:', e); + return undefined; + } + + const expiryTime = this.parseTokenExpiry(idToken); + // Note: We are placing the OIDC ID Token into the `access_token` field. + // This is because the CLI uses this field to construct the + // `Authorization: Bearer ` header, which is the correct way to + // present an ID token. + const newTokens: OAuthTokens = { + access_token: idToken, + token_type: 'Bearer', + }; + + if (expiryTime) { + this.tokenExpiryTime = expiryTime; + this.cachedToken = newTokens; + } + + return newTokens; + } + + saveTokens(_tokens: OAuthTokens): void { + // No-op + } + + redirectToAuthorization(_authorizationUrl: URL): void { + // No-op + } + + saveCodeVerifier(_codeVerifier: string): void { + // No-op + } + + codeVerifier(): string { + // No-op + return ''; + } + + /** + * Parses a JWT string to extract its expiry time. + * @param idToken The JWT ID token. + * @returns The expiry time in **milliseconds**, or undefined if parsing fails. + */ + private parseTokenExpiry(idToken: string): number | undefined { + try { + const payload = JSON.parse( + Buffer.from(idToken.split('.')[1], 'base64').toString(), + ); + + if (payload && typeof payload.exp === 'number') { + return payload.exp * 1000; // Convert seconds to milliseconds + } + } catch (e) { + console.error('Failed to parse ID token for expiry time with error:', e); + } + + // Return undefined if try block fails or 'exp' is missing/invalid + return undefined; + } +} diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 1879948a99..e6ece88e45 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -24,6 +24,7 @@ import { parse } from 'shell-quote'; import type { Config, MCPServerConfig } from '../config/config.js'; import { AuthProviderType } from '../config/config.js'; import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; +import { ServiceAccountImpersonationProvider } from '../mcp/sa-impersonation-provider.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; import type { FunctionDeclaration } from '@google/genai'; @@ -440,6 +441,7 @@ async function createTransportWithOAuth( * @param toolRegistry The central registry where discovered tools will be registered. * @returns A promise that resolves when the discovery process has been attempted for all servers. */ + export async function discoverMcpTools( mcpServers: Record, mcpServerCommand: string | undefined, @@ -1171,6 +1173,34 @@ export async function createTransport( mcpServerConfig: MCPServerConfig, debugMode: boolean, ): Promise { + if ( + mcpServerConfig.authProviderType === + AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION + ) { + const provider = new ServiceAccountImpersonationProvider(mcpServerConfig); + const transportOptions: + | StreamableHTTPClientTransportOptions + | SSEClientTransportOptions = { + authProvider: provider, + }; + + if (mcpServerConfig.httpUrl) { + return new StreamableHTTPClientTransport( + new URL(mcpServerConfig.httpUrl), + transportOptions, + ); + } else if (mcpServerConfig.url) { + // Default to SSE if only url is provided + return new SSEClientTransport( + new URL(mcpServerConfig.url), + transportOptions, + ); + } + throw new Error( + 'No URL configured for ServiceAccountImpersonation MCP Server', + ); + } + if ( mcpServerConfig.authProviderType === AuthProviderType.GOOGLE_CREDENTIALS ) {