mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-27 21:44:25 -07:00
Fix oauth support for MCP servers (#10427)
This commit is contained in:
@@ -9,12 +9,16 @@ import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
|
|||||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||||
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||||
import { AuthProviderType, type Config } from '../config/config.js';
|
import { AuthProviderType, type Config } from '../config/config.js';
|
||||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||||
|
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
||||||
|
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||||
|
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
||||||
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||||
import type { WorkspaceContext } from '../utils/workspaceContext.js';
|
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||||
import {
|
import {
|
||||||
|
connectToMcpServer,
|
||||||
createTransport,
|
createTransport,
|
||||||
hasNetworkTransport,
|
hasNetworkTransport,
|
||||||
isEnabled,
|
isEnabled,
|
||||||
@@ -22,14 +26,30 @@ import {
|
|||||||
populateMcpServerCommand,
|
populateMcpServerCommand,
|
||||||
} from './mcp-client.js';
|
} from './mcp-client.js';
|
||||||
import type { ToolRegistry } from './tool-registry.js';
|
import type { ToolRegistry } from './tool-registry.js';
|
||||||
|
import * as fs from 'node:fs';
|
||||||
|
import * as os from 'node:os';
|
||||||
|
import * as path from 'node:path';
|
||||||
|
|
||||||
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
|
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
|
||||||
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
||||||
vi.mock('@google/genai');
|
vi.mock('@google/genai');
|
||||||
vi.mock('../mcp/oauth-provider.js');
|
vi.mock('../mcp/oauth-provider.js');
|
||||||
vi.mock('../mcp/oauth-token-storage.js');
|
vi.mock('../mcp/oauth-token-storage.js');
|
||||||
|
vi.mock('../mcp/oauth-utils.js');
|
||||||
|
|
||||||
describe('mcp-client', () => {
|
describe('mcp-client', () => {
|
||||||
|
let workspaceContext: WorkspaceContext;
|
||||||
|
let testWorkspace: string;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
// create a tmp dir for this test
|
||||||
|
// Create a unique temporary directory for the workspace to avoid conflicts
|
||||||
|
testWorkspace = fs.mkdtempSync(
|
||||||
|
path.join(os.tmpdir(), 'gemini-agent-test-'),
|
||||||
|
);
|
||||||
|
workspaceContext = new WorkspaceContext(testWorkspace);
|
||||||
|
});
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
vi.restoreAllMocks();
|
vi.restoreAllMocks();
|
||||||
});
|
});
|
||||||
@@ -70,7 +90,7 @@ describe('mcp-client', () => {
|
|||||||
},
|
},
|
||||||
mockedToolRegistry,
|
mockedToolRegistry,
|
||||||
{} as PromptRegistry,
|
{} as PromptRegistry,
|
||||||
{} as WorkspaceContext,
|
workspaceContext,
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
await client.connect();
|
await client.connect();
|
||||||
@@ -133,7 +153,7 @@ describe('mcp-client', () => {
|
|||||||
},
|
},
|
||||||
mockedToolRegistry,
|
mockedToolRegistry,
|
||||||
{} as PromptRegistry,
|
{} as PromptRegistry,
|
||||||
{} as WorkspaceContext,
|
workspaceContext,
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
await client.connect();
|
await client.connect();
|
||||||
@@ -173,7 +193,7 @@ describe('mcp-client', () => {
|
|||||||
},
|
},
|
||||||
{} as ToolRegistry,
|
{} as ToolRegistry,
|
||||||
{} as PromptRegistry,
|
{} as PromptRegistry,
|
||||||
{} as WorkspaceContext,
|
workspaceContext,
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
await client.connect();
|
await client.connect();
|
||||||
@@ -214,7 +234,7 @@ describe('mcp-client', () => {
|
|||||||
},
|
},
|
||||||
mockedToolRegistry,
|
mockedToolRegistry,
|
||||||
{} as PromptRegistry,
|
{} as PromptRegistry,
|
||||||
{} as WorkspaceContext,
|
workspaceContext,
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
await client.connect();
|
await client.connect();
|
||||||
@@ -262,7 +282,7 @@ describe('mcp-client', () => {
|
|||||||
},
|
},
|
||||||
mockedToolRegistry,
|
mockedToolRegistry,
|
||||||
{} as PromptRegistry,
|
{} as PromptRegistry,
|
||||||
{} as WorkspaceContext,
|
workspaceContext,
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
await client.connect();
|
await client.connect();
|
||||||
@@ -521,3 +541,134 @@ describe('mcp-client', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('connectToMcpServer with OAuth', () => {
|
||||||
|
let mockedClient: ClientLib.Client;
|
||||||
|
let workspaceContext: WorkspaceContext;
|
||||||
|
let testWorkspace: string;
|
||||||
|
let mockAuthProvider: MCPOAuthProvider;
|
||||||
|
let mockTokenStorage: MCPOAuthTokenStorage;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockedClient = {
|
||||||
|
connect: vi.fn(),
|
||||||
|
close: vi.fn(),
|
||||||
|
registerCapabilities: vi.fn(),
|
||||||
|
setRequestHandler: vi.fn(),
|
||||||
|
onclose: vi.fn(),
|
||||||
|
notification: vi.fn(),
|
||||||
|
} as unknown as ClientLib.Client;
|
||||||
|
vi.mocked(ClientLib.Client).mockImplementation(() => mockedClient);
|
||||||
|
|
||||||
|
testWorkspace = fs.mkdtempSync(
|
||||||
|
path.join(os.tmpdir(), 'gemini-agent-test-'),
|
||||||
|
);
|
||||||
|
workspaceContext = new WorkspaceContext(testWorkspace);
|
||||||
|
|
||||||
|
vi.spyOn(console, 'log').mockImplementation(() => {});
|
||||||
|
vi.spyOn(console, 'warn').mockImplementation(() => {});
|
||||||
|
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||||
|
|
||||||
|
mockTokenStorage = {
|
||||||
|
getCredentials: vi.fn().mockResolvedValue({ clientId: 'test-client' }),
|
||||||
|
} as unknown as MCPOAuthTokenStorage;
|
||||||
|
vi.mocked(MCPOAuthTokenStorage).mockReturnValue(mockTokenStorage);
|
||||||
|
mockAuthProvider = {
|
||||||
|
authenticate: vi.fn().mockResolvedValue(undefined),
|
||||||
|
getValidToken: vi.fn().mockResolvedValue('test-access-token'),
|
||||||
|
tokenStorage: mockTokenStorage,
|
||||||
|
} as unknown as MCPOAuthProvider;
|
||||||
|
vi.mocked(MCPOAuthProvider).mockReturnValue(mockAuthProvider);
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle automatic OAuth flow on 401 with www-authenticate header', async () => {
|
||||||
|
const serverUrl = 'http://test-server.com/';
|
||||||
|
const authUrl = 'http://auth.example.com/auth';
|
||||||
|
const tokenUrl = 'http://auth.example.com/token';
|
||||||
|
const wwwAuthHeader = `Bearer realm="test", resource_metadata="http://test-server.com/.well-known/oauth-protected-resource"`;
|
||||||
|
|
||||||
|
vi.mocked(mockedClient.connect).mockRejectedValueOnce(
|
||||||
|
new Error(`401 Unauthorized\nwww-authenticate: ${wwwAuthHeader}`),
|
||||||
|
);
|
||||||
|
|
||||||
|
vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({
|
||||||
|
authorizationUrl: authUrl,
|
||||||
|
tokenUrl,
|
||||||
|
scopes: ['test-scope'],
|
||||||
|
});
|
||||||
|
|
||||||
|
// We need this to be an any type because we dig into its private state.
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
let capturedTransport: any;
|
||||||
|
vi.mocked(mockedClient.connect).mockImplementationOnce(
|
||||||
|
async (transport) => {
|
||||||
|
capturedTransport = transport;
|
||||||
|
return Promise.resolve();
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
const client = await connectToMcpServer(
|
||||||
|
'test-server',
|
||||||
|
{ httpUrl: serverUrl },
|
||||||
|
false,
|
||||||
|
workspaceContext,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(client).toBe(mockedClient);
|
||||||
|
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
||||||
|
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
||||||
|
|
||||||
|
const authHeader =
|
||||||
|
capturedTransport._requestInit?.headers?.['Authorization'];
|
||||||
|
expect(authHeader).toBe('Bearer test-access-token');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should discover oauth config if not in www-authenticate header', async () => {
|
||||||
|
const serverUrl = 'http://test-server.com';
|
||||||
|
const authUrl = 'http://auth.example.com/auth';
|
||||||
|
const tokenUrl = 'http://auth.example.com/token';
|
||||||
|
|
||||||
|
vi.mocked(mockedClient.connect).mockRejectedValueOnce(
|
||||||
|
new Error('401 Unauthorized'),
|
||||||
|
);
|
||||||
|
|
||||||
|
vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({
|
||||||
|
authorizationUrl: authUrl,
|
||||||
|
tokenUrl,
|
||||||
|
scopes: ['test-scope'],
|
||||||
|
});
|
||||||
|
vi.mocked(mockAuthProvider.getValidToken).mockResolvedValue(
|
||||||
|
'test-access-token-from-discovery',
|
||||||
|
);
|
||||||
|
|
||||||
|
// We need this to be an any type because we dig into its private state.
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
let capturedTransport: any;
|
||||||
|
vi.mocked(mockedClient.connect).mockImplementationOnce(
|
||||||
|
async (transport) => {
|
||||||
|
capturedTransport = transport;
|
||||||
|
return Promise.resolve();
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
const client = await connectToMcpServer(
|
||||||
|
'test-server',
|
||||||
|
{ httpUrl: serverUrl },
|
||||||
|
false,
|
||||||
|
workspaceContext,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(client).toBe(mockedClient);
|
||||||
|
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
||||||
|
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
||||||
|
expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl);
|
||||||
|
|
||||||
|
const authHeader =
|
||||||
|
capturedTransport._requestInit?.headers?.['Authorization'];
|
||||||
|
expect(authHeader).toBe('Bearer test-access-token-from-discovery');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|||||||
@@ -55,6 +55,8 @@ export type DiscoveredMCPPrompt = Prompt & {
|
|||||||
export enum MCPServerStatus {
|
export enum MCPServerStatus {
|
||||||
/** Server is disconnected or experiencing errors */
|
/** Server is disconnected or experiencing errors */
|
||||||
DISCONNECTED = 'disconnected',
|
DISCONNECTED = 'disconnected',
|
||||||
|
/** Server is actively disconnecting */
|
||||||
|
DISCONNECTING = 'disconnecting',
|
||||||
/** Server is in the process of connecting */
|
/** Server is in the process of connecting */
|
||||||
CONNECTING = 'connecting',
|
CONNECTING = 'connecting',
|
||||||
/** Server is connected and ready to use */
|
/** Server is connected and ready to use */
|
||||||
@@ -80,10 +82,9 @@ export enum MCPDiscoveryState {
|
|||||||
* managing the state of a single MCP server.
|
* managing the state of a single MCP server.
|
||||||
*/
|
*/
|
||||||
export class McpClient {
|
export class McpClient {
|
||||||
private client: Client;
|
private client: Client | undefined;
|
||||||
private transport: Transport | undefined;
|
private transport: Transport | undefined;
|
||||||
private status: MCPServerStatus = MCPServerStatus.DISCONNECTED;
|
private status: MCPServerStatus = MCPServerStatus.DISCONNECTED;
|
||||||
private isDisconnecting = false;
|
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
private readonly serverName: string,
|
private readonly serverName: string,
|
||||||
@@ -92,51 +93,34 @@ export class McpClient {
|
|||||||
private readonly promptRegistry: PromptRegistry,
|
private readonly promptRegistry: PromptRegistry,
|
||||||
private readonly workspaceContext: WorkspaceContext,
|
private readonly workspaceContext: WorkspaceContext,
|
||||||
private readonly debugMode: boolean,
|
private readonly debugMode: boolean,
|
||||||
) {
|
) {}
|
||||||
this.client = new Client({
|
|
||||||
name: `gemini-cli-mcp-client-${this.serverName}`,
|
|
||||||
version: '0.0.1',
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Connects to the MCP server.
|
* Connects to the MCP server.
|
||||||
*/
|
*/
|
||||||
async connect(): Promise<void> {
|
async connect(): Promise<void> {
|
||||||
this.isDisconnecting = false;
|
if (this.status !== MCPServerStatus.DISCONNECTED) {
|
||||||
|
throw new Error(
|
||||||
|
`Can only connect when the client is disconnected, current state is ${this.status}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
this.updateStatus(MCPServerStatus.CONNECTING);
|
this.updateStatus(MCPServerStatus.CONNECTING);
|
||||||
try {
|
try {
|
||||||
this.transport = await this.createTransport();
|
this.client = await connectToMcpServer(
|
||||||
|
this.serverName,
|
||||||
|
this.serverConfig,
|
||||||
|
this.debugMode,
|
||||||
|
this.workspaceContext,
|
||||||
|
);
|
||||||
|
const originalOnError = this.client.onerror;
|
||||||
this.client.onerror = (error) => {
|
this.client.onerror = (error) => {
|
||||||
if (this.isDisconnecting) {
|
if (this.status !== MCPServerStatus.CONNECTED) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (originalOnError) originalOnError(error);
|
||||||
console.error(`MCP ERROR (${this.serverName}):`, error.toString());
|
console.error(`MCP ERROR (${this.serverName}):`, error.toString());
|
||||||
this.updateStatus(MCPServerStatus.DISCONNECTED);
|
this.updateStatus(MCPServerStatus.DISCONNECTED);
|
||||||
};
|
};
|
||||||
|
|
||||||
this.client.registerCapabilities({
|
|
||||||
roots: {},
|
|
||||||
});
|
|
||||||
|
|
||||||
this.client.setRequestHandler(ListRootsRequestSchema, async () => {
|
|
||||||
const roots = [];
|
|
||||||
for (const dir of this.workspaceContext.getDirectories()) {
|
|
||||||
roots.push({
|
|
||||||
uri: pathToFileURL(dir).toString(),
|
|
||||||
name: basename(dir),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
roots,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
await this.client.connect(this.transport, {
|
|
||||||
timeout: this.serverConfig.timeout,
|
|
||||||
});
|
|
||||||
|
|
||||||
this.updateStatus(MCPServerStatus.CONNECTED);
|
this.updateStatus(MCPServerStatus.CONNECTED);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
this.updateStatus(MCPServerStatus.DISCONNECTED);
|
this.updateStatus(MCPServerStatus.DISCONNECTED);
|
||||||
@@ -168,11 +152,18 @@ export class McpClient {
|
|||||||
* Disconnects from the MCP server.
|
* Disconnects from the MCP server.
|
||||||
*/
|
*/
|
||||||
async disconnect(): Promise<void> {
|
async disconnect(): Promise<void> {
|
||||||
this.isDisconnecting = true;
|
if (this.status !== MCPServerStatus.CONNECTED) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
this.updateStatus(MCPServerStatus.DISCONNECTING);
|
||||||
|
const client = this.client;
|
||||||
|
this.client = undefined;
|
||||||
if (this.transport) {
|
if (this.transport) {
|
||||||
await this.transport.close();
|
await this.transport.close();
|
||||||
}
|
}
|
||||||
this.client.close();
|
if (client) {
|
||||||
|
await client.close();
|
||||||
|
}
|
||||||
this.updateStatus(MCPServerStatus.DISCONNECTED);
|
this.updateStatus(MCPServerStatus.DISCONNECTED);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,21 +179,27 @@ export class McpClient {
|
|||||||
updateMCPServerStatus(this.serverName, status);
|
updateMCPServerStatus(this.serverName, status);
|
||||||
}
|
}
|
||||||
|
|
||||||
private async createTransport(): Promise<Transport> {
|
private assertConnected(): void {
|
||||||
return createTransport(this.serverName, this.serverConfig, this.debugMode);
|
if (this.status !== MCPServerStatus.CONNECTED) {
|
||||||
|
throw new Error(
|
||||||
|
`Client is not connected, must connect before interacting with the server. Current state is ${this.status}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async discoverTools(cliConfig: Config): Promise<DiscoveredMCPTool[]> {
|
private async discoverTools(cliConfig: Config): Promise<DiscoveredMCPTool[]> {
|
||||||
|
this.assertConnected();
|
||||||
return discoverTools(
|
return discoverTools(
|
||||||
this.serverName,
|
this.serverName,
|
||||||
this.serverConfig,
|
this.serverConfig,
|
||||||
this.client,
|
this.client!,
|
||||||
cliConfig,
|
cliConfig,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private async discoverPrompts(): Promise<Prompt[]> {
|
private async discoverPrompts(): Promise<Prompt[]> {
|
||||||
return discoverPrompts(this.serverName, this.client, this.promptRegistry);
|
this.assertConnected();
|
||||||
|
return discoverPrompts(this.serverName, this.client!, this.promptRegistry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user