mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-15 16:41:11 -07:00
feat: Implement slash command handling in ACP for /memory,/init,/extensions and /restore (#20528)
This commit is contained in:
@@ -93,6 +93,7 @@ describe('GeminiAgent Session Resume', () => {
|
||||
},
|
||||
getApprovalMode: vi.fn().mockReturnValue('default'),
|
||||
isPlanEnabled: vi.fn().mockReturnValue(false),
|
||||
getCheckpointingEnabled: vi.fn().mockReturnValue(false),
|
||||
} as unknown as Mocked<Config>;
|
||||
mockSettings = {
|
||||
merged: {
|
||||
|
||||
30
packages/cli/src/zed-integration/commandHandler.test.ts
Normal file
30
packages/cli/src/zed-integration/commandHandler.test.ts
Normal file
@@ -0,0 +1,30 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { CommandHandler } from './commandHandler.js';
|
||||
import { describe, it, expect } from 'vitest';
|
||||
|
||||
describe('CommandHandler', () => {
|
||||
it('parses commands correctly', () => {
|
||||
const handler = new CommandHandler();
|
||||
// @ts-expect-error - testing private method
|
||||
const parse = (query: string) => handler.parseSlashCommand(query);
|
||||
|
||||
const memShow = parse('/memory show');
|
||||
expect(memShow.commandToExecute?.name).toBe('memory show');
|
||||
expect(memShow.args).toBe('');
|
||||
|
||||
const memAdd = parse('/memory add hello world');
|
||||
expect(memAdd.commandToExecute?.name).toBe('memory add');
|
||||
expect(memAdd.args).toBe('hello world');
|
||||
|
||||
const extList = parse('/extensions list');
|
||||
expect(extList.commandToExecute?.name).toBe('extensions list');
|
||||
|
||||
const init = parse('/init');
|
||||
expect(init.commandToExecute?.name).toBe('init');
|
||||
});
|
||||
});
|
||||
134
packages/cli/src/zed-integration/commandHandler.ts
Normal file
134
packages/cli/src/zed-integration/commandHandler.ts
Normal file
@@ -0,0 +1,134 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Command, CommandContext } from './commands/types.js';
|
||||
import { CommandRegistry } from './commands/commandRegistry.js';
|
||||
import { MemoryCommand } from './commands/memory.js';
|
||||
import { ExtensionsCommand } from './commands/extensions.js';
|
||||
import { InitCommand } from './commands/init.js';
|
||||
import { RestoreCommand } from './commands/restore.js';
|
||||
|
||||
export class CommandHandler {
|
||||
private registry: CommandRegistry;
|
||||
|
||||
constructor() {
|
||||
this.registry = CommandHandler.createRegistry();
|
||||
}
|
||||
|
||||
private static createRegistry(): CommandRegistry {
|
||||
const registry = new CommandRegistry();
|
||||
registry.register(new MemoryCommand());
|
||||
registry.register(new ExtensionsCommand());
|
||||
registry.register(new InitCommand());
|
||||
registry.register(new RestoreCommand());
|
||||
return registry;
|
||||
}
|
||||
|
||||
getAvailableCommands(): Array<{ name: string; description: string }> {
|
||||
return this.registry.getAllCommands().map((cmd) => ({
|
||||
name: cmd.name,
|
||||
description: cmd.description,
|
||||
}));
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses and executes a command string if it matches a registered command.
|
||||
* Returns true if a command was handled, false otherwise.
|
||||
*/
|
||||
async handleCommand(
|
||||
commandText: string,
|
||||
context: CommandContext,
|
||||
): Promise<boolean> {
|
||||
const { commandToExecute, args } = this.parseSlashCommand(commandText);
|
||||
|
||||
if (commandToExecute) {
|
||||
await this.runCommand(commandToExecute, args, context);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private async runCommand(
|
||||
commandToExecute: Command,
|
||||
args: string,
|
||||
context: CommandContext,
|
||||
): Promise<void> {
|
||||
try {
|
||||
const result = await commandToExecute.execute(
|
||||
context,
|
||||
args ? args.split(/\s+/) : [],
|
||||
);
|
||||
|
||||
let messageContent = '';
|
||||
if (typeof result.data === 'string') {
|
||||
messageContent = result.data;
|
||||
} else if (
|
||||
typeof result.data === 'object' &&
|
||||
result.data !== null &&
|
||||
'content' in result.data
|
||||
) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
|
||||
messageContent = (result.data as Record<string, any>)[
|
||||
'content'
|
||||
] as string;
|
||||
} else {
|
||||
messageContent = JSON.stringify(result.data, null, 2);
|
||||
}
|
||||
|
||||
await context.sendMessage(messageContent);
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
await context.sendMessage(`Error: ${errorMessage}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a raw slash command string into its matching headless command and arguments.
|
||||
* Mirrors `packages/cli/src/utils/commands.ts` logic.
|
||||
*/
|
||||
private parseSlashCommand(query: string): {
|
||||
commandToExecute: Command | undefined;
|
||||
args: string;
|
||||
} {
|
||||
const trimmed = query.trim();
|
||||
const parts = trimmed.substring(1).trim().split(/\s+/);
|
||||
const commandPath = parts.filter((p) => p);
|
||||
|
||||
let currentCommands = this.registry.getAllCommands();
|
||||
let commandToExecute: Command | undefined;
|
||||
let pathIndex = 0;
|
||||
|
||||
for (const part of commandPath) {
|
||||
const foundCommand = currentCommands.find((cmd) => {
|
||||
const expectedName = commandPath.slice(0, pathIndex + 1).join(' ');
|
||||
return (
|
||||
cmd.name === part ||
|
||||
cmd.name === expectedName ||
|
||||
cmd.aliases?.includes(part) ||
|
||||
cmd.aliases?.includes(expectedName)
|
||||
);
|
||||
});
|
||||
|
||||
if (foundCommand) {
|
||||
commandToExecute = foundCommand;
|
||||
pathIndex++;
|
||||
if (foundCommand.subCommands) {
|
||||
currentCommands = foundCommand.subCommands;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const args = parts.slice(pathIndex).join(' ');
|
||||
|
||||
return { commandToExecute, args };
|
||||
}
|
||||
}
|
||||
33
packages/cli/src/zed-integration/commands/commandRegistry.ts
Normal file
33
packages/cli/src/zed-integration/commands/commandRegistry.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { debugLogger } from '@google/gemini-cli-core';
|
||||
import type { Command } from './types.js';
|
||||
|
||||
export class CommandRegistry {
|
||||
private readonly commands = new Map<string, Command>();
|
||||
|
||||
register(command: Command) {
|
||||
if (this.commands.has(command.name)) {
|
||||
debugLogger.warn(`Command ${command.name} already registered. Skipping.`);
|
||||
return;
|
||||
}
|
||||
|
||||
this.commands.set(command.name, command);
|
||||
|
||||
for (const subCommand of command.subCommands ?? []) {
|
||||
this.register(subCommand);
|
||||
}
|
||||
}
|
||||
|
||||
get(commandName: string): Command | undefined {
|
||||
return this.commands.get(commandName);
|
||||
}
|
||||
|
||||
getAllCommands(): Command[] {
|
||||
return [...this.commands.values()];
|
||||
}
|
||||
}
|
||||
428
packages/cli/src/zed-integration/commands/extensions.ts
Normal file
428
packages/cli/src/zed-integration/commands/extensions.ts
Normal file
@@ -0,0 +1,428 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { listExtensions } from '@google/gemini-cli-core';
|
||||
import { SettingScope } from '../../config/settings.js';
|
||||
import {
|
||||
ExtensionManager,
|
||||
inferInstallMetadata,
|
||||
} from '../../config/extension-manager.js';
|
||||
import { getErrorMessage } from '../../utils/errors.js';
|
||||
import { McpServerEnablementManager } from '../../config/mcp/mcpServerEnablement.js';
|
||||
import { stat } from 'node:fs/promises';
|
||||
import type {
|
||||
Command,
|
||||
CommandContext,
|
||||
CommandExecutionResponse,
|
||||
} from './types.js';
|
||||
import type { Config } from '@google/gemini-cli-core';
|
||||
|
||||
export class ExtensionsCommand implements Command {
|
||||
readonly name = 'extensions';
|
||||
readonly description = 'Manage extensions.';
|
||||
readonly subCommands = [
|
||||
new ListExtensionsCommand(),
|
||||
new ExploreExtensionsCommand(),
|
||||
new EnableExtensionCommand(),
|
||||
new DisableExtensionCommand(),
|
||||
new InstallExtensionCommand(),
|
||||
new LinkExtensionCommand(),
|
||||
new UninstallExtensionCommand(),
|
||||
new RestartExtensionCommand(),
|
||||
new UpdateExtensionCommand(),
|
||||
];
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
_: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
return new ListExtensionsCommand().execute(context, _);
|
||||
}
|
||||
}
|
||||
|
||||
export class ListExtensionsCommand implements Command {
|
||||
readonly name = 'extensions list';
|
||||
readonly description = 'Lists all installed extensions.';
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
_: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const extensions = listExtensions(context.config);
|
||||
const data = extensions.length ? extensions : 'No extensions installed.';
|
||||
|
||||
return { name: this.name, data };
|
||||
}
|
||||
}
|
||||
|
||||
export class ExploreExtensionsCommand implements Command {
|
||||
readonly name = 'extensions explore';
|
||||
readonly description = 'Explore available extensions.';
|
||||
|
||||
async execute(
|
||||
_context: CommandContext,
|
||||
_: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const extensionsUrl = 'https://geminicli.com/extensions/';
|
||||
return {
|
||||
name: this.name,
|
||||
data: `View or install available extensions at ${extensionsUrl}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
function getEnableDisableContext(
|
||||
config: Config,
|
||||
args: string[],
|
||||
invocationName: string,
|
||||
) {
|
||||
const extensionManager = config.getExtensionLoader();
|
||||
if (!(extensionManager instanceof ExtensionManager)) {
|
||||
return {
|
||||
error: `Cannot ${invocationName} extensions in this environment.`,
|
||||
};
|
||||
}
|
||||
|
||||
if (args.length === 0) {
|
||||
return {
|
||||
error: `Usage: /extensions ${invocationName} <extension> [--scope=<user|workspace|session>]`,
|
||||
};
|
||||
}
|
||||
|
||||
let scope = SettingScope.User;
|
||||
if (args.includes('--scope=workspace') || args.includes('workspace')) {
|
||||
scope = SettingScope.Workspace;
|
||||
} else if (args.includes('--scope=session') || args.includes('session')) {
|
||||
scope = SettingScope.Session;
|
||||
}
|
||||
|
||||
const name = args.filter(
|
||||
(a) =>
|
||||
!a.startsWith('--scope') && !['user', 'workspace', 'session'].includes(a),
|
||||
)[0];
|
||||
|
||||
let names: string[] = [];
|
||||
if (name === '--all') {
|
||||
let extensions = extensionManager.getExtensions();
|
||||
if (invocationName === 'enable') {
|
||||
extensions = extensions.filter((ext) => !ext.isActive);
|
||||
}
|
||||
if (invocationName === 'disable') {
|
||||
extensions = extensions.filter((ext) => ext.isActive);
|
||||
}
|
||||
names = extensions.map((ext) => ext.name);
|
||||
} else if (name) {
|
||||
names = [name];
|
||||
} else {
|
||||
return { error: 'No extension name provided.' };
|
||||
}
|
||||
|
||||
return { extensionManager, names, scope };
|
||||
}
|
||||
|
||||
export class EnableExtensionCommand implements Command {
|
||||
readonly name = 'extensions enable';
|
||||
readonly description = 'Enable an extension.';
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
args: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const enableContext = getEnableDisableContext(
|
||||
context.config,
|
||||
args,
|
||||
'enable',
|
||||
);
|
||||
if ('error' in enableContext) {
|
||||
return { name: this.name, data: enableContext.error };
|
||||
}
|
||||
|
||||
const { names, scope, extensionManager } = enableContext;
|
||||
const output: string[] = [];
|
||||
|
||||
for (const name of names) {
|
||||
try {
|
||||
await extensionManager.enableExtension(name, scope);
|
||||
output.push(`Extension "${name}" enabled for scope "${scope}".`);
|
||||
|
||||
const extension = extensionManager
|
||||
.getExtensions()
|
||||
.find((e) => e.name === name);
|
||||
|
||||
if (extension?.mcpServers) {
|
||||
const mcpEnablementManager = McpServerEnablementManager.getInstance();
|
||||
const mcpClientManager = context.config.getMcpClientManager();
|
||||
const enabledServers = await mcpEnablementManager.autoEnableServers(
|
||||
Object.keys(extension.mcpServers),
|
||||
);
|
||||
|
||||
if (mcpClientManager && enabledServers.length > 0) {
|
||||
const restartPromises = enabledServers.map((serverName) =>
|
||||
mcpClientManager.restartServer(serverName).catch((error) => {
|
||||
output.push(
|
||||
`Failed to restart MCP server '${serverName}': ${getErrorMessage(error)}`,
|
||||
);
|
||||
}),
|
||||
);
|
||||
await Promise.all(restartPromises);
|
||||
output.push(`Re-enabled MCP servers: ${enabledServers.join(', ')}`);
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
output.push(`Failed to enable "${name}": ${getErrorMessage(e)}`);
|
||||
}
|
||||
}
|
||||
|
||||
return { name: this.name, data: output.join('\n') || 'No action taken.' };
|
||||
}
|
||||
}
|
||||
|
||||
export class DisableExtensionCommand implements Command {
|
||||
readonly name = 'extensions disable';
|
||||
readonly description = 'Disable an extension.';
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
args: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const enableContext = getEnableDisableContext(
|
||||
context.config,
|
||||
args,
|
||||
'disable',
|
||||
);
|
||||
if ('error' in enableContext) {
|
||||
return { name: this.name, data: enableContext.error };
|
||||
}
|
||||
|
||||
const { names, scope, extensionManager } = enableContext;
|
||||
const output: string[] = [];
|
||||
|
||||
for (const name of names) {
|
||||
try {
|
||||
await extensionManager.disableExtension(name, scope);
|
||||
output.push(`Extension "${name}" disabled for scope "${scope}".`);
|
||||
} catch (e) {
|
||||
output.push(`Failed to disable "${name}": ${getErrorMessage(e)}`);
|
||||
}
|
||||
}
|
||||
|
||||
return { name: this.name, data: output.join('\n') || 'No action taken.' };
|
||||
}
|
||||
}
|
||||
|
||||
export class InstallExtensionCommand implements Command {
|
||||
readonly name = 'extensions install';
|
||||
readonly description = 'Install an extension from a git repo or local path.';
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
args: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const extensionLoader = context.config.getExtensionLoader();
|
||||
if (!(extensionLoader instanceof ExtensionManager)) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: 'Cannot install extensions in this environment.',
|
||||
};
|
||||
}
|
||||
|
||||
const source = args.join(' ').trim();
|
||||
if (!source) {
|
||||
return { name: this.name, data: `Usage: /extensions install <source>` };
|
||||
}
|
||||
|
||||
if (/[;&|`'"]/.test(source)) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: `Invalid source: contains disallowed characters.`,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const installMetadata = await inferInstallMetadata(source);
|
||||
const extension =
|
||||
await extensionLoader.installOrUpdateExtension(installMetadata);
|
||||
return {
|
||||
name: this.name,
|
||||
data: `Extension "${extension.name}" installed successfully.`,
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: `Failed to install extension from "${source}": ${getErrorMessage(error)}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class LinkExtensionCommand implements Command {
|
||||
readonly name = 'extensions link';
|
||||
readonly description = 'Link an extension from a local path.';
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
args: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const extensionLoader = context.config.getExtensionLoader();
|
||||
if (!(extensionLoader instanceof ExtensionManager)) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: 'Cannot link extensions in this environment.',
|
||||
};
|
||||
}
|
||||
|
||||
const sourceFilepath = args.join(' ').trim();
|
||||
if (!sourceFilepath) {
|
||||
return { name: this.name, data: `Usage: /extensions link <source>` };
|
||||
}
|
||||
|
||||
try {
|
||||
await stat(sourceFilepath);
|
||||
} catch (_error) {
|
||||
return { name: this.name, data: `Invalid source: ${sourceFilepath}` };
|
||||
}
|
||||
|
||||
try {
|
||||
const extension = await extensionLoader.installOrUpdateExtension({
|
||||
source: sourceFilepath,
|
||||
type: 'link',
|
||||
});
|
||||
return {
|
||||
name: this.name,
|
||||
data: `Extension "${extension.name}" linked successfully.`,
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: `Failed to link extension: ${getErrorMessage(error)}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class UninstallExtensionCommand implements Command {
|
||||
readonly name = 'extensions uninstall';
|
||||
readonly description = 'Uninstall an extension.';
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
args: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const extensionLoader = context.config.getExtensionLoader();
|
||||
if (!(extensionLoader instanceof ExtensionManager)) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: 'Cannot uninstall extensions in this environment.',
|
||||
};
|
||||
}
|
||||
|
||||
const name = args.join(' ').trim();
|
||||
if (!name) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: `Usage: /extensions uninstall <extension-name>`,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
await extensionLoader.uninstallExtension(name, false);
|
||||
return {
|
||||
name: this.name,
|
||||
data: `Extension "${name}" uninstalled successfully.`,
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: `Failed to uninstall extension "${name}": ${getErrorMessage(error)}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class RestartExtensionCommand implements Command {
|
||||
readonly name = 'extensions restart';
|
||||
readonly description = 'Restart an extension.';
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
args: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const extensionLoader = context.config.getExtensionLoader();
|
||||
if (!(extensionLoader instanceof ExtensionManager)) {
|
||||
return { name: this.name, data: 'Cannot restart extensions.' };
|
||||
}
|
||||
|
||||
const all = args.includes('--all');
|
||||
const names = all ? null : args.filter((a) => !!a);
|
||||
|
||||
if (!all && names?.length === 0) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: 'Usage: /extensions restart <extension-names>|--all',
|
||||
};
|
||||
}
|
||||
|
||||
let extensionsToRestart = extensionLoader
|
||||
.getExtensions()
|
||||
.filter((e) => e.isActive);
|
||||
if (names) {
|
||||
extensionsToRestart = extensionsToRestart.filter((e) =>
|
||||
names.includes(e.name),
|
||||
);
|
||||
}
|
||||
|
||||
if (extensionsToRestart.length === 0) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: 'No active extensions matched the request.',
|
||||
};
|
||||
}
|
||||
|
||||
const output: string[] = [];
|
||||
for (const extension of extensionsToRestart) {
|
||||
try {
|
||||
await extensionLoader.restartExtension(extension);
|
||||
output.push(`Restarted "${extension.name}".`);
|
||||
} catch (e) {
|
||||
output.push(
|
||||
`Failed to restart "${extension.name}": ${getErrorMessage(e)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return { name: this.name, data: output.join('\n') };
|
||||
}
|
||||
}
|
||||
|
||||
export class UpdateExtensionCommand implements Command {
|
||||
readonly name = 'extensions update';
|
||||
readonly description = 'Update an extension.';
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
args: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const extensionLoader = context.config.getExtensionLoader();
|
||||
if (!(extensionLoader instanceof ExtensionManager)) {
|
||||
return { name: this.name, data: 'Cannot update extensions.' };
|
||||
}
|
||||
|
||||
const all = args.includes('--all');
|
||||
const names = all ? null : args.filter((a) => !!a);
|
||||
|
||||
if (!all && names?.length === 0) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: 'Usage: /extensions update <extension-names>|--all',
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
name: this.name,
|
||||
data: 'Headless extension updating requires internal UI dispatches. Please use `gemini extensions update` directly in the terminal.',
|
||||
};
|
||||
}
|
||||
}
|
||||
62
packages/cli/src/zed-integration/commands/init.ts
Normal file
62
packages/cli/src/zed-integration/commands/init.ts
Normal file
@@ -0,0 +1,62 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as fs from 'node:fs';
|
||||
import * as path from 'node:path';
|
||||
import { performInit } from '@google/gemini-cli-core';
|
||||
import type {
|
||||
Command,
|
||||
CommandContext,
|
||||
CommandExecutionResponse,
|
||||
} from './types.js';
|
||||
|
||||
export class InitCommand implements Command {
|
||||
name = 'init';
|
||||
description = 'Analyzes the project and creates a tailored GEMINI.md file';
|
||||
requiresWorkspace = true;
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
_args: string[] = [],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const targetDir = context.config.getTargetDir();
|
||||
if (!targetDir) {
|
||||
throw new Error('Command requires a workspace.');
|
||||
}
|
||||
|
||||
const geminiMdPath = path.join(targetDir, 'GEMINI.md');
|
||||
const result = performInit(fs.existsSync(geminiMdPath));
|
||||
|
||||
switch (result.type) {
|
||||
case 'message':
|
||||
return {
|
||||
name: this.name,
|
||||
data: result,
|
||||
};
|
||||
case 'submit_prompt':
|
||||
fs.writeFileSync(geminiMdPath, '', 'utf8');
|
||||
|
||||
if (typeof result.content !== 'string') {
|
||||
throw new Error('Init command content must be a string.');
|
||||
}
|
||||
|
||||
// Inform the user since we can't trigger the UI-based interactive agent loop here directly.
|
||||
// We output the prompt text they can use to re-trigger the generation manually,
|
||||
// or just seed the GEMINI.md file as we've done above.
|
||||
return {
|
||||
name: this.name,
|
||||
data: {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: `A template GEMINI.md has been created at ${geminiMdPath}.\n\nTo populate it with project context, you can run the following prompt in a new chat:\n\n${result.content}`,
|
||||
},
|
||||
};
|
||||
|
||||
default:
|
||||
throw new Error('Unknown result type from performInit');
|
||||
}
|
||||
}
|
||||
}
|
||||
121
packages/cli/src/zed-integration/commands/memory.ts
Normal file
121
packages/cli/src/zed-integration/commands/memory.ts
Normal file
@@ -0,0 +1,121 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
addMemory,
|
||||
listMemoryFiles,
|
||||
refreshMemory,
|
||||
showMemory,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type {
|
||||
Command,
|
||||
CommandContext,
|
||||
CommandExecutionResponse,
|
||||
} from './types.js';
|
||||
|
||||
const DEFAULT_SANITIZATION_CONFIG = {
|
||||
allowedEnvironmentVariables: [],
|
||||
blockedEnvironmentVariables: [],
|
||||
enableEnvironmentVariableRedaction: false,
|
||||
};
|
||||
|
||||
export class MemoryCommand implements Command {
|
||||
readonly name = 'memory';
|
||||
readonly description = 'Manage memory.';
|
||||
readonly subCommands = [
|
||||
new ShowMemoryCommand(),
|
||||
new RefreshMemoryCommand(),
|
||||
new ListMemoryCommand(),
|
||||
new AddMemoryCommand(),
|
||||
];
|
||||
readonly requiresWorkspace = true;
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
_: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
return new ShowMemoryCommand().execute(context, _);
|
||||
}
|
||||
}
|
||||
|
||||
export class ShowMemoryCommand implements Command {
|
||||
readonly name = 'memory show';
|
||||
readonly description = 'Shows the current memory contents.';
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
_: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const result = showMemory(context.config);
|
||||
return { name: this.name, data: result.content };
|
||||
}
|
||||
}
|
||||
|
||||
export class RefreshMemoryCommand implements Command {
|
||||
readonly name = 'memory refresh';
|
||||
readonly aliases = ['memory reload'];
|
||||
readonly description = 'Refreshes the memory from the source.';
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
_: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const result = await refreshMemory(context.config);
|
||||
return { name: this.name, data: result.content };
|
||||
}
|
||||
}
|
||||
|
||||
export class ListMemoryCommand implements Command {
|
||||
readonly name = 'memory list';
|
||||
readonly description = 'Lists the paths of the GEMINI.md files in use.';
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
_: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const result = listMemoryFiles(context.config);
|
||||
return { name: this.name, data: result.content };
|
||||
}
|
||||
}
|
||||
|
||||
export class AddMemoryCommand implements Command {
|
||||
readonly name = 'memory add';
|
||||
readonly description = 'Add content to the memory.';
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
args: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const textToAdd = args.join(' ').trim();
|
||||
const result = addMemory(textToAdd);
|
||||
if (result.type === 'message') {
|
||||
return { name: this.name, data: result.content };
|
||||
}
|
||||
|
||||
const toolRegistry = context.config.getToolRegistry();
|
||||
const tool = toolRegistry.getTool(result.toolName);
|
||||
if (tool) {
|
||||
const abortController = new AbortController();
|
||||
const signal = abortController.signal;
|
||||
|
||||
await context.sendMessage(`Saving memory via ${result.toolName}...`);
|
||||
|
||||
await tool.buildAndExecute(result.toolArgs, signal, undefined, {
|
||||
sanitizationConfig: DEFAULT_SANITIZATION_CONFIG,
|
||||
});
|
||||
await refreshMemory(context.config);
|
||||
return {
|
||||
name: this.name,
|
||||
data: `Added memory: "${textToAdd}"`,
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
name: this.name,
|
||||
data: `Error: Tool ${result.toolName} not found.`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
178
packages/cli/src/zed-integration/commands/restore.ts
Normal file
178
packages/cli/src/zed-integration/commands/restore.ts
Normal file
@@ -0,0 +1,178 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
getCheckpointInfoList,
|
||||
getToolCallDataSchema,
|
||||
isNodeError,
|
||||
performRestore,
|
||||
} from '@google/gemini-cli-core';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import * as path from 'node:path';
|
||||
import type {
|
||||
Command,
|
||||
CommandContext,
|
||||
CommandExecutionResponse,
|
||||
} from './types.js';
|
||||
|
||||
export class RestoreCommand implements Command {
|
||||
readonly name = 'restore';
|
||||
readonly description =
|
||||
'Restore to a previous checkpoint, or list available checkpoints to restore. This will reset the conversation and file history to the state it was in when the checkpoint was created';
|
||||
readonly requiresWorkspace = true;
|
||||
readonly subCommands = [new ListCheckpointsCommand()];
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
args: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const { config, git: gitService } = context;
|
||||
const argsStr = args.join(' ');
|
||||
|
||||
try {
|
||||
if (!argsStr) {
|
||||
return await new ListCheckpointsCommand().execute(context);
|
||||
}
|
||||
|
||||
if (!config.getCheckpointingEnabled()) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: 'Checkpointing is not enabled. Please enable it in your settings (`general.checkpointing.enabled: true`) to use /restore.',
|
||||
};
|
||||
}
|
||||
|
||||
const selectedFile = argsStr.endsWith('.json')
|
||||
? argsStr
|
||||
: `${argsStr}.json`;
|
||||
|
||||
const checkpointDir = config.storage.getProjectTempCheckpointsDir();
|
||||
const filePath = path.join(checkpointDir, selectedFile);
|
||||
|
||||
let data: string;
|
||||
try {
|
||||
data = await fs.readFile(filePath, 'utf-8');
|
||||
} catch (error) {
|
||||
if (isNodeError(error) && error.code === 'ENOENT') {
|
||||
return {
|
||||
name: this.name,
|
||||
data: `File not found: ${selectedFile}`,
|
||||
};
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
|
||||
const toolCallData = JSON.parse(data);
|
||||
const ToolCallDataSchema = getToolCallDataSchema();
|
||||
const parseResult = ToolCallDataSchema.safeParse(toolCallData);
|
||||
|
||||
if (!parseResult.success) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: 'Checkpoint file is invalid or corrupted.',
|
||||
};
|
||||
}
|
||||
|
||||
const restoreResultGenerator = performRestore(
|
||||
parseResult.data,
|
||||
gitService,
|
||||
);
|
||||
|
||||
const restoreResult = [];
|
||||
for await (const result of restoreResultGenerator) {
|
||||
restoreResult.push(result);
|
||||
}
|
||||
|
||||
// Format the result nicely since Zed just dumps data
|
||||
const formattedResult = restoreResult
|
||||
.map((r) => {
|
||||
if (r.type === 'message') {
|
||||
return `[${r.messageType.toUpperCase()}] ${r.content}`;
|
||||
} else if (r.type === 'load_history') {
|
||||
return `Loaded history with ${r.clientHistory.length} messages.`;
|
||||
}
|
||||
return `Restored: ${JSON.stringify(r)}`;
|
||||
})
|
||||
.join('\n');
|
||||
|
||||
return {
|
||||
name: this.name,
|
||||
data: formattedResult,
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: `An unexpected error occurred during restore: ${error}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class ListCheckpointsCommand implements Command {
|
||||
readonly name = 'restore list';
|
||||
readonly description = 'Lists all available checkpoints.';
|
||||
|
||||
async execute(context: CommandContext): Promise<CommandExecutionResponse> {
|
||||
const { config } = context;
|
||||
|
||||
try {
|
||||
if (!config.getCheckpointingEnabled()) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: 'Checkpointing is not enabled. Please enable it in your settings (`general.checkpointing.enabled: true`) to use /restore.',
|
||||
};
|
||||
}
|
||||
|
||||
const checkpointDir = config.storage.getProjectTempCheckpointsDir();
|
||||
try {
|
||||
await fs.mkdir(checkpointDir, { recursive: true });
|
||||
} catch (_e) {
|
||||
// Ignore
|
||||
}
|
||||
|
||||
const files = await fs.readdir(checkpointDir);
|
||||
const jsonFiles = files.filter((file) => file.endsWith('.json'));
|
||||
|
||||
if (jsonFiles.length === 0) {
|
||||
return { name: this.name, data: 'No checkpoints found.' };
|
||||
}
|
||||
|
||||
const checkpointFiles = new Map<string, string>();
|
||||
for (const file of jsonFiles) {
|
||||
const filePath = path.join(checkpointDir, file);
|
||||
const data = await fs.readFile(filePath, 'utf-8');
|
||||
checkpointFiles.set(file, data);
|
||||
}
|
||||
|
||||
const checkpointInfoList = getCheckpointInfoList(checkpointFiles);
|
||||
|
||||
const formatted = checkpointInfoList
|
||||
.map((info) => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const i = info as Record<string, any>;
|
||||
const fileName = String(i['fileName'] || 'Unknown');
|
||||
const toolName = String(i['toolName'] || 'Unknown');
|
||||
const status = String(i['status'] || 'Unknown');
|
||||
const timestamp = new Date(
|
||||
Number(i['timestamp']) || 0,
|
||||
).toLocaleString();
|
||||
|
||||
return `- **${fileName}**: ${toolName} (Status: ${status}) [${timestamp}]`;
|
||||
})
|
||||
.join('\n');
|
||||
|
||||
return {
|
||||
name: this.name,
|
||||
data: `Available Checkpoints:\n${formatted}`,
|
||||
};
|
||||
} catch (_error) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: 'An unexpected error occurred while listing checkpoints.',
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
40
packages/cli/src/zed-integration/commands/types.ts
Normal file
40
packages/cli/src/zed-integration/commands/types.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config, GitService } from '@google/gemini-cli-core';
|
||||
import type { LoadedSettings } from '../../config/settings.js';
|
||||
|
||||
export interface CommandContext {
|
||||
config: Config;
|
||||
settings: LoadedSettings;
|
||||
git?: GitService;
|
||||
sendMessage: (text: string) => Promise<void>;
|
||||
}
|
||||
|
||||
export interface CommandArgument {
|
||||
readonly name: string;
|
||||
readonly description: string;
|
||||
readonly isRequired?: boolean;
|
||||
}
|
||||
|
||||
export interface Command {
|
||||
readonly name: string;
|
||||
readonly aliases?: string[];
|
||||
readonly description: string;
|
||||
readonly arguments?: CommandArgument[];
|
||||
readonly subCommands?: Command[];
|
||||
readonly requiresWorkspace?: boolean;
|
||||
|
||||
execute(
|
||||
context: CommandContext,
|
||||
args: string[],
|
||||
): Promise<CommandExecutionResponse>;
|
||||
}
|
||||
|
||||
export interface CommandExecutionResponse {
|
||||
readonly name: string;
|
||||
readonly data: unknown;
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
type Mocked,
|
||||
} from 'vitest';
|
||||
import { GeminiAgent, Session } from './zedIntegration.js';
|
||||
import type { CommandHandler } from './commandHandler.js';
|
||||
import * as acp from '@agentclientprotocol/sdk';
|
||||
import {
|
||||
AuthType,
|
||||
@@ -26,6 +27,7 @@ import {
|
||||
type Config,
|
||||
type MessageBus,
|
||||
LlmRole,
|
||||
type GitService,
|
||||
} from '@google/gemini-cli-core';
|
||||
import {
|
||||
SettingScope,
|
||||
@@ -62,7 +64,33 @@ vi.mock('node:path', async (importOriginal) => {
|
||||
};
|
||||
});
|
||||
|
||||
// Mock ReadManyFilesTool
|
||||
vi.mock('../ui/commands/memoryCommand.js', () => ({
|
||||
memoryCommand: {
|
||||
name: 'memory',
|
||||
action: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock('../ui/commands/extensionsCommand.js', () => ({
|
||||
extensionsCommand: vi.fn().mockReturnValue({
|
||||
name: 'extensions',
|
||||
action: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock('../ui/commands/restoreCommand.js', () => ({
|
||||
restoreCommand: vi.fn().mockReturnValue({
|
||||
name: 'restore',
|
||||
action: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock('../ui/commands/initCommand.js', () => ({
|
||||
initCommand: {
|
||||
name: 'init',
|
||||
action: vi.fn(),
|
||||
},
|
||||
}));
|
||||
vi.mock(
|
||||
'@google/gemini-cli-core',
|
||||
async (
|
||||
@@ -145,6 +173,7 @@ describe('GeminiAgent', () => {
|
||||
}),
|
||||
getApprovalMode: vi.fn().mockReturnValue('default'),
|
||||
isPlanEnabled: vi.fn().mockReturnValue(false),
|
||||
getCheckpointingEnabled: vi.fn().mockReturnValue(false),
|
||||
} as unknown as Mocked<Awaited<ReturnType<typeof loadCliConfig>>>;
|
||||
mockSettings = {
|
||||
merged: {
|
||||
@@ -225,6 +254,7 @@ describe('GeminiAgent', () => {
|
||||
});
|
||||
|
||||
it('should create a new session', async () => {
|
||||
vi.useFakeTimers();
|
||||
mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({
|
||||
apiKey: 'test-key',
|
||||
});
|
||||
@@ -237,6 +267,17 @@ describe('GeminiAgent', () => {
|
||||
expect(loadCliConfig).toHaveBeenCalled();
|
||||
expect(mockConfig.initialize).toHaveBeenCalled();
|
||||
expect(mockConfig.getGeminiClient).toHaveBeenCalled();
|
||||
|
||||
// Verify deferred call
|
||||
await vi.runAllTimersAsync();
|
||||
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
update: expect.objectContaining({
|
||||
sessionUpdate: 'available_commands_update',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should return modes without plan mode when plan is disabled', async () => {
|
||||
@@ -477,6 +518,7 @@ describe('Session', () => {
|
||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||
getActiveModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||
getMcpServers: vi.fn(),
|
||||
getFileService: vi.fn().mockReturnValue({
|
||||
shouldIgnoreFile: vi.fn().mockReturnValue(false),
|
||||
}),
|
||||
@@ -487,6 +529,8 @@ describe('Session', () => {
|
||||
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
|
||||
setApprovalMode: vi.fn(),
|
||||
isPlanEnabled: vi.fn().mockReturnValue(false),
|
||||
getCheckpointingEnabled: vi.fn().mockReturnValue(false),
|
||||
getGitService: vi.fn().mockResolvedValue({} as GitService),
|
||||
waitForMcpInit: vi.fn(),
|
||||
} as unknown as Mocked<Config>;
|
||||
mockConnection = {
|
||||
@@ -495,13 +539,38 @@ describe('Session', () => {
|
||||
sendNotification: vi.fn(),
|
||||
} as unknown as Mocked<acp.AgentSideConnection>;
|
||||
|
||||
session = new Session('session-1', mockChat, mockConfig, mockConnection);
|
||||
session = new Session('session-1', mockChat, mockConfig, mockConnection, {
|
||||
system: { settings: {} },
|
||||
systemDefaults: { settings: {} },
|
||||
user: { settings: {} },
|
||||
workspace: { settings: {} },
|
||||
merged: { settings: {} },
|
||||
errors: [],
|
||||
} as unknown as LoadedSettings);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should send available commands', async () => {
|
||||
await session.sendAvailableCommands();
|
||||
|
||||
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
update: expect.objectContaining({
|
||||
sessionUpdate: 'available_commands_update',
|
||||
availableCommands: expect.arrayContaining([
|
||||
expect.objectContaining({ name: 'memory' }),
|
||||
expect.objectContaining({ name: 'extensions' }),
|
||||
expect.objectContaining({ name: 'restore' }),
|
||||
expect.objectContaining({ name: 'init' }),
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should await MCP initialization before processing a prompt', async () => {
|
||||
const stream = createMockStream([
|
||||
{
|
||||
@@ -551,6 +620,113 @@ describe('Session', () => {
|
||||
expect(result).toEqual({ stopReason: 'end_turn' });
|
||||
});
|
||||
|
||||
it('should handle /memory command', async () => {
|
||||
const handleCommandSpy = vi
|
||||
.spyOn(
|
||||
(session as unknown as { commandHandler: CommandHandler })
|
||||
.commandHandler,
|
||||
'handleCommand',
|
||||
)
|
||||
.mockResolvedValue(true);
|
||||
|
||||
const result = await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
prompt: [{ type: 'text', text: '/memory view' }],
|
||||
});
|
||||
|
||||
expect(result).toEqual({ stopReason: 'end_turn' });
|
||||
expect(handleCommandSpy).toHaveBeenCalledWith(
|
||||
'/memory view',
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(mockChat.sendMessageStream).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle /extensions command', async () => {
|
||||
const handleCommandSpy = vi
|
||||
.spyOn(
|
||||
(session as unknown as { commandHandler: CommandHandler })
|
||||
.commandHandler,
|
||||
'handleCommand',
|
||||
)
|
||||
.mockResolvedValue(true);
|
||||
|
||||
const result = await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
prompt: [{ type: 'text', text: '/extensions list' }],
|
||||
});
|
||||
|
||||
expect(result).toEqual({ stopReason: 'end_turn' });
|
||||
expect(handleCommandSpy).toHaveBeenCalledWith(
|
||||
'/extensions list',
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(mockChat.sendMessageStream).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle /extensions explore command', async () => {
|
||||
const handleCommandSpy = vi
|
||||
.spyOn(
|
||||
(session as unknown as { commandHandler: CommandHandler })
|
||||
.commandHandler,
|
||||
'handleCommand',
|
||||
)
|
||||
.mockResolvedValue(true);
|
||||
|
||||
const result = await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
prompt: [{ type: 'text', text: '/extensions explore' }],
|
||||
});
|
||||
|
||||
expect(result).toEqual({ stopReason: 'end_turn' });
|
||||
expect(handleCommandSpy).toHaveBeenCalledWith(
|
||||
'/extensions explore',
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(mockChat.sendMessageStream).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle /restore command', async () => {
|
||||
const handleCommandSpy = vi
|
||||
.spyOn(
|
||||
(session as unknown as { commandHandler: CommandHandler })
|
||||
.commandHandler,
|
||||
'handleCommand',
|
||||
)
|
||||
.mockResolvedValue(true);
|
||||
|
||||
const result = await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
prompt: [{ type: 'text', text: '/restore' }],
|
||||
});
|
||||
|
||||
expect(result).toEqual({ stopReason: 'end_turn' });
|
||||
expect(handleCommandSpy).toHaveBeenCalledWith(
|
||||
'/restore',
|
||||
expect.any(Object),
|
||||
);
|
||||
expect(mockChat.sendMessageStream).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle /init command', async () => {
|
||||
const handleCommandSpy = vi
|
||||
.spyOn(
|
||||
(session as unknown as { commandHandler: CommandHandler })
|
||||
.commandHandler,
|
||||
'handleCommand',
|
||||
)
|
||||
.mockResolvedValue(true);
|
||||
|
||||
const result = await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
prompt: [{ type: 'text', text: '/init' }],
|
||||
});
|
||||
|
||||
expect(result).toEqual({ stopReason: 'end_turn' });
|
||||
expect(handleCommandSpy).toHaveBeenCalledWith('/init', expect.any(Object));
|
||||
expect(mockChat.sendMessageStream).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle tool calls', async () => {
|
||||
const stream1 = createMockStream([
|
||||
{
|
||||
@@ -1207,4 +1383,24 @@ describe('Session', () => {
|
||||
'Invalid or unavailable mode: invalid-mode',
|
||||
);
|
||||
});
|
||||
it('should handle unquoted commands from autocomplete (with empty leading parts)', async () => {
|
||||
// Mock handleCommand to verify it gets called
|
||||
const handleCommandSpy = vi
|
||||
.spyOn(
|
||||
(session as unknown as { commandHandler: CommandHandler })
|
||||
.commandHandler,
|
||||
'handleCommand',
|
||||
)
|
||||
.mockResolvedValue(true);
|
||||
|
||||
await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
prompt: [
|
||||
{ type: 'text', text: '' },
|
||||
{ type: 'text', text: '/memory' },
|
||||
],
|
||||
});
|
||||
|
||||
expect(handleCommandSpy).toHaveBeenCalledWith('/memory', expect.anything());
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,15 +4,13 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
Config,
|
||||
GeminiChat,
|
||||
ToolResult,
|
||||
ToolCallConfirmationDetails,
|
||||
FilterFilesOptions,
|
||||
ConversationRecord,
|
||||
} from '@google/gemini-cli-core';
|
||||
import {
|
||||
type Config,
|
||||
type GeminiChat,
|
||||
type ToolResult,
|
||||
type ToolCallConfirmationDetails,
|
||||
type FilterFilesOptions,
|
||||
type ConversationRecord,
|
||||
CoreToolCallStatus,
|
||||
AuthType,
|
||||
logToolCall,
|
||||
@@ -61,11 +59,14 @@ import { loadCliConfig } from '../config/config.js';
|
||||
import { runExitCleanup } from '../utils/cleanup.js';
|
||||
import { SessionSelector } from '../utils/sessionUtils.js';
|
||||
|
||||
import { CommandHandler } from './commandHandler.js';
|
||||
export async function runZedIntegration(
|
||||
config: Config,
|
||||
settings: LoadedSettings,
|
||||
argv: CliArgs,
|
||||
) {
|
||||
// ... (skip unchanged lines) ...
|
||||
|
||||
const { stdout: workingStdout } = createWorkingStdio();
|
||||
const stdout = Writable.toWeb(workingStdout) as WritableStream;
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
@@ -240,9 +241,20 @@ export class GeminiAgent {
|
||||
|
||||
const geminiClient = config.getGeminiClient();
|
||||
const chat = await geminiClient.startChat();
|
||||
const session = new Session(sessionId, chat, config, this.connection);
|
||||
const session = new Session(
|
||||
sessionId,
|
||||
chat,
|
||||
config,
|
||||
this.connection,
|
||||
this.settings,
|
||||
);
|
||||
this.sessions.set(sessionId, session);
|
||||
|
||||
setTimeout(() => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
session.sendAvailableCommands();
|
||||
}, 0);
|
||||
|
||||
return {
|
||||
sessionId,
|
||||
modes: {
|
||||
@@ -291,6 +303,7 @@ export class GeminiAgent {
|
||||
geminiClient.getChat(),
|
||||
config,
|
||||
this.connection,
|
||||
this.settings,
|
||||
);
|
||||
this.sessions.set(sessionId, session);
|
||||
|
||||
@@ -298,6 +311,11 @@ export class GeminiAgent {
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
session.streamHistory(sessionData.messages);
|
||||
|
||||
setTimeout(() => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
session.sendAvailableCommands();
|
||||
}, 0);
|
||||
|
||||
return {
|
||||
modes: {
|
||||
availableModes: buildAvailableModes(config.isPlanEnabled()),
|
||||
@@ -418,12 +436,14 @@ export class GeminiAgent {
|
||||
|
||||
export class Session {
|
||||
private pendingPrompt: AbortController | null = null;
|
||||
private commandHandler = new CommandHandler();
|
||||
|
||||
constructor(
|
||||
private readonly id: string,
|
||||
private readonly chat: GeminiChat,
|
||||
private readonly config: Config,
|
||||
private readonly connection: acp.AgentSideConnection,
|
||||
private readonly settings: LoadedSettings,
|
||||
) {}
|
||||
|
||||
async cancelPendingPrompt(): Promise<void> {
|
||||
@@ -446,6 +466,22 @@ export class Session {
|
||||
return {};
|
||||
}
|
||||
|
||||
private getAvailableCommands() {
|
||||
return this.commandHandler.getAvailableCommands();
|
||||
}
|
||||
|
||||
async sendAvailableCommands(): Promise<void> {
|
||||
const availableCommands = this.getAvailableCommands().map((command) => ({
|
||||
name: command.name,
|
||||
description: command.description,
|
||||
}));
|
||||
|
||||
await this.sendUpdate({
|
||||
sessionUpdate: 'available_commands_update',
|
||||
availableCommands,
|
||||
});
|
||||
}
|
||||
|
||||
async streamHistory(messages: ConversationRecord['messages']): Promise<void> {
|
||||
for (const msg of messages) {
|
||||
const contentString = partListUnionToString(msg.content);
|
||||
@@ -528,6 +564,41 @@ export class Session {
|
||||
|
||||
const parts = await this.#resolvePrompt(params.prompt, pendingSend.signal);
|
||||
|
||||
// Command interception
|
||||
let commandText = '';
|
||||
|
||||
for (const part of parts) {
|
||||
if (typeof part === 'object' && part !== null) {
|
||||
if ('text' in part) {
|
||||
// It is a text part
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-assignment, @typescript-eslint/no-unsafe-type-assertion
|
||||
const text = (part as any).text;
|
||||
if (typeof text === 'string') {
|
||||
commandText += text;
|
||||
}
|
||||
} else {
|
||||
// Non-text part (image, embedded resource)
|
||||
// Stop looking for command
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
commandText = commandText.trim();
|
||||
|
||||
if (
|
||||
commandText &&
|
||||
(commandText.startsWith('/') || commandText.startsWith('$'))
|
||||
) {
|
||||
// If we found a command, pass it to handleCommand
|
||||
// Note: handleCommand currently expects `commandText` to be the command string
|
||||
// It uses `parts` argument but effectively ignores it in current implementation
|
||||
const handled = await this.handleCommand(commandText, parts);
|
||||
if (handled) {
|
||||
return { stopReason: 'end_turn' };
|
||||
}
|
||||
}
|
||||
|
||||
let nextMessage: Content | null = { role: 'user', parts };
|
||||
|
||||
while (nextMessage !== null) {
|
||||
@@ -627,9 +698,28 @@ export class Session {
|
||||
return { stopReason: 'end_turn' };
|
||||
}
|
||||
|
||||
private async sendUpdate(
|
||||
update: acp.SessionNotification['update'],
|
||||
): Promise<void> {
|
||||
private async handleCommand(
|
||||
commandText: string,
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
parts: Part[],
|
||||
): Promise<boolean> {
|
||||
const gitService = await this.config.getGitService();
|
||||
const commandContext = {
|
||||
config: this.config,
|
||||
settings: this.settings,
|
||||
git: gitService,
|
||||
sendMessage: async (text: string) => {
|
||||
await this.sendUpdate({
|
||||
sessionUpdate: 'agent_message_chunk',
|
||||
content: { type: 'text', text },
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
return this.commandHandler.handleCommand(commandText, commandContext);
|
||||
}
|
||||
|
||||
private async sendUpdate(update: acp.SessionUpdate): Promise<void> {
|
||||
const params: acp.SessionNotification = {
|
||||
sessionId: this.id,
|
||||
update,
|
||||
|
||||
Reference in New Issue
Block a user