mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-26 21:14:35 -07:00
Fix oauth support for MCP servers (#10427)
This commit is contained in:
@@ -9,12 +9,16 @@ import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { AuthProviderType, type Config } from '../config/config.js';
|
||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
||||
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
||||
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import type { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
import {
|
||||
connectToMcpServer,
|
||||
createTransport,
|
||||
hasNetworkTransport,
|
||||
isEnabled,
|
||||
@@ -22,14 +26,30 @@ import {
|
||||
populateMcpServerCommand,
|
||||
} from './mcp-client.js';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
import * as fs from 'node:fs';
|
||||
import * as os from 'node:os';
|
||||
import * as path from 'node:path';
|
||||
|
||||
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
|
||||
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
||||
vi.mock('@google/genai');
|
||||
vi.mock('../mcp/oauth-provider.js');
|
||||
vi.mock('../mcp/oauth-token-storage.js');
|
||||
vi.mock('../mcp/oauth-utils.js');
|
||||
|
||||
describe('mcp-client', () => {
|
||||
let workspaceContext: WorkspaceContext;
|
||||
let testWorkspace: string;
|
||||
|
||||
beforeEach(() => {
|
||||
// create a tmp dir for this test
|
||||
// Create a unique temporary directory for the workspace to avoid conflicts
|
||||
testWorkspace = fs.mkdtempSync(
|
||||
path.join(os.tmpdir(), 'gemini-agent-test-'),
|
||||
);
|
||||
workspaceContext = new WorkspaceContext(testWorkspace);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
@@ -70,7 +90,7 @@ describe('mcp-client', () => {
|
||||
},
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{} as WorkspaceContext,
|
||||
workspaceContext,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
@@ -133,7 +153,7 @@ describe('mcp-client', () => {
|
||||
},
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{} as WorkspaceContext,
|
||||
workspaceContext,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
@@ -173,7 +193,7 @@ describe('mcp-client', () => {
|
||||
},
|
||||
{} as ToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{} as WorkspaceContext,
|
||||
workspaceContext,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
@@ -214,7 +234,7 @@ describe('mcp-client', () => {
|
||||
},
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{} as WorkspaceContext,
|
||||
workspaceContext,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
@@ -262,7 +282,7 @@ describe('mcp-client', () => {
|
||||
},
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{} as WorkspaceContext,
|
||||
workspaceContext,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
@@ -521,3 +541,134 @@ describe('mcp-client', () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('connectToMcpServer with OAuth', () => {
|
||||
let mockedClient: ClientLib.Client;
|
||||
let workspaceContext: WorkspaceContext;
|
||||
let testWorkspace: string;
|
||||
let mockAuthProvider: MCPOAuthProvider;
|
||||
let mockTokenStorage: MCPOAuthTokenStorage;
|
||||
|
||||
beforeEach(() => {
|
||||
mockedClient = {
|
||||
connect: vi.fn(),
|
||||
close: vi.fn(),
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
onclose: vi.fn(),
|
||||
notification: vi.fn(),
|
||||
} as unknown as ClientLib.Client;
|
||||
vi.mocked(ClientLib.Client).mockImplementation(() => mockedClient);
|
||||
|
||||
testWorkspace = fs.mkdtempSync(
|
||||
path.join(os.tmpdir(), 'gemini-agent-test-'),
|
||||
);
|
||||
workspaceContext = new WorkspaceContext(testWorkspace);
|
||||
|
||||
vi.spyOn(console, 'log').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
mockTokenStorage = {
|
||||
getCredentials: vi.fn().mockResolvedValue({ clientId: 'test-client' }),
|
||||
} as unknown as MCPOAuthTokenStorage;
|
||||
vi.mocked(MCPOAuthTokenStorage).mockReturnValue(mockTokenStorage);
|
||||
mockAuthProvider = {
|
||||
authenticate: vi.fn().mockResolvedValue(undefined),
|
||||
getValidToken: vi.fn().mockResolvedValue('test-access-token'),
|
||||
tokenStorage: mockTokenStorage,
|
||||
} as unknown as MCPOAuthProvider;
|
||||
vi.mocked(MCPOAuthProvider).mockReturnValue(mockAuthProvider);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should handle automatic OAuth flow on 401 with www-authenticate header', async () => {
|
||||
const serverUrl = 'http://test-server.com/';
|
||||
const authUrl = 'http://auth.example.com/auth';
|
||||
const tokenUrl = 'http://auth.example.com/token';
|
||||
const wwwAuthHeader = `Bearer realm="test", resource_metadata="http://test-server.com/.well-known/oauth-protected-resource"`;
|
||||
|
||||
vi.mocked(mockedClient.connect).mockRejectedValueOnce(
|
||||
new Error(`401 Unauthorized\nwww-authenticate: ${wwwAuthHeader}`),
|
||||
);
|
||||
|
||||
vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({
|
||||
authorizationUrl: authUrl,
|
||||
tokenUrl,
|
||||
scopes: ['test-scope'],
|
||||
});
|
||||
|
||||
// We need this to be an any type because we dig into its private state.
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let capturedTransport: any;
|
||||
vi.mocked(mockedClient.connect).mockImplementationOnce(
|
||||
async (transport) => {
|
||||
capturedTransport = transport;
|
||||
return Promise.resolve();
|
||||
},
|
||||
);
|
||||
|
||||
const client = await connectToMcpServer(
|
||||
'test-server',
|
||||
{ httpUrl: serverUrl },
|
||||
false,
|
||||
workspaceContext,
|
||||
);
|
||||
|
||||
expect(client).toBe(mockedClient);
|
||||
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
||||
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
||||
|
||||
const authHeader =
|
||||
capturedTransport._requestInit?.headers?.['Authorization'];
|
||||
expect(authHeader).toBe('Bearer test-access-token');
|
||||
});
|
||||
|
||||
it('should discover oauth config if not in www-authenticate header', async () => {
|
||||
const serverUrl = 'http://test-server.com';
|
||||
const authUrl = 'http://auth.example.com/auth';
|
||||
const tokenUrl = 'http://auth.example.com/token';
|
||||
|
||||
vi.mocked(mockedClient.connect).mockRejectedValueOnce(
|
||||
new Error('401 Unauthorized'),
|
||||
);
|
||||
|
||||
vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({
|
||||
authorizationUrl: authUrl,
|
||||
tokenUrl,
|
||||
scopes: ['test-scope'],
|
||||
});
|
||||
vi.mocked(mockAuthProvider.getValidToken).mockResolvedValue(
|
||||
'test-access-token-from-discovery',
|
||||
);
|
||||
|
||||
// We need this to be an any type because we dig into its private state.
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
let capturedTransport: any;
|
||||
vi.mocked(mockedClient.connect).mockImplementationOnce(
|
||||
async (transport) => {
|
||||
capturedTransport = transport;
|
||||
return Promise.resolve();
|
||||
},
|
||||
);
|
||||
|
||||
const client = await connectToMcpServer(
|
||||
'test-server',
|
||||
{ httpUrl: serverUrl },
|
||||
false,
|
||||
workspaceContext,
|
||||
);
|
||||
|
||||
expect(client).toBe(mockedClient);
|
||||
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
||||
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
||||
expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl);
|
||||
|
||||
const authHeader =
|
||||
capturedTransport._requestInit?.headers?.['Authorization'];
|
||||
expect(authHeader).toBe('Bearer test-access-token-from-discovery');
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user