diff --git a/packages/core/src/mcp/oauth-provider.test.ts b/packages/core/src/mcp/oauth-provider.test.ts index 251ccb4a5e..3caffe4a73 100644 --- a/packages/core/src/mcp/oauth-provider.test.ts +++ b/packages/core/src/mcp/oauth-provider.test.ts @@ -137,6 +137,22 @@ vi.mock('node:http', () => ({ createServer: vi.fn(() => mockHttpServer), })); +// Mock startCallbackServer to return what the new implementation returns +vi.mock('../utils/oauth-flow.js', async (importOriginal) => { + const actual = (await importOriginal()) as any; + return { + ...actual, + startCallbackServer: vi.fn((expectedState: string, port?: number) => { + const result = actual.startCallbackServer(expectedState, port); + // Ensure the mock server is used if createServer is mocked + if (vi.isMockFunction(http.createServer)) { + result.server = mockHttpServer; + } + return result; + }), + }; +}); + describe('MCPOAuthProvider', () => { const mockConfig: MCPOAuthConfig = { enabled: true, diff --git a/packages/core/src/mcp/oauth-provider.ts b/packages/core/src/mcp/oauth-provider.ts index 6aaafa6054..dccb2036c8 100644 --- a/packages/core/src/mcp/oauth-provider.ts +++ b/packages/core/src/mcp/oauth-provider.ts @@ -375,84 +375,87 @@ export class MCPOAuthProvider { // This ensures we only create one server and eliminates race conditions const callbackServer = startCallbackServer(pkceParams.state, preferredPort); - // Wait for server to start and get the allocated port - // We need this port for client registration and auth URL building - const redirectPort = await callbackServer.port; - debugLogger.debug(`Callback server listening on port ${redirectPort}`); + try { + // Wait for server to start and get the allocated port + // We need this port for client registration and auth URL building + const redirectPort = await callbackServer.port; + debugLogger.debug(`Callback server listening on port ${redirectPort}`); - // If no client ID is provided, try dynamic client registration - if (!config.clientId) { - let registrationUrl = config.registrationUrl; + // If no client ID is provided, try dynamic client registration + if (!config.clientId) { + let registrationUrl = config.registrationUrl; - // If no registration URL was previously discovered, try to discover it - if (!registrationUrl) { - // Use the issuer to discover registration endpoint - if (!config.issuer) { - throw new Error('Cannot perform dynamic registration without issuer'); + // If no registration URL was previously discovered, try to discover it + if (!registrationUrl) { + // Use the issuer to discover registration endpoint + if (!config.issuer) { + throw new Error( + 'Cannot perform dynamic registration without issuer', + ); + } + + debugLogger.debug('→ Attempting dynamic client registration...'); + const { metadata: authServerMetadata } = + await this.discoverAuthServerMetadataForRegistration(config.issuer); + registrationUrl = authServerMetadata.registration_endpoint; } - debugLogger.debug('→ Attempting dynamic client registration...'); - const { metadata: authServerMetadata } = - await this.discoverAuthServerMetadataForRegistration(config.issuer); - registrationUrl = authServerMetadata.registration_endpoint; + // Register client if registration endpoint is available + if (registrationUrl) { + const clientRegistration = await this.registerClient( + registrationUrl, + config, + redirectPort, + ); + + config.clientId = clientRegistration.client_id; + if (clientRegistration.client_secret) { + config.clientSecret = clientRegistration.client_secret; + } + + debugLogger.debug('✓ Dynamic client registration successful'); + } else { + throw new Error( + 'No client ID provided and dynamic registration not supported', + ); + } } - // Register client if registration endpoint is available - if (registrationUrl) { - const clientRegistration = await this.registerClient( - registrationUrl, - config, - redirectPort, - ); - - config.clientId = clientRegistration.client_id; - if (clientRegistration.client_secret) { - config.clientSecret = clientRegistration.client_secret; - } - - debugLogger.debug('✓ Dynamic client registration successful'); - } else { + // Validate configuration + if (!config.clientId || !config.authorizationUrl || !config.tokenUrl) { throw new Error( - 'No client ID provided and dynamic registration not supported', + 'Missing required OAuth configuration after discovery and registration', ); } - } - // Validate configuration - if (!config.clientId || !config.authorizationUrl || !config.tokenUrl) { - throw new Error( - 'Missing required OAuth configuration after discovery and registration', + // Build flow config for shared utilities + const flowConfig: OAuthFlowConfig = { + clientId: config.clientId, + clientSecret: config.clientSecret, + authorizationUrl: config.authorizationUrl, + tokenUrl: config.tokenUrl, + scopes: config.scopes, + audiences: config.audiences, + redirectUri: config.redirectUri, + }; + + // Build authorization URL + const resource = this.buildResourceParam(mcpServerUrl); + const authUrl = buildAuthorizationUrl( + flowConfig, + pkceParams, + redirectPort, + resource, ); - } - // Build flow config for shared utilities - const flowConfig: OAuthFlowConfig = { - clientId: config.clientId, - clientSecret: config.clientSecret, - authorizationUrl: config.authorizationUrl, - tokenUrl: config.tokenUrl, - scopes: config.scopes, - audiences: config.audiences, - redirectUri: config.redirectUri, - }; + const userConsent = await getConsentForOauth( + `Authentication required for MCP Server: '${serverName}.'`, + ); + if (!userConsent) { + throw new FatalCancellationError('Authentication cancelled by user.'); + } - // Build authorization URL - const resource = this.buildResourceParam(mcpServerUrl); - const authUrl = buildAuthorizationUrl( - flowConfig, - pkceParams, - redirectPort, - resource, - ); - - const userConsent = await getConsentForOauth( - `Authentication required for MCP Server: '${serverName}.'`, - ); - if (!userConsent) { - throw new FatalCancellationError('Authentication cancelled by user.'); - } - - displayMessage(`→ Opening your browser for OAuth sign-in... + displayMessage(`→ Opening your browser for OAuth sign-in... If the browser does not open, copy and paste this URL into your browser: ${authUrl} @@ -460,82 +463,85 @@ ${authUrl} 💡 TIP: Triple-click to select the entire URL, then copy and paste it into your browser. ⚠️ Make sure to copy the COMPLETE URL - it may wrap across multiple lines.`); - // Open browser securely (callback server is already running) - try { - await openBrowserSecurely(authUrl); - } catch (error) { - debugLogger.warn( - 'Failed to open browser automatically:', - getErrorMessage(error), - ); - } - - // Wait for callback - const { code } = await callbackServer.response; - - debugLogger.debug( - '✓ Authorization code received, exchanging for tokens...', - ); - - // Exchange code for tokens - const tokenResponse = await exchangeCodeForToken( - flowConfig, - code, - pkceParams.codeVerifier, - redirectPort, - resource, - ); - - // Convert to our token format - if (!tokenResponse.access_token) { - throw new Error('No access token received from token endpoint'); - } - - const token: OAuthToken = { - accessToken: tokenResponse.access_token, - tokenType: tokenResponse.token_type || 'Bearer', - refreshToken: tokenResponse.refresh_token, - scope: tokenResponse.scope, - }; - - if (tokenResponse.expires_in) { - token.expiresAt = Date.now() + tokenResponse.expires_in * 1000; - } - - // Save token - try { - await this.tokenStorage.saveToken( - serverName, - token, - config.clientId, - config.tokenUrl, - mcpServerUrl, - ); - debugLogger.debug('✓ Authentication successful! Token saved.'); - - // Verify token was saved - const savedToken = await this.tokenStorage.getCredentials(serverName); - if (savedToken && savedToken.token && savedToken.token.accessToken) { - // Avoid leaking token material; log a short SHA-256 fingerprint instead. - const tokenFingerprint = crypto - .createHash('sha256') - .update(savedToken.token.accessToken) - .digest('hex') - .slice(0, 8); - debugLogger.debug( - `✓ Token verification successful (fingerprint: ${tokenFingerprint})`, - ); - } else { + // Open browser securely (callback server is already running) + try { + await openBrowserSecurely(authUrl); + } catch (error) { debugLogger.warn( - 'Token verification failed: token not found or invalid after save', + 'Failed to open browser automatically:', + getErrorMessage(error), ); } - } catch (saveError) { - debugLogger.error('Failed to save auth token.', saveError); - throw saveError; - } - return token; + // Wait for callback + const { code } = await callbackServer.response; + + debugLogger.debug( + '✓ Authorization code received, exchanging for tokens...', + ); + + // Exchange code for tokens + const tokenResponse = await exchangeCodeForToken( + flowConfig, + code, + pkceParams.codeVerifier, + redirectPort, + resource, + ); + + // Convert to our token format + if (!tokenResponse.access_token) { + throw new Error('No access token received from token endpoint'); + } + + const token: OAuthToken = { + accessToken: tokenResponse.access_token, + tokenType: tokenResponse.token_type || 'Bearer', + refreshToken: tokenResponse.refresh_token, + scope: tokenResponse.scope, + }; + + if (tokenResponse.expires_in) { + token.expiresAt = Date.now() + tokenResponse.expires_in * 1000; + } + + // Save token + try { + await this.tokenStorage.saveToken( + serverName, + token, + config.clientId, + config.tokenUrl, + mcpServerUrl, + ); + debugLogger.debug('✓ Authentication successful! Token saved.'); + + // Verify token was saved + const savedToken = await this.tokenStorage.getCredentials(serverName); + if (savedToken && savedToken.token && savedToken.token.accessToken) { + // Avoid leaking token material; log a short SHA-256 fingerprint instead. + const tokenFingerprint = crypto + .createHash('sha256') + .update(savedToken.token.accessToken) + .digest('hex') + .slice(0, 8); + debugLogger.debug( + `✓ Token verification successful (fingerprint: ${tokenFingerprint})`, + ); + } else { + debugLogger.warn( + 'Token verification failed: token not found or invalid after save', + ); + } + } catch (saveError) { + debugLogger.error('Failed to save auth token.', saveError); + throw saveError; + } + + return token; + } finally { + callbackServer.close(); + } } /** diff --git a/packages/core/src/utils/oauth-flow-fix.test.ts b/packages/core/src/utils/oauth-flow-fix.test.ts new file mode 100644 index 0000000000..727e1fe063 --- /dev/null +++ b/packages/core/src/utils/oauth-flow-fix.test.ts @@ -0,0 +1,70 @@ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { startCallbackServer } from './oauth-flow.js'; + +describe('OAuth Flow Repro', () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it('should not have an unhandled rejection when close() is called before timeout', async () => { + let unhandledRejection: any = null; + const handler = (reason: any) => { + unhandledRejection = reason; + }; + process.on('unhandledRejection', handler); + + try { + const server = startCallbackServer('test-state'); + await server.port; + + // Explicitly close the server + server.close(); + + // Fast forward past the default 5 minute timeout + vi.advanceTimersByTime(5 * 60 * 1000 + 100); + + // Give it a tick + await Promise.resolve(); + await Promise.resolve(); + + expect(unhandledRejection).toBeNull(); + } finally { + process.off('unhandledRejection', handler); + } + }); + + it('should not have an unhandled rejection even if NOT closed, due to internal catch', async () => { + let unhandledRejection: any = null; + const handler = (reason: any) => { + unhandledRejection = reason; + }; + process.on('unhandledRejection', handler); + + try { + const server = startCallbackServer('test-state'); + await server.port; + + // Abandon the server without closing it + + // Fast forward past the default 5 minute timeout + vi.advanceTimersByTime(5 * 60 * 1000 + 100); + + // Give it a tick + await Promise.resolve(); + await Promise.resolve(); + + // Should be null because startCallbackServer now has an internal .catch() + expect(unhandledRejection).toBeNull(); + + // Cleanup for the test + server.close(); + } finally { + process.off('unhandledRejection', handler); + } + }); +}); diff --git a/packages/core/src/utils/oauth-flow.ts b/packages/core/src/utils/oauth-flow.ts index 67062c9ec5..6666bcc610 100644 --- a/packages/core/src/utils/oauth-flow.ts +++ b/packages/core/src/utils/oauth-flow.ts @@ -108,6 +108,8 @@ export function startCallbackServer( ): { port: Promise; response: Promise; + close: () => void; + server: http.Server; } { let portResolve: (port: number) => void; let portReject: (error: Error) => void; @@ -117,136 +119,153 @@ export function startCallbackServer( }); let timeoutId: NodeJS.Timeout | undefined; + let serverPort: number; + let resolveResponse: (value: OAuthAuthorizationResponse) => void; + let rejectResponse: (reason: any) => void; const responsePromise = new Promise( (resolve, reject) => { - let serverPort: number; - - const server = http.createServer( - async (req: http.IncomingMessage, res: http.ServerResponse) => { - try { - const url = new URL(req.url!, `http://localhost:${serverPort}`); - - if (url.pathname !== REDIRECT_PATH) { - res.writeHead(404); - res.end('Not found'); - return; - } - - const code = url.searchParams.get('code'); - const state = url.searchParams.get('state'); - const error = url.searchParams.get('error'); - - if (error) { - res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' }); - res.end(` - - -

