diff --git a/packages/cli/src/config/extension.test.ts b/packages/cli/src/config/extension.test.ts index 0597d441ec..aa8416a574 100644 --- a/packages/cli/src/config/extension.test.ts +++ b/packages/cli/src/config/extension.test.ts @@ -94,6 +94,15 @@ vi.mock('child_process', async (importOriginal) => { }; }); +const mockQuestion = vi.hoisted(() => vi.fn()); +const mockClose = vi.hoisted(() => vi.fn()); +vi.mock('node:readline', () => ({ + createInterface: vi.fn(() => ({ + question: mockQuestion, + close: mockClose, + })), +})); + const EXTENSIONS_DIRECTORY_NAME = path.join(GEMINI_DIR, 'extensions'); describe('loadExtensions', () => { @@ -244,6 +253,7 @@ describe('loadExtensions', () => { source: sourceExtDir, type: 'link', }); + expect(extensionName).toEqual('my-linked-extension'); const extensions = loadExtensions(); expect(extensions).toHaveLength(1); @@ -434,6 +444,7 @@ describe('installExtension', () => { let userExtensionsDir: string; beforeEach(() => { + mockQuestion.mockImplementation((_query, callback) => callback('y')); tempHomeDir = fs.mkdtempSync( path.join(os.tmpdir(), 'gemini-cli-test-home-'), ); @@ -448,6 +459,8 @@ describe('installExtension', () => { }); afterEach(() => { + mockQuestion.mockClear(); + mockClose.mockClear(); fs.rmSync(tempHomeDir, { recursive: true, force: true }); fs.rmSync(userExtensionsDir, { recursive: true, force: true }); }); @@ -565,6 +578,56 @@ describe('installExtension', () => { const logger = ClearcutLogger.getInstance({} as Config); expect(logger?.logExtensionInstallEvent).toHaveBeenCalled(); }); + + it('should continue installation if user accepts prompt for local extension with mcp servers', async () => { + const sourceExtDir = createExtension({ + extensionsDir: tempHomeDir, + name: 'my-local-extension', + version: '1.0.0', + mcpServers: { + 'test-server': { + command: 'node', + args: ['server.js'], + }, + }, + }); + + mockQuestion.mockImplementation((_query, callback) => callback('y')); + + await expect( + installExtension({ source: sourceExtDir, type: 'local' }), + ).resolves.toBe('my-local-extension'); + + expect(mockQuestion).toHaveBeenCalledWith( + expect.stringContaining('Do you want to continue? (y/n)'), + expect.any(Function), + ); + }); + + it('should cancel installation if user declines prompt for local extension with mcp servers', async () => { + const sourceExtDir = createExtension({ + extensionsDir: tempHomeDir, + name: 'my-local-extension', + version: '1.0.0', + mcpServers: { + 'test-server': { + command: 'node', + args: ['server.js'], + }, + }, + }); + + mockQuestion.mockImplementation((_query, callback) => callback('n')); + + await expect( + installExtension({ source: sourceExtDir, type: 'local' }), + ).rejects.toThrow('Installation cancelled by user.'); + + expect(mockQuestion).toHaveBeenCalledWith( + expect.stringContaining('Do you want to continue? (y/n)'), + expect.any(Function), + ); + }); }); describe('uninstallExtension', () => { @@ -807,6 +870,7 @@ describe('updateExtension', () => { afterEach(() => { fs.rmSync(tempHomeDir, { recursive: true, force: true }); + mockClose.mockClear(); }); it('should update a git-installed extension', async () => { diff --git a/packages/cli/src/config/extension.ts b/packages/cli/src/config/extension.ts index b03a815621..8cdfaba52c 100644 --- a/packages/cli/src/config/extension.ts +++ b/packages/cli/src/config/extension.ts @@ -372,6 +372,26 @@ async function cloneFromGit( } } +/** + * Asks users a prompt and awaits for a y/n response + * @param prompt A yes/no prompt to ask the user + * @returns Whether or not the user answers 'y' (yes) + */ +async function promptForContinuation(prompt: string): Promise { + const readline = await import('node:readline'); + const rl = readline.createInterface({ + input: process.stdin, + output: process.stdout, + }); + + return new Promise((resolve) => { + rl.question(prompt, (answer) => { + rl.close(); + resolve(answer.toLowerCase() === 'y'); + }); + }); +} + export async function installExtension( installMetadata: ExtensionInstallMetadata, cwd: string = process.cwd(), @@ -443,6 +463,26 @@ export async function installExtension( ); } + const mcpServerEntries = Object.entries( + newExtensionConfig.mcpServers || {}, + ); + if (mcpServerEntries.length) { + console.info('This extension will run the following MCP servers: '); + for (const [key, value] of mcpServerEntries) { + console.info(` * ${key}: ${value.description}`); + } + console.info( + 'The extension will append info to your gemini.md context', + ); + + const shouldContinue = await promptForContinuation( + 'Do you want to continue? (y/n): ', + ); + if (!shouldContinue) { + throw new Error('Installation cancelled by user.'); + } + } + await fs.promises.mkdir(destinationPath, { recursive: true }); if (installMetadata.type === 'local' || installMetadata.type === 'git') { @@ -599,7 +639,6 @@ export async function updateExtension( await copyExtension(extension.path, tempDir); await uninstallExtension(extension.config.name, cwd); await installExtension(extension.installMetadata, cwd); - const updatedExtensionStorage = new ExtensionStorage(extension.config.name); const updatedExtension = loadExtension( updatedExtensionStorage.getExtensionDir(),