mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-13 23:51:16 -07:00
2128 lines
68 KiB
TypeScript
2128 lines
68 KiB
TypeScript
/**
|
|
* @license
|
|
* Copyright 2025 Google LLC
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
|
|
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
|
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
|
import {
|
|
StreamableHTTPClientTransport,
|
|
StreamableHTTPError,
|
|
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
|
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
|
import { AuthProviderType, type Config } from '../config/config.js';
|
|
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
|
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
|
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
|
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
|
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
|
import { ToolListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
|
|
|
|
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
|
import {
|
|
connectToMcpServer,
|
|
createTransport,
|
|
hasNetworkTransport,
|
|
isEnabled,
|
|
McpClient,
|
|
populateMcpServerCommand,
|
|
} from './mcp-client.js';
|
|
import type { ToolRegistry } from './tool-registry.js';
|
|
import type { ResourceRegistry } from '../resources/resource-registry.js';
|
|
import * as fs from 'node:fs';
|
|
import * as os from 'node:os';
|
|
import * as path from 'node:path';
|
|
import { coreEvents } from '../utils/events.js';
|
|
import type { EnvironmentSanitizationConfig } from '../services/environmentSanitization.js';
|
|
|
|
const EMPTY_CONFIG: EnvironmentSanitizationConfig = {
|
|
enableEnvironmentVariableRedaction: true,
|
|
allowedEnvironmentVariables: [],
|
|
blockedEnvironmentVariables: [],
|
|
};
|
|
|
|
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
|
|
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
|
vi.mock('@google/genai');
|
|
vi.mock('../mcp/oauth-provider.js');
|
|
vi.mock('../mcp/oauth-token-storage.js');
|
|
vi.mock('../mcp/oauth-utils.js');
|
|
vi.mock('google-auth-library');
|
|
import { GoogleAuth } from 'google-auth-library';
|
|
|
|
vi.mock('../utils/events.js', () => ({
|
|
coreEvents: {
|
|
emitFeedback: vi.fn(),
|
|
emitConsoleLog: vi.fn(),
|
|
},
|
|
}));
|
|
|
|
describe('mcp-client', () => {
|
|
let workspaceContext: WorkspaceContext;
|
|
let testWorkspace: string;
|
|
|
|
beforeEach(() => {
|
|
// create a tmp dir for this test
|
|
// Create a unique temporary directory for the workspace to avoid conflicts
|
|
testWorkspace = fs.mkdtempSync(
|
|
path.join(os.tmpdir(), 'gemini-agent-test-'),
|
|
);
|
|
workspaceContext = new WorkspaceContext(testWorkspace);
|
|
});
|
|
|
|
afterEach(() => {
|
|
vi.restoreAllMocks();
|
|
});
|
|
|
|
describe('McpClient', () => {
|
|
it('should discover tools', async () => {
|
|
const mockedClient = {
|
|
connect: vi.fn(),
|
|
discover: vi.fn(),
|
|
disconnect: vi.fn(),
|
|
getStatus: vi.fn(),
|
|
registerCapabilities: vi.fn(),
|
|
setRequestHandler: vi.fn(),
|
|
setNotificationHandler: vi.fn(),
|
|
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
|
|
listTools: vi.fn().mockResolvedValue({
|
|
tools: [
|
|
{
|
|
name: 'testFunction',
|
|
inputSchema: {
|
|
type: 'object',
|
|
properties: {},
|
|
},
|
|
},
|
|
],
|
|
}),
|
|
listPrompts: vi.fn().mockResolvedValue({
|
|
prompts: [],
|
|
}),
|
|
request: vi.fn().mockResolvedValue({}),
|
|
};
|
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
|
mockedClient as unknown as ClientLib.Client,
|
|
);
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
const mockedToolRegistry = {
|
|
registerTool: vi.fn(),
|
|
sortTools: vi.fn(),
|
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
|
} as unknown as ToolRegistry;
|
|
const promptRegistry = {
|
|
registerPrompt: vi.fn(),
|
|
removePromptsByServer: vi.fn(),
|
|
} as unknown as PromptRegistry;
|
|
const resourceRegistry = {
|
|
setResourcesForServer: vi.fn(),
|
|
removeResourcesByServer: vi.fn(),
|
|
} as unknown as ResourceRegistry;
|
|
const client = new McpClient(
|
|
'test-server',
|
|
{
|
|
command: 'test-command',
|
|
},
|
|
mockedToolRegistry,
|
|
promptRegistry,
|
|
resourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
);
|
|
await client.connect();
|
|
await client.discover({} as Config);
|
|
expect(mockedClient.listTools).toHaveBeenCalledWith(
|
|
{},
|
|
{ timeout: 600000 },
|
|
);
|
|
});
|
|
|
|
it('should not skip tools even if a parameter is missing a type', async () => {
|
|
const consoleWarnSpy = vi
|
|
.spyOn(console, 'warn')
|
|
.mockImplementation(() => {});
|
|
const mockedClient = {
|
|
connect: vi.fn(),
|
|
discover: vi.fn(),
|
|
disconnect: vi.fn(),
|
|
getStatus: vi.fn(),
|
|
registerCapabilities: vi.fn(),
|
|
setRequestHandler: vi.fn(),
|
|
setNotificationHandler: vi.fn(),
|
|
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
|
|
|
|
listTools: vi.fn().mockResolvedValue({
|
|
tools: [
|
|
{
|
|
name: 'validTool',
|
|
inputSchema: {
|
|
type: 'object',
|
|
properties: {
|
|
param1: { type: 'string' },
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: 'invalidTool',
|
|
inputSchema: {
|
|
type: 'object',
|
|
properties: {
|
|
param1: { description: 'a param with no type' },
|
|
},
|
|
},
|
|
},
|
|
],
|
|
}),
|
|
listPrompts: vi.fn().mockResolvedValue({
|
|
prompts: [],
|
|
}),
|
|
request: vi.fn().mockResolvedValue({}),
|
|
};
|
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
|
mockedClient as unknown as ClientLib.Client,
|
|
);
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
const mockedToolRegistry = {
|
|
registerTool: vi.fn(),
|
|
sortTools: vi.fn(),
|
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
|
} as unknown as ToolRegistry;
|
|
const promptRegistry = {
|
|
registerPrompt: vi.fn(),
|
|
removePromptsByServer: vi.fn(),
|
|
} as unknown as PromptRegistry;
|
|
const resourceRegistry = {
|
|
setResourcesForServer: vi.fn(),
|
|
removeResourcesByServer: vi.fn(),
|
|
} as unknown as ResourceRegistry;
|
|
const client = new McpClient(
|
|
'test-server',
|
|
{
|
|
command: 'test-command',
|
|
},
|
|
mockedToolRegistry,
|
|
promptRegistry,
|
|
resourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
);
|
|
await client.connect();
|
|
await client.discover({} as Config);
|
|
expect(mockedToolRegistry.registerTool).toHaveBeenCalledTimes(2);
|
|
expect(consoleWarnSpy).not.toHaveBeenCalled();
|
|
consoleWarnSpy.mockRestore();
|
|
});
|
|
|
|
it('should propagate errors when discovering prompts', async () => {
|
|
const mockedClient = {
|
|
connect: vi.fn(),
|
|
discover: vi.fn(),
|
|
disconnect: vi.fn(),
|
|
getStatus: vi.fn(),
|
|
registerCapabilities: vi.fn(),
|
|
setRequestHandler: vi.fn(),
|
|
setNotificationHandler: vi.fn(),
|
|
getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }),
|
|
listTools: vi.fn().mockResolvedValue({ tools: [] }),
|
|
listPrompts: vi.fn().mockRejectedValue(new Error('Test error')),
|
|
request: vi.fn().mockResolvedValue({}),
|
|
};
|
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
|
mockedClient as unknown as ClientLib.Client,
|
|
);
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
const mockedToolRegistry = {
|
|
registerTool: vi.fn(),
|
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
|
} as unknown as ToolRegistry;
|
|
const promptRegistry = {
|
|
registerPrompt: vi.fn(),
|
|
removePromptsByServer: vi.fn(),
|
|
} as unknown as PromptRegistry;
|
|
const resourceRegistry = {
|
|
setResourcesForServer: vi.fn(),
|
|
removeResourcesByServer: vi.fn(),
|
|
} as unknown as ResourceRegistry;
|
|
const client = new McpClient(
|
|
'test-server',
|
|
{
|
|
command: 'test-command',
|
|
},
|
|
mockedToolRegistry,
|
|
promptRegistry,
|
|
resourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
);
|
|
await client.connect();
|
|
await expect(client.discover({} as Config)).rejects.toThrow('Test error');
|
|
expect(coreEvents.emitFeedback).toHaveBeenCalledWith(
|
|
'error',
|
|
`Error discovering prompts from test-server: Test error`,
|
|
expect.any(Error),
|
|
);
|
|
});
|
|
|
|
it('should not discover tools if server does not support them', async () => {
|
|
const mockedClient = {
|
|
connect: vi.fn(),
|
|
discover: vi.fn(),
|
|
disconnect: vi.fn(),
|
|
getStatus: vi.fn(),
|
|
registerCapabilities: vi.fn(),
|
|
setRequestHandler: vi.fn(),
|
|
setNotificationHandler: vi.fn(),
|
|
getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }),
|
|
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
|
request: vi.fn().mockResolvedValue({}),
|
|
};
|
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
|
mockedClient as unknown as ClientLib.Client,
|
|
);
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
const mockedToolRegistry = {
|
|
registerTool: vi.fn(),
|
|
sortTools: vi.fn(),
|
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
|
} as unknown as ToolRegistry;
|
|
const promptRegistry = {
|
|
registerPrompt: vi.fn(),
|
|
removePromptsByServer: vi.fn(),
|
|
} as unknown as PromptRegistry;
|
|
const resourceRegistry = {
|
|
setResourcesForServer: vi.fn(),
|
|
removeResourcesByServer: vi.fn(),
|
|
} as unknown as ResourceRegistry;
|
|
const client = new McpClient(
|
|
'test-server',
|
|
{
|
|
command: 'test-command',
|
|
},
|
|
mockedToolRegistry,
|
|
promptRegistry,
|
|
resourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
);
|
|
await client.connect();
|
|
await expect(client.discover({} as Config)).rejects.toThrow(
|
|
'No prompts, tools, or resources found on the server.',
|
|
);
|
|
});
|
|
|
|
it('should discover tools if server supports them', async () => {
|
|
const mockedClient = {
|
|
connect: vi.fn(),
|
|
discover: vi.fn(),
|
|
disconnect: vi.fn(),
|
|
getStatus: vi.fn(),
|
|
registerCapabilities: vi.fn(),
|
|
setRequestHandler: vi.fn(),
|
|
setNotificationHandler: vi.fn(),
|
|
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
|
|
listTools: vi.fn().mockResolvedValue({
|
|
tools: [
|
|
{
|
|
name: 'testTool',
|
|
description: 'A test tool',
|
|
inputSchema: { type: 'object', properties: {} },
|
|
},
|
|
],
|
|
}),
|
|
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
|
request: vi.fn().mockResolvedValue({}),
|
|
};
|
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
|
mockedClient as unknown as ClientLib.Client,
|
|
);
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
const mockedToolRegistry = {
|
|
registerTool: vi.fn(),
|
|
sortTools: vi.fn(),
|
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
|
} as unknown as ToolRegistry;
|
|
const promptRegistry = {
|
|
registerPrompt: vi.fn(),
|
|
removePromptsByServer: vi.fn(),
|
|
} as unknown as PromptRegistry;
|
|
const resourceRegistry = {
|
|
setResourcesForServer: vi.fn(),
|
|
removeResourcesByServer: vi.fn(),
|
|
} as unknown as ResourceRegistry;
|
|
const client = new McpClient(
|
|
'test-server',
|
|
{
|
|
command: 'test-command',
|
|
},
|
|
mockedToolRegistry,
|
|
promptRegistry,
|
|
resourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
);
|
|
await client.connect();
|
|
await client.discover({} as Config);
|
|
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
|
});
|
|
|
|
it('should discover tools with $defs and $ref in schema', async () => {
|
|
const mockedClient = {
|
|
connect: vi.fn(),
|
|
discover: vi.fn(),
|
|
disconnect: vi.fn(),
|
|
getStatus: vi.fn(),
|
|
registerCapabilities: vi.fn(),
|
|
setRequestHandler: vi.fn(),
|
|
setNotificationHandler: vi.fn(),
|
|
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
|
|
listTools: vi.fn().mockResolvedValue({
|
|
tools: [
|
|
{
|
|
name: 'toolWithDefs',
|
|
description: 'A tool using $defs',
|
|
inputSchema: {
|
|
type: 'object',
|
|
properties: {
|
|
param1: {
|
|
$ref: '#/$defs/MyType',
|
|
},
|
|
},
|
|
$defs: {
|
|
MyType: {
|
|
type: 'string',
|
|
description: 'A defined type',
|
|
},
|
|
},
|
|
},
|
|
},
|
|
],
|
|
}),
|
|
listPrompts: vi.fn().mockResolvedValue({
|
|
prompts: [],
|
|
}),
|
|
request: vi.fn().mockResolvedValue({}),
|
|
};
|
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
|
mockedClient as unknown as ClientLib.Client,
|
|
);
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
const mockedToolRegistry = {
|
|
registerTool: vi.fn(),
|
|
sortTools: vi.fn(),
|
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
|
} as unknown as ToolRegistry;
|
|
const promptRegistry = {
|
|
registerPrompt: vi.fn(),
|
|
removePromptsByServer: vi.fn(),
|
|
} as unknown as PromptRegistry;
|
|
const resourceRegistry = {
|
|
setResourcesForServer: vi.fn(),
|
|
removeResourcesByServer: vi.fn(),
|
|
} as unknown as ResourceRegistry;
|
|
const client = new McpClient(
|
|
'test-server',
|
|
{
|
|
command: 'test-command',
|
|
},
|
|
mockedToolRegistry,
|
|
promptRegistry,
|
|
resourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
);
|
|
await client.connect();
|
|
await client.discover({} as Config);
|
|
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
|
const registeredTool = vi.mocked(mockedToolRegistry.registerTool).mock
|
|
.calls[0][0];
|
|
expect(registeredTool.schema.parametersJsonSchema).toEqual({
|
|
type: 'object',
|
|
properties: {
|
|
param1: {
|
|
$ref: '#/$defs/MyType',
|
|
},
|
|
},
|
|
$defs: {
|
|
MyType: {
|
|
type: 'string',
|
|
description: 'A defined type',
|
|
},
|
|
},
|
|
});
|
|
});
|
|
|
|
it('should discover resources when a server only exposes resources', async () => {
|
|
const mockedClient = {
|
|
connect: vi.fn(),
|
|
discover: vi.fn(),
|
|
disconnect: vi.fn(),
|
|
getStatus: vi.fn(),
|
|
registerCapabilities: vi.fn(),
|
|
setRequestHandler: vi.fn(),
|
|
setNotificationHandler: vi.fn(),
|
|
getServerCapabilities: vi.fn().mockReturnValue({ resources: {} }),
|
|
request: vi.fn().mockImplementation(({ method }) => {
|
|
if (method === 'resources/list') {
|
|
return Promise.resolve({
|
|
resources: [
|
|
{
|
|
uri: 'file:///tmp/resource.txt',
|
|
name: 'resource',
|
|
description: 'Test Resource',
|
|
mimeType: 'text/plain',
|
|
},
|
|
],
|
|
});
|
|
}
|
|
return Promise.resolve({ prompts: [] });
|
|
}),
|
|
} as unknown as ClientLib.Client;
|
|
vi.mocked(ClientLib.Client).mockReturnValue(mockedClient);
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
const mockedToolRegistry = {
|
|
registerTool: vi.fn(),
|
|
sortTools: vi.fn(),
|
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
|
} as unknown as ToolRegistry;
|
|
const promptRegistry = {
|
|
registerPrompt: vi.fn(),
|
|
removePromptsByServer: vi.fn(),
|
|
} as unknown as PromptRegistry;
|
|
const resourceRegistry = {
|
|
setResourcesForServer: vi.fn(),
|
|
removeResourcesByServer: vi.fn(),
|
|
} as unknown as ResourceRegistry;
|
|
const client = new McpClient(
|
|
'test-server',
|
|
{
|
|
command: 'test-command',
|
|
},
|
|
mockedToolRegistry,
|
|
promptRegistry,
|
|
resourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
);
|
|
await client.connect();
|
|
await client.discover({} as Config);
|
|
expect(resourceRegistry.setResourcesForServer).toHaveBeenCalledWith(
|
|
'test-server',
|
|
[
|
|
expect.objectContaining({
|
|
uri: 'file:///tmp/resource.txt',
|
|
name: 'resource',
|
|
}),
|
|
],
|
|
);
|
|
});
|
|
|
|
it('refreshes registry when resource list change notification is received', async () => {
|
|
let listCallCount = 0;
|
|
let resourceListHandler:
|
|
| ((notification: unknown) => Promise<void> | void)
|
|
| undefined;
|
|
const mockedClient = {
|
|
connect: vi.fn(),
|
|
discover: vi.fn(),
|
|
disconnect: vi.fn(),
|
|
getStatus: vi.fn(),
|
|
registerCapabilities: vi.fn(),
|
|
setRequestHandler: vi.fn(),
|
|
setNotificationHandler: vi.fn((_, handler) => {
|
|
resourceListHandler = handler;
|
|
}),
|
|
getServerCapabilities: vi
|
|
.fn()
|
|
.mockReturnValue({ resources: { listChanged: true } }),
|
|
request: vi.fn().mockImplementation(({ method }) => {
|
|
if (method === 'resources/list') {
|
|
listCallCount += 1;
|
|
if (listCallCount === 1) {
|
|
return Promise.resolve({
|
|
resources: [
|
|
{
|
|
uri: 'file:///tmp/one.txt',
|
|
},
|
|
],
|
|
});
|
|
}
|
|
return Promise.resolve({
|
|
resources: [
|
|
{
|
|
uri: 'file:///tmp/two.txt',
|
|
},
|
|
],
|
|
});
|
|
}
|
|
return Promise.resolve({ prompts: [] });
|
|
}),
|
|
} as unknown as ClientLib.Client;
|
|
vi.mocked(ClientLib.Client).mockReturnValue(mockedClient);
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
const mockedToolRegistry = {
|
|
registerTool: vi.fn(),
|
|
sortTools: vi.fn(),
|
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
|
} as unknown as ToolRegistry;
|
|
const promptRegistry = {
|
|
registerPrompt: vi.fn(),
|
|
removePromptsByServer: vi.fn(),
|
|
} as unknown as PromptRegistry;
|
|
const resourceRegistry = {
|
|
setResourcesForServer: vi.fn(),
|
|
removeResourcesByServer: vi.fn(),
|
|
} as unknown as ResourceRegistry;
|
|
const client = new McpClient(
|
|
'test-server',
|
|
{
|
|
command: 'test-command',
|
|
},
|
|
mockedToolRegistry,
|
|
promptRegistry,
|
|
resourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
);
|
|
await client.connect();
|
|
await client.discover({} as Config);
|
|
|
|
expect(mockedClient.setNotificationHandler).toHaveBeenCalledOnce();
|
|
expect(resourceListHandler).toBeDefined();
|
|
|
|
await resourceListHandler?.({
|
|
method: 'notifications/resources/list_changed',
|
|
});
|
|
|
|
expect(resourceRegistry.setResourcesForServer).toHaveBeenLastCalledWith(
|
|
'test-server',
|
|
[expect.objectContaining({ uri: 'file:///tmp/two.txt' })],
|
|
);
|
|
|
|
expect(coreEvents.emitFeedback).toHaveBeenCalledWith(
|
|
'info',
|
|
'Resources updated for server: test-server',
|
|
);
|
|
});
|
|
|
|
it('refreshes prompts when prompt list change notification is received', async () => {
|
|
let listCallCount = 0;
|
|
let promptListHandler:
|
|
| ((notification: unknown) => Promise<void> | void)
|
|
| undefined;
|
|
const mockedClient = {
|
|
connect: vi.fn(),
|
|
discover: vi.fn(),
|
|
disconnect: vi.fn(),
|
|
getStatus: vi.fn(),
|
|
registerCapabilities: vi.fn(),
|
|
setRequestHandler: vi.fn(),
|
|
setNotificationHandler: vi.fn((_, handler) => {
|
|
promptListHandler = handler;
|
|
}),
|
|
getServerCapabilities: vi
|
|
.fn()
|
|
.mockReturnValue({ prompts: { listChanged: true } }),
|
|
listPrompts: vi.fn().mockImplementation(() => {
|
|
listCallCount += 1;
|
|
if (listCallCount === 1) {
|
|
return Promise.resolve({
|
|
prompts: [{ name: 'one', description: 'first' }],
|
|
});
|
|
}
|
|
return Promise.resolve({
|
|
prompts: [{ name: 'two', description: 'second' }],
|
|
});
|
|
}),
|
|
request: vi.fn().mockResolvedValue({ prompts: [] }),
|
|
} as unknown as ClientLib.Client;
|
|
vi.mocked(ClientLib.Client).mockReturnValue(mockedClient);
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
const mockedToolRegistry = {
|
|
registerTool: vi.fn(),
|
|
sortTools: vi.fn(),
|
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
|
} as unknown as ToolRegistry;
|
|
const promptRegistry = {
|
|
registerPrompt: vi.fn(),
|
|
removePromptsByServer: vi.fn(),
|
|
} as unknown as PromptRegistry;
|
|
const resourceRegistry = {
|
|
setResourcesForServer: vi.fn(),
|
|
removeResourcesByServer: vi.fn(),
|
|
} as unknown as ResourceRegistry;
|
|
const client = new McpClient(
|
|
'test-server',
|
|
{
|
|
command: 'test-command',
|
|
},
|
|
mockedToolRegistry,
|
|
promptRegistry,
|
|
resourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
);
|
|
await client.connect();
|
|
await client.discover({ sanitizationConfig: EMPTY_CONFIG } as Config);
|
|
|
|
expect(mockedClient.setNotificationHandler).toHaveBeenCalledOnce();
|
|
expect(promptListHandler).toBeDefined();
|
|
|
|
await promptListHandler?.({
|
|
method: 'notifications/prompts/list_changed',
|
|
});
|
|
|
|
expect(promptRegistry.removePromptsByServer).toHaveBeenCalledWith(
|
|
'test-server',
|
|
);
|
|
expect(promptRegistry.registerPrompt).toHaveBeenLastCalledWith(
|
|
expect.objectContaining({ name: 'two' }),
|
|
);
|
|
expect(coreEvents.emitFeedback).toHaveBeenCalledWith(
|
|
'info',
|
|
'Prompts updated for server: test-server',
|
|
);
|
|
});
|
|
|
|
it('should remove tools and prompts on disconnect', async () => {
|
|
const mockedClient = {
|
|
connect: vi.fn(),
|
|
close: vi.fn(),
|
|
getStatus: vi.fn(),
|
|
registerCapabilities: vi.fn(),
|
|
setRequestHandler: vi.fn(),
|
|
setNotificationHandler: vi.fn(),
|
|
getServerCapabilities: vi
|
|
.fn()
|
|
.mockReturnValue({ tools: {}, prompts: {} }),
|
|
listPrompts: vi.fn().mockResolvedValue({
|
|
prompts: [{ id: 'prompt1', text: 'a prompt' }],
|
|
}),
|
|
request: vi.fn().mockResolvedValue({}),
|
|
listTools: vi.fn().mockResolvedValue({
|
|
tools: [
|
|
{
|
|
name: 'testTool',
|
|
description: 'A test tool',
|
|
inputSchema: { type: 'object', properties: {} },
|
|
},
|
|
],
|
|
}),
|
|
};
|
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
|
mockedClient as unknown as ClientLib.Client,
|
|
);
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
const mockedToolRegistry = {
|
|
registerTool: vi.fn(),
|
|
unregisterTool: vi.fn(),
|
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
|
removeMcpToolsByServer: vi.fn(),
|
|
sortTools: vi.fn(),
|
|
} as unknown as ToolRegistry;
|
|
const mockedPromptRegistry = {
|
|
registerPrompt: vi.fn(),
|
|
unregisterPrompt: vi.fn(),
|
|
removePromptsByServer: vi.fn(),
|
|
} as unknown as PromptRegistry;
|
|
const resourceRegistry = {
|
|
setResourcesForServer: vi.fn(),
|
|
removeResourcesByServer: vi.fn(),
|
|
} as unknown as ResourceRegistry;
|
|
const client = new McpClient(
|
|
'test-server',
|
|
{
|
|
command: 'test-command',
|
|
},
|
|
mockedToolRegistry,
|
|
mockedPromptRegistry,
|
|
resourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
);
|
|
await client.connect();
|
|
await client.discover({} as Config);
|
|
|
|
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
|
expect(mockedPromptRegistry.registerPrompt).toHaveBeenCalledOnce();
|
|
|
|
await client.disconnect();
|
|
|
|
expect(mockedClient.close).toHaveBeenCalledOnce();
|
|
expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledOnce();
|
|
expect(mockedPromptRegistry.removePromptsByServer).toHaveBeenCalledOnce();
|
|
expect(resourceRegistry.removeResourcesByServer).toHaveBeenCalledOnce();
|
|
});
|
|
});
|
|
|
|
describe('Dynamic Tool Updates', () => {
|
|
it('should set up notification handler if server supports tool list changes', async () => {
|
|
const mockedClient = {
|
|
connect: vi.fn(),
|
|
getStatus: vi.fn(),
|
|
registerCapabilities: vi.fn(),
|
|
setRequestHandler: vi.fn(),
|
|
// Capability enables the listener
|
|
getServerCapabilities: vi
|
|
.fn()
|
|
.mockReturnValue({ tools: { listChanged: true } }),
|
|
setNotificationHandler: vi.fn(),
|
|
listTools: vi.fn().mockResolvedValue({ tools: [] }),
|
|
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
|
request: vi.fn().mockResolvedValue({}),
|
|
};
|
|
|
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
|
mockedClient as unknown as ClientLib.Client,
|
|
);
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
|
|
const client = new McpClient(
|
|
'test-server',
|
|
{ command: 'test-command' },
|
|
{} as ToolRegistry,
|
|
{} as PromptRegistry,
|
|
{} as ResourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
);
|
|
|
|
await client.connect();
|
|
|
|
expect(mockedClient.setNotificationHandler).toHaveBeenCalledWith(
|
|
ToolListChangedNotificationSchema,
|
|
expect.any(Function),
|
|
);
|
|
});
|
|
|
|
it('should NOT set up notification handler if server lacks capability', async () => {
|
|
const mockedClient = {
|
|
connect: vi.fn(),
|
|
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), // No listChanged
|
|
setNotificationHandler: vi.fn(),
|
|
request: vi.fn().mockResolvedValue({}),
|
|
registerCapabilities: vi.fn().mockResolvedValue({}),
|
|
setRequestHandler: vi.fn().mockResolvedValue({}),
|
|
};
|
|
|
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
|
mockedClient as unknown as ClientLib.Client,
|
|
);
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
|
|
const client = new McpClient(
|
|
'test-server',
|
|
{ command: 'test-command' },
|
|
{} as ToolRegistry,
|
|
{} as PromptRegistry,
|
|
{} as ResourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
);
|
|
|
|
await client.connect();
|
|
|
|
expect(mockedClient.setNotificationHandler).not.toHaveBeenCalled();
|
|
});
|
|
|
|
it('should refresh tools and notify manager when notification is received', async () => {
|
|
// Setup mocks
|
|
const mockedClient = {
|
|
connect: vi.fn(),
|
|
getServerCapabilities: vi
|
|
.fn()
|
|
.mockReturnValue({ tools: { listChanged: true } }),
|
|
setNotificationHandler: vi.fn(),
|
|
listTools: vi.fn().mockResolvedValue({
|
|
tools: [
|
|
{
|
|
name: 'newTool',
|
|
inputSchema: { type: 'object', properties: {} },
|
|
},
|
|
],
|
|
}),
|
|
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
|
request: vi.fn().mockResolvedValue({}),
|
|
registerCapabilities: vi.fn().mockResolvedValue({}),
|
|
setRequestHandler: vi.fn().mockResolvedValue({}),
|
|
};
|
|
|
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
|
mockedClient as unknown as ClientLib.Client,
|
|
);
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
|
|
const mockedToolRegistry = {
|
|
removeMcpToolsByServer: vi.fn(),
|
|
registerTool: vi.fn(),
|
|
sortTools: vi.fn(),
|
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
|
} as unknown as ToolRegistry;
|
|
|
|
const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined);
|
|
|
|
// Initialize client with onToolsUpdated callback
|
|
const client = new McpClient(
|
|
'test-server',
|
|
{ command: 'test-command' },
|
|
mockedToolRegistry,
|
|
{} as PromptRegistry,
|
|
{} as ResourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
onToolsUpdatedSpy,
|
|
);
|
|
|
|
// 1. Connect (sets up listener)
|
|
await client.connect();
|
|
|
|
// 2. Extract the callback passed to setNotificationHandler
|
|
const notificationCallback =
|
|
mockedClient.setNotificationHandler.mock.calls[0][1];
|
|
|
|
// 3. Trigger the notification manually
|
|
await notificationCallback();
|
|
|
|
// 4. Assertions
|
|
// It should clear old tools
|
|
expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledWith(
|
|
'test-server',
|
|
);
|
|
|
|
// It should fetch new tools (listTools called inside discoverTools)
|
|
expect(mockedClient.listTools).toHaveBeenCalled();
|
|
|
|
// It should register the new tool
|
|
expect(mockedToolRegistry.registerTool).toHaveBeenCalled();
|
|
|
|
// It should notify the manager
|
|
expect(onToolsUpdatedSpy).toHaveBeenCalled();
|
|
|
|
// It should emit feedback event
|
|
expect(coreEvents.emitFeedback).toHaveBeenCalledWith(
|
|
'info',
|
|
'Tools updated for server: test-server',
|
|
);
|
|
});
|
|
|
|
it('should handle errors during tool refresh gracefully', async () => {
|
|
const mockedClient = {
|
|
connect: vi.fn(),
|
|
getServerCapabilities: vi
|
|
.fn()
|
|
.mockReturnValue({ tools: { listChanged: true } }),
|
|
setNotificationHandler: vi.fn(),
|
|
// Simulate error during discovery
|
|
listTools: vi.fn().mockRejectedValue(new Error('Network blip')),
|
|
request: vi.fn().mockResolvedValue({}),
|
|
registerCapabilities: vi.fn().mockResolvedValue({}),
|
|
setRequestHandler: vi.fn().mockResolvedValue({}),
|
|
};
|
|
|
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
|
mockedClient as unknown as ClientLib.Client,
|
|
);
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
|
|
const mockedToolRegistry = {
|
|
removeMcpToolsByServer: vi.fn(),
|
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
|
} as unknown as ToolRegistry;
|
|
|
|
const client = new McpClient(
|
|
'test-server',
|
|
{ command: 'test-command' },
|
|
mockedToolRegistry,
|
|
{} as PromptRegistry,
|
|
{} as ResourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
);
|
|
|
|
await client.connect();
|
|
|
|
const notificationCallback =
|
|
mockedClient.setNotificationHandler.mock.calls[0][1];
|
|
|
|
// Trigger notification - should fail internally but catch the error
|
|
await notificationCallback();
|
|
|
|
// Should try to remove tools
|
|
expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalled();
|
|
|
|
// Should NOT emit success feedback
|
|
expect(coreEvents.emitFeedback).not.toHaveBeenCalledWith(
|
|
'info',
|
|
expect.stringContaining('Tools updated'),
|
|
);
|
|
});
|
|
|
|
it('should handle concurrent updates from multiple servers', async () => {
|
|
const createMockSdkClient = (toolName: string) => ({
|
|
connect: vi.fn(),
|
|
getServerCapabilities: vi
|
|
.fn()
|
|
.mockReturnValue({ tools: { listChanged: true } }),
|
|
setNotificationHandler: vi.fn(),
|
|
listTools: vi.fn().mockResolvedValue({
|
|
tools: [
|
|
{
|
|
name: toolName,
|
|
inputSchema: { type: 'object', properties: {} },
|
|
},
|
|
],
|
|
}),
|
|
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
|
request: vi.fn().mockResolvedValue({}),
|
|
registerCapabilities: vi.fn().mockResolvedValue({}),
|
|
setRequestHandler: vi.fn().mockResolvedValue({}),
|
|
});
|
|
|
|
const mockClientA = createMockSdkClient('tool-from-A');
|
|
const mockClientB = createMockSdkClient('tool-from-B');
|
|
|
|
vi.mocked(ClientLib.Client)
|
|
.mockReturnValueOnce(mockClientA as unknown as ClientLib.Client)
|
|
.mockReturnValueOnce(mockClientB as unknown as ClientLib.Client);
|
|
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
|
|
const mockedToolRegistry = {
|
|
removeMcpToolsByServer: vi.fn(),
|
|
registerTool: vi.fn(),
|
|
sortTools: vi.fn(),
|
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
|
} as unknown as ToolRegistry;
|
|
|
|
const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined);
|
|
|
|
const clientA = new McpClient(
|
|
'server-A',
|
|
{ command: 'cmd-a' },
|
|
mockedToolRegistry,
|
|
{} as PromptRegistry,
|
|
{} as ResourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
onToolsUpdatedSpy,
|
|
);
|
|
|
|
const clientB = new McpClient(
|
|
'server-B',
|
|
{ command: 'cmd-b' },
|
|
mockedToolRegistry,
|
|
{} as PromptRegistry,
|
|
{} as ResourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
onToolsUpdatedSpy,
|
|
);
|
|
|
|
await clientA.connect();
|
|
await clientB.connect();
|
|
|
|
const handlerA = mockClientA.setNotificationHandler.mock.calls[0][1];
|
|
const handlerB = mockClientB.setNotificationHandler.mock.calls[0][1];
|
|
|
|
// Trigger burst updates simultaneously
|
|
await Promise.all([handlerA(), handlerB()]);
|
|
|
|
expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledWith(
|
|
'server-A',
|
|
);
|
|
expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledWith(
|
|
'server-B',
|
|
);
|
|
|
|
// Verify fetching happened on both clients
|
|
expect(mockClientA.listTools).toHaveBeenCalled();
|
|
expect(mockClientB.listTools).toHaveBeenCalled();
|
|
|
|
// Verify tools from both servers were registered (2 total calls)
|
|
expect(mockedToolRegistry.registerTool).toHaveBeenCalledTimes(2);
|
|
|
|
// Verify the update callback was triggered for both
|
|
expect(onToolsUpdatedSpy).toHaveBeenCalledTimes(2);
|
|
});
|
|
|
|
it('should abort discovery and log error if timeout is exceeded during refresh', async () => {
|
|
vi.useFakeTimers();
|
|
|
|
const mockedClient = {
|
|
connect: vi.fn(),
|
|
getServerCapabilities: vi
|
|
.fn()
|
|
.mockReturnValue({ tools: { listChanged: true } }),
|
|
setNotificationHandler: vi.fn(),
|
|
// Mock listTools to simulate a long running process that respects the abort signal
|
|
listTools: vi.fn().mockImplementation(
|
|
async (params, options) =>
|
|
new Promise((resolve, reject) => {
|
|
if (options?.signal?.aborted) {
|
|
return reject(new Error('Operation aborted'));
|
|
}
|
|
options?.signal?.addEventListener(
|
|
'abort',
|
|
() => {
|
|
reject(new Error('Operation aborted'));
|
|
},
|
|
{ once: true },
|
|
);
|
|
}),
|
|
),
|
|
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
|
request: vi.fn().mockResolvedValue({}),
|
|
registerCapabilities: vi.fn().mockResolvedValue({}),
|
|
setRequestHandler: vi.fn().mockResolvedValue({}),
|
|
};
|
|
|
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
|
mockedClient as unknown as ClientLib.Client,
|
|
);
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
|
|
const mockedToolRegistry = {
|
|
removeMcpToolsByServer: vi.fn(),
|
|
registerTool: vi.fn(),
|
|
sortTools: vi.fn(),
|
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
|
} as unknown as ToolRegistry;
|
|
|
|
const client = new McpClient(
|
|
'test-server',
|
|
// Set a short timeout
|
|
{ command: 'test-command', timeout: 100 },
|
|
mockedToolRegistry,
|
|
{} as PromptRegistry,
|
|
{} as ResourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
);
|
|
|
|
await client.connect();
|
|
|
|
const notificationCallback =
|
|
mockedClient.setNotificationHandler.mock.calls[0][1];
|
|
|
|
const refreshPromise = notificationCallback();
|
|
|
|
vi.advanceTimersByTime(150);
|
|
|
|
await refreshPromise;
|
|
|
|
expect(mockedClient.listTools).toHaveBeenCalledWith(
|
|
expect.anything(),
|
|
expect.objectContaining({
|
|
signal: expect.any(AbortSignal),
|
|
}),
|
|
);
|
|
|
|
expect(mockedToolRegistry.registerTool).not.toHaveBeenCalled();
|
|
|
|
vi.useRealTimers();
|
|
});
|
|
|
|
it('should pass abort signal to onToolsUpdated callback', async () => {
|
|
const mockedClient = {
|
|
connect: vi.fn(),
|
|
getServerCapabilities: vi
|
|
.fn()
|
|
.mockReturnValue({ tools: { listChanged: true } }),
|
|
setNotificationHandler: vi.fn(),
|
|
listTools: vi.fn().mockResolvedValue({ tools: [] }),
|
|
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
|
request: vi.fn().mockResolvedValue({}),
|
|
registerCapabilities: vi.fn().mockResolvedValue({}),
|
|
setRequestHandler: vi.fn().mockResolvedValue({}),
|
|
};
|
|
|
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
|
mockedClient as unknown as ClientLib.Client,
|
|
);
|
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
{} as SdkClientStdioLib.StdioClientTransport,
|
|
);
|
|
|
|
const mockedToolRegistry = {
|
|
removeMcpToolsByServer: vi.fn(),
|
|
registerTool: vi.fn(),
|
|
sortTools: vi.fn(),
|
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
|
} as unknown as ToolRegistry;
|
|
|
|
const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined);
|
|
|
|
const client = new McpClient(
|
|
'test-server',
|
|
{ command: 'test-command' },
|
|
mockedToolRegistry,
|
|
{} as PromptRegistry,
|
|
{} as ResourceRegistry,
|
|
workspaceContext,
|
|
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
|
false,
|
|
'0.0.1',
|
|
onToolsUpdatedSpy,
|
|
);
|
|
|
|
await client.connect();
|
|
|
|
const notificationCallback =
|
|
mockedClient.setNotificationHandler.mock.calls[0][1];
|
|
|
|
await notificationCallback();
|
|
|
|
expect(onToolsUpdatedSpy).toHaveBeenCalledWith(expect.any(AbortSignal));
|
|
|
|
// Verify the signal passed was not aborted (happy path)
|
|
const signal = onToolsUpdatedSpy.mock.calls[0][0];
|
|
expect(signal.aborted).toBe(false);
|
|
});
|
|
});
|
|
|
|
describe('appendMcpServerCommand', () => {
|
|
it('should do nothing if no MCP servers or command are configured', () => {
|
|
const out = populateMcpServerCommand({}, undefined);
|
|
expect(out).toEqual({});
|
|
});
|
|
|
|
it('should discover tools via mcpServerCommand', () => {
|
|
const commandString = 'command --arg1 value1';
|
|
const out = populateMcpServerCommand({}, commandString);
|
|
expect(out).toEqual({
|
|
mcp: {
|
|
command: 'command',
|
|
args: ['--arg1', 'value1'],
|
|
},
|
|
});
|
|
});
|
|
|
|
it('should handle error if mcpServerCommand parsing fails', () => {
|
|
expect(() => populateMcpServerCommand({}, 'derp && herp')).toThrowError();
|
|
});
|
|
});
|
|
|
|
describe('createTransport', () => {
|
|
describe('should connect via httpUrl', () => {
|
|
it('without headers', async () => {
|
|
const transport = await createTransport(
|
|
'test-server',
|
|
{
|
|
httpUrl: 'http://test-server',
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
|
expect(transport).toMatchObject({
|
|
_url: new URL('http://test-server'),
|
|
_requestInit: { headers: {} },
|
|
});
|
|
});
|
|
|
|
it('with headers', async () => {
|
|
const transport = await createTransport(
|
|
'test-server',
|
|
{
|
|
httpUrl: 'http://test-server',
|
|
headers: { Authorization: 'derp' },
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
|
expect(transport).toMatchObject({
|
|
_url: new URL('http://test-server'),
|
|
_requestInit: {
|
|
headers: { Authorization: 'derp' },
|
|
},
|
|
});
|
|
});
|
|
});
|
|
|
|
describe('should connect via url', () => {
|
|
it('without headers', async () => {
|
|
const transport = await createTransport(
|
|
'test-server',
|
|
{
|
|
url: 'http://test-server',
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
|
expect(transport).toMatchObject({
|
|
_url: new URL('http://test-server'),
|
|
_requestInit: { headers: {} },
|
|
});
|
|
});
|
|
|
|
it('with headers', async () => {
|
|
const transport = await createTransport(
|
|
'test-server',
|
|
{
|
|
url: 'http://test-server',
|
|
headers: { Authorization: 'derp' },
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
|
expect(transport).toMatchObject({
|
|
_url: new URL('http://test-server'),
|
|
_requestInit: {
|
|
headers: { Authorization: 'derp' },
|
|
},
|
|
});
|
|
});
|
|
|
|
it('with type="http" creates StreamableHTTPClientTransport', async () => {
|
|
const transport = await createTransport(
|
|
'test-server',
|
|
{
|
|
url: 'http://test-server',
|
|
type: 'http',
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
|
expect(transport).toMatchObject({
|
|
_url: new URL('http://test-server'),
|
|
_requestInit: { headers: {} },
|
|
});
|
|
});
|
|
|
|
it('with type="sse" creates SSEClientTransport', async () => {
|
|
const transport = await createTransport(
|
|
'test-server',
|
|
{
|
|
url: 'http://test-server',
|
|
type: 'sse',
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(transport).toBeInstanceOf(SSEClientTransport);
|
|
expect(transport).toMatchObject({
|
|
_url: new URL('http://test-server'),
|
|
_requestInit: { headers: {} },
|
|
});
|
|
});
|
|
|
|
it('without type defaults to StreamableHTTPClientTransport', async () => {
|
|
const transport = await createTransport(
|
|
'test-server',
|
|
{
|
|
url: 'http://test-server',
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
|
expect(transport).toMatchObject({
|
|
_url: new URL('http://test-server'),
|
|
_requestInit: { headers: {} },
|
|
});
|
|
});
|
|
|
|
it('with type="http" and headers applies headers correctly', async () => {
|
|
const transport = await createTransport(
|
|
'test-server',
|
|
{
|
|
url: 'http://test-server',
|
|
type: 'http',
|
|
headers: { Authorization: 'Bearer token' },
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
|
expect(transport).toMatchObject({
|
|
_url: new URL('http://test-server'),
|
|
_requestInit: {
|
|
headers: { Authorization: 'Bearer token' },
|
|
},
|
|
});
|
|
});
|
|
|
|
it('with type="sse" and headers applies headers correctly', async () => {
|
|
const transport = await createTransport(
|
|
'test-server',
|
|
{
|
|
url: 'http://test-server',
|
|
type: 'sse',
|
|
headers: { 'X-API-Key': 'key123' },
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(transport).toBeInstanceOf(SSEClientTransport);
|
|
expect(transport).toMatchObject({
|
|
_url: new URL('http://test-server'),
|
|
_requestInit: {
|
|
headers: { 'X-API-Key': 'key123' },
|
|
},
|
|
});
|
|
});
|
|
|
|
it('httpUrl takes priority over url when both are present', async () => {
|
|
const transport = await createTransport(
|
|
'test-server',
|
|
{
|
|
httpUrl: 'http://test-server-http',
|
|
url: 'http://test-server-url',
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
// httpUrl should take priority and create HTTP transport
|
|
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
|
expect(transport).toMatchObject({
|
|
_url: new URL('http://test-server-http'),
|
|
_requestInit: { headers: {} },
|
|
});
|
|
});
|
|
});
|
|
|
|
it('should connect via command', async () => {
|
|
const mockedTransport = vi
|
|
.spyOn(SdkClientStdioLib, 'StdioClientTransport')
|
|
.mockReturnValue({} as SdkClientStdioLib.StdioClientTransport);
|
|
|
|
await createTransport(
|
|
'test-server',
|
|
{
|
|
command: 'test-command',
|
|
args: ['--foo', 'bar'],
|
|
env: { GEMINI_CLI_FOO: 'bar' },
|
|
cwd: 'test/cwd',
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(mockedTransport).toHaveBeenCalledWith({
|
|
command: 'test-command',
|
|
args: ['--foo', 'bar'],
|
|
cwd: 'test/cwd',
|
|
env: expect.objectContaining({ GEMINI_CLI_FOO: 'bar' }),
|
|
stderr: 'pipe',
|
|
});
|
|
});
|
|
|
|
it('should redact sensitive environment variables for command transport', async () => {
|
|
const mockedTransport = vi
|
|
.spyOn(SdkClientStdioLib, 'StdioClientTransport')
|
|
.mockReturnValue({} as SdkClientStdioLib.StdioClientTransport);
|
|
|
|
const originalEnv = process.env;
|
|
process.env = {
|
|
...originalEnv,
|
|
GEMINI_API_KEY: 'sensitive-key',
|
|
GEMINI_CLI_SAFE_VAR: 'safe-value',
|
|
};
|
|
// Ensure strict sanitization is not triggered for this test
|
|
delete process.env['GITHUB_SHA'];
|
|
delete process.env['SURFACE'];
|
|
|
|
try {
|
|
await createTransport(
|
|
'test-server',
|
|
{
|
|
command: 'test-command',
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
const callArgs = mockedTransport.mock.calls[0][0];
|
|
expect(callArgs.env).toBeDefined();
|
|
expect(callArgs.env!['GEMINI_CLI_SAFE_VAR']).toBe('safe-value');
|
|
expect(callArgs.env!['GEMINI_API_KEY']).toBeUndefined();
|
|
} finally {
|
|
process.env = originalEnv;
|
|
}
|
|
});
|
|
|
|
it('should include extension settings in environment', async () => {
|
|
const mockedTransport = vi
|
|
.spyOn(SdkClientStdioLib, 'StdioClientTransport')
|
|
.mockReturnValue({} as SdkClientStdioLib.StdioClientTransport);
|
|
|
|
await createTransport(
|
|
'test-server',
|
|
{
|
|
command: 'test-command',
|
|
extension: {
|
|
name: 'test-ext',
|
|
resolvedSettings: [
|
|
{
|
|
envVar: 'GEMINI_CLI_EXT_VAR',
|
|
value: 'ext-value',
|
|
sensitive: false,
|
|
name: 'ext-setting',
|
|
},
|
|
],
|
|
version: '',
|
|
isActive: false,
|
|
path: '',
|
|
contextFiles: [],
|
|
id: '',
|
|
},
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
const callArgs = mockedTransport.mock.calls[0][0];
|
|
expect(callArgs.env).toBeDefined();
|
|
expect(callArgs.env!['GEMINI_CLI_EXT_VAR']).toBe('ext-value');
|
|
});
|
|
|
|
describe('useGoogleCredentialProvider', () => {
|
|
beforeEach(() => {
|
|
// Mock GoogleAuth client
|
|
const mockClient = {
|
|
getAccessToken: vi.fn().mockResolvedValue({ token: 'test-token' }),
|
|
quotaProjectId: 'myproject',
|
|
};
|
|
|
|
GoogleAuth.prototype.getClient = vi.fn().mockResolvedValue(mockClient);
|
|
});
|
|
|
|
it('should use GoogleCredentialProvider when specified', async () => {
|
|
const transport = await createTransport(
|
|
'test-server',
|
|
{
|
|
httpUrl: 'http://test.googleapis.com',
|
|
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
|
|
oauth: {
|
|
scopes: ['scope1'],
|
|
},
|
|
headers: {
|
|
'X-Goog-User-Project': 'myproject',
|
|
},
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
const authProvider = (transport as any)._authProvider;
|
|
expect(authProvider).toBeInstanceOf(GoogleCredentialProvider);
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
const googUserProject = (transport as any)._requestInit?.headers?.[
|
|
'X-Goog-User-Project'
|
|
];
|
|
expect(googUserProject).toBe('myproject');
|
|
});
|
|
|
|
it('should use headers from GoogleCredentialProvider', async () => {
|
|
const mockGetRequestHeaders = vi.fn().mockResolvedValue({
|
|
'X-Goog-User-Project': 'provider-project',
|
|
});
|
|
vi.spyOn(
|
|
GoogleCredentialProvider.prototype,
|
|
'getRequestHeaders',
|
|
).mockImplementation(mockGetRequestHeaders);
|
|
|
|
const transport = await createTransport(
|
|
'test-server',
|
|
{
|
|
httpUrl: 'http://test.googleapis.com',
|
|
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
|
|
oauth: {
|
|
scopes: ['scope1'],
|
|
},
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
|
expect(mockGetRequestHeaders).toHaveBeenCalled();
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
const headers = (transport as any)._requestInit?.headers;
|
|
expect(headers['X-Goog-User-Project']).toBe('provider-project');
|
|
});
|
|
|
|
it('should prioritize provider headers over config headers', async () => {
|
|
const mockGetRequestHeaders = vi.fn().mockResolvedValue({
|
|
'X-Goog-User-Project': 'provider-project',
|
|
});
|
|
vi.spyOn(
|
|
GoogleCredentialProvider.prototype,
|
|
'getRequestHeaders',
|
|
).mockImplementation(mockGetRequestHeaders);
|
|
|
|
const transport = await createTransport(
|
|
'test-server',
|
|
{
|
|
httpUrl: 'http://test.googleapis.com',
|
|
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
|
|
oauth: {
|
|
scopes: ['scope1'],
|
|
},
|
|
headers: {
|
|
'X-Goog-User-Project': 'config-project',
|
|
},
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
const headers = (transport as any)._requestInit?.headers;
|
|
expect(headers['X-Goog-User-Project']).toBe('provider-project');
|
|
});
|
|
|
|
it('should use GoogleCredentialProvider with SSE transport', async () => {
|
|
const transport = await createTransport(
|
|
'test-server',
|
|
{
|
|
url: 'http://test.googleapis.com',
|
|
type: 'sse',
|
|
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
|
|
oauth: {
|
|
scopes: ['scope1'],
|
|
},
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(transport).toBeInstanceOf(SSEClientTransport);
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
const authProvider = (transport as any)._authProvider;
|
|
expect(authProvider).toBeInstanceOf(GoogleCredentialProvider);
|
|
});
|
|
|
|
it('should throw an error if no URL is provided with GoogleCredentialProvider', async () => {
|
|
await expect(
|
|
createTransport(
|
|
'test-server',
|
|
{
|
|
authProviderType: AuthProviderType.GOOGLE_CREDENTIALS,
|
|
oauth: {
|
|
scopes: ['scope1'],
|
|
},
|
|
},
|
|
false,
|
|
EMPTY_CONFIG,
|
|
),
|
|
).rejects.toThrow(
|
|
'URL must be provided in the config for Google Credentials provider',
|
|
);
|
|
});
|
|
});
|
|
});
|
|
describe('isEnabled', () => {
|
|
const funcDecl = { name: 'myTool' };
|
|
const serverName = 'myServer';
|
|
|
|
it('should return true if no include or exclude lists are provided', () => {
|
|
const mcpServerConfig = {};
|
|
expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true);
|
|
});
|
|
|
|
it('should return false if the tool is in the exclude list', () => {
|
|
const mcpServerConfig = { excludeTools: ['myTool'] };
|
|
expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false);
|
|
});
|
|
|
|
it('should return true if the tool is in the include list', () => {
|
|
const mcpServerConfig = { includeTools: ['myTool'] };
|
|
expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true);
|
|
});
|
|
|
|
it('should return true if the tool is in the include list with parentheses', () => {
|
|
const mcpServerConfig = { includeTools: ['myTool()'] };
|
|
expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true);
|
|
});
|
|
|
|
it('should return false if the include list exists but does not contain the tool', () => {
|
|
const mcpServerConfig = { includeTools: ['anotherTool'] };
|
|
expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false);
|
|
});
|
|
|
|
it('should return false if the tool is in both the include and exclude lists', () => {
|
|
const mcpServerConfig = {
|
|
includeTools: ['myTool'],
|
|
excludeTools: ['myTool'],
|
|
};
|
|
expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false);
|
|
});
|
|
|
|
it('should return false if the function declaration has no name', () => {
|
|
const namelessFuncDecl = {};
|
|
const mcpServerConfig = {};
|
|
expect(isEnabled(namelessFuncDecl, serverName, mcpServerConfig)).toBe(
|
|
false,
|
|
);
|
|
});
|
|
});
|
|
|
|
describe('hasNetworkTransport', () => {
|
|
it('should return true if only url is provided', () => {
|
|
const config = { url: 'http://example.com' };
|
|
expect(hasNetworkTransport(config)).toBe(true);
|
|
});
|
|
|
|
it('should return true if only httpUrl is provided', () => {
|
|
const config = { httpUrl: 'http://example.com' };
|
|
expect(hasNetworkTransport(config)).toBe(true);
|
|
});
|
|
|
|
it('should return true if both url and httpUrl are provided', () => {
|
|
const config = {
|
|
url: 'http://example.com/sse',
|
|
httpUrl: 'http://example.com/http',
|
|
};
|
|
expect(hasNetworkTransport(config)).toBe(true);
|
|
});
|
|
|
|
it('should return false if neither url nor httpUrl is provided', () => {
|
|
const config = { command: 'do-something' };
|
|
expect(hasNetworkTransport(config)).toBe(false);
|
|
});
|
|
|
|
it('should return false for an empty config object', () => {
|
|
const config = {};
|
|
expect(hasNetworkTransport(config)).toBe(false);
|
|
});
|
|
});
|
|
});
|
|
|
|
describe('connectToMcpServer with OAuth', () => {
|
|
let mockedClient: ClientLib.Client;
|
|
let workspaceContext: WorkspaceContext;
|
|
let testWorkspace: string;
|
|
let mockAuthProvider: MCPOAuthProvider;
|
|
let mockTokenStorage: MCPOAuthTokenStorage;
|
|
|
|
beforeEach(() => {
|
|
mockedClient = {
|
|
connect: vi.fn(),
|
|
close: vi.fn(),
|
|
registerCapabilities: vi.fn(),
|
|
setRequestHandler: vi.fn(),
|
|
onclose: vi.fn(),
|
|
notification: vi.fn(),
|
|
} as unknown as ClientLib.Client;
|
|
vi.mocked(ClientLib.Client).mockImplementation(() => mockedClient);
|
|
|
|
testWorkspace = fs.mkdtempSync(
|
|
path.join(os.tmpdir(), 'gemini-agent-test-'),
|
|
);
|
|
workspaceContext = new WorkspaceContext(testWorkspace);
|
|
|
|
vi.spyOn(console, 'log').mockImplementation(() => {});
|
|
vi.spyOn(console, 'warn').mockImplementation(() => {});
|
|
vi.spyOn(console, 'error').mockImplementation(() => {});
|
|
|
|
mockTokenStorage = {
|
|
getCredentials: vi.fn().mockResolvedValue({ clientId: 'test-client' }),
|
|
} as unknown as MCPOAuthTokenStorage;
|
|
vi.mocked(MCPOAuthTokenStorage).mockReturnValue(mockTokenStorage);
|
|
mockAuthProvider = {
|
|
authenticate: vi.fn().mockResolvedValue(undefined),
|
|
getValidToken: vi.fn().mockResolvedValue('test-access-token'),
|
|
tokenStorage: mockTokenStorage,
|
|
} as unknown as MCPOAuthProvider;
|
|
vi.mocked(MCPOAuthProvider).mockReturnValue(mockAuthProvider);
|
|
});
|
|
|
|
afterEach(() => {
|
|
vi.clearAllMocks();
|
|
});
|
|
|
|
it('should handle automatic OAuth flow on 401 with www-authenticate header', async () => {
|
|
const serverUrl = 'http://test-server.com/';
|
|
const authUrl = 'http://auth.example.com/auth';
|
|
const tokenUrl = 'http://auth.example.com/token';
|
|
const wwwAuthHeader = `Bearer realm="test", resource_metadata="http://test-server.com/.well-known/oauth-protected-resource"`;
|
|
|
|
vi.mocked(mockedClient.connect).mockRejectedValueOnce(
|
|
new StreamableHTTPError(
|
|
401,
|
|
`Unauthorized\nwww-authenticate: ${wwwAuthHeader}`,
|
|
),
|
|
);
|
|
|
|
vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({
|
|
authorizationUrl: authUrl,
|
|
tokenUrl,
|
|
scopes: ['test-scope'],
|
|
});
|
|
|
|
// We need this to be an any type because we dig into its private state.
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
let capturedTransport: any;
|
|
vi.mocked(mockedClient.connect).mockImplementationOnce(
|
|
async (transport) => {
|
|
capturedTransport = transport;
|
|
return Promise.resolve();
|
|
},
|
|
);
|
|
|
|
const client = await connectToMcpServer(
|
|
'0.0.1',
|
|
'test-server',
|
|
{ httpUrl: serverUrl, oauth: { enabled: true } },
|
|
false,
|
|
workspaceContext,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(client).toBe(mockedClient);
|
|
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
|
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
|
|
|
const authHeader =
|
|
capturedTransport._requestInit?.headers?.['Authorization'];
|
|
expect(authHeader).toBe('Bearer test-access-token');
|
|
});
|
|
|
|
it('should discover oauth config if not in www-authenticate header', async () => {
|
|
const serverUrl = 'http://test-server.com';
|
|
const authUrl = 'http://auth.example.com/auth';
|
|
const tokenUrl = 'http://auth.example.com/token';
|
|
|
|
vi.mocked(mockedClient.connect).mockRejectedValueOnce(
|
|
new StreamableHTTPError(401, 'Unauthorized'),
|
|
);
|
|
|
|
vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({
|
|
authorizationUrl: authUrl,
|
|
tokenUrl,
|
|
scopes: ['test-scope'],
|
|
});
|
|
vi.mocked(mockAuthProvider.getValidToken).mockResolvedValue(
|
|
'test-access-token-from-discovery',
|
|
);
|
|
|
|
// We need this to be an any type because we dig into its private state.
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
let capturedTransport: any;
|
|
vi.mocked(mockedClient.connect).mockImplementationOnce(
|
|
async (transport) => {
|
|
capturedTransport = transport;
|
|
return Promise.resolve();
|
|
},
|
|
);
|
|
|
|
const client = await connectToMcpServer(
|
|
'0.0.1',
|
|
'test-server',
|
|
{ httpUrl: serverUrl, oauth: { enabled: true } },
|
|
false,
|
|
workspaceContext,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(client).toBe(mockedClient);
|
|
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
|
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
|
expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl);
|
|
|
|
const authHeader =
|
|
capturedTransport._requestInit?.headers?.['Authorization'];
|
|
expect(authHeader).toBe('Bearer test-access-token-from-discovery');
|
|
});
|
|
});
|
|
|
|
describe('connectToMcpServer - HTTP→SSE fallback', () => {
|
|
let mockedClient: ClientLib.Client;
|
|
let workspaceContext: WorkspaceContext;
|
|
let testWorkspace: string;
|
|
|
|
beforeEach(() => {
|
|
mockedClient = {
|
|
connect: vi.fn(),
|
|
close: vi.fn(),
|
|
registerCapabilities: vi.fn(),
|
|
setRequestHandler: vi.fn(),
|
|
onclose: vi.fn(),
|
|
notification: vi.fn(),
|
|
} as unknown as ClientLib.Client;
|
|
vi.mocked(ClientLib.Client).mockImplementation(() => mockedClient);
|
|
|
|
testWorkspace = fs.mkdtempSync(
|
|
path.join(os.tmpdir(), 'gemini-agent-test-'),
|
|
);
|
|
workspaceContext = new WorkspaceContext(testWorkspace);
|
|
|
|
vi.spyOn(console, 'log').mockImplementation(() => {});
|
|
vi.spyOn(console, 'warn').mockImplementation(() => {});
|
|
vi.spyOn(console, 'error').mockImplementation(() => {});
|
|
});
|
|
|
|
afterEach(() => {
|
|
vi.clearAllMocks();
|
|
});
|
|
|
|
it('should NOT trigger fallback when type="http" is explicit', async () => {
|
|
vi.mocked(mockedClient.connect).mockRejectedValueOnce(
|
|
new Error('Connection failed'),
|
|
);
|
|
|
|
await expect(
|
|
connectToMcpServer(
|
|
'0.0.1',
|
|
'test-server',
|
|
{ url: 'http://test-server', type: 'http' },
|
|
false,
|
|
workspaceContext,
|
|
EMPTY_CONFIG,
|
|
),
|
|
).rejects.toThrow('Connection failed');
|
|
|
|
// Should only try once (no fallback)
|
|
expect(mockedClient.connect).toHaveBeenCalledTimes(1);
|
|
});
|
|
|
|
it('should NOT trigger fallback when type="sse" is explicit', async () => {
|
|
vi.mocked(mockedClient.connect).mockRejectedValueOnce(
|
|
new Error('Connection failed'),
|
|
);
|
|
|
|
await expect(
|
|
connectToMcpServer(
|
|
'0.0.1',
|
|
'test-server',
|
|
{ url: 'http://test-server', type: 'sse' },
|
|
false,
|
|
workspaceContext,
|
|
EMPTY_CONFIG,
|
|
),
|
|
).rejects.toThrow('Connection failed');
|
|
|
|
// Should only try once (no fallback)
|
|
expect(mockedClient.connect).toHaveBeenCalledTimes(1);
|
|
});
|
|
|
|
it('should trigger fallback when url provided without type and HTTP fails', async () => {
|
|
vi.mocked(mockedClient.connect)
|
|
.mockRejectedValueOnce(new StreamableHTTPError(500, 'Server error'))
|
|
.mockResolvedValueOnce(undefined);
|
|
|
|
const client = await connectToMcpServer(
|
|
'0.0.1',
|
|
'test-server',
|
|
{ url: 'http://test-server' },
|
|
false,
|
|
workspaceContext,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(client).toBe(mockedClient);
|
|
// First HTTP attempt fails, second SSE attempt succeeds
|
|
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
|
});
|
|
|
|
it('should throw original HTTP error when both HTTP and SSE fail (non-401)', async () => {
|
|
const httpError = new StreamableHTTPError(500, 'Server error');
|
|
const sseError = new Error('SSE connection failed');
|
|
|
|
vi.mocked(mockedClient.connect)
|
|
.mockRejectedValueOnce(httpError)
|
|
.mockRejectedValueOnce(sseError);
|
|
|
|
await expect(
|
|
connectToMcpServer(
|
|
'0.0.1',
|
|
'test-server',
|
|
{ url: 'http://test-server' },
|
|
false,
|
|
workspaceContext,
|
|
EMPTY_CONFIG,
|
|
),
|
|
).rejects.toThrow('Server error');
|
|
|
|
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
|
});
|
|
|
|
it('should handle HTTP 404 followed by SSE success', async () => {
|
|
vi.mocked(mockedClient.connect)
|
|
.mockRejectedValueOnce(new StreamableHTTPError(404, 'Not Found'))
|
|
.mockResolvedValueOnce(undefined);
|
|
|
|
const client = await connectToMcpServer(
|
|
'0.0.1',
|
|
'test-server',
|
|
{ url: 'http://test-server' },
|
|
false,
|
|
workspaceContext,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(client).toBe(mockedClient);
|
|
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
|
});
|
|
});
|
|
|
|
describe('connectToMcpServer - OAuth with transport fallback', () => {
|
|
let mockedClient: ClientLib.Client;
|
|
let workspaceContext: WorkspaceContext;
|
|
let testWorkspace: string;
|
|
let mockAuthProvider: MCPOAuthProvider;
|
|
let mockTokenStorage: MCPOAuthTokenStorage;
|
|
|
|
beforeEach(() => {
|
|
mockedClient = {
|
|
connect: vi.fn(),
|
|
close: vi.fn(),
|
|
registerCapabilities: vi.fn(),
|
|
setRequestHandler: vi.fn(),
|
|
onclose: vi.fn(),
|
|
notification: vi.fn(),
|
|
} as unknown as ClientLib.Client;
|
|
vi.mocked(ClientLib.Client).mockImplementation(() => mockedClient);
|
|
|
|
testWorkspace = fs.mkdtempSync(
|
|
path.join(os.tmpdir(), 'gemini-agent-test-'),
|
|
);
|
|
workspaceContext = new WorkspaceContext(testWorkspace);
|
|
|
|
vi.spyOn(console, 'log').mockImplementation(() => {});
|
|
vi.spyOn(console, 'warn').mockImplementation(() => {});
|
|
vi.spyOn(console, 'error').mockImplementation(() => {});
|
|
|
|
// Mock fetch to prevent real network calls during OAuth discovery fallback.
|
|
// When a 401 error lacks a www-authenticate header, the code attempts to
|
|
// fetch the header directly from the server, which would hang without this mock.
|
|
vi.stubGlobal(
|
|
'fetch',
|
|
vi.fn().mockResolvedValue({
|
|
status: 401,
|
|
headers: new Headers({
|
|
'www-authenticate': `Bearer realm="test", resource_metadata="http://test-server/.well-known/oauth-protected-resource"`,
|
|
}),
|
|
}),
|
|
);
|
|
|
|
mockTokenStorage = {
|
|
getCredentials: vi.fn().mockResolvedValue({ clientId: 'test-client' }),
|
|
} as unknown as MCPOAuthTokenStorage;
|
|
vi.mocked(MCPOAuthTokenStorage).mockReturnValue(mockTokenStorage);
|
|
|
|
mockAuthProvider = {
|
|
authenticate: vi.fn().mockResolvedValue(undefined),
|
|
getValidToken: vi.fn().mockResolvedValue('test-access-token'),
|
|
tokenStorage: mockTokenStorage,
|
|
} as unknown as MCPOAuthProvider;
|
|
vi.mocked(MCPOAuthProvider).mockReturnValue(mockAuthProvider);
|
|
|
|
vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({
|
|
authorizationUrl: 'http://auth.example.com/auth',
|
|
tokenUrl: 'http://auth.example.com/token',
|
|
scopes: ['test-scope'],
|
|
});
|
|
});
|
|
|
|
afterEach(() => {
|
|
vi.clearAllMocks();
|
|
vi.unstubAllGlobals();
|
|
});
|
|
|
|
it('should handle HTTP 404 → SSE 401 → OAuth → SSE+OAuth succeeds', async () => {
|
|
// Tests that OAuth flow works when SSE (not HTTP) requires auth
|
|
vi.mocked(mockedClient.connect)
|
|
.mockRejectedValueOnce(new StreamableHTTPError(404, 'Not Found'))
|
|
.mockRejectedValueOnce(new StreamableHTTPError(401, 'Unauthorized'))
|
|
.mockResolvedValueOnce(undefined);
|
|
|
|
const client = await connectToMcpServer(
|
|
'0.0.1',
|
|
'test-server',
|
|
{ url: 'http://test-server', oauth: { enabled: true } },
|
|
false,
|
|
workspaceContext,
|
|
EMPTY_CONFIG,
|
|
);
|
|
|
|
expect(client).toBe(mockedClient);
|
|
expect(mockedClient.connect).toHaveBeenCalledTimes(3);
|
|
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
|
});
|
|
});
|