mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 05:12:55 -07:00
chore(core): add token caching in google auth provider (#11946)
This commit is contained in:
@@ -82,31 +82,76 @@ describe('GoogleCredentialProvider', () => {
|
|||||||
|
|
||||||
describe('with provider instance', () => {
|
describe('with provider instance', () => {
|
||||||
let provider: GoogleCredentialProvider;
|
let provider: GoogleCredentialProvider;
|
||||||
|
let mockGetAccessToken: Mock;
|
||||||
|
let mockClient: {
|
||||||
|
getAccessToken: Mock;
|
||||||
|
credentials?: { expiry_date: number | null };
|
||||||
|
};
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
|
// clear and reset mock client before each test
|
||||||
|
mockGetAccessToken = vi.fn();
|
||||||
|
mockClient = {
|
||||||
|
getAccessToken: mockGetAccessToken,
|
||||||
|
};
|
||||||
|
(GoogleAuth.prototype.getClient as Mock).mockResolvedValue(mockClient);
|
||||||
provider = new GoogleCredentialProvider(validConfig);
|
provider = new GoogleCredentialProvider(validConfig);
|
||||||
vi.clearAllMocks();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return credentials', async () => {
|
it('should return credentials', async () => {
|
||||||
const mockClient = {
|
mockGetAccessToken.mockResolvedValue({ token: 'test-token' });
|
||||||
getAccessToken: vi.fn().mockResolvedValue({ token: 'test-token' }),
|
|
||||||
};
|
|
||||||
(GoogleAuth.prototype.getClient as Mock).mockResolvedValue(mockClient);
|
|
||||||
|
|
||||||
const credentials = await provider.tokens();
|
const credentials = await provider.tokens();
|
||||||
|
|
||||||
expect(credentials?.access_token).toBe('test-token');
|
expect(credentials?.access_token).toBe('test-token');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return undefined if access token is not available', async () => {
|
it('should return undefined if access token is not available', async () => {
|
||||||
const mockClient = {
|
mockGetAccessToken.mockResolvedValue({ token: null });
|
||||||
getAccessToken: vi.fn().mockResolvedValue({ token: null }),
|
|
||||||
};
|
|
||||||
(GoogleAuth.prototype.getClient as Mock).mockResolvedValue(mockClient);
|
|
||||||
|
|
||||||
const credentials = await provider.tokens();
|
const credentials = await provider.tokens();
|
||||||
expect(credentials).toBeUndefined();
|
expect(credentials).toBeUndefined();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should return a cached token if it is not expired', async () => {
|
||||||
|
vi.useFakeTimers();
|
||||||
|
mockClient.credentials = { expiry_date: Date.now() + 3600 * 1000 }; // 1 hour
|
||||||
|
mockGetAccessToken.mockResolvedValue({ token: 'test-token' });
|
||||||
|
|
||||||
|
// first call
|
||||||
|
const firstTokens = await provider.tokens();
|
||||||
|
expect(firstTokens?.access_token).toBe('test-token');
|
||||||
|
expect(mockGetAccessToken).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
|
// second call
|
||||||
|
vi.advanceTimersByTime(1800 * 1000); // Advance time by 30 minutes
|
||||||
|
const secondTokens = await provider.tokens();
|
||||||
|
expect(secondTokens).toBe(firstTokens);
|
||||||
|
expect(mockGetAccessToken).toHaveBeenCalledTimes(1); // Should not be called again
|
||||||
|
|
||||||
|
vi.useRealTimers();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should fetch a new token if the cached token is expired', async () => {
|
||||||
|
vi.useFakeTimers();
|
||||||
|
|
||||||
|
// first call
|
||||||
|
mockClient.credentials = { expiry_date: Date.now() + 1000 }; // Expires in 1 second
|
||||||
|
mockGetAccessToken.mockResolvedValue({ token: 'expired-token' });
|
||||||
|
|
||||||
|
const firstTokens = await provider.tokens();
|
||||||
|
expect(firstTokens?.access_token).toBe('expired-token');
|
||||||
|
expect(mockGetAccessToken).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
|
// second call
|
||||||
|
vi.advanceTimersByTime(1001); // Advance time past expiry
|
||||||
|
mockClient.credentials = { expiry_date: Date.now() + 3600 * 1000 }; // New expiry
|
||||||
|
mockGetAccessToken.mockResolvedValue({ token: 'new-token' });
|
||||||
|
|
||||||
|
const newTokens = await provider.tokens();
|
||||||
|
expect(newTokens?.access_token).toBe('new-token');
|
||||||
|
expect(mockGetAccessToken).toHaveBeenCalledTimes(2); // new fetch
|
||||||
|
|
||||||
|
vi.useRealTimers();
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -13,11 +13,14 @@ import type {
|
|||||||
} from '@modelcontextprotocol/sdk/shared/auth.js';
|
} from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||||
import { GoogleAuth } from 'google-auth-library';
|
import { GoogleAuth } from 'google-auth-library';
|
||||||
import type { MCPServerConfig } from '../config/config.js';
|
import type { MCPServerConfig } from '../config/config.js';
|
||||||
|
import { FIVE_MIN_BUFFER_MS } from './oauth-utils.js';
|
||||||
|
|
||||||
const ALLOWED_HOSTS = [/^.+\.googleapis\.com$/, /^(.*\.)?luci\.app$/];
|
const ALLOWED_HOSTS = [/^.+\.googleapis\.com$/, /^(.*\.)?luci\.app$/];
|
||||||
|
|
||||||
export class GoogleCredentialProvider implements OAuthClientProvider {
|
export class GoogleCredentialProvider implements OAuthClientProvider {
|
||||||
private readonly auth: GoogleAuth;
|
private readonly auth: GoogleAuth;
|
||||||
|
private cachedToken?: OAuthTokens;
|
||||||
|
private tokenExpiryTime?: number;
|
||||||
|
|
||||||
// Properties required by OAuthClientProvider, with no-op values
|
// Properties required by OAuthClientProvider, with no-op values
|
||||||
readonly redirectUrl = '';
|
readonly redirectUrl = '';
|
||||||
@@ -65,6 +68,19 @@ export class GoogleCredentialProvider implements OAuthClientProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async tokens(): Promise<OAuthTokens | undefined> {
|
async tokens(): Promise<OAuthTokens | undefined> {
|
||||||
|
// check for a valid, non-expired cached token.
|
||||||
|
if (
|
||||||
|
this.cachedToken &&
|
||||||
|
this.tokenExpiryTime &&
|
||||||
|
Date.now() < this.tokenExpiryTime - FIVE_MIN_BUFFER_MS
|
||||||
|
) {
|
||||||
|
return this.cachedToken;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear invalid/expired cache.
|
||||||
|
this.cachedToken = undefined;
|
||||||
|
this.tokenExpiryTime = undefined;
|
||||||
|
|
||||||
const client = await this.auth.getClient();
|
const client = await this.auth.getClient();
|
||||||
const accessTokenResponse = await client.getAccessToken();
|
const accessTokenResponse = await client.getAccessToken();
|
||||||
|
|
||||||
@@ -73,11 +89,18 @@ export class GoogleCredentialProvider implements OAuthClientProvider {
|
|||||||
return undefined;
|
return undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
const tokens: OAuthTokens = {
|
const newToken: OAuthTokens = {
|
||||||
access_token: accessTokenResponse.token,
|
access_token: accessTokenResponse.token,
|
||||||
token_type: 'Bearer',
|
token_type: 'Bearer',
|
||||||
};
|
};
|
||||||
return tokens;
|
|
||||||
|
const expiryTime = client.credentials?.expiry_date;
|
||||||
|
if (expiryTime) {
|
||||||
|
this.tokenExpiryTime = expiryTime;
|
||||||
|
this.cachedToken = newToken;
|
||||||
|
}
|
||||||
|
|
||||||
|
return newToken;
|
||||||
}
|
}
|
||||||
|
|
||||||
saveTokens(_tokens: OAuthTokens): void {
|
saveTokens(_tokens: OAuthTokens): void {
|
||||||
|
|||||||
@@ -325,4 +325,41 @@ describe('OAuthUtils', () => {
|
|||||||
expect(() => OAuthUtils.buildResourceParameter('not-a-url')).toThrow();
|
expect(() => OAuthUtils.buildResourceParameter('not-a-url')).toThrow();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('parseTokenExpiry', () => {
|
||||||
|
it('should return the expiry time in milliseconds for a valid token', () => {
|
||||||
|
// Corresponds to a date of 2100-01-01T00:00:00Z
|
||||||
|
const expiry = 4102444800;
|
||||||
|
const payload = { exp: expiry };
|
||||||
|
const token = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`;
|
||||||
|
const result = OAuthUtils.parseTokenExpiry(token);
|
||||||
|
expect(result).toBe(expiry * 1000);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return undefined for a token without an expiry time', () => {
|
||||||
|
const payload = { iat: 1678886400 };
|
||||||
|
const token = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`;
|
||||||
|
const result = OAuthUtils.parseTokenExpiry(token);
|
||||||
|
expect(result).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return undefined for a token with an invalid expiry time', () => {
|
||||||
|
const payload = { exp: 'not-a-number' };
|
||||||
|
const token = `header.${Buffer.from(JSON.stringify(payload)).toString('base64')}.signature`;
|
||||||
|
const result = OAuthUtils.parseTokenExpiry(token);
|
||||||
|
expect(result).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return undefined for a malformed token', () => {
|
||||||
|
const token = 'not-a-valid-token';
|
||||||
|
const result = OAuthUtils.parseTokenExpiry(token);
|
||||||
|
expect(result).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return undefined for a token with invalid JSON in payload', () => {
|
||||||
|
const token = `header.${Buffer.from('{ not valid json').toString('base64')}.signature`;
|
||||||
|
const result = OAuthUtils.parseTokenExpiry(token);
|
||||||
|
expect(result).toBeUndefined();
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -38,6 +38,8 @@ export interface OAuthProtectedResourceMetadata {
|
|||||||
resource_encryption_enc_values_supported?: string[];
|
resource_encryption_enc_values_supported?: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const FIVE_MIN_BUFFER_MS = 5 * 60 * 1000;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Utility class for common OAuth operations.
|
* Utility class for common OAuth operations.
|
||||||
*/
|
*/
|
||||||
@@ -362,4 +364,26 @@ export class OAuthUtils {
|
|||||||
const url = new URL(endpointUrl);
|
const url = new URL(endpointUrl);
|
||||||
return `${url.protocol}//${url.host}${url.pathname}`;
|
return `${url.protocol}//${url.host}${url.pathname}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parses a JWT string to extract its expiry time.
|
||||||
|
* @param idToken The JWT ID token.
|
||||||
|
* @returns The expiry time in **milliseconds**, or undefined if parsing fails.
|
||||||
|
*/
|
||||||
|
static parseTokenExpiry(idToken: string): number | undefined {
|
||||||
|
try {
|
||||||
|
const payload = JSON.parse(
|
||||||
|
Buffer.from(idToken.split('.')[1], 'base64').toString(),
|
||||||
|
);
|
||||||
|
|
||||||
|
if (payload && typeof payload.exp === 'number') {
|
||||||
|
return payload.exp * 1000; // Convert seconds to milliseconds
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Failed to parse ID token for expiry time with error:', e);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return undefined if try block fails or 'exp' is missing/invalid
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,11 +11,10 @@ import type {
|
|||||||
OAuthTokens,
|
OAuthTokens,
|
||||||
} from '@modelcontextprotocol/sdk/shared/auth.js';
|
} from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||||
import { GoogleAuth } from 'google-auth-library';
|
import { GoogleAuth } from 'google-auth-library';
|
||||||
|
import { OAuthUtils, FIVE_MIN_BUFFER_MS } from './oauth-utils.js';
|
||||||
import type { MCPServerConfig } from '../config/config.js';
|
import type { MCPServerConfig } from '../config/config.js';
|
||||||
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
|
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
|
||||||
|
|
||||||
const fiveMinBufferMs = 5 * 60 * 1000;
|
|
||||||
|
|
||||||
function createIamApiUrl(targetSA: string): string {
|
function createIamApiUrl(targetSA: string): string {
|
||||||
return `https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/${encodeURIComponent(targetSA)}:generateIdToken`;
|
return `https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/${encodeURIComponent(targetSA)}:generateIdToken`;
|
||||||
}
|
}
|
||||||
@@ -78,7 +77,7 @@ export class ServiceAccountImpersonationProvider
|
|||||||
if (
|
if (
|
||||||
this.cachedToken &&
|
this.cachedToken &&
|
||||||
this.tokenExpiryTime &&
|
this.tokenExpiryTime &&
|
||||||
Date.now() < this.tokenExpiryTime - fiveMinBufferMs
|
Date.now() < this.tokenExpiryTime - FIVE_MIN_BUFFER_MS
|
||||||
) {
|
) {
|
||||||
return this.cachedToken;
|
return this.cachedToken;
|
||||||
}
|
}
|
||||||
@@ -112,7 +111,7 @@ export class ServiceAccountImpersonationProvider
|
|||||||
return undefined;
|
return undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
const expiryTime = this.parseTokenExpiry(idToken);
|
const expiryTime = OAuthUtils.parseTokenExpiry(idToken);
|
||||||
// Note: We are placing the OIDC ID Token into the `access_token` field.
|
// Note: We are placing the OIDC ID Token into the `access_token` field.
|
||||||
// This is because the CLI uses this field to construct the
|
// This is because the CLI uses this field to construct the
|
||||||
// `Authorization: Bearer <token>` header, which is the correct way to
|
// `Authorization: Bearer <token>` header, which is the correct way to
|
||||||
@@ -146,26 +145,4 @@ export class ServiceAccountImpersonationProvider
|
|||||||
// No-op
|
// No-op
|
||||||
return '';
|
return '';
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Parses a JWT string to extract its expiry time.
|
|
||||||
* @param idToken The JWT ID token.
|
|
||||||
* @returns The expiry time in **milliseconds**, or undefined if parsing fails.
|
|
||||||
*/
|
|
||||||
private parseTokenExpiry(idToken: string): number | undefined {
|
|
||||||
try {
|
|
||||||
const payload = JSON.parse(
|
|
||||||
Buffer.from(idToken.split('.')[1], 'base64').toString(),
|
|
||||||
);
|
|
||||||
|
|
||||||
if (payload && typeof payload.exp === 'number') {
|
|
||||||
return payload.exp * 1000; // Convert seconds to milliseconds
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
console.error('Failed to parse ID token for expiry time with error:', e);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return undefined if try block fails or 'exp' is missing/invalid
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user