From e65208b24458978d1e7a3c1aec674f7d6a829a3d Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Thu, 22 Jan 2026 12:12:13 -0500 Subject: [PATCH] feat: add AskUser tool schema (#16988) --- packages/core/src/confirmation-bus/types.ts | 40 ++++ packages/core/src/tools/ask-user.test.ts | 227 ++++++++++++++++++++ packages/core/src/tools/ask-user.ts | 208 ++++++++++++++++++ packages/core/src/tools/tool-names.ts | 2 + 4 files changed, 477 insertions(+) create mode 100644 packages/core/src/tools/ask-user.test.ts create mode 100644 packages/core/src/tools/ask-user.ts diff --git a/packages/core/src/confirmation-bus/types.ts b/packages/core/src/confirmation-bus/types.ts index 0e7e7eae06..786894a972 100644 --- a/packages/core/src/confirmation-bus/types.ts +++ b/packages/core/src/confirmation-bus/types.ts @@ -22,6 +22,8 @@ export enum MessageBusType { HOOK_EXECUTION_RESPONSE = 'hook-execution-response', HOOK_POLICY_DECISION = 'hook-policy-decision', TOOL_CALLS_UPDATE = 'tool-calls-update', + ASK_USER_REQUEST = 'ask-user-request', + ASK_USER_RESPONSE = 'ask-user-response', } export interface ToolCallsUpdateMessage { @@ -141,6 +143,42 @@ export interface HookPolicyDecision { reason?: string; } +export interface QuestionOption { + label: string; + description: string; +} + +export enum QuestionType { + CHOICE = 'choice', + TEXT = 'text', + YESNO = 'yesno', +} + +export interface Question { + question: string; + header: string; + /** Question type: 'choice' renders selectable options, 'text' renders free-form input, 'yesno' renders a binary Yes/No choice. Defaults to 'choice'. */ + type?: QuestionType; + /** Available choices. Required when type is 'choice' (or omitted), ignored for 'text'. */ + options?: QuestionOption[]; + /** Allow multiple selections. Only applies to 'choice' type. */ + multiSelect?: boolean; + /** Placeholder hint text for 'text' type input field. */ + placeholder?: string; +} + +export interface AskUserRequest { + type: MessageBusType.ASK_USER_REQUEST; + questions: Question[]; + correlationId: string; +} + +export interface AskUserResponse { + type: MessageBusType.ASK_USER_RESPONSE; + correlationId: string; + answers: { [questionIndex: string]: string }; +} + export type Message = | ToolConfirmationRequest | ToolConfirmationResponse @@ -151,4 +189,6 @@ export type Message = | HookExecutionRequest | HookExecutionResponse | HookPolicyDecision + | AskUserRequest + | AskUserResponse | ToolCallsUpdateMessage; diff --git a/packages/core/src/tools/ask-user.test.ts b/packages/core/src/tools/ask-user.test.ts new file mode 100644 index 0000000000..05b64313b9 --- /dev/null +++ b/packages/core/src/tools/ask-user.test.ts @@ -0,0 +1,227 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { AskUserTool } from './ask-user.js'; +import { + MessageBusType, + QuestionType, + type Question, +} from '../confirmation-bus/types.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; + +describe('AskUserTool', () => { + let mockMessageBus: MessageBus; + let tool: AskUserTool; + + beforeEach(() => { + mockMessageBus = { + publish: vi.fn().mockResolvedValue(undefined), + subscribe: vi.fn(), + unsubscribe: vi.fn(), + } as unknown as MessageBus; + tool = new AskUserTool(mockMessageBus); + }); + + it('should have correct metadata', () => { + expect(tool.name).toBe('ask_user'); + expect(tool.displayName).toBe('Ask User'); + }); + + describe('validateToolParams', () => { + it('should return error if questions is missing', () => { + // @ts-expect-error - Intentionally invalid params + const result = tool.validateToolParams({}); + expect(result).toContain("must have required property 'questions'"); + }); + + it('should return error if questions array is empty', () => { + const result = tool.validateToolParams({ questions: [] }); + expect(result).toContain('must NOT have fewer than 1 items'); + }); + + it('should return error if questions array exceeds max', () => { + const questions = Array(5).fill({ + question: 'Test?', + header: 'Test', + options: [ + { label: 'A', description: 'A' }, + { label: 'B', description: 'B' }, + ], + }); + const result = tool.validateToolParams({ questions }); + expect(result).toContain('must NOT have more than 4 items'); + }); + + it('should return error if question field is missing', () => { + const result = tool.validateToolParams({ + questions: [{ header: 'Test' } as unknown as Question], + }); + expect(result).toContain("must have required property 'question'"); + }); + + it('should return error if header field is missing', () => { + const result = tool.validateToolParams({ + questions: [{ question: 'Test?' } as unknown as Question], + }); + expect(result).toContain("must have required property 'header'"); + }); + + it('should return error if header exceeds max length', () => { + const result = tool.validateToolParams({ + questions: [{ question: 'Test?', header: 'This is way too long' }], + }); + expect(result).toContain('must NOT have more than 12 characters'); + }); + + it('should return error if options has fewer than 2 items', () => { + const result = tool.validateToolParams({ + questions: [ + { + question: 'Test?', + header: 'Test', + options: [{ label: 'A', description: 'A' }], + }, + ], + }); + expect(result).toContain('must NOT have fewer than 2 items'); + }); + + it('should return error if options has more than 4 items', () => { + const result = tool.validateToolParams({ + questions: [ + { + question: 'Test?', + header: 'Test', + options: [ + { label: 'A', description: 'A' }, + { label: 'B', description: 'B' }, + { label: 'C', description: 'C' }, + { label: 'D', description: 'D' }, + { label: 'E', description: 'E' }, + ], + }, + ], + }); + expect(result).toContain('must NOT have more than 4 items'); + }); + + it('should return null for valid params', () => { + const result = tool.validateToolParams({ + questions: [ + { + question: 'Which approach?', + header: 'Approach', + options: [ + { label: 'A', description: 'Option A' }, + { label: 'B', description: 'Option B' }, + ], + }, + ], + }); + expect(result).toBeNull(); + }); + }); + + it('should publish ASK_USER_REQUEST and wait for response', async () => { + const questions = [ + { + question: 'How should we proceed with this task?', + header: 'Approach', + options: [ + { + label: 'Quick fix (Recommended)', + description: + 'Apply the most direct solution to resolve the immediate issue.', + }, + { + label: 'Comprehensive refactor', + description: + 'Restructure the affected code for better long-term maintainability.', + }, + ], + multiSelect: false, + }, + ]; + + const invocation = tool.build({ questions }); + const executePromise = invocation.execute(new AbortController().signal); + + // Verify publish called with normalized questions (type defaults to CHOICE) + expect(mockMessageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.ASK_USER_REQUEST, + questions: questions.map((q) => ({ + ...q, + type: QuestionType.CHOICE, + })), + }), + ); + + // Get the correlation ID from the published message + const publishCall = vi.mocked(mockMessageBus.publish).mock.calls[0][0] as { + correlationId: string; + }; + const correlationId = publishCall.correlationId; + expect(correlationId).toBeDefined(); + + // Verify subscribe called + expect(mockMessageBus.subscribe).toHaveBeenCalledWith( + MessageBusType.ASK_USER_RESPONSE, + expect.any(Function), + ); + + // Simulate response + const subscribeCall = vi + .mocked(mockMessageBus.subscribe) + .mock.calls.find((call) => call[0] === MessageBusType.ASK_USER_RESPONSE); + const handler = subscribeCall![1]; + + const answers = { '0': 'Quick fix (Recommended)' }; + handler({ + type: MessageBusType.ASK_USER_RESPONSE, + correlationId, + answers, + }); + + const result = await executePromise; + expect(result.returnDisplay).toContain('User answered:'); + expect(result.returnDisplay).toContain( + ' Approach → Quick fix (Recommended)', + ); + expect(JSON.parse(result.llmContent as string)).toEqual({ answers }); + }); + + it('should handle cancellation', async () => { + const invocation = tool.build({ + questions: [ + { + question: 'Which sections of the documentation should be updated?', + header: 'Docs', + options: [ + { + label: 'User Guide', + description: 'Update the main user-facing documentation.', + }, + { + label: 'API Reference', + description: 'Update the detailed API documentation.', + }, + ], + multiSelect: true, + }, + ], + }); + + const controller = new AbortController(); + const executePromise = invocation.execute(controller.signal); + + controller.abort(); + + const result = await executePromise; + expect(result.error?.message).toBe('Cancelled'); + }); +}); diff --git a/packages/core/src/tools/ask-user.ts b/packages/core/src/tools/ask-user.ts new file mode 100644 index 0000000000..7075809f9f --- /dev/null +++ b/packages/core/src/tools/ask-user.ts @@ -0,0 +1,208 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + BaseDeclarativeTool, + BaseToolInvocation, + type ToolResult, + Kind, + type ToolCallConfirmationDetails, +} from './tools.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { + MessageBusType, + QuestionType, + type Question, + type AskUserRequest, + type AskUserResponse, +} from '../confirmation-bus/types.js'; +import { randomUUID } from 'node:crypto'; +import { ASK_USER_TOOL_NAME } from './tool-names.js'; + +export interface AskUserParams { + questions: Question[]; +} + +export class AskUserTool extends BaseDeclarativeTool< + AskUserParams, + ToolResult +> { + constructor(messageBus: MessageBus) { + super( + ASK_USER_TOOL_NAME, + 'Ask User', + 'Ask the user one or more questions to gather preferences, clarify requirements, or make decisions.', + Kind.Other, + { + type: 'object', + required: ['questions'], + properties: { + questions: { + type: 'array', + minItems: 1, + maxItems: 4, + items: { + type: 'object', + required: ['question', 'header'], + properties: { + question: { + type: 'string', + description: + 'The complete question to ask the user. Should be clear, specific, and end with a question mark.', + }, + header: { + type: 'string', + maxLength: 12, + description: + 'Very short label displayed as a chip/tag (max 12 chars). Examples: "Auth method", "Library", "Approach".', + }, + type: { + type: 'string', + enum: ['choice', 'text', 'yesno'], + description: + "Question type. 'choice' (default) shows selectable options, 'text' shows a free-form text input, 'yesno' shows a binary Yes/No choice.", + }, + options: { + type: 'array', + description: + "Required for 'choice' type, ignored for 'text' and 'yesno'. The available choices (2-4 options). Do NOT include an 'Other' option - one is automatically added for 'choice' type.", + minItems: 2, + maxItems: 4, + items: { + type: 'object', + required: ['label', 'description'], + properties: { + label: { + type: 'string', + description: + 'The display text for this option that the user will see and select. Should be concise (1-5 words) and clearly describe the choice.', + }, + description: { + type: 'string', + description: + 'Explanation of what this option means or what will happen if chosen. Useful for providing context about trade-offs or implications.', + }, + }, + }, + }, + multiSelect: { + type: 'boolean', + description: + "Only applies to 'choice' type. Set to true to allow multiple selections.", + }, + placeholder: { + type: 'string', + description: + "Optional hint text for 'text' type input field.", + }, + }, + }, + }, + }, + }, + messageBus, + ); + } + + protected createInvocation( + params: AskUserParams, + messageBus: MessageBus, + toolName: string, + toolDisplayName: string, + ): AskUserInvocation { + return new AskUserInvocation(params, messageBus, toolName, toolDisplayName); + } +} + +export class AskUserInvocation extends BaseToolInvocation< + AskUserParams, + ToolResult +> { + override async shouldConfirmExecute( + _abortSignal: AbortSignal, + ): Promise { + return false; + } + + getDescription(): string { + return `Asking user: ${this.params.questions.map((q) => q.question).join(', ')}`; + } + + async execute(signal: AbortSignal): Promise { + const correlationId = randomUUID(); + + const request: AskUserRequest = { + type: MessageBusType.ASK_USER_REQUEST, + questions: this.params.questions.map((q) => ({ + ...q, + type: q.type ?? QuestionType.CHOICE, + })), + correlationId, + }; + + return new Promise((resolve, reject) => { + const responseHandler = (response: AskUserResponse): void => { + if (response.correlationId === correlationId) { + cleanup(); + + // Build formatted key-value display + const formattedAnswers = Object.entries(response.answers) + .map(([index, answer]) => { + const question = this.params.questions[parseInt(index, 10)]; + const category = question?.header ?? `Q${index}`; + return ` ${category} → ${answer}`; + }) + .join('\n'); + + const returnDisplay = `User answered:\n${formattedAnswers}`; + + resolve({ + llmContent: JSON.stringify({ answers: response.answers }), + returnDisplay, + }); + } + }; + + const cleanup = () => { + if (responseHandler) { + this.messageBus.unsubscribe( + MessageBusType.ASK_USER_RESPONSE, + responseHandler, + ); + } + signal.removeEventListener('abort', abortHandler); + }; + + const abortHandler = () => { + cleanup(); + resolve({ + llmContent: 'Tool execution cancelled by user.', + returnDisplay: 'Cancelled', + error: { + message: 'Cancelled', + }, + }); + }; + + if (signal.aborted) { + abortHandler(); + return; + } + + signal.addEventListener('abort', abortHandler); + this.messageBus.subscribe( + MessageBusType.ASK_USER_RESPONSE, + responseHandler, + ); + + // Publish request + this.messageBus.publish(request).catch((err) => { + cleanup(); + reject(err); + }); + }); + } +} diff --git a/packages/core/src/tools/tool-names.ts b/packages/core/src/tools/tool-names.ts index 41e4be8dec..42ccadc877 100644 --- a/packages/core/src/tools/tool-names.ts +++ b/packages/core/src/tools/tool-names.ts @@ -24,6 +24,7 @@ export const GET_INTERNAL_DOCS_TOOL_NAME = 'get_internal_docs'; export const ACTIVATE_SKILL_TOOL_NAME = 'activate_skill'; export const EDIT_TOOL_NAMES = new Set([EDIT_TOOL_NAME, WRITE_FILE_TOOL_NAME]); export const DELEGATE_TO_AGENT_TOOL_NAME = 'delegate_to_agent'; +export const ASK_USER_TOOL_NAME = 'ask_user'; /** Prefix used for tools discovered via the toolDiscoveryCommand. */ export const DISCOVERED_TOOL_PREFIX = 'discovered_tool_'; @@ -46,6 +47,7 @@ export const ALL_BUILTIN_TOOL_NAMES = [ MEMORY_TOOL_NAME, ACTIVATE_SKILL_TOOL_NAME, DELEGATE_TO_AGENT_TOOL_NAME, + ASK_USER_TOOL_NAME, ] as const; /**