chore(core): add token caching in google auth provider (#11946)

This commit is contained in:
Adam Weidman
2025-10-24 16:26:42 +02:00
committed by GitHub
parent a889c15e38
commit c079084ca4
5 changed files with 144 additions and 38 deletions
@@ -82,31 +82,76 @@ describe('GoogleCredentialProvider', () => {
describe('with provider instance', () => { describe('with provider instance', () => {
let provider: GoogleCredentialProvider; let provider: GoogleCredentialProvider;
let mockGetAccessToken: Mock;
let mockClient: {
getAccessToken: Mock;
credentials?: { expiry_date: number | null };
};
beforeEach(() => { beforeEach(() => {
// clear and reset mock client before each test
mockGetAccessToken = vi.fn();
mockClient = {
getAccessToken: mockGetAccessToken,
};
(GoogleAuth.prototype.getClient as Mock).mockResolvedValue(mockClient);
provider = new GoogleCredentialProvider(validConfig); provider = new GoogleCredentialProvider(validConfig);
vi.clearAllMocks();
}); });
it('should return credentials', async () => { it('should return credentials', async () => {
const mockClient = { mockGetAccessToken.mockResolvedValue({ token: 'test-token' });
getAccessToken: vi.fn().mockResolvedValue({ token: 'test-token' }),
};
(GoogleAuth.prototype.getClient as Mock).mockResolvedValue(mockClient);
const credentials = await provider.tokens(); const credentials = await provider.tokens();
expect(credentials?.access_token).toBe('test-token'); expect(credentials?.access_token).toBe('test-token');
}); });
it('should return undefined if access token is not available', async () => { it('should return undefined if access token is not available', async () => {
const mockClient = { mockGetAccessToken.mockResolvedValue({ token: null });
getAccessToken: vi.fn().mockResolvedValue({ token: null }),
};
(GoogleAuth.prototype.getClient as Mock).mockResolvedValue(mockClient);
const credentials = await provider.tokens(); const credentials = await provider.tokens();
expect(credentials).toBeUndefined(); expect(credentials).toBeUndefined();
}); });
it('should return a cached token if it is not expired', async () => {
vi.useFakeTimers();
mockClient.credentials = { expiry_date: Date.now() + 3600 * 1000 }; // 1 hour
mockGetAccessToken.mockResolvedValue({ token: 'test-token' });
// first call
const firstTokens = await provider.tokens();
expect(firstTokens?.access_token).toBe('test-token');
expect(mockGetAccessToken).toHaveBeenCalledTimes(1);
// second call
vi.advanceTimersByTime(1800 * 1000); // Advance time by 30 minutes
const secondTokens = await provider.tokens();
expect(secondTokens).toBe(firstTokens);
expect(mockGetAccessToken).toHaveBeenCalledTimes(1); // Should not be called again
vi.useRealTimers();
});
it('should fetch a new token if the cached token is expired', async () => {
vi.useFakeTimers();
// first call
mockClient.credentials = { expiry_date: Date.now() + 1000 }; // Expires in 1 second
mockGetAccessToken.mockResolvedValue({ token: 'expired-token' });
const firstTokens = await provider.tokens();
expect(firstTokens?.access_token).toBe('expired-token');
expect(mockGetAccessToken).toHaveBeenCalledTimes(1);
// second call
vi.advanceTimersByTime(1001); // Advance time past expiry
mockClient.credentials = { expiry_date: Date.now() + 3600 * 1000 }; // New expiry
mockGetAccessToken.mockResolvedValue({ token: 'new-token' });
const newTokens = await provider.tokens();
expect(newTokens?.access_token).toBe('new-token');
expect(mockGetAccessToken).toHaveBeenCalledTimes(2); // new fetch
vi.useRealTimers();
});
}); });
}); });
+25 -2
View File
@@ -13,11 +13,14 @@ import type {
} from '@modelcontextprotocol/sdk/shared/auth.js'; } from '@modelcontextprotocol/sdk/shared/auth.js';
import { GoogleAuth } from 'google-auth-library'; import { GoogleAuth } from 'google-auth-library';
import type { MCPServerConfig } from '../config/config.js'; import type { MCPServerConfig } from '../config/config.js';
import { FIVE_MIN_BUFFER_MS } from './oauth-utils.js';
const ALLOWED_HOSTS = [/^.+\.googleapis\.com$/, /^(.*\.)?luci\.app$/]; const ALLOWED_HOSTS = [/^.+\.googleapis\.com$/, /^(.*\.)?luci\.app$/];
export class GoogleCredentialProvider implements OAuthClientProvider { export class GoogleCredentialProvider implements OAuthClientProvider {
private readonly auth: GoogleAuth; private readonly auth: GoogleAuth;
private cachedToken?: OAuthTokens;
private tokenExpiryTime?: number;
// Properties required by OAuthClientProvider, with no-op values // Properties required by OAuthClientProvider, with no-op values
readonly redirectUrl = ''; readonly redirectUrl = '';
@@ -65,6 +68,19 @@ export class GoogleCredentialProvider implements OAuthClientProvider {
} }
async tokens(): Promise<OAuthTokens | undefined> { async tokens(): Promise<OAuthTokens | undefined> {
// check for a valid, non-expired cached token.
if (
this.cachedToken &&
this.tokenExpiryTime &&
Date.now() < this.tokenExpiryTime - FIVE_MIN_BUFFER_MS
) {
return this.cachedToken;
}
// Clear invalid/expired cache.
this.cachedToken = undefined;
this.tokenExpiryTime = undefined;
const client = await this.auth.getClient(); const client = await this.auth.getClient();
const accessTokenResponse = await client.getAccessToken(); const accessTokenResponse = await client.getAccessToken();
@@ -73,11 +89,18 @@ export class GoogleCredentialProvider implements OAuthClientProvider {
return undefined; return undefined;
} }
const tokens: OAuthTokens = { const newToken: OAuthTokens = {
access_token: accessTokenResponse.token, access_token: accessTokenResponse.token,
token_type: 'Bearer', token_type: 'Bearer',
}; };
return tokens;
const expiryTime = client.credentials?.expiry_date;
if (expiryTime) {
this.tokenExpiryTime = expiryTime;
this.cachedToken = newToken;
}
return newToken;
} }
saveTokens(_tokens: OAuthTokens): void { saveTokens(_tokens: OAuthTokens): void {
+37
View File
@@ -325,4 +325,41 @@ describe('OAuthUtils', () => {
expect(() => OAuthUtils.buildResourceParameter('not-a-url')).toThrow(); expect(() => OAuthUtils.buildResourceParameter('not-a-url')).toThrow();
}); });
}); });
describe('parseTokenExpiry', () => {
it('should return the expiry time in milliseconds for a valid token', () => {
// Corresponds to a date of 2100-01-01T00:00:00Z
const expiry = 4102444800;
const payload = { exp: expiry };
const token = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`;
const result = OAuthUtils.parseTokenExpiry(token);
expect(result).toBe(expiry * 1000);
});
it('should return undefined for a token without an expiry time', () => {
const payload = { iat: 1678886400 };
const token = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`;
const result = OAuthUtils.parseTokenExpiry(token);
expect(result).toBeUndefined();
});
it('should return undefined for a token with an invalid expiry time', () => {
const payload = { exp: 'not-a-number' };
const token = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`;
const result = OAuthUtils.parseTokenExpiry(token);
expect(result).toBeUndefined();
});
it('should return undefined for a malformed token', () => {
const token = 'not-a-valid-token';
const result = OAuthUtils.parseTokenExpiry(token);
expect(result).toBeUndefined();
});
it('should return undefined for a token with invalid JSON in payload', () => {
const token = `header.${Buffer.from('{ not valid json').toString('base64')}.signature`;
const result = OAuthUtils.parseTokenExpiry(token);
expect(result).toBeUndefined();
});
});
}); });
+24
View File
@@ -38,6 +38,8 @@ export interface OAuthProtectedResourceMetadata {
resource_encryption_enc_values_supported?: string[]; resource_encryption_enc_values_supported?: string[];
} }
export const FIVE_MIN_BUFFER_MS = 5 * 60 * 1000;
/** /**
* Utility class for common OAuth operations. * Utility class for common OAuth operations.
*/ */
@@ -362,4 +364,26 @@ export class OAuthUtils {
const url = new URL(endpointUrl); const url = new URL(endpointUrl);
return `${url.protocol}//${url.host}${url.pathname}`; return `${url.protocol}//${url.host}${url.pathname}`;
} }
/**
* Parses a JWT string to extract its expiry time.
* @param idToken The JWT ID token.
* @returns The expiry time in **milliseconds**, or undefined if parsing fails.
*/
static parseTokenExpiry(idToken: string): number | undefined {
try {
const payload = JSON.parse(
Buffer.from(idToken.split('.')[1], 'base64').toString(),
);
if (payload && typeof payload.exp === 'number') {
return payload.exp * 1000; // Convert seconds to milliseconds
}
} catch (e) {
console.error('Failed to parse ID token for expiry time with error:', e);
}
// Return undefined if try block fails or 'exp' is missing/invalid
return undefined;
}
} }
@@ -11,11 +11,10 @@ import type {
OAuthTokens, OAuthTokens,
} from '@modelcontextprotocol/sdk/shared/auth.js'; } from '@modelcontextprotocol/sdk/shared/auth.js';
import { GoogleAuth } from 'google-auth-library'; import { GoogleAuth } from 'google-auth-library';
import { OAuthUtils, FIVE_MIN_BUFFER_MS } from './oauth-utils.js';
import type { MCPServerConfig } from '../config/config.js'; import type { MCPServerConfig } from '../config/config.js';
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js'; import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
const fiveMinBufferMs = 5 * 60 * 1000;
function createIamApiUrl(targetSA: string): string { function createIamApiUrl(targetSA: string): string {
return `https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/${encodeURIComponent(targetSA)}:generateIdToken`; return `https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/${encodeURIComponent(targetSA)}:generateIdToken`;
} }
@@ -78,7 +77,7 @@ export class ServiceAccountImpersonationProvider
if ( if (
this.cachedToken && this.cachedToken &&
this.tokenExpiryTime && this.tokenExpiryTime &&
Date.now() < this.tokenExpiryTime - fiveMinBufferMs Date.now() < this.tokenExpiryTime - FIVE_MIN_BUFFER_MS
) { ) {
return this.cachedToken; return this.cachedToken;
} }
@@ -112,7 +111,7 @@ export class ServiceAccountImpersonationProvider
return undefined; return undefined;
} }
const expiryTime = this.parseTokenExpiry(idToken); const expiryTime = OAuthUtils.parseTokenExpiry(idToken);
// Note: We are placing the OIDC ID Token into the `access_token` field. // Note: We are placing the OIDC ID Token into the `access_token` field.
// This is because the CLI uses this field to construct the // This is because the CLI uses this field to construct the
// `Authorization: Bearer <token>` header, which is the correct way to // `Authorization: Bearer <token>` header, which is the correct way to
@@ -146,26 +145,4 @@ export class ServiceAccountImpersonationProvider
// No-op // No-op
return ''; return '';
} }
/**
* Parses a JWT string to extract its expiry time.
* @param idToken The JWT ID token.
* @returns The expiry time in **milliseconds**, or undefined if parsing fails.
*/
private parseTokenExpiry(idToken: string): number | undefined {
try {
const payload = JSON.parse(
Buffer.from(idToken.split('.')[1], 'base64').toString(),
);
if (payload && typeof payload.exp === 'number') {
return payload.exp * 1000; // Convert seconds to milliseconds
}
} catch (e) {
console.error('Failed to parse ID token for expiry time with error:', e);
}
// Return undefined if try block fails or 'exp' is missing/invalid
return undefined;
}
} }