feat(iap support): Add service account impersonation provider to MCPServers to support IAP on Cloud Run (#8505)

Co-authored-by: Bryan Morgan <bryanmorgan@google.com>
This commit is contained in:
Adam Weidman
2025-09-27 10:12:24 +02:00
committed by GitHub
parent 19400ba8c7
commit db51e3f4cd
4 changed files with 360 additions and 0 deletions

View File

@@ -169,12 +169,18 @@ export class MCPServerConfig {
// OAuth configuration
readonly oauth?: MCPOAuthConfig,
readonly authProviderType?: AuthProviderType,
// Service Account Configuration
/* targetAudience format: CLIENT_ID.apps.googleusercontent.com */
readonly targetAudience?: string,
/* targetServiceAccount format: <service-account-name>@<project-num>.iam.gserviceaccount.com */
readonly targetServiceAccount?: string,
) {}
}
export enum AuthProviderType {
DYNAMIC_DISCOVERY = 'dynamic_discovery',
GOOGLE_CREDENTIALS = 'google_credentials',
SERVICE_ACCOUNT_IMPERSONATION = 'service_account_impersonation',
}
export interface SandboxConfig {

View File

@@ -0,0 +1,153 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { ServiceAccountImpersonationProvider } from './sa-impersonation-provider.js';
import type { MCPServerConfig } from '../config/config.js';
const mockRequest = vi.fn();
const mockGetClient = vi.fn(() => ({
request: mockRequest,
}));
// Mock the google-auth-library to use a shared mock function
vi.mock('google-auth-library', async (importOriginal) => {
const actual = await importOriginal<typeof import('google-auth-library')>();
return {
...actual,
GoogleAuth: vi.fn().mockImplementation(() => ({
getClient: mockGetClient,
})),
};
});
const defaultSAConfig: MCPServerConfig = {
url: 'https://my-iap-service.run.app',
targetAudience: 'my-audience',
targetServiceAccount: 'my-sa',
};
describe('ServiceAccountImpersonationProvider', () => {
beforeEach(() => {
// Reset mocks before each test
vi.clearAllMocks();
});
it('should throw an error if no URL is provided', () => {
const config: MCPServerConfig = {};
expect(() => new ServiceAccountImpersonationProvider(config)).toThrow(
'A url or httpUrl must be provided for the Service Account Impersonation provider',
);
});
it('should throw an error if no targetAudience is provided', () => {
const config: MCPServerConfig = {
url: 'https://my-iap-service.run.app',
};
expect(() => new ServiceAccountImpersonationProvider(config)).toThrow(
'targetAudience must be provided for the Service Account Impersonation provider',
);
});
it('should throw an error if no targetSA is provided', () => {
const config: MCPServerConfig = {
url: 'https://my-iap-service.run.app',
targetAudience: 'my-audience',
};
expect(() => new ServiceAccountImpersonationProvider(config)).toThrow(
'targetServiceAccount must be provided for the Service Account Impersonation provider',
);
});
it('should correctly get tokens for a valid config', async () => {
const mockToken = 'mock-id-token-123';
mockRequest.mockResolvedValue({ data: { token: mockToken } });
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
const tokens = await provider.tokens();
expect(tokens).toBeDefined();
expect(tokens?.access_token).toBe(mockToken);
expect(tokens?.token_type).toBe('Bearer');
});
it('should return undefined if token acquisition fails', async () => {
mockRequest.mockResolvedValue({ data: { token: null } });
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
const tokens = await provider.tokens();
expect(tokens).toBeUndefined();
});
it('should make a request with the correct parameters', async () => {
mockRequest.mockResolvedValue({ data: { token: 'test-token' } });
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
await provider.tokens();
expect(mockRequest).toHaveBeenCalledWith({
url: 'https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/my-sa:generateIdToken',
method: 'POST',
data: {
audience: 'my-audience',
includeEmail: true,
},
});
});
it('should return a cached token if it is not expired', async () => {
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
vi.useFakeTimers();
// jwt payload with exp set to 1 hour from now
const payload = { exp: Math.floor(Date.now() / 1000) + 3600 };
const jwt = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`;
mockRequest.mockResolvedValue({ data: { token: jwt } });
const firstTokens = await provider.tokens();
expect(firstTokens?.access_token).toBe(jwt);
expect(mockRequest).toHaveBeenCalledTimes(1);
// Advance time by 30 minutes
vi.advanceTimersByTime(1800 * 1000);
// Seturn cached token
const secondTokens = await provider.tokens();
expect(secondTokens).toBe(firstTokens);
expect(mockRequest).toHaveBeenCalledTimes(1);
vi.useRealTimers();
});
it('should fetch a new token if the cached token is expired (using fake timers)', async () => {
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
vi.useFakeTimers();
// Get and cache a token that expires in 1 second
const expiredPayload = { exp: Math.floor(Date.now() / 1000) + 1 };
const expiredJwt = `header.${Buffer.from(JSON.stringify(expiredPayload)).toString('base64')}.signature`;
mockRequest.mockResolvedValue({ data: { token: expiredJwt } });
const firstTokens = await provider.tokens();
expect(firstTokens?.access_token).toBe(expiredJwt);
expect(mockRequest).toHaveBeenCalledTimes(1);
// Prepare the mock for the *next* call
const newPayload = { exp: Math.floor(Date.now() / 1000) + 3600 };
const newJwt = `header.${Buffer.from(JSON.stringify(newPayload)).toString('base64')}.signature`;
mockRequest.mockResolvedValue({ data: { token: newJwt } });
vi.advanceTimersByTime(1001);
const newTokens = await provider.tokens();
expect(newTokens?.access_token).toBe(newJwt);
expect(newTokens?.access_token).not.toBe(expiredJwt);
expect(mockRequest).toHaveBeenCalledTimes(2); // Confirms a new fetch
vi.useRealTimers();
});
});

View File

@@ -0,0 +1,171 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type {
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthTokens,
} from '@modelcontextprotocol/sdk/shared/auth.js';
import { GoogleAuth } from 'google-auth-library';
import type { MCPServerConfig } from '../config/config.js';
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
const fiveMinBufferMs = 5 * 60 * 1000;
function createIamApiUrl(targetSA: string): string {
return `https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/${encodeURIComponent(targetSA)}:generateIdToken`;
}
export class ServiceAccountImpersonationProvider
implements OAuthClientProvider
{
private readonly targetServiceAccount: string;
private readonly targetAudience: string; // OAuth Client Id
private readonly auth: GoogleAuth;
private cachedToken?: OAuthTokens;
private tokenExpiryTime?: number;
// Properties required by OAuthClientProvider, with no-op values
readonly redirectUrl = '';
readonly clientMetadata: OAuthClientMetadata = {
client_name: 'Gemini CLI (Service Account Impersonation)',
redirect_uris: [],
grant_types: [],
response_types: [],
token_endpoint_auth_method: 'none',
};
private _clientInformation?: OAuthClientInformationFull;
constructor(private readonly config: MCPServerConfig) {
// This check is done in mcp-client.ts. This is just an additional check.
if (!this.config.httpUrl && !this.config.url) {
throw new Error(
'A url or httpUrl must be provided for the Service Account Impersonation provider',
);
}
if (!config.targetAudience) {
throw new Error(
'targetAudience must be provided for the Service Account Impersonation provider',
);
}
this.targetAudience = config.targetAudience;
if (!config.targetServiceAccount) {
throw new Error(
'targetServiceAccount must be provided for the Service Account Impersonation provider',
);
}
this.targetServiceAccount = config.targetServiceAccount;
this.auth = new GoogleAuth();
}
clientInformation(): OAuthClientInformation | undefined {
return this._clientInformation;
}
saveClientInformation(clientInformation: OAuthClientInformationFull): void {
this._clientInformation = clientInformation;
}
async tokens(): Promise<OAuthTokens | undefined> {
// 1. Check if we have a valid, non-expired cached token.
if (
this.cachedToken &&
this.tokenExpiryTime &&
Date.now() < this.tokenExpiryTime - fiveMinBufferMs
) {
return this.cachedToken;
}
// 2. Clear any invalid/expired cache.
this.cachedToken = undefined;
this.tokenExpiryTime = undefined;
// 3. Fetch a new ID token.
const client = await this.auth.getClient();
const url = createIamApiUrl(this.targetServiceAccount);
let idToken: string;
try {
const res = await client.request<{ token: string }>({
url,
method: 'POST',
data: {
audience: this.targetAudience,
includeEmail: true,
},
});
idToken = res.data.token;
if (!idToken || idToken.length === 0) {
console.error('Failed to get ID token from Google');
return undefined;
}
} catch (e) {
console.error('Failed to fetch ID token from Google:', e);
return undefined;
}
const expiryTime = this.parseTokenExpiry(idToken);
// Note: We are placing the OIDC ID Token into the `access_token` field.
// This is because the CLI uses this field to construct the
// `Authorization: Bearer <token>` header, which is the correct way to
// present an ID token.
const newTokens: OAuthTokens = {
access_token: idToken,
token_type: 'Bearer',
};
if (expiryTime) {
this.tokenExpiryTime = expiryTime;
this.cachedToken = newTokens;
}
return newTokens;
}
saveTokens(_tokens: OAuthTokens): void {
// No-op
}
redirectToAuthorization(_authorizationUrl: URL): void {
// No-op
}
saveCodeVerifier(_codeVerifier: string): void {
// No-op
}
codeVerifier(): string {
// No-op
return '';
}
/**
* Parses a JWT string to extract its expiry time.
* @param idToken The JWT ID token.
* @returns The expiry time in **milliseconds**, or undefined if parsing fails.
*/
private parseTokenExpiry(idToken: string): number | undefined {
try {
const payload = JSON.parse(
Buffer.from(idToken.split('.')[1], 'base64').toString(),
);
if (payload && typeof payload.exp === 'number') {
return payload.exp * 1000; // Convert seconds to milliseconds
}
} catch (e) {
console.error('Failed to parse ID token for expiry time with error:', e);
}
// Return undefined if try block fails or 'exp' is missing/invalid
return undefined;
}
}

View File

@@ -24,6 +24,7 @@ import { parse } from 'shell-quote';
import type { Config, MCPServerConfig } from '../config/config.js';
import { AuthProviderType } from '../config/config.js';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { ServiceAccountImpersonationProvider } from '../mcp/sa-impersonation-provider.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import type { FunctionDeclaration } from '@google/genai';
@@ -440,6 +441,7 @@ async function createTransportWithOAuth(
* @param toolRegistry The central registry where discovered tools will be registered.
* @returns A promise that resolves when the discovery process has been attempted for all servers.
*/
export async function discoverMcpTools(
mcpServers: Record<string, MCPServerConfig>,
mcpServerCommand: string | undefined,
@@ -1171,6 +1173,34 @@ export async function createTransport(
mcpServerConfig: MCPServerConfig,
debugMode: boolean,
): Promise<Transport> {
if (
mcpServerConfig.authProviderType ===
AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION
) {
const provider = new ServiceAccountImpersonationProvider(mcpServerConfig);
const transportOptions:
| StreamableHTTPClientTransportOptions
| SSEClientTransportOptions = {
authProvider: provider,
};
if (mcpServerConfig.httpUrl) {
return new StreamableHTTPClientTransport(
new URL(mcpServerConfig.httpUrl),
transportOptions,
);
} else if (mcpServerConfig.url) {
// Default to SSE if only url is provided
return new SSEClientTransport(
new URL(mcpServerConfig.url),
transportOptions,
);
}
throw new Error(
'No URL configured for ServiceAccountImpersonation MCP Server',
);
}
if (
mcpServerConfig.authProviderType === AuthProviderType.GOOGLE_CREDENTIALS
) {