From 918ab3c2ecbabc718db8b404d79b49b7e4e8964a Mon Sep 17 00:00:00 2001 From: shishu314 Date: Fri, 5 Sep 2025 12:08:50 -0400 Subject: [PATCH] feat(security) - Make oauth token storage implement the shared interface (#7802) Co-authored-by: Shi Shu --- packages/cli/src/ui/commands/mcpCommand.ts | 2 +- packages/core/src/mcp/oauth-provider.test.ts | 34 +++++--- packages/core/src/mcp/oauth-provider.ts | 6 +- .../core/src/mcp/oauth-token-storage.test.ts | 36 ++++---- packages/core/src/mcp/oauth-token-storage.ts | 86 +++++++++++-------- packages/core/src/tools/mcp-client.ts | 11 +-- 6 files changed, 99 insertions(+), 76 deletions(-) diff --git a/packages/cli/src/ui/commands/mcpCommand.ts b/packages/cli/src/ui/commands/mcpCommand.ts index 7e2358be85..79f6bd3951 100644 --- a/packages/cli/src/ui/commands/mcpCommand.ts +++ b/packages/cli/src/ui/commands/mcpCommand.ts @@ -143,7 +143,7 @@ const getMcpStatus = async ( '@google/gemini-cli-core' ); const tokenStorage = new MCPOAuthTokenStorage(); - const hasToken = await tokenStorage.getToken(serverName); + const hasToken = await tokenStorage.getCredentials(serverName); if (hasToken) { const isExpired = tokenStorage.isTokenExpired(hasToken.token); if (isExpired) { diff --git a/packages/core/src/mcp/oauth-provider.test.ts b/packages/core/src/mcp/oauth-provider.test.ts index 48af4a4a2f..744d8e7db2 100644 --- a/packages/core/src/mcp/oauth-provider.test.ts +++ b/packages/core/src/mcp/oauth-provider.test.ts @@ -14,16 +14,16 @@ vi.mock('../utils/secure-browser-launcher.js', () => ({ vi.mock('node:crypto'); vi.mock('./oauth-token-storage.js', () => { const mockSaveToken = vi.fn(); - const mockGetToken = vi.fn(); + const mockGetCredentials = vi.fn(); const mockIsTokenExpired = vi.fn(); - const mockRemoveToken = vi.fn(); + const mockdeleteCredentials = vi.fn(); return { MCPOAuthTokenStorage: vi.fn(() => ({ saveToken: mockSaveToken, - getToken: mockGetToken, + getCredentials: mockGetCredentials, isTokenExpired: mockIsTokenExpired, - removeToken: mockRemoveToken, + deleteCredentials: mockdeleteCredentials, })), }; }); @@ -163,7 +163,7 @@ describe('MCPOAuthProvider', () => { // Mock token storage const tokenStorage = new MCPOAuthTokenStorage(); vi.mocked(tokenStorage.saveToken).mockResolvedValue(undefined); - vi.mocked(tokenStorage.getToken).mockResolvedValue(null); + vi.mocked(tokenStorage.getCredentials).mockResolvedValue(null); }); afterEach(() => { @@ -638,7 +638,9 @@ describe('MCPOAuthProvider', () => { }; const tokenStorage = new MCPOAuthTokenStorage(); - vi.mocked(tokenStorage.getToken).mockResolvedValue(validCredentials); + vi.mocked(tokenStorage.getCredentials).mockResolvedValue( + validCredentials, + ); vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(false); const authProvider = new MCPOAuthProvider(); @@ -660,7 +662,9 @@ describe('MCPOAuthProvider', () => { }; const tokenStorage = new MCPOAuthTokenStorage(); - vi.mocked(tokenStorage.getToken).mockResolvedValue(expiredCredentials); + vi.mocked(tokenStorage.getCredentials).mockResolvedValue( + expiredCredentials, + ); vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true); const refreshResponse = { @@ -697,7 +701,7 @@ describe('MCPOAuthProvider', () => { it('should return null when no credentials exist', async () => { const tokenStorage = new MCPOAuthTokenStorage(); - vi.mocked(tokenStorage.getToken).mockResolvedValue(null); + vi.mocked(tokenStorage.getCredentials).mockResolvedValue(null); const authProvider = new MCPOAuthProvider(); const result = await authProvider.getValidToken( @@ -718,9 +722,11 @@ describe('MCPOAuthProvider', () => { }; const tokenStorage = new MCPOAuthTokenStorage(); - vi.mocked(tokenStorage.getToken).mockResolvedValue(expiredCredentials); + vi.mocked(tokenStorage.getCredentials).mockResolvedValue( + expiredCredentials, + ); vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true); - vi.mocked(tokenStorage.removeToken).mockResolvedValue(undefined); + vi.mocked(tokenStorage.deleteCredentials).mockResolvedValue(undefined); mockFetch.mockResolvedValueOnce( createMockResponse({ @@ -738,7 +744,9 @@ describe('MCPOAuthProvider', () => { ); expect(result).toBeNull(); - expect(tokenStorage.removeToken).toHaveBeenCalledWith('test-server'); + expect(tokenStorage.deleteCredentials).toHaveBeenCalledWith( + 'test-server', + ); expect(console.error).toHaveBeenCalledWith( expect.stringContaining('Failed to refresh token'), ); @@ -758,7 +766,9 @@ describe('MCPOAuthProvider', () => { }; const tokenStorage = new MCPOAuthTokenStorage(); - vi.mocked(tokenStorage.getToken).mockResolvedValue(tokenWithoutRefresh); + vi.mocked(tokenStorage.getCredentials).mockResolvedValue( + tokenWithoutRefresh, + ); vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true); const authProvider = new MCPOAuthProvider(); diff --git a/packages/core/src/mcp/oauth-provider.ts b/packages/core/src/mcp/oauth-provider.ts index ef34b63a09..3f116b7d7e 100644 --- a/packages/core/src/mcp/oauth-provider.ts +++ b/packages/core/src/mcp/oauth-provider.ts @@ -798,7 +798,7 @@ export class MCPOAuthProvider { console.log('Authentication successful! Token saved.'); // Verify token was saved - const savedToken = await this.tokenStorage.getToken(serverName); + const savedToken = await this.tokenStorage.getCredentials(serverName); if (savedToken && savedToken.token && savedToken.token.accessToken) { const tokenPreview = savedToken.token.accessToken.length > 20 @@ -830,7 +830,7 @@ export class MCPOAuthProvider { config: MCPOAuthConfig, ): Promise { console.debug(`Getting valid token for server: ${serverName}`); - const credentials = await this.tokenStorage.getToken(serverName); + const credentials = await this.tokenStorage.getCredentials(serverName); if (!credentials) { console.debug(`No credentials found for server: ${serverName}`); @@ -884,7 +884,7 @@ export class MCPOAuthProvider { } catch (error) { console.error(`Failed to refresh token: ${getErrorMessage(error)}`); // Remove invalid token - await this.tokenStorage.removeToken(serverName); + await this.tokenStorage.deleteCredentials(serverName); } } diff --git a/packages/core/src/mcp/oauth-token-storage.test.ts b/packages/core/src/mcp/oauth-token-storage.test.ts index 066e8a81aa..6b7f9c8c3f 100644 --- a/packages/core/src/mcp/oauth-token-storage.test.ts +++ b/packages/core/src/mcp/oauth-token-storage.test.ts @@ -54,11 +54,11 @@ describe('MCPOAuthTokenStorage', () => { vi.restoreAllMocks(); }); - describe('loadTokens', () => { + describe('getAllCredentials', () => { it('should return empty map when token file does not exist', async () => { vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' }); - const tokens = await tokenStorage.loadTokens(); + const tokens = await tokenStorage.getAllCredentials(); expect(tokens.size).toBe(0); expect(console.error).not.toHaveBeenCalled(); @@ -68,7 +68,7 @@ describe('MCPOAuthTokenStorage', () => { const tokensArray = [mockCredentials]; vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(tokensArray)); - const tokens = await tokenStorage.loadTokens(); + const tokens = await tokenStorage.getAllCredentials(); expect(tokens.size).toBe(1); expect(tokens.get('test-server')).toEqual(mockCredentials); @@ -81,7 +81,7 @@ describe('MCPOAuthTokenStorage', () => { it('should handle corrupted token file gracefully', async () => { vi.mocked(fs.readFile).mockResolvedValue('invalid json'); - const tokens = await tokenStorage.loadTokens(); + const tokens = await tokenStorage.getAllCredentials(); expect(tokens.size).toBe(0); expect(console.error).toHaveBeenCalledWith( @@ -93,7 +93,7 @@ describe('MCPOAuthTokenStorage', () => { const error = new Error('Permission denied'); vi.mocked(fs.readFile).mockRejectedValue(error); - const tokens = await tokenStorage.loadTokens(); + const tokens = await tokenStorage.getAllCredentials(); expect(tokens.size).toBe(0); expect(console.error).toHaveBeenCalledWith( @@ -163,13 +163,13 @@ describe('MCPOAuthTokenStorage', () => { }); }); - describe('getToken', () => { + describe('getCredentials', () => { it('should return token for existing server', async () => { vi.mocked(fs.readFile).mockResolvedValue( JSON.stringify([mockCredentials]), ); - const result = await tokenStorage.getToken('test-server'); + const result = await tokenStorage.getCredentials('test-server'); expect(result).toEqual(mockCredentials); }); @@ -179,7 +179,7 @@ describe('MCPOAuthTokenStorage', () => { JSON.stringify([mockCredentials]), ); - const result = await tokenStorage.getToken('non-existent'); + const result = await tokenStorage.getCredentials('non-existent'); expect(result).toBeNull(); }); @@ -187,13 +187,13 @@ describe('MCPOAuthTokenStorage', () => { it('should return null when no tokens file exists', async () => { vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' }); - const result = await tokenStorage.getToken('test-server'); + const result = await tokenStorage.getCredentials('test-server'); expect(result).toBeNull(); }); }); - describe('removeToken', () => { + describe('deleteCredentials', () => { it('should remove token for specific server', async () => { const credentials1 = { ...mockCredentials, serverName: 'server1' }; const credentials2 = { ...mockCredentials, serverName: 'server2' }; @@ -202,7 +202,7 @@ describe('MCPOAuthTokenStorage', () => { ); vi.mocked(fs.writeFile).mockResolvedValue(undefined); - await tokenStorage.removeToken('server1'); + await tokenStorage.deleteCredentials('server1'); const writeCall = vi.mocked(fs.writeFile).mock.calls[0]; const savedData = JSON.parse(writeCall[1] as string); @@ -217,7 +217,7 @@ describe('MCPOAuthTokenStorage', () => { ); vi.mocked(fs.unlink).mockResolvedValue(undefined); - await tokenStorage.removeToken('test-server'); + await tokenStorage.deleteCredentials('test-server'); expect(fs.unlink).toHaveBeenCalledWith( path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'), @@ -230,7 +230,7 @@ describe('MCPOAuthTokenStorage', () => { JSON.stringify([mockCredentials]), ); - await tokenStorage.removeToken('non-existent'); + await tokenStorage.deleteCredentials('non-existent'); expect(fs.writeFile).not.toHaveBeenCalled(); expect(fs.unlink).not.toHaveBeenCalled(); @@ -242,7 +242,7 @@ describe('MCPOAuthTokenStorage', () => { ); vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied')); - await tokenStorage.removeToken('test-server'); + await tokenStorage.deleteCredentials('test-server'); expect(console.error).toHaveBeenCalledWith( expect.stringContaining('Failed to remove MCP OAuth token'), @@ -294,11 +294,11 @@ describe('MCPOAuthTokenStorage', () => { }); }); - describe('clearAllTokens', () => { + describe('clearAll', () => { it('should remove token file successfully', async () => { vi.mocked(fs.unlink).mockResolvedValue(undefined); - await tokenStorage.clearAllTokens(); + await tokenStorage.clearAll(); expect(fs.unlink).toHaveBeenCalledWith( path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'), @@ -308,7 +308,7 @@ describe('MCPOAuthTokenStorage', () => { it('should handle non-existent file gracefully', async () => { vi.mocked(fs.unlink).mockRejectedValue({ code: 'ENOENT' }); - await tokenStorage.clearAllTokens(); + await tokenStorage.clearAll(); expect(console.error).not.toHaveBeenCalled(); }); @@ -316,7 +316,7 @@ describe('MCPOAuthTokenStorage', () => { it('should handle other file errors gracefully', async () => { vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied')); - await tokenStorage.clearAllTokens(); + await tokenStorage.clearAll(); expect(console.error).toHaveBeenCalledWith( expect.stringContaining('Failed to clear MCP OAuth tokens'), diff --git a/packages/core/src/mcp/oauth-token-storage.ts b/packages/core/src/mcp/oauth-token-storage.ts index ee8ed00ea0..ba8af40229 100644 --- a/packages/core/src/mcp/oauth-token-storage.ts +++ b/packages/core/src/mcp/oauth-token-storage.ts @@ -8,12 +8,16 @@ import { promises as fs } from 'node:fs'; import * as path from 'node:path'; import { Storage } from '../config/storage.js'; import { getErrorMessage } from '../utils/errors.js'; -import type { OAuthToken, OAuthCredentials } from './token-storage/types.js'; +import type { + OAuthToken, + OAuthCredentials, + TokenStorage, +} from './token-storage/types.js'; /** * Class for managing MCP OAuth token storage and retrieval. */ -export class MCPOAuthTokenStorage { +export class MCPOAuthTokenStorage implements TokenStorage { /** * Get the path to the token storage file. * @@ -36,7 +40,7 @@ export class MCPOAuthTokenStorage { * * @returns A map of server names to credentials */ - async loadTokens(): Promise> { + async getAllCredentials(): Promise> { const tokenMap = new Map(); try { @@ -59,36 +63,14 @@ export class MCPOAuthTokenStorage { return tokenMap; } - /** - * Save a token for a specific MCP server. - * - * @param serverName The name of the MCP server - * @param token The OAuth token to save - * @param clientId Optional client ID used for this token - * @param tokenUrl Optional token URL used for this token - * @param mcpServerUrl Optional MCP server URL - */ - async saveToken( - serverName: string, - token: OAuthToken, - clientId?: string, - tokenUrl?: string, - mcpServerUrl?: string, - ): Promise { - await this.ensureConfigDir(); + async listServers(): Promise { + const tokens = await this.getAllCredentials(); + return Array.from(tokens.keys()); + } - const tokens = await this.loadTokens(); - - const credential: OAuthCredentials = { - serverName, - token, - clientId, - tokenUrl, - mcpServerUrl, - updatedAt: Date.now(), - }; - - tokens.set(serverName, credential); + async setCredentials(credentials: OAuthCredentials): Promise { + const tokens = await this.getAllCredentials(); + tokens.set(credentials.serverName, credentials); const tokenArray = Array.from(tokens.values()); const tokenFile = this.getTokenFilePath(); @@ -107,14 +89,44 @@ export class MCPOAuthTokenStorage { } } + /** + * Save a token for a specific MCP server. + * + * @param serverName The name of the MCP server + * @param token The OAuth token to save + * @param clientId Optional client ID used for this token + * @param tokenUrl Optional token URL used for this token + * @param mcpServerUrl Optional MCP server URL + */ + async saveToken( + serverName: string, + token: OAuthToken, + clientId?: string, + tokenUrl?: string, + mcpServerUrl?: string, + ): Promise { + await this.ensureConfigDir(); + + const credential: OAuthCredentials = { + serverName, + token, + clientId, + tokenUrl, + mcpServerUrl, + updatedAt: Date.now(), + }; + + await this.setCredentials(credential); + } + /** * Get a token for a specific MCP server. * * @param serverName The name of the MCP server * @returns The stored credentials or null if not found */ - async getToken(serverName: string): Promise { - const tokens = await this.loadTokens(); + async getCredentials(serverName: string): Promise { + const tokens = await this.getAllCredentials(); return tokens.get(serverName) || null; } @@ -123,8 +135,8 @@ export class MCPOAuthTokenStorage { * * @param serverName The name of the MCP server */ - async removeToken(serverName: string): Promise { - const tokens = await this.loadTokens(); + async deleteCredentials(serverName: string): Promise { + const tokens = await this.getAllCredentials(); if (tokens.delete(serverName)) { const tokenArray = Array.from(tokens.values()); @@ -166,7 +178,7 @@ export class MCPOAuthTokenStorage { /** * Clear all stored MCP OAuth tokens. */ - async clearAllTokens(): Promise { + async clearAll(): Promise { try { const tokenFile = this.getTokenFilePath(); await fs.unlink(tokenFile); diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 25ec83a632..468d43139f 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -897,7 +897,7 @@ export async function connectToMcpServer( if (!shouldTriggerOAuth) { // For SSE servers without explicit OAuth config, if a token was found but rejected, report it accurately. const tokenStorage = new MCPOAuthTokenStorage(); - const credentials = await tokenStorage.getToken(mcpServerName); + const credentials = await tokenStorage.getCredentials(mcpServerName); if (credentials) { const authProvider = new MCPOAuthProvider(tokenStorage); const hasStoredTokens = await authProvider.getValidToken( @@ -982,7 +982,7 @@ export async function connectToMcpServer( // Get the valid token - we need to create a proper OAuth config // The token should already be available from the authentication process const tokenStorage = new MCPOAuthTokenStorage(); - const credentials = await tokenStorage.getToken(mcpServerName); + const credentials = await tokenStorage.getCredentials(mcpServerName); if (credentials) { const authProvider = new MCPOAuthProvider(tokenStorage); const accessToken = await authProvider.getValidToken( @@ -1057,7 +1057,7 @@ export async function connectToMcpServer( if (!shouldTryDiscovery) { const tokenStorage = new MCPOAuthTokenStorage(); - const credentials = await tokenStorage.getToken(mcpServerName); + const credentials = await tokenStorage.getCredentials(mcpServerName); if (credentials) { const authProvider = new MCPOAuthProvider(tokenStorage); const hasStoredTokens = await authProvider.getValidToken( @@ -1128,7 +1128,8 @@ export async function connectToMcpServer( // Retry connection with OAuth token const tokenStorage = new MCPOAuthTokenStorage(); - const credentials = await tokenStorage.getToken(mcpServerName); + const credentials = + await tokenStorage.getCredentials(mcpServerName); if (credentials) { const authProvider = new MCPOAuthProvider(tokenStorage); const accessToken = await authProvider.getValidToken( @@ -1286,7 +1287,7 @@ export async function createTransport( } else { // Check if we have stored OAuth tokens for this server (from previous authentication) const tokenStorage = new MCPOAuthTokenStorage(); - const credentials = await tokenStorage.getToken(mcpServerName); + const credentials = await tokenStorage.getCredentials(mcpServerName); if (credentials) { const authProvider = new MCPOAuthProvider(tokenStorage); accessToken = await authProvider.getValidToken(mcpServerName, {