mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-11 06:31:01 -07:00
feat(iap support): Add service account impersonation provider to MCPServers to support IAP on Cloud Run (#8505)
Co-authored-by: Bryan Morgan <bryanmorgan@google.com>
This commit is contained in:
@@ -169,12 +169,18 @@ export class MCPServerConfig {
|
||||
// OAuth configuration
|
||||
readonly oauth?: MCPOAuthConfig,
|
||||
readonly authProviderType?: AuthProviderType,
|
||||
// Service Account Configuration
|
||||
/* targetAudience format: CLIENT_ID.apps.googleusercontent.com */
|
||||
readonly targetAudience?: string,
|
||||
/* targetServiceAccount format: <service-account-name>@<project-num>.iam.gserviceaccount.com */
|
||||
readonly targetServiceAccount?: string,
|
||||
) {}
|
||||
}
|
||||
|
||||
export enum AuthProviderType {
|
||||
DYNAMIC_DISCOVERY = 'dynamic_discovery',
|
||||
GOOGLE_CREDENTIALS = 'google_credentials',
|
||||
SERVICE_ACCOUNT_IMPERSONATION = 'service_account_impersonation',
|
||||
}
|
||||
|
||||
export interface SandboxConfig {
|
||||
|
||||
153
packages/core/src/mcp/sa-impersonation-provider.test.ts
Normal file
153
packages/core/src/mcp/sa-impersonation-provider.test.ts
Normal file
@@ -0,0 +1,153 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { ServiceAccountImpersonationProvider } from './sa-impersonation-provider.js';
|
||||
import type { MCPServerConfig } from '../config/config.js';
|
||||
|
||||
const mockRequest = vi.fn();
|
||||
const mockGetClient = vi.fn(() => ({
|
||||
request: mockRequest,
|
||||
}));
|
||||
|
||||
// Mock the google-auth-library to use a shared mock function
|
||||
vi.mock('google-auth-library', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('google-auth-library')>();
|
||||
return {
|
||||
...actual,
|
||||
GoogleAuth: vi.fn().mockImplementation(() => ({
|
||||
getClient: mockGetClient,
|
||||
})),
|
||||
};
|
||||
});
|
||||
|
||||
const defaultSAConfig: MCPServerConfig = {
|
||||
url: 'https://my-iap-service.run.app',
|
||||
targetAudience: 'my-audience',
|
||||
targetServiceAccount: 'my-sa',
|
||||
};
|
||||
|
||||
describe('ServiceAccountImpersonationProvider', () => {
|
||||
beforeEach(() => {
|
||||
// Reset mocks before each test
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should throw an error if no URL is provided', () => {
|
||||
const config: MCPServerConfig = {};
|
||||
expect(() => new ServiceAccountImpersonationProvider(config)).toThrow(
|
||||
'A url or httpUrl must be provided for the Service Account Impersonation provider',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if no targetAudience is provided', () => {
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://my-iap-service.run.app',
|
||||
};
|
||||
expect(() => new ServiceAccountImpersonationProvider(config)).toThrow(
|
||||
'targetAudience must be provided for the Service Account Impersonation provider',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if no targetSA is provided', () => {
|
||||
const config: MCPServerConfig = {
|
||||
url: 'https://my-iap-service.run.app',
|
||||
targetAudience: 'my-audience',
|
||||
};
|
||||
expect(() => new ServiceAccountImpersonationProvider(config)).toThrow(
|
||||
'targetServiceAccount must be provided for the Service Account Impersonation provider',
|
||||
);
|
||||
});
|
||||
|
||||
it('should correctly get tokens for a valid config', async () => {
|
||||
const mockToken = 'mock-id-token-123';
|
||||
mockRequest.mockResolvedValue({ data: { token: mockToken } });
|
||||
|
||||
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
|
||||
const tokens = await provider.tokens();
|
||||
|
||||
expect(tokens).toBeDefined();
|
||||
expect(tokens?.access_token).toBe(mockToken);
|
||||
expect(tokens?.token_type).toBe('Bearer');
|
||||
});
|
||||
|
||||
it('should return undefined if token acquisition fails', async () => {
|
||||
mockRequest.mockResolvedValue({ data: { token: null } });
|
||||
|
||||
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
|
||||
const tokens = await provider.tokens();
|
||||
|
||||
expect(tokens).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should make a request with the correct parameters', async () => {
|
||||
mockRequest.mockResolvedValue({ data: { token: 'test-token' } });
|
||||
|
||||
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
|
||||
await provider.tokens();
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
url: 'https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/my-sa:generateIdToken',
|
||||
method: 'POST',
|
||||
data: {
|
||||
audience: 'my-audience',
|
||||
includeEmail: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should return a cached token if it is not expired', async () => {
|
||||
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
|
||||
vi.useFakeTimers();
|
||||
|
||||
// jwt payload with exp set to 1 hour from now
|
||||
const payload = { exp: Math.floor(Date.now() / 1000) + 3600 };
|
||||
const jwt = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`;
|
||||
mockRequest.mockResolvedValue({ data: { token: jwt } });
|
||||
|
||||
const firstTokens = await provider.tokens();
|
||||
expect(firstTokens?.access_token).toBe(jwt);
|
||||
expect(mockRequest).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Advance time by 30 minutes
|
||||
vi.advanceTimersByTime(1800 * 1000);
|
||||
|
||||
// Seturn cached token
|
||||
const secondTokens = await provider.tokens();
|
||||
expect(secondTokens).toBe(firstTokens);
|
||||
expect(mockRequest).toHaveBeenCalledTimes(1);
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should fetch a new token if the cached token is expired (using fake timers)', async () => {
|
||||
const provider = new ServiceAccountImpersonationProvider(defaultSAConfig);
|
||||
vi.useFakeTimers();
|
||||
|
||||
// Get and cache a token that expires in 1 second
|
||||
const expiredPayload = { exp: Math.floor(Date.now() / 1000) + 1 };
|
||||
const expiredJwt = `header.${Buffer.from(JSON.stringify(expiredPayload)).toString('base64')}.signature`;
|
||||
|
||||
mockRequest.mockResolvedValue({ data: { token: expiredJwt } });
|
||||
const firstTokens = await provider.tokens();
|
||||
expect(firstTokens?.access_token).toBe(expiredJwt);
|
||||
expect(mockRequest).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Prepare the mock for the *next* call
|
||||
const newPayload = { exp: Math.floor(Date.now() / 1000) + 3600 };
|
||||
const newJwt = `header.${Buffer.from(JSON.stringify(newPayload)).toString('base64')}.signature`;
|
||||
mockRequest.mockResolvedValue({ data: { token: newJwt } });
|
||||
|
||||
vi.advanceTimersByTime(1001);
|
||||
|
||||
const newTokens = await provider.tokens();
|
||||
expect(newTokens?.access_token).toBe(newJwt);
|
||||
expect(newTokens?.access_token).not.toBe(expiredJwt);
|
||||
expect(mockRequest).toHaveBeenCalledTimes(2); // Confirms a new fetch
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
});
|
||||
171
packages/core/src/mcp/sa-impersonation-provider.ts
Normal file
171
packages/core/src/mcp/sa-impersonation-provider.ts
Normal file
@@ -0,0 +1,171 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthTokens,
|
||||
} from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import { GoogleAuth } from 'google-auth-library';
|
||||
import type { MCPServerConfig } from '../config/config.js';
|
||||
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
|
||||
|
||||
const fiveMinBufferMs = 5 * 60 * 1000;
|
||||
|
||||
function createIamApiUrl(targetSA: string): string {
|
||||
return `https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/${encodeURIComponent(targetSA)}:generateIdToken`;
|
||||
}
|
||||
|
||||
export class ServiceAccountImpersonationProvider
|
||||
implements OAuthClientProvider
|
||||
{
|
||||
private readonly targetServiceAccount: string;
|
||||
private readonly targetAudience: string; // OAuth Client Id
|
||||
private readonly auth: GoogleAuth;
|
||||
private cachedToken?: OAuthTokens;
|
||||
private tokenExpiryTime?: number;
|
||||
|
||||
// Properties required by OAuthClientProvider, with no-op values
|
||||
readonly redirectUrl = '';
|
||||
readonly clientMetadata: OAuthClientMetadata = {
|
||||
client_name: 'Gemini CLI (Service Account Impersonation)',
|
||||
redirect_uris: [],
|
||||
grant_types: [],
|
||||
response_types: [],
|
||||
token_endpoint_auth_method: 'none',
|
||||
};
|
||||
private _clientInformation?: OAuthClientInformationFull;
|
||||
|
||||
constructor(private readonly config: MCPServerConfig) {
|
||||
// This check is done in mcp-client.ts. This is just an additional check.
|
||||
if (!this.config.httpUrl && !this.config.url) {
|
||||
throw new Error(
|
||||
'A url or httpUrl must be provided for the Service Account Impersonation provider',
|
||||
);
|
||||
}
|
||||
|
||||
if (!config.targetAudience) {
|
||||
throw new Error(
|
||||
'targetAudience must be provided for the Service Account Impersonation provider',
|
||||
);
|
||||
}
|
||||
this.targetAudience = config.targetAudience;
|
||||
|
||||
if (!config.targetServiceAccount) {
|
||||
throw new Error(
|
||||
'targetServiceAccount must be provided for the Service Account Impersonation provider',
|
||||
);
|
||||
}
|
||||
this.targetServiceAccount = config.targetServiceAccount;
|
||||
|
||||
this.auth = new GoogleAuth();
|
||||
}
|
||||
|
||||
clientInformation(): OAuthClientInformation | undefined {
|
||||
return this._clientInformation;
|
||||
}
|
||||
|
||||
saveClientInformation(clientInformation: OAuthClientInformationFull): void {
|
||||
this._clientInformation = clientInformation;
|
||||
}
|
||||
|
||||
async tokens(): Promise<OAuthTokens | undefined> {
|
||||
// 1. Check if we have a valid, non-expired cached token.
|
||||
if (
|
||||
this.cachedToken &&
|
||||
this.tokenExpiryTime &&
|
||||
Date.now() < this.tokenExpiryTime - fiveMinBufferMs
|
||||
) {
|
||||
return this.cachedToken;
|
||||
}
|
||||
|
||||
// 2. Clear any invalid/expired cache.
|
||||
this.cachedToken = undefined;
|
||||
this.tokenExpiryTime = undefined;
|
||||
|
||||
// 3. Fetch a new ID token.
|
||||
const client = await this.auth.getClient();
|
||||
const url = createIamApiUrl(this.targetServiceAccount);
|
||||
|
||||
let idToken: string;
|
||||
try {
|
||||
const res = await client.request<{ token: string }>({
|
||||
url,
|
||||
method: 'POST',
|
||||
data: {
|
||||
audience: this.targetAudience,
|
||||
includeEmail: true,
|
||||
},
|
||||
});
|
||||
idToken = res.data.token;
|
||||
|
||||
if (!idToken || idToken.length === 0) {
|
||||
console.error('Failed to get ID token from Google');
|
||||
return undefined;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to fetch ID token from Google:', e);
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const expiryTime = this.parseTokenExpiry(idToken);
|
||||
// Note: We are placing the OIDC ID Token into the `access_token` field.
|
||||
// This is because the CLI uses this field to construct the
|
||||
// `Authorization: Bearer <token>` header, which is the correct way to
|
||||
// present an ID token.
|
||||
const newTokens: OAuthTokens = {
|
||||
access_token: idToken,
|
||||
token_type: 'Bearer',
|
||||
};
|
||||
|
||||
if (expiryTime) {
|
||||
this.tokenExpiryTime = expiryTime;
|
||||
this.cachedToken = newTokens;
|
||||
}
|
||||
|
||||
return newTokens;
|
||||
}
|
||||
|
||||
saveTokens(_tokens: OAuthTokens): void {
|
||||
// No-op
|
||||
}
|
||||
|
||||
redirectToAuthorization(_authorizationUrl: URL): void {
|
||||
// No-op
|
||||
}
|
||||
|
||||
saveCodeVerifier(_codeVerifier: string): void {
|
||||
// No-op
|
||||
}
|
||||
|
||||
codeVerifier(): string {
|
||||
// No-op
|
||||
return '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a JWT string to extract its expiry time.
|
||||
* @param idToken The JWT ID token.
|
||||
* @returns The expiry time in **milliseconds**, or undefined if parsing fails.
|
||||
*/
|
||||
private parseTokenExpiry(idToken: string): number | undefined {
|
||||
try {
|
||||
const payload = JSON.parse(
|
||||
Buffer.from(idToken.split('.')[1], 'base64').toString(),
|
||||
);
|
||||
|
||||
if (payload && typeof payload.exp === 'number') {
|
||||
return payload.exp * 1000; // Convert seconds to milliseconds
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to parse ID token for expiry time with error:', e);
|
||||
}
|
||||
|
||||
// Return undefined if try block fails or 'exp' is missing/invalid
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,7 @@ import { parse } from 'shell-quote';
|
||||
import type { Config, MCPServerConfig } from '../config/config.js';
|
||||
import { AuthProviderType } from '../config/config.js';
|
||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||
import { ServiceAccountImpersonationProvider } from '../mcp/sa-impersonation-provider.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
|
||||
import type { FunctionDeclaration } from '@google/genai';
|
||||
@@ -440,6 +441,7 @@ async function createTransportWithOAuth(
|
||||
* @param toolRegistry The central registry where discovered tools will be registered.
|
||||
* @returns A promise that resolves when the discovery process has been attempted for all servers.
|
||||
*/
|
||||
|
||||
export async function discoverMcpTools(
|
||||
mcpServers: Record<string, MCPServerConfig>,
|
||||
mcpServerCommand: string | undefined,
|
||||
@@ -1171,6 +1173,34 @@ export async function createTransport(
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
debugMode: boolean,
|
||||
): Promise<Transport> {
|
||||
if (
|
||||
mcpServerConfig.authProviderType ===
|
||||
AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION
|
||||
) {
|
||||
const provider = new ServiceAccountImpersonationProvider(mcpServerConfig);
|
||||
const transportOptions:
|
||||
| StreamableHTTPClientTransportOptions
|
||||
| SSEClientTransportOptions = {
|
||||
authProvider: provider,
|
||||
};
|
||||
|
||||
if (mcpServerConfig.httpUrl) {
|
||||
return new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerConfig.httpUrl),
|
||||
transportOptions,
|
||||
);
|
||||
} else if (mcpServerConfig.url) {
|
||||
// Default to SSE if only url is provided
|
||||
return new SSEClientTransport(
|
||||
new URL(mcpServerConfig.url),
|
||||
transportOptions,
|
||||
);
|
||||
}
|
||||
throw new Error(
|
||||
'No URL configured for ServiceAccountImpersonation MCP Server',
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
mcpServerConfig.authProviderType === AuthProviderType.GOOGLE_CREDENTIALS
|
||||
) {
|
||||
|
||||
Reference in New Issue
Block a user