refactor(core): extract shared OAuth flow primitives from MCPOAuthProvider (#20895)

This commit is contained in:
Sandy Tao
2026-03-05 09:01:37 -08:00
committed by GitHub
parent ddafd79661
commit 0228c2b9f0
3 changed files with 1208 additions and 513 deletions
+58 -513
View File
@@ -4,9 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import * as http from 'node:http';
import * as crypto from 'node:crypto';
import type * as net from 'node:net';
import { URL } from 'node:url';
import { openBrowserSecurely } from '../utils/secure-browser-launcher.js';
import type { OAuthToken } from './token-storage/types.js';
@@ -16,6 +14,23 @@ import { OAuthUtils, ResourceMismatchError } from './oauth-utils.js';
import { coreEvents } from '../utils/events.js';
import { debugLogger } from '../utils/debugLogger.js';
import { getConsentForOauth } from '../utils/authConsent.js';
import {
generatePKCEParams,
startCallbackServer,
getPortFromUrl,
buildAuthorizationUrl,
exchangeCodeForToken,
refreshAccessToken as refreshAccessTokenShared,
REDIRECT_PATH,
type OAuthFlowConfig,
type OAuthTokenResponse,
} from '../utils/oauth-flow.js';
// Re-export types that were moved to oauth-flow.ts for backward compatibility.
export type {
OAuthAuthorizationResponse,
OAuthTokenResponse,
} from '../utils/oauth-flow.js';
/**
* OAuth configuration for an MCP server.
@@ -34,25 +49,6 @@ export interface MCPOAuthConfig {
registrationUrl?: string;
}
/**
* OAuth authorization response.
*/
export interface OAuthAuthorizationResponse {
code: string;
state: string;
}
/**
* OAuth token response from the authorization server.
*/
export interface OAuthTokenResponse {
access_token: string;
token_type: string;
expires_in?: number;
refresh_token?: string;
scope?: string;
}
/**
* Dynamic client registration request (RFC 7591).
*/
@@ -80,18 +76,6 @@ export interface OAuthClientRegistrationResponse {
scope?: string;
}
/**
* PKCE (Proof Key for Code Exchange) parameters.
*/
interface PKCEParams {
codeVerifier: string;
codeChallenge: string;
state: string;
}
const REDIRECT_PATH = '/oauth/callback';
const HTTP_OK = 200;
/**
* Provider for handling OAuth authentication for MCP servers.
*/
@@ -239,375 +223,18 @@ export class MCPOAuthProvider {
}
/**
* Generate PKCE parameters for OAuth flow.
*
* @returns PKCE parameters including code verifier, challenge, and state
* Build the OAuth resource parameter from an MCP server URL, if available.
* Returns undefined if the URL is not provided or cannot be processed.
*/
private generatePKCEParams(): PKCEParams {
// Generate code verifier (43-128 characters)
// using 64 bytes results in ~86 characters, safely above the minimum of 43
const codeVerifier = crypto.randomBytes(64).toString('base64url');
// Generate code challenge using SHA256
const codeChallenge = crypto
.createHash('sha256')
.update(codeVerifier)
.digest('base64url');
// Generate state for CSRF protection
const state = crypto.randomBytes(16).toString('base64url');
return { codeVerifier, codeChallenge, state };
}
/**
* Start a local HTTP server to handle OAuth callback.
* The server will listen on the specified port (or port 0 for OS assignment).
*
* @param expectedState The state parameter to validate
* @returns Object containing the port (available immediately) and a promise for the auth response
*/
private startCallbackServer(
expectedState: string,
port?: number,
): {
port: Promise<number>;
response: Promise<OAuthAuthorizationResponse>;
} {
let portResolve: (port: number) => void;
let portReject: (error: Error) => void;
const portPromise = new Promise<number>((resolve, reject) => {
portResolve = resolve;
portReject = reject;
});
const responsePromise = new Promise<OAuthAuthorizationResponse>(
(resolve, reject) => {
let serverPort: number;
const server = http.createServer(
async (req: http.IncomingMessage, res: http.ServerResponse) => {
try {
const url = new URL(req.url!, `http://localhost:${serverPort}`);
if (url.pathname !== REDIRECT_PATH) {
res.writeHead(404);
res.end('Not found');
return;
}
const code = url.searchParams.get('code');
const state = url.searchParams.get('state');
const error = url.searchParams.get('error');
if (error) {
res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' });
res.end(`
<html>
<body>
<h1>Authentication Failed</h1>
<p>Error: ${error.replace(/</g, '&lt;').replace(/>/g, '&gt;')}</p>
<p>${(url.searchParams.get('error_description') || '').replace(/</g, '&lt;').replace(/>/g, '&gt;')}</p>
<p>You can close this window.</p>
</body>
</html>
`);
server.close();
reject(new Error(`OAuth error: ${error}`));
return;
}
if (!code || !state) {
res.writeHead(400);
res.end('Missing code or state parameter');
return;
}
if (state !== expectedState) {
res.writeHead(400);
res.end('Invalid state parameter');
server.close();
reject(new Error('State mismatch - possible CSRF attack'));
return;
}
// Send success response to browser
res.writeHead(HTTP_OK, { 'Content-Type': 'text/html' });
res.end(`
<html>
<body>
<h1>Authentication Successful!</h1>
<p>You can close this window and return to Gemini CLI.</p>
<script>window.close();</script>
</body>
</html>
`);
server.close();
resolve({ code, state });
} catch (error) {
server.close();
reject(error);
}
},
);
server.on('error', (error) => {
portReject(error);
reject(error);
});
// Determine which port to use (env var, argument, or OS-assigned)
let listenPort = 0; // Default to OS-assigned port
const portStr = process.env['OAUTH_CALLBACK_PORT'];
if (portStr) {
const envPort = parseInt(portStr, 10);
if (isNaN(envPort) || envPort <= 0 || envPort > 65535) {
const error = new Error(
`Invalid value for OAUTH_CALLBACK_PORT: "${portStr}"`,
);
portReject(error);
reject(error);
return;
}
listenPort = envPort;
} else if (port !== undefined) {
listenPort = port;
}
server.listen(listenPort, () => {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const address = server.address() as net.AddressInfo;
serverPort = address.port;
debugLogger.log(
`OAuth callback server listening on port ${serverPort}`,
);
portResolve(serverPort); // Resolve port promise immediately
});
// Timeout after 5 minutes
setTimeout(
() => {
server.close();
reject(new Error('OAuth callback timeout'));
},
5 * 60 * 1000,
);
},
);
return { port: portPromise, response: responsePromise };
}
/**
* Extract the port number from a URL string if available and valid.
*
* @param urlString The URL string to parse
* @returns The port number or undefined if not found or invalid
*/
private getPortFromUrl(urlString?: string): number | undefined {
if (!urlString) {
return undefined;
}
private buildResourceParam(mcpServerUrl?: string): string | undefined {
if (!mcpServerUrl) return undefined;
try {
const url = new URL(urlString);
if (url.port) {
const parsedPort = parseInt(url.port, 10);
if (!isNaN(parsedPort) && parsedPort > 0 && parsedPort <= 65535) {
return parsedPort;
}
}
} catch {
// Ignore invalid URL
}
return undefined;
}
/**
* Build the authorization URL for the OAuth flow.
*
* @param config OAuth configuration
* @param pkceParams PKCE parameters
* @param redirectPort The port to use for the redirect URI
* @param mcpServerUrl The MCP server URL to use as the resource parameter
* @returns The authorization URL
*/
private buildAuthorizationUrl(
config: MCPOAuthConfig,
pkceParams: PKCEParams,
redirectPort: number,
mcpServerUrl?: string,
): string {
const redirectUri =
config.redirectUri || `http://localhost:${redirectPort}${REDIRECT_PATH}`;
const params = new URLSearchParams({
client_id: config.clientId!,
response_type: 'code',
redirect_uri: redirectUri,
state: pkceParams.state,
code_challenge: pkceParams.codeChallenge,
code_challenge_method: 'S256',
});
if (config.scopes && config.scopes.length > 0) {
params.append('scope', config.scopes.join(' '));
}
if (config.audiences && config.audiences.length > 0) {
params.append('audience', config.audiences.join(' '));
}
// Add resource parameter for MCP OAuth spec compliance
// Only add if we have an MCP server URL (indicates MCP OAuth flow, not standard OAuth)
if (mcpServerUrl) {
try {
params.append(
'resource',
OAuthUtils.buildResourceParameter(mcpServerUrl),
);
} catch (error) {
debugLogger.warn(
`Could not add resource parameter: ${getErrorMessage(error)}`,
);
}
}
const url = new URL(config.authorizationUrl!);
params.forEach((value, key) => {
url.searchParams.append(key, value);
});
return url.toString();
}
/**
* Exchange authorization code for tokens.
*
* @param config OAuth configuration
* @param code Authorization code
* @param codeVerifier PKCE code verifier
* @param redirectPort The port to use for the redirect URI
* @param mcpServerUrl The MCP server URL to use as the resource parameter
* @returns The token response
*/
private async exchangeCodeForToken(
config: MCPOAuthConfig,
code: string,
codeVerifier: string,
redirectPort: number,
mcpServerUrl?: string,
): Promise<OAuthTokenResponse> {
const redirectUri =
config.redirectUri || `http://localhost:${redirectPort}${REDIRECT_PATH}`;
const params = new URLSearchParams({
grant_type: 'authorization_code',
code,
redirect_uri: redirectUri,
code_verifier: codeVerifier,
client_id: config.clientId!,
});
if (config.clientSecret) {
params.append('client_secret', config.clientSecret);
}
if (config.audiences && config.audiences.length > 0) {
params.append('audience', config.audiences.join(' '));
}
// Add resource parameter for MCP OAuth spec compliance
// Only add if we have an MCP server URL (indicates MCP OAuth flow, not standard OAuth)
if (mcpServerUrl) {
const resourceUrl = mcpServerUrl;
try {
params.append(
'resource',
OAuthUtils.buildResourceParameter(resourceUrl),
);
} catch (error) {
debugLogger.warn(
`Could not add resource parameter: ${getErrorMessage(error)}`,
);
}
}
const response = await fetch(config.tokenUrl!, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
Accept: 'application/json, application/x-www-form-urlencoded',
},
body: params.toString(),
});
const responseText = await response.text();
const contentType = response.headers.get('content-type') || '';
if (!response.ok) {
// Try to parse error from form-urlencoded response
let errorMessage: string | null = null;
try {
const errorParams = new URLSearchParams(responseText);
const error = errorParams.get('error');
const errorDescription = errorParams.get('error_description');
if (error) {
errorMessage = `Token exchange failed: ${error} - ${errorDescription || 'No description'}`;
}
} catch {
// Fall back to raw error
}
throw new Error(
errorMessage ||
`Token exchange failed: ${response.status} - ${responseText}`,
);
}
// Log unexpected content types for debugging
if (
!contentType.includes('application/json') &&
!contentType.includes('application/x-www-form-urlencoded')
) {
return OAuthUtils.buildResourceParameter(mcpServerUrl);
} catch (error) {
debugLogger.warn(
`Token endpoint returned unexpected content-type: ${contentType}. ` +
`Expected application/json or application/x-www-form-urlencoded. ` +
`Will attempt to parse response.`,
`Could not add resource parameter: ${getErrorMessage(error)}`,
);
}
// Try to parse as JSON first, fall back to form-urlencoded
try {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
return JSON.parse(responseText) as OAuthTokenResponse;
} catch {
// Parse form-urlencoded response
const tokenParams = new URLSearchParams(responseText);
const accessToken = tokenParams.get('access_token');
const tokenType = tokenParams.get('token_type') || 'Bearer';
const expiresIn = tokenParams.get('expires_in');
const refreshToken = tokenParams.get('refresh_token');
const scope = tokenParams.get('scope');
if (!accessToken) {
// Check for error in response
const error = tokenParams.get('error');
const errorDescription = tokenParams.get('error_description');
throw new Error(
`Token exchange failed: ${error || 'no_access_token'} - ${errorDescription || responseText}`,
);
}
return {
access_token: accessToken,
token_type: tokenType,
expires_in: expiresIn ? parseInt(expiresIn, 10) : undefined,
refresh_token: refreshToken || undefined,
scope: scope || undefined,
} as OAuthTokenResponse;
return undefined;
}
}
@@ -626,112 +253,21 @@ export class MCPOAuthProvider {
tokenUrl: string,
mcpServerUrl?: string,
): Promise<OAuthTokenResponse> {
const params = new URLSearchParams({
grant_type: 'refresh_token',
refresh_token: refreshToken,
client_id: config.clientId!,
});
if (config.clientSecret) {
params.append('client_secret', config.clientSecret);
if (!config.clientId) {
throw new Error('Missing required clientId for token refresh');
}
if (config.scopes && config.scopes.length > 0) {
params.append('scope', config.scopes.join(' '));
}
if (config.audiences && config.audiences.length > 0) {
params.append('audience', config.audiences.join(' '));
}
// Add resource parameter for MCP OAuth spec compliance
// Only add if we have an MCP server URL (indicates MCP OAuth flow, not standard OAuth)
if (mcpServerUrl) {
try {
params.append(
'resource',
OAuthUtils.buildResourceParameter(mcpServerUrl),
);
} catch (error) {
debugLogger.warn(
`Could not add resource parameter: ${getErrorMessage(error)}`,
);
}
}
const response = await fetch(tokenUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
Accept: 'application/json, application/x-www-form-urlencoded',
return refreshAccessTokenShared(
{
clientId: config.clientId,
clientSecret: config.clientSecret,
scopes: config.scopes,
audiences: config.audiences,
},
body: params.toString(),
});
const responseText = await response.text();
const contentType = response.headers.get('content-type') || '';
if (!response.ok) {
// Try to parse error from form-urlencoded response
let errorMessage: string | null = null;
try {
const errorParams = new URLSearchParams(responseText);
const error = errorParams.get('error');
const errorDescription = errorParams.get('error_description');
if (error) {
errorMessage = `Token refresh failed: ${error} - ${errorDescription || 'No description'}`;
}
} catch {
// Fall back to raw error
}
throw new Error(
errorMessage ||
`Token refresh failed: ${response.status} - ${responseText}`,
);
}
// Log unexpected content types for debugging
if (
!contentType.includes('application/json') &&
!contentType.includes('application/x-www-form-urlencoded')
) {
debugLogger.warn(
`Token refresh endpoint returned unexpected content-type: ${contentType}. ` +
`Expected application/json or application/x-www-form-urlencoded. ` +
`Will attempt to parse response.`,
);
}
// Try to parse as JSON first, fall back to form-urlencoded
try {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
return JSON.parse(responseText) as OAuthTokenResponse;
} catch {
// Parse form-urlencoded response
const tokenParams = new URLSearchParams(responseText);
const accessToken = tokenParams.get('access_token');
const tokenType = tokenParams.get('token_type') || 'Bearer';
const expiresIn = tokenParams.get('expires_in');
const refreshToken = tokenParams.get('refresh_token');
const scope = tokenParams.get('scope');
if (!accessToken) {
// Check for error in response
const error = tokenParams.get('error');
const errorDescription = tokenParams.get('error_description');
throw new Error(
`Token refresh failed: ${error || 'unknown_error'} - ${errorDescription || responseText}`,
);
}
return {
access_token: accessToken,
token_type: tokenType,
expires_in: expiresIn ? parseInt(expiresIn, 10) : undefined,
refresh_token: refreshToken || undefined,
scope: scope || undefined,
} as OAuthTokenResponse;
}
refreshToken,
tokenUrl,
this.buildResourceParam(mcpServerUrl),
);
}
/**
@@ -830,17 +366,14 @@ export class MCPOAuthProvider {
}
// Generate PKCE parameters
const pkceParams = this.generatePKCEParams();
const pkceParams = generatePKCEParams();
// Determine preferred port from redirectUri if available
const preferredPort = this.getPortFromUrl(config.redirectUri);
const preferredPort = getPortFromUrl(config.redirectUri);
// Start callback server first to allocate port
// This ensures we only create one server and eliminates race conditions
const callbackServer = this.startCallbackServer(
pkceParams.state,
preferredPort,
);
const callbackServer = startCallbackServer(pkceParams.state, preferredPort);
// Wait for server to start and get the allocated port
// We need this port for client registration and auth URL building
@@ -892,12 +425,24 @@ export class MCPOAuthProvider {
);
}
// Build flow config for shared utilities
const flowConfig: OAuthFlowConfig = {
clientId: config.clientId,
clientSecret: config.clientSecret,
authorizationUrl: config.authorizationUrl,
tokenUrl: config.tokenUrl,
scopes: config.scopes,
audiences: config.audiences,
redirectUri: config.redirectUri,
};
// Build authorization URL
const authUrl = this.buildAuthorizationUrl(
config,
const resource = this.buildResourceParam(mcpServerUrl);
const authUrl = buildAuthorizationUrl(
flowConfig,
pkceParams,
redirectPort,
mcpServerUrl,
resource,
);
const userConsent = await getConsentForOauth(
@@ -933,12 +478,12 @@ ${authUrl}
);
// Exchange code for tokens
const tokenResponse = await this.exchangeCodeForToken(
config,
const tokenResponse = await exchangeCodeForToken(
flowConfig,
code,
pkceParams.codeVerifier,
redirectPort,
mcpServerUrl,
resource,
);
// Convert to our token format