Feat(security) - Make the OAuthTokenStorage non static (#7716)

Co-authored-by: Shi Shu <shii@google.com>
This commit is contained in:
shishu314
2025-09-04 16:42:47 -04:00
committed by GitHub
parent e088c06a9a
commit 35a841f71a
7 changed files with 188 additions and 149 deletions
@@ -22,17 +22,18 @@ import { Type } from '@google/genai';
vi.mock('@google/gemini-cli-core', async (importOriginal) => { vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const actual = const actual =
await importOriginal<typeof import('@google/gemini-cli-core')>(); await importOriginal<typeof import('@google/gemini-cli-core')>();
const mockAuthenticate = vi.fn();
return { return {
...actual, ...actual,
getMCPServerStatus: vi.fn(), getMCPServerStatus: vi.fn(),
getMCPDiscoveryState: vi.fn(), getMCPDiscoveryState: vi.fn(),
MCPOAuthProvider: { MCPOAuthProvider: vi.fn(() => ({
authenticate: vi.fn(), authenticate: mockAuthenticate,
}, })),
MCPOAuthTokenStorage: { MCPOAuthTokenStorage: vi.fn(() => ({
getToken: vi.fn(), getToken: vi.fn(),
isTokenExpired: vi.fn(), isTokenExpired: vi.fn(),
}, })),
}; };
}); });
@@ -892,13 +893,14 @@ describe('mcpCommand', () => {
context.ui.reloadCommands = vi.fn(); context.ui.reloadCommands = vi.fn();
const { MCPOAuthProvider } = await import('@google/gemini-cli-core'); const { MCPOAuthProvider } = await import('@google/gemini-cli-core');
const mockAuthProvider = new MCPOAuthProvider();
const authCommand = mcpCommand.subCommands?.find( const authCommand = mcpCommand.subCommands?.find(
(cmd) => cmd.name === 'auth', (cmd) => cmd.name === 'auth',
); );
const result = await authCommand!.action!(context, 'test-server'); const result = await authCommand!.action!(context, 'test-server');
expect(MCPOAuthProvider.authenticate).toHaveBeenCalledWith( expect(mockAuthProvider.authenticate).toHaveBeenCalledWith(
'test-server', 'test-server',
{ enabled: true }, { enabled: true },
'http://localhost:3000', 'http://localhost:3000',
@@ -928,9 +930,10 @@ describe('mcpCommand', () => {
}); });
const { MCPOAuthProvider } = await import('@google/gemini-cli-core'); const { MCPOAuthProvider } = await import('@google/gemini-cli-core');
( const mockAuthProvider = new MCPOAuthProvider();
MCPOAuthProvider.authenticate as ReturnType<typeof vi.fn> vi.mocked(mockAuthProvider.authenticate).mockRejectedValue(
).mockRejectedValue(new Error('Auth failed')); new Error('Auth failed'),
);
const authCommand = mcpCommand.subCommands?.find( const authCommand = mcpCommand.subCommands?.find(
(cmd) => cmd.name === 'auth', (cmd) => cmd.name === 'auth',
+6 -7
View File
@@ -20,6 +20,7 @@ import {
MCPServerStatus, MCPServerStatus,
mcpServerRequiresOAuth, mcpServerRequiresOAuth,
getErrorMessage, getErrorMessage,
MCPOAuthTokenStorage,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
const COLOR_GREEN = '\u001b[32m'; const COLOR_GREEN = '\u001b[32m';
@@ -141,9 +142,10 @@ const getMcpStatus = async (
const { MCPOAuthTokenStorage } = await import( const { MCPOAuthTokenStorage } = await import(
'@google/gemini-cli-core' '@google/gemini-cli-core'
); );
const hasToken = await MCPOAuthTokenStorage.getToken(serverName); const tokenStorage = new MCPOAuthTokenStorage();
const hasToken = await tokenStorage.getToken(serverName);
if (hasToken) { if (hasToken) {
const isExpired = MCPOAuthTokenStorage.isTokenExpired(hasToken.token); const isExpired = tokenStorage.isTokenExpired(hasToken.token);
if (isExpired) { if (isExpired) {
message += ` ${COLOR_YELLOW}(OAuth token expired)${RESET_COLOR}`; message += ` ${COLOR_YELLOW}(OAuth token expired)${RESET_COLOR}`;
} else { } else {
@@ -385,11 +387,8 @@ const authCommand: SlashCommand = {
// Pass the MCP server URL for OAuth discovery // Pass the MCP server URL for OAuth discovery
const mcpServerUrl = server.httpUrl || server.url; const mcpServerUrl = server.httpUrl || server.url;
await MCPOAuthProvider.authenticate( const authProvider = new MCPOAuthProvider(new MCPOAuthTokenStorage());
serverName, await authProvider.authenticate(serverName, oauthConfig, mcpServerUrl);
oauthConfig,
mcpServerUrl,
);
context.ui.addItem( context.ui.addItem(
{ {
+75 -48
View File
@@ -12,7 +12,21 @@ vi.mock('../utils/secure-browser-launcher.js', () => ({
openBrowserSecurely: mockOpenBrowserSecurely, openBrowserSecurely: mockOpenBrowserSecurely,
})); }));
vi.mock('node:crypto'); vi.mock('node:crypto');
vi.mock('./oauth-token-storage.js'); vi.mock('./oauth-token-storage.js', () => {
const mockSaveToken = vi.fn();
const mockGetToken = vi.fn();
const mockIsTokenExpired = vi.fn();
const mockRemoveToken = vi.fn();
return {
MCPOAuthTokenStorage: vi.fn(() => ({
saveToken: mockSaveToken,
getToken: mockGetToken,
isTokenExpired: mockIsTokenExpired,
removeToken: mockRemoveToken,
})),
};
});
import { describe, it, expect, beforeEach, afterEach } from 'vitest'; import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import * as http from 'node:http'; import * as http from 'node:http';
@@ -147,8 +161,9 @@ describe('MCPOAuthProvider', () => {
}); });
// Mock token storage // Mock token storage
vi.mocked(MCPOAuthTokenStorage.saveToken).mockResolvedValue(undefined); const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null); vi.mocked(tokenStorage.saveToken).mockResolvedValue(undefined);
vi.mocked(tokenStorage.getToken).mockResolvedValue(null);
}); });
afterEach(() => { afterEach(() => {
@@ -192,10 +207,8 @@ describe('MCPOAuthProvider', () => {
}), }),
); );
const result = await MCPOAuthProvider.authenticate( const authProvider = new MCPOAuthProvider();
'test-server', const result = await authProvider.authenticate('test-server', mockConfig);
mockConfig,
);
expect(result).toEqual({ expect(result).toEqual({
accessToken: 'access_token_123', accessToken: 'access_token_123',
@@ -208,7 +221,8 @@ describe('MCPOAuthProvider', () => {
expect(mockOpenBrowserSecurely).toHaveBeenCalledWith( expect(mockOpenBrowserSecurely).toHaveBeenCalledWith(
expect.stringContaining('authorize'), expect.stringContaining('authorize'),
); );
expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith( const tokenStorage = new MCPOAuthTokenStorage();
expect(tokenStorage.saveToken).toHaveBeenCalledWith(
'test-server', 'test-server',
expect.objectContaining({ accessToken: 'access_token_123' }), expect.objectContaining({ accessToken: 'access_token_123' }),
'test-client-id', 'test-client-id',
@@ -296,7 +310,8 @@ describe('MCPOAuthProvider', () => {
}), }),
); );
const result = await MCPOAuthProvider.authenticate( const authProvider = new MCPOAuthProvider();
const result = await authProvider.authenticate(
'test-server', 'test-server',
configWithoutAuth, configWithoutAuth,
'https://api.example.com', 'https://api.example.com',
@@ -385,7 +400,8 @@ describe('MCPOAuthProvider', () => {
}), }),
); );
const result = await MCPOAuthProvider.authenticate( const authProvider = new MCPOAuthProvider();
const result = await authProvider.authenticate(
'test-server', 'test-server',
configWithoutClient, configWithoutClient,
); );
@@ -424,8 +440,9 @@ describe('MCPOAuthProvider', () => {
}, 10); }, 10);
}); });
const authProvider = new MCPOAuthProvider();
await expect( await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig), authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('OAuth error: access_denied'); ).rejects.toThrow('OAuth error: access_denied');
}); });
@@ -453,8 +470,9 @@ describe('MCPOAuthProvider', () => {
}, 10); }, 10);
}); });
const authProvider = new MCPOAuthProvider();
await expect( await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig), authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('State mismatch - possible CSRF attack'); ).rejects.toThrow('State mismatch - possible CSRF attack');
}); });
@@ -491,8 +509,9 @@ describe('MCPOAuthProvider', () => {
}), }),
); );
const authProvider = new MCPOAuthProvider();
await expect( await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig), authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('Token exchange failed: invalid_grant - Invalid grant'); ).rejects.toThrow('Token exchange failed: invalid_grant - Invalid grant');
}); });
@@ -516,8 +535,9 @@ describe('MCPOAuthProvider', () => {
return originalSetTimeout(callback, 0); return originalSetTimeout(callback, 0);
}) as unknown as typeof setTimeout; }) as unknown as typeof setTimeout;
const authProvider = new MCPOAuthProvider();
await expect( await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig), authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('OAuth callback timeout'); ).rejects.toThrow('OAuth callback timeout');
global.setTimeout = originalSetTimeout; global.setTimeout = originalSetTimeout;
@@ -542,7 +562,8 @@ describe('MCPOAuthProvider', () => {
}), }),
); );
const result = await MCPOAuthProvider.refreshAccessToken( const authProvider = new MCPOAuthProvider();
const result = await authProvider.refreshAccessToken(
mockConfig, mockConfig,
'old_refresh_token', 'old_refresh_token',
'https://auth.example.com/token', 'https://auth.example.com/token',
@@ -572,7 +593,8 @@ describe('MCPOAuthProvider', () => {
}), }),
); );
await MCPOAuthProvider.refreshAccessToken( const authProvider = new MCPOAuthProvider();
await authProvider.refreshAccessToken(
mockConfig, mockConfig,
'refresh_token', 'refresh_token',
'https://auth.example.com/token', 'https://auth.example.com/token',
@@ -592,8 +614,9 @@ describe('MCPOAuthProvider', () => {
}), }),
); );
const authProvider = new MCPOAuthProvider();
await expect( await expect(
MCPOAuthProvider.refreshAccessToken( authProvider.refreshAccessToken(
mockConfig, mockConfig,
'invalid_refresh_token', 'invalid_refresh_token',
'https://auth.example.com/token', 'https://auth.example.com/token',
@@ -614,12 +637,12 @@ describe('MCPOAuthProvider', () => {
updatedAt: Date.now(), updatedAt: Date.now(),
}; };
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue( const tokenStorage = new MCPOAuthTokenStorage();
validCredentials, vi.mocked(tokenStorage.getToken).mockResolvedValue(validCredentials);
); vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(false);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(false);
const result = await MCPOAuthProvider.getValidToken( const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server', 'test-server',
mockConfig, mockConfig,
); );
@@ -636,10 +659,9 @@ describe('MCPOAuthProvider', () => {
updatedAt: Date.now(), updatedAt: Date.now(),
}; };
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue( const tokenStorage = new MCPOAuthTokenStorage();
expiredCredentials, vi.mocked(tokenStorage.getToken).mockResolvedValue(expiredCredentials);
); vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
const refreshResponse = { const refreshResponse = {
access_token: 'new_access_token', access_token: 'new_access_token',
@@ -657,13 +679,14 @@ describe('MCPOAuthProvider', () => {
}), }),
); );
const result = await MCPOAuthProvider.getValidToken( const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server', 'test-server',
mockConfig, mockConfig,
); );
expect(result).toBe('new_access_token'); expect(result).toBe('new_access_token');
expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith( expect(tokenStorage.saveToken).toHaveBeenCalledWith(
'test-server', 'test-server',
expect.objectContaining({ accessToken: 'new_access_token' }), expect.objectContaining({ accessToken: 'new_access_token' }),
'test-client-id', 'test-client-id',
@@ -673,9 +696,11 @@ describe('MCPOAuthProvider', () => {
}); });
it('should return null when no credentials exist', async () => { it('should return null when no credentials exist', async () => {
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null); const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getToken).mockResolvedValue(null);
const result = await MCPOAuthProvider.getValidToken( const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server', 'test-server',
mockConfig, mockConfig,
); );
@@ -692,11 +717,10 @@ describe('MCPOAuthProvider', () => {
updatedAt: Date.now(), updatedAt: Date.now(),
}; };
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue( const tokenStorage = new MCPOAuthTokenStorage();
expiredCredentials, vi.mocked(tokenStorage.getToken).mockResolvedValue(expiredCredentials);
); vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true); vi.mocked(tokenStorage.removeToken).mockResolvedValue(undefined);
vi.mocked(MCPOAuthTokenStorage.removeToken).mockResolvedValue(undefined);
mockFetch.mockResolvedValueOnce( mockFetch.mockResolvedValueOnce(
createMockResponse({ createMockResponse({
@@ -707,15 +731,14 @@ describe('MCPOAuthProvider', () => {
}), }),
); );
const result = await MCPOAuthProvider.getValidToken( const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server', 'test-server',
mockConfig, mockConfig,
); );
expect(result).toBeNull(); expect(result).toBeNull();
expect(MCPOAuthTokenStorage.removeToken).toHaveBeenCalledWith( expect(tokenStorage.removeToken).toHaveBeenCalledWith('test-server');
'test-server',
);
expect(console.error).toHaveBeenCalledWith( expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to refresh token'), expect.stringContaining('Failed to refresh token'),
); );
@@ -734,12 +757,12 @@ describe('MCPOAuthProvider', () => {
updatedAt: Date.now(), updatedAt: Date.now(),
}; };
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue( const tokenStorage = new MCPOAuthTokenStorage();
tokenWithoutRefresh, vi.mocked(tokenStorage.getToken).mockResolvedValue(tokenWithoutRefresh);
); vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
const result = await MCPOAuthProvider.getValidToken( const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server', 'test-server',
mockConfig, mockConfig,
); );
@@ -784,7 +807,8 @@ describe('MCPOAuthProvider', () => {
}), }),
); );
await MCPOAuthProvider.authenticate('test-server', mockConfig); const authProvider = new MCPOAuthProvider();
await authProvider.authenticate('test-server', mockConfig);
expect(crypto.randomBytes).toHaveBeenCalledWith(32); // code verifier expect(crypto.randomBytes).toHaveBeenCalledWith(32); // code verifier
expect(crypto.randomBytes).toHaveBeenCalledWith(16); // state expect(crypto.randomBytes).toHaveBeenCalledWith(16); // state
@@ -833,7 +857,8 @@ describe('MCPOAuthProvider', () => {
}), }),
); );
await MCPOAuthProvider.authenticate( const authProvider = new MCPOAuthProvider();
await authProvider.authenticate(
'test-server', 'test-server',
mockConfig, mockConfig,
'https://auth.example.com', 'https://auth.example.com',
@@ -894,7 +919,8 @@ describe('MCPOAuthProvider', () => {
authorizationUrl: 'https://auth.example.com/authorize?audience=1234', authorizationUrl: 'https://auth.example.com/authorize?audience=1234',
}; };
await MCPOAuthProvider.authenticate('test-server', configWithParamsInUrl); const authProvider = new MCPOAuthProvider();
await authProvider.authenticate('test-server', configWithParamsInUrl);
const url = new URL(capturedUrl!); const url = new URL(capturedUrl!);
expect(url.searchParams.get('audience')).toBe('1234'); expect(url.searchParams.get('audience')).toBe('1234');
@@ -947,7 +973,8 @@ describe('MCPOAuthProvider', () => {
authorizationUrl: 'https://auth.example.com/authorize#login', authorizationUrl: 'https://auth.example.com/authorize#login',
}; };
await MCPOAuthProvider.authenticate('test-server', configWithFragment); const authProvider = new MCPOAuthProvider();
await authProvider.authenticate('test-server', configWithFragment);
const url = new URL(capturedUrl!); const url = new URL(capturedUrl!);
expect(url.searchParams.get('client_id')).toBe('test-client-id'); expect(url.searchParams.get('client_id')).toBe('test-client-id');
+34 -36
View File
@@ -85,13 +85,19 @@ interface PKCEParams {
state: string; state: string;
} }
const REDIRECT_PORT = 7777;
const REDIRECT_PATH = '/oauth/callback';
const HTTP_OK = 200;
/** /**
* Provider for handling OAuth authentication for MCP servers. * Provider for handling OAuth authentication for MCP servers.
*/ */
export class MCPOAuthProvider { export class MCPOAuthProvider {
private static readonly REDIRECT_PORT = 7777; private readonly tokenStorage: MCPOAuthTokenStorage;
private static readonly REDIRECT_PATH = '/oauth/callback';
private static readonly HTTP_OK = 200; constructor(tokenStorage: MCPOAuthTokenStorage = new MCPOAuthTokenStorage()) {
this.tokenStorage = tokenStorage;
}
/** /**
* Register a client dynamically with the OAuth server. * Register a client dynamically with the OAuth server.
@@ -100,13 +106,12 @@ export class MCPOAuthProvider {
* @param config OAuth configuration * @param config OAuth configuration
* @returns The registered client information * @returns The registered client information
*/ */
private static async registerClient( private async registerClient(
registrationUrl: string, registrationUrl: string,
config: MCPOAuthConfig, config: MCPOAuthConfig,
): Promise<OAuthClientRegistrationResponse> { ): Promise<OAuthClientRegistrationResponse> {
const redirectUri = const redirectUri =
config.redirectUri || config.redirectUri || `http://localhost:${REDIRECT_PORT}${REDIRECT_PATH}`;
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
const registrationRequest: OAuthClientRegistrationRequest = { const registrationRequest: OAuthClientRegistrationRequest = {
client_name: 'Gemini CLI MCP Client', client_name: 'Gemini CLI MCP Client',
@@ -142,7 +147,7 @@ export class MCPOAuthProvider {
* @param mcpServerUrl The MCP server URL * @param mcpServerUrl The MCP server URL
* @returns OAuth configuration if discovered, null otherwise * @returns OAuth configuration if discovered, null otherwise
*/ */
private static async discoverOAuthFromMCPServer( private async discoverOAuthFromMCPServer(
mcpServerUrl: string, mcpServerUrl: string,
): Promise<MCPOAuthConfig | null> { ): Promise<MCPOAuthConfig | null> {
// Use the full URL with path preserved for OAuth discovery // Use the full URL with path preserved for OAuth discovery
@@ -154,7 +159,7 @@ export class MCPOAuthProvider {
* *
* @returns PKCE parameters including code verifier, challenge, and state * @returns PKCE parameters including code verifier, challenge, and state
*/ */
private static generatePKCEParams(): PKCEParams { private generatePKCEParams(): PKCEParams {
// Generate code verifier (43-128 characters) // Generate code verifier (43-128 characters)
const codeVerifier = crypto.randomBytes(32).toString('base64url'); const codeVerifier = crypto.randomBytes(32).toString('base64url');
@@ -176,19 +181,16 @@ export class MCPOAuthProvider {
* @param expectedState The state parameter to validate * @param expectedState The state parameter to validate
* @returns Promise that resolves with the authorization code * @returns Promise that resolves with the authorization code
*/ */
private static async startCallbackServer( private async startCallbackServer(
expectedState: string, expectedState: string,
): Promise<OAuthAuthorizationResponse> { ): Promise<OAuthAuthorizationResponse> {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
const server = http.createServer( const server = http.createServer(
async (req: http.IncomingMessage, res: http.ServerResponse) => { async (req: http.IncomingMessage, res: http.ServerResponse) => {
try { try {
const url = new URL( const url = new URL(req.url!, `http://localhost:${REDIRECT_PORT}`);
req.url!,
`http://localhost:${this.REDIRECT_PORT}`,
);
if (url.pathname !== this.REDIRECT_PATH) { if (url.pathname !== REDIRECT_PATH) {
res.writeHead(404); res.writeHead(404);
res.end('Not found'); res.end('Not found');
return; return;
@@ -199,7 +201,7 @@ export class MCPOAuthProvider {
const error = url.searchParams.get('error'); const error = url.searchParams.get('error');
if (error) { if (error) {
res.writeHead(this.HTTP_OK, { 'Content-Type': 'text/html' }); res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' });
res.end(` res.end(`
<html> <html>
<body> <body>
@@ -230,7 +232,7 @@ export class MCPOAuthProvider {
} }
// Send success response to browser // Send success response to browser
res.writeHead(this.HTTP_OK, { 'Content-Type': 'text/html' }); res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' });
res.end(` res.end(`
<html> <html>
<body> <body>
@@ -251,10 +253,8 @@ export class MCPOAuthProvider {
); );
server.on('error', reject); server.on('error', reject);
server.listen(this.REDIRECT_PORT, () => { server.listen(REDIRECT_PORT, () => {
console.log( console.log(`OAuth callback server listening on port ${REDIRECT_PORT}`);
`OAuth callback server listening on port ${this.REDIRECT_PORT}`,
);
}); });
// Timeout after 5 minutes // Timeout after 5 minutes
@@ -276,14 +276,13 @@ export class MCPOAuthProvider {
* @param mcpServerUrl The MCP server URL to use as the resource parameter * @param mcpServerUrl The MCP server URL to use as the resource parameter
* @returns The authorization URL * @returns The authorization URL
*/ */
private static buildAuthorizationUrl( private buildAuthorizationUrl(
config: MCPOAuthConfig, config: MCPOAuthConfig,
pkceParams: PKCEParams, pkceParams: PKCEParams,
mcpServerUrl?: string, mcpServerUrl?: string,
): string { ): string {
const redirectUri = const redirectUri =
config.redirectUri || config.redirectUri || `http://localhost:${REDIRECT_PORT}${REDIRECT_PATH}`;
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
const params = new URLSearchParams({ const params = new URLSearchParams({
client_id: config.clientId!, client_id: config.clientId!,
@@ -333,15 +332,14 @@ export class MCPOAuthProvider {
* @param mcpServerUrl The MCP server URL to use as the resource parameter * @param mcpServerUrl The MCP server URL to use as the resource parameter
* @returns The token response * @returns The token response
*/ */
private static async exchangeCodeForToken( private async exchangeCodeForToken(
config: MCPOAuthConfig, config: MCPOAuthConfig,
code: string, code: string,
codeVerifier: string, codeVerifier: string,
mcpServerUrl?: string, mcpServerUrl?: string,
): Promise<OAuthTokenResponse> { ): Promise<OAuthTokenResponse> {
const redirectUri = const redirectUri =
config.redirectUri || config.redirectUri || `http://localhost:${REDIRECT_PORT}${REDIRECT_PATH}`;
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
const params = new URLSearchParams({ const params = new URLSearchParams({
grant_type: 'authorization_code', grant_type: 'authorization_code',
@@ -458,7 +456,7 @@ export class MCPOAuthProvider {
* @param mcpServerUrl The MCP server URL to use as the resource parameter * @param mcpServerUrl The MCP server URL to use as the resource parameter
* @returns The new token response * @returns The new token response
*/ */
static async refreshAccessToken( async refreshAccessToken(
config: MCPOAuthConfig, config: MCPOAuthConfig,
refreshToken: string, refreshToken: string,
tokenUrl: string, tokenUrl: string,
@@ -579,7 +577,7 @@ export class MCPOAuthProvider {
* @param mcpServerUrl Optional MCP server URL for OAuth discovery * @param mcpServerUrl Optional MCP server URL for OAuth discovery
* @returns The obtained OAuth token * @returns The obtained OAuth token
*/ */
static async authenticate( async authenticate(
serverName: string, serverName: string,
config: MCPOAuthConfig, config: MCPOAuthConfig,
mcpServerUrl?: string, mcpServerUrl?: string,
@@ -790,7 +788,7 @@ export class MCPOAuthProvider {
// Save token // Save token
try { try {
await MCPOAuthTokenStorage.saveToken( await this.tokenStorage.saveToken(
serverName, serverName,
token, token,
config.clientId, config.clientId,
@@ -800,7 +798,7 @@ export class MCPOAuthProvider {
console.log('Authentication successful! Token saved.'); console.log('Authentication successful! Token saved.');
// Verify token was saved // Verify token was saved
const savedToken = await MCPOAuthTokenStorage.getToken(serverName); const savedToken = await this.tokenStorage.getToken(serverName);
if (savedToken && savedToken.token && savedToken.token.accessToken) { if (savedToken && savedToken.token && savedToken.token.accessToken) {
const tokenPreview = const tokenPreview =
savedToken.token.accessToken.length > 20 savedToken.token.accessToken.length > 20
@@ -827,12 +825,12 @@ export class MCPOAuthProvider {
* @param config OAuth configuration * @param config OAuth configuration
* @returns A valid access token or null if not authenticated * @returns A valid access token or null if not authenticated
*/ */
static async getValidToken( async getValidToken(
serverName: string, serverName: string,
config: MCPOAuthConfig, config: MCPOAuthConfig,
): Promise<string | null> { ): Promise<string | null> {
console.debug(`Getting valid token for server: ${serverName}`); console.debug(`Getting valid token for server: ${serverName}`);
const credentials = await MCPOAuthTokenStorage.getToken(serverName); const credentials = await this.tokenStorage.getToken(serverName);
if (!credentials) { if (!credentials) {
console.debug(`No credentials found for server: ${serverName}`); console.debug(`No credentials found for server: ${serverName}`);
@@ -841,11 +839,11 @@ export class MCPOAuthProvider {
const { token } = credentials; const { token } = credentials;
console.debug( console.debug(
`Found token for server: ${serverName}, expired: ${MCPOAuthTokenStorage.isTokenExpired(token)}`, `Found token for server: ${serverName}, expired: ${this.tokenStorage.isTokenExpired(token)}`,
); );
// Check if token is expired // Check if token is expired
if (!MCPOAuthTokenStorage.isTokenExpired(token)) { if (!this.tokenStorage.isTokenExpired(token)) {
console.debug(`Returning valid token for server: ${serverName}`); console.debug(`Returning valid token for server: ${serverName}`);
return token.accessToken; return token.accessToken;
} }
@@ -874,7 +872,7 @@ export class MCPOAuthProvider {
newToken.expiresAt = Date.now() + newTokenResponse.expires_in * 1000; newToken.expiresAt = Date.now() + newTokenResponse.expires_in * 1000;
} }
await MCPOAuthTokenStorage.saveToken( await this.tokenStorage.saveToken(
serverName, serverName,
newToken, newToken,
config.clientId, config.clientId,
@@ -886,7 +884,7 @@ export class MCPOAuthProvider {
} catch (error) { } catch (error) {
console.error(`Failed to refresh token: ${getErrorMessage(error)}`); console.error(`Failed to refresh token: ${getErrorMessage(error)}`);
// Remove invalid token // Remove invalid token
await MCPOAuthTokenStorage.removeToken(serverName); await this.tokenStorage.removeToken(serverName);
} }
} }
@@ -26,6 +26,8 @@ vi.mock('node:os', () => ({
})); }));
describe('MCPOAuthTokenStorage', () => { describe('MCPOAuthTokenStorage', () => {
let tokenStorage: MCPOAuthTokenStorage;
const mockToken: OAuthToken = { const mockToken: OAuthToken = {
accessToken: 'access_token_123', accessToken: 'access_token_123',
refreshToken: 'refresh_token_456', refreshToken: 'refresh_token_456',
@@ -43,6 +45,7 @@ describe('MCPOAuthTokenStorage', () => {
}; };
beforeEach(() => { beforeEach(() => {
tokenStorage = new MCPOAuthTokenStorage();
vi.clearAllMocks(); vi.clearAllMocks();
vi.spyOn(console, 'error').mockImplementation(() => {}); vi.spyOn(console, 'error').mockImplementation(() => {});
}); });
@@ -55,7 +58,7 @@ describe('MCPOAuthTokenStorage', () => {
it('should return empty map when token file does not exist', async () => { it('should return empty map when token file does not exist', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' }); vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
const tokens = await MCPOAuthTokenStorage.loadTokens(); const tokens = await tokenStorage.loadTokens();
expect(tokens.size).toBe(0); expect(tokens.size).toBe(0);
expect(console.error).not.toHaveBeenCalled(); expect(console.error).not.toHaveBeenCalled();
@@ -65,7 +68,7 @@ describe('MCPOAuthTokenStorage', () => {
const tokensArray = [mockCredentials]; const tokensArray = [mockCredentials];
vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(tokensArray)); vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(tokensArray));
const tokens = await MCPOAuthTokenStorage.loadTokens(); const tokens = await tokenStorage.loadTokens();
expect(tokens.size).toBe(1); expect(tokens.size).toBe(1);
expect(tokens.get('test-server')).toEqual(mockCredentials); expect(tokens.get('test-server')).toEqual(mockCredentials);
@@ -78,7 +81,7 @@ describe('MCPOAuthTokenStorage', () => {
it('should handle corrupted token file gracefully', async () => { it('should handle corrupted token file gracefully', async () => {
vi.mocked(fs.readFile).mockResolvedValue('invalid json'); vi.mocked(fs.readFile).mockResolvedValue('invalid json');
const tokens = await MCPOAuthTokenStorage.loadTokens(); const tokens = await tokenStorage.loadTokens();
expect(tokens.size).toBe(0); expect(tokens.size).toBe(0);
expect(console.error).toHaveBeenCalledWith( expect(console.error).toHaveBeenCalledWith(
@@ -90,7 +93,7 @@ describe('MCPOAuthTokenStorage', () => {
const error = new Error('Permission denied'); const error = new Error('Permission denied');
vi.mocked(fs.readFile).mockRejectedValue(error); vi.mocked(fs.readFile).mockRejectedValue(error);
const tokens = await MCPOAuthTokenStorage.loadTokens(); const tokens = await tokenStorage.loadTokens();
expect(tokens.size).toBe(0); expect(tokens.size).toBe(0);
expect(console.error).toHaveBeenCalledWith( expect(console.error).toHaveBeenCalledWith(
@@ -105,7 +108,7 @@ describe('MCPOAuthTokenStorage', () => {
vi.mocked(fs.mkdir).mockResolvedValue(undefined); vi.mocked(fs.mkdir).mockResolvedValue(undefined);
vi.mocked(fs.writeFile).mockResolvedValue(undefined); vi.mocked(fs.writeFile).mockResolvedValue(undefined);
await MCPOAuthTokenStorage.saveToken( await tokenStorage.saveToken(
'test-server', 'test-server',
mockToken, mockToken,
'client-id', 'client-id',
@@ -134,7 +137,7 @@ describe('MCPOAuthTokenStorage', () => {
vi.mocked(fs.writeFile).mockResolvedValue(undefined); vi.mocked(fs.writeFile).mockResolvedValue(undefined);
const newToken = { ...mockToken, accessToken: 'new_access_token' }; const newToken = { ...mockToken, accessToken: 'new_access_token' };
await MCPOAuthTokenStorage.saveToken('existing-server', newToken); await tokenStorage.saveToken('existing-server', newToken);
const writeCall = vi.mocked(fs.writeFile).mock.calls[0]; const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
const savedData = JSON.parse(writeCall[1] as string); const savedData = JSON.parse(writeCall[1] as string);
@@ -151,7 +154,7 @@ describe('MCPOAuthTokenStorage', () => {
vi.mocked(fs.writeFile).mockRejectedValue(writeError); vi.mocked(fs.writeFile).mockRejectedValue(writeError);
await expect( await expect(
MCPOAuthTokenStorage.saveToken('test-server', mockToken), tokenStorage.saveToken('test-server', mockToken),
).rejects.toThrow('Disk full'); ).rejects.toThrow('Disk full');
expect(console.error).toHaveBeenCalledWith( expect(console.error).toHaveBeenCalledWith(
@@ -166,7 +169,7 @@ describe('MCPOAuthTokenStorage', () => {
JSON.stringify([mockCredentials]), JSON.stringify([mockCredentials]),
); );
const result = await MCPOAuthTokenStorage.getToken('test-server'); const result = await tokenStorage.getToken('test-server');
expect(result).toEqual(mockCredentials); expect(result).toEqual(mockCredentials);
}); });
@@ -176,7 +179,7 @@ describe('MCPOAuthTokenStorage', () => {
JSON.stringify([mockCredentials]), JSON.stringify([mockCredentials]),
); );
const result = await MCPOAuthTokenStorage.getToken('non-existent'); const result = await tokenStorage.getToken('non-existent');
expect(result).toBeNull(); expect(result).toBeNull();
}); });
@@ -184,7 +187,7 @@ describe('MCPOAuthTokenStorage', () => {
it('should return null when no tokens file exists', async () => { it('should return null when no tokens file exists', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' }); vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
const result = await MCPOAuthTokenStorage.getToken('test-server'); const result = await tokenStorage.getToken('test-server');
expect(result).toBeNull(); expect(result).toBeNull();
}); });
@@ -199,7 +202,7 @@ describe('MCPOAuthTokenStorage', () => {
); );
vi.mocked(fs.writeFile).mockResolvedValue(undefined); vi.mocked(fs.writeFile).mockResolvedValue(undefined);
await MCPOAuthTokenStorage.removeToken('server1'); await tokenStorage.removeToken('server1');
const writeCall = vi.mocked(fs.writeFile).mock.calls[0]; const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
const savedData = JSON.parse(writeCall[1] as string); const savedData = JSON.parse(writeCall[1] as string);
@@ -214,7 +217,7 @@ describe('MCPOAuthTokenStorage', () => {
); );
vi.mocked(fs.unlink).mockResolvedValue(undefined); vi.mocked(fs.unlink).mockResolvedValue(undefined);
await MCPOAuthTokenStorage.removeToken('test-server'); await tokenStorage.removeToken('test-server');
expect(fs.unlink).toHaveBeenCalledWith( expect(fs.unlink).toHaveBeenCalledWith(
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'), path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
@@ -227,7 +230,7 @@ describe('MCPOAuthTokenStorage', () => {
JSON.stringify([mockCredentials]), JSON.stringify([mockCredentials]),
); );
await MCPOAuthTokenStorage.removeToken('non-existent'); await tokenStorage.removeToken('non-existent');
expect(fs.writeFile).not.toHaveBeenCalled(); expect(fs.writeFile).not.toHaveBeenCalled();
expect(fs.unlink).not.toHaveBeenCalled(); expect(fs.unlink).not.toHaveBeenCalled();
@@ -239,7 +242,7 @@ describe('MCPOAuthTokenStorage', () => {
); );
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied')); vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
await MCPOAuthTokenStorage.removeToken('test-server'); await tokenStorage.removeToken('test-server');
expect(console.error).toHaveBeenCalledWith( expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to remove MCP OAuth token'), expect.stringContaining('Failed to remove MCP OAuth token'),
@@ -252,7 +255,7 @@ describe('MCPOAuthTokenStorage', () => {
const tokenWithoutExpiry = { ...mockToken }; const tokenWithoutExpiry = { ...mockToken };
delete tokenWithoutExpiry.expiresAt; delete tokenWithoutExpiry.expiresAt;
const result = MCPOAuthTokenStorage.isTokenExpired(tokenWithoutExpiry); const result = tokenStorage.isTokenExpired(tokenWithoutExpiry);
expect(result).toBe(false); expect(result).toBe(false);
}); });
@@ -263,7 +266,7 @@ describe('MCPOAuthTokenStorage', () => {
expiresAt: Date.now() + 3600000, // 1 hour from now expiresAt: Date.now() + 3600000, // 1 hour from now
}; };
const result = MCPOAuthTokenStorage.isTokenExpired(futureToken); const result = tokenStorage.isTokenExpired(futureToken);
expect(result).toBe(false); expect(result).toBe(false);
}); });
@@ -274,7 +277,7 @@ describe('MCPOAuthTokenStorage', () => {
expiresAt: Date.now() - 3600000, // 1 hour ago expiresAt: Date.now() - 3600000, // 1 hour ago
}; };
const result = MCPOAuthTokenStorage.isTokenExpired(expiredToken); const result = tokenStorage.isTokenExpired(expiredToken);
expect(result).toBe(true); expect(result).toBe(true);
}); });
@@ -285,7 +288,7 @@ describe('MCPOAuthTokenStorage', () => {
expiresAt: Date.now() + 60000, // 1 minute from now (within 5-minute buffer) expiresAt: Date.now() + 60000, // 1 minute from now (within 5-minute buffer)
}; };
const result = MCPOAuthTokenStorage.isTokenExpired(soonToExpireToken); const result = tokenStorage.isTokenExpired(soonToExpireToken);
expect(result).toBe(true); expect(result).toBe(true);
}); });
@@ -295,7 +298,7 @@ describe('MCPOAuthTokenStorage', () => {
it('should remove token file successfully', async () => { it('should remove token file successfully', async () => {
vi.mocked(fs.unlink).mockResolvedValue(undefined); vi.mocked(fs.unlink).mockResolvedValue(undefined);
await MCPOAuthTokenStorage.clearAllTokens(); await tokenStorage.clearAllTokens();
expect(fs.unlink).toHaveBeenCalledWith( expect(fs.unlink).toHaveBeenCalledWith(
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'), path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
@@ -305,7 +308,7 @@ describe('MCPOAuthTokenStorage', () => {
it('should handle non-existent file gracefully', async () => { it('should handle non-existent file gracefully', async () => {
vi.mocked(fs.unlink).mockRejectedValue({ code: 'ENOENT' }); vi.mocked(fs.unlink).mockRejectedValue({ code: 'ENOENT' });
await MCPOAuthTokenStorage.clearAllTokens(); await tokenStorage.clearAllTokens();
expect(console.error).not.toHaveBeenCalled(); expect(console.error).not.toHaveBeenCalled();
}); });
@@ -313,7 +316,7 @@ describe('MCPOAuthTokenStorage', () => {
it('should handle other file errors gracefully', async () => { it('should handle other file errors gracefully', async () => {
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied')); vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
await MCPOAuthTokenStorage.clearAllTokens(); await tokenStorage.clearAllTokens();
expect(console.error).toHaveBeenCalledWith( expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to clear MCP OAuth tokens'), expect.stringContaining('Failed to clear MCP OAuth tokens'),
+8 -8
View File
@@ -19,14 +19,14 @@ export class MCPOAuthTokenStorage {
* *
* @returns The full path to the token storage file * @returns The full path to the token storage file
*/ */
private static getTokenFilePath(): string { private getTokenFilePath(): string {
return Storage.getMcpOAuthTokensPath(); return Storage.getMcpOAuthTokensPath();
} }
/** /**
* Ensure the config directory exists. * Ensure the config directory exists.
*/ */
private static async ensureConfigDir(): Promise<void> { private async ensureConfigDir(): Promise<void> {
const configDir = path.dirname(this.getTokenFilePath()); const configDir = path.dirname(this.getTokenFilePath());
await fs.mkdir(configDir, { recursive: true }); await fs.mkdir(configDir, { recursive: true });
} }
@@ -36,7 +36,7 @@ export class MCPOAuthTokenStorage {
* *
* @returns A map of server names to credentials * @returns A map of server names to credentials
*/ */
static async loadTokens(): Promise<Map<string, OAuthCredentials>> { async loadTokens(): Promise<Map<string, OAuthCredentials>> {
const tokenMap = new Map<string, OAuthCredentials>(); const tokenMap = new Map<string, OAuthCredentials>();
try { try {
@@ -68,7 +68,7 @@ export class MCPOAuthTokenStorage {
* @param tokenUrl Optional token URL used for this token * @param tokenUrl Optional token URL used for this token
* @param mcpServerUrl Optional MCP server URL * @param mcpServerUrl Optional MCP server URL
*/ */
static async saveToken( async saveToken(
serverName: string, serverName: string,
token: OAuthToken, token: OAuthToken,
clientId?: string, clientId?: string,
@@ -113,7 +113,7 @@ export class MCPOAuthTokenStorage {
* @param serverName The name of the MCP server * @param serverName The name of the MCP server
* @returns The stored credentials or null if not found * @returns The stored credentials or null if not found
*/ */
static async getToken(serverName: string): Promise<OAuthCredentials | null> { async getToken(serverName: string): Promise<OAuthCredentials | null> {
const tokens = await this.loadTokens(); const tokens = await this.loadTokens();
return tokens.get(serverName) || null; return tokens.get(serverName) || null;
} }
@@ -123,7 +123,7 @@ export class MCPOAuthTokenStorage {
* *
* @param serverName The name of the MCP server * @param serverName The name of the MCP server
*/ */
static async removeToken(serverName: string): Promise<void> { async removeToken(serverName: string): Promise<void> {
const tokens = await this.loadTokens(); const tokens = await this.loadTokens();
if (tokens.delete(serverName)) { if (tokens.delete(serverName)) {
@@ -153,7 +153,7 @@ export class MCPOAuthTokenStorage {
* @param token The token to check * @param token The token to check
* @returns True if the token is expired * @returns True if the token is expired
*/ */
static isTokenExpired(token: OAuthToken): boolean { isTokenExpired(token: OAuthToken): boolean {
if (!token.expiresAt) { if (!token.expiresAt) {
return false; // No expiry, assume valid return false; // No expiry, assume valid
} }
@@ -166,7 +166,7 @@ export class MCPOAuthTokenStorage {
/** /**
* Clear all stored MCP OAuth tokens. * Clear all stored MCP OAuth tokens.
*/ */
static async clearAllTokens(): Promise<void> { async clearAllTokens(): Promise<void> {
try { try {
const tokenFile = this.getTokenFilePath(); const tokenFile = this.getTokenFilePath();
await fs.unlink(tokenFile); await fs.unlink(tokenFile);
+29 -20
View File
@@ -365,11 +365,8 @@ async function handleAutomaticOAuth(
console.log( console.log(
`Starting OAuth authentication for server '${mcpServerName}'...`, `Starting OAuth authentication for server '${mcpServerName}'...`,
); );
await MCPOAuthProvider.authenticate( const authProvider = new MCPOAuthProvider(new MCPOAuthTokenStorage());
mcpServerName, await authProvider.authenticate(mcpServerName, oauthAuthConfig, serverUrl);
oauthAuthConfig,
serverUrl,
);
console.log( console.log(
`OAuth authentication successful for server '${mcpServerName}'`, `OAuth authentication successful for server '${mcpServerName}'`,
@@ -899,9 +896,11 @@ export async function connectToMcpServer(
if (!shouldTriggerOAuth) { if (!shouldTriggerOAuth) {
// For SSE servers without explicit OAuth config, if a token was found but rejected, report it accurately. // For SSE servers without explicit OAuth config, if a token was found but rejected, report it accurately.
const credentials = await MCPOAuthTokenStorage.getToken(mcpServerName); const tokenStorage = new MCPOAuthTokenStorage();
const credentials = await tokenStorage.getToken(mcpServerName);
if (credentials) { if (credentials) {
const hasStoredTokens = await MCPOAuthProvider.getValidToken( const authProvider = new MCPOAuthProvider(tokenStorage);
const hasStoredTokens = await authProvider.getValidToken(
mcpServerName, mcpServerName,
{ {
// Pass client ID if available // Pass client ID if available
@@ -982,10 +981,11 @@ export async function connectToMcpServer(
// Get the valid token - we need to create a proper OAuth config // Get the valid token - we need to create a proper OAuth config
// The token should already be available from the authentication process // The token should already be available from the authentication process
const credentials = const tokenStorage = new MCPOAuthTokenStorage();
await MCPOAuthTokenStorage.getToken(mcpServerName); const credentials = await tokenStorage.getToken(mcpServerName);
if (credentials) { if (credentials) {
const accessToken = await MCPOAuthProvider.getValidToken( const authProvider = new MCPOAuthProvider(tokenStorage);
const accessToken = await authProvider.getValidToken(
mcpServerName, mcpServerName,
{ {
// Pass client ID if available // Pass client ID if available
@@ -1056,10 +1056,11 @@ export async function connectToMcpServer(
mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled; mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled;
if (!shouldTryDiscovery) { if (!shouldTryDiscovery) {
const credentials = const tokenStorage = new MCPOAuthTokenStorage();
await MCPOAuthTokenStorage.getToken(mcpServerName); const credentials = await tokenStorage.getToken(mcpServerName);
if (credentials) { if (credentials) {
const hasStoredTokens = await MCPOAuthProvider.getValidToken( const authProvider = new MCPOAuthProvider(tokenStorage);
const hasStoredTokens = await authProvider.getValidToken(
mcpServerName, mcpServerName,
{ {
// Pass client ID if available // Pass client ID if available
@@ -1116,17 +1117,21 @@ export async function connectToMcpServer(
console.log( console.log(
`Starting OAuth authentication for server '${mcpServerName}'...`, `Starting OAuth authentication for server '${mcpServerName}'...`,
); );
await MCPOAuthProvider.authenticate( const authProvider = new MCPOAuthProvider(
new MCPOAuthTokenStorage(),
);
await authProvider.authenticate(
mcpServerName, mcpServerName,
oauthAuthConfig, oauthAuthConfig,
authServerUrl, authServerUrl,
); );
// Retry connection with OAuth token // Retry connection with OAuth token
const credentials = const tokenStorage = new MCPOAuthTokenStorage();
await MCPOAuthTokenStorage.getToken(mcpServerName); const credentials = await tokenStorage.getToken(mcpServerName);
if (credentials) { if (credentials) {
const accessToken = await MCPOAuthProvider.getValidToken( const authProvider = new MCPOAuthProvider(tokenStorage);
const accessToken = await authProvider.getValidToken(
mcpServerName, mcpServerName,
{ {
// Pass client ID if available // Pass client ID if available
@@ -1261,7 +1266,9 @@ export async function createTransport(
let hasOAuthConfig = mcpServerConfig.oauth?.enabled; let hasOAuthConfig = mcpServerConfig.oauth?.enabled;
if (hasOAuthConfig && mcpServerConfig.oauth) { if (hasOAuthConfig && mcpServerConfig.oauth) {
accessToken = await MCPOAuthProvider.getValidToken( const tokenStorage = new MCPOAuthTokenStorage();
const authProvider = new MCPOAuthProvider(tokenStorage);
accessToken = await authProvider.getValidToken(
mcpServerName, mcpServerName,
mcpServerConfig.oauth, mcpServerConfig.oauth,
); );
@@ -1278,9 +1285,11 @@ export async function createTransport(
} }
} else { } else {
// Check if we have stored OAuth tokens for this server (from previous authentication) // Check if we have stored OAuth tokens for this server (from previous authentication)
const credentials = await MCPOAuthTokenStorage.getToken(mcpServerName); const tokenStorage = new MCPOAuthTokenStorage();
const credentials = await tokenStorage.getToken(mcpServerName);
if (credentials) { if (credentials) {
accessToken = await MCPOAuthProvider.getValidToken(mcpServerName, { const authProvider = new MCPOAuthProvider(tokenStorage);
accessToken = await authProvider.getValidToken(mcpServerName, {
// Pass client ID if available // Pass client ID if available
clientId: credentials.clientId, clientId: credentials.clientId,
}); });