Authentication Failed

-

Error: ${error.replace(//g, '>')}

-

${(url.searchParams.get('error_description') || '').replace(//g, '>')}

-

You can close this window.

- - - `); - server.close(); - reject(new Error(`OAuth error: ${error}`)); - return; - } - - if (!code || !state) { - res.writeHead(400); - res.end('Missing code or state parameter'); - return; - } - - if (state !== expectedState) { - res.writeHead(400); - res.end('Invalid state parameter'); - server.close(); - reject(new Error('State mismatch - possible CSRF attack')); - return; - } - - // Send success response to browser - res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' }); - res.end(` - - -

Authentication Successful!

-

You can close this window and return to Gemini CLI.

- - - - `); - - server.close(); - resolve({ code, state }); - } catch (error) { - server.close(); - reject(error); - } - }, - ); - - server.on('error', (error) => { - portReject(error); - reject(error); - }); - - // Determine which port to use (env var, argument, or OS-assigned) - let listenPort = 0; // Default to OS-assigned port - - const portStr = process.env['OAUTH_CALLBACK_PORT']; - if (portStr) { - const envPort = parseInt(portStr, 10); - if (isNaN(envPort) || envPort <= 0 || envPort > 65535) { - const error = new Error( - `Invalid value for OAUTH_CALLBACK_PORT: "${portStr}"`, - ); - portReject(error); - reject(error); - return; - } - listenPort = envPort; - } else if (port !== undefined) { - listenPort = port; - } - - server.listen(listenPort, () => { - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const address = server.address() as net.AddressInfo; - serverPort = address.port; - debugLogger.log( - `OAuth callback server listening on port ${serverPort}`, - ); - portResolve(serverPort); // Resolve port promise immediately - }); - - const abortController = new AbortController(); - timeoutId = setTimeout( - () => { - abortController.abort(new Error('OAuth callback timeout')); - }, - 5 * 60 * 1000, - ); - timeoutId.unref(); - - const onAbort = () => { - server.close(); - reject(abortController.signal.reason); - }; - abortController.signal.addEventListener('abort', onAbort, { once: true }); - - server.on('close', () => { - abortController.signal.removeEventListener('abort', onAbort); - }); + resolveResponse = resolve; + rejectResponse = reject; }, ); + const server = http.createServer( + async (req: http.IncomingMessage, res: http.ServerResponse) => { + try { + const url = new URL(req.url!, `http://localhost:${serverPort}`); + + if (url.pathname !== REDIRECT_PATH) { + res.writeHead(404); + res.end('Not found'); + return; + } + + const code = url.searchParams.get('code'); + const state = url.searchParams.get('state'); + const error = url.searchParams.get('error'); + + if (error) { + res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' }); + res.end(` + + +

Authentication Failed

+

Error: ${error.replace(//g, '>')}

+

${(url.searchParams.get('error_description') || '').replace(//g, '>')}

+

You can close this window.

+ + + `); + server.close(); + rejectResponse(new Error(`OAuth error: ${error}`)); + return; + } + + if (!code || !state) { + res.writeHead(400); + res.end('Missing code or state parameter'); + return; + } + + if (state !== expectedState) { + res.writeHead(400); + res.end('Invalid state parameter'); + server.close(); + rejectResponse(new Error('State mismatch - possible CSRF attack')); + return; + } + + // Send success response to browser + res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' }); + res.end(` + + +

Authentication Successful!

+

You can close this window and return to Gemini CLI.

+ + + + `); + + server.close(); + resolveResponse({ code, state }); + } catch (error) { + server.close(); + rejectResponse(error); + } + }, + ); + + server.on('error', (error) => { + portReject(error); + rejectResponse(error); + }); + + // Determine which port to use (env var, argument, or OS-assigned) + let listenPort = 0; // Default to OS-assigned port + + const portStr = process.env['OAUTH_CALLBACK_PORT']; + if (portStr) { + const envPort = parseInt(portStr, 10); + if (isNaN(envPort) || envPort <= 0 || envPort > 65535) { + const error = new Error( + `Invalid value for OAUTH_CALLBACK_PORT: "${portStr}"`, + ); + portReject(error); + rejectResponse(error); + // We still return the object, but the promises will be rejected + } else { + listenPort = envPort; + } + } else if (port !== undefined) { + listenPort = port; + } + + if (listenPort !== undefined || !portStr) { + server.listen(listenPort, () => { + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + const address = server.address() as net.AddressInfo; + serverPort = address.port; + debugLogger.log(`OAuth callback server listening on port ${serverPort}`); + portResolve(serverPort); // Resolve port promise immediately + }); + } + + const abortController = new AbortController(); + timeoutId = setTimeout( + () => { + abortController.abort(new Error('OAuth callback timeout')); + }, + 5 * 60 * 1000, + ); + timeoutId.unref(); + + const onAbort = () => { + server.close(); + rejectResponse(abortController.signal.reason); + }; + abortController.signal.addEventListener('abort', onAbort, { once: true }); + + server.on('close', () => { + abortController.signal.removeEventListener('abort', onAbort); + }); + + // Attach a no-op catch to prevent unhandled rejections if the promise is abandoned. + // The caller can still await it and catch their own errors. + responsePromise.catch(() => {}); + return { port: portPromise, response: responsePromise, + close: () => { + if (timeoutId) { + clearTimeout(timeoutId); + timeoutId = undefined; + } + server.close(); + }, + server, }; }