mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 14:10:37 -07:00
chore(core): add token caching in google auth provider (#11946)
This commit is contained in:
@@ -82,31 +82,76 @@ describe('GoogleCredentialProvider', () => {
|
||||
|
||||
describe('with provider instance', () => {
|
||||
let provider: GoogleCredentialProvider;
|
||||
let mockGetAccessToken: Mock;
|
||||
let mockClient: {
|
||||
getAccessToken: Mock;
|
||||
credentials?: { expiry_date: number | null };
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
// clear and reset mock client before each test
|
||||
mockGetAccessToken = vi.fn();
|
||||
mockClient = {
|
||||
getAccessToken: mockGetAccessToken,
|
||||
};
|
||||
(GoogleAuth.prototype.getClient as Mock).mockResolvedValue(mockClient);
|
||||
provider = new GoogleCredentialProvider(validConfig);
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should return credentials', async () => {
|
||||
const mockClient = {
|
||||
getAccessToken: vi.fn().mockResolvedValue({ token: 'test-token' }),
|
||||
};
|
||||
(GoogleAuth.prototype.getClient as Mock).mockResolvedValue(mockClient);
|
||||
mockGetAccessToken.mockResolvedValue({ token: 'test-token' });
|
||||
|
||||
const credentials = await provider.tokens();
|
||||
|
||||
expect(credentials?.access_token).toBe('test-token');
|
||||
});
|
||||
|
||||
it('should return undefined if access token is not available', async () => {
|
||||
const mockClient = {
|
||||
getAccessToken: vi.fn().mockResolvedValue({ token: null }),
|
||||
};
|
||||
(GoogleAuth.prototype.getClient as Mock).mockResolvedValue(mockClient);
|
||||
mockGetAccessToken.mockResolvedValue({ token: null });
|
||||
|
||||
const credentials = await provider.tokens();
|
||||
expect(credentials).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return a cached token if it is not expired', async () => {
|
||||
vi.useFakeTimers();
|
||||
mockClient.credentials = { expiry_date: Date.now() + 3600 * 1000 }; // 1 hour
|
||||
mockGetAccessToken.mockResolvedValue({ token: 'test-token' });
|
||||
|
||||
// first call
|
||||
const firstTokens = await provider.tokens();
|
||||
expect(firstTokens?.access_token).toBe('test-token');
|
||||
expect(mockGetAccessToken).toHaveBeenCalledTimes(1);
|
||||
|
||||
// second call
|
||||
vi.advanceTimersByTime(1800 * 1000); // Advance time by 30 minutes
|
||||
const secondTokens = await provider.tokens();
|
||||
expect(secondTokens).toBe(firstTokens);
|
||||
expect(mockGetAccessToken).toHaveBeenCalledTimes(1); // Should not be called again
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should fetch a new token if the cached token is expired', async () => {
|
||||
vi.useFakeTimers();
|
||||
|
||||
// first call
|
||||
mockClient.credentials = { expiry_date: Date.now() + 1000 }; // Expires in 1 second
|
||||
mockGetAccessToken.mockResolvedValue({ token: 'expired-token' });
|
||||
|
||||
const firstTokens = await provider.tokens();
|
||||
expect(firstTokens?.access_token).toBe('expired-token');
|
||||
expect(mockGetAccessToken).toHaveBeenCalledTimes(1);
|
||||
|
||||
// second call
|
||||
vi.advanceTimersByTime(1001); // Advance time past expiry
|
||||
mockClient.credentials = { expiry_date: Date.now() + 3600 * 1000 }; // New expiry
|
||||
mockGetAccessToken.mockResolvedValue({ token: 'new-token' });
|
||||
|
||||
const newTokens = await provider.tokens();
|
||||
expect(newTokens?.access_token).toBe('new-token');
|
||||
expect(mockGetAccessToken).toHaveBeenCalledTimes(2); // new fetch
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -13,11 +13,14 @@ import type {
|
||||
} from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import { GoogleAuth } from 'google-auth-library';
|
||||
import type { MCPServerConfig } from '../config/config.js';
|
||||
import { FIVE_MIN_BUFFER_MS } from './oauth-utils.js';
|
||||
|
||||
const ALLOWED_HOSTS = [/^.+\.googleapis\.com$/, /^(.*\.)?luci\.app$/];
|
||||
|
||||
export class GoogleCredentialProvider implements OAuthClientProvider {
|
||||
private readonly auth: GoogleAuth;
|
||||
private cachedToken?: OAuthTokens;
|
||||
private tokenExpiryTime?: number;
|
||||
|
||||
// Properties required by OAuthClientProvider, with no-op values
|
||||
readonly redirectUrl = '';
|
||||
@@ -65,6 +68,19 @@ export class GoogleCredentialProvider implements OAuthClientProvider {
|
||||
}
|
||||
|
||||
async tokens(): Promise<OAuthTokens | undefined> {
|
||||
// check for a valid, non-expired cached token.
|
||||
if (
|
||||
this.cachedToken &&
|
||||
this.tokenExpiryTime &&
|
||||
Date.now() < this.tokenExpiryTime - FIVE_MIN_BUFFER_MS
|
||||
) {
|
||||
return this.cachedToken;
|
||||
}
|
||||
|
||||
// Clear invalid/expired cache.
|
||||
this.cachedToken = undefined;
|
||||
this.tokenExpiryTime = undefined;
|
||||
|
||||
const client = await this.auth.getClient();
|
||||
const accessTokenResponse = await client.getAccessToken();
|
||||
|
||||
@@ -73,11 +89,18 @@ export class GoogleCredentialProvider implements OAuthClientProvider {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const tokens: OAuthTokens = {
|
||||
const newToken: OAuthTokens = {
|
||||
access_token: accessTokenResponse.token,
|
||||
token_type: 'Bearer',
|
||||
};
|
||||
return tokens;
|
||||
|
||||
const expiryTime = client.credentials?.expiry_date;
|
||||
if (expiryTime) {
|
||||
this.tokenExpiryTime = expiryTime;
|
||||
this.cachedToken = newToken;
|
||||
}
|
||||
|
||||
return newToken;
|
||||
}
|
||||
|
||||
saveTokens(_tokens: OAuthTokens): void {
|
||||
|
||||
@@ -325,4 +325,41 @@ describe('OAuthUtils', () => {
|
||||
expect(() => OAuthUtils.buildResourceParameter('not-a-url')).toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('parseTokenExpiry', () => {
|
||||
it('should return the expiry time in milliseconds for a valid token', () => {
|
||||
// Corresponds to a date of 2100-01-01T00:00:00Z
|
||||
const expiry = 4102444800;
|
||||
const payload = { exp: expiry };
|
||||
const token = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`;
|
||||
const result = OAuthUtils.parseTokenExpiry(token);
|
||||
expect(result).toBe(expiry * 1000);
|
||||
});
|
||||
|
||||
it('should return undefined for a token without an expiry time', () => {
|
||||
const payload = { iat: 1678886400 };
|
||||
const token = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`;
|
||||
const result = OAuthUtils.parseTokenExpiry(token);
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return undefined for a token with an invalid expiry time', () => {
|
||||
const payload = { exp: 'not-a-number' };
|
||||
const token = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`;
|
||||
const result = OAuthUtils.parseTokenExpiry(token);
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return undefined for a malformed token', () => {
|
||||
const token = 'not-a-valid-token';
|
||||
const result = OAuthUtils.parseTokenExpiry(token);
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return undefined for a token with invalid JSON in payload', () => {
|
||||
const token = `header.${Buffer.from('{ not valid json').toString('base64')}.signature`;
|
||||
const result = OAuthUtils.parseTokenExpiry(token);
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -38,6 +38,8 @@ export interface OAuthProtectedResourceMetadata {
|
||||
resource_encryption_enc_values_supported?: string[];
|
||||
}
|
||||
|
||||
export const FIVE_MIN_BUFFER_MS = 5 * 60 * 1000;
|
||||
|
||||
/**
|
||||
* Utility class for common OAuth operations.
|
||||
*/
|
||||
@@ -362,4 +364,26 @@ export class OAuthUtils {
|
||||
const url = new URL(endpointUrl);
|
||||
return `${url.protocol}//${url.host}${url.pathname}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a JWT string to extract its expiry time.
|
||||
* @param idToken The JWT ID token.
|
||||
* @returns The expiry time in **milliseconds**, or undefined if parsing fails.
|
||||
*/
|
||||
static parseTokenExpiry(idToken: string): number | undefined {
|
||||
try {
|
||||
const payload = JSON.parse(
|
||||
Buffer.from(idToken.split('.')[1], 'base64').toString(),
|
||||
);
|
||||
|
||||
if (payload && typeof payload.exp === 'number') {
|
||||
return payload.exp * 1000; // Convert seconds to milliseconds
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to parse ID token for expiry time with error:', e);
|
||||
}
|
||||
|
||||
// Return undefined if try block fails or 'exp' is missing/invalid
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,11 +11,10 @@ import type {
|
||||
OAuthTokens,
|
||||
} from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import { GoogleAuth } from 'google-auth-library';
|
||||
import { OAuthUtils, FIVE_MIN_BUFFER_MS } from './oauth-utils.js';
|
||||
import type { MCPServerConfig } from '../config/config.js';
|
||||
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
|
||||
|
||||
const fiveMinBufferMs = 5 * 60 * 1000;
|
||||
|
||||
function createIamApiUrl(targetSA: string): string {
|
||||
return `https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/${encodeURIComponent(targetSA)}:generateIdToken`;
|
||||
}
|
||||
@@ -78,7 +77,7 @@ export class ServiceAccountImpersonationProvider
|
||||
if (
|
||||
this.cachedToken &&
|
||||
this.tokenExpiryTime &&
|
||||
Date.now() < this.tokenExpiryTime - fiveMinBufferMs
|
||||
Date.now() < this.tokenExpiryTime - FIVE_MIN_BUFFER_MS
|
||||
) {
|
||||
return this.cachedToken;
|
||||
}
|
||||
@@ -112,7 +111,7 @@ export class ServiceAccountImpersonationProvider
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const expiryTime = this.parseTokenExpiry(idToken);
|
||||
const expiryTime = OAuthUtils.parseTokenExpiry(idToken);
|
||||
// Note: We are placing the OIDC ID Token into the `access_token` field.
|
||||
// This is because the CLI uses this field to construct the
|
||||
// `Authorization: Bearer <token>` header, which is the correct way to
|
||||
@@ -146,26 +145,4 @@ export class ServiceAccountImpersonationProvider
|
||||
// No-op
|
||||
return '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a JWT string to extract its expiry time.
|
||||
* @param idToken The JWT ID token.
|
||||
* @returns The expiry time in **milliseconds**, or undefined if parsing fails.
|
||||
*/
|
||||
private parseTokenExpiry(idToken: string): number | undefined {
|
||||
try {
|
||||
const payload = JSON.parse(
|
||||
Buffer.from(idToken.split('.')[1], 'base64').toString(),
|
||||
);
|
||||
|
||||
if (payload && typeof payload.exp === 'number') {
|
||||
return payload.exp * 1000; // Convert seconds to milliseconds
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Failed to parse ID token for expiry time with error:', e);
|
||||
}
|
||||
|
||||
// Return undefined if try block fails or 'exp' is missing/invalid
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user