feat: Implement slash command handling in ACP for /memory,/init,/extensions and /restore (#20528)

This commit is contained in:
Sri Pasumarthi
2026-03-03 13:29:14 -08:00
committed by GitHub
parent d6c560498b
commit 27d7aeb1ed
11 changed files with 1327 additions and 14 deletions

View File

@@ -93,6 +93,7 @@ describe('GeminiAgent Session Resume', () => {
},
getApprovalMode: vi.fn().mockReturnValue('default'),
isPlanEnabled: vi.fn().mockReturnValue(false),
getCheckpointingEnabled: vi.fn().mockReturnValue(false),
} as unknown as Mocked<Config>;
mockSettings = {
merged: {

View File

@@ -0,0 +1,30 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { CommandHandler } from './commandHandler.js';
import { describe, it, expect } from 'vitest';
describe('CommandHandler', () => {
it('parses commands correctly', () => {
const handler = new CommandHandler();
// @ts-expect-error - testing private method
const parse = (query: string) => handler.parseSlashCommand(query);
const memShow = parse('/memory show');
expect(memShow.commandToExecute?.name).toBe('memory show');
expect(memShow.args).toBe('');
const memAdd = parse('/memory add hello world');
expect(memAdd.commandToExecute?.name).toBe('memory add');
expect(memAdd.args).toBe('hello world');
const extList = parse('/extensions list');
expect(extList.commandToExecute?.name).toBe('extensions list');
const init = parse('/init');
expect(init.commandToExecute?.name).toBe('init');
});
});

View File

@@ -0,0 +1,134 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Command, CommandContext } from './commands/types.js';
import { CommandRegistry } from './commands/commandRegistry.js';
import { MemoryCommand } from './commands/memory.js';
import { ExtensionsCommand } from './commands/extensions.js';
import { InitCommand } from './commands/init.js';
import { RestoreCommand } from './commands/restore.js';
export class CommandHandler {
private registry: CommandRegistry;
constructor() {
this.registry = CommandHandler.createRegistry();
}
private static createRegistry(): CommandRegistry {
const registry = new CommandRegistry();
registry.register(new MemoryCommand());
registry.register(new ExtensionsCommand());
registry.register(new InitCommand());
registry.register(new RestoreCommand());
return registry;
}
getAvailableCommands(): Array<{ name: string; description: string }> {
return this.registry.getAllCommands().map((cmd) => ({
name: cmd.name,
description: cmd.description,
}));
}
/**
* Parses and executes a command string if it matches a registered command.
* Returns true if a command was handled, false otherwise.
*/
async handleCommand(
commandText: string,
context: CommandContext,
): Promise<boolean> {
const { commandToExecute, args } = this.parseSlashCommand(commandText);
if (commandToExecute) {
await this.runCommand(commandToExecute, args, context);
return true;
}
return false;
}
private async runCommand(
commandToExecute: Command,
args: string,
context: CommandContext,
): Promise<void> {
try {
const result = await commandToExecute.execute(
context,
args ? args.split(/\s+/) : [],
);
let messageContent = '';
if (typeof result.data === 'string') {
messageContent = result.data;
} else if (
typeof result.data === 'object' &&
result.data !== null &&
'content' in result.data
) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
messageContent = (result.data as Record<string, any>)[
'content'
] as string;
} else {
messageContent = JSON.stringify(result.data, null, 2);
}
await context.sendMessage(messageContent);
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
await context.sendMessage(`Error: ${errorMessage}`);
}
}
/**
* Parses a raw slash command string into its matching headless command and arguments.
* Mirrors `packages/cli/src/utils/commands.ts` logic.
*/
private parseSlashCommand(query: string): {
commandToExecute: Command | undefined;
args: string;
} {
const trimmed = query.trim();
const parts = trimmed.substring(1).trim().split(/\s+/);
const commandPath = parts.filter((p) => p);
let currentCommands = this.registry.getAllCommands();
let commandToExecute: Command | undefined;
let pathIndex = 0;
for (const part of commandPath) {
const foundCommand = currentCommands.find((cmd) => {
const expectedName = commandPath.slice(0, pathIndex + 1).join(' ');
return (
cmd.name === part ||
cmd.name === expectedName ||
cmd.aliases?.includes(part) ||
cmd.aliases?.includes(expectedName)
);
});
if (foundCommand) {
commandToExecute = foundCommand;
pathIndex++;
if (foundCommand.subCommands) {
currentCommands = foundCommand.subCommands;
} else {
break;
}
} else {
break;
}
}
const args = parts.slice(pathIndex).join(' ');
return { commandToExecute, args };
}
}

View File

@@ -0,0 +1,33 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { debugLogger } from '@google/gemini-cli-core';
import type { Command } from './types.js';
export class CommandRegistry {
private readonly commands = new Map<string, Command>();
register(command: Command) {
if (this.commands.has(command.name)) {
debugLogger.warn(`Command ${command.name} already registered. Skipping.`);
return;
}
this.commands.set(command.name, command);
for (const subCommand of command.subCommands ?? []) {
this.register(subCommand);
}
}
get(commandName: string): Command | undefined {
return this.commands.get(commandName);
}
getAllCommands(): Command[] {
return [...this.commands.values()];
}
}

View File

@@ -0,0 +1,428 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { listExtensions } from '@google/gemini-cli-core';
import { SettingScope } from '../../config/settings.js';
import {
ExtensionManager,
inferInstallMetadata,
} from '../../config/extension-manager.js';
import { getErrorMessage } from '../../utils/errors.js';
import { McpServerEnablementManager } from '../../config/mcp/mcpServerEnablement.js';
import { stat } from 'node:fs/promises';
import type {
Command,
CommandContext,
CommandExecutionResponse,
} from './types.js';
import type { Config } from '@google/gemini-cli-core';
export class ExtensionsCommand implements Command {
readonly name = 'extensions';
readonly description = 'Manage extensions.';
readonly subCommands = [
new ListExtensionsCommand(),
new ExploreExtensionsCommand(),
new EnableExtensionCommand(),
new DisableExtensionCommand(),
new InstallExtensionCommand(),
new LinkExtensionCommand(),
new UninstallExtensionCommand(),
new RestartExtensionCommand(),
new UpdateExtensionCommand(),
];
async execute(
context: CommandContext,
_: string[],
): Promise<CommandExecutionResponse> {
return new ListExtensionsCommand().execute(context, _);
}
}
export class ListExtensionsCommand implements Command {
readonly name = 'extensions list';
readonly description = 'Lists all installed extensions.';
async execute(
context: CommandContext,
_: string[],
): Promise<CommandExecutionResponse> {
const extensions = listExtensions(context.config);
const data = extensions.length ? extensions : 'No extensions installed.';
return { name: this.name, data };
}
}
export class ExploreExtensionsCommand implements Command {
readonly name = 'extensions explore';
readonly description = 'Explore available extensions.';
async execute(
_context: CommandContext,
_: string[],
): Promise<CommandExecutionResponse> {
const extensionsUrl = 'https://geminicli.com/extensions/';
return {
name: this.name,
data: `View or install available extensions at ${extensionsUrl}`,
};
}
}
function getEnableDisableContext(
config: Config,
args: string[],
invocationName: string,
) {
const extensionManager = config.getExtensionLoader();
if (!(extensionManager instanceof ExtensionManager)) {
return {
error: `Cannot ${invocationName} extensions in this environment.`,
};
}
if (args.length === 0) {
return {
error: `Usage: /extensions ${invocationName} <extension> [--scope=<user|workspace|session>]`,
};
}
let scope = SettingScope.User;
if (args.includes('--scope=workspace') || args.includes('workspace')) {
scope = SettingScope.Workspace;
} else if (args.includes('--scope=session') || args.includes('session')) {
scope = SettingScope.Session;
}
const name = args.filter(
(a) =>
!a.startsWith('--scope') && !['user', 'workspace', 'session'].includes(a),
)[0];
let names: string[] = [];
if (name === '--all') {
let extensions = extensionManager.getExtensions();
if (invocationName === 'enable') {
extensions = extensions.filter((ext) => !ext.isActive);
}
if (invocationName === 'disable') {
extensions = extensions.filter((ext) => ext.isActive);
}
names = extensions.map((ext) => ext.name);
} else if (name) {
names = [name];
} else {
return { error: 'No extension name provided.' };
}
return { extensionManager, names, scope };
}
export class EnableExtensionCommand implements Command {
readonly name = 'extensions enable';
readonly description = 'Enable an extension.';
async execute(
context: CommandContext,
args: string[],
): Promise<CommandExecutionResponse> {
const enableContext = getEnableDisableContext(
context.config,
args,
'enable',
);
if ('error' in enableContext) {
return { name: this.name, data: enableContext.error };
}
const { names, scope, extensionManager } = enableContext;
const output: string[] = [];
for (const name of names) {
try {
await extensionManager.enableExtension(name, scope);
output.push(`Extension "${name}" enabled for scope "${scope}".`);
const extension = extensionManager
.getExtensions()
.find((e) => e.name === name);
if (extension?.mcpServers) {
const mcpEnablementManager = McpServerEnablementManager.getInstance();
const mcpClientManager = context.config.getMcpClientManager();
const enabledServers = await mcpEnablementManager.autoEnableServers(
Object.keys(extension.mcpServers),
);
if (mcpClientManager && enabledServers.length > 0) {
const restartPromises = enabledServers.map((serverName) =>
mcpClientManager.restartServer(serverName).catch((error) => {
output.push(
`Failed to restart MCP server '${serverName}': ${getErrorMessage(error)}`,
);
}),
);
await Promise.all(restartPromises);
output.push(`Re-enabled MCP servers: ${enabledServers.join(', ')}`);
}
}
} catch (e) {
output.push(`Failed to enable "${name}": ${getErrorMessage(e)}`);
}
}
return { name: this.name, data: output.join('\n') || 'No action taken.' };
}
}
export class DisableExtensionCommand implements Command {
readonly name = 'extensions disable';
readonly description = 'Disable an extension.';
async execute(
context: CommandContext,
args: string[],
): Promise<CommandExecutionResponse> {
const enableContext = getEnableDisableContext(
context.config,
args,
'disable',
);
if ('error' in enableContext) {
return { name: this.name, data: enableContext.error };
}
const { names, scope, extensionManager } = enableContext;
const output: string[] = [];
for (const name of names) {
try {
await extensionManager.disableExtension(name, scope);
output.push(`Extension "${name}" disabled for scope "${scope}".`);
} catch (e) {
output.push(`Failed to disable "${name}": ${getErrorMessage(e)}`);
}
}
return { name: this.name, data: output.join('\n') || 'No action taken.' };
}
}
export class InstallExtensionCommand implements Command {
readonly name = 'extensions install';
readonly description = 'Install an extension from a git repo or local path.';
async execute(
context: CommandContext,
args: string[],
): Promise<CommandExecutionResponse> {
const extensionLoader = context.config.getExtensionLoader();
if (!(extensionLoader instanceof ExtensionManager)) {
return {
name: this.name,
data: 'Cannot install extensions in this environment.',
};
}
const source = args.join(' ').trim();
if (!source) {
return { name: this.name, data: `Usage: /extensions install <source>` };
}
if (/[;&|`'"]/.test(source)) {
return {
name: this.name,
data: `Invalid source: contains disallowed characters.`,
};
}
try {
const installMetadata = await inferInstallMetadata(source);
const extension =
await extensionLoader.installOrUpdateExtension(installMetadata);
return {
name: this.name,
data: `Extension "${extension.name}" installed successfully.`,
};
} catch (error) {
return {
name: this.name,
data: `Failed to install extension from "${source}": ${getErrorMessage(error)}`,
};
}
}
}
export class LinkExtensionCommand implements Command {
readonly name = 'extensions link';
readonly description = 'Link an extension from a local path.';
async execute(
context: CommandContext,
args: string[],
): Promise<CommandExecutionResponse> {
const extensionLoader = context.config.getExtensionLoader();
if (!(extensionLoader instanceof ExtensionManager)) {
return {
name: this.name,
data: 'Cannot link extensions in this environment.',
};
}
const sourceFilepath = args.join(' ').trim();
if (!sourceFilepath) {
return { name: this.name, data: `Usage: /extensions link <source>` };
}
try {
await stat(sourceFilepath);
} catch (_error) {
return { name: this.name, data: `Invalid source: ${sourceFilepath}` };
}
try {
const extension = await extensionLoader.installOrUpdateExtension({
source: sourceFilepath,
type: 'link',
});
return {
name: this.name,
data: `Extension "${extension.name}" linked successfully.`,
};
} catch (error) {
return {
name: this.name,
data: `Failed to link extension: ${getErrorMessage(error)}`,
};
}
}
}
export class UninstallExtensionCommand implements Command {
readonly name = 'extensions uninstall';
readonly description = 'Uninstall an extension.';
async execute(
context: CommandContext,
args: string[],
): Promise<CommandExecutionResponse> {
const extensionLoader = context.config.getExtensionLoader();
if (!(extensionLoader instanceof ExtensionManager)) {
return {
name: this.name,
data: 'Cannot uninstall extensions in this environment.',
};
}
const name = args.join(' ').trim();
if (!name) {
return {
name: this.name,
data: `Usage: /extensions uninstall <extension-name>`,
};
}
try {
await extensionLoader.uninstallExtension(name, false);
return {
name: this.name,
data: `Extension "${name}" uninstalled successfully.`,
};
} catch (error) {
return {
name: this.name,
data: `Failed to uninstall extension "${name}": ${getErrorMessage(error)}`,
};
}
}
}
export class RestartExtensionCommand implements Command {
readonly name = 'extensions restart';
readonly description = 'Restart an extension.';
async execute(
context: CommandContext,
args: string[],
): Promise<CommandExecutionResponse> {
const extensionLoader = context.config.getExtensionLoader();
if (!(extensionLoader instanceof ExtensionManager)) {
return { name: this.name, data: 'Cannot restart extensions.' };
}
const all = args.includes('--all');
const names = all ? null : args.filter((a) => !!a);
if (!all && names?.length === 0) {
return {
name: this.name,
data: 'Usage: /extensions restart <extension-names>|--all',
};
}
let extensionsToRestart = extensionLoader
.getExtensions()
.filter((e) => e.isActive);
if (names) {
extensionsToRestart = extensionsToRestart.filter((e) =>
names.includes(e.name),
);
}
if (extensionsToRestart.length === 0) {
return {
name: this.name,
data: 'No active extensions matched the request.',
};
}
const output: string[] = [];
for (const extension of extensionsToRestart) {
try {
await extensionLoader.restartExtension(extension);
output.push(`Restarted "${extension.name}".`);
} catch (e) {
output.push(
`Failed to restart "${extension.name}": ${getErrorMessage(e)}`,
);
}
}
return { name: this.name, data: output.join('\n') };
}
}
export class UpdateExtensionCommand implements Command {
readonly name = 'extensions update';
readonly description = 'Update an extension.';
async execute(
context: CommandContext,
args: string[],
): Promise<CommandExecutionResponse> {
const extensionLoader = context.config.getExtensionLoader();
if (!(extensionLoader instanceof ExtensionManager)) {
return { name: this.name, data: 'Cannot update extensions.' };
}
const all = args.includes('--all');
const names = all ? null : args.filter((a) => !!a);
if (!all && names?.length === 0) {
return {
name: this.name,
data: 'Usage: /extensions update <extension-names>|--all',
};
}
return {
name: this.name,
data: 'Headless extension updating requires internal UI dispatches. Please use `gemini extensions update` directly in the terminal.',
};
}
}

View File

@@ -0,0 +1,62 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as fs from 'node:fs';
import * as path from 'node:path';
import { performInit } from '@google/gemini-cli-core';
import type {
Command,
CommandContext,
CommandExecutionResponse,
} from './types.js';
export class InitCommand implements Command {
name = 'init';
description = 'Analyzes the project and creates a tailored GEMINI.md file';
requiresWorkspace = true;
async execute(
context: CommandContext,
_args: string[] = [],
): Promise<CommandExecutionResponse> {
const targetDir = context.config.getTargetDir();
if (!targetDir) {
throw new Error('Command requires a workspace.');
}
const geminiMdPath = path.join(targetDir, 'GEMINI.md');
const result = performInit(fs.existsSync(geminiMdPath));
switch (result.type) {
case 'message':
return {
name: this.name,
data: result,
};
case 'submit_prompt':
fs.writeFileSync(geminiMdPath, '', 'utf8');
if (typeof result.content !== 'string') {
throw new Error('Init command content must be a string.');
}
// Inform the user since we can't trigger the UI-based interactive agent loop here directly.
// We output the prompt text they can use to re-trigger the generation manually,
// or just seed the GEMINI.md file as we've done above.
return {
name: this.name,
data: {
type: 'message',
messageType: 'info',
content: `A template GEMINI.md has been created at ${geminiMdPath}.\n\nTo populate it with project context, you can run the following prompt in a new chat:\n\n${result.content}`,
},
};
default:
throw new Error('Unknown result type from performInit');
}
}
}

View File

@@ -0,0 +1,121 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
addMemory,
listMemoryFiles,
refreshMemory,
showMemory,
} from '@google/gemini-cli-core';
import type {
Command,
CommandContext,
CommandExecutionResponse,
} from './types.js';
const DEFAULT_SANITIZATION_CONFIG = {
allowedEnvironmentVariables: [],
blockedEnvironmentVariables: [],
enableEnvironmentVariableRedaction: false,
};
export class MemoryCommand implements Command {
readonly name = 'memory';
readonly description = 'Manage memory.';
readonly subCommands = [
new ShowMemoryCommand(),
new RefreshMemoryCommand(),
new ListMemoryCommand(),
new AddMemoryCommand(),
];
readonly requiresWorkspace = true;
async execute(
context: CommandContext,
_: string[],
): Promise<CommandExecutionResponse> {
return new ShowMemoryCommand().execute(context, _);
}
}
export class ShowMemoryCommand implements Command {
readonly name = 'memory show';
readonly description = 'Shows the current memory contents.';
async execute(
context: CommandContext,
_: string[],
): Promise<CommandExecutionResponse> {
const result = showMemory(context.config);
return { name: this.name, data: result.content };
}
}
export class RefreshMemoryCommand implements Command {
readonly name = 'memory refresh';
readonly aliases = ['memory reload'];
readonly description = 'Refreshes the memory from the source.';
async execute(
context: CommandContext,
_: string[],
): Promise<CommandExecutionResponse> {
const result = await refreshMemory(context.config);
return { name: this.name, data: result.content };
}
}
export class ListMemoryCommand implements Command {
readonly name = 'memory list';
readonly description = 'Lists the paths of the GEMINI.md files in use.';
async execute(
context: CommandContext,
_: string[],
): Promise<CommandExecutionResponse> {
const result = listMemoryFiles(context.config);
return { name: this.name, data: result.content };
}
}
export class AddMemoryCommand implements Command {
readonly name = 'memory add';
readonly description = 'Add content to the memory.';
async execute(
context: CommandContext,
args: string[],
): Promise<CommandExecutionResponse> {
const textToAdd = args.join(' ').trim();
const result = addMemory(textToAdd);
if (result.type === 'message') {
return { name: this.name, data: result.content };
}
const toolRegistry = context.config.getToolRegistry();
const tool = toolRegistry.getTool(result.toolName);
if (tool) {
const abortController = new AbortController();
const signal = abortController.signal;
await context.sendMessage(`Saving memory via ${result.toolName}...`);
await tool.buildAndExecute(result.toolArgs, signal, undefined, {
sanitizationConfig: DEFAULT_SANITIZATION_CONFIG,
});
await refreshMemory(context.config);
return {
name: this.name,
data: `Added memory: "${textToAdd}"`,
};
} else {
return {
name: this.name,
data: `Error: Tool ${result.toolName} not found.`,
};
}
}
}

View File

@@ -0,0 +1,178 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
getCheckpointInfoList,
getToolCallDataSchema,
isNodeError,
performRestore,
} from '@google/gemini-cli-core';
import * as fs from 'node:fs/promises';
import * as path from 'node:path';
import type {
Command,
CommandContext,
CommandExecutionResponse,
} from './types.js';
export class RestoreCommand implements Command {
readonly name = 'restore';
readonly description =
'Restore to a previous checkpoint, or list available checkpoints to restore. This will reset the conversation and file history to the state it was in when the checkpoint was created';
readonly requiresWorkspace = true;
readonly subCommands = [new ListCheckpointsCommand()];
async execute(
context: CommandContext,
args: string[],
): Promise<CommandExecutionResponse> {
const { config, git: gitService } = context;
const argsStr = args.join(' ');
try {
if (!argsStr) {
return await new ListCheckpointsCommand().execute(context);
}
if (!config.getCheckpointingEnabled()) {
return {
name: this.name,
data: 'Checkpointing is not enabled. Please enable it in your settings (`general.checkpointing.enabled: true`) to use /restore.',
};
}
const selectedFile = argsStr.endsWith('.json')
? argsStr
: `${argsStr}.json`;
const checkpointDir = config.storage.getProjectTempCheckpointsDir();
const filePath = path.join(checkpointDir, selectedFile);
let data: string;
try {
data = await fs.readFile(filePath, 'utf-8');
} catch (error) {
if (isNodeError(error) && error.code === 'ENOENT') {
return {
name: this.name,
data: `File not found: ${selectedFile}`,
};
}
throw error;
}
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
const toolCallData = JSON.parse(data);
const ToolCallDataSchema = getToolCallDataSchema();
const parseResult = ToolCallDataSchema.safeParse(toolCallData);
if (!parseResult.success) {
return {
name: this.name,
data: 'Checkpoint file is invalid or corrupted.',
};
}
const restoreResultGenerator = performRestore(
parseResult.data,
gitService,
);
const restoreResult = [];
for await (const result of restoreResultGenerator) {
restoreResult.push(result);
}
// Format the result nicely since Zed just dumps data
const formattedResult = restoreResult
.map((r) => {
if (r.type === 'message') {
return `[${r.messageType.toUpperCase()}] ${r.content}`;
} else if (r.type === 'load_history') {
return `Loaded history with ${r.clientHistory.length} messages.`;
}
return `Restored: ${JSON.stringify(r)}`;
})
.join('\n');
return {
name: this.name,
data: formattedResult,
};
} catch (error) {
return {
name: this.name,
data: `An unexpected error occurred during restore: ${error}`,
};
}
}
}
export class ListCheckpointsCommand implements Command {
readonly name = 'restore list';
readonly description = 'Lists all available checkpoints.';
async execute(context: CommandContext): Promise<CommandExecutionResponse> {
const { config } = context;
try {
if (!config.getCheckpointingEnabled()) {
return {
name: this.name,
data: 'Checkpointing is not enabled. Please enable it in your settings (`general.checkpointing.enabled: true`) to use /restore.',
};
}
const checkpointDir = config.storage.getProjectTempCheckpointsDir();
try {
await fs.mkdir(checkpointDir, { recursive: true });
} catch (_e) {
// Ignore
}
const files = await fs.readdir(checkpointDir);
const jsonFiles = files.filter((file) => file.endsWith('.json'));
if (jsonFiles.length === 0) {
return { name: this.name, data: 'No checkpoints found.' };
}
const checkpointFiles = new Map<string, string>();
for (const file of jsonFiles) {
const filePath = path.join(checkpointDir, file);
const data = await fs.readFile(filePath, 'utf-8');
checkpointFiles.set(file, data);
}
const checkpointInfoList = getCheckpointInfoList(checkpointFiles);
const formatted = checkpointInfoList
.map((info) => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const i = info as Record<string, any>;
const fileName = String(i['fileName'] || 'Unknown');
const toolName = String(i['toolName'] || 'Unknown');
const status = String(i['status'] || 'Unknown');
const timestamp = new Date(
Number(i['timestamp']) || 0,
).toLocaleString();
return `- **${fileName}**: ${toolName} (Status: ${status}) [${timestamp}]`;
})
.join('\n');
return {
name: this.name,
data: `Available Checkpoints:\n${formatted}`,
};
} catch (_error) {
return {
name: this.name,
data: 'An unexpected error occurred while listing checkpoints.',
};
}
}
}

View File

@@ -0,0 +1,40 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Config, GitService } from '@google/gemini-cli-core';
import type { LoadedSettings } from '../../config/settings.js';
export interface CommandContext {
config: Config;
settings: LoadedSettings;
git?: GitService;
sendMessage: (text: string) => Promise<void>;
}
export interface CommandArgument {
readonly name: string;
readonly description: string;
readonly isRequired?: boolean;
}
export interface Command {
readonly name: string;
readonly aliases?: string[];
readonly description: string;
readonly arguments?: CommandArgument[];
readonly subCommands?: Command[];
readonly requiresWorkspace?: boolean;
execute(
context: CommandContext,
args: string[],
): Promise<CommandExecutionResponse>;
}
export interface CommandExecutionResponse {
readonly name: string;
readonly data: unknown;
}

View File

@@ -15,6 +15,7 @@ import {
type Mocked,
} from 'vitest';
import { GeminiAgent, Session } from './zedIntegration.js';
import type { CommandHandler } from './commandHandler.js';
import * as acp from '@agentclientprotocol/sdk';
import {
AuthType,
@@ -26,6 +27,7 @@ import {
type Config,
type MessageBus,
LlmRole,
type GitService,
} from '@google/gemini-cli-core';
import {
SettingScope,
@@ -62,7 +64,33 @@ vi.mock('node:path', async (importOriginal) => {
};
});
// Mock ReadManyFilesTool
vi.mock('../ui/commands/memoryCommand.js', () => ({
memoryCommand: {
name: 'memory',
action: vi.fn(),
},
}));
vi.mock('../ui/commands/extensionsCommand.js', () => ({
extensionsCommand: vi.fn().mockReturnValue({
name: 'extensions',
action: vi.fn(),
}),
}));
vi.mock('../ui/commands/restoreCommand.js', () => ({
restoreCommand: vi.fn().mockReturnValue({
name: 'restore',
action: vi.fn(),
}),
}));
vi.mock('../ui/commands/initCommand.js', () => ({
initCommand: {
name: 'init',
action: vi.fn(),
},
}));
vi.mock(
'@google/gemini-cli-core',
async (
@@ -145,6 +173,7 @@ describe('GeminiAgent', () => {
}),
getApprovalMode: vi.fn().mockReturnValue('default'),
isPlanEnabled: vi.fn().mockReturnValue(false),
getCheckpointingEnabled: vi.fn().mockReturnValue(false),
} as unknown as Mocked<Awaited<ReturnType<typeof loadCliConfig>>>;
mockSettings = {
merged: {
@@ -225,6 +254,7 @@ describe('GeminiAgent', () => {
});
it('should create a new session', async () => {
vi.useFakeTimers();
mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({
apiKey: 'test-key',
});
@@ -237,6 +267,17 @@ describe('GeminiAgent', () => {
expect(loadCliConfig).toHaveBeenCalled();
expect(mockConfig.initialize).toHaveBeenCalled();
expect(mockConfig.getGeminiClient).toHaveBeenCalled();
// Verify deferred call
await vi.runAllTimersAsync();
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith(
expect.objectContaining({
update: expect.objectContaining({
sessionUpdate: 'available_commands_update',
}),
}),
);
vi.useRealTimers();
});
it('should return modes without plan mode when plan is disabled', async () => {
@@ -477,6 +518,7 @@ describe('Session', () => {
getModel: vi.fn().mockReturnValue('gemini-pro'),
getActiveModel: vi.fn().mockReturnValue('gemini-pro'),
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
getMcpServers: vi.fn(),
getFileService: vi.fn().mockReturnValue({
shouldIgnoreFile: vi.fn().mockReturnValue(false),
}),
@@ -487,6 +529,8 @@ describe('Session', () => {
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
setApprovalMode: vi.fn(),
isPlanEnabled: vi.fn().mockReturnValue(false),
getCheckpointingEnabled: vi.fn().mockReturnValue(false),
getGitService: vi.fn().mockResolvedValue({} as GitService),
waitForMcpInit: vi.fn(),
} as unknown as Mocked<Config>;
mockConnection = {
@@ -495,13 +539,38 @@ describe('Session', () => {
sendNotification: vi.fn(),
} as unknown as Mocked<acp.AgentSideConnection>;
session = new Session('session-1', mockChat, mockConfig, mockConnection);
session = new Session('session-1', mockChat, mockConfig, mockConnection, {
system: { settings: {} },
systemDefaults: { settings: {} },
user: { settings: {} },
workspace: { settings: {} },
merged: { settings: {} },
errors: [],
} as unknown as LoadedSettings);
});
afterEach(() => {
vi.clearAllMocks();
});
it('should send available commands', async () => {
await session.sendAvailableCommands();
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith(
expect.objectContaining({
update: expect.objectContaining({
sessionUpdate: 'available_commands_update',
availableCommands: expect.arrayContaining([
expect.objectContaining({ name: 'memory' }),
expect.objectContaining({ name: 'extensions' }),
expect.objectContaining({ name: 'restore' }),
expect.objectContaining({ name: 'init' }),
]),
}),
}),
);
});
it('should await MCP initialization before processing a prompt', async () => {
const stream = createMockStream([
{
@@ -551,6 +620,113 @@ describe('Session', () => {
expect(result).toEqual({ stopReason: 'end_turn' });
});
it('should handle /memory command', async () => {
const handleCommandSpy = vi
.spyOn(
(session as unknown as { commandHandler: CommandHandler })
.commandHandler,
'handleCommand',
)
.mockResolvedValue(true);
const result = await session.prompt({
sessionId: 'session-1',
prompt: [{ type: 'text', text: '/memory view' }],
});
expect(result).toEqual({ stopReason: 'end_turn' });
expect(handleCommandSpy).toHaveBeenCalledWith(
'/memory view',
expect.any(Object),
);
expect(mockChat.sendMessageStream).not.toHaveBeenCalled();
});
it('should handle /extensions command', async () => {
const handleCommandSpy = vi
.spyOn(
(session as unknown as { commandHandler: CommandHandler })
.commandHandler,
'handleCommand',
)
.mockResolvedValue(true);
const result = await session.prompt({
sessionId: 'session-1',
prompt: [{ type: 'text', text: '/extensions list' }],
});
expect(result).toEqual({ stopReason: 'end_turn' });
expect(handleCommandSpy).toHaveBeenCalledWith(
'/extensions list',
expect.any(Object),
);
expect(mockChat.sendMessageStream).not.toHaveBeenCalled();
});
it('should handle /extensions explore command', async () => {
const handleCommandSpy = vi
.spyOn(
(session as unknown as { commandHandler: CommandHandler })
.commandHandler,
'handleCommand',
)
.mockResolvedValue(true);
const result = await session.prompt({
sessionId: 'session-1',
prompt: [{ type: 'text', text: '/extensions explore' }],
});
expect(result).toEqual({ stopReason: 'end_turn' });
expect(handleCommandSpy).toHaveBeenCalledWith(
'/extensions explore',
expect.any(Object),
);
expect(mockChat.sendMessageStream).not.toHaveBeenCalled();
});
it('should handle /restore command', async () => {
const handleCommandSpy = vi
.spyOn(
(session as unknown as { commandHandler: CommandHandler })
.commandHandler,
'handleCommand',
)
.mockResolvedValue(true);
const result = await session.prompt({
sessionId: 'session-1',
prompt: [{ type: 'text', text: '/restore' }],
});
expect(result).toEqual({ stopReason: 'end_turn' });
expect(handleCommandSpy).toHaveBeenCalledWith(
'/restore',
expect.any(Object),
);
expect(mockChat.sendMessageStream).not.toHaveBeenCalled();
});
it('should handle /init command', async () => {
const handleCommandSpy = vi
.spyOn(
(session as unknown as { commandHandler: CommandHandler })
.commandHandler,
'handleCommand',
)
.mockResolvedValue(true);
const result = await session.prompt({
sessionId: 'session-1',
prompt: [{ type: 'text', text: '/init' }],
});
expect(result).toEqual({ stopReason: 'end_turn' });
expect(handleCommandSpy).toHaveBeenCalledWith('/init', expect.any(Object));
expect(mockChat.sendMessageStream).not.toHaveBeenCalled();
});
it('should handle tool calls', async () => {
const stream1 = createMockStream([
{
@@ -1207,4 +1383,24 @@ describe('Session', () => {
'Invalid or unavailable mode: invalid-mode',
);
});
it('should handle unquoted commands from autocomplete (with empty leading parts)', async () => {
// Mock handleCommand to verify it gets called
const handleCommandSpy = vi
.spyOn(
(session as unknown as { commandHandler: CommandHandler })
.commandHandler,
'handleCommand',
)
.mockResolvedValue(true);
await session.prompt({
sessionId: 'session-1',
prompt: [
{ type: 'text', text: '' },
{ type: 'text', text: '/memory' },
],
});
expect(handleCommandSpy).toHaveBeenCalledWith('/memory', expect.anything());
});
});

View File

@@ -4,15 +4,13 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type {
Config,
GeminiChat,
ToolResult,
ToolCallConfirmationDetails,
FilterFilesOptions,
ConversationRecord,
} from '@google/gemini-cli-core';
import {
type Config,
type GeminiChat,
type ToolResult,
type ToolCallConfirmationDetails,
type FilterFilesOptions,
type ConversationRecord,
CoreToolCallStatus,
AuthType,
logToolCall,
@@ -61,11 +59,14 @@ import { loadCliConfig } from '../config/config.js';
import { runExitCleanup } from '../utils/cleanup.js';
import { SessionSelector } from '../utils/sessionUtils.js';
import { CommandHandler } from './commandHandler.js';
export async function runZedIntegration(
config: Config,
settings: LoadedSettings,
argv: CliArgs,
) {
// ... (skip unchanged lines) ...
const { stdout: workingStdout } = createWorkingStdio();
const stdout = Writable.toWeb(workingStdout) as WritableStream;
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
@@ -240,9 +241,20 @@ export class GeminiAgent {
const geminiClient = config.getGeminiClient();
const chat = await geminiClient.startChat();
const session = new Session(sessionId, chat, config, this.connection);
const session = new Session(
sessionId,
chat,
config,
this.connection,
this.settings,
);
this.sessions.set(sessionId, session);
setTimeout(() => {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
session.sendAvailableCommands();
}, 0);
return {
sessionId,
modes: {
@@ -291,6 +303,7 @@ export class GeminiAgent {
geminiClient.getChat(),
config,
this.connection,
this.settings,
);
this.sessions.set(sessionId, session);
@@ -298,6 +311,11 @@ export class GeminiAgent {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
session.streamHistory(sessionData.messages);
setTimeout(() => {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
session.sendAvailableCommands();
}, 0);
return {
modes: {
availableModes: buildAvailableModes(config.isPlanEnabled()),
@@ -418,12 +436,14 @@ export class GeminiAgent {
export class Session {
private pendingPrompt: AbortController | null = null;
private commandHandler = new CommandHandler();
constructor(
private readonly id: string,
private readonly chat: GeminiChat,
private readonly config: Config,
private readonly connection: acp.AgentSideConnection,
private readonly settings: LoadedSettings,
) {}
async cancelPendingPrompt(): Promise<void> {
@@ -446,6 +466,22 @@ export class Session {
return {};
}
private getAvailableCommands() {
return this.commandHandler.getAvailableCommands();
}
async sendAvailableCommands(): Promise<void> {
const availableCommands = this.getAvailableCommands().map((command) => ({
name: command.name,
description: command.description,
}));
await this.sendUpdate({
sessionUpdate: 'available_commands_update',
availableCommands,
});
}
async streamHistory(messages: ConversationRecord['messages']): Promise<void> {
for (const msg of messages) {
const contentString = partListUnionToString(msg.content);
@@ -528,6 +564,41 @@ export class Session {
const parts = await this.#resolvePrompt(params.prompt, pendingSend.signal);
// Command interception
let commandText = '';
for (const part of parts) {
if (typeof part === 'object' && part !== null) {
if ('text' in part) {
// It is a text part
// eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-assignment, @typescript-eslint/no-unsafe-type-assertion
const text = (part as any).text;
if (typeof text === 'string') {
commandText += text;
}
} else {
// Non-text part (image, embedded resource)
// Stop looking for command
break;
}
}
}
commandText = commandText.trim();
if (
commandText &&
(commandText.startsWith('/') || commandText.startsWith('$'))
) {
// If we found a command, pass it to handleCommand
// Note: handleCommand currently expects `commandText` to be the command string
// It uses `parts` argument but effectively ignores it in current implementation
const handled = await this.handleCommand(commandText, parts);
if (handled) {
return { stopReason: 'end_turn' };
}
}
let nextMessage: Content | null = { role: 'user', parts };
while (nextMessage !== null) {
@@ -627,9 +698,28 @@ export class Session {
return { stopReason: 'end_turn' };
}
private async sendUpdate(
update: acp.SessionNotification['update'],
): Promise<void> {
private async handleCommand(
commandText: string,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
parts: Part[],
): Promise<boolean> {
const gitService = await this.config.getGitService();
const commandContext = {
config: this.config,
settings: this.settings,
git: gitService,
sendMessage: async (text: string) => {
await this.sendUpdate({
sessionUpdate: 'agent_message_chunk',
content: { type: 'text', text },
});
},
};
return this.commandHandler.handleCommand(commandText, commandContext);
}
private async sendUpdate(update: acp.SessionUpdate): Promise<void> {
const params: acp.SessionNotification = {
sessionId: this.id,
update,