diff --git a/packages/core/src/mcp/google-auth-provider.test.ts b/packages/core/src/mcp/google-auth-provider.test.ts index b568fa2ca7..efe959ff3c 100644 --- a/packages/core/src/mcp/google-auth-provider.test.ts +++ b/packages/core/src/mcp/google-auth-provider.test.ts @@ -82,31 +82,76 @@ describe('GoogleCredentialProvider', () => { describe('with provider instance', () => { let provider: GoogleCredentialProvider; + let mockGetAccessToken: Mock; + let mockClient: { + getAccessToken: Mock; + credentials?: { expiry_date: number | null }; + }; beforeEach(() => { + // clear and reset mock client before each test + mockGetAccessToken = vi.fn(); + mockClient = { + getAccessToken: mockGetAccessToken, + }; + (GoogleAuth.prototype.getClient as Mock).mockResolvedValue(mockClient); provider = new GoogleCredentialProvider(validConfig); - vi.clearAllMocks(); }); it('should return credentials', async () => { - const mockClient = { - getAccessToken: vi.fn().mockResolvedValue({ token: 'test-token' }), - }; - (GoogleAuth.prototype.getClient as Mock).mockResolvedValue(mockClient); + mockGetAccessToken.mockResolvedValue({ token: 'test-token' }); const credentials = await provider.tokens(); - expect(credentials?.access_token).toBe('test-token'); }); it('should return undefined if access token is not available', async () => { - const mockClient = { - getAccessToken: vi.fn().mockResolvedValue({ token: null }), - }; - (GoogleAuth.prototype.getClient as Mock).mockResolvedValue(mockClient); + mockGetAccessToken.mockResolvedValue({ token: null }); const credentials = await provider.tokens(); expect(credentials).toBeUndefined(); }); + + it('should return a cached token if it is not expired', async () => { + vi.useFakeTimers(); + mockClient.credentials = { expiry_date: Date.now() + 3600 * 1000 }; // 1 hour + mockGetAccessToken.mockResolvedValue({ token: 'test-token' }); + + // first call + const firstTokens = await provider.tokens(); + expect(firstTokens?.access_token).toBe('test-token'); + expect(mockGetAccessToken).toHaveBeenCalledTimes(1); + + // second call + vi.advanceTimersByTime(1800 * 1000); // Advance time by 30 minutes + const secondTokens = await provider.tokens(); + expect(secondTokens).toBe(firstTokens); + expect(mockGetAccessToken).toHaveBeenCalledTimes(1); // Should not be called again + + vi.useRealTimers(); + }); + + it('should fetch a new token if the cached token is expired', async () => { + vi.useFakeTimers(); + + // first call + mockClient.credentials = { expiry_date: Date.now() + 1000 }; // Expires in 1 second + mockGetAccessToken.mockResolvedValue({ token: 'expired-token' }); + + const firstTokens = await provider.tokens(); + expect(firstTokens?.access_token).toBe('expired-token'); + expect(mockGetAccessToken).toHaveBeenCalledTimes(1); + + // second call + vi.advanceTimersByTime(1001); // Advance time past expiry + mockClient.credentials = { expiry_date: Date.now() + 3600 * 1000 }; // New expiry + mockGetAccessToken.mockResolvedValue({ token: 'new-token' }); + + const newTokens = await provider.tokens(); + expect(newTokens?.access_token).toBe('new-token'); + expect(mockGetAccessToken).toHaveBeenCalledTimes(2); // new fetch + + vi.useRealTimers(); + }); }); }); diff --git a/packages/core/src/mcp/google-auth-provider.ts b/packages/core/src/mcp/google-auth-provider.ts index d761156225..d152b4d256 100644 --- a/packages/core/src/mcp/google-auth-provider.ts +++ b/packages/core/src/mcp/google-auth-provider.ts @@ -13,11 +13,14 @@ import type { } from '@modelcontextprotocol/sdk/shared/auth.js'; import { GoogleAuth } from 'google-auth-library'; import type { MCPServerConfig } from '../config/config.js'; +import { FIVE_MIN_BUFFER_MS } from './oauth-utils.js'; const ALLOWED_HOSTS = [/^.+\.googleapis\.com$/, /^(.*\.)?luci\.app$/]; export class GoogleCredentialProvider implements OAuthClientProvider { private readonly auth: GoogleAuth; + private cachedToken?: OAuthTokens; + private tokenExpiryTime?: number; // Properties required by OAuthClientProvider, with no-op values readonly redirectUrl = ''; @@ -65,6 +68,19 @@ export class GoogleCredentialProvider implements OAuthClientProvider { } async tokens(): Promise { + // check for a valid, non-expired cached token. + if ( + this.cachedToken && + this.tokenExpiryTime && + Date.now() < this.tokenExpiryTime - FIVE_MIN_BUFFER_MS + ) { + return this.cachedToken; + } + + // Clear invalid/expired cache. + this.cachedToken = undefined; + this.tokenExpiryTime = undefined; + const client = await this.auth.getClient(); const accessTokenResponse = await client.getAccessToken(); @@ -73,11 +89,18 @@ export class GoogleCredentialProvider implements OAuthClientProvider { return undefined; } - const tokens: OAuthTokens = { + const newToken: OAuthTokens = { access_token: accessTokenResponse.token, token_type: 'Bearer', }; - return tokens; + + const expiryTime = client.credentials?.expiry_date; + if (expiryTime) { + this.tokenExpiryTime = expiryTime; + this.cachedToken = newToken; + } + + return newToken; } saveTokens(_tokens: OAuthTokens): void { diff --git a/packages/core/src/mcp/oauth-utils.test.ts b/packages/core/src/mcp/oauth-utils.test.ts index 93aa507e21..bec8ef9f4b 100644 --- a/packages/core/src/mcp/oauth-utils.test.ts +++ b/packages/core/src/mcp/oauth-utils.test.ts @@ -325,4 +325,41 @@ describe('OAuthUtils', () => { expect(() => OAuthUtils.buildResourceParameter('not-a-url')).toThrow(); }); }); + + describe('parseTokenExpiry', () => { + it('should return the expiry time in milliseconds for a valid token', () => { + // Corresponds to a date of 2100-01-01T00:00:00Z + const expiry = 4102444800; + const payload = { exp: expiry }; + const token = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`; + const result = OAuthUtils.parseTokenExpiry(token); + expect(result).toBe(expiry * 1000); + }); + + it('should return undefined for a token without an expiry time', () => { + const payload = { iat: 1678886400 }; + const token = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`; + const result = OAuthUtils.parseTokenExpiry(token); + expect(result).toBeUndefined(); + }); + + it('should return undefined for a token with an invalid expiry time', () => { + const payload = { exp: 'not-a-number' }; + const token = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`; + const result = OAuthUtils.parseTokenExpiry(token); + expect(result).toBeUndefined(); + }); + + it('should return undefined for a malformed token', () => { + const token = 'not-a-valid-token'; + const result = OAuthUtils.parseTokenExpiry(token); + expect(result).toBeUndefined(); + }); + + it('should return undefined for a token with invalid JSON in payload', () => { + const token = `header.${Buffer.from('{ not valid json').toString('base64')}.signature`; + const result = OAuthUtils.parseTokenExpiry(token); + expect(result).toBeUndefined(); + }); + }); }); diff --git a/packages/core/src/mcp/oauth-utils.ts b/packages/core/src/mcp/oauth-utils.ts index cf6bfc289d..0f4bd0b24a 100644 --- a/packages/core/src/mcp/oauth-utils.ts +++ b/packages/core/src/mcp/oauth-utils.ts @@ -38,6 +38,8 @@ export interface OAuthProtectedResourceMetadata { resource_encryption_enc_values_supported?: string[]; } +export const FIVE_MIN_BUFFER_MS = 5 * 60 * 1000; + /** * Utility class for common OAuth operations. */ @@ -362,4 +364,26 @@ export class OAuthUtils { const url = new URL(endpointUrl); return `${url.protocol}//${url.host}${url.pathname}`; } + + /** + * Parses a JWT string to extract its expiry time. + * @param idToken The JWT ID token. + * @returns The expiry time in **milliseconds**, or undefined if parsing fails. + */ + static parseTokenExpiry(idToken: string): number | undefined { + try { + const payload = JSON.parse( + Buffer.from(idToken.split('.')[1], 'base64').toString(), + ); + + if (payload && typeof payload.exp === 'number') { + return payload.exp * 1000; // Convert seconds to milliseconds + } + } catch (e) { + console.error('Failed to parse ID token for expiry time with error:', e); + } + + // Return undefined if try block fails or 'exp' is missing/invalid + return undefined; + } } diff --git a/packages/core/src/mcp/sa-impersonation-provider.ts b/packages/core/src/mcp/sa-impersonation-provider.ts index e3336693d2..b9335e2622 100644 --- a/packages/core/src/mcp/sa-impersonation-provider.ts +++ b/packages/core/src/mcp/sa-impersonation-provider.ts @@ -11,11 +11,10 @@ import type { OAuthTokens, } from '@modelcontextprotocol/sdk/shared/auth.js'; import { GoogleAuth } from 'google-auth-library'; +import { OAuthUtils, FIVE_MIN_BUFFER_MS } from './oauth-utils.js'; import type { MCPServerConfig } from '../config/config.js'; import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js'; -const fiveMinBufferMs = 5 * 60 * 1000; - function createIamApiUrl(targetSA: string): string { return `https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/${encodeURIComponent(targetSA)}:generateIdToken`; } @@ -78,7 +77,7 @@ export class ServiceAccountImpersonationProvider if ( this.cachedToken && this.tokenExpiryTime && - Date.now() < this.tokenExpiryTime - fiveMinBufferMs + Date.now() < this.tokenExpiryTime - FIVE_MIN_BUFFER_MS ) { return this.cachedToken; } @@ -112,7 +111,7 @@ export class ServiceAccountImpersonationProvider return undefined; } - const expiryTime = this.parseTokenExpiry(idToken); + const expiryTime = OAuthUtils.parseTokenExpiry(idToken); // Note: We are placing the OIDC ID Token into the `access_token` field. // This is because the CLI uses this field to construct the // `Authorization: Bearer ` header, which is the correct way to @@ -146,26 +145,4 @@ export class ServiceAccountImpersonationProvider // No-op return ''; } - - /** - * Parses a JWT string to extract its expiry time. - * @param idToken The JWT ID token. - * @returns The expiry time in **milliseconds**, or undefined if parsing fails. - */ - private parseTokenExpiry(idToken: string): number | undefined { - try { - const payload = JSON.parse( - Buffer.from(idToken.split('.')[1], 'base64').toString(), - ); - - if (payload && typeof payload.exp === 'number') { - return payload.exp * 1000; // Convert seconds to milliseconds - } - } catch (e) { - console.error('Failed to parse ID token for expiry time with error:', e); - } - - // Return undefined if try block fails or 'exp' is missing/invalid - return undefined; - } }