mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 05:12:55 -07:00
fix(core): resolve PKCE length issue and stabilize OAuth redirect port (#16815)
This commit is contained in:
@@ -1006,6 +1006,154 @@ describe('MCPOAuthProvider', () => {
|
|||||||
|
|
||||||
global.setTimeout = originalSetTimeout;
|
global.setTimeout = originalSetTimeout;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should use port from redirectUri if provided', async () => {
|
||||||
|
const configWithPort: MCPOAuthConfig = {
|
||||||
|
...mockConfig,
|
||||||
|
redirectUri: 'http://localhost:12345/oauth/callback',
|
||||||
|
};
|
||||||
|
|
||||||
|
let callbackHandler: unknown;
|
||||||
|
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||||
|
callbackHandler = handler;
|
||||||
|
return mockHttpServer as unknown as http.Server;
|
||||||
|
});
|
||||||
|
|
||||||
|
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||||
|
callback?.();
|
||||||
|
setTimeout(() => {
|
||||||
|
const mockReq = {
|
||||||
|
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
|
||||||
|
};
|
||||||
|
const mockRes = {
|
||||||
|
writeHead: vi.fn(),
|
||||||
|
end: vi.fn(),
|
||||||
|
};
|
||||||
|
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||||
|
mockReq,
|
||||||
|
mockRes,
|
||||||
|
);
|
||||||
|
}, 10);
|
||||||
|
});
|
||||||
|
mockHttpServer.address.mockReturnValue({
|
||||||
|
port: 12345,
|
||||||
|
address: '127.0.0.1',
|
||||||
|
family: 'IPv4',
|
||||||
|
});
|
||||||
|
|
||||||
|
mockFetch.mockResolvedValueOnce(
|
||||||
|
createMockResponse({
|
||||||
|
ok: true,
|
||||||
|
contentType: 'application/json',
|
||||||
|
text: JSON.stringify(mockTokenResponse),
|
||||||
|
json: mockTokenResponse,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
const authProvider = new MCPOAuthProvider();
|
||||||
|
await authProvider.authenticate('test-server', configWithPort);
|
||||||
|
|
||||||
|
expect(mockHttpServer.listen).toHaveBeenCalledWith(
|
||||||
|
12345,
|
||||||
|
expect.any(Function),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should ignore invalid ports in redirectUri', async () => {
|
||||||
|
const configWithInvalidPort: MCPOAuthConfig = {
|
||||||
|
...mockConfig,
|
||||||
|
redirectUri: 'http://localhost:invalid/oauth/callback',
|
||||||
|
};
|
||||||
|
|
||||||
|
let callbackHandler: unknown;
|
||||||
|
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||||
|
callbackHandler = handler;
|
||||||
|
return mockHttpServer as unknown as http.Server;
|
||||||
|
});
|
||||||
|
|
||||||
|
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||||
|
callback?.();
|
||||||
|
setTimeout(() => {
|
||||||
|
const mockReq = {
|
||||||
|
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
|
||||||
|
};
|
||||||
|
const mockRes = {
|
||||||
|
writeHead: vi.fn(),
|
||||||
|
end: vi.fn(),
|
||||||
|
};
|
||||||
|
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||||
|
mockReq,
|
||||||
|
mockRes,
|
||||||
|
);
|
||||||
|
}, 10);
|
||||||
|
});
|
||||||
|
|
||||||
|
mockFetch.mockResolvedValueOnce(
|
||||||
|
createMockResponse({
|
||||||
|
ok: true,
|
||||||
|
contentType: 'application/json',
|
||||||
|
text: JSON.stringify(mockTokenResponse),
|
||||||
|
json: mockTokenResponse,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
const authProvider = new MCPOAuthProvider();
|
||||||
|
await authProvider.authenticate('test-server', configWithInvalidPort);
|
||||||
|
|
||||||
|
// Should be called with 0 (OS assigned) because the port was invalid
|
||||||
|
expect(mockHttpServer.listen).toHaveBeenCalledWith(
|
||||||
|
0,
|
||||||
|
expect.any(Function),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not default to privileged ports when redirectUri has no port', async () => {
|
||||||
|
const configNoPort: MCPOAuthConfig = {
|
||||||
|
...mockConfig,
|
||||||
|
redirectUri: 'http://localhost/oauth/callback',
|
||||||
|
};
|
||||||
|
|
||||||
|
let callbackHandler: unknown;
|
||||||
|
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||||
|
callbackHandler = handler;
|
||||||
|
return mockHttpServer as unknown as http.Server;
|
||||||
|
});
|
||||||
|
|
||||||
|
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||||
|
callback?.();
|
||||||
|
setTimeout(() => {
|
||||||
|
const mockReq = {
|
||||||
|
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
|
||||||
|
};
|
||||||
|
const mockRes = {
|
||||||
|
writeHead: vi.fn(),
|
||||||
|
end: vi.fn(),
|
||||||
|
};
|
||||||
|
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||||
|
mockReq,
|
||||||
|
mockRes,
|
||||||
|
);
|
||||||
|
}, 10);
|
||||||
|
});
|
||||||
|
|
||||||
|
mockFetch.mockResolvedValueOnce(
|
||||||
|
createMockResponse({
|
||||||
|
ok: true,
|
||||||
|
contentType: 'application/json',
|
||||||
|
text: JSON.stringify(mockTokenResponse),
|
||||||
|
json: mockTokenResponse,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
const authProvider = new MCPOAuthProvider();
|
||||||
|
await authProvider.authenticate('test-server', configNoPort);
|
||||||
|
|
||||||
|
// Should be called with 0 (OS assigned), not 80
|
||||||
|
expect(mockHttpServer.listen).toHaveBeenCalledWith(
|
||||||
|
0,
|
||||||
|
expect.any(Function),
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('refreshAccessToken', () => {
|
describe('refreshAccessToken', () => {
|
||||||
@@ -1286,7 +1434,7 @@ describe('MCPOAuthProvider', () => {
|
|||||||
const authProvider = new MCPOAuthProvider();
|
const authProvider = new MCPOAuthProvider();
|
||||||
await authProvider.authenticate('test-server', mockConfig);
|
await authProvider.authenticate('test-server', mockConfig);
|
||||||
|
|
||||||
expect(crypto.randomBytes).toHaveBeenCalledWith(32); // code verifier
|
expect(crypto.randomBytes).toHaveBeenCalledWith(64); // code verifier
|
||||||
expect(crypto.randomBytes).toHaveBeenCalledWith(16); // state
|
expect(crypto.randomBytes).toHaveBeenCalledWith(16); // state
|
||||||
expect(crypto.createHash).toHaveBeenCalledWith('sha256');
|
expect(crypto.createHash).toHaveBeenCalledWith('sha256');
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -245,7 +245,8 @@ export class MCPOAuthProvider {
|
|||||||
*/
|
*/
|
||||||
private generatePKCEParams(): PKCEParams {
|
private generatePKCEParams(): PKCEParams {
|
||||||
// Generate code verifier (43-128 characters)
|
// Generate code verifier (43-128 characters)
|
||||||
const codeVerifier = crypto.randomBytes(32).toString('base64url');
|
// using 64 bytes results in ~86 characters, safely above the minimum of 43
|
||||||
|
const codeVerifier = crypto.randomBytes(64).toString('base64url');
|
||||||
|
|
||||||
// Generate code challenge using SHA256
|
// Generate code challenge using SHA256
|
||||||
const codeChallenge = crypto
|
const codeChallenge = crypto
|
||||||
@@ -266,7 +267,10 @@ export class MCPOAuthProvider {
|
|||||||
* @param expectedState The state parameter to validate
|
* @param expectedState The state parameter to validate
|
||||||
* @returns Object containing the port (available immediately) and a promise for the auth response
|
* @returns Object containing the port (available immediately) and a promise for the auth response
|
||||||
*/
|
*/
|
||||||
private startCallbackServer(expectedState: string): {
|
private startCallbackServer(
|
||||||
|
expectedState: string,
|
||||||
|
port?: number,
|
||||||
|
): {
|
||||||
port: Promise<number>;
|
port: Promise<number>;
|
||||||
response: Promise<OAuthAuthorizationResponse>;
|
response: Promise<OAuthAuthorizationResponse>;
|
||||||
} {
|
} {
|
||||||
@@ -353,9 +357,10 @@ export class MCPOAuthProvider {
|
|||||||
reject(error);
|
reject(error);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Determine which port to use (env var or OS-assigned)
|
// Determine which port to use (env var, argument, or OS-assigned)
|
||||||
const portStr = process.env['OAUTH_CALLBACK_PORT'];
|
|
||||||
let listenPort = 0; // Default to OS-assigned port
|
let listenPort = 0; // Default to OS-assigned port
|
||||||
|
|
||||||
|
const portStr = process.env['OAUTH_CALLBACK_PORT'];
|
||||||
if (portStr) {
|
if (portStr) {
|
||||||
const envPort = parseInt(portStr, 10);
|
const envPort = parseInt(portStr, 10);
|
||||||
if (isNaN(envPort) || envPort <= 0 || envPort > 65535) {
|
if (isNaN(envPort) || envPort <= 0 || envPort > 65535) {
|
||||||
@@ -367,6 +372,8 @@ export class MCPOAuthProvider {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
listenPort = envPort;
|
listenPort = envPort;
|
||||||
|
} else if (port !== undefined) {
|
||||||
|
listenPort = port;
|
||||||
}
|
}
|
||||||
|
|
||||||
server.listen(listenPort, () => {
|
server.listen(listenPort, () => {
|
||||||
@@ -393,7 +400,34 @@ export class MCPOAuthProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Build the authorization URL with PKCE parameters.
|
* Extract the port number from a URL string if available and valid.
|
||||||
|
*
|
||||||
|
* @param urlString The URL string to parse
|
||||||
|
* @returns The port number or undefined if not found or invalid
|
||||||
|
*/
|
||||||
|
private getPortFromUrl(urlString?: string): number | undefined {
|
||||||
|
if (!urlString) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const url = new URL(urlString);
|
||||||
|
if (url.port) {
|
||||||
|
const parsedPort = parseInt(url.port, 10);
|
||||||
|
if (!isNaN(parsedPort) && parsedPort > 0 && parsedPort <= 65535) {
|
||||||
|
return parsedPort;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Ignore invalid URL
|
||||||
|
}
|
||||||
|
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build the authorization URL for the OAuth flow.
|
||||||
|
|
||||||
*
|
*
|
||||||
* @param config OAuth configuration
|
* @param config OAuth configuration
|
||||||
* @param pkceParams PKCE parameters
|
* @param pkceParams PKCE parameters
|
||||||
@@ -798,9 +832,15 @@ export class MCPOAuthProvider {
|
|||||||
// Generate PKCE parameters
|
// Generate PKCE parameters
|
||||||
const pkceParams = this.generatePKCEParams();
|
const pkceParams = this.generatePKCEParams();
|
||||||
|
|
||||||
|
// Determine preferred port from redirectUri if available
|
||||||
|
const preferredPort = this.getPortFromUrl(config.redirectUri);
|
||||||
|
|
||||||
// Start callback server first to allocate port
|
// Start callback server first to allocate port
|
||||||
// This ensures we only create one server and eliminates race conditions
|
// This ensures we only create one server and eliminates race conditions
|
||||||
const callbackServer = this.startCallbackServer(pkceParams.state);
|
const callbackServer = this.startCallbackServer(
|
||||||
|
pkceParams.state,
|
||||||
|
preferredPort,
|
||||||
|
);
|
||||||
|
|
||||||
// Wait for server to start and get the allocated port
|
// Wait for server to start and get the allocated port
|
||||||
// We need this port for client registration and auth URL building
|
// We need this port for client registration and auth URL building
|
||||||
|
|||||||
Reference in New Issue
Block a user