Feat(security) - Make the OAuthTokenStorage non static (#7716)

Co-authored-by: Shi Shu <shii@google.com>
This commit is contained in:
shishu314
2025-09-04 16:42:47 -04:00
committed by GitHub
parent e088c06a9a
commit 35a841f71a
7 changed files with 188 additions and 149 deletions

View File

@@ -22,17 +22,18 @@ import { Type } from '@google/genai';
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const actual =
await importOriginal<typeof import('@google/gemini-cli-core')>();
const mockAuthenticate = vi.fn();
return {
...actual,
getMCPServerStatus: vi.fn(),
getMCPDiscoveryState: vi.fn(),
MCPOAuthProvider: {
authenticate: vi.fn(),
},
MCPOAuthTokenStorage: {
MCPOAuthProvider: vi.fn(() => ({
authenticate: mockAuthenticate,
})),
MCPOAuthTokenStorage: vi.fn(() => ({
getToken: vi.fn(),
isTokenExpired: vi.fn(),
},
})),
};
});
@@ -892,13 +893,14 @@ describe('mcpCommand', () => {
context.ui.reloadCommands = vi.fn();
const { MCPOAuthProvider } = await import('@google/gemini-cli-core');
const mockAuthProvider = new MCPOAuthProvider();
const authCommand = mcpCommand.subCommands?.find(
(cmd) => cmd.name === 'auth',
);
const result = await authCommand!.action!(context, 'test-server');
expect(MCPOAuthProvider.authenticate).toHaveBeenCalledWith(
expect(mockAuthProvider.authenticate).toHaveBeenCalledWith(
'test-server',
{ enabled: true },
'http://localhost:3000',
@@ -928,9 +930,10 @@ describe('mcpCommand', () => {
});
const { MCPOAuthProvider } = await import('@google/gemini-cli-core');
(
MCPOAuthProvider.authenticate as ReturnType<typeof vi.fn>
).mockRejectedValue(new Error('Auth failed'));
const mockAuthProvider = new MCPOAuthProvider();
vi.mocked(mockAuthProvider.authenticate).mockRejectedValue(
new Error('Auth failed'),
);
const authCommand = mcpCommand.subCommands?.find(
(cmd) => cmd.name === 'auth',

View File

@@ -20,6 +20,7 @@ import {
MCPServerStatus,
mcpServerRequiresOAuth,
getErrorMessage,
MCPOAuthTokenStorage,
} from '@google/gemini-cli-core';
const COLOR_GREEN = '\u001b[32m';
@@ -141,9 +142,10 @@ const getMcpStatus = async (
const { MCPOAuthTokenStorage } = await import(
'@google/gemini-cli-core'
);
const hasToken = await MCPOAuthTokenStorage.getToken(serverName);
const tokenStorage = new MCPOAuthTokenStorage();
const hasToken = await tokenStorage.getToken(serverName);
if (hasToken) {
const isExpired = MCPOAuthTokenStorage.isTokenExpired(hasToken.token);
const isExpired = tokenStorage.isTokenExpired(hasToken.token);
if (isExpired) {
message += ` ${COLOR_YELLOW}(OAuth token expired)${RESET_COLOR}`;
} else {
@@ -385,11 +387,8 @@ const authCommand: SlashCommand = {
// Pass the MCP server URL for OAuth discovery
const mcpServerUrl = server.httpUrl || server.url;
await MCPOAuthProvider.authenticate(
serverName,
oauthConfig,
mcpServerUrl,
);
const authProvider = new MCPOAuthProvider(new MCPOAuthTokenStorage());
await authProvider.authenticate(serverName, oauthConfig, mcpServerUrl);
context.ui.addItem(
{

View File

@@ -12,7 +12,21 @@ vi.mock('../utils/secure-browser-launcher.js', () => ({
openBrowserSecurely: mockOpenBrowserSecurely,
}));
vi.mock('node:crypto');
vi.mock('./oauth-token-storage.js');
vi.mock('./oauth-token-storage.js', () => {
const mockSaveToken = vi.fn();
const mockGetToken = vi.fn();
const mockIsTokenExpired = vi.fn();
const mockRemoveToken = vi.fn();
return {
MCPOAuthTokenStorage: vi.fn(() => ({
saveToken: mockSaveToken,
getToken: mockGetToken,
isTokenExpired: mockIsTokenExpired,
removeToken: mockRemoveToken,
})),
};
});
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import * as http from 'node:http';
@@ -147,8 +161,9 @@ describe('MCPOAuthProvider', () => {
});
// Mock token storage
vi.mocked(MCPOAuthTokenStorage.saveToken).mockResolvedValue(undefined);
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null);
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.saveToken).mockResolvedValue(undefined);
vi.mocked(tokenStorage.getToken).mockResolvedValue(null);
});
afterEach(() => {
@@ -192,10 +207,8 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.authenticate(
'test-server',
mockConfig,
);
const authProvider = new MCPOAuthProvider();
const result = await authProvider.authenticate('test-server', mockConfig);
expect(result).toEqual({
accessToken: 'access_token_123',
@@ -208,7 +221,8 @@ describe('MCPOAuthProvider', () => {
expect(mockOpenBrowserSecurely).toHaveBeenCalledWith(
expect.stringContaining('authorize'),
);
expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith(
const tokenStorage = new MCPOAuthTokenStorage();
expect(tokenStorage.saveToken).toHaveBeenCalledWith(
'test-server',
expect.objectContaining({ accessToken: 'access_token_123' }),
'test-client-id',
@@ -296,7 +310,8 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.authenticate(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.authenticate(
'test-server',
configWithoutAuth,
'https://api.example.com',
@@ -385,7 +400,8 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.authenticate(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.authenticate(
'test-server',
configWithoutClient,
);
@@ -424,8 +440,9 @@ describe('MCPOAuthProvider', () => {
}, 10);
});
const authProvider = new MCPOAuthProvider();
await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig),
authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('OAuth error: access_denied');
});
@@ -453,8 +470,9 @@ describe('MCPOAuthProvider', () => {
}, 10);
});
const authProvider = new MCPOAuthProvider();
await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig),
authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('State mismatch - possible CSRF attack');
});
@@ -491,8 +509,9 @@ describe('MCPOAuthProvider', () => {
}),
);
const authProvider = new MCPOAuthProvider();
await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig),
authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('Token exchange failed: invalid_grant - Invalid grant');
});
@@ -516,8 +535,9 @@ describe('MCPOAuthProvider', () => {
return originalSetTimeout(callback, 0);
}) as unknown as typeof setTimeout;
const authProvider = new MCPOAuthProvider();
await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig),
authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('OAuth callback timeout');
global.setTimeout = originalSetTimeout;
@@ -542,7 +562,8 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.refreshAccessToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.refreshAccessToken(
mockConfig,
'old_refresh_token',
'https://auth.example.com/token',
@@ -572,7 +593,8 @@ describe('MCPOAuthProvider', () => {
}),
);
await MCPOAuthProvider.refreshAccessToken(
const authProvider = new MCPOAuthProvider();
await authProvider.refreshAccessToken(
mockConfig,
'refresh_token',
'https://auth.example.com/token',
@@ -592,8 +614,9 @@ describe('MCPOAuthProvider', () => {
}),
);
const authProvider = new MCPOAuthProvider();
await expect(
MCPOAuthProvider.refreshAccessToken(
authProvider.refreshAccessToken(
mockConfig,
'invalid_refresh_token',
'https://auth.example.com/token',
@@ -614,12 +637,12 @@ describe('MCPOAuthProvider', () => {
updatedAt: Date.now(),
};
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
validCredentials,
);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(false);
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getToken).mockResolvedValue(validCredentials);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(false);
const result = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server',
mockConfig,
);
@@ -636,10 +659,9 @@ describe('MCPOAuthProvider', () => {
updatedAt: Date.now(),
};
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
expiredCredentials,
);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getToken).mockResolvedValue(expiredCredentials);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
const refreshResponse = {
access_token: 'new_access_token',
@@ -657,13 +679,14 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server',
mockConfig,
);
expect(result).toBe('new_access_token');
expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith(
expect(tokenStorage.saveToken).toHaveBeenCalledWith(
'test-server',
expect.objectContaining({ accessToken: 'new_access_token' }),
'test-client-id',
@@ -673,9 +696,11 @@ describe('MCPOAuthProvider', () => {
});
it('should return null when no credentials exist', async () => {
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null);
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getToken).mockResolvedValue(null);
const result = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server',
mockConfig,
);
@@ -692,11 +717,10 @@ describe('MCPOAuthProvider', () => {
updatedAt: Date.now(),
};
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
expiredCredentials,
);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
vi.mocked(MCPOAuthTokenStorage.removeToken).mockResolvedValue(undefined);
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getToken).mockResolvedValue(expiredCredentials);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
vi.mocked(tokenStorage.removeToken).mockResolvedValue(undefined);
mockFetch.mockResolvedValueOnce(
createMockResponse({
@@ -707,15 +731,14 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server',
mockConfig,
);
expect(result).toBeNull();
expect(MCPOAuthTokenStorage.removeToken).toHaveBeenCalledWith(
'test-server',
);
expect(tokenStorage.removeToken).toHaveBeenCalledWith('test-server');
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to refresh token'),
);
@@ -734,12 +757,12 @@ describe('MCPOAuthProvider', () => {
updatedAt: Date.now(),
};
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
tokenWithoutRefresh,
);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getToken).mockResolvedValue(tokenWithoutRefresh);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
const result = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server',
mockConfig,
);
@@ -784,7 +807,8 @@ describe('MCPOAuthProvider', () => {
}),
);
await MCPOAuthProvider.authenticate('test-server', mockConfig);
const authProvider = new MCPOAuthProvider();
await authProvider.authenticate('test-server', mockConfig);
expect(crypto.randomBytes).toHaveBeenCalledWith(32); // code verifier
expect(crypto.randomBytes).toHaveBeenCalledWith(16); // state
@@ -833,7 +857,8 @@ describe('MCPOAuthProvider', () => {
}),
);
await MCPOAuthProvider.authenticate(
const authProvider = new MCPOAuthProvider();
await authProvider.authenticate(
'test-server',
mockConfig,
'https://auth.example.com',
@@ -894,7 +919,8 @@ describe('MCPOAuthProvider', () => {
authorizationUrl: 'https://auth.example.com/authorize?audience=1234',
};
await MCPOAuthProvider.authenticate('test-server', configWithParamsInUrl);
const authProvider = new MCPOAuthProvider();
await authProvider.authenticate('test-server', configWithParamsInUrl);
const url = new URL(capturedUrl!);
expect(url.searchParams.get('audience')).toBe('1234');
@@ -947,7 +973,8 @@ describe('MCPOAuthProvider', () => {
authorizationUrl: 'https://auth.example.com/authorize#login',
};
await MCPOAuthProvider.authenticate('test-server', configWithFragment);
const authProvider = new MCPOAuthProvider();
await authProvider.authenticate('test-server', configWithFragment);
const url = new URL(capturedUrl!);
expect(url.searchParams.get('client_id')).toBe('test-client-id');

View File

@@ -85,13 +85,19 @@ interface PKCEParams {
state: string;
}
const REDIRECT_PORT = 7777;
const REDIRECT_PATH = '/oauth/callback';
const HTTP_OK = 200;
/**
* Provider for handling OAuth authentication for MCP servers.
*/
export class MCPOAuthProvider {
private static readonly REDIRECT_PORT = 7777;
private static readonly REDIRECT_PATH = '/oauth/callback';
private static readonly HTTP_OK = 200;
private readonly tokenStorage: MCPOAuthTokenStorage;
constructor(tokenStorage: MCPOAuthTokenStorage = new MCPOAuthTokenStorage()) {
this.tokenStorage = tokenStorage;
}
/**
* Register a client dynamically with the OAuth server.
@@ -100,13 +106,12 @@ export class MCPOAuthProvider {
* @param config OAuth configuration
* @returns The registered client information
*/
private static async registerClient(
private async registerClient(
registrationUrl: string,
config: MCPOAuthConfig,
): Promise<OAuthClientRegistrationResponse> {
const redirectUri =
config.redirectUri ||
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
config.redirectUri || `http://localhost:${REDIRECT_PORT}${REDIRECT_PATH}`;
const registrationRequest: OAuthClientRegistrationRequest = {
client_name: 'Gemini CLI MCP Client',
@@ -142,7 +147,7 @@ export class MCPOAuthProvider {
* @param mcpServerUrl The MCP server URL
* @returns OAuth configuration if discovered, null otherwise
*/
private static async discoverOAuthFromMCPServer(
private async discoverOAuthFromMCPServer(
mcpServerUrl: string,
): Promise<MCPOAuthConfig | null> {
// Use the full URL with path preserved for OAuth discovery
@@ -154,7 +159,7 @@ export class MCPOAuthProvider {
*
* @returns PKCE parameters including code verifier, challenge, and state
*/
private static generatePKCEParams(): PKCEParams {
private generatePKCEParams(): PKCEParams {
// Generate code verifier (43-128 characters)
const codeVerifier = crypto.randomBytes(32).toString('base64url');
@@ -176,19 +181,16 @@ export class MCPOAuthProvider {
* @param expectedState The state parameter to validate
* @returns Promise that resolves with the authorization code
*/
private static async startCallbackServer(
private async startCallbackServer(
expectedState: string,
): Promise<OAuthAuthorizationResponse> {
return new Promise((resolve, reject) => {
const server = http.createServer(
async (req: http.IncomingMessage, res: http.ServerResponse) => {
try {
const url = new URL(
req.url!,
`http://localhost:${this.REDIRECT_PORT}`,
);
const url = new URL(req.url!, `http://localhost:${REDIRECT_PORT}`);
if (url.pathname !== this.REDIRECT_PATH) {
if (url.pathname !== REDIRECT_PATH) {
res.writeHead(404);
res.end('Not found');
return;
@@ -199,7 +201,7 @@ export class MCPOAuthProvider {
const error = url.searchParams.get('error');
if (error) {
res.writeHead(this.HTTP_OK, { 'Content-Type': 'text/html' });
res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' });
res.end(`
<html>
<body>
@@ -230,7 +232,7 @@ export class MCPOAuthProvider {
}
// Send success response to browser
res.writeHead(this.HTTP_OK, { 'Content-Type': 'text/html' });
res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' });
res.end(`
<html>
<body>
@@ -251,10 +253,8 @@ export class MCPOAuthProvider {
);
server.on('error', reject);
server.listen(this.REDIRECT_PORT, () => {
console.log(
`OAuth callback server listening on port ${this.REDIRECT_PORT}`,
);
server.listen(REDIRECT_PORT, () => {
console.log(`OAuth callback server listening on port ${REDIRECT_PORT}`);
});
// Timeout after 5 minutes
@@ -276,14 +276,13 @@ export class MCPOAuthProvider {
* @param mcpServerUrl The MCP server URL to use as the resource parameter
* @returns The authorization URL
*/
private static buildAuthorizationUrl(
private buildAuthorizationUrl(
config: MCPOAuthConfig,
pkceParams: PKCEParams,
mcpServerUrl?: string,
): string {
const redirectUri =
config.redirectUri ||
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
config.redirectUri || `http://localhost:${REDIRECT_PORT}${REDIRECT_PATH}`;
const params = new URLSearchParams({
client_id: config.clientId!,
@@ -333,15 +332,14 @@ export class MCPOAuthProvider {
* @param mcpServerUrl The MCP server URL to use as the resource parameter
* @returns The token response
*/
private static async exchangeCodeForToken(
private async exchangeCodeForToken(
config: MCPOAuthConfig,
code: string,
codeVerifier: string,
mcpServerUrl?: string,
): Promise<OAuthTokenResponse> {
const redirectUri =
config.redirectUri ||
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
config.redirectUri || `http://localhost:${REDIRECT_PORT}${REDIRECT_PATH}`;
const params = new URLSearchParams({
grant_type: 'authorization_code',
@@ -458,7 +456,7 @@ export class MCPOAuthProvider {
* @param mcpServerUrl The MCP server URL to use as the resource parameter
* @returns The new token response
*/
static async refreshAccessToken(
async refreshAccessToken(
config: MCPOAuthConfig,
refreshToken: string,
tokenUrl: string,
@@ -579,7 +577,7 @@ export class MCPOAuthProvider {
* @param mcpServerUrl Optional MCP server URL for OAuth discovery
* @returns The obtained OAuth token
*/
static async authenticate(
async authenticate(
serverName: string,
config: MCPOAuthConfig,
mcpServerUrl?: string,
@@ -790,7 +788,7 @@ export class MCPOAuthProvider {
// Save token
try {
await MCPOAuthTokenStorage.saveToken(
await this.tokenStorage.saveToken(
serverName,
token,
config.clientId,
@@ -800,7 +798,7 @@ export class MCPOAuthProvider {
console.log('Authentication successful! Token saved.');
// Verify token was saved
const savedToken = await MCPOAuthTokenStorage.getToken(serverName);
const savedToken = await this.tokenStorage.getToken(serverName);
if (savedToken && savedToken.token && savedToken.token.accessToken) {
const tokenPreview =
savedToken.token.accessToken.length > 20
@@ -827,12 +825,12 @@ export class MCPOAuthProvider {
* @param config OAuth configuration
* @returns A valid access token or null if not authenticated
*/
static async getValidToken(
async getValidToken(
serverName: string,
config: MCPOAuthConfig,
): Promise<string | null> {
console.debug(`Getting valid token for server: ${serverName}`);
const credentials = await MCPOAuthTokenStorage.getToken(serverName);
const credentials = await this.tokenStorage.getToken(serverName);
if (!credentials) {
console.debug(`No credentials found for server: ${serverName}`);
@@ -841,11 +839,11 @@ export class MCPOAuthProvider {
const { token } = credentials;
console.debug(
`Found token for server: ${serverName}, expired: ${MCPOAuthTokenStorage.isTokenExpired(token)}`,
`Found token for server: ${serverName}, expired: ${this.tokenStorage.isTokenExpired(token)}`,
);
// Check if token is expired
if (!MCPOAuthTokenStorage.isTokenExpired(token)) {
if (!this.tokenStorage.isTokenExpired(token)) {
console.debug(`Returning valid token for server: ${serverName}`);
return token.accessToken;
}
@@ -874,7 +872,7 @@ export class MCPOAuthProvider {
newToken.expiresAt = Date.now() + newTokenResponse.expires_in * 1000;
}
await MCPOAuthTokenStorage.saveToken(
await this.tokenStorage.saveToken(
serverName,
newToken,
config.clientId,
@@ -886,7 +884,7 @@ export class MCPOAuthProvider {
} catch (error) {
console.error(`Failed to refresh token: ${getErrorMessage(error)}`);
// Remove invalid token
await MCPOAuthTokenStorage.removeToken(serverName);
await this.tokenStorage.removeToken(serverName);
}
}

View File

@@ -26,6 +26,8 @@ vi.mock('node:os', () => ({
}));
describe('MCPOAuthTokenStorage', () => {
let tokenStorage: MCPOAuthTokenStorage;
const mockToken: OAuthToken = {
accessToken: 'access_token_123',
refreshToken: 'refresh_token_456',
@@ -43,6 +45,7 @@ describe('MCPOAuthTokenStorage', () => {
};
beforeEach(() => {
tokenStorage = new MCPOAuthTokenStorage();
vi.clearAllMocks();
vi.spyOn(console, 'error').mockImplementation(() => {});
});
@@ -55,7 +58,7 @@ describe('MCPOAuthTokenStorage', () => {
it('should return empty map when token file does not exist', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
const tokens = await MCPOAuthTokenStorage.loadTokens();
const tokens = await tokenStorage.loadTokens();
expect(tokens.size).toBe(0);
expect(console.error).not.toHaveBeenCalled();
@@ -65,7 +68,7 @@ describe('MCPOAuthTokenStorage', () => {
const tokensArray = [mockCredentials];
vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(tokensArray));
const tokens = await MCPOAuthTokenStorage.loadTokens();
const tokens = await tokenStorage.loadTokens();
expect(tokens.size).toBe(1);
expect(tokens.get('test-server')).toEqual(mockCredentials);
@@ -78,7 +81,7 @@ describe('MCPOAuthTokenStorage', () => {
it('should handle corrupted token file gracefully', async () => {
vi.mocked(fs.readFile).mockResolvedValue('invalid json');
const tokens = await MCPOAuthTokenStorage.loadTokens();
const tokens = await tokenStorage.loadTokens();
expect(tokens.size).toBe(0);
expect(console.error).toHaveBeenCalledWith(
@@ -90,7 +93,7 @@ describe('MCPOAuthTokenStorage', () => {
const error = new Error('Permission denied');
vi.mocked(fs.readFile).mockRejectedValue(error);
const tokens = await MCPOAuthTokenStorage.loadTokens();
const tokens = await tokenStorage.loadTokens();
expect(tokens.size).toBe(0);
expect(console.error).toHaveBeenCalledWith(
@@ -105,7 +108,7 @@ describe('MCPOAuthTokenStorage', () => {
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
await MCPOAuthTokenStorage.saveToken(
await tokenStorage.saveToken(
'test-server',
mockToken,
'client-id',
@@ -134,7 +137,7 @@ describe('MCPOAuthTokenStorage', () => {
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
const newToken = { ...mockToken, accessToken: 'new_access_token' };
await MCPOAuthTokenStorage.saveToken('existing-server', newToken);
await tokenStorage.saveToken('existing-server', newToken);
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
const savedData = JSON.parse(writeCall[1] as string);
@@ -151,7 +154,7 @@ describe('MCPOAuthTokenStorage', () => {
vi.mocked(fs.writeFile).mockRejectedValue(writeError);
await expect(
MCPOAuthTokenStorage.saveToken('test-server', mockToken),
tokenStorage.saveToken('test-server', mockToken),
).rejects.toThrow('Disk full');
expect(console.error).toHaveBeenCalledWith(
@@ -166,7 +169,7 @@ describe('MCPOAuthTokenStorage', () => {
JSON.stringify([mockCredentials]),
);
const result = await MCPOAuthTokenStorage.getToken('test-server');
const result = await tokenStorage.getToken('test-server');
expect(result).toEqual(mockCredentials);
});
@@ -176,7 +179,7 @@ describe('MCPOAuthTokenStorage', () => {
JSON.stringify([mockCredentials]),
);
const result = await MCPOAuthTokenStorage.getToken('non-existent');
const result = await tokenStorage.getToken('non-existent');
expect(result).toBeNull();
});
@@ -184,7 +187,7 @@ describe('MCPOAuthTokenStorage', () => {
it('should return null when no tokens file exists', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
const result = await MCPOAuthTokenStorage.getToken('test-server');
const result = await tokenStorage.getToken('test-server');
expect(result).toBeNull();
});
@@ -199,7 +202,7 @@ describe('MCPOAuthTokenStorage', () => {
);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
await MCPOAuthTokenStorage.removeToken('server1');
await tokenStorage.removeToken('server1');
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
const savedData = JSON.parse(writeCall[1] as string);
@@ -214,7 +217,7 @@ describe('MCPOAuthTokenStorage', () => {
);
vi.mocked(fs.unlink).mockResolvedValue(undefined);
await MCPOAuthTokenStorage.removeToken('test-server');
await tokenStorage.removeToken('test-server');
expect(fs.unlink).toHaveBeenCalledWith(
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
@@ -227,7 +230,7 @@ describe('MCPOAuthTokenStorage', () => {
JSON.stringify([mockCredentials]),
);
await MCPOAuthTokenStorage.removeToken('non-existent');
await tokenStorage.removeToken('non-existent');
expect(fs.writeFile).not.toHaveBeenCalled();
expect(fs.unlink).not.toHaveBeenCalled();
@@ -239,7 +242,7 @@ describe('MCPOAuthTokenStorage', () => {
);
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
await MCPOAuthTokenStorage.removeToken('test-server');
await tokenStorage.removeToken('test-server');
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to remove MCP OAuth token'),
@@ -252,7 +255,7 @@ describe('MCPOAuthTokenStorage', () => {
const tokenWithoutExpiry = { ...mockToken };
delete tokenWithoutExpiry.expiresAt;
const result = MCPOAuthTokenStorage.isTokenExpired(tokenWithoutExpiry);
const result = tokenStorage.isTokenExpired(tokenWithoutExpiry);
expect(result).toBe(false);
});
@@ -263,7 +266,7 @@ describe('MCPOAuthTokenStorage', () => {
expiresAt: Date.now() + 3600000, // 1 hour from now
};
const result = MCPOAuthTokenStorage.isTokenExpired(futureToken);
const result = tokenStorage.isTokenExpired(futureToken);
expect(result).toBe(false);
});
@@ -274,7 +277,7 @@ describe('MCPOAuthTokenStorage', () => {
expiresAt: Date.now() - 3600000, // 1 hour ago
};
const result = MCPOAuthTokenStorage.isTokenExpired(expiredToken);
const result = tokenStorage.isTokenExpired(expiredToken);
expect(result).toBe(true);
});
@@ -285,7 +288,7 @@ describe('MCPOAuthTokenStorage', () => {
expiresAt: Date.now() + 60000, // 1 minute from now (within 5-minute buffer)
};
const result = MCPOAuthTokenStorage.isTokenExpired(soonToExpireToken);
const result = tokenStorage.isTokenExpired(soonToExpireToken);
expect(result).toBe(true);
});
@@ -295,7 +298,7 @@ describe('MCPOAuthTokenStorage', () => {
it('should remove token file successfully', async () => {
vi.mocked(fs.unlink).mockResolvedValue(undefined);
await MCPOAuthTokenStorage.clearAllTokens();
await tokenStorage.clearAllTokens();
expect(fs.unlink).toHaveBeenCalledWith(
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
@@ -305,7 +308,7 @@ describe('MCPOAuthTokenStorage', () => {
it('should handle non-existent file gracefully', async () => {
vi.mocked(fs.unlink).mockRejectedValue({ code: 'ENOENT' });
await MCPOAuthTokenStorage.clearAllTokens();
await tokenStorage.clearAllTokens();
expect(console.error).not.toHaveBeenCalled();
});
@@ -313,7 +316,7 @@ describe('MCPOAuthTokenStorage', () => {
it('should handle other file errors gracefully', async () => {
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
await MCPOAuthTokenStorage.clearAllTokens();
await tokenStorage.clearAllTokens();
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to clear MCP OAuth tokens'),

View File

@@ -19,14 +19,14 @@ export class MCPOAuthTokenStorage {
*
* @returns The full path to the token storage file
*/
private static getTokenFilePath(): string {
private getTokenFilePath(): string {
return Storage.getMcpOAuthTokensPath();
}
/**
* Ensure the config directory exists.
*/
private static async ensureConfigDir(): Promise<void> {
private async ensureConfigDir(): Promise<void> {
const configDir = path.dirname(this.getTokenFilePath());
await fs.mkdir(configDir, { recursive: true });
}
@@ -36,7 +36,7 @@ export class MCPOAuthTokenStorage {
*
* @returns A map of server names to credentials
*/
static async loadTokens(): Promise<Map<string, OAuthCredentials>> {
async loadTokens(): Promise<Map<string, OAuthCredentials>> {
const tokenMap = new Map<string, OAuthCredentials>();
try {
@@ -68,7 +68,7 @@ export class MCPOAuthTokenStorage {
* @param tokenUrl Optional token URL used for this token
* @param mcpServerUrl Optional MCP server URL
*/
static async saveToken(
async saveToken(
serverName: string,
token: OAuthToken,
clientId?: string,
@@ -113,7 +113,7 @@ export class MCPOAuthTokenStorage {
* @param serverName The name of the MCP server
* @returns The stored credentials or null if not found
*/
static async getToken(serverName: string): Promise<OAuthCredentials | null> {
async getToken(serverName: string): Promise<OAuthCredentials | null> {
const tokens = await this.loadTokens();
return tokens.get(serverName) || null;
}
@@ -123,7 +123,7 @@ export class MCPOAuthTokenStorage {
*
* @param serverName The name of the MCP server
*/
static async removeToken(serverName: string): Promise<void> {
async removeToken(serverName: string): Promise<void> {
const tokens = await this.loadTokens();
if (tokens.delete(serverName)) {
@@ -153,7 +153,7 @@ export class MCPOAuthTokenStorage {
* @param token The token to check
* @returns True if the token is expired
*/
static isTokenExpired(token: OAuthToken): boolean {
isTokenExpired(token: OAuthToken): boolean {
if (!token.expiresAt) {
return false; // No expiry, assume valid
}
@@ -166,7 +166,7 @@ export class MCPOAuthTokenStorage {
/**
* Clear all stored MCP OAuth tokens.
*/
static async clearAllTokens(): Promise<void> {
async clearAllTokens(): Promise<void> {
try {
const tokenFile = this.getTokenFilePath();
await fs.unlink(tokenFile);

View File

@@ -365,11 +365,8 @@ async function handleAutomaticOAuth(
console.log(
`Starting OAuth authentication for server '${mcpServerName}'...`,
);
await MCPOAuthProvider.authenticate(
mcpServerName,
oauthAuthConfig,
serverUrl,
);
const authProvider = new MCPOAuthProvider(new MCPOAuthTokenStorage());
await authProvider.authenticate(mcpServerName, oauthAuthConfig, serverUrl);
console.log(
`OAuth authentication successful for server '${mcpServerName}'`,
@@ -899,9 +896,11 @@ export async function connectToMcpServer(
if (!shouldTriggerOAuth) {
// For SSE servers without explicit OAuth config, if a token was found but rejected, report it accurately.
const credentials = await MCPOAuthTokenStorage.getToken(mcpServerName);
const tokenStorage = new MCPOAuthTokenStorage();
const credentials = await tokenStorage.getToken(mcpServerName);
if (credentials) {
const hasStoredTokens = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider(tokenStorage);
const hasStoredTokens = await authProvider.getValidToken(
mcpServerName,
{
// Pass client ID if available
@@ -982,10 +981,11 @@ export async function connectToMcpServer(
// Get the valid token - we need to create a proper OAuth config
// The token should already be available from the authentication process
const credentials =
await MCPOAuthTokenStorage.getToken(mcpServerName);
const tokenStorage = new MCPOAuthTokenStorage();
const credentials = await tokenStorage.getToken(mcpServerName);
if (credentials) {
const accessToken = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider(tokenStorage);
const accessToken = await authProvider.getValidToken(
mcpServerName,
{
// Pass client ID if available
@@ -1056,10 +1056,11 @@ export async function connectToMcpServer(
mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled;
if (!shouldTryDiscovery) {
const credentials =
await MCPOAuthTokenStorage.getToken(mcpServerName);
const tokenStorage = new MCPOAuthTokenStorage();
const credentials = await tokenStorage.getToken(mcpServerName);
if (credentials) {
const hasStoredTokens = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider(tokenStorage);
const hasStoredTokens = await authProvider.getValidToken(
mcpServerName,
{
// Pass client ID if available
@@ -1116,17 +1117,21 @@ export async function connectToMcpServer(
console.log(
`Starting OAuth authentication for server '${mcpServerName}'...`,
);
await MCPOAuthProvider.authenticate(
const authProvider = new MCPOAuthProvider(
new MCPOAuthTokenStorage(),
);
await authProvider.authenticate(
mcpServerName,
oauthAuthConfig,
authServerUrl,
);
// Retry connection with OAuth token
const credentials =
await MCPOAuthTokenStorage.getToken(mcpServerName);
const tokenStorage = new MCPOAuthTokenStorage();
const credentials = await tokenStorage.getToken(mcpServerName);
if (credentials) {
const accessToken = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider(tokenStorage);
const accessToken = await authProvider.getValidToken(
mcpServerName,
{
// Pass client ID if available
@@ -1261,7 +1266,9 @@ export async function createTransport(
let hasOAuthConfig = mcpServerConfig.oauth?.enabled;
if (hasOAuthConfig && mcpServerConfig.oauth) {
accessToken = await MCPOAuthProvider.getValidToken(
const tokenStorage = new MCPOAuthTokenStorage();
const authProvider = new MCPOAuthProvider(tokenStorage);
accessToken = await authProvider.getValidToken(
mcpServerName,
mcpServerConfig.oauth,
);
@@ -1278,9 +1285,11 @@ export async function createTransport(
}
} else {
// Check if we have stored OAuth tokens for this server (from previous authentication)
const credentials = await MCPOAuthTokenStorage.getToken(mcpServerName);
const tokenStorage = new MCPOAuthTokenStorage();
const credentials = await tokenStorage.getToken(mcpServerName);
if (credentials) {
accessToken = await MCPOAuthProvider.getValidToken(mcpServerName, {
const authProvider = new MCPOAuthProvider(tokenStorage);
accessToken = await authProvider.getValidToken(mcpServerName, {
// Pass client ID if available
clientId: credentials.clientId,
});