mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-11 06:31:01 -07:00
Improve test code coverage for cli/command/extensions package (#12994)
This commit is contained in:
303
packages/cli/src/commands/extensions/disable.test.ts
Normal file
303
packages/cli/src/commands/extensions/disable.test.ts
Normal file
@@ -0,0 +1,303 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
vi,
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
|
||||
import { type CommandModule, type Argv } from 'yargs';
|
||||
|
||||
import { handleDisable, disableCommand } from './disable.js';
|
||||
|
||||
import { ExtensionManager } from '../../config/extension-manager.js';
|
||||
|
||||
import {
|
||||
loadSettings,
|
||||
SettingScope,
|
||||
type LoadedSettings,
|
||||
} from '../../config/settings.js';
|
||||
import { getErrorMessage } from '../../utils/errors.js';
|
||||
|
||||
// Mock dependencies
|
||||
|
||||
vi.mock('../../config/extension-manager.js');
|
||||
|
||||
vi.mock('../../config/settings.js');
|
||||
|
||||
vi.mock('../../utils/errors.js');
|
||||
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
const actual =
|
||||
await importOriginal<typeof import('@google/gemini-cli-core')>();
|
||||
|
||||
return {
|
||||
...actual,
|
||||
|
||||
debugLogger: {
|
||||
log: vi.fn(),
|
||||
|
||||
error: vi.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('../../config/extensions/consent.js', () => ({
|
||||
requestConsentNonInteractive: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../../config/extensions/extensionSettings.js', () => ({
|
||||
promptForSetting: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('extensions disable command', () => {
|
||||
const mockLoadSettings = vi.mocked(loadSettings);
|
||||
|
||||
const mockGetErrorMessage = vi.mocked(getErrorMessage);
|
||||
|
||||
const mockExtensionManager = vi.mocked(ExtensionManager);
|
||||
|
||||
interface MockDebugLogger {
|
||||
log: Mock;
|
||||
error: Mock;
|
||||
}
|
||||
let mockDebugLogger: MockDebugLogger;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// We need to re-import the mocked module to get the fresh mock
|
||||
|
||||
mockDebugLogger = (await import('@google/gemini-cli-core'))
|
||||
.debugLogger as unknown as MockDebugLogger;
|
||||
|
||||
mockLoadSettings.mockReturnValue({
|
||||
merged: {},
|
||||
} as unknown as LoadedSettings);
|
||||
|
||||
mockExtensionManager.prototype.loadExtensions = vi
|
||||
|
||||
.fn()
|
||||
|
||||
.mockResolvedValue(undefined);
|
||||
|
||||
mockExtensionManager.prototype.disableExtension = vi
|
||||
|
||||
.fn()
|
||||
|
||||
.mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('handleDisable', () => {
|
||||
it.each([
|
||||
{
|
||||
name: 'my-extension',
|
||||
scope: undefined,
|
||||
expectedScope: SettingScope.User,
|
||||
expectedLog:
|
||||
'Extension "my-extension" successfully disabled for scope "undefined".',
|
||||
},
|
||||
{
|
||||
name: 'my-extension',
|
||||
scope: 'user',
|
||||
expectedScope: SettingScope.User,
|
||||
expectedLog:
|
||||
'Extension "my-extension" successfully disabled for scope "user".',
|
||||
},
|
||||
{
|
||||
name: 'my-extension',
|
||||
scope: 'workspace',
|
||||
expectedScope: SettingScope.Workspace,
|
||||
expectedLog:
|
||||
'Extension "my-extension" successfully disabled for scope "workspace".',
|
||||
},
|
||||
])(
|
||||
'should disable an extension in the $expectedScope scope when scope is $scope',
|
||||
async ({ name, scope, expectedScope, expectedLog }) => {
|
||||
const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir');
|
||||
|
||||
await handleDisable({ name, scope });
|
||||
|
||||
expect(mockExtensionManager).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
workspaceDir: '/test/dir',
|
||||
}),
|
||||
);
|
||||
|
||||
expect(
|
||||
mockExtensionManager.prototype.loadExtensions,
|
||||
).toHaveBeenCalled();
|
||||
|
||||
expect(
|
||||
mockExtensionManager.prototype.disableExtension,
|
||||
).toHaveBeenCalledWith(name, expectedScope);
|
||||
|
||||
expect(mockDebugLogger.log).toHaveBeenCalledWith(expectedLog);
|
||||
|
||||
mockCwd.mockRestore();
|
||||
},
|
||||
);
|
||||
|
||||
it('should log an error message and exit with code 1 when extension disabling fails', async () => {
|
||||
const mockProcessExit = vi
|
||||
|
||||
.spyOn(process, 'exit')
|
||||
|
||||
.mockImplementation((() => {}) as (
|
||||
code?: string | number | null | undefined,
|
||||
) => never);
|
||||
|
||||
const error = new Error('Disable failed');
|
||||
|
||||
(
|
||||
mockExtensionManager.prototype.disableExtension as Mock
|
||||
).mockRejectedValue(error);
|
||||
|
||||
mockGetErrorMessage.mockReturnValue('Disable failed message');
|
||||
|
||||
await handleDisable({ name: 'my-extension' });
|
||||
|
||||
expect(mockDebugLogger.error).toHaveBeenCalledWith(
|
||||
'Disable failed message',
|
||||
);
|
||||
|
||||
expect(mockProcessExit).toHaveBeenCalledWith(1);
|
||||
|
||||
mockProcessExit.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('disableCommand', () => {
|
||||
const command = disableCommand as CommandModule;
|
||||
|
||||
it('should have correct command and describe', () => {
|
||||
expect(command.command).toBe('disable [--scope] <name>');
|
||||
|
||||
expect(command.describe).toBe('Disables an extension.');
|
||||
});
|
||||
|
||||
describe('builder', () => {
|
||||
interface MockYargs {
|
||||
positional: Mock;
|
||||
|
||||
option: Mock;
|
||||
|
||||
check: Mock;
|
||||
}
|
||||
|
||||
let yargsMock: MockYargs;
|
||||
|
||||
beforeEach(() => {
|
||||
yargsMock = {
|
||||
positional: vi.fn().mockReturnThis(),
|
||||
|
||||
option: vi.fn().mockReturnThis(),
|
||||
|
||||
check: vi.fn().mockReturnThis(),
|
||||
};
|
||||
});
|
||||
|
||||
it('should configure positional and option arguments', () => {
|
||||
(command.builder as (yargs: Argv) => Argv)(
|
||||
yargsMock as unknown as Argv,
|
||||
);
|
||||
|
||||
expect(yargsMock.positional).toHaveBeenCalledWith('name', {
|
||||
describe: 'The name of the extension to disable.',
|
||||
|
||||
type: 'string',
|
||||
});
|
||||
|
||||
expect(yargsMock.option).toHaveBeenCalledWith('scope', {
|
||||
describe: 'The scope to disable the extension in.',
|
||||
|
||||
type: 'string',
|
||||
|
||||
default: SettingScope.User,
|
||||
});
|
||||
|
||||
expect(yargsMock.check).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('check function should throw for invalid scope', () => {
|
||||
(command.builder as (yargs: Argv) => Argv)(
|
||||
yargsMock as unknown as Argv,
|
||||
);
|
||||
|
||||
const checkCallback = yargsMock.check.mock.calls[0][0];
|
||||
|
||||
const expectedError = `Invalid scope: invalid. Please use one of ${Object.values(
|
||||
SettingScope,
|
||||
)
|
||||
|
||||
.map((s) => s.toLowerCase())
|
||||
|
||||
.join(', ')}.`;
|
||||
|
||||
expect(() => checkCallback({ scope: 'invalid' })).toThrow(
|
||||
expectedError,
|
||||
);
|
||||
});
|
||||
|
||||
it.each(['user', 'workspace', 'USER', 'WorkSpace'])(
|
||||
'check function should return true for valid scope "%s"',
|
||||
(scope) => {
|
||||
(command.builder as (yargs: Argv) => Argv)(
|
||||
yargsMock as unknown as Argv,
|
||||
);
|
||||
|
||||
const checkCallback = yargsMock.check.mock.calls[0][0];
|
||||
|
||||
expect(checkCallback({ scope })).toBe(true);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
it('handler should trigger extension disabling', async () => {
|
||||
const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir');
|
||||
|
||||
interface TestArgv {
|
||||
name: string;
|
||||
scope: string;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
const argv: TestArgv = {
|
||||
name: 'test-ext',
|
||||
scope: 'workspace',
|
||||
_: [],
|
||||
$0: '',
|
||||
};
|
||||
|
||||
await (command.handler as unknown as (args: TestArgv) => void)(argv);
|
||||
expect(mockExtensionManager).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
workspaceDir: '/test/dir',
|
||||
}),
|
||||
);
|
||||
|
||||
expect(mockExtensionManager.prototype.loadExtensions).toHaveBeenCalled();
|
||||
|
||||
expect(
|
||||
mockExtensionManager.prototype.disableExtension,
|
||||
).toHaveBeenCalledWith('test-ext', SettingScope.Workspace);
|
||||
|
||||
expect(mockDebugLogger.log).toHaveBeenCalledWith(
|
||||
'Extension "test-ext" successfully disabled for scope "workspace".',
|
||||
);
|
||||
|
||||
mockCwd.mockRestore();
|
||||
});
|
||||
});
|
||||
});
|
||||
208
packages/cli/src/commands/extensions/enable.test.ts
Normal file
208
packages/cli/src/commands/extensions/enable.test.ts
Normal file
@@ -0,0 +1,208 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
vi,
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import { type CommandModule, type Argv } from 'yargs';
|
||||
import { handleEnable, enableCommand } from './enable.js';
|
||||
import { ExtensionManager } from '../../config/extension-manager.js';
|
||||
import {
|
||||
loadSettings,
|
||||
SettingScope,
|
||||
type LoadedSettings,
|
||||
} from '../../config/settings.js';
|
||||
import { FatalConfigError } from '@google/gemini-cli-core';
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('../../config/extension-manager.js');
|
||||
vi.mock('../../config/settings.js');
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
const actual =
|
||||
await importOriginal<typeof import('@google/gemini-cli-core')>();
|
||||
return {
|
||||
...actual,
|
||||
debugLogger: {
|
||||
log: vi.fn(),
|
||||
error: vi.fn(),
|
||||
},
|
||||
getErrorMessage: vi.fn((error: { message: string }) => error.message),
|
||||
FatalConfigError: class extends Error {
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = 'FatalConfigError';
|
||||
}
|
||||
},
|
||||
};
|
||||
});
|
||||
vi.mock('../../config/extensions/consent.js');
|
||||
vi.mock('../../config/extensions/extensionSettings.js');
|
||||
|
||||
describe('extensions enable command', () => {
|
||||
const mockLoadSettings = vi.mocked(loadSettings);
|
||||
const mockExtensionManager = vi.mocked(ExtensionManager);
|
||||
interface MockDebugLogger {
|
||||
log: Mock;
|
||||
error: Mock;
|
||||
}
|
||||
let mockDebugLogger: MockDebugLogger;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks();
|
||||
mockDebugLogger = (await import('@google/gemini-cli-core'))
|
||||
.debugLogger as unknown as MockDebugLogger;
|
||||
mockLoadSettings.mockReturnValue({
|
||||
merged: {},
|
||||
} as unknown as LoadedSettings);
|
||||
mockExtensionManager.prototype.loadExtensions = vi
|
||||
.fn()
|
||||
.mockResolvedValue(undefined);
|
||||
mockExtensionManager.prototype.enableExtension = vi.fn();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('handleEnable', () => {
|
||||
it.each([
|
||||
{
|
||||
name: 'my-extension',
|
||||
scope: undefined,
|
||||
expectedScope: SettingScope.User,
|
||||
expectedLog:
|
||||
'Extension "my-extension" successfully enabled in all scopes.',
|
||||
},
|
||||
{
|
||||
name: 'my-extension',
|
||||
scope: 'workspace',
|
||||
expectedScope: SettingScope.Workspace,
|
||||
expectedLog:
|
||||
'Extension "my-extension" successfully enabled for scope "workspace".',
|
||||
},
|
||||
])(
|
||||
'should enable an extension in the $expectedScope scope when scope is $scope',
|
||||
async ({ name, scope, expectedScope, expectedLog }) => {
|
||||
const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir');
|
||||
await handleEnable({ name, scope });
|
||||
|
||||
expect(mockExtensionManager).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
workspaceDir: '/test/dir',
|
||||
}),
|
||||
);
|
||||
expect(
|
||||
mockExtensionManager.prototype.loadExtensions,
|
||||
).toHaveBeenCalled();
|
||||
expect(
|
||||
mockExtensionManager.prototype.enableExtension,
|
||||
).toHaveBeenCalledWith(name, expectedScope);
|
||||
expect(mockDebugLogger.log).toHaveBeenCalledWith(expectedLog);
|
||||
mockCwd.mockRestore();
|
||||
},
|
||||
);
|
||||
|
||||
it('should throw FatalConfigError when extension enabling fails', async () => {
|
||||
const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir');
|
||||
const error = new Error('Enable failed');
|
||||
(
|
||||
mockExtensionManager.prototype.enableExtension as Mock
|
||||
).mockImplementation(() => {
|
||||
throw error;
|
||||
});
|
||||
|
||||
const promise = handleEnable({ name: 'my-extension' });
|
||||
await expect(promise).rejects.toThrow(FatalConfigError);
|
||||
await expect(promise).rejects.toThrow('Enable failed');
|
||||
|
||||
mockCwd.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('enableCommand', () => {
|
||||
const command = enableCommand as CommandModule;
|
||||
|
||||
it('should have correct command and describe', () => {
|
||||
expect(command.command).toBe('enable [--scope] <name>');
|
||||
expect(command.describe).toBe('Enables an extension.');
|
||||
});
|
||||
|
||||
describe('builder', () => {
|
||||
interface MockYargs {
|
||||
positional: Mock;
|
||||
option: Mock;
|
||||
check: Mock;
|
||||
}
|
||||
|
||||
let yargsMock: MockYargs;
|
||||
beforeEach(() => {
|
||||
yargsMock = {
|
||||
positional: vi.fn().mockReturnThis(),
|
||||
option: vi.fn().mockReturnThis(),
|
||||
check: vi.fn().mockReturnThis(),
|
||||
};
|
||||
});
|
||||
|
||||
it('should configure positional and option arguments', () => {
|
||||
(command.builder as (yargs: Argv) => Argv)(
|
||||
yargsMock as unknown as Argv,
|
||||
);
|
||||
expect(yargsMock.positional).toHaveBeenCalledWith('name', {
|
||||
describe: 'The name of the extension to enable.',
|
||||
type: 'string',
|
||||
});
|
||||
expect(yargsMock.option).toHaveBeenCalledWith('scope', {
|
||||
describe:
|
||||
'The scope to enable the extension in. If not set, will be enabled in all scopes.',
|
||||
type: 'string',
|
||||
});
|
||||
expect(yargsMock.check).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('check function should throw for invalid scope', () => {
|
||||
(command.builder as (yargs: Argv) => Argv)(
|
||||
yargsMock as unknown as Argv,
|
||||
);
|
||||
const checkCallback = yargsMock.check.mock.calls[0][0];
|
||||
const expectedError = `Invalid scope: invalid. Please use one of ${Object.values(
|
||||
SettingScope,
|
||||
)
|
||||
.map((s) => s.toLowerCase())
|
||||
.join(', ')}.`;
|
||||
expect(() => checkCallback({ scope: 'invalid' })).toThrow(
|
||||
expectedError,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('handler should call handleEnable', async () => {
|
||||
const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir');
|
||||
interface TestArgv {
|
||||
name: string;
|
||||
scope: string;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
const argv: TestArgv = {
|
||||
name: 'test-ext',
|
||||
scope: 'workspace',
|
||||
_: [],
|
||||
$0: '',
|
||||
};
|
||||
await (command.handler as unknown as (args: TestArgv) => void)(argv);
|
||||
|
||||
expect(
|
||||
mockExtensionManager.prototype.enableExtension,
|
||||
).toHaveBeenCalledWith('test-ext', SettingScope.Workspace);
|
||||
mockCwd.mockRestore();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,143 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
|
||||
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
|
||||
import { z } from 'zod';
|
||||
|
||||
// Mock the MCP server and transport
|
||||
const mockRegisterTool = vi.fn();
|
||||
const mockRegisterPrompt = vi.fn();
|
||||
const mockConnect = vi.fn();
|
||||
|
||||
vi.mock('@modelcontextprotocol/sdk/server/mcp.js', () => ({
|
||||
McpServer: vi.fn().mockImplementation(() => ({
|
||||
registerTool: mockRegisterTool,
|
||||
registerPrompt: mockRegisterPrompt,
|
||||
connect: mockConnect,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('@modelcontextprotocol/sdk/server/stdio.js', () => ({
|
||||
StdioServerTransport: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('MCP Server Example', () => {
|
||||
beforeEach(async () => {
|
||||
// Dynamically import the server setup after mocks are in place
|
||||
await import('./example.js');
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.resetModules();
|
||||
});
|
||||
|
||||
it('should create an McpServer with the correct name and version', () => {
|
||||
expect(McpServer).toHaveBeenCalledWith({
|
||||
name: 'prompt-server',
|
||||
version: '1.0.0',
|
||||
});
|
||||
});
|
||||
|
||||
it('should register the "fetch_posts" tool', () => {
|
||||
expect(mockRegisterTool).toHaveBeenCalledWith(
|
||||
'fetch_posts',
|
||||
{
|
||||
description: 'Fetches a list of posts from a public API.',
|
||||
inputSchema: z.object({}).shape,
|
||||
},
|
||||
expect.any(Function),
|
||||
);
|
||||
});
|
||||
|
||||
it('should register the "poem-writer" prompt', () => {
|
||||
expect(mockRegisterPrompt).toHaveBeenCalledWith(
|
||||
'poem-writer',
|
||||
{
|
||||
title: 'Poem Writer',
|
||||
description: 'Write a nice haiku',
|
||||
argsSchema: expect.any(Object),
|
||||
},
|
||||
expect.any(Function),
|
||||
);
|
||||
});
|
||||
|
||||
it('should connect the server to an StdioServerTransport', () => {
|
||||
expect(StdioServerTransport).toHaveBeenCalled();
|
||||
expect(mockConnect).toHaveBeenCalledWith(expect.any(StdioServerTransport));
|
||||
});
|
||||
|
||||
describe('fetch_posts tool implementation', () => {
|
||||
it('should fetch posts and return a formatted response', async () => {
|
||||
const mockPosts = [
|
||||
{ id: 1, title: 'Post 1' },
|
||||
{ id: 2, title: 'Post 2' },
|
||||
];
|
||||
global.fetch = vi.fn().mockResolvedValue({
|
||||
json: vi.fn().mockResolvedValue(mockPosts),
|
||||
});
|
||||
|
||||
const toolFn = (mockRegisterTool as Mock).mock.calls[0][2];
|
||||
const result = await toolFn();
|
||||
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
'https://jsonplaceholder.typicode.com/posts',
|
||||
);
|
||||
expect(result).toEqual({
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: JSON.stringify({ posts: mockPosts }),
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('poem-writer prompt implementation', () => {
|
||||
it('should generate a prompt with a title', () => {
|
||||
const promptFn = (mockRegisterPrompt as Mock).mock.calls[0][2];
|
||||
const result = promptFn({ title: 'My Poem' });
|
||||
expect(result).toEqual({
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: {
|
||||
type: 'text',
|
||||
text: 'Write a haiku called My Poem. Note that a haiku is 5 syllables followed by 7 syllables followed by 5 syllables ',
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('should generate a prompt with a title and mood', () => {
|
||||
const promptFn = (mockRegisterPrompt as Mock).mock.calls[0][2];
|
||||
const result = promptFn({ title: 'My Poem', mood: 'sad' });
|
||||
expect(result).toEqual({
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: {
|
||||
type: 'text',
|
||||
text: 'Write a haiku with the mood sad called My Poem. Note that a haiku is 5 syllables followed by 7 syllables followed by 5 syllables ',
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
174
packages/cli/src/commands/extensions/link.test.ts
Normal file
174
packages/cli/src/commands/extensions/link.test.ts
Normal file
@@ -0,0 +1,174 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
vi,
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import { type CommandModule, type Argv } from 'yargs';
|
||||
import { handleLink, linkCommand } from './link.js';
|
||||
import { ExtensionManager } from '../../config/extension-manager.js';
|
||||
import { loadSettings, type LoadedSettings } from '../../config/settings.js';
|
||||
import { getErrorMessage } from '../../utils/errors.js';
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('../../config/extension-manager.js');
|
||||
vi.mock('../../config/settings.js');
|
||||
vi.mock('../../utils/errors.js');
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
const actual =
|
||||
await importOriginal<typeof import('@google/gemini-cli-core')>();
|
||||
return {
|
||||
...actual,
|
||||
debugLogger: {
|
||||
log: vi.fn(),
|
||||
error: vi.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
vi.mock('../../config/extensions/consent.js', () => ({
|
||||
requestConsentNonInteractive: vi.fn(),
|
||||
}));
|
||||
vi.mock('../../config/extensions/extensionSettings.js', () => ({
|
||||
promptForSetting: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('extensions link command', () => {
|
||||
const mockLoadSettings = vi.mocked(loadSettings);
|
||||
const mockGetErrorMessage = vi.mocked(getErrorMessage);
|
||||
const mockExtensionManager = vi.mocked(ExtensionManager);
|
||||
interface MockDebugLogger {
|
||||
log: Mock;
|
||||
error: Mock;
|
||||
}
|
||||
let mockDebugLogger: MockDebugLogger;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks();
|
||||
mockDebugLogger = (await import('@google/gemini-cli-core'))
|
||||
.debugLogger as unknown as MockDebugLogger;
|
||||
mockLoadSettings.mockReturnValue({
|
||||
merged: {},
|
||||
} as unknown as LoadedSettings);
|
||||
mockExtensionManager.prototype.loadExtensions = vi
|
||||
.fn()
|
||||
.mockResolvedValue(undefined);
|
||||
mockExtensionManager.prototype.installOrUpdateExtension = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ name: 'my-linked-extension' });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('handleLink', () => {
|
||||
it('should link an extension from a local path', async () => {
|
||||
const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir');
|
||||
await handleLink({ path: '/local/path/to/extension' });
|
||||
|
||||
expect(mockExtensionManager).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
workspaceDir: '/test/dir',
|
||||
}),
|
||||
);
|
||||
expect(mockExtensionManager.prototype.loadExtensions).toHaveBeenCalled();
|
||||
expect(
|
||||
mockExtensionManager.prototype.installOrUpdateExtension,
|
||||
).toHaveBeenCalledWith({
|
||||
source: '/local/path/to/extension',
|
||||
type: 'link',
|
||||
});
|
||||
expect(mockDebugLogger.log).toHaveBeenCalledWith(
|
||||
'Extension "my-linked-extension" linked successfully and enabled.',
|
||||
);
|
||||
mockCwd.mockRestore();
|
||||
});
|
||||
|
||||
it('should log an error message and exit with code 1 when linking fails', async () => {
|
||||
const mockProcessExit = vi
|
||||
.spyOn(process, 'exit')
|
||||
.mockImplementation((() => {}) as (
|
||||
code?: string | number | null | undefined,
|
||||
) => never);
|
||||
const error = new Error('Link failed');
|
||||
(
|
||||
mockExtensionManager.prototype.installOrUpdateExtension as Mock
|
||||
).mockRejectedValue(error);
|
||||
mockGetErrorMessage.mockReturnValue('Link failed message');
|
||||
|
||||
await handleLink({ path: '/local/path/to/extension' });
|
||||
|
||||
expect(mockDebugLogger.error).toHaveBeenCalledWith('Link failed message');
|
||||
expect(mockProcessExit).toHaveBeenCalledWith(1);
|
||||
mockProcessExit.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('linkCommand', () => {
|
||||
const command = linkCommand as CommandModule;
|
||||
|
||||
it('should have correct command and describe', () => {
|
||||
expect(command.command).toBe('link <path>');
|
||||
expect(command.describe).toBe(
|
||||
'Links an extension from a local path. Updates made to the local path will always be reflected.',
|
||||
);
|
||||
});
|
||||
|
||||
describe('builder', () => {
|
||||
interface MockYargs {
|
||||
positional: Mock;
|
||||
check: Mock;
|
||||
}
|
||||
|
||||
let yargsMock: MockYargs;
|
||||
beforeEach(() => {
|
||||
yargsMock = {
|
||||
positional: vi.fn().mockReturnThis(),
|
||||
check: vi.fn().mockReturnThis(),
|
||||
};
|
||||
});
|
||||
|
||||
it('should configure positional argument', () => {
|
||||
(command.builder as (yargs: Argv) => Argv)(
|
||||
yargsMock as unknown as Argv,
|
||||
);
|
||||
expect(yargsMock.positional).toHaveBeenCalledWith('path', {
|
||||
describe: 'The name of the extension to link.',
|
||||
type: 'string',
|
||||
});
|
||||
expect(yargsMock.check).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it('handler should call handleLink', async () => {
|
||||
const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir');
|
||||
interface TestArgv {
|
||||
path: string;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
const argv: TestArgv = {
|
||||
path: '/local/path/to/extension',
|
||||
_: [],
|
||||
$0: '',
|
||||
};
|
||||
await (command.handler as unknown as (args: TestArgv) => void)(argv);
|
||||
|
||||
expect(
|
||||
mockExtensionManager.prototype.installOrUpdateExtension,
|
||||
).toHaveBeenCalledWith({
|
||||
source: '/local/path/to/extension',
|
||||
type: 'link',
|
||||
});
|
||||
mockCwd.mockRestore();
|
||||
});
|
||||
});
|
||||
});
|
||||
137
packages/cli/src/commands/extensions/list.test.ts
Normal file
137
packages/cli/src/commands/extensions/list.test.ts
Normal file
@@ -0,0 +1,137 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
vi,
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import { type CommandModule } from 'yargs';
|
||||
import { handleList, listCommand } from './list.js';
|
||||
import { ExtensionManager } from '../../config/extension-manager.js';
|
||||
import { loadSettings, type LoadedSettings } from '../../config/settings.js';
|
||||
import { getErrorMessage } from '../../utils/errors.js';
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('../../config/extension-manager.js');
|
||||
vi.mock('../../config/settings.js');
|
||||
vi.mock('../../utils/errors.js');
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
const actual =
|
||||
await importOriginal<typeof import('@google/gemini-cli-core')>();
|
||||
return {
|
||||
...actual,
|
||||
debugLogger: {
|
||||
log: vi.fn(),
|
||||
error: vi.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
vi.mock('../../config/extensions/consent.js', () => ({
|
||||
requestConsentNonInteractive: vi.fn(),
|
||||
}));
|
||||
vi.mock('../../config/extensions/extensionSettings.js', () => ({
|
||||
promptForSetting: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('extensions list command', () => {
|
||||
const mockLoadSettings = vi.mocked(loadSettings);
|
||||
const mockGetErrorMessage = vi.mocked(getErrorMessage);
|
||||
const mockExtensionManager = vi.mocked(ExtensionManager);
|
||||
interface MockDebugLogger {
|
||||
log: Mock;
|
||||
error: Mock;
|
||||
}
|
||||
let mockDebugLogger: MockDebugLogger;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks();
|
||||
mockDebugLogger = (await import('@google/gemini-cli-core'))
|
||||
.debugLogger as unknown as MockDebugLogger;
|
||||
mockLoadSettings.mockReturnValue({
|
||||
merged: {},
|
||||
} as unknown as LoadedSettings);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('handleList', () => {
|
||||
it('should log a message if no extensions are installed', async () => {
|
||||
const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir');
|
||||
mockExtensionManager.prototype.loadExtensions = vi
|
||||
.fn()
|
||||
.mockResolvedValue([]);
|
||||
await handleList();
|
||||
|
||||
expect(mockDebugLogger.log).toHaveBeenCalledWith(
|
||||
'No extensions installed.',
|
||||
);
|
||||
mockCwd.mockRestore();
|
||||
});
|
||||
|
||||
it('should list all installed extensions', async () => {
|
||||
const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir');
|
||||
const extensions = [
|
||||
{ name: 'ext1', version: '1.0.0' },
|
||||
{ name: 'ext2', version: '2.0.0' },
|
||||
];
|
||||
mockExtensionManager.prototype.loadExtensions = vi
|
||||
.fn()
|
||||
.mockResolvedValue(extensions);
|
||||
mockExtensionManager.prototype.toOutputString = vi.fn(
|
||||
(ext) => `${ext.name}@${ext.version}`,
|
||||
);
|
||||
await handleList();
|
||||
|
||||
expect(mockDebugLogger.log).toHaveBeenCalledWith(
|
||||
'ext1@1.0.0\n\next2@2.0.0',
|
||||
);
|
||||
mockCwd.mockRestore();
|
||||
});
|
||||
|
||||
it('should log an error message and exit with code 1 when listing fails', async () => {
|
||||
const mockProcessExit = vi
|
||||
.spyOn(process, 'exit')
|
||||
.mockImplementation((() => {}) as (
|
||||
code?: string | number | null | undefined,
|
||||
) => never);
|
||||
const error = new Error('List failed');
|
||||
mockExtensionManager.prototype.loadExtensions = vi
|
||||
.fn()
|
||||
.mockRejectedValue(error);
|
||||
mockGetErrorMessage.mockReturnValue('List failed message');
|
||||
|
||||
await handleList();
|
||||
|
||||
expect(mockDebugLogger.error).toHaveBeenCalledWith('List failed message');
|
||||
expect(mockProcessExit).toHaveBeenCalledWith(1);
|
||||
mockProcessExit.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('listCommand', () => {
|
||||
const command = listCommand as CommandModule;
|
||||
|
||||
it('should have correct command and describe', () => {
|
||||
expect(command.command).toBe('list');
|
||||
expect(command.describe).toBe('Lists installed extensions.');
|
||||
});
|
||||
|
||||
it('handler should call handleList', async () => {
|
||||
mockExtensionManager.prototype.loadExtensions = vi
|
||||
.fn()
|
||||
.mockResolvedValue([]);
|
||||
await (command.handler as () => Promise<void>)();
|
||||
expect(mockExtensionManager.prototype.loadExtensions).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -4,18 +4,171 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { uninstallCommand } from './uninstall.js';
|
||||
import yargs from 'yargs';
|
||||
import {
|
||||
vi,
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import { type CommandModule, type Argv } from 'yargs';
|
||||
import { handleUninstall, uninstallCommand } from './uninstall.js';
|
||||
import { ExtensionManager } from '../../config/extension-manager.js';
|
||||
import { loadSettings, type LoadedSettings } from '../../config/settings.js';
|
||||
import { getErrorMessage } from '../../utils/errors.js';
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('../../config/extension-manager.js');
|
||||
vi.mock('../../config/settings.js');
|
||||
vi.mock('../../utils/errors.js');
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
const actual =
|
||||
await importOriginal<typeof import('@google/gemini-cli-core')>();
|
||||
return {
|
||||
...actual,
|
||||
debugLogger: {
|
||||
log: vi.fn(),
|
||||
error: vi.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
vi.mock('../../config/extensions/consent.js', () => ({
|
||||
requestConsentNonInteractive: vi.fn(),
|
||||
}));
|
||||
vi.mock('../../config/extensions/extensionSettings.js', () => ({
|
||||
promptForSetting: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('extensions uninstall command', () => {
|
||||
it('should fail if no source is provided', () => {
|
||||
const validationParser = yargs([])
|
||||
.command(uninstallCommand)
|
||||
.fail(false)
|
||||
.locale('en');
|
||||
expect(() => validationParser.parse('uninstall')).toThrow(
|
||||
'Not enough non-option arguments: got 0, need at least 1',
|
||||
);
|
||||
const mockLoadSettings = vi.mocked(loadSettings);
|
||||
const mockGetErrorMessage = vi.mocked(getErrorMessage);
|
||||
const mockExtensionManager = vi.mocked(ExtensionManager);
|
||||
interface MockDebugLogger {
|
||||
log: Mock;
|
||||
error: Mock;
|
||||
}
|
||||
let mockDebugLogger: MockDebugLogger;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks();
|
||||
mockDebugLogger = (await import('@google/gemini-cli-core'))
|
||||
.debugLogger as unknown as MockDebugLogger;
|
||||
mockLoadSettings.mockReturnValue({
|
||||
merged: {},
|
||||
} as unknown as LoadedSettings);
|
||||
mockExtensionManager.prototype.loadExtensions = vi
|
||||
.fn()
|
||||
.mockResolvedValue(undefined);
|
||||
mockExtensionManager.prototype.uninstallExtension = vi
|
||||
.fn()
|
||||
.mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('handleUninstall', () => {
|
||||
it('should uninstall an extension', async () => {
|
||||
const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir');
|
||||
await handleUninstall({ name: 'my-extension' });
|
||||
|
||||
expect(mockExtensionManager).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
workspaceDir: '/test/dir',
|
||||
}),
|
||||
);
|
||||
expect(mockExtensionManager.prototype.loadExtensions).toHaveBeenCalled();
|
||||
expect(
|
||||
mockExtensionManager.prototype.uninstallExtension,
|
||||
).toHaveBeenCalledWith('my-extension', false);
|
||||
expect(mockDebugLogger.log).toHaveBeenCalledWith(
|
||||
'Extension "my-extension" successfully uninstalled.',
|
||||
);
|
||||
mockCwd.mockRestore();
|
||||
});
|
||||
|
||||
it('should log an error message and exit with code 1 when uninstallation fails', async () => {
|
||||
const mockProcessExit = vi
|
||||
.spyOn(process, 'exit')
|
||||
.mockImplementation((() => {}) as (
|
||||
code?: string | number | null | undefined,
|
||||
) => never);
|
||||
const error = new Error('Uninstall failed');
|
||||
(
|
||||
mockExtensionManager.prototype.uninstallExtension as Mock
|
||||
).mockRejectedValue(error);
|
||||
mockGetErrorMessage.mockReturnValue('Uninstall failed message');
|
||||
|
||||
await handleUninstall({ name: 'my-extension' });
|
||||
|
||||
expect(mockDebugLogger.error).toHaveBeenCalledWith(
|
||||
'Uninstall failed message',
|
||||
);
|
||||
expect(mockProcessExit).toHaveBeenCalledWith(1);
|
||||
mockProcessExit.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('uninstallCommand', () => {
|
||||
const command = uninstallCommand as CommandModule;
|
||||
|
||||
it('should have correct command and describe', () => {
|
||||
expect(command.command).toBe('uninstall <name>');
|
||||
expect(command.describe).toBe('Uninstalls an extension.');
|
||||
});
|
||||
|
||||
describe('builder', () => {
|
||||
interface MockYargs {
|
||||
positional: Mock;
|
||||
check: Mock;
|
||||
}
|
||||
|
||||
let yargsMock: MockYargs;
|
||||
beforeEach(() => {
|
||||
yargsMock = {
|
||||
positional: vi.fn().mockReturnThis(),
|
||||
check: vi.fn().mockReturnThis(),
|
||||
};
|
||||
});
|
||||
|
||||
it('should configure positional argument', () => {
|
||||
(command.builder as (yargs: Argv) => Argv)(
|
||||
yargsMock as unknown as Argv,
|
||||
);
|
||||
expect(yargsMock.positional).toHaveBeenCalledWith('name', {
|
||||
describe: 'The name or source path of the extension to uninstall.',
|
||||
type: 'string',
|
||||
});
|
||||
expect(yargsMock.check).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('check function should throw for missing name', () => {
|
||||
(command.builder as (yargs: Argv) => Argv)(
|
||||
yargsMock as unknown as Argv,
|
||||
);
|
||||
const checkCallback = yargsMock.check.mock.calls[0][0];
|
||||
expect(() => checkCallback({ name: '' })).toThrow(
|
||||
'Please include the name of the extension to uninstall as a positional argument.',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('handler should call handleUninstall', async () => {
|
||||
const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir');
|
||||
interface TestArgv {
|
||||
name: string;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
const argv: TestArgv = { name: 'my-extension', _: [], $0: '' };
|
||||
await (command.handler as unknown as (args: TestArgv) => void)(argv);
|
||||
|
||||
expect(
|
||||
mockExtensionManager.prototype.uninstallExtension,
|
||||
).toHaveBeenCalledWith('my-extension', false);
|
||||
mockCwd.mockRestore();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
226
packages/cli/src/commands/extensions/update.test.ts
Normal file
226
packages/cli/src/commands/extensions/update.test.ts
Normal file
@@ -0,0 +1,226 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
vi,
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import { type CommandModule, type Argv } from 'yargs';
|
||||
import { handleUpdate, updateCommand } from './update.js';
|
||||
import { ExtensionManager } from '../../config/extension-manager.js';
|
||||
import { loadSettings, type LoadedSettings } from '../../config/settings.js';
|
||||
import * as update from '../../config/extensions/update.js';
|
||||
import * as github from '../../config/extensions/github.js';
|
||||
import { ExtensionUpdateState } from '../../ui/state/extensions.js';
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('../../config/extension-manager.js');
|
||||
vi.mock('../../config/settings.js');
|
||||
vi.mock('../../utils/errors.js');
|
||||
vi.mock('../../config/extensions/update.js');
|
||||
vi.mock('../../config/extensions/github.js');
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
const actual =
|
||||
await importOriginal<typeof import('@google/gemini-cli-core')>();
|
||||
return {
|
||||
...actual,
|
||||
debugLogger: {
|
||||
log: vi.fn(),
|
||||
error: vi.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
vi.mock('../../config/extensions/consent.js', () => ({
|
||||
requestConsentNonInteractive: vi.fn(),
|
||||
}));
|
||||
vi.mock('../../config/extensions/extensionSettings.js', () => ({
|
||||
promptForSetting: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('extensions update command', () => {
|
||||
const mockLoadSettings = vi.mocked(loadSettings);
|
||||
const mockExtensionManager = vi.mocked(ExtensionManager);
|
||||
const mockUpdateExtension = vi.mocked(update.updateExtension);
|
||||
const mockCheckForExtensionUpdate = vi.mocked(github.checkForExtensionUpdate);
|
||||
const mockCheckForAllExtensionUpdates = vi.mocked(
|
||||
update.checkForAllExtensionUpdates,
|
||||
);
|
||||
const mockUpdateAllUpdatableExtensions = vi.mocked(
|
||||
update.updateAllUpdatableExtensions,
|
||||
);
|
||||
|
||||
interface MockDebugLogger {
|
||||
log: Mock;
|
||||
error: Mock;
|
||||
}
|
||||
let mockDebugLogger: MockDebugLogger;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks();
|
||||
mockDebugLogger = (await import('@google/gemini-cli-core'))
|
||||
.debugLogger as unknown as MockDebugLogger;
|
||||
mockLoadSettings.mockReturnValue({
|
||||
merged: { experimental: { extensionReloading: true } },
|
||||
} as unknown as LoadedSettings);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('handleUpdate', () => {
|
||||
it.each([
|
||||
{
|
||||
state: ExtensionUpdateState.UPDATE_AVAILABLE,
|
||||
expectedLog:
|
||||
'Extension "my-extension" successfully updated: 1.0.0 → 1.1.0.',
|
||||
shouldCallUpdateExtension: true,
|
||||
},
|
||||
{
|
||||
state: ExtensionUpdateState.UP_TO_DATE,
|
||||
expectedLog: 'Extension "my-extension" is already up to date.',
|
||||
shouldCallUpdateExtension: false,
|
||||
},
|
||||
])(
|
||||
'should handle single extension update state: $state',
|
||||
async ({ state, expectedLog, shouldCallUpdateExtension }) => {
|
||||
const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir');
|
||||
const extensions = [{ name: 'my-extension', installMetadata: {} }];
|
||||
mockExtensionManager.prototype.loadExtensions = vi
|
||||
.fn()
|
||||
.mockResolvedValue(extensions);
|
||||
mockCheckForExtensionUpdate.mockResolvedValue(state);
|
||||
mockUpdateExtension.mockResolvedValue({
|
||||
name: 'my-extension',
|
||||
originalVersion: '1.0.0',
|
||||
updatedVersion: '1.1.0',
|
||||
});
|
||||
|
||||
await handleUpdate({ name: 'my-extension' });
|
||||
|
||||
expect(mockDebugLogger.log).toHaveBeenCalledWith(expectedLog);
|
||||
if (shouldCallUpdateExtension) {
|
||||
expect(mockUpdateExtension).toHaveBeenCalled();
|
||||
} else {
|
||||
expect(mockUpdateExtension).not.toHaveBeenCalled();
|
||||
}
|
||||
mockCwd.mockRestore();
|
||||
},
|
||||
);
|
||||
|
||||
it.each([
|
||||
{
|
||||
updatedExtensions: [
|
||||
{ name: 'ext1', originalVersion: '1.0.0', updatedVersion: '1.1.0' },
|
||||
{ name: 'ext2', originalVersion: '2.0.0', updatedVersion: '2.1.0' },
|
||||
],
|
||||
expectedLog:
|
||||
'Extension "ext1" successfully updated: 1.0.0 → 1.1.0.\nExtension "ext2" successfully updated: 2.0.0 → 2.1.0.',
|
||||
},
|
||||
{
|
||||
updatedExtensions: [],
|
||||
expectedLog: 'No extensions to update.',
|
||||
},
|
||||
])(
|
||||
'should handle updating all extensions: %s',
|
||||
async ({ updatedExtensions, expectedLog }) => {
|
||||
const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir');
|
||||
mockExtensionManager.prototype.loadExtensions = vi
|
||||
.fn()
|
||||
.mockResolvedValue([]);
|
||||
mockCheckForAllExtensionUpdates.mockResolvedValue(undefined);
|
||||
mockUpdateAllUpdatableExtensions.mockResolvedValue(updatedExtensions);
|
||||
|
||||
await handleUpdate({ all: true });
|
||||
|
||||
expect(mockDebugLogger.log).toHaveBeenCalledWith(expectedLog);
|
||||
mockCwd.mockRestore();
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
describe('updateCommand', () => {
|
||||
const command = updateCommand as CommandModule;
|
||||
|
||||
it('should have correct command and describe', () => {
|
||||
expect(command.command).toBe('update [<name>] [--all]');
|
||||
expect(command.describe).toBe(
|
||||
'Updates all extensions or a named extension to the latest version.',
|
||||
);
|
||||
});
|
||||
|
||||
describe('builder', () => {
|
||||
interface MockYargs {
|
||||
positional: Mock;
|
||||
option: Mock;
|
||||
conflicts: Mock;
|
||||
check: Mock;
|
||||
}
|
||||
|
||||
let yargsMock: MockYargs;
|
||||
beforeEach(() => {
|
||||
yargsMock = {
|
||||
positional: vi.fn().mockReturnThis(),
|
||||
option: vi.fn().mockReturnThis(),
|
||||
conflicts: vi.fn().mockReturnThis(),
|
||||
check: vi.fn().mockReturnThis(),
|
||||
};
|
||||
});
|
||||
|
||||
it('should configure arguments', () => {
|
||||
(command.builder as (yargs: Argv) => Argv)(
|
||||
yargsMock as unknown as Argv,
|
||||
);
|
||||
expect(yargsMock.positional).toHaveBeenCalledWith(
|
||||
'name',
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(yargsMock.option).toHaveBeenCalledWith(
|
||||
'all',
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(yargsMock.conflicts).toHaveBeenCalledWith('name', 'all');
|
||||
expect(yargsMock.check).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('check function should throw an error if neither a name nor --all is provided', () => {
|
||||
(command.builder as (yargs: Argv) => Argv)(
|
||||
yargsMock as unknown as Argv,
|
||||
);
|
||||
const checkCallback = yargsMock.check.mock.calls[0][0];
|
||||
expect(() => checkCallback({ name: undefined, all: false })).toThrow(
|
||||
'Either an extension name or --all must be provided',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('handler should call handleUpdate', async () => {
|
||||
const extensions = [{ name: 'my-extension', installMetadata: {} }];
|
||||
mockExtensionManager.prototype.loadExtensions = vi
|
||||
.fn()
|
||||
.mockResolvedValue(extensions);
|
||||
mockCheckForExtensionUpdate.mockResolvedValue(
|
||||
ExtensionUpdateState.UPDATE_AVAILABLE,
|
||||
);
|
||||
mockUpdateExtension.mockResolvedValue({
|
||||
name: 'my-extension',
|
||||
originalVersion: '1.0.0',
|
||||
updatedVersion: '1.1.0',
|
||||
});
|
||||
|
||||
await (command.handler as (args: object) => Promise<void>)({
|
||||
name: 'my-extension',
|
||||
});
|
||||
|
||||
expect(mockUpdateExtension).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -451,7 +451,7 @@ describe('startInteractiveUI', () => {
|
||||
const mockConfig = {
|
||||
getProjectRoot: () => '/root',
|
||||
getScreenReader: () => false,
|
||||
} as Config;
|
||||
} as unknown as Config;
|
||||
const mockSettings = {
|
||||
merged: {
|
||||
ui: {
|
||||
@@ -477,7 +477,6 @@ describe('startInteractiveUI', () => {
|
||||
isKittyProtocolSupported: vi.fn(() => true),
|
||||
isKittyProtocolEnabled: vi.fn(() => true),
|
||||
}));
|
||||
|
||||
vi.mock('./ui/utils/updateCheck.js', () => ({
|
||||
checkForUpdates: vi.fn(() => Promise.resolve(null)),
|
||||
}));
|
||||
|
||||
296
packages/cli/src/zed-integration/acp.test.ts
Normal file
296
packages/cli/src/zed-integration/acp.test.ts
Normal file
@@ -0,0 +1,296 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { vi, describe, it, expect, beforeEach, type Mock } from 'vitest';
|
||||
import {
|
||||
AgentSideConnection,
|
||||
RequestError,
|
||||
type Agent,
|
||||
type Client,
|
||||
} from './acp.js';
|
||||
import { type ErrorResponse } from './schema.js';
|
||||
import { type MethodHandler } from './connection.js';
|
||||
import { ReadableStream, WritableStream } from 'node:stream/web';
|
||||
|
||||
const mockConnectionConstructor = vi.hoisted(() =>
|
||||
vi.fn<
|
||||
(
|
||||
arg1: MethodHandler,
|
||||
arg2: WritableStream<Uint8Array>,
|
||||
arg3: ReadableStream<Uint8Array>,
|
||||
) => { sendRequest: Mock; sendNotification: Mock }
|
||||
>(() => ({
|
||||
sendRequest: vi.fn(),
|
||||
sendNotification: vi.fn(),
|
||||
})),
|
||||
);
|
||||
|
||||
vi.mock('./connection.js', async (importOriginal) => {
|
||||
const actual = await importOriginal();
|
||||
return {
|
||||
...(actual as object),
|
||||
Connection: mockConnectionConstructor,
|
||||
};
|
||||
});
|
||||
|
||||
describe('acp', () => {
|
||||
describe('RequestError', () => {
|
||||
it('should create a parse error', () => {
|
||||
const error = RequestError.parseError('details');
|
||||
expect(error.code).toBe(-32700);
|
||||
expect(error.message).toBe('Parse error');
|
||||
expect(error.data?.details).toBe('details');
|
||||
});
|
||||
|
||||
it('should create a method not found error', () => {
|
||||
const error = RequestError.methodNotFound('details');
|
||||
expect(error.code).toBe(-32601);
|
||||
expect(error.message).toBe('Method not found');
|
||||
expect(error.data?.details).toBe('details');
|
||||
});
|
||||
|
||||
it('should convert to a result', () => {
|
||||
const error = RequestError.internalError('details');
|
||||
const result = error.toResult() as { error: ErrorResponse };
|
||||
expect(result.error.code).toBe(-32603);
|
||||
expect(result.error.message).toBe('Internal error');
|
||||
expect(result.error.data).toEqual({ details: 'details' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('AgentSideConnection', () => {
|
||||
let mockAgent: Agent;
|
||||
|
||||
let toAgent: WritableStream<Uint8Array>;
|
||||
let fromAgent: ReadableStream<Uint8Array>;
|
||||
let agentSideConnection: AgentSideConnection;
|
||||
let connectionInstance: InstanceType<typeof mockConnectionConstructor>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
const initializeResponse = {
|
||||
agentCapabilities: { loadSession: true },
|
||||
authMethods: [],
|
||||
protocolVersion: 1,
|
||||
};
|
||||
const newSessionResponse = { sessionId: 'session-1' };
|
||||
const loadSessionResponse = { sessionId: 'session-1' };
|
||||
|
||||
mockAgent = {
|
||||
initialize: vi.fn().mockResolvedValue(initializeResponse),
|
||||
newSession: vi.fn().mockResolvedValue(newSessionResponse),
|
||||
loadSession: vi.fn().mockResolvedValue(loadSessionResponse),
|
||||
authenticate: vi.fn(),
|
||||
prompt: vi.fn(),
|
||||
cancel: vi.fn(),
|
||||
};
|
||||
|
||||
toAgent = new WritableStream<Uint8Array>();
|
||||
fromAgent = new ReadableStream<Uint8Array>();
|
||||
|
||||
agentSideConnection = new AgentSideConnection(
|
||||
(_client: Client) => mockAgent,
|
||||
toAgent,
|
||||
fromAgent,
|
||||
);
|
||||
|
||||
// Get the mocked Connection instance
|
||||
connectionInstance = mockConnectionConstructor.mock.results[0].value;
|
||||
});
|
||||
|
||||
it('should initialize Connection with the correct handler and streams', () => {
|
||||
expect(mockConnectionConstructor).toHaveBeenCalledTimes(1);
|
||||
expect(mockConnectionConstructor).toHaveBeenCalledWith(
|
||||
expect.any(Function),
|
||||
toAgent,
|
||||
fromAgent,
|
||||
);
|
||||
});
|
||||
|
||||
it('should call agent.initialize when Connection handler receives initialize method', async () => {
|
||||
const initializeParams = {
|
||||
clientCapabilities: { fs: { readTextFile: true, writeTextFile: true } },
|
||||
protocolVersion: 1,
|
||||
};
|
||||
const initializeResponse = {
|
||||
agentCapabilities: { loadSession: true },
|
||||
authMethods: [],
|
||||
protocolVersion: 1,
|
||||
};
|
||||
const handler = mockConnectionConstructor.mock
|
||||
.calls[0][0]! as MethodHandler;
|
||||
const result = await handler('initialize', initializeParams);
|
||||
|
||||
expect(mockAgent.initialize).toHaveBeenCalledWith(initializeParams);
|
||||
expect(result).toEqual(initializeResponse);
|
||||
});
|
||||
|
||||
it('should call agent.newSession when Connection handler receives session_new method', async () => {
|
||||
const newSessionParams = { cwd: '/tmp', mcpServers: [] };
|
||||
const newSessionResponse = { sessionId: 'session-1' };
|
||||
const handler = mockConnectionConstructor.mock
|
||||
.calls[0][0]! as MethodHandler;
|
||||
const result = await handler('session/new', newSessionParams);
|
||||
|
||||
expect(mockAgent.newSession).toHaveBeenCalledWith(newSessionParams);
|
||||
expect(result).toEqual(newSessionResponse);
|
||||
});
|
||||
|
||||
it('should call agent.loadSession when Connection handler receives session_load method', async () => {
|
||||
const loadSessionParams = {
|
||||
cwd: '/tmp',
|
||||
mcpServers: [],
|
||||
sessionId: 'session-1',
|
||||
};
|
||||
const loadSessionResponse = { sessionId: 'session-1' };
|
||||
const handler = mockConnectionConstructor.mock
|
||||
.calls[0][0]! as MethodHandler;
|
||||
const result = await handler('session/load', loadSessionParams);
|
||||
|
||||
expect(mockAgent.loadSession).toHaveBeenCalledWith(loadSessionParams);
|
||||
expect(result).toEqual(loadSessionResponse);
|
||||
});
|
||||
|
||||
it('should throw methodNotFound if agent.loadSession is not implemented', async () => {
|
||||
mockAgent.loadSession = undefined; // Simulate not implemented
|
||||
const loadSessionParams = {
|
||||
cwd: '/tmp',
|
||||
mcpServers: [],
|
||||
sessionId: 'session-1',
|
||||
};
|
||||
const handler = mockConnectionConstructor.mock
|
||||
.calls[0][0]! as MethodHandler;
|
||||
await expect(handler('session/load', loadSessionParams)).rejects.toThrow(
|
||||
RequestError.methodNotFound().message,
|
||||
);
|
||||
});
|
||||
|
||||
it('should call agent.authenticate when Connection handler receives authenticate method', async () => {
|
||||
const authenticateParams = {
|
||||
methodId: 'test-auth-method',
|
||||
authMethod: {
|
||||
id: 'test-auth',
|
||||
name: 'Test Auth Method',
|
||||
description: 'A test authentication method',
|
||||
},
|
||||
};
|
||||
const handler = mockConnectionConstructor.mock
|
||||
.calls[0][0]! as MethodHandler;
|
||||
const result = await handler('authenticate', authenticateParams);
|
||||
|
||||
expect(mockAgent.authenticate).toHaveBeenCalledWith(authenticateParams);
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should call agent.prompt when Connection handler receives session_prompt method', async () => {
|
||||
const promptParams = {
|
||||
prompt: [{ type: 'text', text: 'hi' }],
|
||||
sessionId: 'session-1',
|
||||
};
|
||||
const promptResponse = {
|
||||
response: [{ type: 'text', text: 'hello' }],
|
||||
traceId: 'trace-1',
|
||||
};
|
||||
(mockAgent.prompt as Mock).mockResolvedValue(promptResponse);
|
||||
const handler = mockConnectionConstructor.mock
|
||||
.calls[0][0]! as MethodHandler;
|
||||
const result = await handler('session/prompt', promptParams);
|
||||
|
||||
expect(mockAgent.prompt).toHaveBeenCalledWith(promptParams);
|
||||
expect(result).toEqual(promptResponse);
|
||||
});
|
||||
|
||||
it('should call agent.cancel when Connection handler receives session_cancel method', async () => {
|
||||
const cancelParams = { sessionId: 'session-1' };
|
||||
const handler = mockConnectionConstructor.mock
|
||||
.calls[0][0]! as MethodHandler;
|
||||
const result = await handler('session/cancel', cancelParams);
|
||||
|
||||
expect(mockAgent.cancel).toHaveBeenCalledWith(cancelParams);
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw methodNotFound for unknown methods', async () => {
|
||||
const handler = mockConnectionConstructor.mock
|
||||
.calls[0][0]! as MethodHandler;
|
||||
await expect(handler('unknown_method', {})).rejects.toThrow(
|
||||
RequestError.methodNotFound().message,
|
||||
);
|
||||
});
|
||||
|
||||
it('should send sessionUpdate notification via connection', async () => {
|
||||
const params = {
|
||||
sessionId: '123',
|
||||
update: {
|
||||
sessionUpdate: 'user_message_chunk' as const,
|
||||
content: { type: 'text' as const, text: 'hello' },
|
||||
},
|
||||
};
|
||||
await agentSideConnection.sessionUpdate(params);
|
||||
});
|
||||
|
||||
it('should send requestPermission request via connection', async () => {
|
||||
const params = {
|
||||
sessionId: '123',
|
||||
toolCall: {
|
||||
toolCallId: 'tool-1',
|
||||
title: 'Test Tool',
|
||||
kind: 'other' as const,
|
||||
status: 'pending' as const,
|
||||
},
|
||||
options: [
|
||||
{
|
||||
optionId: 'option-1',
|
||||
name: 'Allow',
|
||||
kind: 'allow_once' as const,
|
||||
},
|
||||
],
|
||||
};
|
||||
const response = {
|
||||
outcome: { outcome: 'selected', optionId: 'option-1' },
|
||||
};
|
||||
(connectionInstance.sendRequest as Mock).mockResolvedValue(response);
|
||||
|
||||
const result = await agentSideConnection.requestPermission(params);
|
||||
expect(connectionInstance.sendRequest).toHaveBeenCalledWith(
|
||||
'session/request_permission',
|
||||
params,
|
||||
);
|
||||
expect(result).toEqual(response);
|
||||
});
|
||||
|
||||
it('should send readTextFile request via connection', async () => {
|
||||
const params = { path: '/a/b.txt', sessionId: 'session-1' };
|
||||
const response = { content: 'file content' };
|
||||
(connectionInstance.sendRequest as Mock).mockResolvedValue(response);
|
||||
|
||||
const result = await agentSideConnection.readTextFile(params);
|
||||
expect(connectionInstance.sendRequest).toHaveBeenCalledWith(
|
||||
'fs/read_text_file',
|
||||
params,
|
||||
);
|
||||
expect(result).toEqual(response);
|
||||
});
|
||||
|
||||
it('should send writeTextFile request via connection', async () => {
|
||||
const params = {
|
||||
path: '/a/b.txt',
|
||||
content: 'new content',
|
||||
sessionId: 'session-1',
|
||||
};
|
||||
const response = { success: true };
|
||||
(connectionInstance.sendRequest as Mock).mockResolvedValue(response);
|
||||
|
||||
const result = await agentSideConnection.writeTextFile(params);
|
||||
expect(connectionInstance.sendRequest).toHaveBeenCalledWith(
|
||||
'fs/write_text_file',
|
||||
params,
|
||||
);
|
||||
expect(result).toEqual(response);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -6,12 +6,12 @@
|
||||
|
||||
/* ACP defines a schema for a simple (experimental) JSON-RPC protocol that allows GUI applications to interact with agents. */
|
||||
|
||||
import { z } from 'zod';
|
||||
import * as schema from './schema.js';
|
||||
export * from './schema.js';
|
||||
|
||||
import type { WritableStream, ReadableStream } from 'node:stream/web';
|
||||
import { coreEvents } from '@google/gemini-cli-core';
|
||||
import { Connection, RequestError } from './connection.js';
|
||||
export { RequestError };
|
||||
|
||||
export class AgentSideConnection implements Client {
|
||||
#connection: Connection;
|
||||
@@ -108,236 +108,6 @@ export class AgentSideConnection implements Client {
|
||||
}
|
||||
}
|
||||
|
||||
type AnyMessage = AnyRequest | AnyResponse | AnyNotification;
|
||||
|
||||
type AnyRequest = {
|
||||
jsonrpc: '2.0';
|
||||
id: string | number;
|
||||
method: string;
|
||||
params?: unknown;
|
||||
};
|
||||
|
||||
type AnyResponse = {
|
||||
jsonrpc: '2.0';
|
||||
id: string | number;
|
||||
} & Result<unknown>;
|
||||
|
||||
type AnyNotification = {
|
||||
jsonrpc: '2.0';
|
||||
method: string;
|
||||
params?: unknown;
|
||||
};
|
||||
|
||||
type Result<T> =
|
||||
| {
|
||||
result: T;
|
||||
}
|
||||
| {
|
||||
error: ErrorResponse;
|
||||
};
|
||||
|
||||
type ErrorResponse = {
|
||||
code: number;
|
||||
message: string;
|
||||
data?: unknown;
|
||||
};
|
||||
|
||||
type PendingResponse = {
|
||||
resolve: (response: unknown) => void;
|
||||
reject: (error: ErrorResponse) => void;
|
||||
};
|
||||
|
||||
type MethodHandler = (method: string, params: unknown) => Promise<unknown>;
|
||||
|
||||
class Connection {
|
||||
#pendingResponses: Map<string | number, PendingResponse> = new Map();
|
||||
#nextRequestId: number = 0;
|
||||
#handler: MethodHandler;
|
||||
#peerInput: WritableStream<Uint8Array>;
|
||||
#writeQueue: Promise<void> = Promise.resolve();
|
||||
#textEncoder: TextEncoder;
|
||||
|
||||
constructor(
|
||||
handler: MethodHandler,
|
||||
peerInput: WritableStream<Uint8Array>,
|
||||
peerOutput: ReadableStream<Uint8Array>,
|
||||
) {
|
||||
this.#handler = handler;
|
||||
this.#peerInput = peerInput;
|
||||
this.#textEncoder = new TextEncoder();
|
||||
this.#receive(peerOutput);
|
||||
}
|
||||
|
||||
async #receive(output: ReadableStream<Uint8Array>) {
|
||||
let content = '';
|
||||
const decoder = new TextDecoder();
|
||||
for await (const chunk of output) {
|
||||
content += decoder.decode(chunk, { stream: true });
|
||||
const lines = content.split('\n');
|
||||
content = lines.pop() || '';
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmedLine = line.trim();
|
||||
|
||||
if (trimmedLine) {
|
||||
const message = JSON.parse(trimmedLine);
|
||||
this.#processMessage(message);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async #processMessage(message: AnyMessage) {
|
||||
if ('method' in message && 'id' in message) {
|
||||
// It's a request
|
||||
const response = await this.#tryCallHandler(
|
||||
message.method,
|
||||
message.params,
|
||||
);
|
||||
|
||||
await this.#sendMessage({
|
||||
jsonrpc: '2.0',
|
||||
id: message.id,
|
||||
...response,
|
||||
});
|
||||
} else if ('method' in message) {
|
||||
// It's a notification
|
||||
await this.#tryCallHandler(message.method, message.params);
|
||||
} else if ('id' in message) {
|
||||
// It's a response
|
||||
this.#handleResponse(message as AnyResponse);
|
||||
}
|
||||
}
|
||||
|
||||
async #tryCallHandler(
|
||||
method: string,
|
||||
params?: unknown,
|
||||
): Promise<Result<unknown>> {
|
||||
try {
|
||||
const result = await this.#handler(method, params);
|
||||
return { result: result ?? null };
|
||||
} catch (error: unknown) {
|
||||
if (error instanceof RequestError) {
|
||||
return error.toResult();
|
||||
}
|
||||
|
||||
if (error instanceof z.ZodError) {
|
||||
return RequestError.invalidParams(
|
||||
JSON.stringify(error.format(), undefined, 2),
|
||||
).toResult();
|
||||
}
|
||||
|
||||
let details;
|
||||
|
||||
if (error instanceof Error) {
|
||||
details = error.message;
|
||||
} else if (
|
||||
typeof error === 'object' &&
|
||||
error != null &&
|
||||
'message' in error &&
|
||||
typeof error.message === 'string'
|
||||
) {
|
||||
details = error.message;
|
||||
}
|
||||
|
||||
return RequestError.internalError(details).toResult();
|
||||
}
|
||||
}
|
||||
|
||||
#handleResponse(response: AnyResponse) {
|
||||
const pendingResponse = this.#pendingResponses.get(response.id);
|
||||
if (pendingResponse) {
|
||||
if ('result' in response) {
|
||||
pendingResponse.resolve(response.result);
|
||||
} else if ('error' in response) {
|
||||
pendingResponse.reject(response.error);
|
||||
}
|
||||
this.#pendingResponses.delete(response.id);
|
||||
}
|
||||
}
|
||||
|
||||
async sendRequest<Req, Resp>(method: string, params?: Req): Promise<Resp> {
|
||||
const id = this.#nextRequestId++;
|
||||
const responsePromise = new Promise((resolve, reject) => {
|
||||
this.#pendingResponses.set(id, { resolve, reject });
|
||||
});
|
||||
await this.#sendMessage({ jsonrpc: '2.0', id, method, params });
|
||||
return responsePromise as Promise<Resp>;
|
||||
}
|
||||
|
||||
async sendNotification<N>(method: string, params?: N): Promise<void> {
|
||||
await this.#sendMessage({ jsonrpc: '2.0', method, params });
|
||||
}
|
||||
|
||||
async #sendMessage(json: AnyMessage) {
|
||||
const content = JSON.stringify(json) + '\n';
|
||||
this.#writeQueue = this.#writeQueue
|
||||
.then(async () => {
|
||||
const writer = this.#peerInput.getWriter();
|
||||
try {
|
||||
await writer.write(this.#textEncoder.encode(content));
|
||||
} finally {
|
||||
writer.releaseLock();
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
// Continue processing writes on error
|
||||
coreEvents.emitFeedback('error', 'ACP write error.', error);
|
||||
});
|
||||
return this.#writeQueue;
|
||||
}
|
||||
}
|
||||
|
||||
export class RequestError extends Error {
|
||||
data?: { details?: string };
|
||||
|
||||
constructor(
|
||||
public code: number,
|
||||
message: string,
|
||||
details?: string,
|
||||
) {
|
||||
super(message);
|
||||
this.name = 'RequestError';
|
||||
if (details) {
|
||||
this.data = { details };
|
||||
}
|
||||
}
|
||||
|
||||
static parseError(details?: string): RequestError {
|
||||
return new RequestError(-32700, 'Parse error', details);
|
||||
}
|
||||
|
||||
static invalidRequest(details?: string): RequestError {
|
||||
return new RequestError(-32600, 'Invalid request', details);
|
||||
}
|
||||
|
||||
static methodNotFound(details?: string): RequestError {
|
||||
return new RequestError(-32601, 'Method not found', details);
|
||||
}
|
||||
|
||||
static invalidParams(details?: string): RequestError {
|
||||
return new RequestError(-32602, 'Invalid params', details);
|
||||
}
|
||||
|
||||
static internalError(details?: string): RequestError {
|
||||
return new RequestError(-32603, 'Internal error', details);
|
||||
}
|
||||
|
||||
static authRequired(details?: string): RequestError {
|
||||
return new RequestError(-32000, 'Authentication required', details);
|
||||
}
|
||||
|
||||
toResult<T>(): Result<T> {
|
||||
return {
|
||||
error: {
|
||||
code: this.code,
|
||||
message: this.message,
|
||||
data: this.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export interface Client {
|
||||
requestPermission(
|
||||
params: schema.RequestPermissionRequest,
|
||||
|
||||
229
packages/cli/src/zed-integration/connection.ts
Normal file
229
packages/cli/src/zed-integration/connection.ts
Normal file
@@ -0,0 +1,229 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { z } from 'zod';
|
||||
import { coreEvents } from '@google/gemini-cli-core';
|
||||
import { type Result, type ErrorResponse } from './schema.js';
|
||||
import type { WritableStream, ReadableStream } from 'node:stream/web';
|
||||
|
||||
export class RequestError extends Error {
|
||||
data?: { details?: string };
|
||||
|
||||
constructor(
|
||||
public code: number,
|
||||
message: string,
|
||||
details?: string,
|
||||
) {
|
||||
super(message);
|
||||
this.name = 'RequestError';
|
||||
if (details) {
|
||||
this.data = { details };
|
||||
}
|
||||
}
|
||||
|
||||
static parseError(details?: string): RequestError {
|
||||
return new RequestError(-32700, 'Parse error', details);
|
||||
}
|
||||
|
||||
static invalidRequest(details?: string): RequestError {
|
||||
return new RequestError(-32600, 'Invalid request', details);
|
||||
}
|
||||
|
||||
static methodNotFound(details?: string): RequestError {
|
||||
return new RequestError(-32601, 'Method not found', details);
|
||||
}
|
||||
|
||||
static invalidParams(details?: string): RequestError {
|
||||
return new RequestError(-32602, 'Invalid params', details);
|
||||
}
|
||||
|
||||
static internalError(details?: string): RequestError {
|
||||
return new RequestError(-32603, 'Internal error', details);
|
||||
}
|
||||
|
||||
static authRequired(details?: string): RequestError {
|
||||
return new RequestError(-32000, 'Authentication required', details);
|
||||
}
|
||||
|
||||
toResult<T>(): Result<T> {
|
||||
return {
|
||||
error: {
|
||||
code: this.code,
|
||||
message: this.message,
|
||||
data: this.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
type AnyMessage = AnyRequest | AnyResponse | AnyNotification;
|
||||
|
||||
type AnyRequest = {
|
||||
jsonrpc: '2.0';
|
||||
id: string | number;
|
||||
method: string;
|
||||
params?: unknown;
|
||||
};
|
||||
|
||||
type AnyResponse = {
|
||||
jsonrpc: '2.0';
|
||||
id: string | number;
|
||||
} & Result<unknown>;
|
||||
|
||||
type AnyNotification = {
|
||||
jsonrpc: '2.0';
|
||||
method: string;
|
||||
params?: unknown;
|
||||
};
|
||||
|
||||
type PendingResponse = {
|
||||
resolve: (response: unknown) => void;
|
||||
reject: (error: ErrorResponse) => void;
|
||||
};
|
||||
|
||||
export type MethodHandler = (
|
||||
method: string,
|
||||
params: unknown,
|
||||
) => Promise<unknown>;
|
||||
|
||||
export class Connection {
|
||||
#pendingResponses: Map<string | number, PendingResponse> = new Map();
|
||||
#nextRequestId: number = 0;
|
||||
#handler: MethodHandler;
|
||||
#peerInput: WritableStream<Uint8Array>;
|
||||
#writeQueue: Promise<void> = Promise.resolve();
|
||||
#textEncoder: TextEncoder;
|
||||
|
||||
constructor(
|
||||
handler: MethodHandler,
|
||||
peerInput: WritableStream<Uint8Array>,
|
||||
peerOutput: ReadableStream<Uint8Array>,
|
||||
) {
|
||||
this.#handler = handler;
|
||||
this.#peerInput = peerInput;
|
||||
this.#textEncoder = new TextEncoder();
|
||||
this.#receive(peerOutput);
|
||||
}
|
||||
|
||||
async #receive(output: ReadableStream<Uint8Array>) {
|
||||
let content = '';
|
||||
const decoder = new TextDecoder();
|
||||
for await (const chunk of output) {
|
||||
content += decoder.decode(chunk, { stream: true });
|
||||
const lines = content.split('\n');
|
||||
content = lines.pop() || '';
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmedLine = line.trim();
|
||||
|
||||
if (trimmedLine) {
|
||||
const message = JSON.parse(trimmedLine);
|
||||
this.#processMessage(message);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async #processMessage(message: AnyMessage) {
|
||||
if ('method' in message && 'id' in message) {
|
||||
// It's a request
|
||||
const response = await this.#tryCallHandler(
|
||||
message.method,
|
||||
message.params,
|
||||
);
|
||||
|
||||
await this.#sendMessage({
|
||||
jsonrpc: '2.0',
|
||||
id: message.id,
|
||||
...response,
|
||||
});
|
||||
} else if ('method' in message) {
|
||||
// It's a notification
|
||||
await this.#tryCallHandler(message.method, message.params);
|
||||
} else if ('id' in message) {
|
||||
// It's a response
|
||||
this.#handleResponse(message as AnyResponse);
|
||||
}
|
||||
}
|
||||
|
||||
async #tryCallHandler(
|
||||
method: string,
|
||||
params?: unknown,
|
||||
): Promise<Result<unknown>> {
|
||||
try {
|
||||
const result = await this.#handler(method, params);
|
||||
return { result: result ?? null };
|
||||
} catch (error: unknown) {
|
||||
if (error instanceof RequestError) {
|
||||
return error.toResult();
|
||||
}
|
||||
|
||||
if (error instanceof z.ZodError) {
|
||||
return RequestError.invalidParams(
|
||||
JSON.stringify(error.format(), undefined, 2),
|
||||
).toResult();
|
||||
}
|
||||
|
||||
let details;
|
||||
|
||||
if (error instanceof Error) {
|
||||
details = error.message;
|
||||
} else if (
|
||||
typeof error === 'object' &&
|
||||
error != null &&
|
||||
'message' in error &&
|
||||
typeof error.message === 'string'
|
||||
) {
|
||||
details = error.message;
|
||||
}
|
||||
|
||||
return RequestError.internalError(details).toResult();
|
||||
}
|
||||
}
|
||||
|
||||
#handleResponse(response: AnyResponse) {
|
||||
const pendingResponse = this.#pendingResponses.get(response.id);
|
||||
if (pendingResponse) {
|
||||
if ('result' in response) {
|
||||
pendingResponse.resolve(response.result);
|
||||
} else if ('error' in response) {
|
||||
pendingResponse.reject(response.error);
|
||||
}
|
||||
this.#pendingResponses.delete(response.id);
|
||||
}
|
||||
}
|
||||
|
||||
async sendRequest<Req, Resp>(method: string, params?: Req): Promise<Resp> {
|
||||
const id = this.#nextRequestId++;
|
||||
const responsePromise = new Promise((resolve, reject) => {
|
||||
this.#pendingResponses.set(id, { resolve, reject });
|
||||
});
|
||||
await this.#sendMessage({ jsonrpc: '2.0', id, method, params });
|
||||
return responsePromise as Promise<Resp>;
|
||||
}
|
||||
|
||||
async sendNotification<N>(method: string, params?: N): Promise<void> {
|
||||
await this.#sendMessage({ jsonrpc: '2.0', method, params });
|
||||
}
|
||||
|
||||
async #sendMessage(json: AnyMessage) {
|
||||
const content = JSON.stringify(json) + '\n';
|
||||
this.#writeQueue = this.#writeQueue
|
||||
.then(async () => {
|
||||
const writer = this.#peerInput.getWriter();
|
||||
try {
|
||||
await writer.write(this.#textEncoder.encode(content));
|
||||
} finally {
|
||||
writer.releaseLock();
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
// Continue processing writes on error
|
||||
coreEvents.emitFeedback('error', 'ACP write error.', error);
|
||||
});
|
||||
return this.#writeQueue;
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,12 @@ export const CLIENT_METHODS = {
|
||||
|
||||
export const PROTOCOL_VERSION = 1;
|
||||
|
||||
export const authMethodSchema = z.object({
|
||||
description: z.string().nullable(),
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
});
|
||||
|
||||
export type WriteTextFileRequest = z.infer<typeof writeTextFileRequestSchema>;
|
||||
|
||||
export type ReadTextFileRequest = z.infer<typeof readTextFileRequestSchema>;
|
||||
@@ -128,6 +134,20 @@ export type AgentRequest = z.infer<typeof agentRequestSchema>;
|
||||
|
||||
export type AgentNotification = z.infer<typeof agentNotificationSchema>;
|
||||
|
||||
export type Result<T> =
|
||||
| {
|
||||
result: T;
|
||||
}
|
||||
| {
|
||||
error: ErrorResponse;
|
||||
};
|
||||
|
||||
export type ErrorResponse = {
|
||||
code: number;
|
||||
message: string;
|
||||
data?: unknown;
|
||||
};
|
||||
|
||||
export const writeTextFileRequestSchema = z.object({
|
||||
content: z.string(),
|
||||
path: z.string(),
|
||||
@@ -203,6 +223,7 @@ export const cancelNotificationSchema = z.object({
|
||||
|
||||
export const authenticateRequestSchema = z.object({
|
||||
methodId: z.string(),
|
||||
authMethod: authMethodSchema,
|
||||
});
|
||||
|
||||
export const authenticateResponseSchema = z.null();
|
||||
@@ -283,12 +304,6 @@ export const agentCapabilitiesSchema = z.object({
|
||||
promptCapabilities: promptCapabilitiesSchema.optional(),
|
||||
});
|
||||
|
||||
export const authMethodSchema = z.object({
|
||||
description: z.string().nullable(),
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
});
|
||||
|
||||
export const clientResponseSchema = z.union([
|
||||
writeTextFileResponseSchema,
|
||||
readTextFileResponseSchema,
|
||||
|
||||
Reference in New Issue
Block a user