fix(core): prevent unhandled rejection on OAuth callback timeout

This commit is contained in:
Coco Sheng
2026-05-14 12:56:45 -04:00
parent 8fb1b5aa01
commit fc6a115655
4 changed files with 369 additions and 258 deletions
@@ -137,6 +137,22 @@ vi.mock('node:http', () => ({
createServer: vi.fn(() => mockHttpServer),
}));
// Mock startCallbackServer to return what the new implementation returns
vi.mock('../utils/oauth-flow.js', async (importOriginal) => {
const actual = (await importOriginal()) as any;
return {
...actual,
startCallbackServer: vi.fn((expectedState: string, port?: number) => {
const result = actual.startCallbackServer(expectedState, port);
// Ensure the mock server is used if createServer is mocked
if (vi.isMockFunction(http.createServer)) {
result.server = mockHttpServer;
}
return result;
}),
};
});
describe('MCPOAuthProvider', () => {
const mockConfig: MCPOAuthConfig = {
enabled: true,
+143 -137
View File
@@ -375,84 +375,87 @@ export class MCPOAuthProvider {
// This ensures we only create one server and eliminates race conditions
const callbackServer = startCallbackServer(pkceParams.state, preferredPort);
// Wait for server to start and get the allocated port
// We need this port for client registration and auth URL building
const redirectPort = await callbackServer.port;
debugLogger.debug(`Callback server listening on port ${redirectPort}`);
try {
// Wait for server to start and get the allocated port
// We need this port for client registration and auth URL building
const redirectPort = await callbackServer.port;
debugLogger.debug(`Callback server listening on port ${redirectPort}`);
// If no client ID is provided, try dynamic client registration
if (!config.clientId) {
let registrationUrl = config.registrationUrl;
// If no client ID is provided, try dynamic client registration
if (!config.clientId) {
let registrationUrl = config.registrationUrl;
// If no registration URL was previously discovered, try to discover it
if (!registrationUrl) {
// Use the issuer to discover registration endpoint
if (!config.issuer) {
throw new Error('Cannot perform dynamic registration without issuer');
// If no registration URL was previously discovered, try to discover it
if (!registrationUrl) {
// Use the issuer to discover registration endpoint
if (!config.issuer) {
throw new Error(
'Cannot perform dynamic registration without issuer',
);
}
debugLogger.debug('→ Attempting dynamic client registration...');
const { metadata: authServerMetadata } =
await this.discoverAuthServerMetadataForRegistration(config.issuer);
registrationUrl = authServerMetadata.registration_endpoint;
}
debugLogger.debug('→ Attempting dynamic client registration...');
const { metadata: authServerMetadata } =
await this.discoverAuthServerMetadataForRegistration(config.issuer);
registrationUrl = authServerMetadata.registration_endpoint;
// Register client if registration endpoint is available
if (registrationUrl) {
const clientRegistration = await this.registerClient(
registrationUrl,
config,
redirectPort,
);
config.clientId = clientRegistration.client_id;
if (clientRegistration.client_secret) {
config.clientSecret = clientRegistration.client_secret;
}
debugLogger.debug('✓ Dynamic client registration successful');
} else {
throw new Error(
'No client ID provided and dynamic registration not supported',
);
}
}
// Register client if registration endpoint is available
if (registrationUrl) {
const clientRegistration = await this.registerClient(
registrationUrl,
config,
redirectPort,
);
config.clientId = clientRegistration.client_id;
if (clientRegistration.client_secret) {
config.clientSecret = clientRegistration.client_secret;
}
debugLogger.debug('✓ Dynamic client registration successful');
} else {
// Validate configuration
if (!config.clientId || !config.authorizationUrl || !config.tokenUrl) {
throw new Error(
'No client ID provided and dynamic registration not supported',
'Missing required OAuth configuration after discovery and registration',
);
}
}
// Validate configuration
if (!config.clientId || !config.authorizationUrl || !config.tokenUrl) {
throw new Error(
'Missing required OAuth configuration after discovery and registration',
// Build flow config for shared utilities
const flowConfig: OAuthFlowConfig = {
clientId: config.clientId,
clientSecret: config.clientSecret,
authorizationUrl: config.authorizationUrl,
tokenUrl: config.tokenUrl,
scopes: config.scopes,
audiences: config.audiences,
redirectUri: config.redirectUri,
};
// Build authorization URL
const resource = this.buildResourceParam(mcpServerUrl);
const authUrl = buildAuthorizationUrl(
flowConfig,
pkceParams,
redirectPort,
resource,
);
}
// Build flow config for shared utilities
const flowConfig: OAuthFlowConfig = {
clientId: config.clientId,
clientSecret: config.clientSecret,
authorizationUrl: config.authorizationUrl,
tokenUrl: config.tokenUrl,
scopes: config.scopes,
audiences: config.audiences,
redirectUri: config.redirectUri,
};
const userConsent = await getConsentForOauth(
`Authentication required for MCP Server: '${serverName}.'`,
);
if (!userConsent) {
throw new FatalCancellationError('Authentication cancelled by user.');
}
// Build authorization URL
const resource = this.buildResourceParam(mcpServerUrl);
const authUrl = buildAuthorizationUrl(
flowConfig,
pkceParams,
redirectPort,
resource,
);
const userConsent = await getConsentForOauth(
`Authentication required for MCP Server: '${serverName}.'`,
);
if (!userConsent) {
throw new FatalCancellationError('Authentication cancelled by user.');
}
displayMessage(`→ Opening your browser for OAuth sign-in...
displayMessage(`→ Opening your browser for OAuth sign-in...
If the browser does not open, copy and paste this URL into your browser:
${authUrl}
@@ -460,82 +463,85 @@ ${authUrl}
💡 TIP: Triple-click to select the entire URL, then copy and paste it into your browser.
Make sure to copy the COMPLETE URL - it may wrap across multiple lines.`);
// Open browser securely (callback server is already running)
try {
await openBrowserSecurely(authUrl);
} catch (error) {
debugLogger.warn(
'Failed to open browser automatically:',
getErrorMessage(error),
);
}
// Wait for callback
const { code } = await callbackServer.response;
debugLogger.debug(
'✓ Authorization code received, exchanging for tokens...',
);
// Exchange code for tokens
const tokenResponse = await exchangeCodeForToken(
flowConfig,
code,
pkceParams.codeVerifier,
redirectPort,
resource,
);
// Convert to our token format
if (!tokenResponse.access_token) {
throw new Error('No access token received from token endpoint');
}
const token: OAuthToken = {
accessToken: tokenResponse.access_token,
tokenType: tokenResponse.token_type || 'Bearer',
refreshToken: tokenResponse.refresh_token,
scope: tokenResponse.scope,
};
if (tokenResponse.expires_in) {
token.expiresAt = Date.now() + tokenResponse.expires_in * 1000;
}
// Save token
try {
await this.tokenStorage.saveToken(
serverName,
token,
config.clientId,
config.tokenUrl,
mcpServerUrl,
);
debugLogger.debug('✓ Authentication successful! Token saved.');
// Verify token was saved
const savedToken = await this.tokenStorage.getCredentials(serverName);
if (savedToken && savedToken.token && savedToken.token.accessToken) {
// Avoid leaking token material; log a short SHA-256 fingerprint instead.
const tokenFingerprint = crypto
.createHash('sha256')
.update(savedToken.token.accessToken)
.digest('hex')
.slice(0, 8);
debugLogger.debug(
`✓ Token verification successful (fingerprint: ${tokenFingerprint})`,
);
} else {
// Open browser securely (callback server is already running)
try {
await openBrowserSecurely(authUrl);
} catch (error) {
debugLogger.warn(
'Token verification failed: token not found or invalid after save',
'Failed to open browser automatically:',
getErrorMessage(error),
);
}
} catch (saveError) {
debugLogger.error('Failed to save auth token.', saveError);
throw saveError;
}
return token;
// Wait for callback
const { code } = await callbackServer.response;
debugLogger.debug(
'✓ Authorization code received, exchanging for tokens...',
);
// Exchange code for tokens
const tokenResponse = await exchangeCodeForToken(
flowConfig,
code,
pkceParams.codeVerifier,
redirectPort,
resource,
);
// Convert to our token format
if (!tokenResponse.access_token) {
throw new Error('No access token received from token endpoint');
}
const token: OAuthToken = {
accessToken: tokenResponse.access_token,
tokenType: tokenResponse.token_type || 'Bearer',
refreshToken: tokenResponse.refresh_token,
scope: tokenResponse.scope,
};
if (tokenResponse.expires_in) {
token.expiresAt = Date.now() + tokenResponse.expires_in * 1000;
}
// Save token
try {
await this.tokenStorage.saveToken(
serverName,
token,
config.clientId,
config.tokenUrl,
mcpServerUrl,
);
debugLogger.debug('✓ Authentication successful! Token saved.');
// Verify token was saved
const savedToken = await this.tokenStorage.getCredentials(serverName);
if (savedToken && savedToken.token && savedToken.token.accessToken) {
// Avoid leaking token material; log a short SHA-256 fingerprint instead.
const tokenFingerprint = crypto
.createHash('sha256')
.update(savedToken.token.accessToken)
.digest('hex')
.slice(0, 8);
debugLogger.debug(
`✓ Token verification successful (fingerprint: ${tokenFingerprint})`,
);
} else {
debugLogger.warn(
'Token verification failed: token not found or invalid after save',
);
}
} catch (saveError) {
debugLogger.error('Failed to save auth token.', saveError);
throw saveError;
}
return token;
} finally {
callbackServer.close();
}
}
/**
@@ -0,0 +1,70 @@
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { startCallbackServer } from './oauth-flow.js';
describe('OAuth Flow Repro', () => {
beforeEach(() => {
vi.useFakeTimers();
});
afterEach(() => {
vi.useRealTimers();
});
it('should not have an unhandled rejection when close() is called before timeout', async () => {
let unhandledRejection: any = null;
const handler = (reason: any) => {
unhandledRejection = reason;
};
process.on('unhandledRejection', handler);
try {
const server = startCallbackServer('test-state');
await server.port;
// Explicitly close the server
server.close();
// Fast forward past the default 5 minute timeout
vi.advanceTimersByTime(5 * 60 * 1000 + 100);
// Give it a tick
await Promise.resolve();
await Promise.resolve();
expect(unhandledRejection).toBeNull();
} finally {
process.off('unhandledRejection', handler);
}
});
it('should not have an unhandled rejection even if NOT closed, due to internal catch', async () => {
let unhandledRejection: any = null;
const handler = (reason: any) => {
unhandledRejection = reason;
};
process.on('unhandledRejection', handler);
try {
const server = startCallbackServer('test-state');
await server.port;
// Abandon the server without closing it
// Fast forward past the default 5 minute timeout
vi.advanceTimersByTime(5 * 60 * 1000 + 100);
// Give it a tick
await Promise.resolve();
await Promise.resolve();
// Should be null because startCallbackServer now has an internal .catch()
expect(unhandledRejection).toBeNull();
// Cleanup for the test
server.close();
} finally {
process.off('unhandledRejection', handler);
}
});
});
+140 -121
View File
@@ -108,6 +108,8 @@ export function startCallbackServer(
): {
port: Promise<number>;
response: Promise<OAuthAuthorizationResponse>;
close: () => void;
server: http.Server;
} {
let portResolve: (port: number) => void;
let portReject: (error: Error) => void;
@@ -117,136 +119,153 @@ export function startCallbackServer(
});
let timeoutId: NodeJS.Timeout | undefined;
let serverPort: number;
let resolveResponse: (value: OAuthAuthorizationResponse) => void;
let rejectResponse: (reason: any) => void;
const responsePromise = new Promise<OAuthAuthorizationResponse>(
(resolve, reject) => {
let serverPort: number;
const server = http.createServer(
async (req: http.IncomingMessage, res: http.ServerResponse) => {
try {
const url = new URL(req.url!, `http://localhost:${serverPort}`);
if (url.pathname !== REDIRECT_PATH) {
res.writeHead(404);
res.end('Not found');
return;
}
const code = url.searchParams.get('code');
const state = url.searchParams.get('state');
const error = url.searchParams.get('error');
if (error) {
res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' });
res.end(`
<html>
<body>
<h1>Authentication Failed</h1>
<p>Error: ${error.replace(/</g, '&lt;').replace(/>/g, '&gt;')}</p>
<p>${(url.searchParams.get('error_description') || '').replace(/</g, '&lt;').replace(/>/g, '&gt;')}</p>
<p>You can close this window.</p>
</body>
</html>
`);
server.close();
reject(new Error(`OAuth error: ${error}`));
return;
}
if (!code || !state) {
res.writeHead(400);
res.end('Missing code or state parameter');
return;
}
if (state !== expectedState) {
res.writeHead(400);
res.end('Invalid state parameter');
server.close();
reject(new Error('State mismatch - possible CSRF attack'));
return;
}
// Send success response to browser
res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' });
res.end(`
<html>
<body>
<h1>Authentication Successful!</h1>
<p>You can close this window and return to Gemini CLI.</p>
<script>window.close();</script>
</body>
</html>
`);
server.close();
resolve({ code, state });
} catch (error) {
server.close();
reject(error);
}
},
);
server.on('error', (error) => {
portReject(error);
reject(error);
});
// Determine which port to use (env var, argument, or OS-assigned)
let listenPort = 0; // Default to OS-assigned port
const portStr = process.env['OAUTH_CALLBACK_PORT'];
if (portStr) {
const envPort = parseInt(portStr, 10);
if (isNaN(envPort) || envPort <= 0 || envPort > 65535) {
const error = new Error(
`Invalid value for OAUTH_CALLBACK_PORT: "${portStr}"`,
);
portReject(error);
reject(error);
return;
}
listenPort = envPort;
} else if (port !== undefined) {
listenPort = port;
}
server.listen(listenPort, () => {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const address = server.address() as net.AddressInfo;
serverPort = address.port;
debugLogger.log(
`OAuth callback server listening on port ${serverPort}`,
);
portResolve(serverPort); // Resolve port promise immediately
});
const abortController = new AbortController();
timeoutId = setTimeout(
() => {
abortController.abort(new Error('OAuth callback timeout'));
},
5 * 60 * 1000,
);
timeoutId.unref();
const onAbort = () => {
server.close();
reject(abortController.signal.reason);
};
abortController.signal.addEventListener('abort', onAbort, { once: true });
server.on('close', () => {
abortController.signal.removeEventListener('abort', onAbort);
});
resolveResponse = resolve;
rejectResponse = reject;
},
);
const server = http.createServer(
async (req: http.IncomingMessage, res: http.ServerResponse) => {
try {
const url = new URL(req.url!, `http://localhost:${serverPort}`);
if (url.pathname !== REDIRECT_PATH) {
res.writeHead(404);
res.end('Not found');
return;
}
const code = url.searchParams.get('code');
const state = url.searchParams.get('state');
const error = url.searchParams.get('error');
if (error) {
res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' });
res.end(`
<html>
<body>
<h1>Authentication Failed</h1>
<p>Error: ${error.replace(/</g, '&lt;').replace(/>/g, '&gt;')}</p>
<p>${(url.searchParams.get('error_description') || '').replace(/</g, '&lt;').replace(/>/g, '&gt;')}</p>
<p>You can close this window.</p>
</body>
</html>
`);
server.close();
rejectResponse(new Error(`OAuth error: ${error}`));
return;
}
if (!code || !state) {
res.writeHead(400);
res.end('Missing code or state parameter');
return;
}
if (state !== expectedState) {
res.writeHead(400);
res.end('Invalid state parameter');
server.close();
rejectResponse(new Error('State mismatch - possible CSRF attack'));
return;
}
// Send success response to browser
res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' });
res.end(`
<html>
<body>
<h1>Authentication Successful!</h1>
<p>You can close this window and return to Gemini CLI.</p>
<script>window.close();</script>
</body>
</html>
`);
server.close();
resolveResponse({ code, state });
} catch (error) {
server.close();
rejectResponse(error);
}
},
);
server.on('error', (error) => {
portReject(error);
rejectResponse(error);
});
// Determine which port to use (env var, argument, or OS-assigned)
let listenPort = 0; // Default to OS-assigned port
const portStr = process.env['OAUTH_CALLBACK_PORT'];
if (portStr) {
const envPort = parseInt(portStr, 10);
if (isNaN(envPort) || envPort <= 0 || envPort > 65535) {
const error = new Error(
`Invalid value for OAUTH_CALLBACK_PORT: "${portStr}"`,
);
portReject(error);
rejectResponse(error);
// We still return the object, but the promises will be rejected
} else {
listenPort = envPort;
}
} else if (port !== undefined) {
listenPort = port;
}
if (listenPort !== undefined || !portStr) {
server.listen(listenPort, () => {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const address = server.address() as net.AddressInfo;
serverPort = address.port;
debugLogger.log(`OAuth callback server listening on port ${serverPort}`);
portResolve(serverPort); // Resolve port promise immediately
});
}
const abortController = new AbortController();
timeoutId = setTimeout(
() => {
abortController.abort(new Error('OAuth callback timeout'));
},
5 * 60 * 1000,
);
timeoutId.unref();
const onAbort = () => {
server.close();
rejectResponse(abortController.signal.reason);
};
abortController.signal.addEventListener('abort', onAbort, { once: true });
server.on('close', () => {
abortController.signal.removeEventListener('abort', onAbort);
});
// Attach a no-op catch to prevent unhandled rejections if the promise is abandoned.
// The caller can still await it and catch their own errors.
responsePromise.catch(() => {});
return {
port: portPromise,
response: responsePromise,
close: () => {
if (timeoutId) {
clearTimeout(timeoutId);
timeoutId = undefined;
}
server.close();
},
server,
};
}