From 5cbab75c7d40080e182a4a2561c91d4689352811 Mon Sep 17 00:00:00 2001 From: Luke Schlangen Date: Mon, 15 Sep 2025 10:13:21 -0500 Subject: [PATCH] fix: positional arguments for MCP prompts (#8034) --- .../cli/src/services/McpPromptLoader.test.ts | 273 +++++++++++++++++- packages/cli/src/services/McpPromptLoader.ts | 72 ++++- .../src/ui/hooks/useSlashCompletion.test.ts | 19 +- .../cli/src/ui/hooks/useSlashCompletion.ts | 22 +- 4 files changed, 368 insertions(+), 18 deletions(-) diff --git a/packages/cli/src/services/McpPromptLoader.test.ts b/packages/cli/src/services/McpPromptLoader.test.ts index 609e553362..3ba4c012ad 100644 --- a/packages/cli/src/services/McpPromptLoader.test.ts +++ b/packages/cli/src/services/McpPromptLoader.test.ts @@ -7,11 +7,42 @@ import { McpPromptLoader } from './McpPromptLoader.js'; import type { Config } from '@google/gemini-cli-core'; import type { PromptArgument } from '@modelcontextprotocol/sdk/types.js'; -import { describe, it, expect } from 'vitest'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { CommandKind, type CommandContext } from '../ui/commands/types.js'; +import * as cliCore from '@google/gemini-cli-core'; + +// Define the mock prompt data at a higher scope +const mockPrompt = { + name: 'test-prompt', + description: 'A test prompt.', + serverName: 'test-server', + arguments: [ + { name: 'name', required: true, description: "The animal's name." }, + { name: 'age', required: true, description: "The animal's age." }, + { name: 'species', required: true, description: "The animal's species." }, + { + name: 'enclosure', + required: false, + description: "The animal's enclosure.", + }, + { name: 'trail', required: false, description: "The animal's trail." }, + ], + invoke: vi.fn().mockResolvedValue({ + messages: [{ content: { text: 'Hello, world!' } }], + }), +}; describe('McpPromptLoader', () => { const mockConfig = {} as Config; + // Use a beforeEach to set up and clean a spy for each test + beforeEach(() => { + vi.clearAllMocks(); + vi.spyOn(cliCore, 'getMCPServerPrompts').mockReturnValue([mockPrompt]); + }); + + // --- `parseArgs` tests remain the same --- + describe('parseArgs', () => { it('should handle multi-word positional arguments', () => { const loader = new McpPromptLoader(mockConfig); @@ -125,4 +156,244 @@ describe('McpPromptLoader', () => { }); }); }); + + describe('loadCommands', () => { + const mockConfigWithPrompts = { + getMcpServers: () => ({ + 'test-server': { httpUrl: 'https://test-server.com' }, + }), + } as unknown as Config; + + it('should load prompts as slash commands', async () => { + const loader = new McpPromptLoader(mockConfigWithPrompts); + const commands = await loader.loadCommands(new AbortController().signal); + expect(commands).toHaveLength(1); + expect(commands[0].name).toBe('test-prompt'); + expect(commands[0].description).toBe('A test prompt.'); + expect(commands[0].kind).toBe(CommandKind.MCP_PROMPT); + }); + + it('should handle prompt invocation successfully', async () => { + const loader = new McpPromptLoader(mockConfigWithPrompts); + const commands = await loader.loadCommands(new AbortController().signal); + const action = commands[0].action!; + const context = {} as CommandContext; + const result = await action(context, 'test-name 123 tiger'); + expect(mockPrompt.invoke).toHaveBeenCalledWith({ + name: 'test-name', + age: '123', + species: 'tiger', + }); + expect(result).toEqual({ + type: 'submit_prompt', + content: JSON.stringify('Hello, world!'), + }); + }); + + it('should return an error for missing required arguments', async () => { + const loader = new McpPromptLoader(mockConfigWithPrompts); + const commands = await loader.loadCommands(new AbortController().signal); + const action = commands[0].action!; + const context = {} as CommandContext; + const result = await action(context, 'test-name'); + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: 'Missing required argument(s): --age, --species', + }); + }); + + it('should return an error message if prompt invocation fails', async () => { + vi.spyOn(mockPrompt, 'invoke').mockRejectedValue( + new Error('Invocation failed!'), + ); + const loader = new McpPromptLoader(mockConfigWithPrompts); + const commands = await loader.loadCommands(new AbortController().signal); + const action = commands[0].action!; + const context = {} as CommandContext; + const result = await action(context, 'test-name 123 tiger'); + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: 'Error: Invocation failed!', + }); + }); + + it('should return an empty array if config is not available', async () => { + const loader = new McpPromptLoader(null); + const commands = await loader.loadCommands(new AbortController().signal); + expect(commands).toEqual([]); + }); + + describe('completion', () => { + it('should suggest no arguments when using positional arguments', async () => { + const loader = new McpPromptLoader(mockConfigWithPrompts); + const commands = await loader.loadCommands( + new AbortController().signal, + ); + const completion = commands[0].completion!; + const context = {} as CommandContext; + const suggestions = await completion(context, 'test-name 6 tiger'); + expect(suggestions).toEqual([]); + }); + + it('should suggest all arguments when none are present', async () => { + const loader = new McpPromptLoader(mockConfigWithPrompts); + const commands = await loader.loadCommands( + new AbortController().signal, + ); + const completion = commands[0].completion!; + const context = { + invocation: { + raw: '/find ', + name: 'find', + args: '', + }, + } as CommandContext; + const suggestions = await completion(context, ''); + expect(suggestions).toEqual([ + '--name="', + '--age="', + '--species="', + '--enclosure="', + '--trail="', + ]); + }); + + it('should suggest remaining arguments when some are present', async () => { + const loader = new McpPromptLoader(mockConfigWithPrompts); + const commands = await loader.loadCommands( + new AbortController().signal, + ); + const completion = commands[0].completion!; + const context = { + invocation: { + raw: '/find --name="test-name" --age="6" ', + name: 'find', + args: '--name="test-name" --age="6"', + }, + } as CommandContext; + const suggestions = await completion(context, ''); + expect(suggestions).toEqual([ + '--species="', + '--enclosure="', + '--trail="', + ]); + }); + + it('should suggest no arguments when all are present', async () => { + const loader = new McpPromptLoader(mockConfigWithPrompts); + const commands = await loader.loadCommands( + new AbortController().signal, + ); + const completion = commands[0].completion!; + const context = {} as CommandContext; + const suggestions = await completion( + context, + '--name="test-name" --age="6" --species="tiger" --enclosure="Tiger Den" --trail="Jungle"', + ); + expect(suggestions).toEqual([]); + }); + + it('should suggest nothing for prompts with no arguments', async () => { + // Temporarily override the mock to return a prompt with no args + vi.spyOn(cliCore, 'getMCPServerPrompts').mockReturnValue([ + { ...mockPrompt, arguments: [] }, + ]); + const loader = new McpPromptLoader(mockConfigWithPrompts); + const commands = await loader.loadCommands( + new AbortController().signal, + ); + const completion = commands[0].completion!; + const context = {} as CommandContext; + const suggestions = await completion(context, ''); + expect(suggestions).toEqual([]); + }); + + it('should suggest arguments matching a partial argument', async () => { + const loader = new McpPromptLoader(mockConfigWithPrompts); + const commands = await loader.loadCommands( + new AbortController().signal, + ); + const completion = commands[0].completion!; + const context = { + invocation: { + raw: '/find --s', + name: 'find', + args: '--s', + }, + } as CommandContext; + const suggestions = await completion(context, '--s'); + expect(suggestions).toEqual(['--species="']); + }); + + it('should suggest arguments even when a partial argument is parsed as a value', async () => { + const loader = new McpPromptLoader(mockConfigWithPrompts); + const commands = await loader.loadCommands( + new AbortController().signal, + ); + const completion = commands[0].completion!; + const context = { + invocation: { + raw: '/find --name="test" --a', + name: 'find', + args: '--name="test" --a', + }, + } as CommandContext; + const suggestions = await completion(context, '--a'); + expect(suggestions).toEqual(['--age="']); + }); + + it('should auto-close the quote for a named argument value', async () => { + const loader = new McpPromptLoader(mockConfigWithPrompts); + const commands = await loader.loadCommands( + new AbortController().signal, + ); + const completion = commands[0].completion!; + const context = { + invocation: { + raw: '/find --name="test', + name: 'find', + args: '--name="test', + }, + } as CommandContext; + const suggestions = await completion(context, '--name="test'); + expect(suggestions).toEqual(['--name="test"']); + }); + + it('should auto-close the quote for an empty named argument value', async () => { + const loader = new McpPromptLoader(mockConfigWithPrompts); + const commands = await loader.loadCommands( + new AbortController().signal, + ); + const completion = commands[0].completion!; + const context = { + invocation: { + raw: '/find --name="', + name: 'find', + args: '--name="', + }, + } as CommandContext; + const suggestions = await completion(context, '--name="'); + expect(suggestions).toEqual(['--name=""']); + }); + + it('should not add a quote if already present', async () => { + const loader = new McpPromptLoader(mockConfigWithPrompts); + const commands = await loader.loadCommands( + new AbortController().signal, + ); + const completion = commands[0].completion!; + const context = { + invocation: { + raw: '/find --name="test"', + name: 'find', + args: '--name="test"', + }, + } as CommandContext; + const suggestions = await completion(context, '--name="test"'); + expect(suggestions).toEqual([]); + }); + }); + }); }); diff --git a/packages/cli/src/services/McpPromptLoader.ts b/packages/cli/src/services/McpPromptLoader.ts index c24e7b5122..c402fa82e0 100644 --- a/packages/cli/src/services/McpPromptLoader.ts +++ b/packages/cli/src/services/McpPromptLoader.ts @@ -141,23 +141,69 @@ export class McpPromptLoader implements ICommandLoader { }; } }, - completion: async (_: CommandContext, partialArg: string) => { - if (!prompt || !prompt.arguments) { + completion: async ( + commandContext: CommandContext, + partialArg: string, + ) => { + const invocation = commandContext.invocation; + if (!prompt || !prompt.arguments || !invocation) { return []; } - - const suggestions: string[] = []; - const usedArgNames = new Set( - (partialArg.match(/--([^=]+)/g) || []).map((s) => s.substring(2)), - ); - - for (const arg of prompt.arguments) { - if (!usedArgNames.has(arg.name)) { - suggestions.push(`--${arg.name}=""`); - } + const indexOfFirstSpace = invocation.raw.indexOf(' ') + 1; + let promptInputs = + indexOfFirstSpace === 0 + ? {} + : this.parseArgs( + invocation.raw.substring(indexOfFirstSpace), + prompt.arguments, + ); + if (promptInputs instanceof Error) { + promptInputs = {}; } - return suggestions; + const providedArgNames = Object.keys(promptInputs); + const unusedArguments = + prompt.arguments + .filter((arg) => { + // If this arguments is not in the prompt inputs + // add it to unusedArguments + if (!providedArgNames.includes(arg.name)) { + return true; + } + + // The parseArgs method assigns the value + // at the end of the prompt as a final value + // The argument should still be suggested + // Example /add --numberOne="34" --num + // numberTwo would be assigned a value of --num + // numberTwo should still be considered unused + const argValue = promptInputs[arg.name]; + return argValue === partialArg; + }) + .map((argument) => `--${argument.name}="`) || []; + + const exactlyMatchingArgumentAtTheEnd = prompt.arguments + .map((argument) => `--${argument.name}="`) + .filter((flagArgument) => { + const regex = new RegExp(`${flagArgument}[^"]*$`); + return regex.test(invocation.raw); + }); + + if (exactlyMatchingArgumentAtTheEnd.length === 1) { + if (exactlyMatchingArgumentAtTheEnd[0] === partialArg) { + return [`${partialArg}"`]; + } + if (partialArg.endsWith('"')) { + return [partialArg]; + } + return [`${partialArg}"`]; + } + + const matchingArguments = unusedArguments.filter((flagArgument) => + flagArgument.startsWith(partialArg), + ); + + return matchingArguments; }, }; promptCommands.push(newPromptCommand); diff --git a/packages/cli/src/ui/hooks/useSlashCompletion.test.ts b/packages/cli/src/ui/hooks/useSlashCompletion.test.ts index b5568ce9b4..afebbc92db 100644 --- a/packages/cli/src/ui/hooks/useSlashCompletion.test.ts +++ b/packages/cli/src/ui/hooks/useSlashCompletion.test.ts @@ -549,7 +549,13 @@ describe('useSlashCompletion', () => { await waitFor(() => { expect(mockCompletionFn).toHaveBeenCalledWith( - mockCommandContext, + expect.objectContaining({ + invocation: { + raw: '/chat resume my-ch', + name: 'resume', + args: 'my-ch', + }, + }), 'my-ch', ); }); @@ -591,7 +597,16 @@ describe('useSlashCompletion', () => { ); await waitFor(() => { - expect(mockCompletionFn).toHaveBeenCalledWith(mockCommandContext, ''); + expect(mockCompletionFn).toHaveBeenCalledWith( + expect.objectContaining({ + invocation: { + raw: '/chat resume', + name: 'resume', + args: '', + }, + }), + '', + ); }); await waitFor(() => { diff --git a/packages/cli/src/ui/hooks/useSlashCompletion.ts b/packages/cli/src/ui/hooks/useSlashCompletion.ts index a284e0bc6e..bee4677ffa 100644 --- a/packages/cli/src/ui/hooks/useSlashCompletion.ts +++ b/packages/cli/src/ui/hooks/useSlashCompletion.ts @@ -7,7 +7,11 @@ import { useState, useEffect, useMemo } from 'react'; import { AsyncFzf } from 'fzf'; import type { Suggestion } from '../components/SuggestionsDisplay.js'; -import type { CommandContext, SlashCommand } from '../commands/types.js'; +import { + CommandKind, + type CommandContext, + type SlashCommand, +} from '../commands/types.js'; // Type alias for improved type safety based on actual fzf result structure type FzfCommandResult = { @@ -93,9 +97,13 @@ function useCommandParser( const found: SlashCommand | undefined = currentLevel.find((cmd) => matchesCommand(cmd, part), ); + if (found) { leafCommand = found; currentLevel = found.subCommands as readonly SlashCommand[] | undefined; + if (found.kind === CommandKind.MCP_PROMPT) { + break; + } } else { leafCommand = null; currentLevel = []; @@ -194,7 +202,17 @@ function useCommandSuggestions( const depth = commandPathParts.length; const argString = rawParts.slice(depth).join(' '); const results = - (await leafCommand.completion(commandContext, argString)) || []; + (await leafCommand.completion( + { + ...commandContext, + invocation: { + raw: `/${rawParts.join(' ')}`, + name: leafCommand.name, + args: argString, + }, + }, + argString, + )) || []; if (!signal.aborted) { const finalSuggestions = results.map((s) => ({