diff --git a/packages/cli/src/zed-integration/acpResume.test.ts b/packages/cli/src/zed-integration/acpResume.test.ts index 54c04a0ff3..9addafd369 100644 --- a/packages/cli/src/zed-integration/acpResume.test.ts +++ b/packages/cli/src/zed-integration/acpResume.test.ts @@ -93,6 +93,7 @@ describe('GeminiAgent Session Resume', () => { }, getApprovalMode: vi.fn().mockReturnValue('default'), isPlanEnabled: vi.fn().mockReturnValue(false), + getCheckpointingEnabled: vi.fn().mockReturnValue(false), } as unknown as Mocked; mockSettings = { merged: { diff --git a/packages/cli/src/zed-integration/commandHandler.test.ts b/packages/cli/src/zed-integration/commandHandler.test.ts new file mode 100644 index 0000000000..8e04f014f3 --- /dev/null +++ b/packages/cli/src/zed-integration/commandHandler.test.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { CommandHandler } from './commandHandler.js'; +import { describe, it, expect } from 'vitest'; + +describe('CommandHandler', () => { + it('parses commands correctly', () => { + const handler = new CommandHandler(); + // @ts-expect-error - testing private method + const parse = (query: string) => handler.parseSlashCommand(query); + + const memShow = parse('/memory show'); + expect(memShow.commandToExecute?.name).toBe('memory show'); + expect(memShow.args).toBe(''); + + const memAdd = parse('/memory add hello world'); + expect(memAdd.commandToExecute?.name).toBe('memory add'); + expect(memAdd.args).toBe('hello world'); + + const extList = parse('/extensions list'); + expect(extList.commandToExecute?.name).toBe('extensions list'); + + const init = parse('/init'); + expect(init.commandToExecute?.name).toBe('init'); + }); +}); diff --git a/packages/cli/src/zed-integration/commandHandler.ts b/packages/cli/src/zed-integration/commandHandler.ts new file mode 100644 index 0000000000..836cdf7736 --- /dev/null +++ b/packages/cli/src/zed-integration/commandHandler.ts @@ -0,0 +1,134 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Command, CommandContext } from './commands/types.js'; +import { CommandRegistry } from './commands/commandRegistry.js'; +import { MemoryCommand } from './commands/memory.js'; +import { ExtensionsCommand } from './commands/extensions.js'; +import { InitCommand } from './commands/init.js'; +import { RestoreCommand } from './commands/restore.js'; + +export class CommandHandler { + private registry: CommandRegistry; + + constructor() { + this.registry = CommandHandler.createRegistry(); + } + + private static createRegistry(): CommandRegistry { + const registry = new CommandRegistry(); + registry.register(new MemoryCommand()); + registry.register(new ExtensionsCommand()); + registry.register(new InitCommand()); + registry.register(new RestoreCommand()); + return registry; + } + + getAvailableCommands(): Array<{ name: string; description: string }> { + return this.registry.getAllCommands().map((cmd) => ({ + name: cmd.name, + description: cmd.description, + })); + } + + /** + * Parses and executes a command string if it matches a registered command. + * Returns true if a command was handled, false otherwise. + */ + async handleCommand( + commandText: string, + context: CommandContext, + ): Promise { + const { commandToExecute, args } = this.parseSlashCommand(commandText); + + if (commandToExecute) { + await this.runCommand(commandToExecute, args, context); + return true; + } + + return false; + } + + private async runCommand( + commandToExecute: Command, + args: string, + context: CommandContext, + ): Promise { + try { + const result = await commandToExecute.execute( + context, + args ? args.split(/\s+/) : [], + ); + + let messageContent = ''; + if (typeof result.data === 'string') { + messageContent = result.data; + } else if ( + typeof result.data === 'object' && + result.data !== null && + 'content' in result.data + ) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any + messageContent = (result.data as Record)[ + 'content' + ] as string; + } else { + messageContent = JSON.stringify(result.data, null, 2); + } + + await context.sendMessage(messageContent); + } catch (error) { + const errorMessage = + error instanceof Error ? error.message : String(error); + await context.sendMessage(`Error: ${errorMessage}`); + } + } + + /** + * Parses a raw slash command string into its matching headless command and arguments. + * Mirrors `packages/cli/src/utils/commands.ts` logic. + */ + private parseSlashCommand(query: string): { + commandToExecute: Command | undefined; + args: string; + } { + const trimmed = query.trim(); + const parts = trimmed.substring(1).trim().split(/\s+/); + const commandPath = parts.filter((p) => p); + + let currentCommands = this.registry.getAllCommands(); + let commandToExecute: Command | undefined; + let pathIndex = 0; + + for (const part of commandPath) { + const foundCommand = currentCommands.find((cmd) => { + const expectedName = commandPath.slice(0, pathIndex + 1).join(' '); + return ( + cmd.name === part || + cmd.name === expectedName || + cmd.aliases?.includes(part) || + cmd.aliases?.includes(expectedName) + ); + }); + + if (foundCommand) { + commandToExecute = foundCommand; + pathIndex++; + if (foundCommand.subCommands) { + currentCommands = foundCommand.subCommands; + } else { + break; + } + } else { + break; + } + } + + const args = parts.slice(pathIndex).join(' '); + + return { commandToExecute, args }; + } +} diff --git a/packages/cli/src/zed-integration/commands/commandRegistry.ts b/packages/cli/src/zed-integration/commands/commandRegistry.ts new file mode 100644 index 0000000000..b689d5d602 --- /dev/null +++ b/packages/cli/src/zed-integration/commands/commandRegistry.ts @@ -0,0 +1,33 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { debugLogger } from '@google/gemini-cli-core'; +import type { Command } from './types.js'; + +export class CommandRegistry { + private readonly commands = new Map(); + + register(command: Command) { + if (this.commands.has(command.name)) { + debugLogger.warn(`Command ${command.name} already registered. Skipping.`); + return; + } + + this.commands.set(command.name, command); + + for (const subCommand of command.subCommands ?? []) { + this.register(subCommand); + } + } + + get(commandName: string): Command | undefined { + return this.commands.get(commandName); + } + + getAllCommands(): Command[] { + return [...this.commands.values()]; + } +} diff --git a/packages/cli/src/zed-integration/commands/extensions.ts b/packages/cli/src/zed-integration/commands/extensions.ts new file mode 100644 index 0000000000..b9a3ad81ab --- /dev/null +++ b/packages/cli/src/zed-integration/commands/extensions.ts @@ -0,0 +1,428 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { listExtensions } from '@google/gemini-cli-core'; +import { SettingScope } from '../../config/settings.js'; +import { + ExtensionManager, + inferInstallMetadata, +} from '../../config/extension-manager.js'; +import { getErrorMessage } from '../../utils/errors.js'; +import { McpServerEnablementManager } from '../../config/mcp/mcpServerEnablement.js'; +import { stat } from 'node:fs/promises'; +import type { + Command, + CommandContext, + CommandExecutionResponse, +} from './types.js'; +import type { Config } from '@google/gemini-cli-core'; + +export class ExtensionsCommand implements Command { + readonly name = 'extensions'; + readonly description = 'Manage extensions.'; + readonly subCommands = [ + new ListExtensionsCommand(), + new ExploreExtensionsCommand(), + new EnableExtensionCommand(), + new DisableExtensionCommand(), + new InstallExtensionCommand(), + new LinkExtensionCommand(), + new UninstallExtensionCommand(), + new RestartExtensionCommand(), + new UpdateExtensionCommand(), + ]; + + async execute( + context: CommandContext, + _: string[], + ): Promise { + return new ListExtensionsCommand().execute(context, _); + } +} + +export class ListExtensionsCommand implements Command { + readonly name = 'extensions list'; + readonly description = 'Lists all installed extensions.'; + + async execute( + context: CommandContext, + _: string[], + ): Promise { + const extensions = listExtensions(context.config); + const data = extensions.length ? extensions : 'No extensions installed.'; + + return { name: this.name, data }; + } +} + +export class ExploreExtensionsCommand implements Command { + readonly name = 'extensions explore'; + readonly description = 'Explore available extensions.'; + + async execute( + _context: CommandContext, + _: string[], + ): Promise { + const extensionsUrl = 'https://geminicli.com/extensions/'; + return { + name: this.name, + data: `View or install available extensions at ${extensionsUrl}`, + }; + } +} + +function getEnableDisableContext( + config: Config, + args: string[], + invocationName: string, +) { + const extensionManager = config.getExtensionLoader(); + if (!(extensionManager instanceof ExtensionManager)) { + return { + error: `Cannot ${invocationName} extensions in this environment.`, + }; + } + + if (args.length === 0) { + return { + error: `Usage: /extensions ${invocationName} [--scope=]`, + }; + } + + let scope = SettingScope.User; + if (args.includes('--scope=workspace') || args.includes('workspace')) { + scope = SettingScope.Workspace; + } else if (args.includes('--scope=session') || args.includes('session')) { + scope = SettingScope.Session; + } + + const name = args.filter( + (a) => + !a.startsWith('--scope') && !['user', 'workspace', 'session'].includes(a), + )[0]; + + let names: string[] = []; + if (name === '--all') { + let extensions = extensionManager.getExtensions(); + if (invocationName === 'enable') { + extensions = extensions.filter((ext) => !ext.isActive); + } + if (invocationName === 'disable') { + extensions = extensions.filter((ext) => ext.isActive); + } + names = extensions.map((ext) => ext.name); + } else if (name) { + names = [name]; + } else { + return { error: 'No extension name provided.' }; + } + + return { extensionManager, names, scope }; +} + +export class EnableExtensionCommand implements Command { + readonly name = 'extensions enable'; + readonly description = 'Enable an extension.'; + + async execute( + context: CommandContext, + args: string[], + ): Promise { + const enableContext = getEnableDisableContext( + context.config, + args, + 'enable', + ); + if ('error' in enableContext) { + return { name: this.name, data: enableContext.error }; + } + + const { names, scope, extensionManager } = enableContext; + const output: string[] = []; + + for (const name of names) { + try { + await extensionManager.enableExtension(name, scope); + output.push(`Extension "${name}" enabled for scope "${scope}".`); + + const extension = extensionManager + .getExtensions() + .find((e) => e.name === name); + + if (extension?.mcpServers) { + const mcpEnablementManager = McpServerEnablementManager.getInstance(); + const mcpClientManager = context.config.getMcpClientManager(); + const enabledServers = await mcpEnablementManager.autoEnableServers( + Object.keys(extension.mcpServers), + ); + + if (mcpClientManager && enabledServers.length > 0) { + const restartPromises = enabledServers.map((serverName) => + mcpClientManager.restartServer(serverName).catch((error) => { + output.push( + `Failed to restart MCP server '${serverName}': ${getErrorMessage(error)}`, + ); + }), + ); + await Promise.all(restartPromises); + output.push(`Re-enabled MCP servers: ${enabledServers.join(', ')}`); + } + } + } catch (e) { + output.push(`Failed to enable "${name}": ${getErrorMessage(e)}`); + } + } + + return { name: this.name, data: output.join('\n') || 'No action taken.' }; + } +} + +export class DisableExtensionCommand implements Command { + readonly name = 'extensions disable'; + readonly description = 'Disable an extension.'; + + async execute( + context: CommandContext, + args: string[], + ): Promise { + const enableContext = getEnableDisableContext( + context.config, + args, + 'disable', + ); + if ('error' in enableContext) { + return { name: this.name, data: enableContext.error }; + } + + const { names, scope, extensionManager } = enableContext; + const output: string[] = []; + + for (const name of names) { + try { + await extensionManager.disableExtension(name, scope); + output.push(`Extension "${name}" disabled for scope "${scope}".`); + } catch (e) { + output.push(`Failed to disable "${name}": ${getErrorMessage(e)}`); + } + } + + return { name: this.name, data: output.join('\n') || 'No action taken.' }; + } +} + +export class InstallExtensionCommand implements Command { + readonly name = 'extensions install'; + readonly description = 'Install an extension from a git repo or local path.'; + + async execute( + context: CommandContext, + args: string[], + ): Promise { + const extensionLoader = context.config.getExtensionLoader(); + if (!(extensionLoader instanceof ExtensionManager)) { + return { + name: this.name, + data: 'Cannot install extensions in this environment.', + }; + } + + const source = args.join(' ').trim(); + if (!source) { + return { name: this.name, data: `Usage: /extensions install ` }; + } + + if (/[;&|`'"]/.test(source)) { + return { + name: this.name, + data: `Invalid source: contains disallowed characters.`, + }; + } + + try { + const installMetadata = await inferInstallMetadata(source); + const extension = + await extensionLoader.installOrUpdateExtension(installMetadata); + return { + name: this.name, + data: `Extension "${extension.name}" installed successfully.`, + }; + } catch (error) { + return { + name: this.name, + data: `Failed to install extension from "${source}": ${getErrorMessage(error)}`, + }; + } + } +} + +export class LinkExtensionCommand implements Command { + readonly name = 'extensions link'; + readonly description = 'Link an extension from a local path.'; + + async execute( + context: CommandContext, + args: string[], + ): Promise { + const extensionLoader = context.config.getExtensionLoader(); + if (!(extensionLoader instanceof ExtensionManager)) { + return { + name: this.name, + data: 'Cannot link extensions in this environment.', + }; + } + + const sourceFilepath = args.join(' ').trim(); + if (!sourceFilepath) { + return { name: this.name, data: `Usage: /extensions link ` }; + } + + try { + await stat(sourceFilepath); + } catch (_error) { + return { name: this.name, data: `Invalid source: ${sourceFilepath}` }; + } + + try { + const extension = await extensionLoader.installOrUpdateExtension({ + source: sourceFilepath, + type: 'link', + }); + return { + name: this.name, + data: `Extension "${extension.name}" linked successfully.`, + }; + } catch (error) { + return { + name: this.name, + data: `Failed to link extension: ${getErrorMessage(error)}`, + }; + } + } +} + +export class UninstallExtensionCommand implements Command { + readonly name = 'extensions uninstall'; + readonly description = 'Uninstall an extension.'; + + async execute( + context: CommandContext, + args: string[], + ): Promise { + const extensionLoader = context.config.getExtensionLoader(); + if (!(extensionLoader instanceof ExtensionManager)) { + return { + name: this.name, + data: 'Cannot uninstall extensions in this environment.', + }; + } + + const name = args.join(' ').trim(); + if (!name) { + return { + name: this.name, + data: `Usage: /extensions uninstall `, + }; + } + + try { + await extensionLoader.uninstallExtension(name, false); + return { + name: this.name, + data: `Extension "${name}" uninstalled successfully.`, + }; + } catch (error) { + return { + name: this.name, + data: `Failed to uninstall extension "${name}": ${getErrorMessage(error)}`, + }; + } + } +} + +export class RestartExtensionCommand implements Command { + readonly name = 'extensions restart'; + readonly description = 'Restart an extension.'; + + async execute( + context: CommandContext, + args: string[], + ): Promise { + const extensionLoader = context.config.getExtensionLoader(); + if (!(extensionLoader instanceof ExtensionManager)) { + return { name: this.name, data: 'Cannot restart extensions.' }; + } + + const all = args.includes('--all'); + const names = all ? null : args.filter((a) => !!a); + + if (!all && names?.length === 0) { + return { + name: this.name, + data: 'Usage: /extensions restart |--all', + }; + } + + let extensionsToRestart = extensionLoader + .getExtensions() + .filter((e) => e.isActive); + if (names) { + extensionsToRestart = extensionsToRestart.filter((e) => + names.includes(e.name), + ); + } + + if (extensionsToRestart.length === 0) { + return { + name: this.name, + data: 'No active extensions matched the request.', + }; + } + + const output: string[] = []; + for (const extension of extensionsToRestart) { + try { + await extensionLoader.restartExtension(extension); + output.push(`Restarted "${extension.name}".`); + } catch (e) { + output.push( + `Failed to restart "${extension.name}": ${getErrorMessage(e)}`, + ); + } + } + + return { name: this.name, data: output.join('\n') }; + } +} + +export class UpdateExtensionCommand implements Command { + readonly name = 'extensions update'; + readonly description = 'Update an extension.'; + + async execute( + context: CommandContext, + args: string[], + ): Promise { + const extensionLoader = context.config.getExtensionLoader(); + if (!(extensionLoader instanceof ExtensionManager)) { + return { name: this.name, data: 'Cannot update extensions.' }; + } + + const all = args.includes('--all'); + const names = all ? null : args.filter((a) => !!a); + + if (!all && names?.length === 0) { + return { + name: this.name, + data: 'Usage: /extensions update |--all', + }; + } + + return { + name: this.name, + data: 'Headless extension updating requires internal UI dispatches. Please use `gemini extensions update` directly in the terminal.', + }; + } +} diff --git a/packages/cli/src/zed-integration/commands/init.ts b/packages/cli/src/zed-integration/commands/init.ts new file mode 100644 index 0000000000..5c4197f84c --- /dev/null +++ b/packages/cli/src/zed-integration/commands/init.ts @@ -0,0 +1,62 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as fs from 'node:fs'; +import * as path from 'node:path'; +import { performInit } from '@google/gemini-cli-core'; +import type { + Command, + CommandContext, + CommandExecutionResponse, +} from './types.js'; + +export class InitCommand implements Command { + name = 'init'; + description = 'Analyzes the project and creates a tailored GEMINI.md file'; + requiresWorkspace = true; + + async execute( + context: CommandContext, + _args: string[] = [], + ): Promise { + const targetDir = context.config.getTargetDir(); + if (!targetDir) { + throw new Error('Command requires a workspace.'); + } + + const geminiMdPath = path.join(targetDir, 'GEMINI.md'); + const result = performInit(fs.existsSync(geminiMdPath)); + + switch (result.type) { + case 'message': + return { + name: this.name, + data: result, + }; + case 'submit_prompt': + fs.writeFileSync(geminiMdPath, '', 'utf8'); + + if (typeof result.content !== 'string') { + throw new Error('Init command content must be a string.'); + } + + // Inform the user since we can't trigger the UI-based interactive agent loop here directly. + // We output the prompt text they can use to re-trigger the generation manually, + // or just seed the GEMINI.md file as we've done above. + return { + name: this.name, + data: { + type: 'message', + messageType: 'info', + content: `A template GEMINI.md has been created at ${geminiMdPath}.\n\nTo populate it with project context, you can run the following prompt in a new chat:\n\n${result.content}`, + }, + }; + + default: + throw new Error('Unknown result type from performInit'); + } + } +} diff --git a/packages/cli/src/zed-integration/commands/memory.ts b/packages/cli/src/zed-integration/commands/memory.ts new file mode 100644 index 0000000000..9460af7ad1 --- /dev/null +++ b/packages/cli/src/zed-integration/commands/memory.ts @@ -0,0 +1,121 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + addMemory, + listMemoryFiles, + refreshMemory, + showMemory, +} from '@google/gemini-cli-core'; +import type { + Command, + CommandContext, + CommandExecutionResponse, +} from './types.js'; + +const DEFAULT_SANITIZATION_CONFIG = { + allowedEnvironmentVariables: [], + blockedEnvironmentVariables: [], + enableEnvironmentVariableRedaction: false, +}; + +export class MemoryCommand implements Command { + readonly name = 'memory'; + readonly description = 'Manage memory.'; + readonly subCommands = [ + new ShowMemoryCommand(), + new RefreshMemoryCommand(), + new ListMemoryCommand(), + new AddMemoryCommand(), + ]; + readonly requiresWorkspace = true; + + async execute( + context: CommandContext, + _: string[], + ): Promise { + return new ShowMemoryCommand().execute(context, _); + } +} + +export class ShowMemoryCommand implements Command { + readonly name = 'memory show'; + readonly description = 'Shows the current memory contents.'; + + async execute( + context: CommandContext, + _: string[], + ): Promise { + const result = showMemory(context.config); + return { name: this.name, data: result.content }; + } +} + +export class RefreshMemoryCommand implements Command { + readonly name = 'memory refresh'; + readonly aliases = ['memory reload']; + readonly description = 'Refreshes the memory from the source.'; + + async execute( + context: CommandContext, + _: string[], + ): Promise { + const result = await refreshMemory(context.config); + return { name: this.name, data: result.content }; + } +} + +export class ListMemoryCommand implements Command { + readonly name = 'memory list'; + readonly description = 'Lists the paths of the GEMINI.md files in use.'; + + async execute( + context: CommandContext, + _: string[], + ): Promise { + const result = listMemoryFiles(context.config); + return { name: this.name, data: result.content }; + } +} + +export class AddMemoryCommand implements Command { + readonly name = 'memory add'; + readonly description = 'Add content to the memory.'; + + async execute( + context: CommandContext, + args: string[], + ): Promise { + const textToAdd = args.join(' ').trim(); + const result = addMemory(textToAdd); + if (result.type === 'message') { + return { name: this.name, data: result.content }; + } + + const toolRegistry = context.config.getToolRegistry(); + const tool = toolRegistry.getTool(result.toolName); + if (tool) { + const abortController = new AbortController(); + const signal = abortController.signal; + + await context.sendMessage(`Saving memory via ${result.toolName}...`); + + await tool.buildAndExecute(result.toolArgs, signal, undefined, { + sanitizationConfig: DEFAULT_SANITIZATION_CONFIG, + }); + await refreshMemory(context.config); + return { + name: this.name, + data: `Added memory: "${textToAdd}"`, + }; + } else { + return { + name: this.name, + data: `Error: Tool ${result.toolName} not found.`, + }; + } + } +} diff --git a/packages/cli/src/zed-integration/commands/restore.ts b/packages/cli/src/zed-integration/commands/restore.ts new file mode 100644 index 0000000000..ec9166ed84 --- /dev/null +++ b/packages/cli/src/zed-integration/commands/restore.ts @@ -0,0 +1,178 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + getCheckpointInfoList, + getToolCallDataSchema, + isNodeError, + performRestore, +} from '@google/gemini-cli-core'; +import * as fs from 'node:fs/promises'; +import * as path from 'node:path'; +import type { + Command, + CommandContext, + CommandExecutionResponse, +} from './types.js'; + +export class RestoreCommand implements Command { + readonly name = 'restore'; + readonly description = + 'Restore to a previous checkpoint, or list available checkpoints to restore. This will reset the conversation and file history to the state it was in when the checkpoint was created'; + readonly requiresWorkspace = true; + readonly subCommands = [new ListCheckpointsCommand()]; + + async execute( + context: CommandContext, + args: string[], + ): Promise { + const { config, git: gitService } = context; + const argsStr = args.join(' '); + + try { + if (!argsStr) { + return await new ListCheckpointsCommand().execute(context); + } + + if (!config.getCheckpointingEnabled()) { + return { + name: this.name, + data: 'Checkpointing is not enabled. Please enable it in your settings (`general.checkpointing.enabled: true`) to use /restore.', + }; + } + + const selectedFile = argsStr.endsWith('.json') + ? argsStr + : `${argsStr}.json`; + + const checkpointDir = config.storage.getProjectTempCheckpointsDir(); + const filePath = path.join(checkpointDir, selectedFile); + + let data: string; + try { + data = await fs.readFile(filePath, 'utf-8'); + } catch (error) { + if (isNodeError(error) && error.code === 'ENOENT') { + return { + name: this.name, + data: `File not found: ${selectedFile}`, + }; + } + throw error; + } + + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + const toolCallData = JSON.parse(data); + const ToolCallDataSchema = getToolCallDataSchema(); + const parseResult = ToolCallDataSchema.safeParse(toolCallData); + + if (!parseResult.success) { + return { + name: this.name, + data: 'Checkpoint file is invalid or corrupted.', + }; + } + + const restoreResultGenerator = performRestore( + parseResult.data, + gitService, + ); + + const restoreResult = []; + for await (const result of restoreResultGenerator) { + restoreResult.push(result); + } + + // Format the result nicely since Zed just dumps data + const formattedResult = restoreResult + .map((r) => { + if (r.type === 'message') { + return `[${r.messageType.toUpperCase()}] ${r.content}`; + } else if (r.type === 'load_history') { + return `Loaded history with ${r.clientHistory.length} messages.`; + } + return `Restored: ${JSON.stringify(r)}`; + }) + .join('\n'); + + return { + name: this.name, + data: formattedResult, + }; + } catch (error) { + return { + name: this.name, + data: `An unexpected error occurred during restore: ${error}`, + }; + } + } +} + +export class ListCheckpointsCommand implements Command { + readonly name = 'restore list'; + readonly description = 'Lists all available checkpoints.'; + + async execute(context: CommandContext): Promise { + const { config } = context; + + try { + if (!config.getCheckpointingEnabled()) { + return { + name: this.name, + data: 'Checkpointing is not enabled. Please enable it in your settings (`general.checkpointing.enabled: true`) to use /restore.', + }; + } + + const checkpointDir = config.storage.getProjectTempCheckpointsDir(); + try { + await fs.mkdir(checkpointDir, { recursive: true }); + } catch (_e) { + // Ignore + } + + const files = await fs.readdir(checkpointDir); + const jsonFiles = files.filter((file) => file.endsWith('.json')); + + if (jsonFiles.length === 0) { + return { name: this.name, data: 'No checkpoints found.' }; + } + + const checkpointFiles = new Map(); + for (const file of jsonFiles) { + const filePath = path.join(checkpointDir, file); + const data = await fs.readFile(filePath, 'utf-8'); + checkpointFiles.set(file, data); + } + + const checkpointInfoList = getCheckpointInfoList(checkpointFiles); + + const formatted = checkpointInfoList + .map((info) => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const i = info as Record; + const fileName = String(i['fileName'] || 'Unknown'); + const toolName = String(i['toolName'] || 'Unknown'); + const status = String(i['status'] || 'Unknown'); + const timestamp = new Date( + Number(i['timestamp']) || 0, + ).toLocaleString(); + + return `- **${fileName}**: ${toolName} (Status: ${status}) [${timestamp}]`; + }) + .join('\n'); + + return { + name: this.name, + data: `Available Checkpoints:\n${formatted}`, + }; + } catch (_error) { + return { + name: this.name, + data: 'An unexpected error occurred while listing checkpoints.', + }; + } + } +} diff --git a/packages/cli/src/zed-integration/commands/types.ts b/packages/cli/src/zed-integration/commands/types.ts new file mode 100644 index 0000000000..099f0c923f --- /dev/null +++ b/packages/cli/src/zed-integration/commands/types.ts @@ -0,0 +1,40 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config, GitService } from '@google/gemini-cli-core'; +import type { LoadedSettings } from '../../config/settings.js'; + +export interface CommandContext { + config: Config; + settings: LoadedSettings; + git?: GitService; + sendMessage: (text: string) => Promise; +} + +export interface CommandArgument { + readonly name: string; + readonly description: string; + readonly isRequired?: boolean; +} + +export interface Command { + readonly name: string; + readonly aliases?: string[]; + readonly description: string; + readonly arguments?: CommandArgument[]; + readonly subCommands?: Command[]; + readonly requiresWorkspace?: boolean; + + execute( + context: CommandContext, + args: string[], + ): Promise; +} + +export interface CommandExecutionResponse { + readonly name: string; + readonly data: unknown; +} diff --git a/packages/cli/src/zed-integration/zedIntegration.test.ts b/packages/cli/src/zed-integration/zedIntegration.test.ts index e8e5355dc0..23ba8b8ab8 100644 --- a/packages/cli/src/zed-integration/zedIntegration.test.ts +++ b/packages/cli/src/zed-integration/zedIntegration.test.ts @@ -15,6 +15,7 @@ import { type Mocked, } from 'vitest'; import { GeminiAgent, Session } from './zedIntegration.js'; +import type { CommandHandler } from './commandHandler.js'; import * as acp from '@agentclientprotocol/sdk'; import { AuthType, @@ -26,6 +27,7 @@ import { type Config, type MessageBus, LlmRole, + type GitService, } from '@google/gemini-cli-core'; import { SettingScope, @@ -62,7 +64,33 @@ vi.mock('node:path', async (importOriginal) => { }; }); -// Mock ReadManyFilesTool +vi.mock('../ui/commands/memoryCommand.js', () => ({ + memoryCommand: { + name: 'memory', + action: vi.fn(), + }, +})); + +vi.mock('../ui/commands/extensionsCommand.js', () => ({ + extensionsCommand: vi.fn().mockReturnValue({ + name: 'extensions', + action: vi.fn(), + }), +})); + +vi.mock('../ui/commands/restoreCommand.js', () => ({ + restoreCommand: vi.fn().mockReturnValue({ + name: 'restore', + action: vi.fn(), + }), +})); + +vi.mock('../ui/commands/initCommand.js', () => ({ + initCommand: { + name: 'init', + action: vi.fn(), + }, +})); vi.mock( '@google/gemini-cli-core', async ( @@ -145,6 +173,7 @@ describe('GeminiAgent', () => { }), getApprovalMode: vi.fn().mockReturnValue('default'), isPlanEnabled: vi.fn().mockReturnValue(false), + getCheckpointingEnabled: vi.fn().mockReturnValue(false), } as unknown as Mocked>>; mockSettings = { merged: { @@ -225,6 +254,7 @@ describe('GeminiAgent', () => { }); it('should create a new session', async () => { + vi.useFakeTimers(); mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ apiKey: 'test-key', }); @@ -237,6 +267,17 @@ describe('GeminiAgent', () => { expect(loadCliConfig).toHaveBeenCalled(); expect(mockConfig.initialize).toHaveBeenCalled(); expect(mockConfig.getGeminiClient).toHaveBeenCalled(); + + // Verify deferred call + await vi.runAllTimersAsync(); + expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( + expect.objectContaining({ + update: expect.objectContaining({ + sessionUpdate: 'available_commands_update', + }), + }), + ); + vi.useRealTimers(); }); it('should return modes without plan mode when plan is disabled', async () => { @@ -477,6 +518,7 @@ describe('Session', () => { getModel: vi.fn().mockReturnValue('gemini-pro'), getActiveModel: vi.fn().mockReturnValue('gemini-pro'), getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + getMcpServers: vi.fn(), getFileService: vi.fn().mockReturnValue({ shouldIgnoreFile: vi.fn().mockReturnValue(false), }), @@ -487,6 +529,8 @@ describe('Session', () => { getMessageBus: vi.fn().mockReturnValue(mockMessageBus), setApprovalMode: vi.fn(), isPlanEnabled: vi.fn().mockReturnValue(false), + getCheckpointingEnabled: vi.fn().mockReturnValue(false), + getGitService: vi.fn().mockResolvedValue({} as GitService), waitForMcpInit: vi.fn(), } as unknown as Mocked; mockConnection = { @@ -495,13 +539,38 @@ describe('Session', () => { sendNotification: vi.fn(), } as unknown as Mocked; - session = new Session('session-1', mockChat, mockConfig, mockConnection); + session = new Session('session-1', mockChat, mockConfig, mockConnection, { + system: { settings: {} }, + systemDefaults: { settings: {} }, + user: { settings: {} }, + workspace: { settings: {} }, + merged: { settings: {} }, + errors: [], + } as unknown as LoadedSettings); }); afterEach(() => { vi.clearAllMocks(); }); + it('should send available commands', async () => { + await session.sendAvailableCommands(); + + expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( + expect.objectContaining({ + update: expect.objectContaining({ + sessionUpdate: 'available_commands_update', + availableCommands: expect.arrayContaining([ + expect.objectContaining({ name: 'memory' }), + expect.objectContaining({ name: 'extensions' }), + expect.objectContaining({ name: 'restore' }), + expect.objectContaining({ name: 'init' }), + ]), + }), + }), + ); + }); + it('should await MCP initialization before processing a prompt', async () => { const stream = createMockStream([ { @@ -551,6 +620,113 @@ describe('Session', () => { expect(result).toEqual({ stopReason: 'end_turn' }); }); + it('should handle /memory command', async () => { + const handleCommandSpy = vi + .spyOn( + (session as unknown as { commandHandler: CommandHandler }) + .commandHandler, + 'handleCommand', + ) + .mockResolvedValue(true); + + const result = await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: '/memory view' }], + }); + + expect(result).toEqual({ stopReason: 'end_turn' }); + expect(handleCommandSpy).toHaveBeenCalledWith( + '/memory view', + expect.any(Object), + ); + expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); + }); + + it('should handle /extensions command', async () => { + const handleCommandSpy = vi + .spyOn( + (session as unknown as { commandHandler: CommandHandler }) + .commandHandler, + 'handleCommand', + ) + .mockResolvedValue(true); + + const result = await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: '/extensions list' }], + }); + + expect(result).toEqual({ stopReason: 'end_turn' }); + expect(handleCommandSpy).toHaveBeenCalledWith( + '/extensions list', + expect.any(Object), + ); + expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); + }); + + it('should handle /extensions explore command', async () => { + const handleCommandSpy = vi + .spyOn( + (session as unknown as { commandHandler: CommandHandler }) + .commandHandler, + 'handleCommand', + ) + .mockResolvedValue(true); + + const result = await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: '/extensions explore' }], + }); + + expect(result).toEqual({ stopReason: 'end_turn' }); + expect(handleCommandSpy).toHaveBeenCalledWith( + '/extensions explore', + expect.any(Object), + ); + expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); + }); + + it('should handle /restore command', async () => { + const handleCommandSpy = vi + .spyOn( + (session as unknown as { commandHandler: CommandHandler }) + .commandHandler, + 'handleCommand', + ) + .mockResolvedValue(true); + + const result = await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: '/restore' }], + }); + + expect(result).toEqual({ stopReason: 'end_turn' }); + expect(handleCommandSpy).toHaveBeenCalledWith( + '/restore', + expect.any(Object), + ); + expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); + }); + + it('should handle /init command', async () => { + const handleCommandSpy = vi + .spyOn( + (session as unknown as { commandHandler: CommandHandler }) + .commandHandler, + 'handleCommand', + ) + .mockResolvedValue(true); + + const result = await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: '/init' }], + }); + + expect(result).toEqual({ stopReason: 'end_turn' }); + expect(handleCommandSpy).toHaveBeenCalledWith('/init', expect.any(Object)); + expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); + }); + it('should handle tool calls', async () => { const stream1 = createMockStream([ { @@ -1207,4 +1383,24 @@ describe('Session', () => { 'Invalid or unavailable mode: invalid-mode', ); }); + it('should handle unquoted commands from autocomplete (with empty leading parts)', async () => { + // Mock handleCommand to verify it gets called + const handleCommandSpy = vi + .spyOn( + (session as unknown as { commandHandler: CommandHandler }) + .commandHandler, + 'handleCommand', + ) + .mockResolvedValue(true); + + await session.prompt({ + sessionId: 'session-1', + prompt: [ + { type: 'text', text: '' }, + { type: 'text', text: '/memory' }, + ], + }); + + expect(handleCommandSpy).toHaveBeenCalledWith('/memory', expect.anything()); + }); }); diff --git a/packages/cli/src/zed-integration/zedIntegration.ts b/packages/cli/src/zed-integration/zedIntegration.ts index 98c9efdc75..30bf8551f0 100644 --- a/packages/cli/src/zed-integration/zedIntegration.ts +++ b/packages/cli/src/zed-integration/zedIntegration.ts @@ -4,15 +4,13 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { - Config, - GeminiChat, - ToolResult, - ToolCallConfirmationDetails, - FilterFilesOptions, - ConversationRecord, -} from '@google/gemini-cli-core'; import { + type Config, + type GeminiChat, + type ToolResult, + type ToolCallConfirmationDetails, + type FilterFilesOptions, + type ConversationRecord, CoreToolCallStatus, AuthType, logToolCall, @@ -61,11 +59,14 @@ import { loadCliConfig } from '../config/config.js'; import { runExitCleanup } from '../utils/cleanup.js'; import { SessionSelector } from '../utils/sessionUtils.js'; +import { CommandHandler } from './commandHandler.js'; export async function runZedIntegration( config: Config, settings: LoadedSettings, argv: CliArgs, ) { + // ... (skip unchanged lines) ... + const { stdout: workingStdout } = createWorkingStdio(); const stdout = Writable.toWeb(workingStdout) as WritableStream; // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion @@ -240,9 +241,20 @@ export class GeminiAgent { const geminiClient = config.getGeminiClient(); const chat = await geminiClient.startChat(); - const session = new Session(sessionId, chat, config, this.connection); + const session = new Session( + sessionId, + chat, + config, + this.connection, + this.settings, + ); this.sessions.set(sessionId, session); + setTimeout(() => { + // eslint-disable-next-line @typescript-eslint/no-floating-promises + session.sendAvailableCommands(); + }, 0); + return { sessionId, modes: { @@ -291,6 +303,7 @@ export class GeminiAgent { geminiClient.getChat(), config, this.connection, + this.settings, ); this.sessions.set(sessionId, session); @@ -298,6 +311,11 @@ export class GeminiAgent { // eslint-disable-next-line @typescript-eslint/no-floating-promises session.streamHistory(sessionData.messages); + setTimeout(() => { + // eslint-disable-next-line @typescript-eslint/no-floating-promises + session.sendAvailableCommands(); + }, 0); + return { modes: { availableModes: buildAvailableModes(config.isPlanEnabled()), @@ -418,12 +436,14 @@ export class GeminiAgent { export class Session { private pendingPrompt: AbortController | null = null; + private commandHandler = new CommandHandler(); constructor( private readonly id: string, private readonly chat: GeminiChat, private readonly config: Config, private readonly connection: acp.AgentSideConnection, + private readonly settings: LoadedSettings, ) {} async cancelPendingPrompt(): Promise { @@ -446,6 +466,22 @@ export class Session { return {}; } + private getAvailableCommands() { + return this.commandHandler.getAvailableCommands(); + } + + async sendAvailableCommands(): Promise { + const availableCommands = this.getAvailableCommands().map((command) => ({ + name: command.name, + description: command.description, + })); + + await this.sendUpdate({ + sessionUpdate: 'available_commands_update', + availableCommands, + }); + } + async streamHistory(messages: ConversationRecord['messages']): Promise { for (const msg of messages) { const contentString = partListUnionToString(msg.content); @@ -528,6 +564,41 @@ export class Session { const parts = await this.#resolvePrompt(params.prompt, pendingSend.signal); + // Command interception + let commandText = ''; + + for (const part of parts) { + if (typeof part === 'object' && part !== null) { + if ('text' in part) { + // It is a text part + // eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-assignment, @typescript-eslint/no-unsafe-type-assertion + const text = (part as any).text; + if (typeof text === 'string') { + commandText += text; + } + } else { + // Non-text part (image, embedded resource) + // Stop looking for command + break; + } + } + } + + commandText = commandText.trim(); + + if ( + commandText && + (commandText.startsWith('/') || commandText.startsWith('$')) + ) { + // If we found a command, pass it to handleCommand + // Note: handleCommand currently expects `commandText` to be the command string + // It uses `parts` argument but effectively ignores it in current implementation + const handled = await this.handleCommand(commandText, parts); + if (handled) { + return { stopReason: 'end_turn' }; + } + } + let nextMessage: Content | null = { role: 'user', parts }; while (nextMessage !== null) { @@ -627,9 +698,28 @@ export class Session { return { stopReason: 'end_turn' }; } - private async sendUpdate( - update: acp.SessionNotification['update'], - ): Promise { + private async handleCommand( + commandText: string, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + parts: Part[], + ): Promise { + const gitService = await this.config.getGitService(); + const commandContext = { + config: this.config, + settings: this.settings, + git: gitService, + sendMessage: async (text: string) => { + await this.sendUpdate({ + sessionUpdate: 'agent_message_chunk', + content: { type: 'text', text }, + }); + }, + }; + + return this.commandHandler.handleCommand(commandText, commandContext); + } + + private async sendUpdate(update: acp.SessionUpdate): Promise { const params: acp.SessionNotification = { sessionId: this.id, update,