From 2c4f61eca5f382c7af3ae8c9dec0f3d26f0eae1d Mon Sep 17 00:00:00 2001 From: James Date: Fri, 19 Sep 2025 13:49:35 +0000 Subject: [PATCH] feat(cli) Custom Commands work in Non-Interactive/Headless Mode (#8305) --- packages/cli/src/gemini.tsx | 2 +- packages/cli/src/nonInteractiveCli.test.ts | 248 +++++++++++++++++- packages/cli/src/nonInteractiveCli.ts | 55 ++-- packages/cli/src/nonInteractiveCliCommands.ts | 109 ++++++++ .../src/ui/noninteractive/nonInteractiveUi.ts | 29 ++ 5 files changed, 417 insertions(+), 26 deletions(-) create mode 100644 packages/cli/src/nonInteractiveCliCommands.ts create mode 100644 packages/cli/src/ui/noninteractive/nonInteractiveUi.ts diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 0da33fd992..03e98b4362 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -465,7 +465,7 @@ export async function main() { console.log('Session ID: %s', sessionId); } - await runNonInteractive(nonInteractiveConfig, input, prompt_id); + await runNonInteractive(nonInteractiveConfig, settings, input, prompt_id); // Call cleanup before process.exit, which causes cleanup to not run await runExitCleanup(); process.exit(0); diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts index fcf80979eb..71e19b7074 100644 --- a/packages/cli/src/nonInteractiveCli.test.ts +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -22,6 +22,7 @@ import { import type { Part } from '@google/genai'; import { runNonInteractive } from './nonInteractiveCli.js'; import { vi } from 'vitest'; +import type { LoadedSettings } from './config/settings.js'; // Mock core modules vi.mock('./ui/hooks/atCommandProcessor.js'); @@ -48,8 +49,17 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => { }; }); +const mockGetCommands = vi.hoisted(() => vi.fn()); +const mockCommandServiceCreate = vi.hoisted(() => vi.fn()); +vi.mock('./services/CommandService.js', () => ({ + CommandService: { + create: mockCommandServiceCreate, + }, +})); + describe('runNonInteractive', () => { let mockConfig: Config; + let mockSettings: LoadedSettings; let mockToolRegistry: ToolRegistry; let mockCoreExecuteToolCall: vi.Mock; let mockShutdownTelemetry: vi.Mock; @@ -64,6 +74,10 @@ describe('runNonInteractive', () => { mockCoreExecuteToolCall = vi.mocked(executeToolCall); mockShutdownTelemetry = vi.mocked(shutdownTelemetry); + mockCommandServiceCreate.mockResolvedValue({ + getCommands: mockGetCommands, + }); + consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); processStdoutSpy = vi .spyOn(process.stdout, 'write') @@ -102,8 +116,30 @@ describe('runNonInteractive', () => { getContentGeneratorConfig: vi.fn().mockReturnValue({}), getDebugMode: vi.fn().mockReturnValue(false), getOutputFormat: vi.fn().mockReturnValue('text'), + getFolderTrustFeature: vi.fn().mockReturnValue(false), + getFolderTrust: vi.fn().mockReturnValue(false), } as unknown as Config; + mockSettings = { + system: { path: '', settings: {} }, + systemDefaults: { path: '', settings: {} }, + user: { path: '', settings: {} }, + workspace: { path: '', settings: {} }, + errors: [], + setValue: vi.fn(), + merged: { + security: { + auth: { + enforcedType: undefined, + }, + }, + }, + isTrusted: true, + migratedInMemorScopes: new Set(), + forScope: vi.fn(), + computeMergedSettings: vi.fn(), + } as unknown as LoadedSettings; + const { handleAtCommand } = await import( './ui/hooks/atCommandProcessor.js' ); @@ -138,7 +174,12 @@ describe('runNonInteractive', () => { createStreamFromEvents(events), ); - await runNonInteractive(mockConfig, 'Test input', 'prompt-id-1'); + await runNonInteractive( + mockConfig, + mockSettings, + 'Test input', + 'prompt-id-1', + ); expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith( [{ text: 'Test input' }], @@ -178,7 +219,12 @@ describe('runNonInteractive', () => { .mockReturnValueOnce(createStreamFromEvents(firstCallEvents)) .mockReturnValueOnce(createStreamFromEvents(secondCallEvents)); - await runNonInteractive(mockConfig, 'Use a tool', 'prompt-id-2'); + await runNonInteractive( + mockConfig, + mockSettings, + 'Use a tool', + 'prompt-id-2', + ); expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2); expect(mockCoreExecuteToolCall).toHaveBeenCalledWith( @@ -236,7 +282,12 @@ describe('runNonInteractive', () => { .mockReturnValueOnce(createStreamFromEvents([toolCallEvent])) .mockReturnValueOnce(createStreamFromEvents(finalResponse)); - await runNonInteractive(mockConfig, 'Trigger tool error', 'prompt-id-3'); + await runNonInteractive( + mockConfig, + mockSettings, + 'Trigger tool error', + 'prompt-id-3', + ); expect(mockCoreExecuteToolCall).toHaveBeenCalled(); expect(consoleErrorSpy).toHaveBeenCalledWith( @@ -268,7 +319,12 @@ describe('runNonInteractive', () => { }); await expect( - runNonInteractive(mockConfig, 'Initial fail', 'prompt-id-4'), + runNonInteractive( + mockConfig, + mockSettings, + 'Initial fail', + 'prompt-id-4', + ), ).rejects.toThrow(apiError); }); @@ -305,6 +361,7 @@ describe('runNonInteractive', () => { await runNonInteractive( mockConfig, + mockSettings, 'Trigger tool not found', 'prompt-id-5', ); @@ -322,7 +379,12 @@ describe('runNonInteractive', () => { it('should exit when max session turns are exceeded', async () => { vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(0); await expect( - runNonInteractive(mockConfig, 'Trigger loop', 'prompt-id-6'), + runNonInteractive( + mockConfig, + mockSettings, + 'Trigger loop', + 'prompt-id-6', + ), ).rejects.toThrow('process.exit(53) called'); }); @@ -361,7 +423,7 @@ describe('runNonInteractive', () => { ); // 4. Run the non-interactive mode with the raw input - await runNonInteractive(mockConfig, rawInput, 'prompt-id-7'); + await runNonInteractive(mockConfig, mockSettings, rawInput, 'prompt-id-7'); // 5. Assert that sendMessageStream was called with the PROCESSED parts, not the raw input expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith( @@ -408,7 +470,12 @@ describe('runNonInteractive', () => { }; vi.mocked(uiTelemetryService.getMetrics).mockReturnValue(mockMetrics); - await runNonInteractive(mockConfig, 'Test input', 'prompt-id-1'); + await runNonInteractive( + mockConfig, + mockSettings, + 'Test input', + 'prompt-id-1', + ); expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith( [{ text: 'Test input' }], @@ -495,6 +562,7 @@ describe('runNonInteractive', () => { await runNonInteractive( mockConfig, + mockSettings, 'Execute tool only', 'prompt-id-tool-only', ); @@ -548,6 +616,7 @@ describe('runNonInteractive', () => { await runNonInteractive( mockConfig, + mockSettings, 'Empty response test', 'prompt-id-empty', ); @@ -579,7 +648,12 @@ describe('runNonInteractive', () => { let thrownError: Error | null = null; try { - await runNonInteractive(mockConfig, 'Test input', 'prompt-id-error'); + await runNonInteractive( + mockConfig, + mockSettings, + 'Test input', + 'prompt-id-error', + ); // Should not reach here expect.fail('Expected process.exit to be called'); } catch (error) { @@ -619,7 +693,12 @@ describe('runNonInteractive', () => { let thrownError: Error | null = null; try { - await runNonInteractive(mockConfig, 'Invalid syntax', 'prompt-id-fatal'); + await runNonInteractive( + mockConfig, + mockSettings, + 'Invalid syntax', + 'prompt-id-fatal', + ); // Should not reach here expect.fail('Expected process.exit to be called'); } catch (error) { @@ -643,4 +722,155 @@ describe('runNonInteractive', () => { ), ); }); + + it('should execute a slash command that returns a prompt', async () => { + const mockCommand = { + name: 'testcommand', + description: 'a test command', + action: vi.fn().mockResolvedValue({ + type: 'submit_prompt', + content: [{ text: 'Prompt from command' }], + }), + }; + mockGetCommands.mockReturnValue([mockCommand]); + + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Response from command' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 5 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive( + mockConfig, + mockSettings, + '/testcommand', + 'prompt-id-slash', + ); + + // Ensure the prompt sent to the model is from the command, not the raw input + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith( + [{ text: 'Prompt from command' }], + expect.any(AbortSignal), + 'prompt-id-slash', + ); + + expect(processStdoutSpy).toHaveBeenCalledWith('Response from command'); + }); + + it('should throw FatalInputError if a command requires confirmation', async () => { + const mockCommand = { + name: 'confirm', + description: 'a command that needs confirmation', + action: vi.fn().mockResolvedValue({ + type: 'confirm_shell_commands', + commands: ['rm -rf /'], + }), + }; + mockGetCommands.mockReturnValue([mockCommand]); + + await expect( + runNonInteractive( + mockConfig, + mockSettings, + '/confirm', + 'prompt-id-confirm', + ), + ).rejects.toThrow( + 'Exiting due to a confirmation prompt requested by the command.', + ); + }); + + it('should treat an unknown slash command as a regular prompt', async () => { + // No commands are mocked, so any slash command is "unknown" + mockGetCommands.mockReturnValue([]); + + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Response to unknown' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 5 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive( + mockConfig, + mockSettings, + '/unknowncommand', + 'prompt-id-unknown', + ); + + // Ensure the raw input is sent to the model + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith( + [{ text: '/unknowncommand' }], + expect.any(AbortSignal), + 'prompt-id-unknown', + ); + + expect(processStdoutSpy).toHaveBeenCalledWith('Response to unknown'); + }); + + it('should throw for unhandled command result types', async () => { + const mockCommand = { + name: 'noaction', + description: 'unhandled type', + action: vi.fn().mockResolvedValue({ + type: 'unhandled', + }), + }; + mockGetCommands.mockReturnValue([mockCommand]); + + await expect( + runNonInteractive( + mockConfig, + mockSettings, + '/noaction', + 'prompt-id-unhandled', + ), + ).rejects.toThrow( + 'Exiting due to command result that is not supported in non-interactive mode.', + ); + }); + + it('should pass arguments to the slash command action', async () => { + const mockAction = vi.fn().mockResolvedValue({ + type: 'submit_prompt', + content: [{ text: 'Prompt from command' }], + }); + const mockCommand = { + name: 'testargs', + description: 'a test command', + action: mockAction, + }; + mockGetCommands.mockReturnValue([mockCommand]); + + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Acknowledged' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 1 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive( + mockConfig, + mockSettings, + '/testargs arg1 arg2', + 'prompt-id-args', + ); + + expect(mockAction).toHaveBeenCalledWith(expect.any(Object), 'arg1 arg2'); + + expect(processStdoutSpy).toHaveBeenCalledWith('Acknowledged'); + }); }); diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index d98b1c91c9..9a5e5fad09 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -5,6 +5,8 @@ */ import type { Config, ToolCallRequestInfo } from '@google/gemini-cli-core'; +import { isSlashCommand } from './ui/utils/commandUtils.js'; +import type { LoadedSettings } from './config/settings.js'; import { executeToolCall, shutdownTelemetry, @@ -16,8 +18,10 @@ import { JsonFormatter, uiTelemetryService, } from '@google/gemini-cli-core'; + import type { Content, Part } from '@google/genai'; +import { handleSlashCommand } from './nonInteractiveCliCommands.js'; import { ConsolePatcher } from './ui/utils/ConsolePatcher.js'; import { handleAtCommand } from './ui/hooks/atCommandProcessor.js'; import { @@ -29,6 +33,7 @@ import { export async function runNonInteractive( config: Config, + settings: LoadedSettings, input: string, prompt_id: string, ): Promise { @@ -52,26 +57,44 @@ export async function runNonInteractive( const abortController = new AbortController(); - const { processedQuery, shouldProceed } = await handleAtCommand({ - query: input, - config, - addItem: (_item, _timestamp) => 0, - onDebugMessage: () => {}, - messageId: Date.now(), - signal: abortController.signal, - }); + let query: Part[] | undefined; - if (!shouldProceed || !processedQuery) { - // An error occurred during @include processing (e.g., file not found). - // The error message is already logged by handleAtCommand. - throw new FatalInputError( - 'Exiting due to an error processing the @ command.', + if (isSlashCommand(input)) { + const slashCommandResult = await handleSlashCommand( + input, + abortController, + config, + settings, ); + // If a slash command is found and returns a prompt, use it. + // Otherwise, slashCommandResult fall through to the default prompt + // handling. + if (slashCommandResult) { + query = slashCommandResult as Part[]; + } } - let currentMessages: Content[] = [ - { role: 'user', parts: processedQuery as Part[] }, - ]; + if (!query) { + const { processedQuery, shouldProceed } = await handleAtCommand({ + query: input, + config, + addItem: (_item, _timestamp) => 0, + onDebugMessage: () => {}, + messageId: Date.now(), + signal: abortController.signal, + }); + + if (!shouldProceed || !processedQuery) { + // An error occurred during @include processing (e.g., file not found). + // The error message is already logged by handleAtCommand. + throw new FatalInputError( + 'Exiting due to an error processing the @ command.', + ); + } + query = processedQuery as Part[]; + } + + let currentMessages: Content[] = [{ role: 'user', parts: query }]; let turnCount = 0; while (true) { diff --git a/packages/cli/src/nonInteractiveCliCommands.ts b/packages/cli/src/nonInteractiveCliCommands.ts new file mode 100644 index 0000000000..31b748d786 --- /dev/null +++ b/packages/cli/src/nonInteractiveCliCommands.ts @@ -0,0 +1,109 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { PartListUnion } from '@google/genai'; +import { parseSlashCommand } from './utils/commands.js'; +import { + FatalInputError, + Logger, + uiTelemetryService, + type Config, +} from '@google/gemini-cli-core'; +import { CommandService } from './services/CommandService.js'; +import { FileCommandLoader } from './services/FileCommandLoader.js'; +import type { CommandContext } from './ui/commands/types.js'; +import { createNonInteractiveUI } from './ui/noninteractive/nonInteractiveUi.js'; +import type { LoadedSettings } from './config/settings.js'; +import type { SessionStatsState } from './ui/contexts/SessionContext.js'; + +/** + * Processes a slash command in a non-interactive environment. + * + * @returns A Promise that resolves to `PartListUnion` if a valid command is + * found and results in a prompt, or `undefined` otherwise. + * @throws {FatalInputError} if the command result is not supported in + * non-interactive mode. + */ +export const handleSlashCommand = async ( + rawQuery: string, + abortController: AbortController, + config: Config, + settings: LoadedSettings, +): Promise => { + const trimmed = rawQuery.trim(); + if (!trimmed.startsWith('/')) { + return; + } + + // Only custom commands are supported for now. + const loaders = [new FileCommandLoader(config)]; + const commandService = await CommandService.create( + loaders, + abortController.signal, + ); + const commands = commandService.getCommands(); + + const { commandToExecute, args } = parseSlashCommand(rawQuery, commands); + + if (commandToExecute) { + if (commandToExecute.action) { + // Not used by custom commands but may be in the future. + const sessionStats: SessionStatsState = { + sessionId: config?.getSessionId(), + sessionStartTime: new Date(), + metrics: uiTelemetryService.getMetrics(), + lastPromptTokenCount: 0, + promptCount: 1, + }; + + const logger = new Logger(config?.getSessionId() || '', config?.storage); + + const context: CommandContext = { + services: { + config, + settings, + git: undefined, + logger, + }, + ui: createNonInteractiveUI(), + session: { + stats: sessionStats, + sessionShellAllowlist: new Set(), + }, + invocation: { + raw: trimmed, + name: commandToExecute.name, + args, + }, + }; + + const result = await commandToExecute.action(context, args); + + if (result) { + switch (result.type) { + case 'submit_prompt': + return result.content; + case 'confirm_shell_commands': + // This result indicates a command attempted to confirm shell commands. + // However note that currently, ShellTool is excluded in non-interactive + // mode unless 'YOLO mode' is active, so confirmation actually won't + // occur because of YOLO mode. + // This ensures that if a command *does* request confirmation (e.g. + // in the future with more granular permissions), it's handled appropriately. + throw new FatalInputError( + 'Exiting due to a confirmation prompt requested by the command.', + ); + default: + throw new FatalInputError( + 'Exiting due to command result that is not supported in non-interactive mode.', + ); + } + } + } + } + + return; +}; diff --git a/packages/cli/src/ui/noninteractive/nonInteractiveUi.ts b/packages/cli/src/ui/noninteractive/nonInteractiveUi.ts new file mode 100644 index 0000000000..4e53aae57f --- /dev/null +++ b/packages/cli/src/ui/noninteractive/nonInteractiveUi.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { CommandContext } from '../commands/types.js'; + +/** + * Creates a UI context object with no-op functions. + * Useful for non-interactive environments where UI operations + * are not applicable. + */ +export function createNonInteractiveUI(): CommandContext['ui'] { + return { + addItem: (_item, _timestamp) => 0, + clear: () => {}, + setDebugMessage: (_message) => {}, + loadHistory: (_newHistory) => {}, + pendingItem: null, + setPendingItem: (_item) => {}, + toggleCorgiMode: () => {}, + toggleVimEnabled: async () => false, + setGeminiMdFileCount: (_count) => {}, + reloadCommands: () => {}, + extensionsUpdateState: new Map(), + setExtensionsUpdateState: (_updateState) => {}, + }; +}