From 638dd2f6c032adf985b977f465fecb9e4260208e Mon Sep 17 00:00:00 2001 From: Megha Bansal Date: Thu, 13 Nov 2025 21:28:01 -0800 Subject: [PATCH] Improve test code coverage for cli/command/extensions package (#12994) --- .../src/commands/extensions/disable.test.ts | 303 ++++++++++++++++++ .../src/commands/extensions/enable.test.ts | 208 ++++++++++++ .../examples/mcp-server/example.test.ts | 143 +++++++++ .../cli/src/commands/extensions/link.test.ts | 174 ++++++++++ .../cli/src/commands/extensions/list.test.ts | 137 ++++++++ .../src/commands/extensions/uninstall.test.ts | 175 +++++++++- .../src/commands/extensions/update.test.ts | 226 +++++++++++++ packages/cli/src/gemini.test.tsx | 3 +- packages/cli/src/zed-integration/acp.test.ts | 296 +++++++++++++++++ packages/cli/src/zed-integration/acp.ts | 234 +------------- .../cli/src/zed-integration/connection.ts | 229 +++++++++++++ packages/cli/src/zed-integration/schema.ts | 27 +- 12 files changed, 1904 insertions(+), 251 deletions(-) create mode 100644 packages/cli/src/commands/extensions/disable.test.ts create mode 100644 packages/cli/src/commands/extensions/enable.test.ts create mode 100644 packages/cli/src/commands/extensions/examples/mcp-server/example.test.ts create mode 100644 packages/cli/src/commands/extensions/link.test.ts create mode 100644 packages/cli/src/commands/extensions/list.test.ts create mode 100644 packages/cli/src/commands/extensions/update.test.ts create mode 100644 packages/cli/src/zed-integration/acp.test.ts create mode 100644 packages/cli/src/zed-integration/connection.ts diff --git a/packages/cli/src/commands/extensions/disable.test.ts b/packages/cli/src/commands/extensions/disable.test.ts new file mode 100644 index 0000000000..73d1eec135 --- /dev/null +++ b/packages/cli/src/commands/extensions/disable.test.ts @@ -0,0 +1,303 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + vi, + describe, + it, + expect, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; + +import { type CommandModule, type Argv } from 'yargs'; + +import { handleDisable, disableCommand } from './disable.js'; + +import { ExtensionManager } from '../../config/extension-manager.js'; + +import { + loadSettings, + SettingScope, + type LoadedSettings, +} from '../../config/settings.js'; +import { getErrorMessage } from '../../utils/errors.js'; + +// Mock dependencies + +vi.mock('../../config/extension-manager.js'); + +vi.mock('../../config/settings.js'); + +vi.mock('../../utils/errors.js'); + +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const actual = + await importOriginal(); + + return { + ...actual, + + debugLogger: { + log: vi.fn(), + + error: vi.fn(), + }, + }; +}); + +vi.mock('../../config/extensions/consent.js', () => ({ + requestConsentNonInteractive: vi.fn(), +})); + +vi.mock('../../config/extensions/extensionSettings.js', () => ({ + promptForSetting: vi.fn(), +})); + +describe('extensions disable command', () => { + const mockLoadSettings = vi.mocked(loadSettings); + + const mockGetErrorMessage = vi.mocked(getErrorMessage); + + const mockExtensionManager = vi.mocked(ExtensionManager); + + interface MockDebugLogger { + log: Mock; + error: Mock; + } + let mockDebugLogger: MockDebugLogger; + + beforeEach(async () => { + vi.clearAllMocks(); + + // We need to re-import the mocked module to get the fresh mock + + mockDebugLogger = (await import('@google/gemini-cli-core')) + .debugLogger as unknown as MockDebugLogger; + + mockLoadSettings.mockReturnValue({ + merged: {}, + } as unknown as LoadedSettings); + + mockExtensionManager.prototype.loadExtensions = vi + + .fn() + + .mockResolvedValue(undefined); + + mockExtensionManager.prototype.disableExtension = vi + + .fn() + + .mockResolvedValue(undefined); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('handleDisable', () => { + it.each([ + { + name: 'my-extension', + scope: undefined, + expectedScope: SettingScope.User, + expectedLog: + 'Extension "my-extension" successfully disabled for scope "undefined".', + }, + { + name: 'my-extension', + scope: 'user', + expectedScope: SettingScope.User, + expectedLog: + 'Extension "my-extension" successfully disabled for scope "user".', + }, + { + name: 'my-extension', + scope: 'workspace', + expectedScope: SettingScope.Workspace, + expectedLog: + 'Extension "my-extension" successfully disabled for scope "workspace".', + }, + ])( + 'should disable an extension in the $expectedScope scope when scope is $scope', + async ({ name, scope, expectedScope, expectedLog }) => { + const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir'); + + await handleDisable({ name, scope }); + + expect(mockExtensionManager).toHaveBeenCalledWith( + expect.objectContaining({ + workspaceDir: '/test/dir', + }), + ); + + expect( + mockExtensionManager.prototype.loadExtensions, + ).toHaveBeenCalled(); + + expect( + mockExtensionManager.prototype.disableExtension, + ).toHaveBeenCalledWith(name, expectedScope); + + expect(mockDebugLogger.log).toHaveBeenCalledWith(expectedLog); + + mockCwd.mockRestore(); + }, + ); + + it('should log an error message and exit with code 1 when extension disabling fails', async () => { + const mockProcessExit = vi + + .spyOn(process, 'exit') + + .mockImplementation((() => {}) as ( + code?: string | number | null | undefined, + ) => never); + + const error = new Error('Disable failed'); + + ( + mockExtensionManager.prototype.disableExtension as Mock + ).mockRejectedValue(error); + + mockGetErrorMessage.mockReturnValue('Disable failed message'); + + await handleDisable({ name: 'my-extension' }); + + expect(mockDebugLogger.error).toHaveBeenCalledWith( + 'Disable failed message', + ); + + expect(mockProcessExit).toHaveBeenCalledWith(1); + + mockProcessExit.mockRestore(); + }); + }); + + describe('disableCommand', () => { + const command = disableCommand as CommandModule; + + it('should have correct command and describe', () => { + expect(command.command).toBe('disable [--scope] '); + + expect(command.describe).toBe('Disables an extension.'); + }); + + describe('builder', () => { + interface MockYargs { + positional: Mock; + + option: Mock; + + check: Mock; + } + + let yargsMock: MockYargs; + + beforeEach(() => { + yargsMock = { + positional: vi.fn().mockReturnThis(), + + option: vi.fn().mockReturnThis(), + + check: vi.fn().mockReturnThis(), + }; + }); + + it('should configure positional and option arguments', () => { + (command.builder as (yargs: Argv) => Argv)( + yargsMock as unknown as Argv, + ); + + expect(yargsMock.positional).toHaveBeenCalledWith('name', { + describe: 'The name of the extension to disable.', + + type: 'string', + }); + + expect(yargsMock.option).toHaveBeenCalledWith('scope', { + describe: 'The scope to disable the extension in.', + + type: 'string', + + default: SettingScope.User, + }); + + expect(yargsMock.check).toHaveBeenCalled(); + }); + + it('check function should throw for invalid scope', () => { + (command.builder as (yargs: Argv) => Argv)( + yargsMock as unknown as Argv, + ); + + const checkCallback = yargsMock.check.mock.calls[0][0]; + + const expectedError = `Invalid scope: invalid. Please use one of ${Object.values( + SettingScope, + ) + + .map((s) => s.toLowerCase()) + + .join(', ')}.`; + + expect(() => checkCallback({ scope: 'invalid' })).toThrow( + expectedError, + ); + }); + + it.each(['user', 'workspace', 'USER', 'WorkSpace'])( + 'check function should return true for valid scope "%s"', + (scope) => { + (command.builder as (yargs: Argv) => Argv)( + yargsMock as unknown as Argv, + ); + + const checkCallback = yargsMock.check.mock.calls[0][0]; + + expect(checkCallback({ scope })).toBe(true); + }, + ); + }); + + it('handler should trigger extension disabling', async () => { + const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir'); + + interface TestArgv { + name: string; + scope: string; + [key: string]: unknown; + } + const argv: TestArgv = { + name: 'test-ext', + scope: 'workspace', + _: [], + $0: '', + }; + + await (command.handler as unknown as (args: TestArgv) => void)(argv); + expect(mockExtensionManager).toHaveBeenCalledWith( + expect.objectContaining({ + workspaceDir: '/test/dir', + }), + ); + + expect(mockExtensionManager.prototype.loadExtensions).toHaveBeenCalled(); + + expect( + mockExtensionManager.prototype.disableExtension, + ).toHaveBeenCalledWith('test-ext', SettingScope.Workspace); + + expect(mockDebugLogger.log).toHaveBeenCalledWith( + 'Extension "test-ext" successfully disabled for scope "workspace".', + ); + + mockCwd.mockRestore(); + }); + }); +}); diff --git a/packages/cli/src/commands/extensions/enable.test.ts b/packages/cli/src/commands/extensions/enable.test.ts new file mode 100644 index 0000000000..84d323a15c --- /dev/null +++ b/packages/cli/src/commands/extensions/enable.test.ts @@ -0,0 +1,208 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + vi, + describe, + it, + expect, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; +import { type CommandModule, type Argv } from 'yargs'; +import { handleEnable, enableCommand } from './enable.js'; +import { ExtensionManager } from '../../config/extension-manager.js'; +import { + loadSettings, + SettingScope, + type LoadedSettings, +} from '../../config/settings.js'; +import { FatalConfigError } from '@google/gemini-cli-core'; + +// Mock dependencies +vi.mock('../../config/extension-manager.js'); +vi.mock('../../config/settings.js'); +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const actual = + await importOriginal(); + return { + ...actual, + debugLogger: { + log: vi.fn(), + error: vi.fn(), + }, + getErrorMessage: vi.fn((error: { message: string }) => error.message), + FatalConfigError: class extends Error { + constructor(message: string) { + super(message); + this.name = 'FatalConfigError'; + } + }, + }; +}); +vi.mock('../../config/extensions/consent.js'); +vi.mock('../../config/extensions/extensionSettings.js'); + +describe('extensions enable command', () => { + const mockLoadSettings = vi.mocked(loadSettings); + const mockExtensionManager = vi.mocked(ExtensionManager); + interface MockDebugLogger { + log: Mock; + error: Mock; + } + let mockDebugLogger: MockDebugLogger; + + beforeEach(async () => { + vi.clearAllMocks(); + mockDebugLogger = (await import('@google/gemini-cli-core')) + .debugLogger as unknown as MockDebugLogger; + mockLoadSettings.mockReturnValue({ + merged: {}, + } as unknown as LoadedSettings); + mockExtensionManager.prototype.loadExtensions = vi + .fn() + .mockResolvedValue(undefined); + mockExtensionManager.prototype.enableExtension = vi.fn(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('handleEnable', () => { + it.each([ + { + name: 'my-extension', + scope: undefined, + expectedScope: SettingScope.User, + expectedLog: + 'Extension "my-extension" successfully enabled in all scopes.', + }, + { + name: 'my-extension', + scope: 'workspace', + expectedScope: SettingScope.Workspace, + expectedLog: + 'Extension "my-extension" successfully enabled for scope "workspace".', + }, + ])( + 'should enable an extension in the $expectedScope scope when scope is $scope', + async ({ name, scope, expectedScope, expectedLog }) => { + const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir'); + await handleEnable({ name, scope }); + + expect(mockExtensionManager).toHaveBeenCalledWith( + expect.objectContaining({ + workspaceDir: '/test/dir', + }), + ); + expect( + mockExtensionManager.prototype.loadExtensions, + ).toHaveBeenCalled(); + expect( + mockExtensionManager.prototype.enableExtension, + ).toHaveBeenCalledWith(name, expectedScope); + expect(mockDebugLogger.log).toHaveBeenCalledWith(expectedLog); + mockCwd.mockRestore(); + }, + ); + + it('should throw FatalConfigError when extension enabling fails', async () => { + const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir'); + const error = new Error('Enable failed'); + ( + mockExtensionManager.prototype.enableExtension as Mock + ).mockImplementation(() => { + throw error; + }); + + const promise = handleEnable({ name: 'my-extension' }); + await expect(promise).rejects.toThrow(FatalConfigError); + await expect(promise).rejects.toThrow('Enable failed'); + + mockCwd.mockRestore(); + }); + }); + + describe('enableCommand', () => { + const command = enableCommand as CommandModule; + + it('should have correct command and describe', () => { + expect(command.command).toBe('enable [--scope] '); + expect(command.describe).toBe('Enables an extension.'); + }); + + describe('builder', () => { + interface MockYargs { + positional: Mock; + option: Mock; + check: Mock; + } + + let yargsMock: MockYargs; + beforeEach(() => { + yargsMock = { + positional: vi.fn().mockReturnThis(), + option: vi.fn().mockReturnThis(), + check: vi.fn().mockReturnThis(), + }; + }); + + it('should configure positional and option arguments', () => { + (command.builder as (yargs: Argv) => Argv)( + yargsMock as unknown as Argv, + ); + expect(yargsMock.positional).toHaveBeenCalledWith('name', { + describe: 'The name of the extension to enable.', + type: 'string', + }); + expect(yargsMock.option).toHaveBeenCalledWith('scope', { + describe: + 'The scope to enable the extension in. If not set, will be enabled in all scopes.', + type: 'string', + }); + expect(yargsMock.check).toHaveBeenCalled(); + }); + + it('check function should throw for invalid scope', () => { + (command.builder as (yargs: Argv) => Argv)( + yargsMock as unknown as Argv, + ); + const checkCallback = yargsMock.check.mock.calls[0][0]; + const expectedError = `Invalid scope: invalid. Please use one of ${Object.values( + SettingScope, + ) + .map((s) => s.toLowerCase()) + .join(', ')}.`; + expect(() => checkCallback({ scope: 'invalid' })).toThrow( + expectedError, + ); + }); + }); + + it('handler should call handleEnable', async () => { + const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir'); + interface TestArgv { + name: string; + scope: string; + [key: string]: unknown; + } + const argv: TestArgv = { + name: 'test-ext', + scope: 'workspace', + _: [], + $0: '', + }; + await (command.handler as unknown as (args: TestArgv) => void)(argv); + + expect( + mockExtensionManager.prototype.enableExtension, + ).toHaveBeenCalledWith('test-ext', SettingScope.Workspace); + mockCwd.mockRestore(); + }); + }); +}); diff --git a/packages/cli/src/commands/extensions/examples/mcp-server/example.test.ts b/packages/cli/src/commands/extensions/examples/mcp-server/example.test.ts new file mode 100644 index 0000000000..6b732ae981 --- /dev/null +++ b/packages/cli/src/commands/extensions/examples/mcp-server/example.test.ts @@ -0,0 +1,143 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; +import { z } from 'zod'; + +// Mock the MCP server and transport +const mockRegisterTool = vi.fn(); +const mockRegisterPrompt = vi.fn(); +const mockConnect = vi.fn(); + +vi.mock('@modelcontextprotocol/sdk/server/mcp.js', () => ({ + McpServer: vi.fn().mockImplementation(() => ({ + registerTool: mockRegisterTool, + registerPrompt: mockRegisterPrompt, + connect: mockConnect, + })), +})); + +vi.mock('@modelcontextprotocol/sdk/server/stdio.js', () => ({ + StdioServerTransport: vi.fn(), +})); + +describe('MCP Server Example', () => { + beforeEach(async () => { + // Dynamically import the server setup after mocks are in place + await import('./example.js'); + }); + + afterEach(() => { + vi.clearAllMocks(); + vi.resetModules(); + }); + + it('should create an McpServer with the correct name and version', () => { + expect(McpServer).toHaveBeenCalledWith({ + name: 'prompt-server', + version: '1.0.0', + }); + }); + + it('should register the "fetch_posts" tool', () => { + expect(mockRegisterTool).toHaveBeenCalledWith( + 'fetch_posts', + { + description: 'Fetches a list of posts from a public API.', + inputSchema: z.object({}).shape, + }, + expect.any(Function), + ); + }); + + it('should register the "poem-writer" prompt', () => { + expect(mockRegisterPrompt).toHaveBeenCalledWith( + 'poem-writer', + { + title: 'Poem Writer', + description: 'Write a nice haiku', + argsSchema: expect.any(Object), + }, + expect.any(Function), + ); + }); + + it('should connect the server to an StdioServerTransport', () => { + expect(StdioServerTransport).toHaveBeenCalled(); + expect(mockConnect).toHaveBeenCalledWith(expect.any(StdioServerTransport)); + }); + + describe('fetch_posts tool implementation', () => { + it('should fetch posts and return a formatted response', async () => { + const mockPosts = [ + { id: 1, title: 'Post 1' }, + { id: 2, title: 'Post 2' }, + ]; + global.fetch = vi.fn().mockResolvedValue({ + json: vi.fn().mockResolvedValue(mockPosts), + }); + + const toolFn = (mockRegisterTool as Mock).mock.calls[0][2]; + const result = await toolFn(); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://jsonplaceholder.typicode.com/posts', + ); + expect(result).toEqual({ + content: [ + { + type: 'text', + text: JSON.stringify({ posts: mockPosts }), + }, + ], + }); + }); + }); + + describe('poem-writer prompt implementation', () => { + it('should generate a prompt with a title', () => { + const promptFn = (mockRegisterPrompt as Mock).mock.calls[0][2]; + const result = promptFn({ title: 'My Poem' }); + expect(result).toEqual({ + messages: [ + { + role: 'user', + content: { + type: 'text', + text: 'Write a haiku called My Poem. Note that a haiku is 5 syllables followed by 7 syllables followed by 5 syllables ', + }, + }, + ], + }); + }); + + it('should generate a prompt with a title and mood', () => { + const promptFn = (mockRegisterPrompt as Mock).mock.calls[0][2]; + const result = promptFn({ title: 'My Poem', mood: 'sad' }); + expect(result).toEqual({ + messages: [ + { + role: 'user', + content: { + type: 'text', + text: 'Write a haiku with the mood sad called My Poem. Note that a haiku is 5 syllables followed by 7 syllables followed by 5 syllables ', + }, + }, + ], + }); + }); + }); +}); diff --git a/packages/cli/src/commands/extensions/link.test.ts b/packages/cli/src/commands/extensions/link.test.ts new file mode 100644 index 0000000000..550e7d6024 --- /dev/null +++ b/packages/cli/src/commands/extensions/link.test.ts @@ -0,0 +1,174 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + vi, + describe, + it, + expect, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; +import { type CommandModule, type Argv } from 'yargs'; +import { handleLink, linkCommand } from './link.js'; +import { ExtensionManager } from '../../config/extension-manager.js'; +import { loadSettings, type LoadedSettings } from '../../config/settings.js'; +import { getErrorMessage } from '../../utils/errors.js'; + +// Mock dependencies +vi.mock('../../config/extension-manager.js'); +vi.mock('../../config/settings.js'); +vi.mock('../../utils/errors.js'); +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const actual = + await importOriginal(); + return { + ...actual, + debugLogger: { + log: vi.fn(), + error: vi.fn(), + }, + }; +}); +vi.mock('../../config/extensions/consent.js', () => ({ + requestConsentNonInteractive: vi.fn(), +})); +vi.mock('../../config/extensions/extensionSettings.js', () => ({ + promptForSetting: vi.fn(), +})); + +describe('extensions link command', () => { + const mockLoadSettings = vi.mocked(loadSettings); + const mockGetErrorMessage = vi.mocked(getErrorMessage); + const mockExtensionManager = vi.mocked(ExtensionManager); + interface MockDebugLogger { + log: Mock; + error: Mock; + } + let mockDebugLogger: MockDebugLogger; + + beforeEach(async () => { + vi.clearAllMocks(); + mockDebugLogger = (await import('@google/gemini-cli-core')) + .debugLogger as unknown as MockDebugLogger; + mockLoadSettings.mockReturnValue({ + merged: {}, + } as unknown as LoadedSettings); + mockExtensionManager.prototype.loadExtensions = vi + .fn() + .mockResolvedValue(undefined); + mockExtensionManager.prototype.installOrUpdateExtension = vi + .fn() + .mockResolvedValue({ name: 'my-linked-extension' }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('handleLink', () => { + it('should link an extension from a local path', async () => { + const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir'); + await handleLink({ path: '/local/path/to/extension' }); + + expect(mockExtensionManager).toHaveBeenCalledWith( + expect.objectContaining({ + workspaceDir: '/test/dir', + }), + ); + expect(mockExtensionManager.prototype.loadExtensions).toHaveBeenCalled(); + expect( + mockExtensionManager.prototype.installOrUpdateExtension, + ).toHaveBeenCalledWith({ + source: '/local/path/to/extension', + type: 'link', + }); + expect(mockDebugLogger.log).toHaveBeenCalledWith( + 'Extension "my-linked-extension" linked successfully and enabled.', + ); + mockCwd.mockRestore(); + }); + + it('should log an error message and exit with code 1 when linking fails', async () => { + const mockProcessExit = vi + .spyOn(process, 'exit') + .mockImplementation((() => {}) as ( + code?: string | number | null | undefined, + ) => never); + const error = new Error('Link failed'); + ( + mockExtensionManager.prototype.installOrUpdateExtension as Mock + ).mockRejectedValue(error); + mockGetErrorMessage.mockReturnValue('Link failed message'); + + await handleLink({ path: '/local/path/to/extension' }); + + expect(mockDebugLogger.error).toHaveBeenCalledWith('Link failed message'); + expect(mockProcessExit).toHaveBeenCalledWith(1); + mockProcessExit.mockRestore(); + }); + }); + + describe('linkCommand', () => { + const command = linkCommand as CommandModule; + + it('should have correct command and describe', () => { + expect(command.command).toBe('link '); + expect(command.describe).toBe( + 'Links an extension from a local path. Updates made to the local path will always be reflected.', + ); + }); + + describe('builder', () => { + interface MockYargs { + positional: Mock; + check: Mock; + } + + let yargsMock: MockYargs; + beforeEach(() => { + yargsMock = { + positional: vi.fn().mockReturnThis(), + check: vi.fn().mockReturnThis(), + }; + }); + + it('should configure positional argument', () => { + (command.builder as (yargs: Argv) => Argv)( + yargsMock as unknown as Argv, + ); + expect(yargsMock.positional).toHaveBeenCalledWith('path', { + describe: 'The name of the extension to link.', + type: 'string', + }); + expect(yargsMock.check).toHaveBeenCalled(); + }); + }); + + it('handler should call handleLink', async () => { + const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir'); + interface TestArgv { + path: string; + [key: string]: unknown; + } + const argv: TestArgv = { + path: '/local/path/to/extension', + _: [], + $0: '', + }; + await (command.handler as unknown as (args: TestArgv) => void)(argv); + + expect( + mockExtensionManager.prototype.installOrUpdateExtension, + ).toHaveBeenCalledWith({ + source: '/local/path/to/extension', + type: 'link', + }); + mockCwd.mockRestore(); + }); + }); +}); diff --git a/packages/cli/src/commands/extensions/list.test.ts b/packages/cli/src/commands/extensions/list.test.ts new file mode 100644 index 0000000000..283a34f1e7 --- /dev/null +++ b/packages/cli/src/commands/extensions/list.test.ts @@ -0,0 +1,137 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + vi, + describe, + it, + expect, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; +import { type CommandModule } from 'yargs'; +import { handleList, listCommand } from './list.js'; +import { ExtensionManager } from '../../config/extension-manager.js'; +import { loadSettings, type LoadedSettings } from '../../config/settings.js'; +import { getErrorMessage } from '../../utils/errors.js'; + +// Mock dependencies +vi.mock('../../config/extension-manager.js'); +vi.mock('../../config/settings.js'); +vi.mock('../../utils/errors.js'); +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const actual = + await importOriginal(); + return { + ...actual, + debugLogger: { + log: vi.fn(), + error: vi.fn(), + }, + }; +}); +vi.mock('../../config/extensions/consent.js', () => ({ + requestConsentNonInteractive: vi.fn(), +})); +vi.mock('../../config/extensions/extensionSettings.js', () => ({ + promptForSetting: vi.fn(), +})); + +describe('extensions list command', () => { + const mockLoadSettings = vi.mocked(loadSettings); + const mockGetErrorMessage = vi.mocked(getErrorMessage); + const mockExtensionManager = vi.mocked(ExtensionManager); + interface MockDebugLogger { + log: Mock; + error: Mock; + } + let mockDebugLogger: MockDebugLogger; + + beforeEach(async () => { + vi.clearAllMocks(); + mockDebugLogger = (await import('@google/gemini-cli-core')) + .debugLogger as unknown as MockDebugLogger; + mockLoadSettings.mockReturnValue({ + merged: {}, + } as unknown as LoadedSettings); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('handleList', () => { + it('should log a message if no extensions are installed', async () => { + const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir'); + mockExtensionManager.prototype.loadExtensions = vi + .fn() + .mockResolvedValue([]); + await handleList(); + + expect(mockDebugLogger.log).toHaveBeenCalledWith( + 'No extensions installed.', + ); + mockCwd.mockRestore(); + }); + + it('should list all installed extensions', async () => { + const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir'); + const extensions = [ + { name: 'ext1', version: '1.0.0' }, + { name: 'ext2', version: '2.0.0' }, + ]; + mockExtensionManager.prototype.loadExtensions = vi + .fn() + .mockResolvedValue(extensions); + mockExtensionManager.prototype.toOutputString = vi.fn( + (ext) => `${ext.name}@${ext.version}`, + ); + await handleList(); + + expect(mockDebugLogger.log).toHaveBeenCalledWith( + 'ext1@1.0.0\n\next2@2.0.0', + ); + mockCwd.mockRestore(); + }); + + it('should log an error message and exit with code 1 when listing fails', async () => { + const mockProcessExit = vi + .spyOn(process, 'exit') + .mockImplementation((() => {}) as ( + code?: string | number | null | undefined, + ) => never); + const error = new Error('List failed'); + mockExtensionManager.prototype.loadExtensions = vi + .fn() + .mockRejectedValue(error); + mockGetErrorMessage.mockReturnValue('List failed message'); + + await handleList(); + + expect(mockDebugLogger.error).toHaveBeenCalledWith('List failed message'); + expect(mockProcessExit).toHaveBeenCalledWith(1); + mockProcessExit.mockRestore(); + }); + }); + + describe('listCommand', () => { + const command = listCommand as CommandModule; + + it('should have correct command and describe', () => { + expect(command.command).toBe('list'); + expect(command.describe).toBe('Lists installed extensions.'); + }); + + it('handler should call handleList', async () => { + mockExtensionManager.prototype.loadExtensions = vi + .fn() + .mockResolvedValue([]); + await (command.handler as () => Promise)(); + expect(mockExtensionManager.prototype.loadExtensions).toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/cli/src/commands/extensions/uninstall.test.ts b/packages/cli/src/commands/extensions/uninstall.test.ts index e202845878..d639b7442f 100644 --- a/packages/cli/src/commands/extensions/uninstall.test.ts +++ b/packages/cli/src/commands/extensions/uninstall.test.ts @@ -4,18 +4,171 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect } from 'vitest'; -import { uninstallCommand } from './uninstall.js'; -import yargs from 'yargs'; +import { + vi, + describe, + it, + expect, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; +import { type CommandModule, type Argv } from 'yargs'; +import { handleUninstall, uninstallCommand } from './uninstall.js'; +import { ExtensionManager } from '../../config/extension-manager.js'; +import { loadSettings, type LoadedSettings } from '../../config/settings.js'; +import { getErrorMessage } from '../../utils/errors.js'; + +// Mock dependencies +vi.mock('../../config/extension-manager.js'); +vi.mock('../../config/settings.js'); +vi.mock('../../utils/errors.js'); +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const actual = + await importOriginal(); + return { + ...actual, + debugLogger: { + log: vi.fn(), + error: vi.fn(), + }, + }; +}); +vi.mock('../../config/extensions/consent.js', () => ({ + requestConsentNonInteractive: vi.fn(), +})); +vi.mock('../../config/extensions/extensionSettings.js', () => ({ + promptForSetting: vi.fn(), +})); describe('extensions uninstall command', () => { - it('should fail if no source is provided', () => { - const validationParser = yargs([]) - .command(uninstallCommand) - .fail(false) - .locale('en'); - expect(() => validationParser.parse('uninstall')).toThrow( - 'Not enough non-option arguments: got 0, need at least 1', - ); + const mockLoadSettings = vi.mocked(loadSettings); + const mockGetErrorMessage = vi.mocked(getErrorMessage); + const mockExtensionManager = vi.mocked(ExtensionManager); + interface MockDebugLogger { + log: Mock; + error: Mock; + } + let mockDebugLogger: MockDebugLogger; + + beforeEach(async () => { + vi.clearAllMocks(); + mockDebugLogger = (await import('@google/gemini-cli-core')) + .debugLogger as unknown as MockDebugLogger; + mockLoadSettings.mockReturnValue({ + merged: {}, + } as unknown as LoadedSettings); + mockExtensionManager.prototype.loadExtensions = vi + .fn() + .mockResolvedValue(undefined); + mockExtensionManager.prototype.uninstallExtension = vi + .fn() + .mockResolvedValue(undefined); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('handleUninstall', () => { + it('should uninstall an extension', async () => { + const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir'); + await handleUninstall({ name: 'my-extension' }); + + expect(mockExtensionManager).toHaveBeenCalledWith( + expect.objectContaining({ + workspaceDir: '/test/dir', + }), + ); + expect(mockExtensionManager.prototype.loadExtensions).toHaveBeenCalled(); + expect( + mockExtensionManager.prototype.uninstallExtension, + ).toHaveBeenCalledWith('my-extension', false); + expect(mockDebugLogger.log).toHaveBeenCalledWith( + 'Extension "my-extension" successfully uninstalled.', + ); + mockCwd.mockRestore(); + }); + + it('should log an error message and exit with code 1 when uninstallation fails', async () => { + const mockProcessExit = vi + .spyOn(process, 'exit') + .mockImplementation((() => {}) as ( + code?: string | number | null | undefined, + ) => never); + const error = new Error('Uninstall failed'); + ( + mockExtensionManager.prototype.uninstallExtension as Mock + ).mockRejectedValue(error); + mockGetErrorMessage.mockReturnValue('Uninstall failed message'); + + await handleUninstall({ name: 'my-extension' }); + + expect(mockDebugLogger.error).toHaveBeenCalledWith( + 'Uninstall failed message', + ); + expect(mockProcessExit).toHaveBeenCalledWith(1); + mockProcessExit.mockRestore(); + }); + }); + + describe('uninstallCommand', () => { + const command = uninstallCommand as CommandModule; + + it('should have correct command and describe', () => { + expect(command.command).toBe('uninstall '); + expect(command.describe).toBe('Uninstalls an extension.'); + }); + + describe('builder', () => { + interface MockYargs { + positional: Mock; + check: Mock; + } + + let yargsMock: MockYargs; + beforeEach(() => { + yargsMock = { + positional: vi.fn().mockReturnThis(), + check: vi.fn().mockReturnThis(), + }; + }); + + it('should configure positional argument', () => { + (command.builder as (yargs: Argv) => Argv)( + yargsMock as unknown as Argv, + ); + expect(yargsMock.positional).toHaveBeenCalledWith('name', { + describe: 'The name or source path of the extension to uninstall.', + type: 'string', + }); + expect(yargsMock.check).toHaveBeenCalled(); + }); + + it('check function should throw for missing name', () => { + (command.builder as (yargs: Argv) => Argv)( + yargsMock as unknown as Argv, + ); + const checkCallback = yargsMock.check.mock.calls[0][0]; + expect(() => checkCallback({ name: '' })).toThrow( + 'Please include the name of the extension to uninstall as a positional argument.', + ); + }); + }); + + it('handler should call handleUninstall', async () => { + const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir'); + interface TestArgv { + name: string; + [key: string]: unknown; + } + const argv: TestArgv = { name: 'my-extension', _: [], $0: '' }; + await (command.handler as unknown as (args: TestArgv) => void)(argv); + + expect( + mockExtensionManager.prototype.uninstallExtension, + ).toHaveBeenCalledWith('my-extension', false); + mockCwd.mockRestore(); + }); }); }); diff --git a/packages/cli/src/commands/extensions/update.test.ts b/packages/cli/src/commands/extensions/update.test.ts new file mode 100644 index 0000000000..a5109910c1 --- /dev/null +++ b/packages/cli/src/commands/extensions/update.test.ts @@ -0,0 +1,226 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + vi, + describe, + it, + expect, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; +import { type CommandModule, type Argv } from 'yargs'; +import { handleUpdate, updateCommand } from './update.js'; +import { ExtensionManager } from '../../config/extension-manager.js'; +import { loadSettings, type LoadedSettings } from '../../config/settings.js'; +import * as update from '../../config/extensions/update.js'; +import * as github from '../../config/extensions/github.js'; +import { ExtensionUpdateState } from '../../ui/state/extensions.js'; + +// Mock dependencies +vi.mock('../../config/extension-manager.js'); +vi.mock('../../config/settings.js'); +vi.mock('../../utils/errors.js'); +vi.mock('../../config/extensions/update.js'); +vi.mock('../../config/extensions/github.js'); +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const actual = + await importOriginal(); + return { + ...actual, + debugLogger: { + log: vi.fn(), + error: vi.fn(), + }, + }; +}); +vi.mock('../../config/extensions/consent.js', () => ({ + requestConsentNonInteractive: vi.fn(), +})); +vi.mock('../../config/extensions/extensionSettings.js', () => ({ + promptForSetting: vi.fn(), +})); + +describe('extensions update command', () => { + const mockLoadSettings = vi.mocked(loadSettings); + const mockExtensionManager = vi.mocked(ExtensionManager); + const mockUpdateExtension = vi.mocked(update.updateExtension); + const mockCheckForExtensionUpdate = vi.mocked(github.checkForExtensionUpdate); + const mockCheckForAllExtensionUpdates = vi.mocked( + update.checkForAllExtensionUpdates, + ); + const mockUpdateAllUpdatableExtensions = vi.mocked( + update.updateAllUpdatableExtensions, + ); + + interface MockDebugLogger { + log: Mock; + error: Mock; + } + let mockDebugLogger: MockDebugLogger; + + beforeEach(async () => { + vi.clearAllMocks(); + mockDebugLogger = (await import('@google/gemini-cli-core')) + .debugLogger as unknown as MockDebugLogger; + mockLoadSettings.mockReturnValue({ + merged: { experimental: { extensionReloading: true } }, + } as unknown as LoadedSettings); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('handleUpdate', () => { + it.each([ + { + state: ExtensionUpdateState.UPDATE_AVAILABLE, + expectedLog: + 'Extension "my-extension" successfully updated: 1.0.0 → 1.1.0.', + shouldCallUpdateExtension: true, + }, + { + state: ExtensionUpdateState.UP_TO_DATE, + expectedLog: 'Extension "my-extension" is already up to date.', + shouldCallUpdateExtension: false, + }, + ])( + 'should handle single extension update state: $state', + async ({ state, expectedLog, shouldCallUpdateExtension }) => { + const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir'); + const extensions = [{ name: 'my-extension', installMetadata: {} }]; + mockExtensionManager.prototype.loadExtensions = vi + .fn() + .mockResolvedValue(extensions); + mockCheckForExtensionUpdate.mockResolvedValue(state); + mockUpdateExtension.mockResolvedValue({ + name: 'my-extension', + originalVersion: '1.0.0', + updatedVersion: '1.1.0', + }); + + await handleUpdate({ name: 'my-extension' }); + + expect(mockDebugLogger.log).toHaveBeenCalledWith(expectedLog); + if (shouldCallUpdateExtension) { + expect(mockUpdateExtension).toHaveBeenCalled(); + } else { + expect(mockUpdateExtension).not.toHaveBeenCalled(); + } + mockCwd.mockRestore(); + }, + ); + + it.each([ + { + updatedExtensions: [ + { name: 'ext1', originalVersion: '1.0.0', updatedVersion: '1.1.0' }, + { name: 'ext2', originalVersion: '2.0.0', updatedVersion: '2.1.0' }, + ], + expectedLog: + 'Extension "ext1" successfully updated: 1.0.0 → 1.1.0.\nExtension "ext2" successfully updated: 2.0.0 → 2.1.0.', + }, + { + updatedExtensions: [], + expectedLog: 'No extensions to update.', + }, + ])( + 'should handle updating all extensions: %s', + async ({ updatedExtensions, expectedLog }) => { + const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir'); + mockExtensionManager.prototype.loadExtensions = vi + .fn() + .mockResolvedValue([]); + mockCheckForAllExtensionUpdates.mockResolvedValue(undefined); + mockUpdateAllUpdatableExtensions.mockResolvedValue(updatedExtensions); + + await handleUpdate({ all: true }); + + expect(mockDebugLogger.log).toHaveBeenCalledWith(expectedLog); + mockCwd.mockRestore(); + }, + ); + }); + + describe('updateCommand', () => { + const command = updateCommand as CommandModule; + + it('should have correct command and describe', () => { + expect(command.command).toBe('update [] [--all]'); + expect(command.describe).toBe( + 'Updates all extensions or a named extension to the latest version.', + ); + }); + + describe('builder', () => { + interface MockYargs { + positional: Mock; + option: Mock; + conflicts: Mock; + check: Mock; + } + + let yargsMock: MockYargs; + beforeEach(() => { + yargsMock = { + positional: vi.fn().mockReturnThis(), + option: vi.fn().mockReturnThis(), + conflicts: vi.fn().mockReturnThis(), + check: vi.fn().mockReturnThis(), + }; + }); + + it('should configure arguments', () => { + (command.builder as (yargs: Argv) => Argv)( + yargsMock as unknown as Argv, + ); + expect(yargsMock.positional).toHaveBeenCalledWith( + 'name', + expect.any(Object), + ); + expect(yargsMock.option).toHaveBeenCalledWith( + 'all', + expect.any(Object), + ); + expect(yargsMock.conflicts).toHaveBeenCalledWith('name', 'all'); + expect(yargsMock.check).toHaveBeenCalled(); + }); + + it('check function should throw an error if neither a name nor --all is provided', () => { + (command.builder as (yargs: Argv) => Argv)( + yargsMock as unknown as Argv, + ); + const checkCallback = yargsMock.check.mock.calls[0][0]; + expect(() => checkCallback({ name: undefined, all: false })).toThrow( + 'Either an extension name or --all must be provided', + ); + }); + }); + + it('handler should call handleUpdate', async () => { + const extensions = [{ name: 'my-extension', installMetadata: {} }]; + mockExtensionManager.prototype.loadExtensions = vi + .fn() + .mockResolvedValue(extensions); + mockCheckForExtensionUpdate.mockResolvedValue( + ExtensionUpdateState.UPDATE_AVAILABLE, + ); + mockUpdateExtension.mockResolvedValue({ + name: 'my-extension', + originalVersion: '1.0.0', + updatedVersion: '1.1.0', + }); + + await (command.handler as (args: object) => Promise)({ + name: 'my-extension', + }); + + expect(mockUpdateExtension).toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/cli/src/gemini.test.tsx b/packages/cli/src/gemini.test.tsx index bc225c1ba1..b83f2274c2 100644 --- a/packages/cli/src/gemini.test.tsx +++ b/packages/cli/src/gemini.test.tsx @@ -451,7 +451,7 @@ describe('startInteractiveUI', () => { const mockConfig = { getProjectRoot: () => '/root', getScreenReader: () => false, - } as Config; + } as unknown as Config; const mockSettings = { merged: { ui: { @@ -477,7 +477,6 @@ describe('startInteractiveUI', () => { isKittyProtocolSupported: vi.fn(() => true), isKittyProtocolEnabled: vi.fn(() => true), })); - vi.mock('./ui/utils/updateCheck.js', () => ({ checkForUpdates: vi.fn(() => Promise.resolve(null)), })); diff --git a/packages/cli/src/zed-integration/acp.test.ts b/packages/cli/src/zed-integration/acp.test.ts new file mode 100644 index 0000000000..c796b484b0 --- /dev/null +++ b/packages/cli/src/zed-integration/acp.test.ts @@ -0,0 +1,296 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { vi, describe, it, expect, beforeEach, type Mock } from 'vitest'; +import { + AgentSideConnection, + RequestError, + type Agent, + type Client, +} from './acp.js'; +import { type ErrorResponse } from './schema.js'; +import { type MethodHandler } from './connection.js'; +import { ReadableStream, WritableStream } from 'node:stream/web'; + +const mockConnectionConstructor = vi.hoisted(() => + vi.fn< + ( + arg1: MethodHandler, + arg2: WritableStream, + arg3: ReadableStream, + ) => { sendRequest: Mock; sendNotification: Mock } + >(() => ({ + sendRequest: vi.fn(), + sendNotification: vi.fn(), + })), +); + +vi.mock('./connection.js', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...(actual as object), + Connection: mockConnectionConstructor, + }; +}); + +describe('acp', () => { + describe('RequestError', () => { + it('should create a parse error', () => { + const error = RequestError.parseError('details'); + expect(error.code).toBe(-32700); + expect(error.message).toBe('Parse error'); + expect(error.data?.details).toBe('details'); + }); + + it('should create a method not found error', () => { + const error = RequestError.methodNotFound('details'); + expect(error.code).toBe(-32601); + expect(error.message).toBe('Method not found'); + expect(error.data?.details).toBe('details'); + }); + + it('should convert to a result', () => { + const error = RequestError.internalError('details'); + const result = error.toResult() as { error: ErrorResponse }; + expect(result.error.code).toBe(-32603); + expect(result.error.message).toBe('Internal error'); + expect(result.error.data).toEqual({ details: 'details' }); + }); + }); + + describe('AgentSideConnection', () => { + let mockAgent: Agent; + + let toAgent: WritableStream; + let fromAgent: ReadableStream; + let agentSideConnection: AgentSideConnection; + let connectionInstance: InstanceType; + + beforeEach(() => { + vi.clearAllMocks(); + + const initializeResponse = { + agentCapabilities: { loadSession: true }, + authMethods: [], + protocolVersion: 1, + }; + const newSessionResponse = { sessionId: 'session-1' }; + const loadSessionResponse = { sessionId: 'session-1' }; + + mockAgent = { + initialize: vi.fn().mockResolvedValue(initializeResponse), + newSession: vi.fn().mockResolvedValue(newSessionResponse), + loadSession: vi.fn().mockResolvedValue(loadSessionResponse), + authenticate: vi.fn(), + prompt: vi.fn(), + cancel: vi.fn(), + }; + + toAgent = new WritableStream(); + fromAgent = new ReadableStream(); + + agentSideConnection = new AgentSideConnection( + (_client: Client) => mockAgent, + toAgent, + fromAgent, + ); + + // Get the mocked Connection instance + connectionInstance = mockConnectionConstructor.mock.results[0].value; + }); + + it('should initialize Connection with the correct handler and streams', () => { + expect(mockConnectionConstructor).toHaveBeenCalledTimes(1); + expect(mockConnectionConstructor).toHaveBeenCalledWith( + expect.any(Function), + toAgent, + fromAgent, + ); + }); + + it('should call agent.initialize when Connection handler receives initialize method', async () => { + const initializeParams = { + clientCapabilities: { fs: { readTextFile: true, writeTextFile: true } }, + protocolVersion: 1, + }; + const initializeResponse = { + agentCapabilities: { loadSession: true }, + authMethods: [], + protocolVersion: 1, + }; + const handler = mockConnectionConstructor.mock + .calls[0][0]! as MethodHandler; + const result = await handler('initialize', initializeParams); + + expect(mockAgent.initialize).toHaveBeenCalledWith(initializeParams); + expect(result).toEqual(initializeResponse); + }); + + it('should call agent.newSession when Connection handler receives session_new method', async () => { + const newSessionParams = { cwd: '/tmp', mcpServers: [] }; + const newSessionResponse = { sessionId: 'session-1' }; + const handler = mockConnectionConstructor.mock + .calls[0][0]! as MethodHandler; + const result = await handler('session/new', newSessionParams); + + expect(mockAgent.newSession).toHaveBeenCalledWith(newSessionParams); + expect(result).toEqual(newSessionResponse); + }); + + it('should call agent.loadSession when Connection handler receives session_load method', async () => { + const loadSessionParams = { + cwd: '/tmp', + mcpServers: [], + sessionId: 'session-1', + }; + const loadSessionResponse = { sessionId: 'session-1' }; + const handler = mockConnectionConstructor.mock + .calls[0][0]! as MethodHandler; + const result = await handler('session/load', loadSessionParams); + + expect(mockAgent.loadSession).toHaveBeenCalledWith(loadSessionParams); + expect(result).toEqual(loadSessionResponse); + }); + + it('should throw methodNotFound if agent.loadSession is not implemented', async () => { + mockAgent.loadSession = undefined; // Simulate not implemented + const loadSessionParams = { + cwd: '/tmp', + mcpServers: [], + sessionId: 'session-1', + }; + const handler = mockConnectionConstructor.mock + .calls[0][0]! as MethodHandler; + await expect(handler('session/load', loadSessionParams)).rejects.toThrow( + RequestError.methodNotFound().message, + ); + }); + + it('should call agent.authenticate when Connection handler receives authenticate method', async () => { + const authenticateParams = { + methodId: 'test-auth-method', + authMethod: { + id: 'test-auth', + name: 'Test Auth Method', + description: 'A test authentication method', + }, + }; + const handler = mockConnectionConstructor.mock + .calls[0][0]! as MethodHandler; + const result = await handler('authenticate', authenticateParams); + + expect(mockAgent.authenticate).toHaveBeenCalledWith(authenticateParams); + expect(result).toBeUndefined(); + }); + + it('should call agent.prompt when Connection handler receives session_prompt method', async () => { + const promptParams = { + prompt: [{ type: 'text', text: 'hi' }], + sessionId: 'session-1', + }; + const promptResponse = { + response: [{ type: 'text', text: 'hello' }], + traceId: 'trace-1', + }; + (mockAgent.prompt as Mock).mockResolvedValue(promptResponse); + const handler = mockConnectionConstructor.mock + .calls[0][0]! as MethodHandler; + const result = await handler('session/prompt', promptParams); + + expect(mockAgent.prompt).toHaveBeenCalledWith(promptParams); + expect(result).toEqual(promptResponse); + }); + + it('should call agent.cancel when Connection handler receives session_cancel method', async () => { + const cancelParams = { sessionId: 'session-1' }; + const handler = mockConnectionConstructor.mock + .calls[0][0]! as MethodHandler; + const result = await handler('session/cancel', cancelParams); + + expect(mockAgent.cancel).toHaveBeenCalledWith(cancelParams); + expect(result).toBeUndefined(); + }); + + it('should throw methodNotFound for unknown methods', async () => { + const handler = mockConnectionConstructor.mock + .calls[0][0]! as MethodHandler; + await expect(handler('unknown_method', {})).rejects.toThrow( + RequestError.methodNotFound().message, + ); + }); + + it('should send sessionUpdate notification via connection', async () => { + const params = { + sessionId: '123', + update: { + sessionUpdate: 'user_message_chunk' as const, + content: { type: 'text' as const, text: 'hello' }, + }, + }; + await agentSideConnection.sessionUpdate(params); + }); + + it('should send requestPermission request via connection', async () => { + const params = { + sessionId: '123', + toolCall: { + toolCallId: 'tool-1', + title: 'Test Tool', + kind: 'other' as const, + status: 'pending' as const, + }, + options: [ + { + optionId: 'option-1', + name: 'Allow', + kind: 'allow_once' as const, + }, + ], + }; + const response = { + outcome: { outcome: 'selected', optionId: 'option-1' }, + }; + (connectionInstance.sendRequest as Mock).mockResolvedValue(response); + + const result = await agentSideConnection.requestPermission(params); + expect(connectionInstance.sendRequest).toHaveBeenCalledWith( + 'session/request_permission', + params, + ); + expect(result).toEqual(response); + }); + + it('should send readTextFile request via connection', async () => { + const params = { path: '/a/b.txt', sessionId: 'session-1' }; + const response = { content: 'file content' }; + (connectionInstance.sendRequest as Mock).mockResolvedValue(response); + + const result = await agentSideConnection.readTextFile(params); + expect(connectionInstance.sendRequest).toHaveBeenCalledWith( + 'fs/read_text_file', + params, + ); + expect(result).toEqual(response); + }); + + it('should send writeTextFile request via connection', async () => { + const params = { + path: '/a/b.txt', + content: 'new content', + sessionId: 'session-1', + }; + const response = { success: true }; + (connectionInstance.sendRequest as Mock).mockResolvedValue(response); + + const result = await agentSideConnection.writeTextFile(params); + expect(connectionInstance.sendRequest).toHaveBeenCalledWith( + 'fs/write_text_file', + params, + ); + expect(result).toEqual(response); + }); + }); +}); diff --git a/packages/cli/src/zed-integration/acp.ts b/packages/cli/src/zed-integration/acp.ts index 10a993f5fa..f8a0c8eace 100644 --- a/packages/cli/src/zed-integration/acp.ts +++ b/packages/cli/src/zed-integration/acp.ts @@ -6,12 +6,12 @@ /* ACP defines a schema for a simple (experimental) JSON-RPC protocol that allows GUI applications to interact with agents. */ -import { z } from 'zod'; import * as schema from './schema.js'; export * from './schema.js'; import type { WritableStream, ReadableStream } from 'node:stream/web'; -import { coreEvents } from '@google/gemini-cli-core'; +import { Connection, RequestError } from './connection.js'; +export { RequestError }; export class AgentSideConnection implements Client { #connection: Connection; @@ -108,236 +108,6 @@ export class AgentSideConnection implements Client { } } -type AnyMessage = AnyRequest | AnyResponse | AnyNotification; - -type AnyRequest = { - jsonrpc: '2.0'; - id: string | number; - method: string; - params?: unknown; -}; - -type AnyResponse = { - jsonrpc: '2.0'; - id: string | number; -} & Result; - -type AnyNotification = { - jsonrpc: '2.0'; - method: string; - params?: unknown; -}; - -type Result = - | { - result: T; - } - | { - error: ErrorResponse; - }; - -type ErrorResponse = { - code: number; - message: string; - data?: unknown; -}; - -type PendingResponse = { - resolve: (response: unknown) => void; - reject: (error: ErrorResponse) => void; -}; - -type MethodHandler = (method: string, params: unknown) => Promise; - -class Connection { - #pendingResponses: Map = new Map(); - #nextRequestId: number = 0; - #handler: MethodHandler; - #peerInput: WritableStream; - #writeQueue: Promise = Promise.resolve(); - #textEncoder: TextEncoder; - - constructor( - handler: MethodHandler, - peerInput: WritableStream, - peerOutput: ReadableStream, - ) { - this.#handler = handler; - this.#peerInput = peerInput; - this.#textEncoder = new TextEncoder(); - this.#receive(peerOutput); - } - - async #receive(output: ReadableStream) { - let content = ''; - const decoder = new TextDecoder(); - for await (const chunk of output) { - content += decoder.decode(chunk, { stream: true }); - const lines = content.split('\n'); - content = lines.pop() || ''; - - for (const line of lines) { - const trimmedLine = line.trim(); - - if (trimmedLine) { - const message = JSON.parse(trimmedLine); - this.#processMessage(message); - } - } - } - } - - async #processMessage(message: AnyMessage) { - if ('method' in message && 'id' in message) { - // It's a request - const response = await this.#tryCallHandler( - message.method, - message.params, - ); - - await this.#sendMessage({ - jsonrpc: '2.0', - id: message.id, - ...response, - }); - } else if ('method' in message) { - // It's a notification - await this.#tryCallHandler(message.method, message.params); - } else if ('id' in message) { - // It's a response - this.#handleResponse(message as AnyResponse); - } - } - - async #tryCallHandler( - method: string, - params?: unknown, - ): Promise> { - try { - const result = await this.#handler(method, params); - return { result: result ?? null }; - } catch (error: unknown) { - if (error instanceof RequestError) { - return error.toResult(); - } - - if (error instanceof z.ZodError) { - return RequestError.invalidParams( - JSON.stringify(error.format(), undefined, 2), - ).toResult(); - } - - let details; - - if (error instanceof Error) { - details = error.message; - } else if ( - typeof error === 'object' && - error != null && - 'message' in error && - typeof error.message === 'string' - ) { - details = error.message; - } - - return RequestError.internalError(details).toResult(); - } - } - - #handleResponse(response: AnyResponse) { - const pendingResponse = this.#pendingResponses.get(response.id); - if (pendingResponse) { - if ('result' in response) { - pendingResponse.resolve(response.result); - } else if ('error' in response) { - pendingResponse.reject(response.error); - } - this.#pendingResponses.delete(response.id); - } - } - - async sendRequest(method: string, params?: Req): Promise { - const id = this.#nextRequestId++; - const responsePromise = new Promise((resolve, reject) => { - this.#pendingResponses.set(id, { resolve, reject }); - }); - await this.#sendMessage({ jsonrpc: '2.0', id, method, params }); - return responsePromise as Promise; - } - - async sendNotification(method: string, params?: N): Promise { - await this.#sendMessage({ jsonrpc: '2.0', method, params }); - } - - async #sendMessage(json: AnyMessage) { - const content = JSON.stringify(json) + '\n'; - this.#writeQueue = this.#writeQueue - .then(async () => { - const writer = this.#peerInput.getWriter(); - try { - await writer.write(this.#textEncoder.encode(content)); - } finally { - writer.releaseLock(); - } - }) - .catch((error) => { - // Continue processing writes on error - coreEvents.emitFeedback('error', 'ACP write error.', error); - }); - return this.#writeQueue; - } -} - -export class RequestError extends Error { - data?: { details?: string }; - - constructor( - public code: number, - message: string, - details?: string, - ) { - super(message); - this.name = 'RequestError'; - if (details) { - this.data = { details }; - } - } - - static parseError(details?: string): RequestError { - return new RequestError(-32700, 'Parse error', details); - } - - static invalidRequest(details?: string): RequestError { - return new RequestError(-32600, 'Invalid request', details); - } - - static methodNotFound(details?: string): RequestError { - return new RequestError(-32601, 'Method not found', details); - } - - static invalidParams(details?: string): RequestError { - return new RequestError(-32602, 'Invalid params', details); - } - - static internalError(details?: string): RequestError { - return new RequestError(-32603, 'Internal error', details); - } - - static authRequired(details?: string): RequestError { - return new RequestError(-32000, 'Authentication required', details); - } - - toResult(): Result { - return { - error: { - code: this.code, - message: this.message, - data: this.data, - }, - }; - } -} - export interface Client { requestPermission( params: schema.RequestPermissionRequest, diff --git a/packages/cli/src/zed-integration/connection.ts b/packages/cli/src/zed-integration/connection.ts new file mode 100644 index 0000000000..2ad9358627 --- /dev/null +++ b/packages/cli/src/zed-integration/connection.ts @@ -0,0 +1,229 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { z } from 'zod'; +import { coreEvents } from '@google/gemini-cli-core'; +import { type Result, type ErrorResponse } from './schema.js'; +import type { WritableStream, ReadableStream } from 'node:stream/web'; + +export class RequestError extends Error { + data?: { details?: string }; + + constructor( + public code: number, + message: string, + details?: string, + ) { + super(message); + this.name = 'RequestError'; + if (details) { + this.data = { details }; + } + } + + static parseError(details?: string): RequestError { + return new RequestError(-32700, 'Parse error', details); + } + + static invalidRequest(details?: string): RequestError { + return new RequestError(-32600, 'Invalid request', details); + } + + static methodNotFound(details?: string): RequestError { + return new RequestError(-32601, 'Method not found', details); + } + + static invalidParams(details?: string): RequestError { + return new RequestError(-32602, 'Invalid params', details); + } + + static internalError(details?: string): RequestError { + return new RequestError(-32603, 'Internal error', details); + } + + static authRequired(details?: string): RequestError { + return new RequestError(-32000, 'Authentication required', details); + } + + toResult(): Result { + return { + error: { + code: this.code, + message: this.message, + data: this.data, + }, + }; + } +} + +type AnyMessage = AnyRequest | AnyResponse | AnyNotification; + +type AnyRequest = { + jsonrpc: '2.0'; + id: string | number; + method: string; + params?: unknown; +}; + +type AnyResponse = { + jsonrpc: '2.0'; + id: string | number; +} & Result; + +type AnyNotification = { + jsonrpc: '2.0'; + method: string; + params?: unknown; +}; + +type PendingResponse = { + resolve: (response: unknown) => void; + reject: (error: ErrorResponse) => void; +}; + +export type MethodHandler = ( + method: string, + params: unknown, +) => Promise; + +export class Connection { + #pendingResponses: Map = new Map(); + #nextRequestId: number = 0; + #handler: MethodHandler; + #peerInput: WritableStream; + #writeQueue: Promise = Promise.resolve(); + #textEncoder: TextEncoder; + + constructor( + handler: MethodHandler, + peerInput: WritableStream, + peerOutput: ReadableStream, + ) { + this.#handler = handler; + this.#peerInput = peerInput; + this.#textEncoder = new TextEncoder(); + this.#receive(peerOutput); + } + + async #receive(output: ReadableStream) { + let content = ''; + const decoder = new TextDecoder(); + for await (const chunk of output) { + content += decoder.decode(chunk, { stream: true }); + const lines = content.split('\n'); + content = lines.pop() || ''; + + for (const line of lines) { + const trimmedLine = line.trim(); + + if (trimmedLine) { + const message = JSON.parse(trimmedLine); + this.#processMessage(message); + } + } + } + } + + async #processMessage(message: AnyMessage) { + if ('method' in message && 'id' in message) { + // It's a request + const response = await this.#tryCallHandler( + message.method, + message.params, + ); + + await this.#sendMessage({ + jsonrpc: '2.0', + id: message.id, + ...response, + }); + } else if ('method' in message) { + // It's a notification + await this.#tryCallHandler(message.method, message.params); + } else if ('id' in message) { + // It's a response + this.#handleResponse(message as AnyResponse); + } + } + + async #tryCallHandler( + method: string, + params?: unknown, + ): Promise> { + try { + const result = await this.#handler(method, params); + return { result: result ?? null }; + } catch (error: unknown) { + if (error instanceof RequestError) { + return error.toResult(); + } + + if (error instanceof z.ZodError) { + return RequestError.invalidParams( + JSON.stringify(error.format(), undefined, 2), + ).toResult(); + } + + let details; + + if (error instanceof Error) { + details = error.message; + } else if ( + typeof error === 'object' && + error != null && + 'message' in error && + typeof error.message === 'string' + ) { + details = error.message; + } + + return RequestError.internalError(details).toResult(); + } + } + + #handleResponse(response: AnyResponse) { + const pendingResponse = this.#pendingResponses.get(response.id); + if (pendingResponse) { + if ('result' in response) { + pendingResponse.resolve(response.result); + } else if ('error' in response) { + pendingResponse.reject(response.error); + } + this.#pendingResponses.delete(response.id); + } + } + + async sendRequest(method: string, params?: Req): Promise { + const id = this.#nextRequestId++; + const responsePromise = new Promise((resolve, reject) => { + this.#pendingResponses.set(id, { resolve, reject }); + }); + await this.#sendMessage({ jsonrpc: '2.0', id, method, params }); + return responsePromise as Promise; + } + + async sendNotification(method: string, params?: N): Promise { + await this.#sendMessage({ jsonrpc: '2.0', method, params }); + } + + async #sendMessage(json: AnyMessage) { + const content = JSON.stringify(json) + '\n'; + this.#writeQueue = this.#writeQueue + .then(async () => { + const writer = this.#peerInput.getWriter(); + try { + await writer.write(this.#textEncoder.encode(content)); + } finally { + writer.releaseLock(); + } + }) + .catch((error) => { + // Continue processing writes on error + coreEvents.emitFeedback('error', 'ACP write error.', error); + }); + return this.#writeQueue; + } +} diff --git a/packages/cli/src/zed-integration/schema.ts b/packages/cli/src/zed-integration/schema.ts index b35cc47d5c..ef6c1d76d4 100644 --- a/packages/cli/src/zed-integration/schema.ts +++ b/packages/cli/src/zed-integration/schema.ts @@ -24,6 +24,12 @@ export const CLIENT_METHODS = { export const PROTOCOL_VERSION = 1; +export const authMethodSchema = z.object({ + description: z.string().nullable(), + id: z.string(), + name: z.string(), +}); + export type WriteTextFileRequest = z.infer; export type ReadTextFileRequest = z.infer; @@ -128,6 +134,20 @@ export type AgentRequest = z.infer; export type AgentNotification = z.infer; +export type Result = + | { + result: T; + } + | { + error: ErrorResponse; + }; + +export type ErrorResponse = { + code: number; + message: string; + data?: unknown; +}; + export const writeTextFileRequestSchema = z.object({ content: z.string(), path: z.string(), @@ -203,6 +223,7 @@ export const cancelNotificationSchema = z.object({ export const authenticateRequestSchema = z.object({ methodId: z.string(), + authMethod: authMethodSchema, }); export const authenticateResponseSchema = z.null(); @@ -283,12 +304,6 @@ export const agentCapabilitiesSchema = z.object({ promptCapabilities: promptCapabilitiesSchema.optional(), }); -export const authMethodSchema = z.object({ - description: z.string().nullable(), - id: z.string(), - name: z.string(), -}); - export const clientResponseSchema = z.union([ writeTextFileResponseSchema, readTextFileResponseSchema,