diff --git a/packages/cli/src/ui/commands/mcpCommand.test.ts b/packages/cli/src/ui/commands/mcpCommand.test.ts index 212e6e2a77..bcd27f4d8d 100644 --- a/packages/cli/src/ui/commands/mcpCommand.test.ts +++ b/packages/cli/src/ui/commands/mcpCommand.test.ts @@ -22,17 +22,18 @@ import { Type } from '@google/genai'; vi.mock('@google/gemini-cli-core', async (importOriginal) => { const actual = await importOriginal(); + const mockAuthenticate = vi.fn(); return { ...actual, getMCPServerStatus: vi.fn(), getMCPDiscoveryState: vi.fn(), - MCPOAuthProvider: { - authenticate: vi.fn(), - }, - MCPOAuthTokenStorage: { + MCPOAuthProvider: vi.fn(() => ({ + authenticate: mockAuthenticate, + })), + MCPOAuthTokenStorage: vi.fn(() => ({ getToken: vi.fn(), isTokenExpired: vi.fn(), - }, + })), }; }); @@ -892,13 +893,14 @@ describe('mcpCommand', () => { context.ui.reloadCommands = vi.fn(); const { MCPOAuthProvider } = await import('@google/gemini-cli-core'); + const mockAuthProvider = new MCPOAuthProvider(); const authCommand = mcpCommand.subCommands?.find( (cmd) => cmd.name === 'auth', ); const result = await authCommand!.action!(context, 'test-server'); - expect(MCPOAuthProvider.authenticate).toHaveBeenCalledWith( + expect(mockAuthProvider.authenticate).toHaveBeenCalledWith( 'test-server', { enabled: true }, 'http://localhost:3000', @@ -928,9 +930,10 @@ describe('mcpCommand', () => { }); const { MCPOAuthProvider } = await import('@google/gemini-cli-core'); - ( - MCPOAuthProvider.authenticate as ReturnType - ).mockRejectedValue(new Error('Auth failed')); + const mockAuthProvider = new MCPOAuthProvider(); + vi.mocked(mockAuthProvider.authenticate).mockRejectedValue( + new Error('Auth failed'), + ); const authCommand = mcpCommand.subCommands?.find( (cmd) => cmd.name === 'auth', diff --git a/packages/cli/src/ui/commands/mcpCommand.ts b/packages/cli/src/ui/commands/mcpCommand.ts index 50e7ffcf87..7e2358be85 100644 --- a/packages/cli/src/ui/commands/mcpCommand.ts +++ b/packages/cli/src/ui/commands/mcpCommand.ts @@ -20,6 +20,7 @@ import { MCPServerStatus, mcpServerRequiresOAuth, getErrorMessage, + MCPOAuthTokenStorage, } from '@google/gemini-cli-core'; const COLOR_GREEN = '\u001b[32m'; @@ -141,9 +142,10 @@ const getMcpStatus = async ( const { MCPOAuthTokenStorage } = await import( '@google/gemini-cli-core' ); - const hasToken = await MCPOAuthTokenStorage.getToken(serverName); + const tokenStorage = new MCPOAuthTokenStorage(); + const hasToken = await tokenStorage.getToken(serverName); if (hasToken) { - const isExpired = MCPOAuthTokenStorage.isTokenExpired(hasToken.token); + const isExpired = tokenStorage.isTokenExpired(hasToken.token); if (isExpired) { message += ` ${COLOR_YELLOW}(OAuth token expired)${RESET_COLOR}`; } else { @@ -385,11 +387,8 @@ const authCommand: SlashCommand = { // Pass the MCP server URL for OAuth discovery const mcpServerUrl = server.httpUrl || server.url; - await MCPOAuthProvider.authenticate( - serverName, - oauthConfig, - mcpServerUrl, - ); + const authProvider = new MCPOAuthProvider(new MCPOAuthTokenStorage()); + await authProvider.authenticate(serverName, oauthConfig, mcpServerUrl); context.ui.addItem( { diff --git a/packages/core/src/mcp/oauth-provider.test.ts b/packages/core/src/mcp/oauth-provider.test.ts index c13d290c80..48af4a4a2f 100644 --- a/packages/core/src/mcp/oauth-provider.test.ts +++ b/packages/core/src/mcp/oauth-provider.test.ts @@ -12,7 +12,21 @@ vi.mock('../utils/secure-browser-launcher.js', () => ({ openBrowserSecurely: mockOpenBrowserSecurely, })); vi.mock('node:crypto'); -vi.mock('./oauth-token-storage.js'); +vi.mock('./oauth-token-storage.js', () => { + const mockSaveToken = vi.fn(); + const mockGetToken = vi.fn(); + const mockIsTokenExpired = vi.fn(); + const mockRemoveToken = vi.fn(); + + return { + MCPOAuthTokenStorage: vi.fn(() => ({ + saveToken: mockSaveToken, + getToken: mockGetToken, + isTokenExpired: mockIsTokenExpired, + removeToken: mockRemoveToken, + })), + }; +}); import { describe, it, expect, beforeEach, afterEach } from 'vitest'; import * as http from 'node:http'; @@ -147,8 +161,9 @@ describe('MCPOAuthProvider', () => { }); // Mock token storage - vi.mocked(MCPOAuthTokenStorage.saveToken).mockResolvedValue(undefined); - vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null); + const tokenStorage = new MCPOAuthTokenStorage(); + vi.mocked(tokenStorage.saveToken).mockResolvedValue(undefined); + vi.mocked(tokenStorage.getToken).mockResolvedValue(null); }); afterEach(() => { @@ -192,10 +207,8 @@ describe('MCPOAuthProvider', () => { }), ); - const result = await MCPOAuthProvider.authenticate( - 'test-server', - mockConfig, - ); + const authProvider = new MCPOAuthProvider(); + const result = await authProvider.authenticate('test-server', mockConfig); expect(result).toEqual({ accessToken: 'access_token_123', @@ -208,7 +221,8 @@ describe('MCPOAuthProvider', () => { expect(mockOpenBrowserSecurely).toHaveBeenCalledWith( expect.stringContaining('authorize'), ); - expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith( + const tokenStorage = new MCPOAuthTokenStorage(); + expect(tokenStorage.saveToken).toHaveBeenCalledWith( 'test-server', expect.objectContaining({ accessToken: 'access_token_123' }), 'test-client-id', @@ -296,7 +310,8 @@ describe('MCPOAuthProvider', () => { }), ); - const result = await MCPOAuthProvider.authenticate( + const authProvider = new MCPOAuthProvider(); + const result = await authProvider.authenticate( 'test-server', configWithoutAuth, 'https://api.example.com', @@ -385,7 +400,8 @@ describe('MCPOAuthProvider', () => { }), ); - const result = await MCPOAuthProvider.authenticate( + const authProvider = new MCPOAuthProvider(); + const result = await authProvider.authenticate( 'test-server', configWithoutClient, ); @@ -424,8 +440,9 @@ describe('MCPOAuthProvider', () => { }, 10); }); + const authProvider = new MCPOAuthProvider(); await expect( - MCPOAuthProvider.authenticate('test-server', mockConfig), + authProvider.authenticate('test-server', mockConfig), ).rejects.toThrow('OAuth error: access_denied'); }); @@ -453,8 +470,9 @@ describe('MCPOAuthProvider', () => { }, 10); }); + const authProvider = new MCPOAuthProvider(); await expect( - MCPOAuthProvider.authenticate('test-server', mockConfig), + authProvider.authenticate('test-server', mockConfig), ).rejects.toThrow('State mismatch - possible CSRF attack'); }); @@ -491,8 +509,9 @@ describe('MCPOAuthProvider', () => { }), ); + const authProvider = new MCPOAuthProvider(); await expect( - MCPOAuthProvider.authenticate('test-server', mockConfig), + authProvider.authenticate('test-server', mockConfig), ).rejects.toThrow('Token exchange failed: invalid_grant - Invalid grant'); }); @@ -516,8 +535,9 @@ describe('MCPOAuthProvider', () => { return originalSetTimeout(callback, 0); }) as unknown as typeof setTimeout; + const authProvider = new MCPOAuthProvider(); await expect( - MCPOAuthProvider.authenticate('test-server', mockConfig), + authProvider.authenticate('test-server', mockConfig), ).rejects.toThrow('OAuth callback timeout'); global.setTimeout = originalSetTimeout; @@ -542,7 +562,8 @@ describe('MCPOAuthProvider', () => { }), ); - const result = await MCPOAuthProvider.refreshAccessToken( + const authProvider = new MCPOAuthProvider(); + const result = await authProvider.refreshAccessToken( mockConfig, 'old_refresh_token', 'https://auth.example.com/token', @@ -572,7 +593,8 @@ describe('MCPOAuthProvider', () => { }), ); - await MCPOAuthProvider.refreshAccessToken( + const authProvider = new MCPOAuthProvider(); + await authProvider.refreshAccessToken( mockConfig, 'refresh_token', 'https://auth.example.com/token', @@ -592,8 +614,9 @@ describe('MCPOAuthProvider', () => { }), ); + const authProvider = new MCPOAuthProvider(); await expect( - MCPOAuthProvider.refreshAccessToken( + authProvider.refreshAccessToken( mockConfig, 'invalid_refresh_token', 'https://auth.example.com/token', @@ -614,12 +637,12 @@ describe('MCPOAuthProvider', () => { updatedAt: Date.now(), }; - vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue( - validCredentials, - ); - vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(false); + const tokenStorage = new MCPOAuthTokenStorage(); + vi.mocked(tokenStorage.getToken).mockResolvedValue(validCredentials); + vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(false); - const result = await MCPOAuthProvider.getValidToken( + const authProvider = new MCPOAuthProvider(); + const result = await authProvider.getValidToken( 'test-server', mockConfig, ); @@ -636,10 +659,9 @@ describe('MCPOAuthProvider', () => { updatedAt: Date.now(), }; - vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue( - expiredCredentials, - ); - vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true); + const tokenStorage = new MCPOAuthTokenStorage(); + vi.mocked(tokenStorage.getToken).mockResolvedValue(expiredCredentials); + vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true); const refreshResponse = { access_token: 'new_access_token', @@ -657,13 +679,14 @@ describe('MCPOAuthProvider', () => { }), ); - const result = await MCPOAuthProvider.getValidToken( + const authProvider = new MCPOAuthProvider(); + const result = await authProvider.getValidToken( 'test-server', mockConfig, ); expect(result).toBe('new_access_token'); - expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith( + expect(tokenStorage.saveToken).toHaveBeenCalledWith( 'test-server', expect.objectContaining({ accessToken: 'new_access_token' }), 'test-client-id', @@ -673,9 +696,11 @@ describe('MCPOAuthProvider', () => { }); it('should return null when no credentials exist', async () => { - vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null); + const tokenStorage = new MCPOAuthTokenStorage(); + vi.mocked(tokenStorage.getToken).mockResolvedValue(null); - const result = await MCPOAuthProvider.getValidToken( + const authProvider = new MCPOAuthProvider(); + const result = await authProvider.getValidToken( 'test-server', mockConfig, ); @@ -692,11 +717,10 @@ describe('MCPOAuthProvider', () => { updatedAt: Date.now(), }; - vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue( - expiredCredentials, - ); - vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true); - vi.mocked(MCPOAuthTokenStorage.removeToken).mockResolvedValue(undefined); + const tokenStorage = new MCPOAuthTokenStorage(); + vi.mocked(tokenStorage.getToken).mockResolvedValue(expiredCredentials); + vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true); + vi.mocked(tokenStorage.removeToken).mockResolvedValue(undefined); mockFetch.mockResolvedValueOnce( createMockResponse({ @@ -707,15 +731,14 @@ describe('MCPOAuthProvider', () => { }), ); - const result = await MCPOAuthProvider.getValidToken( + const authProvider = new MCPOAuthProvider(); + const result = await authProvider.getValidToken( 'test-server', mockConfig, ); expect(result).toBeNull(); - expect(MCPOAuthTokenStorage.removeToken).toHaveBeenCalledWith( - 'test-server', - ); + expect(tokenStorage.removeToken).toHaveBeenCalledWith('test-server'); expect(console.error).toHaveBeenCalledWith( expect.stringContaining('Failed to refresh token'), ); @@ -734,12 +757,12 @@ describe('MCPOAuthProvider', () => { updatedAt: Date.now(), }; - vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue( - tokenWithoutRefresh, - ); - vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true); + const tokenStorage = new MCPOAuthTokenStorage(); + vi.mocked(tokenStorage.getToken).mockResolvedValue(tokenWithoutRefresh); + vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true); - const result = await MCPOAuthProvider.getValidToken( + const authProvider = new MCPOAuthProvider(); + const result = await authProvider.getValidToken( 'test-server', mockConfig, ); @@ -784,7 +807,8 @@ describe('MCPOAuthProvider', () => { }), ); - await MCPOAuthProvider.authenticate('test-server', mockConfig); + const authProvider = new MCPOAuthProvider(); + await authProvider.authenticate('test-server', mockConfig); expect(crypto.randomBytes).toHaveBeenCalledWith(32); // code verifier expect(crypto.randomBytes).toHaveBeenCalledWith(16); // state @@ -833,7 +857,8 @@ describe('MCPOAuthProvider', () => { }), ); - await MCPOAuthProvider.authenticate( + const authProvider = new MCPOAuthProvider(); + await authProvider.authenticate( 'test-server', mockConfig, 'https://auth.example.com', @@ -894,7 +919,8 @@ describe('MCPOAuthProvider', () => { authorizationUrl: 'https://auth.example.com/authorize?audience=1234', }; - await MCPOAuthProvider.authenticate('test-server', configWithParamsInUrl); + const authProvider = new MCPOAuthProvider(); + await authProvider.authenticate('test-server', configWithParamsInUrl); const url = new URL(capturedUrl!); expect(url.searchParams.get('audience')).toBe('1234'); @@ -947,7 +973,8 @@ describe('MCPOAuthProvider', () => { authorizationUrl: 'https://auth.example.com/authorize#login', }; - await MCPOAuthProvider.authenticate('test-server', configWithFragment); + const authProvider = new MCPOAuthProvider(); + await authProvider.authenticate('test-server', configWithFragment); const url = new URL(capturedUrl!); expect(url.searchParams.get('client_id')).toBe('test-client-id'); diff --git a/packages/core/src/mcp/oauth-provider.ts b/packages/core/src/mcp/oauth-provider.ts index 773d2062bf..ef34b63a09 100644 --- a/packages/core/src/mcp/oauth-provider.ts +++ b/packages/core/src/mcp/oauth-provider.ts @@ -85,13 +85,19 @@ interface PKCEParams { state: string; } +const REDIRECT_PORT = 7777; +const REDIRECT_PATH = '/oauth/callback'; +const HTTP_OK = 200; + /** * Provider for handling OAuth authentication for MCP servers. */ export class MCPOAuthProvider { - private static readonly REDIRECT_PORT = 7777; - private static readonly REDIRECT_PATH = '/oauth/callback'; - private static readonly HTTP_OK = 200; + private readonly tokenStorage: MCPOAuthTokenStorage; + + constructor(tokenStorage: MCPOAuthTokenStorage = new MCPOAuthTokenStorage()) { + this.tokenStorage = tokenStorage; + } /** * Register a client dynamically with the OAuth server. @@ -100,13 +106,12 @@ export class MCPOAuthProvider { * @param config OAuth configuration * @returns The registered client information */ - private static async registerClient( + private async registerClient( registrationUrl: string, config: MCPOAuthConfig, ): Promise { const redirectUri = - config.redirectUri || - `http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`; + config.redirectUri || `http://localhost:${REDIRECT_PORT}${REDIRECT_PATH}`; const registrationRequest: OAuthClientRegistrationRequest = { client_name: 'Gemini CLI MCP Client', @@ -142,7 +147,7 @@ export class MCPOAuthProvider { * @param mcpServerUrl The MCP server URL * @returns OAuth configuration if discovered, null otherwise */ - private static async discoverOAuthFromMCPServer( + private async discoverOAuthFromMCPServer( mcpServerUrl: string, ): Promise { // Use the full URL with path preserved for OAuth discovery @@ -154,7 +159,7 @@ export class MCPOAuthProvider { * * @returns PKCE parameters including code verifier, challenge, and state */ - private static generatePKCEParams(): PKCEParams { + private generatePKCEParams(): PKCEParams { // Generate code verifier (43-128 characters) const codeVerifier = crypto.randomBytes(32).toString('base64url'); @@ -176,19 +181,16 @@ export class MCPOAuthProvider { * @param expectedState The state parameter to validate * @returns Promise that resolves with the authorization code */ - private static async startCallbackServer( + private async startCallbackServer( expectedState: string, ): Promise { return new Promise((resolve, reject) => { const server = http.createServer( async (req: http.IncomingMessage, res: http.ServerResponse) => { try { - const url = new URL( - req.url!, - `http://localhost:${this.REDIRECT_PORT}`, - ); + const url = new URL(req.url!, `http://localhost:${REDIRECT_PORT}`); - if (url.pathname !== this.REDIRECT_PATH) { + if (url.pathname !== REDIRECT_PATH) { res.writeHead(404); res.end('Not found'); return; @@ -199,7 +201,7 @@ export class MCPOAuthProvider { const error = url.searchParams.get('error'); if (error) { - res.writeHead(this.HTTP_OK, { 'Content-Type': 'text/html' }); + res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' }); res.end(` @@ -230,7 +232,7 @@ export class MCPOAuthProvider { } // Send success response to browser - res.writeHead(this.HTTP_OK, { 'Content-Type': 'text/html' }); + res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' }); res.end(` @@ -251,10 +253,8 @@ export class MCPOAuthProvider { ); server.on('error', reject); - server.listen(this.REDIRECT_PORT, () => { - console.log( - `OAuth callback server listening on port ${this.REDIRECT_PORT}`, - ); + server.listen(REDIRECT_PORT, () => { + console.log(`OAuth callback server listening on port ${REDIRECT_PORT}`); }); // Timeout after 5 minutes @@ -276,14 +276,13 @@ export class MCPOAuthProvider { * @param mcpServerUrl The MCP server URL to use as the resource parameter * @returns The authorization URL */ - private static buildAuthorizationUrl( + private buildAuthorizationUrl( config: MCPOAuthConfig, pkceParams: PKCEParams, mcpServerUrl?: string, ): string { const redirectUri = - config.redirectUri || - `http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`; + config.redirectUri || `http://localhost:${REDIRECT_PORT}${REDIRECT_PATH}`; const params = new URLSearchParams({ client_id: config.clientId!, @@ -333,15 +332,14 @@ export class MCPOAuthProvider { * @param mcpServerUrl The MCP server URL to use as the resource parameter * @returns The token response */ - private static async exchangeCodeForToken( + private async exchangeCodeForToken( config: MCPOAuthConfig, code: string, codeVerifier: string, mcpServerUrl?: string, ): Promise { const redirectUri = - config.redirectUri || - `http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`; + config.redirectUri || `http://localhost:${REDIRECT_PORT}${REDIRECT_PATH}`; const params = new URLSearchParams({ grant_type: 'authorization_code', @@ -458,7 +456,7 @@ export class MCPOAuthProvider { * @param mcpServerUrl The MCP server URL to use as the resource parameter * @returns The new token response */ - static async refreshAccessToken( + async refreshAccessToken( config: MCPOAuthConfig, refreshToken: string, tokenUrl: string, @@ -579,7 +577,7 @@ export class MCPOAuthProvider { * @param mcpServerUrl Optional MCP server URL for OAuth discovery * @returns The obtained OAuth token */ - static async authenticate( + async authenticate( serverName: string, config: MCPOAuthConfig, mcpServerUrl?: string, @@ -790,7 +788,7 @@ export class MCPOAuthProvider { // Save token try { - await MCPOAuthTokenStorage.saveToken( + await this.tokenStorage.saveToken( serverName, token, config.clientId, @@ -800,7 +798,7 @@ export class MCPOAuthProvider { console.log('Authentication successful! Token saved.'); // Verify token was saved - const savedToken = await MCPOAuthTokenStorage.getToken(serverName); + const savedToken = await this.tokenStorage.getToken(serverName); if (savedToken && savedToken.token && savedToken.token.accessToken) { const tokenPreview = savedToken.token.accessToken.length > 20 @@ -827,12 +825,12 @@ export class MCPOAuthProvider { * @param config OAuth configuration * @returns A valid access token or null if not authenticated */ - static async getValidToken( + async getValidToken( serverName: string, config: MCPOAuthConfig, ): Promise { console.debug(`Getting valid token for server: ${serverName}`); - const credentials = await MCPOAuthTokenStorage.getToken(serverName); + const credentials = await this.tokenStorage.getToken(serverName); if (!credentials) { console.debug(`No credentials found for server: ${serverName}`); @@ -841,11 +839,11 @@ export class MCPOAuthProvider { const { token } = credentials; console.debug( - `Found token for server: ${serverName}, expired: ${MCPOAuthTokenStorage.isTokenExpired(token)}`, + `Found token for server: ${serverName}, expired: ${this.tokenStorage.isTokenExpired(token)}`, ); // Check if token is expired - if (!MCPOAuthTokenStorage.isTokenExpired(token)) { + if (!this.tokenStorage.isTokenExpired(token)) { console.debug(`Returning valid token for server: ${serverName}`); return token.accessToken; } @@ -874,7 +872,7 @@ export class MCPOAuthProvider { newToken.expiresAt = Date.now() + newTokenResponse.expires_in * 1000; } - await MCPOAuthTokenStorage.saveToken( + await this.tokenStorage.saveToken( serverName, newToken, config.clientId, @@ -886,7 +884,7 @@ export class MCPOAuthProvider { } catch (error) { console.error(`Failed to refresh token: ${getErrorMessage(error)}`); // Remove invalid token - await MCPOAuthTokenStorage.removeToken(serverName); + await this.tokenStorage.removeToken(serverName); } } diff --git a/packages/core/src/mcp/oauth-token-storage.test.ts b/packages/core/src/mcp/oauth-token-storage.test.ts index b574c36b71..066e8a81aa 100644 --- a/packages/core/src/mcp/oauth-token-storage.test.ts +++ b/packages/core/src/mcp/oauth-token-storage.test.ts @@ -26,6 +26,8 @@ vi.mock('node:os', () => ({ })); describe('MCPOAuthTokenStorage', () => { + let tokenStorage: MCPOAuthTokenStorage; + const mockToken: OAuthToken = { accessToken: 'access_token_123', refreshToken: 'refresh_token_456', @@ -43,6 +45,7 @@ describe('MCPOAuthTokenStorage', () => { }; beforeEach(() => { + tokenStorage = new MCPOAuthTokenStorage(); vi.clearAllMocks(); vi.spyOn(console, 'error').mockImplementation(() => {}); }); @@ -55,7 +58,7 @@ describe('MCPOAuthTokenStorage', () => { it('should return empty map when token file does not exist', async () => { vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' }); - const tokens = await MCPOAuthTokenStorage.loadTokens(); + const tokens = await tokenStorage.loadTokens(); expect(tokens.size).toBe(0); expect(console.error).not.toHaveBeenCalled(); @@ -65,7 +68,7 @@ describe('MCPOAuthTokenStorage', () => { const tokensArray = [mockCredentials]; vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(tokensArray)); - const tokens = await MCPOAuthTokenStorage.loadTokens(); + const tokens = await tokenStorage.loadTokens(); expect(tokens.size).toBe(1); expect(tokens.get('test-server')).toEqual(mockCredentials); @@ -78,7 +81,7 @@ describe('MCPOAuthTokenStorage', () => { it('should handle corrupted token file gracefully', async () => { vi.mocked(fs.readFile).mockResolvedValue('invalid json'); - const tokens = await MCPOAuthTokenStorage.loadTokens(); + const tokens = await tokenStorage.loadTokens(); expect(tokens.size).toBe(0); expect(console.error).toHaveBeenCalledWith( @@ -90,7 +93,7 @@ describe('MCPOAuthTokenStorage', () => { const error = new Error('Permission denied'); vi.mocked(fs.readFile).mockRejectedValue(error); - const tokens = await MCPOAuthTokenStorage.loadTokens(); + const tokens = await tokenStorage.loadTokens(); expect(tokens.size).toBe(0); expect(console.error).toHaveBeenCalledWith( @@ -105,7 +108,7 @@ describe('MCPOAuthTokenStorage', () => { vi.mocked(fs.mkdir).mockResolvedValue(undefined); vi.mocked(fs.writeFile).mockResolvedValue(undefined); - await MCPOAuthTokenStorage.saveToken( + await tokenStorage.saveToken( 'test-server', mockToken, 'client-id', @@ -134,7 +137,7 @@ describe('MCPOAuthTokenStorage', () => { vi.mocked(fs.writeFile).mockResolvedValue(undefined); const newToken = { ...mockToken, accessToken: 'new_access_token' }; - await MCPOAuthTokenStorage.saveToken('existing-server', newToken); + await tokenStorage.saveToken('existing-server', newToken); const writeCall = vi.mocked(fs.writeFile).mock.calls[0]; const savedData = JSON.parse(writeCall[1] as string); @@ -151,7 +154,7 @@ describe('MCPOAuthTokenStorage', () => { vi.mocked(fs.writeFile).mockRejectedValue(writeError); await expect( - MCPOAuthTokenStorage.saveToken('test-server', mockToken), + tokenStorage.saveToken('test-server', mockToken), ).rejects.toThrow('Disk full'); expect(console.error).toHaveBeenCalledWith( @@ -166,7 +169,7 @@ describe('MCPOAuthTokenStorage', () => { JSON.stringify([mockCredentials]), ); - const result = await MCPOAuthTokenStorage.getToken('test-server'); + const result = await tokenStorage.getToken('test-server'); expect(result).toEqual(mockCredentials); }); @@ -176,7 +179,7 @@ describe('MCPOAuthTokenStorage', () => { JSON.stringify([mockCredentials]), ); - const result = await MCPOAuthTokenStorage.getToken('non-existent'); + const result = await tokenStorage.getToken('non-existent'); expect(result).toBeNull(); }); @@ -184,7 +187,7 @@ describe('MCPOAuthTokenStorage', () => { it('should return null when no tokens file exists', async () => { vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' }); - const result = await MCPOAuthTokenStorage.getToken('test-server'); + const result = await tokenStorage.getToken('test-server'); expect(result).toBeNull(); }); @@ -199,7 +202,7 @@ describe('MCPOAuthTokenStorage', () => { ); vi.mocked(fs.writeFile).mockResolvedValue(undefined); - await MCPOAuthTokenStorage.removeToken('server1'); + await tokenStorage.removeToken('server1'); const writeCall = vi.mocked(fs.writeFile).mock.calls[0]; const savedData = JSON.parse(writeCall[1] as string); @@ -214,7 +217,7 @@ describe('MCPOAuthTokenStorage', () => { ); vi.mocked(fs.unlink).mockResolvedValue(undefined); - await MCPOAuthTokenStorage.removeToken('test-server'); + await tokenStorage.removeToken('test-server'); expect(fs.unlink).toHaveBeenCalledWith( path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'), @@ -227,7 +230,7 @@ describe('MCPOAuthTokenStorage', () => { JSON.stringify([mockCredentials]), ); - await MCPOAuthTokenStorage.removeToken('non-existent'); + await tokenStorage.removeToken('non-existent'); expect(fs.writeFile).not.toHaveBeenCalled(); expect(fs.unlink).not.toHaveBeenCalled(); @@ -239,7 +242,7 @@ describe('MCPOAuthTokenStorage', () => { ); vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied')); - await MCPOAuthTokenStorage.removeToken('test-server'); + await tokenStorage.removeToken('test-server'); expect(console.error).toHaveBeenCalledWith( expect.stringContaining('Failed to remove MCP OAuth token'), @@ -252,7 +255,7 @@ describe('MCPOAuthTokenStorage', () => { const tokenWithoutExpiry = { ...mockToken }; delete tokenWithoutExpiry.expiresAt; - const result = MCPOAuthTokenStorage.isTokenExpired(tokenWithoutExpiry); + const result = tokenStorage.isTokenExpired(tokenWithoutExpiry); expect(result).toBe(false); }); @@ -263,7 +266,7 @@ describe('MCPOAuthTokenStorage', () => { expiresAt: Date.now() + 3600000, // 1 hour from now }; - const result = MCPOAuthTokenStorage.isTokenExpired(futureToken); + const result = tokenStorage.isTokenExpired(futureToken); expect(result).toBe(false); }); @@ -274,7 +277,7 @@ describe('MCPOAuthTokenStorage', () => { expiresAt: Date.now() - 3600000, // 1 hour ago }; - const result = MCPOAuthTokenStorage.isTokenExpired(expiredToken); + const result = tokenStorage.isTokenExpired(expiredToken); expect(result).toBe(true); }); @@ -285,7 +288,7 @@ describe('MCPOAuthTokenStorage', () => { expiresAt: Date.now() + 60000, // 1 minute from now (within 5-minute buffer) }; - const result = MCPOAuthTokenStorage.isTokenExpired(soonToExpireToken); + const result = tokenStorage.isTokenExpired(soonToExpireToken); expect(result).toBe(true); }); @@ -295,7 +298,7 @@ describe('MCPOAuthTokenStorage', () => { it('should remove token file successfully', async () => { vi.mocked(fs.unlink).mockResolvedValue(undefined); - await MCPOAuthTokenStorage.clearAllTokens(); + await tokenStorage.clearAllTokens(); expect(fs.unlink).toHaveBeenCalledWith( path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'), @@ -305,7 +308,7 @@ describe('MCPOAuthTokenStorage', () => { it('should handle non-existent file gracefully', async () => { vi.mocked(fs.unlink).mockRejectedValue({ code: 'ENOENT' }); - await MCPOAuthTokenStorage.clearAllTokens(); + await tokenStorage.clearAllTokens(); expect(console.error).not.toHaveBeenCalled(); }); @@ -313,7 +316,7 @@ describe('MCPOAuthTokenStorage', () => { it('should handle other file errors gracefully', async () => { vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied')); - await MCPOAuthTokenStorage.clearAllTokens(); + await tokenStorage.clearAllTokens(); expect(console.error).toHaveBeenCalledWith( expect.stringContaining('Failed to clear MCP OAuth tokens'), diff --git a/packages/core/src/mcp/oauth-token-storage.ts b/packages/core/src/mcp/oauth-token-storage.ts index c2c27a2702..ee8ed00ea0 100644 --- a/packages/core/src/mcp/oauth-token-storage.ts +++ b/packages/core/src/mcp/oauth-token-storage.ts @@ -19,14 +19,14 @@ export class MCPOAuthTokenStorage { * * @returns The full path to the token storage file */ - private static getTokenFilePath(): string { + private getTokenFilePath(): string { return Storage.getMcpOAuthTokensPath(); } /** * Ensure the config directory exists. */ - private static async ensureConfigDir(): Promise { + private async ensureConfigDir(): Promise { const configDir = path.dirname(this.getTokenFilePath()); await fs.mkdir(configDir, { recursive: true }); } @@ -36,7 +36,7 @@ export class MCPOAuthTokenStorage { * * @returns A map of server names to credentials */ - static async loadTokens(): Promise> { + async loadTokens(): Promise> { const tokenMap = new Map(); try { @@ -68,7 +68,7 @@ export class MCPOAuthTokenStorage { * @param tokenUrl Optional token URL used for this token * @param mcpServerUrl Optional MCP server URL */ - static async saveToken( + async saveToken( serverName: string, token: OAuthToken, clientId?: string, @@ -113,7 +113,7 @@ export class MCPOAuthTokenStorage { * @param serverName The name of the MCP server * @returns The stored credentials or null if not found */ - static async getToken(serverName: string): Promise { + async getToken(serverName: string): Promise { const tokens = await this.loadTokens(); return tokens.get(serverName) || null; } @@ -123,7 +123,7 @@ export class MCPOAuthTokenStorage { * * @param serverName The name of the MCP server */ - static async removeToken(serverName: string): Promise { + async removeToken(serverName: string): Promise { const tokens = await this.loadTokens(); if (tokens.delete(serverName)) { @@ -153,7 +153,7 @@ export class MCPOAuthTokenStorage { * @param token The token to check * @returns True if the token is expired */ - static isTokenExpired(token: OAuthToken): boolean { + isTokenExpired(token: OAuthToken): boolean { if (!token.expiresAt) { return false; // No expiry, assume valid } @@ -166,7 +166,7 @@ export class MCPOAuthTokenStorage { /** * Clear all stored MCP OAuth tokens. */ - static async clearAllTokens(): Promise { + async clearAllTokens(): Promise { try { const tokenFile = this.getTokenFilePath(); await fs.unlink(tokenFile); diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index e434ead39e..25ec83a632 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -365,11 +365,8 @@ async function handleAutomaticOAuth( console.log( `Starting OAuth authentication for server '${mcpServerName}'...`, ); - await MCPOAuthProvider.authenticate( - mcpServerName, - oauthAuthConfig, - serverUrl, - ); + const authProvider = new MCPOAuthProvider(new MCPOAuthTokenStorage()); + await authProvider.authenticate(mcpServerName, oauthAuthConfig, serverUrl); console.log( `OAuth authentication successful for server '${mcpServerName}'`, @@ -899,9 +896,11 @@ export async function connectToMcpServer( if (!shouldTriggerOAuth) { // For SSE servers without explicit OAuth config, if a token was found but rejected, report it accurately. - const credentials = await MCPOAuthTokenStorage.getToken(mcpServerName); + const tokenStorage = new MCPOAuthTokenStorage(); + const credentials = await tokenStorage.getToken(mcpServerName); if (credentials) { - const hasStoredTokens = await MCPOAuthProvider.getValidToken( + const authProvider = new MCPOAuthProvider(tokenStorage); + const hasStoredTokens = await authProvider.getValidToken( mcpServerName, { // Pass client ID if available @@ -982,10 +981,11 @@ export async function connectToMcpServer( // Get the valid token - we need to create a proper OAuth config // The token should already be available from the authentication process - const credentials = - await MCPOAuthTokenStorage.getToken(mcpServerName); + const tokenStorage = new MCPOAuthTokenStorage(); + const credentials = await tokenStorage.getToken(mcpServerName); if (credentials) { - const accessToken = await MCPOAuthProvider.getValidToken( + const authProvider = new MCPOAuthProvider(tokenStorage); + const accessToken = await authProvider.getValidToken( mcpServerName, { // Pass client ID if available @@ -1056,10 +1056,11 @@ export async function connectToMcpServer( mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled; if (!shouldTryDiscovery) { - const credentials = - await MCPOAuthTokenStorage.getToken(mcpServerName); + const tokenStorage = new MCPOAuthTokenStorage(); + const credentials = await tokenStorage.getToken(mcpServerName); if (credentials) { - const hasStoredTokens = await MCPOAuthProvider.getValidToken( + const authProvider = new MCPOAuthProvider(tokenStorage); + const hasStoredTokens = await authProvider.getValidToken( mcpServerName, { // Pass client ID if available @@ -1116,17 +1117,21 @@ export async function connectToMcpServer( console.log( `Starting OAuth authentication for server '${mcpServerName}'...`, ); - await MCPOAuthProvider.authenticate( + const authProvider = new MCPOAuthProvider( + new MCPOAuthTokenStorage(), + ); + await authProvider.authenticate( mcpServerName, oauthAuthConfig, authServerUrl, ); // Retry connection with OAuth token - const credentials = - await MCPOAuthTokenStorage.getToken(mcpServerName); + const tokenStorage = new MCPOAuthTokenStorage(); + const credentials = await tokenStorage.getToken(mcpServerName); if (credentials) { - const accessToken = await MCPOAuthProvider.getValidToken( + const authProvider = new MCPOAuthProvider(tokenStorage); + const accessToken = await authProvider.getValidToken( mcpServerName, { // Pass client ID if available @@ -1261,7 +1266,9 @@ export async function createTransport( let hasOAuthConfig = mcpServerConfig.oauth?.enabled; if (hasOAuthConfig && mcpServerConfig.oauth) { - accessToken = await MCPOAuthProvider.getValidToken( + const tokenStorage = new MCPOAuthTokenStorage(); + const authProvider = new MCPOAuthProvider(tokenStorage); + accessToken = await authProvider.getValidToken( mcpServerName, mcpServerConfig.oauth, ); @@ -1278,9 +1285,11 @@ export async function createTransport( } } else { // Check if we have stored OAuth tokens for this server (from previous authentication) - const credentials = await MCPOAuthTokenStorage.getToken(mcpServerName); + const tokenStorage = new MCPOAuthTokenStorage(); + const credentials = await tokenStorage.getToken(mcpServerName); if (credentials) { - accessToken = await MCPOAuthProvider.getValidToken(mcpServerName, { + const authProvider = new MCPOAuthProvider(tokenStorage); + accessToken = await authProvider.getValidToken(mcpServerName, { // Pass client ID if available clientId: credentials.clientId, });