mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
feat: consolidate remote MCP servers to use url in config (#13762)
This commit is contained in:
@@ -107,6 +107,7 @@ describe('mcp add command', () => {
|
||||
expect(mockSetValue).toHaveBeenCalledWith(SettingScope.User, 'mcpServers', {
|
||||
'sse-server': {
|
||||
url: 'https://example.com/sse-endpoint',
|
||||
type: 'sse',
|
||||
headers: { 'X-API-Key': 'your-key' },
|
||||
},
|
||||
});
|
||||
@@ -122,7 +123,8 @@ describe('mcp add command', () => {
|
||||
'mcpServers',
|
||||
{
|
||||
'http-server': {
|
||||
httpUrl: 'https://example.com/mcp',
|
||||
url: 'https://example.com/mcp',
|
||||
type: 'http',
|
||||
headers: { Authorization: 'Bearer your-token' },
|
||||
},
|
||||
},
|
||||
|
||||
@@ -69,6 +69,7 @@ async function addMcpServer(
|
||||
case 'sse':
|
||||
newServer = {
|
||||
url: commandOrUrl,
|
||||
type: 'sse',
|
||||
headers,
|
||||
timeout,
|
||||
trust,
|
||||
@@ -79,7 +80,8 @@ async function addMcpServer(
|
||||
break;
|
||||
case 'http':
|
||||
newServer = {
|
||||
httpUrl: commandOrUrl,
|
||||
url: commandOrUrl,
|
||||
type: 'http',
|
||||
headers,
|
||||
timeout,
|
||||
trust,
|
||||
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
MCPServerStatus,
|
||||
getErrorMessage,
|
||||
MCPOAuthTokenStorage,
|
||||
mcpServerRequiresOAuth,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { appEvents, AppEvent } from '../../utils/events.js';
|
||||
import { MessageType, type HistoryItemMcpStatus } from '../types.js';
|
||||
@@ -47,12 +48,23 @@ const authCommand: SlashCommand = {
|
||||
const mcpServers = config.getMcpClientManager()?.getMcpServers() ?? {};
|
||||
|
||||
if (!serverName) {
|
||||
// List servers that support OAuth
|
||||
const oauthServers = Object.entries(mcpServers)
|
||||
// List servers that support OAuth from two sources:
|
||||
// 1. Servers with oauth.enabled in config
|
||||
// 2. Servers detected as requiring OAuth (returned 401)
|
||||
const configuredOAuthServers = Object.entries(mcpServers)
|
||||
.filter(([_, server]) => server.oauth?.enabled)
|
||||
.map(([name, _]) => name);
|
||||
|
||||
if (oauthServers.length === 0) {
|
||||
const detectedOAuthServers = Array.from(
|
||||
mcpServerRequiresOAuth.keys(),
|
||||
).filter((name) => mcpServers[name]); // Only include configured servers
|
||||
|
||||
// Combine and deduplicate
|
||||
const allOAuthServers = [
|
||||
...new Set([...configuredOAuthServers, ...detectedOAuthServers]),
|
||||
];
|
||||
|
||||
if (allOAuthServers.length === 0) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
@@ -63,7 +75,7 @@ const authCommand: SlashCommand = {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: `MCP servers with OAuth authentication:\n${oauthServers.map((s) => ` - ${s}`).join('\n')}\n\nUse /mcp auth <server-name> to authenticate.`,
|
||||
content: `MCP servers with OAuth authentication:\n${allOAuthServers.map((s) => ` - ${s}`).join('\n')}\n\nUse /mcp auth <server-name> to authenticate.`,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -220,7 +232,8 @@ const listAction = async (
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
for (const serverName of serverNames) {
|
||||
const server = mcpServers[serverName];
|
||||
if (server.oauth?.enabled) {
|
||||
// Check auth status for servers with oauth.enabled OR detected as requiring OAuth
|
||||
if (server.oauth?.enabled || mcpServerRequiresOAuth.has(serverName)) {
|
||||
const creds = await tokenStorage.getCredentials(serverName);
|
||||
if (creds) {
|
||||
if (creds.token.expiresAt && creds.token.expiresAt < Date.now()) {
|
||||
|
||||
@@ -190,6 +190,12 @@ export class MCPServerConfig {
|
||||
readonly headers?: Record<string, string>,
|
||||
// For websocket transport
|
||||
readonly tcp?: string,
|
||||
// Transport type (optional, for use with 'url' field)
|
||||
// When set to 'http', uses StreamableHTTPClientTransport
|
||||
// When set to 'sse', uses SSEClientTransport
|
||||
// When omitted, auto-detects transport type
|
||||
// Note: 'httpUrl' is deprecated in favor of 'url' + 'type'
|
||||
readonly type?: 'sse' | 'http',
|
||||
// Common
|
||||
readonly timeout?: number,
|
||||
readonly trust?: boolean,
|
||||
|
||||
@@ -15,7 +15,7 @@ import {
|
||||
MCPDiscoveryState,
|
||||
populateMcpServerCommand,
|
||||
} from './mcp-client.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { getErrorMessage, isAuthenticationError } from '../utils/errors.js';
|
||||
import type { EventEmitter } from 'node:events';
|
||||
import { coreEvents } from '../utils/events.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
@@ -186,14 +186,17 @@ export class McpClientManager {
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
} catch (error) {
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
// Log the error but don't let a single failed server stop the others
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`Error during discovery for server '${name}': ${getErrorMessage(
|
||||
// Check if this is a 401/auth error - if so, don't show as red error
|
||||
// (the info message was already shown in mcp-client.ts)
|
||||
if (!isAuthenticationError(error)) {
|
||||
// Log the error but don't let a single failed server stop the others
|
||||
const errorMessage = getErrorMessage(error);
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`Error during discovery for MCP server '${name}': ${errorMessage}`,
|
||||
error,
|
||||
)}`,
|
||||
error,
|
||||
);
|
||||
);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
// This is required to update the content generator configuration with the
|
||||
|
||||
@@ -7,7 +7,10 @@
|
||||
import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import {
|
||||
StreamableHTTPClientTransport,
|
||||
StreamableHTTPError,
|
||||
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { AuthProviderType, type Config } from '../config/config.js';
|
||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||
@@ -490,16 +493,14 @@ describe('mcp-client', () => {
|
||||
);
|
||||
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
expect(transport).toHaveProperty(
|
||||
'_url',
|
||||
new URL('http://test-server/'),
|
||||
);
|
||||
expect(transport).toMatchObject({
|
||||
_url: new URL('http://test-server'),
|
||||
_requestInit: { headers: {} },
|
||||
});
|
||||
});
|
||||
|
||||
it('with headers', async () => {
|
||||
// We need this to be an any type because we dig into its private state.
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const transport: any = await createTransport(
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
httpUrl: 'http://test-server',
|
||||
@@ -507,13 +508,14 @@ describe('mcp-client', () => {
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
expect(transport).toHaveProperty(
|
||||
'_url',
|
||||
new URL('http://test-server/'),
|
||||
);
|
||||
const authHeader = transport._requestInit?.headers?.['Authorization'];
|
||||
expect(authHeader).toBe('derp');
|
||||
expect(transport).toMatchObject({
|
||||
_url: new URL('http://test-server'),
|
||||
_requestInit: {
|
||||
headers: { Authorization: 'derp' },
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -526,17 +528,15 @@ describe('mcp-client', () => {
|
||||
},
|
||||
false,
|
||||
);
|
||||
expect(transport).toBeInstanceOf(SSEClientTransport);
|
||||
expect(transport).toHaveProperty(
|
||||
'_url',
|
||||
new URL('http://test-server/'),
|
||||
);
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
expect(transport).toMatchObject({
|
||||
_url: new URL('http://test-server'),
|
||||
_requestInit: { headers: {} },
|
||||
});
|
||||
});
|
||||
|
||||
it('with headers', async () => {
|
||||
// We need this to be an any type because we dig into its private state.
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const transport: any = await createTransport(
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
url: 'http://test-server',
|
||||
@@ -544,13 +544,122 @@ describe('mcp-client', () => {
|
||||
},
|
||||
false,
|
||||
);
|
||||
expect(transport).toBeInstanceOf(SSEClientTransport);
|
||||
expect(transport).toHaveProperty(
|
||||
'_url',
|
||||
new URL('http://test-server/'),
|
||||
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
expect(transport).toMatchObject({
|
||||
_url: new URL('http://test-server'),
|
||||
_requestInit: {
|
||||
headers: { Authorization: 'derp' },
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('with type="http" creates StreamableHTTPClientTransport', async () => {
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
url: 'http://test-server',
|
||||
type: 'http',
|
||||
},
|
||||
false,
|
||||
);
|
||||
const authHeader = transport._requestInit?.headers?.['Authorization'];
|
||||
expect(authHeader).toBe('derp');
|
||||
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
expect(transport).toMatchObject({
|
||||
_url: new URL('http://test-server'),
|
||||
_requestInit: { headers: {} },
|
||||
});
|
||||
});
|
||||
|
||||
it('with type="sse" creates SSEClientTransport', async () => {
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
url: 'http://test-server',
|
||||
type: 'sse',
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
expect(transport).toBeInstanceOf(SSEClientTransport);
|
||||
expect(transport).toMatchObject({
|
||||
_url: new URL('http://test-server'),
|
||||
_requestInit: { headers: {} },
|
||||
});
|
||||
});
|
||||
|
||||
it('without type defaults to StreamableHTTPClientTransport', async () => {
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
url: 'http://test-server',
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
expect(transport).toMatchObject({
|
||||
_url: new URL('http://test-server'),
|
||||
_requestInit: { headers: {} },
|
||||
});
|
||||
});
|
||||
|
||||
it('with type="http" and headers applies headers correctly', async () => {
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
url: 'http://test-server',
|
||||
type: 'http',
|
||||
headers: { Authorization: 'Bearer token' },
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
expect(transport).toMatchObject({
|
||||
_url: new URL('http://test-server'),
|
||||
_requestInit: {
|
||||
headers: { Authorization: 'Bearer token' },
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('with type="sse" and headers applies headers correctly', async () => {
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
url: 'http://test-server',
|
||||
type: 'sse',
|
||||
headers: { 'X-API-Key': 'key123' },
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
expect(transport).toBeInstanceOf(SSEClientTransport);
|
||||
expect(transport).toMatchObject({
|
||||
_url: new URL('http://test-server'),
|
||||
_requestInit: {
|
||||
headers: { 'X-API-Key': 'key123' },
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('httpUrl takes priority over url when both are present', async () => {
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
httpUrl: 'http://test-server-http',
|
||||
url: 'http://test-server-url',
|
||||
},
|
||||
false,
|
||||
);
|
||||
|
||||
// httpUrl should take priority and create HTTP transport
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
expect(transport).toMatchObject({
|
||||
_url: new URL('http://test-server-http'),
|
||||
_requestInit: { headers: {} },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -680,6 +789,7 @@ describe('mcp-client', () => {
|
||||
'test-server',
|
||||
{
|
||||
url: 'http://test.googleapis.com',
|
||||
type: 'sse',
|
||||
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
|
||||
oauth: {
|
||||
scopes: ['scope1'],
|
||||
@@ -839,7 +949,10 @@ describe('connectToMcpServer with OAuth', () => {
|
||||
const wwwAuthHeader = `Bearer realm="test", resource_metadata="http://test-server.com/.well-known/oauth-protected-resource"`;
|
||||
|
||||
vi.mocked(mockedClient.connect).mockRejectedValueOnce(
|
||||
new Error(`401 Unauthorized\nwww-authenticate: ${wwwAuthHeader}`),
|
||||
new StreamableHTTPError(
|
||||
401,
|
||||
`Unauthorized\nwww-authenticate: ${wwwAuthHeader}`,
|
||||
),
|
||||
);
|
||||
|
||||
vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({
|
||||
@@ -860,7 +973,7 @@ describe('connectToMcpServer with OAuth', () => {
|
||||
|
||||
const client = await connectToMcpServer(
|
||||
'test-server',
|
||||
{ httpUrl: serverUrl },
|
||||
{ httpUrl: serverUrl, oauth: { enabled: true } },
|
||||
false,
|
||||
workspaceContext,
|
||||
);
|
||||
@@ -880,7 +993,7 @@ describe('connectToMcpServer with OAuth', () => {
|
||||
const tokenUrl = 'http://auth.example.com/token';
|
||||
|
||||
vi.mocked(mockedClient.connect).mockRejectedValueOnce(
|
||||
new Error('401 Unauthorized'),
|
||||
new StreamableHTTPError(401, 'Unauthorized'),
|
||||
);
|
||||
|
||||
vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({
|
||||
@@ -904,7 +1017,7 @@ describe('connectToMcpServer with OAuth', () => {
|
||||
|
||||
const client = await connectToMcpServer(
|
||||
'test-server',
|
||||
{ httpUrl: serverUrl },
|
||||
{ httpUrl: serverUrl, oauth: { enabled: true } },
|
||||
false,
|
||||
workspaceContext,
|
||||
);
|
||||
@@ -919,3 +1032,193 @@ describe('connectToMcpServer with OAuth', () => {
|
||||
expect(authHeader).toBe('Bearer test-access-token-from-discovery');
|
||||
});
|
||||
});
|
||||
|
||||
describe('connectToMcpServer - HTTP→SSE fallback', () => {
|
||||
let mockedClient: ClientLib.Client;
|
||||
let workspaceContext: WorkspaceContext;
|
||||
let testWorkspace: string;
|
||||
|
||||
beforeEach(() => {
|
||||
mockedClient = {
|
||||
connect: vi.fn(),
|
||||
close: vi.fn(),
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
onclose: vi.fn(),
|
||||
notification: vi.fn(),
|
||||
} as unknown as ClientLib.Client;
|
||||
vi.mocked(ClientLib.Client).mockImplementation(() => mockedClient);
|
||||
|
||||
testWorkspace = fs.mkdtempSync(
|
||||
path.join(os.tmpdir(), 'gemini-agent-test-'),
|
||||
);
|
||||
workspaceContext = new WorkspaceContext(testWorkspace);
|
||||
|
||||
vi.spyOn(console, 'log').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should NOT trigger fallback when type="http" is explicit', async () => {
|
||||
vi.mocked(mockedClient.connect).mockRejectedValueOnce(
|
||||
new Error('Connection failed'),
|
||||
);
|
||||
|
||||
await expect(
|
||||
connectToMcpServer(
|
||||
'test-server',
|
||||
{ url: 'http://test-server', type: 'http' },
|
||||
false,
|
||||
workspaceContext,
|
||||
),
|
||||
).rejects.toThrow('Connection failed');
|
||||
|
||||
// Should only try once (no fallback)
|
||||
expect(mockedClient.connect).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should NOT trigger fallback when type="sse" is explicit', async () => {
|
||||
vi.mocked(mockedClient.connect).mockRejectedValueOnce(
|
||||
new Error('Connection failed'),
|
||||
);
|
||||
|
||||
await expect(
|
||||
connectToMcpServer(
|
||||
'test-server',
|
||||
{ url: 'http://test-server', type: 'sse' },
|
||||
false,
|
||||
workspaceContext,
|
||||
),
|
||||
).rejects.toThrow('Connection failed');
|
||||
|
||||
// Should only try once (no fallback)
|
||||
expect(mockedClient.connect).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should trigger fallback when url provided without type and HTTP fails', async () => {
|
||||
vi.mocked(mockedClient.connect)
|
||||
.mockRejectedValueOnce(new StreamableHTTPError(500, 'Server error'))
|
||||
.mockResolvedValueOnce(undefined);
|
||||
|
||||
const client = await connectToMcpServer(
|
||||
'test-server',
|
||||
{ url: 'http://test-server' },
|
||||
false,
|
||||
workspaceContext,
|
||||
);
|
||||
|
||||
expect(client).toBe(mockedClient);
|
||||
// First HTTP attempt fails, second SSE attempt succeeds
|
||||
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should throw original HTTP error when both HTTP and SSE fail (non-401)', async () => {
|
||||
const httpError = new StreamableHTTPError(500, 'Server error');
|
||||
const sseError = new Error('SSE connection failed');
|
||||
|
||||
vi.mocked(mockedClient.connect)
|
||||
.mockRejectedValueOnce(httpError)
|
||||
.mockRejectedValueOnce(sseError);
|
||||
|
||||
await expect(
|
||||
connectToMcpServer(
|
||||
'test-server',
|
||||
{ url: 'http://test-server' },
|
||||
false,
|
||||
workspaceContext,
|
||||
),
|
||||
).rejects.toThrow('Server error');
|
||||
|
||||
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should handle HTTP 404 followed by SSE success', async () => {
|
||||
vi.mocked(mockedClient.connect)
|
||||
.mockRejectedValueOnce(new StreamableHTTPError(404, 'Not Found'))
|
||||
.mockResolvedValueOnce(undefined);
|
||||
|
||||
const client = await connectToMcpServer(
|
||||
'test-server',
|
||||
{ url: 'http://test-server' },
|
||||
false,
|
||||
workspaceContext,
|
||||
);
|
||||
|
||||
expect(client).toBe(mockedClient);
|
||||
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('connectToMcpServer - OAuth with transport fallback', () => {
|
||||
let mockedClient: ClientLib.Client;
|
||||
let workspaceContext: WorkspaceContext;
|
||||
let testWorkspace: string;
|
||||
let mockAuthProvider: MCPOAuthProvider;
|
||||
let mockTokenStorage: MCPOAuthTokenStorage;
|
||||
|
||||
beforeEach(() => {
|
||||
mockedClient = {
|
||||
connect: vi.fn(),
|
||||
close: vi.fn(),
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
onclose: vi.fn(),
|
||||
notification: vi.fn(),
|
||||
} as unknown as ClientLib.Client;
|
||||
vi.mocked(ClientLib.Client).mockImplementation(() => mockedClient);
|
||||
|
||||
testWorkspace = fs.mkdtempSync(
|
||||
path.join(os.tmpdir(), 'gemini-agent-test-'),
|
||||
);
|
||||
workspaceContext = new WorkspaceContext(testWorkspace);
|
||||
|
||||
vi.spyOn(console, 'log').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
mockTokenStorage = {
|
||||
getCredentials: vi.fn().mockResolvedValue({ clientId: 'test-client' }),
|
||||
} as unknown as MCPOAuthTokenStorage;
|
||||
vi.mocked(MCPOAuthTokenStorage).mockReturnValue(mockTokenStorage);
|
||||
|
||||
mockAuthProvider = {
|
||||
authenticate: vi.fn().mockResolvedValue(undefined),
|
||||
getValidToken: vi.fn().mockResolvedValue('test-access-token'),
|
||||
tokenStorage: mockTokenStorage,
|
||||
} as unknown as MCPOAuthProvider;
|
||||
vi.mocked(MCPOAuthProvider).mockReturnValue(mockAuthProvider);
|
||||
|
||||
vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({
|
||||
authorizationUrl: 'http://auth.example.com/auth',
|
||||
tokenUrl: 'http://auth.example.com/token',
|
||||
scopes: ['test-scope'],
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should handle HTTP 404 → SSE 401 → OAuth → SSE+OAuth succeeds', async () => {
|
||||
// Tests that OAuth flow works when SSE (not HTTP) requires auth
|
||||
vi.mocked(mockedClient.connect)
|
||||
.mockRejectedValueOnce(new StreamableHTTPError(404, 'Not Found'))
|
||||
.mockRejectedValueOnce(new StreamableHTTPError(401, 'Unauthorized'))
|
||||
.mockResolvedValueOnce(undefined);
|
||||
|
||||
const client = await connectToMcpServer(
|
||||
'test-server',
|
||||
{ url: 'http://test-server', oauth: { enabled: true } },
|
||||
false,
|
||||
workspaceContext,
|
||||
);
|
||||
|
||||
expect(client).toBe(mockedClient);
|
||||
expect(mockedClient.connect).toHaveBeenCalledTimes(3);
|
||||
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -40,7 +40,11 @@ import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
||||
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
||||
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import {
|
||||
getErrorMessage,
|
||||
isAuthenticationError,
|
||||
UnauthorizedError,
|
||||
} from '../utils/errors.js';
|
||||
import type {
|
||||
Unsubscribe,
|
||||
WorkspaceContext,
|
||||
@@ -443,33 +447,6 @@ function createAuthProvider(
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a transport for URL based servers (remote servers).
|
||||
*
|
||||
* @param mcpServerConfig The MCP server configuration
|
||||
* @param transportOptions The transport options
|
||||
*/
|
||||
function createUrlTransport(
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
transportOptions:
|
||||
| StreamableHTTPClientTransportOptions
|
||||
| SSEClientTransportOptions,
|
||||
): StreamableHTTPClientTransport | SSEClientTransport {
|
||||
if (mcpServerConfig.httpUrl) {
|
||||
return new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerConfig.httpUrl),
|
||||
transportOptions,
|
||||
);
|
||||
}
|
||||
if (mcpServerConfig.url) {
|
||||
return new SSEClientTransport(
|
||||
new URL(mcpServerConfig.url),
|
||||
transportOptions,
|
||||
);
|
||||
}
|
||||
throw new Error('No URL configured for MCP Server');
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a transport with OAuth token for the given server configuration.
|
||||
*
|
||||
@@ -493,7 +470,7 @@ async function createTransportWithOAuth(
|
||||
requestInit: createTransportRequestInit(mcpServerConfig, headers),
|
||||
};
|
||||
|
||||
return createUrlTransport(mcpServerConfig, transportOptions);
|
||||
return createUrlTransport(mcpServerName, mcpServerConfig, transportOptions);
|
||||
} catch (error) {
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
@@ -921,6 +898,156 @@ export function hasNetworkTransport(config: MCPServerConfig): boolean {
|
||||
return !!(config.url || config.httpUrl);
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to retrieve a stored OAuth token for an MCP server.
|
||||
* Handles token validation and refresh automatically.
|
||||
*
|
||||
* @param serverName The name of the MCP server
|
||||
* @returns The valid access token, or null if no token is stored
|
||||
*/
|
||||
async function getStoredOAuthToken(serverName: string): Promise<string | null> {
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const credentials = await tokenStorage.getCredentials(serverName);
|
||||
if (!credentials) return null;
|
||||
|
||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||
return authProvider.getValidToken(serverName, {
|
||||
// Pass client ID if available
|
||||
clientId: credentials.clientId,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to create an SSE transport with optional OAuth authentication.
|
||||
*
|
||||
* @param config The MCP server configuration
|
||||
* @param accessToken Optional OAuth access token for authentication
|
||||
* @returns A configured SSE transport ready for connection
|
||||
*/
|
||||
function createSSETransportWithAuth(
|
||||
config: MCPServerConfig,
|
||||
accessToken?: string | null,
|
||||
): SSEClientTransport {
|
||||
const headers = {
|
||||
...config.headers,
|
||||
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
|
||||
};
|
||||
|
||||
const options: SSEClientTransportOptions = {};
|
||||
if (Object.keys(headers).length > 0) {
|
||||
options.requestInit = { headers };
|
||||
}
|
||||
|
||||
return new SSEClientTransport(new URL(config.url!), options);
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to connect a client using SSE transport with optional OAuth.
|
||||
*
|
||||
* @param client The MCP client to connect
|
||||
* @param config The MCP server configuration
|
||||
* @param accessToken Optional OAuth access token for authentication
|
||||
*/
|
||||
async function connectWithSSETransport(
|
||||
client: Client,
|
||||
config: MCPServerConfig,
|
||||
accessToken?: string | null,
|
||||
): Promise<void> {
|
||||
const transport = createSSETransportWithAuth(config, accessToken);
|
||||
await client.connect(transport, {
|
||||
timeout: config.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to show authentication required message and throw error.
|
||||
* Checks if there's a stored token that was rejected (requires re-auth).
|
||||
*
|
||||
* @param serverName The name of the MCP server
|
||||
* @throws Always throws an error with authentication instructions
|
||||
*/
|
||||
async function showAuthRequiredMessage(serverName: string): Promise<never> {
|
||||
const hasRejectedToken = !!(await getStoredOAuthToken(serverName));
|
||||
|
||||
const message = hasRejectedToken
|
||||
? `MCP server '${serverName}' rejected stored OAuth token. Please re-authenticate using: /mcp auth ${serverName}`
|
||||
: `MCP server '${serverName}' requires authentication using: /mcp auth ${serverName}`;
|
||||
|
||||
coreEvents.emitFeedback('info', message);
|
||||
throw new UnauthorizedError(message);
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to retry connection with OAuth token after authentication.
|
||||
* Handles both HTTP and SSE transports based on what previously failed.
|
||||
*
|
||||
* @param client The MCP client to connect
|
||||
* @param serverName The name of the MCP server
|
||||
* @param config The MCP server configuration
|
||||
* @param accessToken The OAuth access token to use
|
||||
* @param httpReturned404 Whether the HTTP transport returned 404 (indicating SSE-only server)
|
||||
*/
|
||||
async function retryWithOAuth(
|
||||
client: Client,
|
||||
serverName: string,
|
||||
config: MCPServerConfig,
|
||||
accessToken: string,
|
||||
httpReturned404: boolean,
|
||||
): Promise<void> {
|
||||
if (httpReturned404) {
|
||||
// HTTP returned 404, only try SSE
|
||||
debugLogger.log(
|
||||
`Retrying SSE connection to '${serverName}' with OAuth token...`,
|
||||
);
|
||||
await connectWithSSETransport(client, config, accessToken);
|
||||
debugLogger.log(
|
||||
`Successfully connected to '${serverName}' using SSE with OAuth.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// HTTP returned 401, try HTTP with OAuth first
|
||||
debugLogger.log(`Retrying connection to '${serverName}' with OAuth token...`);
|
||||
|
||||
const httpTransport = await createTransportWithOAuth(
|
||||
serverName,
|
||||
config,
|
||||
accessToken,
|
||||
);
|
||||
if (!httpTransport) {
|
||||
throw new Error(
|
||||
`Failed to create OAuth transport for server '${serverName}'`,
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
await client.connect(httpTransport, {
|
||||
timeout: config.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
});
|
||||
debugLogger.log(
|
||||
`Successfully connected to '${serverName}' using HTTP with OAuth.`,
|
||||
);
|
||||
} catch (httpError) {
|
||||
await httpTransport.close();
|
||||
|
||||
// If HTTP+OAuth returns 404 and auto-detection enabled, try SSE+OAuth
|
||||
if (
|
||||
String(httpError).includes('404') &&
|
||||
config.url &&
|
||||
!config.type &&
|
||||
!config.httpUrl
|
||||
) {
|
||||
debugLogger.log(`HTTP with OAuth returned 404, trying SSE with OAuth...`);
|
||||
await connectWithSSETransport(client, config, accessToken);
|
||||
debugLogger.log(
|
||||
`Successfully connected to '${serverName}' using SSE with OAuth.`,
|
||||
);
|
||||
} else {
|
||||
throw httpError;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates and connects an MCP client to a server based on the provided configuration.
|
||||
* It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and
|
||||
@@ -993,6 +1120,10 @@ export async function connectToMcpServer(
|
||||
unlistenDirectories = undefined;
|
||||
};
|
||||
|
||||
let firstAttemptError: Error | null = null;
|
||||
let httpReturned404 = false; // Track if HTTP returned 404 to skip it in OAuth retry
|
||||
let sseError: Error | null = null; // Track SSE fallback error
|
||||
|
||||
try {
|
||||
const transport = await createTransport(
|
||||
mcpServerName,
|
||||
@@ -1006,52 +1137,79 @@ export async function connectToMcpServer(
|
||||
return mcpClient;
|
||||
} catch (error) {
|
||||
await transport.close();
|
||||
firstAttemptError = error as Error;
|
||||
throw error;
|
||||
}
|
||||
} catch (error) {
|
||||
} catch (initialError) {
|
||||
let error = initialError;
|
||||
|
||||
// Check if this is a 401 error FIRST (before attempting SSE fallback)
|
||||
// This ensures OAuth flow happens before we try SSE
|
||||
if (isAuthenticationError(error) && hasNetworkTransport(mcpServerConfig)) {
|
||||
// Continue to OAuth handling below (after SSE fallback section)
|
||||
} else if (
|
||||
// If not 401, and HTTP failed with url without explicit type, try SSE fallback
|
||||
firstAttemptError &&
|
||||
mcpServerConfig.url &&
|
||||
!mcpServerConfig.type &&
|
||||
!mcpServerConfig.httpUrl
|
||||
) {
|
||||
// Check if HTTP returned 404 - if so, we know it's not an HTTP server
|
||||
httpReturned404 = String(firstAttemptError).includes('404');
|
||||
|
||||
const logMessage = httpReturned404
|
||||
? `HTTP returned 404, trying SSE transport...`
|
||||
: `HTTP connection failed, attempting SSE fallback...`;
|
||||
debugLogger.log(`MCP server '${mcpServerName}': ${logMessage}`);
|
||||
|
||||
try {
|
||||
// Try SSE with stored OAuth token if available
|
||||
// This ensures that SSE fallback works for authenticated servers
|
||||
await connectWithSSETransport(
|
||||
mcpClient,
|
||||
mcpServerConfig,
|
||||
await getStoredOAuthToken(mcpServerName),
|
||||
);
|
||||
|
||||
debugLogger.log(
|
||||
`MCP server '${mcpServerName}': Successfully connected using SSE transport.`,
|
||||
);
|
||||
return mcpClient;
|
||||
} catch (sseFallbackError) {
|
||||
sseError = sseFallbackError as Error;
|
||||
|
||||
// If SSE also returned 401, handle OAuth below
|
||||
if (isAuthenticationError(sseError)) {
|
||||
debugLogger.log(
|
||||
`MCP server '${mcpServerName}': SSE returned 401, OAuth authentication required.`,
|
||||
);
|
||||
// Update error to be the SSE error for OAuth handling
|
||||
error = sseError;
|
||||
// Continue to OAuth handling below
|
||||
} else {
|
||||
debugLogger.log(
|
||||
`MCP server '${mcpServerName}': SSE fallback also failed.`,
|
||||
);
|
||||
// Both failed without 401, throw the original error
|
||||
throw firstAttemptError;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is a 401 error that might indicate OAuth is required
|
||||
const errorString = String(error);
|
||||
if (errorString.includes('401') && hasNetworkTransport(mcpServerConfig)) {
|
||||
if (isAuthenticationError(error) && hasNetworkTransport(mcpServerConfig)) {
|
||||
mcpServerRequiresOAuth.set(mcpServerName, true);
|
||||
// Only trigger automatic OAuth discovery for HTTP servers or when OAuth is explicitly configured
|
||||
// For SSE servers, we should not trigger new OAuth flows automatically
|
||||
const shouldTriggerOAuth =
|
||||
mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled;
|
||||
|
||||
// Only trigger automatic OAuth if explicitly enabled in config
|
||||
// Otherwise, show error and tell user to run /mcp auth command
|
||||
const shouldTriggerOAuth = mcpServerConfig.oauth?.enabled;
|
||||
|
||||
if (!shouldTriggerOAuth) {
|
||||
// For SSE servers without explicit OAuth config, if a token was found but rejected, report it accurately.
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
||||
if (credentials) {
|
||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||
const hasStoredTokens = await authProvider.getValidToken(
|
||||
mcpServerName,
|
||||
{
|
||||
// Pass client ID if available
|
||||
clientId: credentials.clientId,
|
||||
},
|
||||
);
|
||||
if (hasStoredTokens) {
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`Stored OAuth token for SSE server '${mcpServerName}' was rejected. ` +
|
||||
`Please re-authenticate using: /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
} else {
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` +
|
||||
`Please authenticate using: /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
throw new Error(
|
||||
`401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` +
|
||||
`Please authenticate using: /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
await showAuthRequiredMessage(mcpServerName);
|
||||
}
|
||||
|
||||
// Try to extract www-authenticate header from the error
|
||||
const errorString = String(error);
|
||||
let wwwAuthenticate = extractWWWAuthenticateHeader(errorString);
|
||||
|
||||
// If we didn't get the header from the error string, try to get it from the server
|
||||
@@ -1061,12 +1219,27 @@ export async function connectToMcpServer(
|
||||
);
|
||||
try {
|
||||
const urlToFetch = mcpServerConfig.httpUrl || mcpServerConfig.url!;
|
||||
|
||||
// Determine correct Accept header based on what transport failed
|
||||
let acceptHeader: string;
|
||||
if (mcpServerConfig.httpUrl) {
|
||||
acceptHeader = 'application/json';
|
||||
} else if (mcpServerConfig.type === 'http') {
|
||||
acceptHeader = 'application/json';
|
||||
} else if (mcpServerConfig.type === 'sse') {
|
||||
acceptHeader = 'text/event-stream';
|
||||
} else if (httpReturned404) {
|
||||
// HTTP failed with 404, SSE returned 401 - use SSE header
|
||||
acceptHeader = 'text/event-stream';
|
||||
} else {
|
||||
// HTTP returned 401 - use HTTP header
|
||||
acceptHeader = 'application/json';
|
||||
}
|
||||
|
||||
const response = await fetch(urlToFetch, {
|
||||
method: 'HEAD',
|
||||
headers: {
|
||||
Accept: mcpServerConfig.httpUrl
|
||||
? 'application/json'
|
||||
: 'text/event-stream',
|
||||
Accept: acceptHeader,
|
||||
},
|
||||
signal: AbortSignal.timeout(5000),
|
||||
});
|
||||
@@ -1101,52 +1274,21 @@ export async function connectToMcpServer(
|
||||
);
|
||||
if (oauthSuccess) {
|
||||
// Retry connection with OAuth token
|
||||
debugLogger.log(
|
||||
`Retrying connection to '${mcpServerName}' with OAuth token...`,
|
||||
);
|
||||
|
||||
// Get the valid token - we need to create a proper OAuth config
|
||||
// The token should already be available from the authentication process
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
||||
if (credentials) {
|
||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||
const accessToken = await authProvider.getValidToken(
|
||||
mcpServerName,
|
||||
{
|
||||
// Pass client ID if available
|
||||
clientId: credentials.clientId,
|
||||
},
|
||||
);
|
||||
|
||||
if (accessToken) {
|
||||
// Create transport with OAuth token
|
||||
const oauthTransport = await createTransportWithOAuth(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
accessToken,
|
||||
);
|
||||
if (oauthTransport) {
|
||||
await mcpClient.connect(oauthTransport, {
|
||||
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
});
|
||||
// Connection successful with OAuth
|
||||
return mcpClient;
|
||||
} else {
|
||||
throw new Error(
|
||||
`Failed to create OAuth transport for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
throw new Error(
|
||||
`Failed to get OAuth token for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const accessToken = await getStoredOAuthToken(mcpServerName);
|
||||
if (!accessToken) {
|
||||
throw new Error(
|
||||
`Failed to get credentials for server '${mcpServerName}' after successful OAuth authentication`,
|
||||
`Failed to get OAuth token for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
|
||||
await retryWithOAuth(
|
||||
mcpClient,
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
accessToken,
|
||||
httpReturned404,
|
||||
);
|
||||
return mcpClient;
|
||||
} else {
|
||||
throw new Error(
|
||||
`Failed to handle automatic OAuth for server '${mcpServerName}'`,
|
||||
@@ -1154,41 +1296,11 @@ export async function connectToMcpServer(
|
||||
}
|
||||
} else {
|
||||
// No www-authenticate header found, but we got a 401
|
||||
// Only try OAuth discovery for HTTP servers or when OAuth is explicitly configured
|
||||
// For SSE servers, we should not trigger new OAuth flows automatically
|
||||
const shouldTryDiscovery =
|
||||
mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled;
|
||||
// Only try OAuth discovery when OAuth is explicitly enabled in config
|
||||
const shouldTryDiscovery = mcpServerConfig.oauth?.enabled;
|
||||
|
||||
if (!shouldTryDiscovery) {
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
||||
if (credentials) {
|
||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||
const hasStoredTokens = await authProvider.getValidToken(
|
||||
mcpServerName,
|
||||
{
|
||||
// Pass client ID if available
|
||||
clientId: credentials.clientId,
|
||||
},
|
||||
);
|
||||
if (hasStoredTokens) {
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`Stored OAuth token for SSE server '${mcpServerName}' was rejected. ` +
|
||||
`Please re-authenticate using: /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
} else {
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` +
|
||||
`Please authenticate using: /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
throw new Error(
|
||||
`401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` +
|
||||
`Please authenticate using: /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
await showAuthRequiredMessage(mcpServerName);
|
||||
}
|
||||
|
||||
// For SSE/HTTP servers, try to discover OAuth configuration from the base URL
|
||||
@@ -1234,47 +1346,30 @@ export async function connectToMcpServer(
|
||||
);
|
||||
|
||||
// Retry connection with OAuth token
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const credentials =
|
||||
await tokenStorage.getCredentials(mcpServerName);
|
||||
if (credentials) {
|
||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||
const accessToken = await authProvider.getValidToken(
|
||||
mcpServerName,
|
||||
{
|
||||
// Pass client ID if available
|
||||
clientId: credentials.clientId,
|
||||
},
|
||||
);
|
||||
if (accessToken) {
|
||||
// Create transport with OAuth token
|
||||
const oauthTransport = await createTransportWithOAuth(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
accessToken,
|
||||
);
|
||||
if (oauthTransport) {
|
||||
await mcpClient.connect(oauthTransport, {
|
||||
timeout:
|
||||
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
});
|
||||
// Connection successful with OAuth
|
||||
return mcpClient;
|
||||
} else {
|
||||
throw new Error(
|
||||
`Failed to create OAuth transport for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
throw new Error(
|
||||
`Failed to get OAuth token for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const accessToken = await getStoredOAuthToken(mcpServerName);
|
||||
if (!accessToken) {
|
||||
throw new Error(
|
||||
`Failed to get stored credentials for server '${mcpServerName}'`,
|
||||
`Failed to get OAuth token for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
|
||||
// Create transport with OAuth token
|
||||
const oauthTransport = await createTransportWithOAuth(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
accessToken,
|
||||
);
|
||||
if (!oauthTransport) {
|
||||
throw new Error(
|
||||
`Failed to create OAuth transport for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
|
||||
await mcpClient.connect(oauthTransport, {
|
||||
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
});
|
||||
// Connection successful with OAuth
|
||||
return mcpClient;
|
||||
} else {
|
||||
throw new Error(
|
||||
`OAuth configuration failed for '${mcpServerName}'. Please authenticate manually with /mcp auth ${mcpServerName}`,
|
||||
@@ -1288,27 +1383,63 @@ export async function connectToMcpServer(
|
||||
}
|
||||
} else {
|
||||
// Handle other connection errors
|
||||
// Create a concise error message
|
||||
const errorMessage = (error as Error).message || String(error);
|
||||
const isNetworkError =
|
||||
errorMessage.includes('ENOTFOUND') ||
|
||||
errorMessage.includes('ECONNREFUSED');
|
||||
|
||||
let conciseError: string;
|
||||
if (isNetworkError) {
|
||||
conciseError = `Cannot connect to '${mcpServerName}' - server may be down or URL incorrect`;
|
||||
} else {
|
||||
conciseError = `Connection failed for '${mcpServerName}': ${errorMessage}`;
|
||||
}
|
||||
|
||||
if (process.env['SANDBOX']) {
|
||||
conciseError += ` (check sandbox availability)`;
|
||||
}
|
||||
|
||||
throw new Error(conciseError);
|
||||
// Re-throw the original error to preserve its structure
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to create the appropriate transport based on config
|
||||
* This handles the logic for httpUrl/url/type consistently
|
||||
*/
|
||||
function createUrlTransport(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
transportOptions:
|
||||
| StreamableHTTPClientTransportOptions
|
||||
| SSEClientTransportOptions,
|
||||
): StreamableHTTPClientTransport | SSEClientTransport {
|
||||
// Priority 1: httpUrl (deprecated)
|
||||
if (mcpServerConfig.httpUrl) {
|
||||
if (mcpServerConfig.url) {
|
||||
debugLogger.warn(
|
||||
`MCP server '${mcpServerName}': Both 'httpUrl' and 'url' are configured. ` +
|
||||
`Using deprecated 'httpUrl'. Please migrate to 'url' with 'type: "http"'.`,
|
||||
);
|
||||
}
|
||||
return new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerConfig.httpUrl),
|
||||
transportOptions,
|
||||
);
|
||||
}
|
||||
|
||||
// Priority 2 & 3: url with explicit type
|
||||
if (mcpServerConfig.url && mcpServerConfig.type) {
|
||||
if (mcpServerConfig.type === 'http') {
|
||||
return new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerConfig.url),
|
||||
transportOptions,
|
||||
);
|
||||
} else if (mcpServerConfig.type === 'sse') {
|
||||
return new SSEClientTransport(
|
||||
new URL(mcpServerConfig.url),
|
||||
transportOptions,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 4: url without type (default to HTTP)
|
||||
if (mcpServerConfig.url) {
|
||||
return new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerConfig.url),
|
||||
transportOptions,
|
||||
);
|
||||
}
|
||||
|
||||
throw new Error(`No URL configured for MCP server '${mcpServerName}'`);
|
||||
}
|
||||
|
||||
/** Visible for Testing */
|
||||
export async function createTransport(
|
||||
mcpServerName: string,
|
||||
@@ -1333,7 +1464,6 @@ export async function createTransport(
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (mcpServerConfig.httpUrl || mcpServerConfig.url) {
|
||||
const authProvider = createAuthProvider(mcpServerConfig);
|
||||
const headers: Record<string, string> =
|
||||
@@ -1342,8 +1472,7 @@ export async function createTransport(
|
||||
if (authProvider === undefined) {
|
||||
// Check if we have OAuth configuration or stored tokens
|
||||
let accessToken: string | null = null;
|
||||
let hasOAuthConfig = mcpServerConfig.oauth?.enabled;
|
||||
if (hasOAuthConfig && mcpServerConfig.oauth) {
|
||||
if (mcpServerConfig.oauth?.enabled && mcpServerConfig.oauth) {
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const mcpAuthProvider = new MCPOAuthProvider(tokenStorage);
|
||||
accessToken = await mcpAuthProvider.getValidToken(
|
||||
@@ -1352,31 +1481,22 @@ export async function createTransport(
|
||||
);
|
||||
|
||||
if (!accessToken) {
|
||||
throw new Error(
|
||||
`MCP server '${mcpServerName}' requires OAuth authentication. ` +
|
||||
`Please authenticate using the /mcp auth command.`,
|
||||
// Emit info message (not error) since this is expected behavior
|
||||
coreEvents.emitFeedback(
|
||||
'info',
|
||||
`MCP server '${mcpServerName}' requires authentication using: /mcp auth ${mcpServerName}`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Check if we have stored OAuth tokens for this server (from previous authentication)
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
||||
if (credentials) {
|
||||
const mcpAuthProvider = new MCPOAuthProvider(tokenStorage);
|
||||
accessToken = await mcpAuthProvider.getValidToken(mcpServerName, {
|
||||
// Pass client ID if available
|
||||
clientId: credentials.clientId,
|
||||
});
|
||||
|
||||
if (accessToken) {
|
||||
hasOAuthConfig = true;
|
||||
debugLogger.log(
|
||||
`Found stored OAuth token for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
accessToken = await getStoredOAuthToken(mcpServerName);
|
||||
if (accessToken) {
|
||||
debugLogger.log(
|
||||
`Found stored OAuth token for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
}
|
||||
if (hasOAuthConfig && accessToken) {
|
||||
if (accessToken) {
|
||||
headers['Authorization'] = `Bearer ${accessToken}`;
|
||||
}
|
||||
}
|
||||
@@ -1388,7 +1508,7 @@ export async function createTransport(
|
||||
authProvider,
|
||||
};
|
||||
|
||||
return createUrlTransport(mcpServerConfig, transportOptions);
|
||||
return createUrlTransport(mcpServerName, mcpServerConfig, transportOptions);
|
||||
}
|
||||
|
||||
if (mcpServerConfig.command) {
|
||||
|
||||
42
packages/core/src/utils/errors.test.ts
Normal file
42
packages/core/src/utils/errors.test.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { isAuthenticationError, UnauthorizedError } from './errors.js';
|
||||
|
||||
describe('isAuthenticationError', () => {
|
||||
it('should detect error with code: 401 property (MCP SDK style)', () => {
|
||||
const error = { code: 401, message: 'Unauthorized' };
|
||||
expect(isAuthenticationError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect UnauthorizedError instance', () => {
|
||||
const error = new UnauthorizedError('Authentication required');
|
||||
expect(isAuthenticationError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for 404 errors', () => {
|
||||
const error = { code: 404, message: 'Not Found' };
|
||||
expect(isAuthenticationError(error)).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle null and undefined gracefully', () => {
|
||||
expect(isAuthenticationError(null)).toBe(false);
|
||||
expect(isAuthenticationError(undefined)).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle non-error objects', () => {
|
||||
expect(isAuthenticationError('string error')).toBe(false);
|
||||
expect(isAuthenticationError(123)).toBe(false);
|
||||
expect(isAuthenticationError({})).toBe(false);
|
||||
});
|
||||
|
||||
it('should detect 401 in various message formats', () => {
|
||||
expect(isAuthenticationError(new Error('401 Unauthorized'))).toBe(true);
|
||||
expect(isAuthenticationError(new Error('HTTP 401'))).toBe(true);
|
||||
expect(isAuthenticationError(new Error('Status code: 401'))).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -117,3 +117,42 @@ function parseResponseData(error: GaxiosError): ResponseData {
|
||||
}
|
||||
return error.response?.data as ResponseData;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if an error is a 401 authentication error.
|
||||
* Uses structured error properties from MCP SDK errors.
|
||||
*
|
||||
* @param error The error to check
|
||||
* @returns true if this is a 401/authentication error
|
||||
*/
|
||||
export function isAuthenticationError(error: unknown): boolean {
|
||||
// Check for MCP SDK errors with code property
|
||||
// (SseError and StreamableHTTPError both have numeric 'code' property)
|
||||
if (error && typeof error === 'object' && 'code' in error) {
|
||||
const errorCode = (error as { code: unknown }).code;
|
||||
if (errorCode === 401) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Check for UnauthorizedError class (from MCP SDK or our own)
|
||||
if (
|
||||
error instanceof Error &&
|
||||
error.constructor.name === 'UnauthorizedError'
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (error instanceof UnauthorizedError) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Fallback: Check for MCP SDK's plain Error messages with HTTP 401
|
||||
// The SDK sometimes throws: new Error(`Error POSTing to endpoint (HTTP 401): ...`)
|
||||
const message = getErrorMessage(error);
|
||||
if (message.includes('401')) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user