feat: Implement direct command handling for /memory,/init,/extensions and /restore,and also send available commands to the client.

This commit is contained in:
Sri Pasumarthi
2026-02-23 12:49:24 -08:00
parent 74e15c3fab
commit 71607574eb
2 changed files with 253 additions and 227 deletions
@@ -26,7 +26,7 @@ import {
type Config,
type MessageBus,
LlmRole,
type MCPServerConfig,
type GitService,
} from '@google/gemini-cli-core';
import {
SettingScope,
@@ -63,7 +63,33 @@ vi.mock('node:path', async (importOriginal) => {
};
});
// Mock ReadManyFilesTool
vi.mock('../ui/commands/memoryCommand.js', () => ({
memoryCommand: {
name: 'memory',
action: vi.fn(),
},
}));
vi.mock('../ui/commands/extensionsCommand.js', () => ({
extensionsCommand: vi.fn().mockReturnValue({
name: 'extensions',
action: vi.fn(),
}),
}));
vi.mock('../ui/commands/restoreCommand.js', () => ({
restoreCommand: vi.fn().mockReturnValue({
name: 'restore',
action: vi.fn(),
}),
}));
vi.mock('../ui/commands/initCommand.js', () => ({
initCommand: {
name: 'init',
action: vi.fn(),
},
}));
vi.mock(
'@google/gemini-cli-core',
async (
@@ -197,6 +223,7 @@ describe('GeminiAgent', () => {
});
it('should create a new session', async () => {
vi.useFakeTimers();
mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({
apiKey: 'test-key',
});
@@ -209,6 +236,17 @@ describe('GeminiAgent', () => {
expect(loadCliConfig).toHaveBeenCalled();
expect(mockConfig.initialize).toHaveBeenCalled();
expect(mockConfig.getGeminiClient).toHaveBeenCalled();
// Verify deferred call
await vi.runAllTimersAsync();
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith(
expect.objectContaining({
update: expect.objectContaining({
sessionUpdate: 'available_commands_update',
}),
}),
);
vi.useRealTimers();
});
it('should return modes without plan mode when plan is disabled', async () => {
@@ -450,6 +488,7 @@ describe('Session', () => {
getActiveModel: vi.fn().mockReturnValue('gemini-pro'),
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
getMcpServers: vi.fn(),
getFileService: vi.fn().mockReturnValue({
shouldIgnoreFile: vi.fn().mockReturnValue(false),
}),
@@ -460,6 +499,7 @@ describe('Session', () => {
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
setApprovalMode: vi.fn(),
isPlanEnabled: vi.fn().mockReturnValue(false),
getGitService: vi.fn().mockResolvedValue({} as GitService),
} as unknown as Mocked<Config>;
mockConnection = {
sessionUpdate: vi.fn(),
@@ -467,7 +507,14 @@ describe('Session', () => {
sendNotification: vi.fn(),
} as unknown as Mocked<acp.AgentSideConnection>;
session = new Session('session-1', mockChat, mockConfig, mockConnection);
session = new Session('session-1', mockChat, mockConfig, mockConnection, {
system: { settings: {} },
systemDefaults: { settings: {} },
user: { settings: {} },
workspace: { settings: {} },
merged: { settings: {} },
errors: [],
} as unknown as LoadedSettings);
});
afterEach(() => {
@@ -482,10 +529,10 @@ describe('Session', () => {
update: expect.objectContaining({
sessionUpdate: 'available_commands_update',
availableCommands: expect.arrayContaining([
expect.objectContaining({ name: 'status' }),
expect.objectContaining({ name: 'mcp' }),
expect.objectContaining({ name: '$commit' }),
expect.objectContaining({ name: '$review-pr' }),
expect.objectContaining({ name: 'memory' }),
expect.objectContaining({ name: 'extensions' }),
expect.objectContaining({ name: 'restore' }),
expect.objectContaining({ name: 'init' }),
]),
}),
}),
@@ -519,161 +566,59 @@ describe('Session', () => {
expect(result).toEqual({ stopReason: 'end_turn' });
});
it('should handle /status command directly with newlines', async () => {
mockConfig.getActiveModel.mockReturnValue('gemini-1.5-pro-test');
it('should handle /memory command', async () => {
const { memoryCommand } = await import('../ui/commands/memoryCommand.js');
const result = await session.prompt({
sessionId: 'session-1',
prompt: [{ type: 'text', text: '/status\nTell me more' }],
prompt: [{ type: 'text', text: '/memory view' }],
});
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith(
expect.objectContaining({
update: expect.objectContaining({
sessionUpdate: 'agent_message_chunk',
}),
}),
);
expect(result).toEqual({ stopReason: 'end_turn' });
expect(mockChat.sendMessageStream).not.toHaveBeenCalled();
});
it('should handle /status command directly', async () => {
mockConfig.getActiveModel.mockReturnValue('gemini-1.5-pro-test');
const result = await session.prompt({
sessionId: 'session-1',
prompt: [{ type: 'text', text: '/status' }],
});
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith(
expect.objectContaining({
update: expect.objectContaining({
sessionUpdate: 'agent_message_chunk',
content: expect.objectContaining({
type: 'text',
text: expect.stringContaining('gemini-1.5-pro-test'),
}),
}),
}),
);
expect(result).toEqual({ stopReason: 'end_turn' });
// Chat should not be called
expect(memoryCommand.action).toHaveBeenCalled();
expect(mockChat.sendMessageStream).not.toHaveBeenCalled();
});
it('should handle /mcp command directly', async () => {
mockConfig.getMcpServers.mockReturnValue({
'test-mcp': {},
} as Record<string, MCPServerConfig>);
it('should handle /extensions command', async () => {
const { extensionsCommand } = await import(
'../ui/commands/extensionsCommand.js'
);
const result = await session.prompt({
sessionId: 'session-1',
prompt: [{ type: 'text', text: '/mcp' }],
prompt: [{ type: 'text', text: '/extensions list' }],
});
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith(
expect.objectContaining({
update: expect.objectContaining({
sessionUpdate: 'agent_message_chunk',
content: expect.objectContaining({
type: 'text',
text: expect.stringContaining('`test-mcp`'),
}),
}),
}),
);
expect(result).toEqual({ stopReason: 'end_turn' });
// extensionsCommand is a factory function, we mocked it to return an object with action
const cmd = extensionsCommand();
expect(cmd.action).toHaveBeenCalled();
expect(mockChat.sendMessageStream).not.toHaveBeenCalled();
});
it('should intercept $commit command and mutate prompt', async () => {
const stream = createMockStream([
{
type: StreamEventType.CHUNK,
value: {
candidates: [{ content: { parts: [{ text: 'Committing...' }] } }],
},
},
]);
mockChat.sendMessageStream.mockResolvedValue(stream);
await session.prompt({
it('should handle /restore command', async () => {
const { restoreCommand } = await import('../ui/commands/restoreCommand.js');
const result = await session.prompt({
sessionId: 'session-1',
// Should replace `$commit` with the instruction
prompt: [{ type: 'text', text: '$commit my cool changes' }],
prompt: [{ type: 'text', text: '/restore' }],
});
expect(mockChat.sendMessageStream).toHaveBeenCalledWith(
expect.anything(),
// The prompt text should be modified to include the commit instruction
expect.arrayContaining([
expect.objectContaining({
text: 'Create a git commit based on the current changes using the tools available. my cool changes',
}),
]),
expect.anything(),
expect.any(AbortSignal),
LlmRole.MAIN,
);
expect(result).toEqual({ stopReason: 'end_turn' });
// restoreCommand is a factory function
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const cmd = (restoreCommand as any)();
expect(cmd.action).toHaveBeenCalled();
expect(mockChat.sendMessageStream).not.toHaveBeenCalled();
});
it('should intercept $commit command with leading spaces and case insensitivity', async () => {
const stream = createMockStream([
{
type: StreamEventType.CHUNK,
value: {
candidates: [{ content: { parts: [{ text: 'Committing...' }] } }],
},
},
]);
mockChat.sendMessageStream.mockResolvedValue(stream);
await session.prompt({
it('should handle /init command', async () => {
const { initCommand } = await import('../ui/commands/initCommand.js');
const result = await session.prompt({
sessionId: 'session-1',
// Should replace `$commit` with the instruction
prompt: [{ type: 'text', text: ' \n$cOmMiT my cool changes' }],
prompt: [{ type: 'text', text: '/init' }],
});
expect(mockChat.sendMessageStream).toHaveBeenCalledWith(
expect.anything(),
// The prompt text should be modified to include the commit instruction
expect.arrayContaining([
expect.objectContaining({
text: 'Create a git commit based on the current changes using the tools available. my cool changes',
}),
]),
expect.anything(),
expect.any(AbortSignal),
LlmRole.MAIN,
);
});
it('should intercept $review-pr command and mutate prompt', async () => {
const stream = createMockStream([
{
type: StreamEventType.CHUNK,
value: {
candidates: [{ content: { parts: [{ text: 'Reviewing...' }] } }],
},
},
]);
mockChat.sendMessageStream.mockResolvedValue(stream);
await session.prompt({
sessionId: 'session-1',
prompt: [{ type: 'text', text: '$review-pr' }],
});
expect(mockChat.sendMessageStream).toHaveBeenCalledWith(
expect.anything(),
expect.arrayContaining([
expect.objectContaining({
text: 'Review the current pull request using the tools available.',
}),
]),
expect.anything(),
expect.any(AbortSignal),
LlmRole.MAIN,
);
expect(result).toEqual({ stopReason: 'end_turn' });
expect(initCommand.action).toHaveBeenCalled();
expect(mockChat.sendMessageStream).not.toHaveBeenCalled();
});
it('should handle tool calls', async () => {
@@ -1205,4 +1150,21 @@ describe('Session', () => {
'Invalid or unavailable mode: invalid-mode',
);
});
it('should handle unquoted commands from autocomplete (with empty leading parts)', async () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const handleCommandSpy = vi.spyOn(session as any, 'handleCommand');
// Mock runCommand to verify it gets called
// eslint-disable-next-line @typescript-eslint/no-explicit-any
vi.spyOn(session as any, 'runCommand').mockResolvedValue(undefined);
await session.prompt({
sessionId: 'session-1',
prompt: [
{ type: 'text', text: '' },
{ type: 'text', text: '/memory' },
],
});
expect(handleCommandSpy).toHaveBeenCalledWith('/memory', expect.anything());
});
});
@@ -4,15 +4,13 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type {
Config,
GeminiChat,
ToolResult,
ToolCallConfirmationDetails,
FilterFilesOptions,
ConversationRecord,
} from '@google/gemini-cli-core';
import {
type Config,
type GeminiChat,
type ToolResult,
type ToolCallConfirmationDetails,
type FilterFilesOptions,
type ConversationRecord,
CoreToolCallStatus,
AuthType,
logToolCall,
@@ -38,6 +36,7 @@ import {
LlmRole,
ApprovalMode,
convertSessionToClientHistory,
type SessionMetrics,
} from '@google/gemini-cli-core';
import * as acp from '@agentclientprotocol/sdk';
import { AcpFileSystemService } from './fileSystemService.js';
@@ -56,11 +55,21 @@ import { loadCliConfig } from '../config/config.js';
import { runExitCleanup } from '../utils/cleanup.js';
import { SessionSelector } from '../utils/sessionUtils.js';
import { memoryCommand } from '../ui/commands/memoryCommand.js';
import { extensionsCommand } from '../ui/commands/extensionsCommand.js';
import { restoreCommand } from '../ui/commands/restoreCommand.js';
import { parseSlashCommand } from '../utils/commands.js';
import { initCommand } from '../ui/commands/initCommand.js';
import type { SlashCommand, CommandContext } from '../ui/commands/types.js';
import type { HistoryItemWithoutId } from '../ui/types.js';
import type { SessionStatsState } from '../ui/contexts/SessionContext.js';
export async function runZedIntegration(
config: Config,
settings: LoadedSettings,
argv: CliArgs,
) {
// ... (skip unchanged lines) ...
const { stdout: workingStdout } = createWorkingStdio();
const stdout = Writable.toWeb(workingStdout) as WritableStream;
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
@@ -219,11 +228,19 @@ export class GeminiAgent {
const geminiClient = config.getGeminiClient();
const chat = await geminiClient.startChat();
const session = new Session(sessionId, chat, config, this.connection);
const session = new Session(
sessionId,
chat,
config,
this.connection,
this.settings,
);
this.sessions.set(sessionId, session);
// eslint-disable-next-line @typescript-eslint/no-floating-promises
session.sendAvailableCommands();
setTimeout(() => {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
session.sendAvailableCommands();
}, 0);
return {
sessionId,
@@ -273,6 +290,7 @@ export class GeminiAgent {
geminiClient.getChat(),
config,
this.connection,
this.settings,
);
this.sessions.set(sessionId, session);
@@ -280,8 +298,10 @@ export class GeminiAgent {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
session.streamHistory(sessionData.messages);
// eslint-disable-next-line @typescript-eslint/no-floating-promises
session.sendAvailableCommands();
setTimeout(() => {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
session.sendAvailableCommands();
}, 0);
return {
modes: {
@@ -409,6 +429,7 @@ export class Session {
private readonly chat: GeminiChat,
private readonly config: Config,
private readonly connection: acp.AgentSideConnection,
private readonly settings: LoadedSettings,
) {}
async cancelPendingPrompt(): Promise<void> {
@@ -436,20 +457,21 @@ export class Session {
sessionUpdate: 'available_commands_update',
availableCommands: [
{
name: 'status',
description: 'Display session configuration and token usage',
name: 'memory',
description: 'Commands for interacting with memory',
},
{
name: 'mcp',
description: 'List configured MCP tools',
name: 'extensions',
description: 'Manage extensions',
},
{
name: '$commit',
description: 'Create a git commit',
name: 'restore',
description: 'Restore a tool call',
},
{
name: '$review-pr',
description: 'Review a pull request',
name: 'init',
description:
'Analyzes the project and creates a tailored GEMINI.md file',
},
],
});
@@ -536,22 +558,40 @@ export class Session {
const parts = await this.#resolvePrompt(params.prompt, pendingSend.signal);
// Command interception
if (
parts.length > 0 &&
typeof parts[0] === 'object' &&
parts[0] !== null &&
'text' in parts[0] &&
parts[0].text
) {
const firstText = parts[0].text.trim();
if (firstText.startsWith('/') || firstText.startsWith('$')) {
const handled = await this.handleCommand(firstText, parts);
if (handled) {
return { stopReason: 'end_turn' };
let commandText = '';
for (const part of parts) {
if (typeof part === 'object' && part !== null) {
if ('text' in part) {
// It is a text part
// eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-assignment, @typescript-eslint/no-unsafe-type-assertion
const text = (part as any).text;
if (typeof text === 'string') {
commandText += text;
}
} else {
// Non-text part (image, embedded resource)
// Stop looking for command
break;
}
}
}
commandText = commandText.trim();
if (
commandText &&
(commandText.startsWith('/') || commandText.startsWith('$'))
) {
// If we found a command, pass it to handleCommand
// Note: handleCommand currently expects `commandText` to be the command string
// It uses `parts` argument but effectively ignores it in current implementation
const handled = await this.handleCommand(commandText, parts);
if (handled) {
return { stopReason: 'end_turn' };
}
}
let nextMessage: Content | null = { role: 'user', parts };
while (nextMessage !== null) {
@@ -653,73 +693,97 @@ export class Session {
private async handleCommand(
commandText: string,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
parts: Part[],
): Promise<boolean> {
const rawCommand = commandText.split(/\s+/)[0] || '';
const commandToMatch = rawCommand.toLowerCase();
const commands: SlashCommand[] = [
memoryCommand,
extensionsCommand(),
initCommand,
];
if (commandToMatch === '/status') {
const activeModel = this.config.getActiveModel();
const resolvedModel = resolveModel(activeModel);
const content = `**Session Status**\n\n- Active Model: \`${resolvedModel}\``;
const restore = restoreCommand(this.config);
if (restore) {
commands.push(restore);
}
await this.sendUpdate({
sessionUpdate: 'agent_message_chunk',
content: { type: 'text', text: content },
});
const { commandToExecute, args } = parseSlashCommand(commandText, commands);
if (commandToExecute) {
await this.runCommand(commandToExecute, commandText, args);
return true;
}
if (commandToMatch === '/mcp') {
const mcpServers = this.config.getMcpServers() || {};
let content = '**Configured MCP Servers**\n';
const serverNames = Object.keys(mcpServers);
if (serverNames.length === 0) {
content += '\nNo MCP servers configured.';
} else {
content += '\n' + serverNames.map((name) => `- \`${name}\``).join('\n');
}
await this.sendUpdate({
sessionUpdate: 'agent_message_chunk',
content: { type: 'text', text: content },
});
return true;
}
if (commandToMatch === '$commit') {
const textPart = parts[0];
if (textPart && 'text' in textPart && typeof textPart.text === 'string') {
textPart.text = textPart.text
.replace(
/^\s*\$commit/i,
'Create a git commit based on the current changes using the tools available.',
)
.trim();
}
return false; // Proceed with LLM execution
}
if (commandToMatch === '$review-pr') {
const textPart = parts[0];
if (textPart && 'text' in textPart && typeof textPart.text === 'string') {
textPart.text = textPart.text
.replace(
/^\s*\$review-pr/i,
'Review the current pull request using the tools available.',
)
.trim();
}
return false; // Proceed with LLM execution
}
return false;
}
private async sendUpdate(
update: acp.SessionNotification['update'],
private async runCommand(
command: SlashCommand,
commandText: string,
rawArgs: string,
): Promise<void> {
// Mock UI for capturing output
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const mockUi = {
addItem: (item: HistoryItemWithoutId) => {
if (item.text) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
this.sendUpdate({
sessionUpdate: 'agent_message_chunk',
content: { type: 'text', text: item.text },
});
}
},
dispatchExtensionStateUpdate: () => {},
setPendingItem: () => {},
reloadCommands: () => {},
removeComponent: () => {},
loadHistory: () => {},
} as unknown as CommandContext['ui'];
const gitService = await this.config.getGitService();
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const context: CommandContext = {
ui: mockUi,
invocation: {
raw: commandText,
name: command.name,
args: rawArgs,
},
services: {
config: this.config,
settings: this.settings,
git: gitService,
logger: debugLogger,
},
session: {
stats: {
sessionId: this.id,
sessionStartTime: new Date(),
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
metrics: {} as unknown as SessionMetrics,
lastPromptTokenCount: 0,
promptCount: 0,
} as SessionStatsState,
sessionShellAllowlist: new Set(),
},
} as unknown as CommandContext;
if (command.action) {
try {
await command.action(context, rawArgs);
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
await this.sendUpdate({
sessionUpdate: 'agent_message_chunk',
content: { type: 'text', text: `Error: ${errorMessage}` },
});
}
}
}
private async sendUpdate(update: acp.SessionUpdate): Promise<void> {
const params: acp.SessionNotification = {
sessionId: this.id,
update,