Extensions MCP refactor (#12413)

This commit is contained in:
Jacob MacDonald
2025-11-04 07:51:18 -08:00
committed by GitHub
parent 2b77c1ded4
commit da4fa5ad75
28 changed files with 877 additions and 478 deletions
+52 -12
View File
@@ -154,6 +154,7 @@ import {
type ExtensionLoader,
SimpleExtensionLoader,
} from '../utils/extensionLoader.js';
import { McpClientManager } from '../tools/mcp-client-manager.js';
export type { FileFilteringOptions };
export {
@@ -251,7 +252,8 @@ export interface ConfigParameters {
extensionLoader?: ExtensionLoader;
enabledExtensions?: string[];
enableExtensionReloading?: boolean;
blockedMcpServers?: Array<{ name: string; extensionName: string }>;
allowedMcpServers?: string[];
blockedMcpServers?: string[];
noBrowser?: boolean;
summarizeToolOutput?: Record<string, SummarizeToolOutputSettings>;
folderTrust?: boolean;
@@ -293,6 +295,9 @@ export interface ConfigParameters {
export class Config {
private toolRegistry!: ToolRegistry;
private mcpClientManager?: McpClientManager;
private allowedMcpServers: string[];
private blockedMcpServers: string[];
private promptRegistry!: PromptRegistry;
private agentRegistry!: AgentRegistry;
private readonly sessionId: string;
@@ -347,10 +352,6 @@ export class Config {
private readonly _extensionLoader: ExtensionLoader;
private readonly _enabledExtensions: string[];
private readonly enableExtensionReloading: boolean;
private readonly _blockedMcpServers: Array<{
name: string;
extensionName: string;
}>;
fallbackModelHandler?: FallbackModelHandler;
private quotaErrorOccurred: boolean = false;
private readonly summarizeToolOutput:
@@ -417,6 +418,8 @@ export class Config {
this.toolCallCommand = params.toolCallCommand;
this.mcpServerCommand = params.mcpServerCommand;
this.mcpServers = params.mcpServers;
this.allowedMcpServers = params.allowedMcpServers ?? [];
this.blockedMcpServers = params.blockedMcpServers ?? [];
this.userMemory = params.userMemory ?? '';
this.geminiMdFileCount = params.geminiMdFileCount ?? 0;
this.geminiMdFilePaths = params.geminiMdFilePaths ?? [];
@@ -458,7 +461,6 @@ export class Config {
this._extensionLoader =
params.extensionLoader ?? new SimpleExtensionLoader([]);
this._enabledExtensions = params.enabledExtensions ?? [];
this._blockedMcpServers = params.blockedMcpServers ?? [];
this.noBrowser = params.noBrowser ?? false;
this.summarizeToolOutput = params.summarizeToolOutput;
this.folderTrust = params.folderTrust ?? false;
@@ -572,6 +574,15 @@ export class Config {
await this.agentRegistry.initialize();
this.toolRegistry = await this.createToolRegistry();
this.mcpClientManager = new McpClientManager(
this.toolRegistry,
this,
this.eventEmitter,
);
await Promise.all([
await this.mcpClientManager.startConfiguredMcpServers(),
await this.getExtensionLoader().start(this),
]);
await this.geminiClient.initialize();
}
@@ -752,8 +763,23 @@ export class Config {
return this.allowedTools;
}
/**
* All the excluded tools from static configuration, loaded extensions, or
* other sources.
*
* May change over time.
*/
getExcludeTools(): string[] | undefined {
return this.excludeTools;
const excludeToolsSet = new Set([...(this.excludeTools ?? [])]);
for (const extension of this.getExtensionLoader().getExtensions()) {
if (!extension.isActive) {
continue;
}
for (const tool of extension.excludeTools || []) {
excludeToolsSet.add(tool);
}
}
return [...excludeToolsSet];
}
getToolDiscoveryCommand(): string | undefined {
@@ -768,10 +794,27 @@ export class Config {
return this.mcpServerCommand;
}
/**
* The user configured MCP servers (via gemini settings files).
*
* Does NOT include mcp servers configured by extensions.
*/
getMcpServers(): Record<string, MCPServerConfig> | undefined {
return this.mcpServers;
}
getMcpClientManager(): McpClientManager | undefined {
return this.mcpClientManager;
}
getAllowedMcpServers(): string[] | undefined {
return this.allowedMcpServers;
}
getBlockedMcpServers(): string[] | undefined {
return this.blockedMcpServers;
}
setMcpServers(mcpServers: Record<string, MCPServerConfig>): void {
this.mcpServers = mcpServers;
}
@@ -955,10 +998,6 @@ export class Config {
return this.enableExtensionReloading;
}
getBlockedMcpServers(): Array<{ name: string; extensionName: string }> {
return this._blockedMcpServers;
}
getNoBrowser(): boolean {
return this.noBrowser;
}
@@ -1155,7 +1194,7 @@ export class Config {
}
async createToolRegistry(): Promise<ToolRegistry> {
const registry = new ToolRegistry(this, this.eventEmitter);
const registry = new ToolRegistry(this);
// Set message bus on tool registry before discovery so MCP tools can access it
if (this.getEnableMessageBusIntegration()) {
@@ -1250,6 +1289,7 @@ export class Config {
if (definition) {
// We must respect the main allowed/exclude lists for agents too.
const excludeTools = this.getExcludeTools() || [];
const allowedTools = this.getAllowedTools();
const isExcluded = excludeTools.includes(definition.name);
+10 -5
View File
@@ -192,11 +192,9 @@ describe('loggers', () => {
getFileFilteringRespectGitIgnore: () => true,
getFileFilteringAllowBuildArtifacts: () => false,
getDebugMode: () => true,
getMcpServers: () => ({
'test-server': {
command: 'test-command',
},
}),
getMcpServers: () => {
throw new Error('Should not call');
},
getQuestion: () => 'test-question',
getTargetDir: () => 'target-dir',
getProxy: () => 'http://test.proxy.com:8080',
@@ -206,6 +204,13 @@ describe('loggers', () => {
{ name: 'ext-one', id: 'id-one' },
{ name: 'ext-two', id: 'id-two' },
] as GeminiCLIExtension[],
getMcpClientManager: () => ({
getMcpServers: () => ({
'test-server': {
command: 'test-command',
},
}),
}),
} as unknown as Config;
const startSessionEvent = new StartSessionEvent(mockConfig);
+2 -1
View File
@@ -74,7 +74,8 @@ export class StartSessionEvent implements BaseTelemetryEvent {
constructor(config: Config, toolRegistry?: ToolRegistry) {
const generatorConfig = config.getContentGeneratorConfig();
const mcpServers = config.getMcpServers();
const mcpServers =
config.getMcpClientManager()?.getMcpServers() ?? config.getMcpServers();
let useGemini = false;
let useVertex = false;
@@ -4,86 +4,193 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { afterEach, describe, expect, it, vi } from 'vitest';
import {
afterEach,
beforeEach,
describe,
expect,
it,
vi,
type MockedObject,
} from 'vitest';
import { McpClientManager } from './mcp-client-manager.js';
import { McpClient } from './mcp-client.js';
import type { ToolRegistry } from './tool-registry.js';
import type { Config } from '../config/config.js';
import { SimpleExtensionLoader } from '../utils/extensionLoader.js';
vi.mock('./mcp-client.js', async () => {
const originalModule = await vi.importActual('./mcp-client.js');
return {
...originalModule,
McpClient: vi.fn(),
populateMcpServerCommand: vi.fn(() => ({
'test-server': {},
})),
};
});
describe('McpClientManager', () => {
afterEach(() => {
vi.restoreAllMocks();
});
let mockedMcpClient: MockedObject<McpClient>;
let mockConfig: MockedObject<Config>;
it('should discover tools from all servers', async () => {
const mockedMcpClient = {
beforeEach(() => {
mockedMcpClient = vi.mockObject({
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
};
vi.mocked(McpClient).mockReturnValue(
mockedMcpClient as unknown as McpClient,
);
const manager = new McpClientManager(
{} as ToolRegistry,
{
isTrustedFolder: () => true,
getExtensionLoader: () => new SimpleExtensionLoader([]),
getMcpServers: () => ({
'test-server': {},
}),
getMcpServerCommand: () => '',
getPromptRegistry: () => {},
getDebugMode: () => false,
getWorkspaceContext: () => {},
getEnableExtensionReloading: () => false,
} as unknown as Config,
);
await manager.discoverAllMcpTools();
getServerConfig: vi.fn(),
} as unknown as McpClient);
vi.mocked(McpClient).mockReturnValue(mockedMcpClient);
mockConfig = vi.mockObject({
isTrustedFolder: vi.fn().mockReturnValue(true),
getMcpServers: vi.fn().mockReturnValue({}),
getPromptRegistry: () => {},
getDebugMode: () => false,
getWorkspaceContext: () => {},
getAllowedMcpServers: vi.fn().mockReturnValue([]),
getBlockedMcpServers: vi.fn().mockReturnValue([]),
getMcpServerCommand: vi.fn().mockReturnValue(''),
getGeminiClient: vi.fn().mockReturnValue({
isInitialized: vi.fn(),
}),
} as unknown as Config);
});
afterEach(() => {
vi.restoreAllMocks();
});
it('should discover tools from all configured', async () => {
mockConfig.getMcpServers.mockReturnValue({
'test-server': {},
});
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
});
it('should not discover tools if folder is not trusted', async () => {
const mockedMcpClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
};
vi.mocked(McpClient).mockReturnValue(
mockedMcpClient as unknown as McpClient,
);
const manager = new McpClientManager(
{} as ToolRegistry,
{
isTrustedFolder: () => false,
getExtensionLoader: () => new SimpleExtensionLoader([]),
getMcpServers: () => ({
'test-server': {},
}),
getMcpServerCommand: () => '',
getPromptRegistry: () => {},
getDebugMode: () => false,
getWorkspaceContext: () => {},
getEnableExtensionReloading: () => false,
} as unknown as Config,
);
await manager.discoverAllMcpTools();
mockConfig.getMcpServers.mockReturnValue({
'test-server': {},
});
mockConfig.isTrustedFolder.mockReturnValue(false);
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
});
it('should not start blocked servers', async () => {
mockConfig.getMcpServers.mockReturnValue({
'test-server': {},
});
mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']);
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
});
it('should only start allowed servers if allow list is not empty', async () => {
mockConfig.getMcpServers.mockReturnValue({
'test-server': {},
'another-server': {},
});
mockConfig.getAllowedMcpServers.mockReturnValue(['another-server']);
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
});
it('should start servers from extensions', async () => {
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
await manager.startExtension({
name: 'test-extension',
mcpServers: {
'test-server': {},
},
isActive: true,
version: '1.0.0',
path: '/some-path',
contextFiles: [],
id: '123',
});
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
});
it('should not start servers from disabled extensions', async () => {
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
await manager.startExtension({
name: 'test-extension',
mcpServers: {
'test-server': {},
},
isActive: false,
version: '1.0.0',
path: '/some-path',
contextFiles: [],
id: '123',
});
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
});
it('should add blocked servers to the blockedMcpServers list', async () => {
mockConfig.getMcpServers.mockReturnValue({
'test-server': {},
});
mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']);
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
await manager.startConfiguredMcpServers();
expect(manager.getBlockedMcpServers()).toEqual([
{ name: 'test-server', extensionName: '' },
]);
});
describe('restart', () => {
it('should restart all running servers', async () => {
mockConfig.getMcpServers.mockReturnValue({
'test-server': {},
});
mockedMcpClient.getServerConfig.mockReturnValue({});
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1);
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1);
await manager.restart();
expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1);
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2);
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2);
});
});
describe('restartServer', () => {
it('should restart the specified server', async () => {
mockConfig.getMcpServers.mockReturnValue({
'test-server': {},
});
mockedMcpClient.getServerConfig.mockReturnValue({});
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1);
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1);
await manager.restartServer('test-server');
expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1);
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2);
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2);
});
it('should throw an error if the server does not exist', async () => {
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
await expect(manager.restartServer('non-existent')).rejects.toThrow(
'No MCP server registered with the name "non-existent"',
);
});
});
});
+140 -53
View File
@@ -33,6 +33,10 @@ export class McpClientManager {
private discoveryPromise: Promise<void> | undefined;
private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
private readonly eventEmitter?: EventEmitter;
private readonly blockedMcpServers: Array<{
name: string;
extensionName: string;
}> = [];
constructor(
toolRegistry: ToolRegistry,
@@ -42,19 +46,10 @@ export class McpClientManager {
this.toolRegistry = toolRegistry;
this.cliConfig = cliConfig;
this.eventEmitter = eventEmitter;
if (this.cliConfig.getEnableExtensionReloading()) {
this.cliConfig
.getExtensionLoader()
.extensionEvents()
.on('extensionLoaded', (event) => this.loadExtension(event.extension))
.on('extensionEnabled', (event) => this.loadExtension(event.extension))
.on('extensionDisabled', (event) =>
this.unloadExtension(event.extension),
)
.on('extensionUnloaded', (event) =>
this.unloadExtension(event.extension),
);
}
}
getBlockedMcpServers() {
return this.blockedMcpServers;
}
/**
@@ -64,21 +59,13 @@ export class McpClientManager {
* - Disconnects all MCP clients from their servers.
* - Updates the Gemini chat configuration to load the new tools.
*/
private async unloadExtension(extension: GeminiCLIExtension) {
async stopExtension(extension: GeminiCLIExtension) {
debugLogger.log(`Unloading extension: ${extension.name}`);
await Promise.all(
Object.keys(extension.mcpServers ?? {}).map((name) => {
const newMcpServers = {
...this.cliConfig.getMcpServers(),
};
delete newMcpServers[name];
this.cliConfig.setMcpServers(newMcpServers);
return this.disconnectClient(name);
}),
Object.keys(extension.mcpServers ?? {}).map(
this.disconnectClient.bind(this),
),
);
// This is required to update the content generator configuration with the
// new tool configuration.
this.cliConfig.getGeminiClient().setTools();
}
/**
@@ -88,20 +75,36 @@ export class McpClientManager {
* - Connects MCP clients to each server and discovers their tools.
* - Updates the Gemini chat configuration to load the new tools.
*/
private async loadExtension(extension: GeminiCLIExtension) {
async startExtension(extension: GeminiCLIExtension) {
debugLogger.log(`Loading extension: ${extension.name}`);
await Promise.all(
Object.entries(extension.mcpServers ?? {}).map(([name, config]) => {
this.cliConfig.setMcpServers({
...this.cliConfig.getMcpServers(),
[name]: config,
});
return this.discoverMcpTools(name, config);
}),
Object.entries(extension.mcpServers ?? {}).map(([name, config]) =>
this.maybeDiscoverMcpServer(name, {
...config,
extension,
}),
),
);
// This is required to update the content generator configuration with the
// new tool configuration.
this.cliConfig.getGeminiClient().setTools();
}
private isAllowedMcpServer(name: string) {
const allowedNames = this.cliConfig.getAllowedMcpServers();
if (
allowedNames &&
allowedNames.length > 0 &&
allowedNames.indexOf(name) === -1
) {
return false;
}
const blockedNames = this.cliConfig.getBlockedMcpServers();
if (
blockedNames &&
blockedNames.length > 0 &&
blockedNames.indexOf(name) !== -1
) {
return false;
}
return true;
}
private async disconnectClient(name: string) {
@@ -115,36 +118,68 @@ export class McpClientManager {
debugLogger.warn(
`Error stopping client '${name}': ${getErrorMessage(error)}`,
);
} finally {
// This is required to update the content generator configuration with the
// new tool configuration.
const geminiClient = this.cliConfig.getGeminiClient();
if (geminiClient.isInitialized()) {
await geminiClient.setTools();
}
}
}
}
discoverMcpTools(
maybeDiscoverMcpServer(
name: string,
config: MCPServerConfig,
): Promise<void> | void {
if (!this.isAllowedMcpServer(name)) {
if (!this.blockedMcpServers.find((s) => s.name === name)) {
this.blockedMcpServers?.push({
name,
extensionName: config.extension?.name ?? '',
});
}
return;
}
if (!this.cliConfig.isTrustedFolder()) {
return;
}
if (config.extension && !config.extension.isActive) {
return;
}
const existing = this.clients.get(name);
if (existing && existing.getServerConfig().extension !== config.extension) {
const extensionText = config.extension
? ` from extension "${config.extension.name}"`
: '';
debugLogger.warn(
`Skipping MCP config for server with name "${name}"${extensionText} as it already exists.`,
);
return;
}
const currentDiscoveryPromise = new Promise<void>((resolve, _reject) => {
(async () => {
try {
await this.disconnectClient(name);
if (existing) {
await existing.disconnect();
}
const client = new McpClient(
name,
config,
this.toolRegistry,
this.cliConfig.getPromptRegistry(),
this.cliConfig.getWorkspaceContext(),
this.cliConfig.getDebugMode(),
);
this.clients.set(name, client);
this.eventEmitter?.emit('mcp-client-update', this.clients);
const client =
existing ??
new McpClient(
name,
config,
this.toolRegistry,
this.cliConfig.getPromptRegistry(),
this.cliConfig.getWorkspaceContext(),
this.cliConfig.getDebugMode(),
);
if (!existing) {
this.clients.set(name, client);
this.eventEmitter?.emit('mcp-client-update', this.clients);
}
try {
await client.connect();
await client.discover(this.cliConfig);
@@ -161,6 +196,12 @@ export class McpClientManager {
);
}
} finally {
// This is required to update the content generator configuration with the
// new tool configuration.
const geminiClient = this.cliConfig.getGeminiClient();
if (geminiClient.isInitialized()) {
await geminiClient.setTools();
}
resolve();
}
})();
@@ -174,6 +215,7 @@ export class McpClientManager {
this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
this.discoveryPromise = currentDiscoveryPromise;
}
this.eventEmitter?.emit('mcp-client-update', this.clients);
const currentPromise = this.discoveryPromise;
currentPromise.then((_) => {
// If we are the last recorded discoveryPromise, then we are done, reset
@@ -187,15 +229,21 @@ export class McpClientManager {
}
/**
* Initiates the tool discovery process for all configured MCP servers.
* Initiates the tool discovery process for all configured MCP servers (via
* gemini settings or command line arguments).
*
* It connects to each server, discovers its available tools, and registers
* them with the `ToolRegistry`.
*
* For any server which is already connected, it will first be disconnected.
*
* This does NOT load extension MCP servers - this happens when the
* ExtensionLoader explicitly calls `loadExtension`.
*/
async discoverAllMcpTools(): Promise<void> {
async startConfiguredMcpServers(): Promise<void> {
if (!this.cliConfig.isTrustedFolder()) {
return;
}
await this.stop();
const servers = populateMcpServerCommand(
this.cliConfig.getMcpServers() || {},
@@ -204,12 +252,40 @@ export class McpClientManager {
this.eventEmitter?.emit('mcp-client-update', this.clients);
await Promise.all(
Object.entries(servers).map(async ([name, config]) =>
this.discoverMcpTools(name, config),
Object.entries(servers).map(([name, config]) =>
this.maybeDiscoverMcpServer(name, config),
),
);
}
/**
* Restarts all active MCP Clients.
*/
async restart(): Promise<void> {
await Promise.all(
Array.from(this.clients.entries()).map(async ([name, client]) => {
try {
await this.maybeDiscoverMcpServer(name, client.getServerConfig());
} catch (error) {
debugLogger.error(
`Error restarting client '${name}': ${getErrorMessage(error)}`,
);
}
}),
);
}
/**
* Restart a single MCP server by name.
*/
async restartServer(name: string) {
const client = this.clients.get(name);
if (!client) {
throw new Error(`No MCP server registered with the name "${name}"`);
}
await this.maybeDiscoverMcpServer(name, client.getServerConfig());
}
/**
* Stops all running local MCP servers and closes all client connections.
* This is the cleanup method to be called on application exit.
@@ -236,4 +312,15 @@ export class McpClientManager {
getDiscoveryState(): MCPDiscoveryState {
return this.discoveryState;
}
/**
* All of the MCP server configurations currently loaded.
*/
getMcpServers(): Record<string, MCPServerConfig> {
const mcpServers: Record<string, MCPServerConfig> = {};
for (const [name, client] of this.clients.entries()) {
mcpServers[name] = client.getServerConfig();
}
return mcpServers;
}
}
@@ -303,6 +303,71 @@ describe('mcp-client', () => {
expect(mockedMcpToTool).toHaveBeenCalledOnce();
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
});
it('should remove tools and prompts on disconnect', async () => {
const mockedClient = {
connect: vi.fn(),
close: vi.fn(),
getStatus: vi.fn(),
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
getServerCapabilities: vi
.fn()
.mockReturnValue({ tools: {}, prompts: {} }),
request: vi.fn().mockResolvedValue({
prompts: [{ id: 'prompt1', text: 'a prompt' }],
}),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'testTool',
description: 'A test tool',
},
],
}),
} as unknown as GenAiLib.CallableTool);
const mockedToolRegistry = {
registerTool: vi.fn(),
unregisterTool: vi.fn(),
getMessageBus: vi.fn().mockReturnValue(undefined),
removeMcpToolsByServer: vi.fn(),
} as unknown as ToolRegistry;
const mockedPromptRegistry = {
registerPrompt: vi.fn(),
unregisterPrompt: vi.fn(),
removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry;
const client = new McpClient(
'test-server',
{
command: 'test-command',
},
mockedToolRegistry,
mockedPromptRegistry,
workspaceContext,
false,
);
await client.connect();
await client.discover({} as Config);
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
expect(mockedPromptRegistry.registerPrompt).toHaveBeenCalledOnce();
await client.disconnect();
expect(mockedClient.close).toHaveBeenCalledOnce();
expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledOnce();
expect(mockedPromptRegistry.removePromptsByServer).toHaveBeenCalledOnce();
});
});
describe('appendMcpServerCommand', () => {
it('should do nothing if no MCP servers or command are configured', () => {
+5
View File
@@ -161,6 +161,7 @@ export class McpClient {
return;
}
this.toolRegistry.removeMcpToolsByServer(this.serverName);
this.promptRegistry.removePromptsByServer(this.serverName);
this.updateStatus(MCPServerStatus.DISCONNECTING);
const client = this.client;
this.client = undefined;
@@ -208,6 +209,10 @@ export class McpClient {
this.assertConnected();
return discoverPrompts(this.serverName, this.client!, this.promptRegistry);
}
getServerConfig(): MCPServerConfig {
return this.serverConfig;
}
}
/**
@@ -19,20 +19,10 @@ import { spawn } from 'node:child_process';
import fs from 'node:fs';
import { MockTool } from '../test-utils/mock-tool.js';
import { McpClientManager } from './mcp-client-manager.js';
import { ToolErrorType } from './tool-error.js';
vi.mock('node:fs');
// Mock ./mcp-client.js to control its behavior within tool-registry tests
vi.mock('./mcp-client.js', async () => {
const originalModule = await vi.importActual('./mcp-client.js');
return {
...originalModule,
};
});
// Mock node:child_process
vi.mock('node:child_process', async () => {
const actual = await vi.importActual('node:child_process');
@@ -401,27 +391,6 @@ describe('ToolRegistry', () => {
expect(result.llmContent).toContain('Stderr: Something went wrong');
expect(result.llmContent).toContain('Exit Code: 1');
});
it('should discover tools using MCP servers defined in getMcpServers', async () => {
const discoverSpy = vi.spyOn(
McpClientManager.prototype,
'discoverAllMcpTools',
);
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
const mcpServerConfigVal = {
'my-mcp-server': {
command: 'mcp-server-cmd',
args: ['--port', '1234'],
trust: true,
},
};
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
await toolRegistry.discoverAllTools();
expect(discoverSpy).toHaveBeenCalled();
});
});
describe('DiscoveredToolInvocation', () => {
+1 -63
View File
@@ -14,13 +14,10 @@ import { Kind, BaseDeclarativeTool, BaseToolInvocation } from './tools.js';
import type { Config } from '../config/config.js';
import { spawn } from 'node:child_process';
import { StringDecoder } from 'node:string_decoder';
import { connectAndDiscover } from './mcp-client.js';
import { McpClientManager } from './mcp-client-manager.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { parse } from 'shell-quote';
import { ToolErrorType } from './tool-error.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
import type { EventEmitter } from 'node:events';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { debugLogger } from '../utils/debugLogger.js';
import { coreEvents } from '../utils/events.js';
@@ -176,12 +173,10 @@ export class ToolRegistry {
// The tools keyed by tool name as seen by the LLM.
private tools: Map<string, AnyDeclarativeTool> = new Map();
private config: Config;
private mcpClientManager: McpClientManager;
private messageBus?: MessageBus;
constructor(config: Config, eventEmitter?: EventEmitter) {
constructor(config: Config) {
this.config = config;
this.mcpClientManager = new McpClientManager(this, config, eventEmitter);
}
setMessageBus(messageBus: MessageBus): void {
@@ -238,64 +233,7 @@ export class ToolRegistry {
async discoverAllTools(): Promise<void> {
// remove any previously discovered tools
this.removeDiscoveredTools();
this.config.getPromptRegistry().clear();
await this.discoverAndRegisterToolsFromCommand();
// discover tools using MCP servers, if configured
await this.mcpClientManager.discoverAllMcpTools();
}
/**
* Discovers tools from project (if available and configured).
* Can be called multiple times to update discovered tools.
* This will NOT discover tools from the command line, only from MCP servers.
*/
async discoverMcpTools(): Promise<void> {
// remove any previously discovered tools
this.removeDiscoveredTools();
this.config.getPromptRegistry().clear();
// discover tools using MCP servers, if configured
await this.mcpClientManager.discoverAllMcpTools();
}
/**
* Restarts all MCP servers and re-discovers tools.
*/
async restartMcpServers(): Promise<void> {
await this.discoverMcpTools();
}
/**
* Discover or re-discover tools for a single MCP server.
* @param serverName - The name of the server to discover tools from.
*/
async discoverToolsForServer(serverName: string): Promise<void> {
// Remove any previously discovered tools from this server
for (const [name, tool] of this.tools.entries()) {
if (tool instanceof DiscoveredMCPTool && tool.serverName === serverName) {
this.tools.delete(name);
}
}
this.config.getPromptRegistry().removePromptsByServer(serverName);
const mcpServers = this.config.getMcpServers() ?? {};
const serverConfig = mcpServers[serverName];
if (serverConfig) {
await connectAndDiscover(
serverName,
serverConfig,
this,
this.config.getPromptRegistry(),
this.config.getDebugMode(),
this.config.getWorkspaceContext(),
this.config,
);
}
}
private async discoverAndRegisterToolsFromCommand(): Promise<void> {
@@ -0,0 +1,108 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest';
import { SimpleExtensionLoader } from './extensionLoader.js';
import type { Config } from '../config/config.js';
import { type McpClientManager } from '../tools/mcp-client-manager.js';
describe('SimpleExtensionLoader', () => {
let mockConfig: Config;
let extensionReloadingEnabled: boolean;
let mockMcpClientManager: McpClientManager;
const activeExtension = {
name: 'test-extension',
isActive: true,
version: '1.0.0',
path: '/path/to/extension',
contextFiles: [],
id: '123',
};
const inactiveExtension = {
name: 'test-extension',
isActive: false,
version: '1.0.0',
path: '/path/to/extension',
contextFiles: [],
id: '123',
};
beforeEach(() => {
mockMcpClientManager = {
startExtension: vi.fn(),
stopExtension: vi.fn(),
} as unknown as McpClientManager;
extensionReloadingEnabled = false;
mockConfig = {
getMcpClientManager: () => mockMcpClientManager,
getEnableExtensionReloading: () => extensionReloadingEnabled,
} as unknown as Config;
});
afterEach(() => {
vi.restoreAllMocks();
});
it('should start active extensions', async () => {
const loader = new SimpleExtensionLoader([activeExtension]);
await loader.start(mockConfig);
expect(mockMcpClientManager.startExtension).toHaveBeenCalledExactlyOnceWith(
activeExtension,
);
});
it('should not start inactive extensions', async () => {
const loader = new SimpleExtensionLoader([inactiveExtension]);
await loader.start(mockConfig);
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
});
describe('interactive extension loading and unloading', () => {
it('should not call `start` or `stop` if the loader is not already started', async () => {
const loader = new SimpleExtensionLoader([]);
await loader.loadExtension(activeExtension);
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
await loader.unloadExtension(activeExtension);
expect(mockMcpClientManager.stopExtension).not.toHaveBeenCalled();
});
it('should start extensions that were explicitly loaded prior to initializing the loader', async () => {
const loader = new SimpleExtensionLoader([]);
await loader.loadExtension(activeExtension);
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
await loader.start(mockConfig);
expect(
mockMcpClientManager.startExtension,
).toHaveBeenCalledExactlyOnceWith(activeExtension);
});
it.each([true, false])(
'should only call `start` and `stop` if extension reloading is enabled ($i)',
async (reloadingEnabled) => {
extensionReloadingEnabled = reloadingEnabled;
const loader = new SimpleExtensionLoader([]);
await loader.start(mockConfig);
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
await loader.loadExtension(activeExtension);
if (reloadingEnabled) {
expect(
mockMcpClientManager.startExtension,
).toHaveBeenCalledExactlyOnceWith(activeExtension);
} else {
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
}
await loader.unloadExtension(activeExtension);
if (reloadingEnabled) {
expect(
mockMcpClientManager.stopExtension,
).toHaveBeenCalledExactlyOnceWith(activeExtension);
} else {
expect(mockMcpClientManager.stopExtension).not.toHaveBeenCalled();
}
},
);
});
});
+175 -26
View File
@@ -4,45 +4,194 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { EventEmitter } from 'node:events';
import type { GeminiCLIExtension } from '../config/config.js';
import type { EventEmitter } from 'node:events';
import type { Config, GeminiCLIExtension } from '../config/config.js';
export interface ExtensionLoader {
getExtensions(): GeminiCLIExtension[];
export abstract class ExtensionLoader {
// Assigned in `start`.
protected config: Config | undefined;
extensionEvents(): EventEmitter<ExtensionEvents>;
// Used to track the count of currently starting and stopping extensions and
// fire appropriate events.
protected startingCount: number = 0;
protected startCompletedCount: number = 0;
protected stoppingCount: number = 0;
protected stopCompletedCount: number = 0;
constructor(private readonly eventEmitter?: EventEmitter<ExtensionEvents>) {}
/**
* All currently known extensions, both active and inactive.
*/
abstract getExtensions(): GeminiCLIExtension[];
/**
* Fully initializes all active extensions.
*
* Called within `Config.initialize`, which must already have an
* McpClientManager, PromptRegistry, and GeminiChat set up.
*/
async start(config: Config): Promise<void> {
if (!this.config) {
this.config = config;
} else {
throw new Error('Already started, you may only call `start` once.');
}
await Promise.all(
this.getExtensions()
.filter((e) => e.isActive)
.map(this.startExtension.bind(this)),
);
}
/**
* Unconditionally starts an `extension` and loads all its MCP servers,
* context, custom commands, etc. Assumes that `start` has already been called
* and we have a Config object.
*
* This should typically only be called from `start`, most other calls should
* go through `maybeStartExtension` which will only start the extension if
* extension reloading is enabled and the `config` object is initialized.
*/
protected async startExtension(extension: GeminiCLIExtension) {
if (!this.config) {
throw new Error('Cannot call `startExtension` prior to calling `start`.');
}
this.startingCount++;
this.eventEmitter?.emit('extensionsStarting', {
total: this.startingCount,
completed: this.startCompletedCount,
});
try {
await this.config.getMcpClientManager()!.startExtension(extension);
// TODO: Move all extension features here, including at least:
// - context file loading
// - custom command loading
// - excluded tool configuration
} finally {
this.startCompletedCount++;
this.eventEmitter?.emit('extensionsStarting', {
total: this.startingCount,
completed: this.startCompletedCount,
});
if (this.startingCount === this.startCompletedCount) {
this.startingCount = 0;
this.startCompletedCount = 0;
}
}
}
/**
* If extension reloading is enabled and `start` has already been called,
* then calls `startExtension` to include all extension features into the
* program.
*/
protected maybeStartExtension(
extension: GeminiCLIExtension,
): Promise<void> | undefined {
if (this.config && this.config.getEnableExtensionReloading()) {
return this.startExtension(extension);
}
return;
}
/**
* Unconditionally stops an `extension` and unloads all its MCP servers,
* context, custom commands, etc. Assumes that `start` has already been called
* and we have a Config object.
*
* Most calls should go through `maybeStopExtension` which will only stop the
* extension if extension reloading is enabled and the `config` object is
* initialized.
*/
protected async stopExtension(extension: GeminiCLIExtension) {
if (!this.config) {
throw new Error('Cannot call `stopExtension` prior to calling `start`.');
}
this.stoppingCount++;
this.eventEmitter?.emit('extensionsStopping', {
total: this.stoppingCount,
completed: this.stopCompletedCount,
});
try {
await this.config.getMcpClientManager()!.stopExtension(extension);
// TODO: Remove all extension features here, including at least:
// - context files
// - custom commands
// - excluded tools
} finally {
this.stopCompletedCount++;
this.eventEmitter?.emit('extensionsStopping', {
total: this.stoppingCount,
completed: this.stopCompletedCount,
});
if (this.stoppingCount === this.stopCompletedCount) {
this.stoppingCount = 0;
this.stopCompletedCount = 0;
}
}
}
/**
* If extension reloading is enabled and `start` has already been called,
* then this also performs all necessary steps to remove all extension
* features from the rest of the system.
*/
protected maybeStopExtension(
extension: GeminiCLIExtension,
): Promise<void> | undefined {
if (this.config && this.config.getEnableExtensionReloading()) {
return this.stopExtension(extension);
}
return;
}
}
export interface ExtensionEvents {
extensionEnabled: ExtensionEnableEvent[];
extensionDisabled: ExtensionDisableEvent[];
extensionLoaded: ExtensionLoadEvent[];
extensionUnloaded: ExtensionUnloadEvent[];
extensionInstalled: ExtensionInstallEvent[];
extensionUninstalled: ExtensionUninstallEvent[];
extensionUpdated: ExtensionUpdateEvent[];
extensionsStarting: ExtensionsStartingEvent[];
extensionsStopping: ExtensionsStoppingEvent[];
}
interface BaseExtensionEvent {
extension: GeminiCLIExtension;
export interface ExtensionsStartingEvent {
total: number;
completed: number;
}
export type ExtensionDisableEvent = BaseExtensionEvent;
export type ExtensionEnableEvent = BaseExtensionEvent;
export type ExtensionInstallEvent = BaseExtensionEvent;
export type ExtensionLoadEvent = BaseExtensionEvent;
export type ExtensionUnloadEvent = BaseExtensionEvent;
export type ExtensionUninstallEvent = BaseExtensionEvent;
export type ExtensionUpdateEvent = BaseExtensionEvent;
export class SimpleExtensionLoader implements ExtensionLoader {
private _eventEmitter = new EventEmitter<ExtensionEvents>();
constructor(private readonly extensions: GeminiCLIExtension[]) {}
export interface ExtensionsStoppingEvent {
total: number;
completed: number;
}
extensionEvents(): EventEmitter<ExtensionEvents> {
return this._eventEmitter;
export class SimpleExtensionLoader extends ExtensionLoader {
constructor(
protected readonly extensions: GeminiCLIExtension[],
eventEmitter?: EventEmitter<ExtensionEvents>,
) {
super(eventEmitter);
}
getExtensions(): GeminiCLIExtension[] {
return this.extensions;
}
/// Adds `extension` to the list of extensions and calls
/// `maybeStartExtension`.
///
/// This is intended for dynamic loading of extensions after calling `start`.
async loadExtension(extension: GeminiCLIExtension) {
this.extensions.push(extension);
await this.maybeStartExtension(extension);
}
/// Removes `extension` from the list of extensions and calls
// `maybeStopExtension` if it was found.
///
/// This is intended for dynamic unloading of extensions after calling `start`.
async unloadExtension(extension: GeminiCLIExtension) {
const index = this.extensions.indexOf(extension);
if (index === -1) return;
this.extensions.splice(index, 1);
await this.maybeStopExtension(extension);
}
}