Extensions MCP refactor (#12413)

This commit is contained in:
Jacob MacDonald
2025-11-04 07:51:18 -08:00
committed by GitHub
parent 2b77c1ded4
commit da4fa5ad75
28 changed files with 877 additions and 478 deletions
@@ -0,0 +1,108 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest';
import { SimpleExtensionLoader } from './extensionLoader.js';
import type { Config } from '../config/config.js';
import { type McpClientManager } from '../tools/mcp-client-manager.js';
describe('SimpleExtensionLoader', () => {
let mockConfig: Config;
let extensionReloadingEnabled: boolean;
let mockMcpClientManager: McpClientManager;
const activeExtension = {
name: 'test-extension',
isActive: true,
version: '1.0.0',
path: '/path/to/extension',
contextFiles: [],
id: '123',
};
const inactiveExtension = {
name: 'test-extension',
isActive: false,
version: '1.0.0',
path: '/path/to/extension',
contextFiles: [],
id: '123',
};
beforeEach(() => {
mockMcpClientManager = {
startExtension: vi.fn(),
stopExtension: vi.fn(),
} as unknown as McpClientManager;
extensionReloadingEnabled = false;
mockConfig = {
getMcpClientManager: () => mockMcpClientManager,
getEnableExtensionReloading: () => extensionReloadingEnabled,
} as unknown as Config;
});
afterEach(() => {
vi.restoreAllMocks();
});
it('should start active extensions', async () => {
const loader = new SimpleExtensionLoader([activeExtension]);
await loader.start(mockConfig);
expect(mockMcpClientManager.startExtension).toHaveBeenCalledExactlyOnceWith(
activeExtension,
);
});
it('should not start inactive extensions', async () => {
const loader = new SimpleExtensionLoader([inactiveExtension]);
await loader.start(mockConfig);
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
});
describe('interactive extension loading and unloading', () => {
it('should not call `start` or `stop` if the loader is not already started', async () => {
const loader = new SimpleExtensionLoader([]);
await loader.loadExtension(activeExtension);
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
await loader.unloadExtension(activeExtension);
expect(mockMcpClientManager.stopExtension).not.toHaveBeenCalled();
});
it('should start extensions that were explicitly loaded prior to initializing the loader', async () => {
const loader = new SimpleExtensionLoader([]);
await loader.loadExtension(activeExtension);
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
await loader.start(mockConfig);
expect(
mockMcpClientManager.startExtension,
).toHaveBeenCalledExactlyOnceWith(activeExtension);
});
it.each([true, false])(
'should only call `start` and `stop` if extension reloading is enabled ($i)',
async (reloadingEnabled) => {
extensionReloadingEnabled = reloadingEnabled;
const loader = new SimpleExtensionLoader([]);
await loader.start(mockConfig);
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
await loader.loadExtension(activeExtension);
if (reloadingEnabled) {
expect(
mockMcpClientManager.startExtension,
).toHaveBeenCalledExactlyOnceWith(activeExtension);
} else {
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
}
await loader.unloadExtension(activeExtension);
if (reloadingEnabled) {
expect(
mockMcpClientManager.stopExtension,
).toHaveBeenCalledExactlyOnceWith(activeExtension);
} else {
expect(mockMcpClientManager.stopExtension).not.toHaveBeenCalled();
}
},
);
});
});
+175 -26
View File
@@ -4,45 +4,194 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { EventEmitter } from 'node:events';
import type { GeminiCLIExtension } from '../config/config.js';
import type { EventEmitter } from 'node:events';
import type { Config, GeminiCLIExtension } from '../config/config.js';
export interface ExtensionLoader {
getExtensions(): GeminiCLIExtension[];
export abstract class ExtensionLoader {
// Assigned in `start`.
protected config: Config | undefined;
extensionEvents(): EventEmitter<ExtensionEvents>;
// Used to track the count of currently starting and stopping extensions and
// fire appropriate events.
protected startingCount: number = 0;
protected startCompletedCount: number = 0;
protected stoppingCount: number = 0;
protected stopCompletedCount: number = 0;
constructor(private readonly eventEmitter?: EventEmitter<ExtensionEvents>) {}
/**
* All currently known extensions, both active and inactive.
*/
abstract getExtensions(): GeminiCLIExtension[];
/**
* Fully initializes all active extensions.
*
* Called within `Config.initialize`, which must already have an
* McpClientManager, PromptRegistry, and GeminiChat set up.
*/
async start(config: Config): Promise<void> {
if (!this.config) {
this.config = config;
} else {
throw new Error('Already started, you may only call `start` once.');
}
await Promise.all(
this.getExtensions()
.filter((e) => e.isActive)
.map(this.startExtension.bind(this)),
);
}
/**
* Unconditionally starts an `extension` and loads all its MCP servers,
* context, custom commands, etc. Assumes that `start` has already been called
* and we have a Config object.
*
* This should typically only be called from `start`, most other calls should
* go through `maybeStartExtension` which will only start the extension if
* extension reloading is enabled and the `config` object is initialized.
*/
protected async startExtension(extension: GeminiCLIExtension) {
if (!this.config) {
throw new Error('Cannot call `startExtension` prior to calling `start`.');
}
this.startingCount++;
this.eventEmitter?.emit('extensionsStarting', {
total: this.startingCount,
completed: this.startCompletedCount,
});
try {
await this.config.getMcpClientManager()!.startExtension(extension);
// TODO: Move all extension features here, including at least:
// - context file loading
// - custom command loading
// - excluded tool configuration
} finally {
this.startCompletedCount++;
this.eventEmitter?.emit('extensionsStarting', {
total: this.startingCount,
completed: this.startCompletedCount,
});
if (this.startingCount === this.startCompletedCount) {
this.startingCount = 0;
this.startCompletedCount = 0;
}
}
}
/**
* If extension reloading is enabled and `start` has already been called,
* then calls `startExtension` to include all extension features into the
* program.
*/
protected maybeStartExtension(
extension: GeminiCLIExtension,
): Promise<void> | undefined {
if (this.config && this.config.getEnableExtensionReloading()) {
return this.startExtension(extension);
}
return;
}
/**
* Unconditionally stops an `extension` and unloads all its MCP servers,
* context, custom commands, etc. Assumes that `start` has already been called
* and we have a Config object.
*
* Most calls should go through `maybeStopExtension` which will only stop the
* extension if extension reloading is enabled and the `config` object is
* initialized.
*/
protected async stopExtension(extension: GeminiCLIExtension) {
if (!this.config) {
throw new Error('Cannot call `stopExtension` prior to calling `start`.');
}
this.stoppingCount++;
this.eventEmitter?.emit('extensionsStopping', {
total: this.stoppingCount,
completed: this.stopCompletedCount,
});
try {
await this.config.getMcpClientManager()!.stopExtension(extension);
// TODO: Remove all extension features here, including at least:
// - context files
// - custom commands
// - excluded tools
} finally {
this.stopCompletedCount++;
this.eventEmitter?.emit('extensionsStopping', {
total: this.stoppingCount,
completed: this.stopCompletedCount,
});
if (this.stoppingCount === this.stopCompletedCount) {
this.stoppingCount = 0;
this.stopCompletedCount = 0;
}
}
}
/**
* If extension reloading is enabled and `start` has already been called,
* then this also performs all necessary steps to remove all extension
* features from the rest of the system.
*/
protected maybeStopExtension(
extension: GeminiCLIExtension,
): Promise<void> | undefined {
if (this.config && this.config.getEnableExtensionReloading()) {
return this.stopExtension(extension);
}
return;
}
}
export interface ExtensionEvents {
extensionEnabled: ExtensionEnableEvent[];
extensionDisabled: ExtensionDisableEvent[];
extensionLoaded: ExtensionLoadEvent[];
extensionUnloaded: ExtensionUnloadEvent[];
extensionInstalled: ExtensionInstallEvent[];
extensionUninstalled: ExtensionUninstallEvent[];
extensionUpdated: ExtensionUpdateEvent[];
extensionsStarting: ExtensionsStartingEvent[];
extensionsStopping: ExtensionsStoppingEvent[];
}
interface BaseExtensionEvent {
extension: GeminiCLIExtension;
export interface ExtensionsStartingEvent {
total: number;
completed: number;
}
export type ExtensionDisableEvent = BaseExtensionEvent;
export type ExtensionEnableEvent = BaseExtensionEvent;
export type ExtensionInstallEvent = BaseExtensionEvent;
export type ExtensionLoadEvent = BaseExtensionEvent;
export type ExtensionUnloadEvent = BaseExtensionEvent;
export type ExtensionUninstallEvent = BaseExtensionEvent;
export type ExtensionUpdateEvent = BaseExtensionEvent;
export class SimpleExtensionLoader implements ExtensionLoader {
private _eventEmitter = new EventEmitter<ExtensionEvents>();
constructor(private readonly extensions: GeminiCLIExtension[]) {}
export interface ExtensionsStoppingEvent {
total: number;
completed: number;
}
extensionEvents(): EventEmitter<ExtensionEvents> {
return this._eventEmitter;
export class SimpleExtensionLoader extends ExtensionLoader {
constructor(
protected readonly extensions: GeminiCLIExtension[],
eventEmitter?: EventEmitter<ExtensionEvents>,
) {
super(eventEmitter);
}
getExtensions(): GeminiCLIExtension[] {
return this.extensions;
}
/// Adds `extension` to the list of extensions and calls
/// `maybeStartExtension`.
///
/// This is intended for dynamic loading of extensions after calling `start`.
async loadExtension(extension: GeminiCLIExtension) {
this.extensions.push(extension);
await this.maybeStartExtension(extension);
}
/// Removes `extension` from the list of extensions and calls
// `maybeStopExtension` if it was found.
///
/// This is intended for dynamic unloading of extensions after calling `start`.
async unloadExtension(extension: GeminiCLIExtension) {
const index = this.extensions.indexOf(extension);
if (index === -1) return;
this.extensions.splice(index, 1);
await this.maybeStopExtension(extension);
}
}