From 3f79d7e5bbb3bdd5364cd72cc2737dee368de5fe Mon Sep 17 00:00:00 2001 From: Jacob MacDonald Date: Fri, 3 Oct 2025 09:23:55 -0700 Subject: [PATCH] Fix oauth support for MCP servers (#10427) --- packages/core/src/tools/mcp-client.test.ts | 165 ++++++++++++++++++++- packages/core/src/tools/mcp-client.ts | 77 +++++----- 2 files changed, 195 insertions(+), 47 deletions(-) diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 85e6ff1371..fe755db7bc 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -9,12 +9,16 @@ import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; -import { afterEach, describe, expect, it, vi } from 'vitest'; +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { AuthProviderType, type Config } from '../config/config.js'; import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; +import { MCPOAuthProvider } from '../mcp/oauth-provider.js'; +import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js'; +import { OAuthUtils } from '../mcp/oauth-utils.js'; import type { PromptRegistry } from '../prompts/prompt-registry.js'; -import type { WorkspaceContext } from '../utils/workspaceContext.js'; +import { WorkspaceContext } from '../utils/workspaceContext.js'; import { + connectToMcpServer, createTransport, hasNetworkTransport, isEnabled, @@ -22,14 +26,30 @@ import { populateMcpServerCommand, } from './mcp-client.js'; import type { ToolRegistry } from './tool-registry.js'; +import * as fs from 'node:fs'; +import * as os from 'node:os'; +import * as path from 'node:path'; vi.mock('@modelcontextprotocol/sdk/client/stdio.js'); vi.mock('@modelcontextprotocol/sdk/client/index.js'); vi.mock('@google/genai'); vi.mock('../mcp/oauth-provider.js'); vi.mock('../mcp/oauth-token-storage.js'); +vi.mock('../mcp/oauth-utils.js'); describe('mcp-client', () => { + let workspaceContext: WorkspaceContext; + let testWorkspace: string; + + beforeEach(() => { + // create a tmp dir for this test + // Create a unique temporary directory for the workspace to avoid conflicts + testWorkspace = fs.mkdtempSync( + path.join(os.tmpdir(), 'gemini-agent-test-'), + ); + workspaceContext = new WorkspaceContext(testWorkspace); + }); + afterEach(() => { vi.restoreAllMocks(); }); @@ -70,7 +90,7 @@ describe('mcp-client', () => { }, mockedToolRegistry, {} as PromptRegistry, - {} as WorkspaceContext, + workspaceContext, false, ); await client.connect(); @@ -133,7 +153,7 @@ describe('mcp-client', () => { }, mockedToolRegistry, {} as PromptRegistry, - {} as WorkspaceContext, + workspaceContext, false, ); await client.connect(); @@ -173,7 +193,7 @@ describe('mcp-client', () => { }, {} as ToolRegistry, {} as PromptRegistry, - {} as WorkspaceContext, + workspaceContext, false, ); await client.connect(); @@ -214,7 +234,7 @@ describe('mcp-client', () => { }, mockedToolRegistry, {} as PromptRegistry, - {} as WorkspaceContext, + workspaceContext, false, ); await client.connect(); @@ -262,7 +282,7 @@ describe('mcp-client', () => { }, mockedToolRegistry, {} as PromptRegistry, - {} as WorkspaceContext, + workspaceContext, false, ); await client.connect(); @@ -521,3 +541,134 @@ describe('mcp-client', () => { }); }); }); + +describe('connectToMcpServer with OAuth', () => { + let mockedClient: ClientLib.Client; + let workspaceContext: WorkspaceContext; + let testWorkspace: string; + let mockAuthProvider: MCPOAuthProvider; + let mockTokenStorage: MCPOAuthTokenStorage; + + beforeEach(() => { + mockedClient = { + connect: vi.fn(), + close: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + onclose: vi.fn(), + notification: vi.fn(), + } as unknown as ClientLib.Client; + vi.mocked(ClientLib.Client).mockImplementation(() => mockedClient); + + testWorkspace = fs.mkdtempSync( + path.join(os.tmpdir(), 'gemini-agent-test-'), + ); + workspaceContext = new WorkspaceContext(testWorkspace); + + vi.spyOn(console, 'log').mockImplementation(() => {}); + vi.spyOn(console, 'warn').mockImplementation(() => {}); + vi.spyOn(console, 'error').mockImplementation(() => {}); + + mockTokenStorage = { + getCredentials: vi.fn().mockResolvedValue({ clientId: 'test-client' }), + } as unknown as MCPOAuthTokenStorage; + vi.mocked(MCPOAuthTokenStorage).mockReturnValue(mockTokenStorage); + mockAuthProvider = { + authenticate: vi.fn().mockResolvedValue(undefined), + getValidToken: vi.fn().mockResolvedValue('test-access-token'), + tokenStorage: mockTokenStorage, + } as unknown as MCPOAuthProvider; + vi.mocked(MCPOAuthProvider).mockReturnValue(mockAuthProvider); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it('should handle automatic OAuth flow on 401 with www-authenticate header', async () => { + const serverUrl = 'http://test-server.com/'; + const authUrl = 'http://auth.example.com/auth'; + const tokenUrl = 'http://auth.example.com/token'; + const wwwAuthHeader = `Bearer realm="test", resource_metadata="http://test-server.com/.well-known/oauth-protected-resource"`; + + vi.mocked(mockedClient.connect).mockRejectedValueOnce( + new Error(`401 Unauthorized\nwww-authenticate: ${wwwAuthHeader}`), + ); + + vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({ + authorizationUrl: authUrl, + tokenUrl, + scopes: ['test-scope'], + }); + + // We need this to be an any type because we dig into its private state. + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let capturedTransport: any; + vi.mocked(mockedClient.connect).mockImplementationOnce( + async (transport) => { + capturedTransport = transport; + return Promise.resolve(); + }, + ); + + const client = await connectToMcpServer( + 'test-server', + { httpUrl: serverUrl }, + false, + workspaceContext, + ); + + expect(client).toBe(mockedClient); + expect(mockedClient.connect).toHaveBeenCalledTimes(2); + expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); + + const authHeader = + capturedTransport._requestInit?.headers?.['Authorization']; + expect(authHeader).toBe('Bearer test-access-token'); + }); + + it('should discover oauth config if not in www-authenticate header', async () => { + const serverUrl = 'http://test-server.com'; + const authUrl = 'http://auth.example.com/auth'; + const tokenUrl = 'http://auth.example.com/token'; + + vi.mocked(mockedClient.connect).mockRejectedValueOnce( + new Error('401 Unauthorized'), + ); + + vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({ + authorizationUrl: authUrl, + tokenUrl, + scopes: ['test-scope'], + }); + vi.mocked(mockAuthProvider.getValidToken).mockResolvedValue( + 'test-access-token-from-discovery', + ); + + // We need this to be an any type because we dig into its private state. + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let capturedTransport: any; + vi.mocked(mockedClient.connect).mockImplementationOnce( + async (transport) => { + capturedTransport = transport; + return Promise.resolve(); + }, + ); + + const client = await connectToMcpServer( + 'test-server', + { httpUrl: serverUrl }, + false, + workspaceContext, + ); + + expect(client).toBe(mockedClient); + expect(mockedClient.connect).toHaveBeenCalledTimes(2); + expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); + expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl); + + const authHeader = + capturedTransport._requestInit?.headers?.['Authorization']; + expect(authHeader).toBe('Bearer test-access-token-from-discovery'); + }); +}); diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index ca25bde453..dc49f9386b 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -55,6 +55,8 @@ export type DiscoveredMCPPrompt = Prompt & { export enum MCPServerStatus { /** Server is disconnected or experiencing errors */ DISCONNECTED = 'disconnected', + /** Server is actively disconnecting */ + DISCONNECTING = 'disconnecting', /** Server is in the process of connecting */ CONNECTING = 'connecting', /** Server is connected and ready to use */ @@ -80,10 +82,9 @@ export enum MCPDiscoveryState { * managing the state of a single MCP server. */ export class McpClient { - private client: Client; + private client: Client | undefined; private transport: Transport | undefined; private status: MCPServerStatus = MCPServerStatus.DISCONNECTED; - private isDisconnecting = false; constructor( private readonly serverName: string, @@ -92,51 +93,34 @@ export class McpClient { private readonly promptRegistry: PromptRegistry, private readonly workspaceContext: WorkspaceContext, private readonly debugMode: boolean, - ) { - this.client = new Client({ - name: `gemini-cli-mcp-client-${this.serverName}`, - version: '0.0.1', - }); - } + ) {} /** * Connects to the MCP server. */ async connect(): Promise { - this.isDisconnecting = false; + if (this.status !== MCPServerStatus.DISCONNECTED) { + throw new Error( + `Can only connect when the client is disconnected, current state is ${this.status}`, + ); + } this.updateStatus(MCPServerStatus.CONNECTING); try { - this.transport = await this.createTransport(); - + this.client = await connectToMcpServer( + this.serverName, + this.serverConfig, + this.debugMode, + this.workspaceContext, + ); + const originalOnError = this.client.onerror; this.client.onerror = (error) => { - if (this.isDisconnecting) { + if (this.status !== MCPServerStatus.CONNECTED) { return; } + if (originalOnError) originalOnError(error); console.error(`MCP ERROR (${this.serverName}):`, error.toString()); this.updateStatus(MCPServerStatus.DISCONNECTED); }; - - this.client.registerCapabilities({ - roots: {}, - }); - - this.client.setRequestHandler(ListRootsRequestSchema, async () => { - const roots = []; - for (const dir of this.workspaceContext.getDirectories()) { - roots.push({ - uri: pathToFileURL(dir).toString(), - name: basename(dir), - }); - } - return { - roots, - }; - }); - - await this.client.connect(this.transport, { - timeout: this.serverConfig.timeout, - }); - this.updateStatus(MCPServerStatus.CONNECTED); } catch (error) { this.updateStatus(MCPServerStatus.DISCONNECTED); @@ -168,11 +152,18 @@ export class McpClient { * Disconnects from the MCP server. */ async disconnect(): Promise { - this.isDisconnecting = true; + if (this.status !== MCPServerStatus.CONNECTED) { + return; + } + this.updateStatus(MCPServerStatus.DISCONNECTING); + const client = this.client; + this.client = undefined; if (this.transport) { await this.transport.close(); } - this.client.close(); + if (client) { + await client.close(); + } this.updateStatus(MCPServerStatus.DISCONNECTED); } @@ -188,21 +179,27 @@ export class McpClient { updateMCPServerStatus(this.serverName, status); } - private async createTransport(): Promise { - return createTransport(this.serverName, this.serverConfig, this.debugMode); + private assertConnected(): void { + if (this.status !== MCPServerStatus.CONNECTED) { + throw new Error( + `Client is not connected, must connect before interacting with the server. Current state is ${this.status}`, + ); + } } private async discoverTools(cliConfig: Config): Promise { + this.assertConnected(); return discoverTools( this.serverName, this.serverConfig, - this.client, + this.client!, cliConfig, ); } private async discoverPrompts(): Promise { - return discoverPrompts(this.serverName, this.client, this.promptRegistry); + this.assertConnected(); + return discoverPrompts(this.serverName, this.client!, this.promptRegistry); } }