From eada8077314c990a00daffaf518968dc1b4cfc5d Mon Sep 17 00:00:00 2001 From: Christian Gunderman Date: Wed, 4 Mar 2026 12:06:21 -0800 Subject: [PATCH] fix(hooks): support 'ask' decision for BeforeTool hooks --- integration-tests/hooks-system.test.ts | 165 +++++++++- packages/a2a-server/src/agent/task.ts | 6 +- .../messages/ToolConfirmationMessage.tsx | 9 + packages/cli/src/ui/hooks/useGeminiStream.ts | 2 +- packages/core/src/agents/remote-invocation.ts | 2 + .../core/src/agents/subagent-tool.test.ts | 2 + packages/core/src/agents/subagent-tool.ts | 4 +- packages/core/src/confirmation-bus/types.ts | 6 + .../src/core/coreToolHookTriggers.test.ts | 146 +-------- .../core/src/core/coreToolHookTriggers.ts | 96 +----- .../core/src/core/coreToolScheduler.test.ts | 57 +++- packages/core/src/core/coreToolScheduler.ts | 120 ++++++- .../src/core/coreToolSchedulerHooks.test.ts | 300 ++++++++++++++++++ packages/core/src/core/turn.ts | 2 + packages/core/src/hooks/hookAggregator.ts | 11 +- packages/core/src/hooks/types.ts | 9 +- packages/core/src/scheduler/confirmation.ts | 14 +- packages/core/src/scheduler/policy.test.ts | 1 + packages/core/src/scheduler/scheduler.test.ts | 2 + packages/core/src/scheduler/scheduler.ts | 93 +++++- .../src/scheduler/scheduler_parallel.test.ts | 1 + packages/core/src/scheduler/tool-executor.ts | 14 + packages/core/src/scheduler/types.ts | 2 + .../core/src/telemetry/conseca-logger.test.ts | 4 +- packages/core/src/test-utils/mock-tool.ts | 34 +- packages/core/src/tools/ask-user.ts | 2 + .../src/tools/confirmation-policy.test.ts | 2 +- packages/core/src/tools/edit.ts | 7 +- .../core/src/tools/enter-plan-mode.test.ts | 8 +- packages/core/src/tools/enter-plan-mode.ts | 11 +- .../core/src/tools/exit-plan-mode.test.ts | 6 +- packages/core/src/tools/exit-plan-mode.ts | 11 +- packages/core/src/tools/get-internal-docs.ts | 2 + packages/core/src/tools/mcp-tool.ts | 2 + packages/core/src/tools/memoryTool.ts | 2 + .../src/tools/message-bus-integration.test.ts | 4 +- packages/core/src/tools/shell.ts | 2 + packages/core/src/tools/tools.ts | 55 ++-- packages/core/src/tools/web-fetch.ts | 2 + packages/core/src/tools/write-file.ts | 7 +- 40 files changed, 909 insertions(+), 316 deletions(-) create mode 100644 packages/core/src/core/coreToolSchedulerHooks.test.ts diff --git a/integration-tests/hooks-system.test.ts b/integration-tests/hooks-system.test.ts index 479851957b..a4be2880e0 100644 --- a/integration-tests/hooks-system.test.ts +++ b/integration-tests/hooks-system.test.ts @@ -8,6 +8,7 @@ import { describe, it, expect, beforeEach, afterEach } from 'vitest'; import { TestRig, poll, normalizePath } from './test-helper.js'; import { join } from 'node:path'; import { writeFileSync } from 'node:fs'; +import os from 'node:os'; describe('Hooks System Integration', () => { let rig: TestRig; @@ -2230,7 +2231,7 @@ console.log(JSON.stringify({ // The hook should have stopped execution message (returned from tool) expect(result).toContain( - 'Agent execution stopped: Emergency Stop triggered by hook', + 'Agent execution stopped by hook: Emergency Stop triggered by hook', ); // Tool should NOT be called successfully (it was blocked/stopped) @@ -2242,4 +2243,166 @@ console.log(JSON.stringify({ expect(writeFileCalls).toHaveLength(0); }); }); + + describe('Hooks "ask" Decision Integration', () => { + it( + 'should force confirmation prompt when hook returns "ask" decision even in YOLO mode', + { timeout: 20000 }, + async () => { + const testName = + 'should force confirmation prompt when hook returns "ask" decision'; + + // 1. Setup hook script that returns 'ask' decision + const hookOutput = { + decision: 'ask', + systemMessage: 'Confirmation forced by security hook', + hookSpecificOutput: { + hookEventName: 'BeforeTool', + }, + }; + + const hookScript = `console.log(JSON.stringify(${JSON.stringify( + hookOutput, + )}));`; + + // Create script path predictably + const scriptPath = join(os.tmpdir(), 'gemini-cli-tests-ask-hook.js'); + writeFileSync(scriptPath, hookScript); + + // 2. Setup rig with YOLO mode enabled but with the 'ask' hook + rig.setup(testName, { + fakeResponsesPath: join( + import.meta.dirname, + 'hooks-system.allow-tool.responses', + ), + settings: { + debugMode: true, + tools: { + approval: 'yolo', + }, + hooksConfig: { + enabled: true, + }, + hooks: { + BeforeTool: [ + { + matcher: 'write_file', + hooks: [ + { + type: 'command', + command: `node "${scriptPath}"`, + timeout: 5000, + }, + ], + }, + ], + }, + }, + }); + + // 3. Run interactive and verify prompt appears despite YOLO mode + const run = await rig.runInteractive(); + + // Send prompt that will trigger write_file + await run.type('Create a file called ask-test.txt with content "test"'); + await run.type('\r'); + + // Wait for the FORCED confirmation prompt to appear + // It should contain the system message from the hook + await run.expectText('Confirmation forced by security hook', 15000); + await run.expectText('Allow', 5000); + + // 4. Approve the permission + await run.type('y'); + await run.type('\r'); + + // Wait for command to execute + await run.expectText('approved.txt', 15000); + + // Should find the tool call + const foundWriteFile = await rig.waitForToolCall('write_file'); + expect(foundWriteFile).toBeTruthy(); + + // File should be created + const fileContent = rig.readFile('approved.txt'); + expect(fileContent).toBe('Approved content'); + }, + ); + + it('should allow cancelling when hook forces "ask" decision', async () => { + const testName = + 'should allow cancelling when hook forces "ask" decision'; + const hookOutput = { + decision: 'ask', + systemMessage: 'Confirmation forced for cancellation test', + hookSpecificOutput: { + hookEventName: 'BeforeTool', + }, + }; + + const hookScript = `console.log(JSON.stringify(${JSON.stringify( + hookOutput, + )}));`; + + const scriptPath = join( + os.tmpdir(), + 'gemini-cli-tests-ask-cancel-hook.js', + ); + writeFileSync(scriptPath, hookScript); + + rig.setup(testName, { + fakeResponsesPath: join( + import.meta.dirname, + 'hooks-system.allow-tool.responses', + ), + settings: { + debugMode: true, + tools: { + approval: 'yolo', + }, + hooksConfig: { + enabled: true, + }, + hooks: { + BeforeTool: [ + { + matcher: 'write_file', + hooks: [ + { + type: 'command', + command: `node "${scriptPath}"`, + timeout: 5000, + }, + ], + }, + ], + }, + }, + }); + + const run = await rig.runInteractive(); + + await run.type( + 'Create a file called cancel-test.txt with content "test"', + ); + await run.type('\r'); + + await run.expectText('Confirmation forced for cancellation test', 15000); + + // 4. Deny the permission using option 4 + await run.type('4'); + await run.type('\r'); + + // Wait for cancellation message + await run.expectText('Cancelled', 10000); + + // Tool should NOT be called successfully + const toolLogs = rig.readToolLogs(); + const writeFileCalls = toolLogs.filter( + (t) => + t.toolRequest.name === 'write_file' && t.toolRequest.success === true, + ); + expect(writeFileCalls).toHaveLength(0); + }); + }); }); diff --git a/packages/a2a-server/src/agent/task.ts b/packages/a2a-server/src/agent/task.ts index c969e601c3..23d6ccac4b 100644 --- a/packages/a2a-server/src/agent/task.ts +++ b/packages/a2a-server/src/agent/task.ts @@ -443,7 +443,11 @@ export class Task { 'Auto-approving all tool calls.', ); toolCalls.forEach((tc: ToolCall) => { - if (tc.status === 'awaiting_approval' && tc.confirmationDetails) { + if ( + tc.status === 'awaiting_approval' && + tc.confirmationDetails && + !tc.request.forcedAsk + ) { const details = tc.confirmationDetails; if (isToolCallConfirmationDetails(details)) { // eslint-disable-next-line @typescript-eslint/no-floating-promises diff --git a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx index 022a68e953..6fec797943 100644 --- a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx +++ b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx @@ -668,6 +668,15 @@ export const ToolConfirmationMessage: React.FC< paddingTop={0} paddingBottom={handlesOwnUI ? 0 : 1} > + {/* System message from hook */} + {confirmationDetails.systemMessage && ( + + + {confirmationDetails.systemMessage} + + + )} + {handlesOwnUI ? ( bodyContent ) : ( diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 2a25359614..af868f0d44 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -1574,7 +1574,7 @@ export const useGeminiStream = ( ) { let awaitingApprovalCalls = toolCalls.filter( (call): call is TrackedWaitingToolCall => - call.status === 'awaiting_approval', + call.status === 'awaiting_approval' && !call.request.forcedAsk, ); // For AUTO_EDIT mode, only approve edit tools (replace, write_file) diff --git a/packages/core/src/agents/remote-invocation.ts b/packages/core/src/agents/remote-invocation.ts index a8c75ec51c..1d995fe2b2 100644 --- a/packages/core/src/agents/remote-invocation.ts +++ b/packages/core/src/agents/remote-invocation.ts @@ -6,6 +6,7 @@ import { BaseToolInvocation, + type ForcedToolDecision, type ToolConfirmationOutcome, type ToolResult, type ToolCallConfirmationDetails, @@ -133,6 +134,7 @@ export class RemoteAgentInvocation extends BaseToolInvocation< protected override async getConfirmationDetails( _abortSignal: AbortSignal, + _forcedDecision?: ForcedToolDecision, ): Promise { // For now, always require confirmation for remote agents until we have a policy system for them. return { diff --git a/packages/core/src/agents/subagent-tool.test.ts b/packages/core/src/agents/subagent-tool.test.ts index 622fd054f0..8664e3a959 100644 --- a/packages/core/src/agents/subagent-tool.test.ts +++ b/packages/core/src/agents/subagent-tool.test.ts @@ -112,6 +112,7 @@ describe('SubAgentInvocation', () => { expect(result).toBe(false); expect(mockInnerInvocation.shouldConfirmExecute).toHaveBeenCalledWith( abortSignal, + undefined, ); expect(MockSubagentToolWrapper).toHaveBeenCalledWith( testDefinition, @@ -156,6 +157,7 @@ describe('SubAgentInvocation', () => { expect(result).toBe(confirmationDetails); expect(mockInnerInvocation.shouldConfirmExecute).toHaveBeenCalledWith( abortSignal, + undefined, ); expect(MockSubagentToolWrapper).toHaveBeenCalledWith( testRemoteDefinition, diff --git a/packages/core/src/agents/subagent-tool.ts b/packages/core/src/agents/subagent-tool.ts index 21a3864160..c580576ac0 100644 --- a/packages/core/src/agents/subagent-tool.ts +++ b/packages/core/src/agents/subagent-tool.ts @@ -9,6 +9,7 @@ import { Kind, type ToolInvocation, type ToolResult, + type ForcedToolDecision, BaseToolInvocation, type ToolCallConfirmationDetails, isTool, @@ -145,12 +146,13 @@ class SubAgentInvocation extends BaseToolInvocation { override async shouldConfirmExecute( abortSignal: AbortSignal, + forcedDecision?: ForcedToolDecision, ): Promise { const invocation = this.buildSubInvocation( this.definition, this.withUserHints(this.params), ); - return invocation.shouldConfirmExecute(abortSignal); + return invocation.shouldConfirmExecute(abortSignal, forcedDecision); } async execute( diff --git a/packages/core/src/confirmation-bus/types.ts b/packages/core/src/confirmation-bus/types.ts index aefafe0fa0..6a693ef8fa 100644 --- a/packages/core/src/confirmation-bus/types.ts +++ b/packages/core/src/confirmation-bus/types.ts @@ -72,12 +72,14 @@ export type SerializableConfirmationDetails = | { type: 'info'; title: string; + systemMessage?: string; prompt: string; urls?: string[]; } | { type: 'edit'; title: string; + systemMessage?: string; fileName: string; filePath: string; fileDiff: string; @@ -88,6 +90,7 @@ export type SerializableConfirmationDetails = | { type: 'exec'; title: string; + systemMessage?: string; command: string; rootCommand: string; rootCommands: string[]; @@ -96,6 +99,7 @@ export type SerializableConfirmationDetails = | { type: 'mcp'; title: string; + systemMessage?: string; serverName: string; toolName: string; toolDisplayName: string; @@ -106,11 +110,13 @@ export type SerializableConfirmationDetails = | { type: 'ask_user'; title: string; + systemMessage?: string; questions: Question[]; } | { type: 'exit_plan_mode'; title: string; + systemMessage?: string; planPath: string; }; diff --git a/packages/core/src/core/coreToolHookTriggers.test.ts b/packages/core/src/core/coreToolHookTriggers.test.ts index 2a654042c6..05c7aa055b 100644 --- a/packages/core/src/core/coreToolHookTriggers.test.ts +++ b/packages/core/src/core/coreToolHookTriggers.test.ts @@ -15,10 +15,7 @@ import { import type { MessageBus } from '../confirmation-bus/message-bus.js'; import type { HookSystem } from '../hooks/hookSystem.js'; import type { Config } from '../config/config.js'; -import { - type DefaultHookOutput, - BeforeToolHookOutput, -} from '../hooks/types.js'; +import { type DefaultHookOutput } from '../hooks/types.js'; class MockInvocation extends BaseToolInvocation<{ key?: string }, ToolResult> { constructor(params: { key?: string }, messageBus: MessageBus) { @@ -66,70 +63,11 @@ describe('executeToolWithHooks', () => { } as unknown as AnyDeclarativeTool; }); - it('should prioritize continue: false over decision: block in BeforeTool', async () => { - const invocation = new MockInvocation({}, messageBus); - const abortSignal = new AbortController().signal; - - vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({ - shouldStopExecution: () => true, - getEffectiveReason: () => 'Stop immediately', - getBlockingError: () => ({ - blocked: false, - reason: 'Should be ignored because continue is false', - }), - } as unknown as DefaultHookOutput); - - const result = await executeToolWithHooks( - invocation, - 'test_tool', - abortSignal, - mockTool, - undefined, - undefined, - undefined, - mockConfig, - ); - - expect(result.error?.type).toBe(ToolErrorType.STOP_EXECUTION); - expect(result.error?.message).toBe('Stop immediately'); - }); - - it('should block execution in BeforeTool if decision is block', async () => { - const invocation = new MockInvocation({}, messageBus); - const abortSignal = new AbortController().signal; - - vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({ - shouldStopExecution: () => false, - getEffectiveReason: () => '', - getBlockingError: () => ({ blocked: true, reason: 'Execution blocked' }), - } as unknown as DefaultHookOutput); - - const result = await executeToolWithHooks( - invocation, - 'test_tool', - abortSignal, - mockTool, - undefined, - undefined, - undefined, - mockConfig, - ); - - expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED); - expect(result.error?.message).toBe('Execution blocked'); - }); - it('should handle continue: false in AfterTool', async () => { const invocation = new MockInvocation({}, messageBus); const abortSignal = new AbortController().signal; const spy = vi.spyOn(invocation, 'execute'); - vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({ - shouldStopExecution: () => false, - getEffectiveReason: () => '', - getBlockingError: () => ({ blocked: false, reason: '' }), - } as unknown as DefaultHookOutput); - vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue({ shouldStopExecution: () => true, getEffectiveReason: () => 'Stop after execution', @@ -156,12 +94,6 @@ describe('executeToolWithHooks', () => { const invocation = new MockInvocation({}, messageBus); const abortSignal = new AbortController().signal; - vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({ - shouldStopExecution: () => false, - getEffectiveReason: () => '', - getBlockingError: () => ({ blocked: false, reason: '' }), - } as unknown as DefaultHookOutput); - vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue({ shouldStopExecution: () => false, getEffectiveReason: () => '', @@ -182,80 +114,4 @@ describe('executeToolWithHooks', () => { expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED); expect(result.error?.message).toBe('Result denied'); }); - - it('should apply modified tool input from BeforeTool hook', async () => { - const params = { key: 'original' }; - const invocation = new MockInvocation(params, messageBus); - const toolName = 'test-tool'; - const abortSignal = new AbortController().signal; - - const mockBeforeOutput = new BeforeToolHookOutput({ - continue: true, - hookSpecificOutput: { - hookEventName: 'BeforeTool', - tool_input: { key: 'modified' }, - }, - }); - vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue( - mockBeforeOutput, - ); - - vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue(undefined); - - const result = await executeToolWithHooks( - invocation, - toolName, - abortSignal, - mockTool, - undefined, - undefined, - undefined, - mockConfig, - ); - - // Verify result reflects modified input - expect(result.llmContent).toBe( - 'key: modified\n\n[System] Tool input parameters (key) were modified by a hook before execution.', - ); - // Verify params object was modified in place - expect(invocation.params.key).toBe('modified'); - - expect(mockHookSystem.fireBeforeToolEvent).toHaveBeenCalled(); - expect(mockTool.build).toHaveBeenCalledWith({ key: 'modified' }); - }); - - it('should not modify input if hook does not provide tool_input', async () => { - const params = { key: 'original' }; - const invocation = new MockInvocation(params, messageBus); - const toolName = 'test-tool'; - const abortSignal = new AbortController().signal; - - const mockBeforeOutput = new BeforeToolHookOutput({ - continue: true, - hookSpecificOutput: { - hookEventName: 'BeforeTool', - // No tool input - }, - }); - vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue( - mockBeforeOutput, - ); - - vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue(undefined); - - const result = await executeToolWithHooks( - invocation, - toolName, - abortSignal, - mockTool, - undefined, - undefined, - undefined, - mockConfig, - ); - - expect(result.llmContent).toBe('key: original'); - expect(invocation.params.key).toBe('original'); - expect(mockTool.build).not.toHaveBeenCalled(); - }); }); diff --git a/packages/core/src/core/coreToolHookTriggers.ts b/packages/core/src/core/coreToolHookTriggers.ts index cbd90e8039..aace3fe9f9 100644 --- a/packages/core/src/core/coreToolHookTriggers.ts +++ b/packages/core/src/core/coreToolHookTriggers.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { type McpToolContext, BeforeToolHookOutput } from '../hooks/types.js'; +import { type McpToolContext } from '../hooks/types.js'; import type { Config } from '../config/config.js'; import type { ToolResult, @@ -13,7 +13,6 @@ import type { ToolLiveOutput, } from '../tools/tools.js'; import { ToolErrorType } from '../tools/tool-error.js'; -import { debugLogger } from '../utils/debugLogger.js'; import type { ShellExecutionConfig } from '../index.js'; import { ShellToolInvocation } from '../tools/shell.js'; import { DiscoveredMCPToolInvocation } from '../tools/mcp-tool.js'; @@ -25,7 +24,7 @@ import { DiscoveredMCPToolInvocation } from '../tools/mcp-tool.js'; * @param config Config to look up server details * @returns MCP context if this is an MCP tool, undefined otherwise */ -function extractMcpContext( +export function extractMcpContext( invocation: ShellToolInvocation | AnyToolInvocation, config: Config, ): McpToolContext | undefined { @@ -78,81 +77,12 @@ export async function executeToolWithHooks( config?: Config, originalRequestName?: string, ): Promise { - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const toolInput = (invocation.params || {}) as Record; - let inputWasModified = false; - let modifiedKeys: string[] = []; - // Extract MCP context if this is an MCP tool (only if config is provided) const mcpContext = config ? extractMcpContext(invocation, config) : undefined; - const hookSystem = config?.getHookSystem(); - if (hookSystem) { - const beforeOutput = await hookSystem.fireBeforeToolEvent( - toolName, - toolInput, - mcpContext, - originalRequestName, - ); - // Check if hook requested to stop entire agent execution - if (beforeOutput?.shouldStopExecution()) { - const reason = beforeOutput.getEffectiveReason(); - return { - llmContent: `Agent execution stopped by hook: ${reason}`, - returnDisplay: `Agent execution stopped by hook: ${reason}`, - error: { - type: ToolErrorType.STOP_EXECUTION, - message: reason, - }, - }; - } - - // Check if hook blocked the tool execution - const blockingError = beforeOutput?.getBlockingError(); - if (blockingError?.blocked) { - return { - llmContent: `Tool execution blocked: ${blockingError.reason}`, - returnDisplay: `Tool execution blocked: ${blockingError.reason}`, - error: { - type: ToolErrorType.EXECUTION_FAILED, - message: blockingError.reason, - }, - }; - } - - // Check if hook requested to update tool input - if (beforeOutput instanceof BeforeToolHookOutput) { - const modifiedInput = beforeOutput.getModifiedToolInput(); - if (modifiedInput) { - // We modify the toolInput object in-place, which should be the same reference as invocation.params - // We use Object.assign to update properties - Object.assign(invocation.params, modifiedInput); - debugLogger.debug(`Tool input modified by hook for ${toolName}`); - inputWasModified = true; - modifiedKeys = Object.keys(modifiedInput); - - // Recreate the invocation with the new parameters - // to ensure any derived state (like resolvedPath in ReadFileTool) is updated. - try { - // We use the tool's build method to validate and create the invocation - // This ensures consistent behavior with the initial creation - invocation = tool.build(invocation.params); - } catch (error) { - return { - llmContent: `Tool parameter modification by hook failed validation: ${ - error instanceof Error ? error.message : String(error) - }`, - returnDisplay: `Tool parameter modification by hook failed validation.`, - error: { - type: ToolErrorType.INVALID_TOOL_PARAMS, - message: String(error), - }, - }; - } - } - } - } + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + const toolInput = (invocation.params || {}) as Record; // Execute the actual tool let toolResult: ToolResult; @@ -171,24 +101,6 @@ export async function executeToolWithHooks( ); } - // Append notification if parameters were modified - if (inputWasModified) { - const modificationMsg = `\n\n[System] Tool input parameters (${modifiedKeys.join( - ', ', - )}) were modified by a hook before execution.`; - if (typeof toolResult.llmContent === 'string') { - toolResult.llmContent += modificationMsg; - } else if (Array.isArray(toolResult.llmContent)) { - toolResult.llmContent.push({ text: modificationMsg }); - } else if (toolResult.llmContent) { - // Handle single Part case by converting to an array - toolResult.llmContent = [ - toolResult.llmContent, - { text: modificationMsg }, - ]; - } - } - if (hookSystem) { const afterOutput = await hookSystem.fireAfterToolEvent( toolName, diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index fcddc05a44..d770468b26 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -285,6 +285,7 @@ function createMockConfig(overrides: Partial = {}): Config { getGeminiClient: () => null, getMessageBus: () => createMockMessageBus(), getEnableHooks: () => false, + getHookSystem: () => undefined, getExperiments: () => {}, } as unknown as Config; @@ -1014,7 +1015,12 @@ describe('CoreToolScheduler YOLO mode', () => { // Assert // 1. The tool's execute method was called directly. - expect(executeFn).toHaveBeenCalledWith({ param: 'value' }); + expect(executeFn).toHaveBeenCalledWith( + { param: 'value' }, + expect.anything(), + undefined, + expect.anything(), + ); // 2. The tool call status never entered CoreToolCallStatus.AwaitingApproval. const statusUpdates = onToolCallsUpdate.mock.calls @@ -1117,7 +1123,12 @@ describe('CoreToolScheduler request queueing', () => { ); // Ensure the second tool call hasn't been executed yet. - expect(executeFn).toHaveBeenCalledWith({ a: 1 }); + expect(executeFn).toHaveBeenCalledWith( + { a: 1 }, + expect.anything(), + undefined, + expect.anything(), + ); // Complete the first tool call. resolveFirstCall!({ @@ -1141,7 +1152,12 @@ describe('CoreToolScheduler request queueing', () => { // Now the second tool call should have been executed. expect(executeFn).toHaveBeenCalledTimes(2); }); - expect(executeFn).toHaveBeenCalledWith({ b: 2 }); + expect(executeFn).toHaveBeenCalledWith( + { b: 2 }, + expect.anything(), + undefined, + expect.anything(), + ); // Wait for the second completion. await vi.waitFor(() => { @@ -1235,7 +1251,12 @@ describe('CoreToolScheduler request queueing', () => { // Assert // 1. The tool's execute method was called directly. - expect(executeFn).toHaveBeenCalledWith({ param: 'value' }); + expect(executeFn).toHaveBeenCalledWith( + { param: 'value' }, + expect.anything(), + undefined, + expect.anything(), + ); // 2. The tool call status never entered CoreToolCallStatus.AwaitingApproval. const statusUpdates = onToolCallsUpdate.mock.calls @@ -1416,8 +1437,18 @@ describe('CoreToolScheduler request queueing', () => { // Ensure the tool was called twice with the correct arguments. expect(executeFn).toHaveBeenCalledTimes(2); - expect(executeFn).toHaveBeenCalledWith({ a: 1 }); - expect(executeFn).toHaveBeenCalledWith({ b: 2 }); + expect(executeFn).toHaveBeenCalledWith( + { a: 1 }, + expect.anything(), + undefined, + expect.anything(), + ); + expect(executeFn).toHaveBeenCalledWith( + { b: 2 }, + expect.anything(), + undefined, + expect.anything(), + ); // Ensure completion callbacks were called twice. expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2); @@ -1774,8 +1805,18 @@ describe('CoreToolScheduler Sequential Execution', () => { // Check that execute was called for the first two tools only expect(executeFn).toHaveBeenCalledTimes(2); - expect(executeFn).toHaveBeenCalledWith({ call: 1 }); - expect(executeFn).toHaveBeenCalledWith({ call: 2 }); + expect(executeFn).toHaveBeenCalledWith( + { call: 1 }, + expect.anything(), + undefined, + expect.anything(), + ); + expect(executeFn).toHaveBeenCalledWith( + { call: 2 }, + expect.anything(), + undefined, + expect.anything(), + ); const completedCalls = onAllToolCallsComplete.mock .calls[0][0] as ToolCall[]; diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 23473e199d..ecb5452079 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -50,6 +50,8 @@ import { ToolExecutor } from '../scheduler/tool-executor.js'; import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; import { getPolicyDenialError } from '../scheduler/policy.js'; import { GeminiCliOperation } from '../telemetry/constants.js'; +import { extractMcpContext } from './coreToolHookTriggers.js'; +import { BeforeToolHookOutput } from '../hooks/types.js'; export type { ToolCall, @@ -604,7 +606,7 @@ export class CoreToolScheduler { return; } - const toolCall = this.toolCallQueue.shift()!; + let toolCall = this.toolCallQueue.shift()!; // This is now the single active tool call. this.toolCalls = [toolCall]; @@ -620,7 +622,8 @@ export class CoreToolScheduler { // This logic is moved from the old `for` loop in `_schedule`. if (toolCall.status === CoreToolCallStatus.Validating) { - const { request: reqInfo, invocation } = toolCall; + const { request: reqInfo } = toolCall; + let { invocation } = toolCall; try { if (signal.aborted) { @@ -635,6 +638,90 @@ export class CoreToolScheduler { return; } + // 1. Hook Check (BeforeTool) + let hookDecision: 'ask' | 'block' | undefined; + let hookSystemMessage: string | undefined; + + const hookSystem = this.config.getHookSystem(); + if (hookSystem) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + const toolInput = (invocation.params || {}) as Record< + string, + unknown + >; + const mcpContext = extractMcpContext(invocation, this.config); + + const beforeOutput = await hookSystem.fireBeforeToolEvent( + toolCall.request.name, + toolInput, + mcpContext, + toolCall.request.originalRequestName, + ); + + if (beforeOutput) { + // Check if hook requested to stop entire agent execution + if (beforeOutput.shouldStopExecution()) { + const reason = beforeOutput.getEffectiveReason(); + this.setStatusInternal( + reqInfo.callId, + CoreToolCallStatus.Error, + signal, + createErrorResponse( + reqInfo, + new Error(`Agent execution stopped by hook: ${reason}`), + ToolErrorType.STOP_EXECUTION, + ), + ); + await this.checkAndNotifyCompletion(signal); + return; + } + + // Check if hook blocked the tool execution + const blockingError = beforeOutput.getBlockingError(); + if (blockingError?.blocked) { + this.setStatusInternal( + reqInfo.callId, + CoreToolCallStatus.Error, + signal, + createErrorResponse( + reqInfo, + new Error(`Tool execution blocked: ${blockingError.reason}`), + ToolErrorType.POLICY_VIOLATION, + ), + ); + await this.checkAndNotifyCompletion(signal); + return; + } + + if (beforeOutput.isAskDecision()) { + hookDecision = 'ask'; + hookSystemMessage = beforeOutput.systemMessage; + // Mark the request so UI knows not to auto-approve it + toolCall.request.forcedAsk = true; + } + + // Check if hook requested to update tool input + if (beforeOutput instanceof BeforeToolHookOutput) { + const modifiedInput = beforeOutput.getModifiedToolInput(); + if (modifiedInput) { + this.setArgsInternal(reqInfo.callId, modifiedInput); + + // IMPORTANT: toolCall and invocation must be updated because setArgsInternal created a new one + const updatedCall = this.toolCalls.find( + (c) => c.request.callId === reqInfo.callId, + ); + if (updatedCall) { + toolCall = updatedCall; + toolCall.request.inputModifiedByHook = true; + if ('invocation' in updatedCall) { + invocation = updatedCall.invocation; + } + } + } + } + } + } + // Policy Check using PolicyEngine // We must reconstruct the FunctionCall format expected by PolicyEngine const toolCallForPolicy = { @@ -645,13 +732,18 @@ export class CoreToolScheduler { toolCall.tool instanceof DiscoveredMCPTool ? toolCall.tool.serverName : undefined; - const toolAnnotations = toolCall.tool.toolAnnotations; + const toolAnnotations = toolCall.tool?.toolAnnotations; - const { decision, rule } = await this.config + const { decision: policyDecision, rule } = await this.config .getPolicyEngine() .check(toolCallForPolicy, serverName, toolAnnotations); - if (decision === PolicyDecision.DENY) { + let finalDecision = policyDecision; + if (hookDecision === 'ask') { + finalDecision = PolicyDecision.ASK_USER; + } + + if (finalDecision === PolicyDecision.DENY) { const { errorMessage, errorType } = getPolicyDenialError( this.config, rule, @@ -666,7 +758,7 @@ export class CoreToolScheduler { return; } - if (decision === PolicyDecision.ALLOW) { + if (finalDecision === PolicyDecision.ALLOW) { this.setToolCallOutcome( reqInfo.callId, ToolConfirmationOutcome.ProceedAlways, @@ -677,11 +769,13 @@ export class CoreToolScheduler { signal, ); } else { - // PolicyDecision.ASK_USER + // PolicyDecision.ASK_USER or forced 'ask' by hook // We need confirmation details to show to the user - const confirmationDetails = - await invocation.shouldConfirmExecute(signal); + const confirmationDetails = await invocation.shouldConfirmExecute( + signal, + hookDecision === 'ask' ? 'ask_user' : undefined, + ); if (!confirmationDetails) { this.setToolCallOutcome( @@ -697,11 +791,17 @@ export class CoreToolScheduler { if (!this.config.isInteractive()) { throw new Error( `Tool execution for "${ - toolCall.tool.displayName || toolCall.tool.name + toolCall.tool?.displayName || + toolCall.tool?.name || + toolCall.request.name }" requires user confirmation, which is not supported in non-interactive mode.`, ); } + if (hookSystemMessage) { + confirmationDetails.systemMessage = hookSystemMessage; + } + // Fire Notification hook before showing confirmation to user const hookSystem = this.config.getHookSystem(); if (hookSystem) { diff --git a/packages/core/src/core/coreToolSchedulerHooks.test.ts b/packages/core/src/core/coreToolSchedulerHooks.test.ts new file mode 100644 index 0000000000..d0872ad64d --- /dev/null +++ b/packages/core/src/core/coreToolSchedulerHooks.test.ts @@ -0,0 +1,300 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi } from 'vitest'; +import { CoreToolScheduler } from './coreToolScheduler.js'; +import type { ToolCall, ErroredToolCall } from '../scheduler/types.js'; +import type { Config, ToolRegistry } from '../index.js'; +import { + ApprovalMode, + DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, +} from '../index.js'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; +import { MockTool } from '../test-utils/mock-tool.js'; +import { DEFAULT_GEMINI_MODEL } from '../config/models.js'; +import type { PolicyEngine } from '../policy/policy-engine.js'; +import type { HookSystem } from '../hooks/hookSystem.js'; +import { BeforeToolHookOutput } from '../hooks/types.js'; + +function createMockConfig(overrides: Partial = {}): Config { + const defaultToolRegistry = { + getTool: () => undefined, + getToolByName: () => undefined, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByDisplayName: () => undefined, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + getExperiments: () => {}, + } as unknown as ToolRegistry; + + const baseConfig = { + getSessionId: () => 'test-session-id', + getUsageStatisticsEnabled: () => true, + getDebugMode: () => false, + isInteractive: () => true, + getApprovalMode: () => ApprovalMode.DEFAULT, + setApprovalMode: () => {}, + getAllowedTools: () => [], + getContentGeneratorConfig: () => ({ + model: 'test-model', + authType: 'oauth-personal', + }), + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + sanitizationConfig: { + enableEnvironmentVariableRedaction: true, + allowedEnvironmentVariables: [], + blockedEnvironmentVariables: [], + }, + }), + storage: { + getProjectTempDir: () => '/tmp', + }, + getTruncateToolOutputThreshold: () => + DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, + getTruncateToolOutputLines: () => 1000, + getToolRegistry: () => defaultToolRegistry, + getActiveModel: () => DEFAULT_GEMINI_MODEL, + getGeminiClient: () => null, + getMessageBus: () => createMockMessageBus(), + getEnableHooks: () => true, // Enabled for these tests + getExperiments: () => {}, + getPolicyEngine: () => + ({ + check: async () => ({ decision: 'allow' }), // Default allow for hook tests + }) as unknown as PolicyEngine, + } as unknown as Config; + + return { ...baseConfig, ...overrides } as Config; +} + +describe('CoreToolScheduler Hooks', () => { + it('should stop execution if BeforeTool hook requests stop', async () => { + const executeFn = vi.fn().mockResolvedValue({ + llmContent: 'Tool executed', + returnDisplay: 'Tool executed', + }); + const mockTool = new MockTool({ name: 'mockTool', execute: executeFn }); + + const toolRegistry = { + getTool: () => mockTool, + getToolByName: () => mockTool, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByDisplayName: () => mockTool, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + } as unknown as ToolRegistry; + + const mockMessageBus = createMockMessageBus(); + const mockHookSystem = { + fireBeforeToolEvent: vi.fn().mockResolvedValue({ + shouldStopExecution: () => true, + getEffectiveReason: () => 'Hook stopped execution', + getBlockingError: () => ({ blocked: false }), + isAskDecision: () => false, + }), + } as unknown as HookSystem; + + const mockConfig = createMockConfig({ + getToolRegistry: () => toolRegistry, + getMessageBus: () => mockMessageBus, + getHookSystem: () => mockHookSystem, + getApprovalMode: () => ApprovalMode.YOLO, + }); + + const onAllToolCallsComplete = vi.fn(); + const scheduler = new CoreToolScheduler({ + config: mockConfig, + onAllToolCallsComplete, + getPreferredEditor: () => 'vscode', + }); + + const request = { + callId: '1', + name: 'mockTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-1', + }; + + await scheduler.schedule([request], new AbortController().signal); + + await vi.waitFor(() => { + expect(onAllToolCallsComplete).toHaveBeenCalled(); + }); + + const completedCalls = onAllToolCallsComplete.mock + .calls[0][0] as ToolCall[]; + expect(completedCalls[0].status).toBe('error'); + const erroredCall = completedCalls[0] as ErroredToolCall; + + // Check error type/message + expect(erroredCall.response.error?.message).toContain( + 'Hook stopped execution', + ); + expect(executeFn).not.toHaveBeenCalled(); + }); + + it('should block tool execution if BeforeTool hook requests block', async () => { + const executeFn = vi.fn(); + const mockTool = new MockTool({ name: 'mockTool', execute: executeFn }); + + const toolRegistry = { + getTool: () => mockTool, + getToolByName: () => mockTool, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByDisplayName: () => mockTool, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + } as unknown as ToolRegistry; + + const mockMessageBus = createMockMessageBus(); + const mockHookSystem = { + fireBeforeToolEvent: vi.fn().mockResolvedValue({ + shouldStopExecution: () => false, + getBlockingError: () => ({ + blocked: true, + reason: 'Hook blocked execution', + }), + isAskDecision: () => false, + }), + } as unknown as HookSystem; + + const mockConfig = createMockConfig({ + getToolRegistry: () => toolRegistry, + getMessageBus: () => mockMessageBus, + getHookSystem: () => mockHookSystem, + getApprovalMode: () => ApprovalMode.YOLO, + }); + + const onAllToolCallsComplete = vi.fn(); + const scheduler = new CoreToolScheduler({ + config: mockConfig, + onAllToolCallsComplete, + getPreferredEditor: () => 'vscode', + }); + + const request = { + callId: '1', + name: 'mockTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-1', + }; + + await scheduler.schedule([request], new AbortController().signal); + + await vi.waitFor(() => { + expect(onAllToolCallsComplete).toHaveBeenCalled(); + }); + + const completedCalls = onAllToolCallsComplete.mock + .calls[0][0] as ToolCall[]; + expect(completedCalls[0].status).toBe('error'); + const erroredCall = completedCalls[0] as ErroredToolCall; + expect(erroredCall.response.error?.message).toContain( + 'Hook blocked execution', + ); + expect(executeFn).not.toHaveBeenCalled(); + }); + + it('should update tool input if BeforeTool hook provides modified input', async () => { + const executeFn = vi.fn().mockResolvedValue({ + llmContent: 'Tool executed', + returnDisplay: 'Tool executed', + }); + const mockTool = new MockTool({ name: 'mockTool', execute: executeFn }); + + const toolRegistry = { + getTool: () => mockTool, + getToolByName: () => mockTool, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByDisplayName: () => mockTool, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + } as unknown as ToolRegistry; + + const mockMessageBus = createMockMessageBus(); + const mockBeforeOutput = new BeforeToolHookOutput({ + continue: true, + hookSpecificOutput: { + hookEventName: 'BeforeTool', + tool_input: { newParam: 'modifiedValue' }, + }, + }); + + const mockHookSystem = { + fireBeforeToolEvent: vi.fn().mockResolvedValue(mockBeforeOutput), + fireAfterToolEvent: vi.fn(), + } as unknown as HookSystem; + + const mockConfig = createMockConfig({ + getToolRegistry: () => toolRegistry, + getMessageBus: () => mockMessageBus, + getHookSystem: () => mockHookSystem, + getApprovalMode: () => ApprovalMode.YOLO, + }); + + const onAllToolCallsComplete = vi.fn(); + const scheduler = new CoreToolScheduler({ + config: mockConfig, + onAllToolCallsComplete, + getPreferredEditor: () => 'vscode', + }); + + const request = { + callId: '1', + name: 'mockTool', + args: { originalParam: 'originalValue' }, + isClientInitiated: false, + prompt_id: 'prompt-1', + }; + + await scheduler.schedule([request], new AbortController().signal); + + await vi.waitFor(() => { + expect(onAllToolCallsComplete).toHaveBeenCalled(); + }); + + const completedCalls = onAllToolCallsComplete.mock + .calls[0][0] as ToolCall[]; + expect(completedCalls[0].status).toBe('success'); + + // Verify execute was called with modified args + expect(executeFn).toHaveBeenCalledWith( + { newParam: 'modifiedValue' }, + expect.anything(), + undefined, + expect.anything(), + ); + + // Verify call request args were updated in the completion report + expect(completedCalls[0].request.args).toEqual({ + newParam: 'modifiedValue', + }); + }); +}); diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 4fd6af2185..e9da886552 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -16,6 +16,7 @@ import { import type { ToolCallConfirmationDetails, ToolResult, + ForcedToolDecision, } from '../tools/tools.js'; import { getResponseText } from '../utils/partUtils.js'; import { reportError } from '../utils/errorReporting.js'; @@ -46,6 +47,7 @@ export interface ServerTool { shouldConfirmExecute( params: Record, abortSignal: AbortSignal, + forcedDecision?: ForcedToolDecision, ): Promise; } diff --git a/packages/core/src/hooks/hookAggregator.ts b/packages/core/src/hooks/hookAggregator.ts index 523bc823fd..af7f93cfab 100644 --- a/packages/core/src/hooks/hookAggregator.ts +++ b/packages/core/src/hooks/hookAggregator.ts @@ -125,6 +125,7 @@ export class HookAggregator { const additionalContexts: string[] = []; let hasBlockDecision = false; + let hasAskDecision = false; let hasContinueFalse = false; for (const output of outputs) { @@ -142,6 +143,12 @@ export class HookAggregator { if (tempOutput.isBlockingDecision()) { hasBlockDecision = true; merged.decision = output.decision; + } else if (tempOutput.isAskDecision()) { + hasAskDecision = true; + // Ask decision is only set if no blocking decision was found so far + if (!hasBlockDecision) { + merged.decision = output.decision; + } } // Collect messages @@ -180,8 +187,8 @@ export class HookAggregator { this.extractAdditionalContext(output, additionalContexts); } - // Set final decision if no blocking decision was found - if (!hasBlockDecision && !hasContinueFalse) { + // Set final decision if no blocking or ask decision was found + if (!hasBlockDecision && !hasAskDecision && !hasContinueFalse) { merged.decision = 'allow'; } diff --git a/packages/core/src/hooks/types.ts b/packages/core/src/hooks/types.ts index 9c6217ffa4..c1a35384ae 100644 --- a/packages/core/src/hooks/types.ts +++ b/packages/core/src/hooks/types.ts @@ -197,12 +197,19 @@ export class DefaultHookOutput implements HookOutput { } /** - * Check if this output represents a blocking decision + * Check if this output represents a blocking decision (block or deny) */ isBlockingDecision(): boolean { return this.decision === 'block' || this.decision === 'deny'; } + /** + * Check if this output represents an 'ask' decision + */ + isAskDecision(): boolean { + return this.decision === 'ask'; + } + /** * Check if this output requests to stop execution */ diff --git a/packages/core/src/scheduler/confirmation.ts b/packages/core/src/scheduler/confirmation.ts index 67ae26d2eb..7db7a0b48f 100644 --- a/packages/core/src/scheduler/confirmation.ts +++ b/packages/core/src/scheduler/confirmation.ts @@ -16,6 +16,7 @@ import { ToolConfirmationOutcome, type ToolConfirmationPayload, type ToolCallConfirmationDetails, + type ForcedToolDecision, } from '../tools/tools.js'; import { type ValidatingToolCall, @@ -116,6 +117,8 @@ export async function resolveConfirmation( getPreferredEditor: () => EditorType | undefined; schedulerId: string; onWaitingForConfirmation?: (waiting: boolean) => void; + systemMessage?: string; + forcedDecision?: ForcedToolDecision; }, ): Promise { const { state, onWaitingForConfirmation } = deps; @@ -126,7 +129,7 @@ export async function resolveConfirmation( // Loop exists to allow the user to modify the parameters and see the new // diff. while (outcome === ToolConfirmationOutcome.ModifyWithEditor) { - if (signal.aborted) throw new Error('Operation cancelled'); + if (signal.aborted) throw new Error('Operation cancelled by user'); const currentCall = state.getToolCall(callId); if (!currentCall || !('invocation' in currentCall)) { @@ -134,12 +137,19 @@ export async function resolveConfirmation( } const currentInvocation = currentCall.invocation; - const details = await currentInvocation.shouldConfirmExecute(signal); + const details = await currentInvocation.shouldConfirmExecute( + signal, + deps.forcedDecision, + ); if (!details) { outcome = ToolConfirmationOutcome.ProceedOnce; break; } + if (deps.systemMessage) { + details.systemMessage = deps.systemMessage; + } + await notifyHooks(deps, details); const correlationId = randomUUID(); diff --git a/packages/core/src/scheduler/policy.test.ts b/packages/core/src/scheduler/policy.test.ts index 05f5b08a2f..ae3b143f0e 100644 --- a/packages/core/src/scheduler/policy.test.ts +++ b/packages/core/src/scheduler/policy.test.ts @@ -519,6 +519,7 @@ describe('Plan Mode Denial Consistency', () => { getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), getMessageBus: vi.fn().mockReturnValue(mockMessageBus), + getHookSystem: vi.fn().mockReturnValue(undefined), isInteractive: vi.fn().mockReturnValue(true), getEnableHooks: vi.fn().mockReturnValue(false), getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.PLAN), // Key: Plan Mode diff --git a/packages/core/src/scheduler/scheduler.test.ts b/packages/core/src/scheduler/scheduler.test.ts index ee5438c319..6fadcf78a9 100644 --- a/packages/core/src/scheduler/scheduler.test.ts +++ b/packages/core/src/scheduler/scheduler.test.ts @@ -170,6 +170,7 @@ describe('Scheduler (Orchestrator)', () => { mockConfig = { getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + getHookSystem: vi.fn().mockReturnValue(undefined), isInteractive: vi.fn().mockReturnValue(true), getEnableHooks: vi.fn().mockReturnValue(true), setApprovalMode: vi.fn(), @@ -1314,6 +1315,7 @@ describe('Scheduler MCP Progress', () => { mockConfig = { getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + getHookSystem: vi.fn().mockReturnValue(undefined), isInteractive: vi.fn().mockReturnValue(true), getEnableHooks: vi.fn().mockReturnValue(true), setApprovalMode: vi.fn(), diff --git a/packages/core/src/scheduler/scheduler.ts b/packages/core/src/scheduler/scheduler.ts index 38e001ea90..e983a11edc 100644 --- a/packages/core/src/scheduler/scheduler.ts +++ b/packages/core/src/scheduler/scheduler.ts @@ -24,6 +24,8 @@ import { type ScheduledToolCall, } from './types.js'; import { ToolErrorType } from '../tools/tool-error.js'; +import { extractMcpContext } from '../core/coreToolHookTriggers.js'; +import { BeforeToolHookOutput } from '../hooks/types.js'; import { PolicyDecision, type ApprovalMode } from '../policy/types.js'; import { ToolConfirmationOutcome, @@ -559,8 +561,95 @@ export class Scheduler { ): Promise { const callId = toolCall.request.callId; + let hookDecision: 'ask' | 'block' | undefined; + let hookSystemMessage: string | undefined; + + const hookSystem = this.config.getHookSystem(); + if (hookSystem) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + const toolInput = (toolCall.invocation.params || {}) as Record< + string, + unknown + >; + const mcpContext = extractMcpContext(toolCall.invocation, this.config); + + const beforeOutput = await hookSystem.fireBeforeToolEvent( + toolCall.request.name, + toolInput, + mcpContext, + toolCall.request.originalRequestName, + ); + + if (beforeOutput) { + if (beforeOutput.shouldStopExecution()) { + this.state.updateStatus( + callId, + CoreToolCallStatus.Error, + createErrorResponse( + toolCall.request, + new Error( + `Agent execution stopped by hook: ${beforeOutput.getEffectiveReason()}`, + ), + ToolErrorType.STOP_EXECUTION, + ), + ); + return; + } + + const blockingError = beforeOutput.getBlockingError(); + if (blockingError?.blocked) { + this.state.updateStatus( + callId, + CoreToolCallStatus.Error, + createErrorResponse( + toolCall.request, + new Error(`Tool execution blocked: ${blockingError.reason}`), + ToolErrorType.POLICY_VIOLATION, + ), + ); + return; + } + + if (beforeOutput.isAskDecision()) { + hookDecision = 'ask'; + hookSystemMessage = beforeOutput.systemMessage; + } + + if (beforeOutput instanceof BeforeToolHookOutput) { + const modifiedInput = beforeOutput.getModifiedToolInput(); + if (modifiedInput) { + toolCall.request.args = modifiedInput; + toolCall.request.inputModifiedByHook = true; + try { + toolCall.invocation = toolCall.tool.build(modifiedInput); + } catch (error) { + this.state.updateStatus( + callId, + CoreToolCallStatus.Error, + createErrorResponse( + toolCall.request, + new Error( + `Tool parameter modification by hook failed validation: ${error instanceof Error ? error.message : String(error)}`, + ), + ToolErrorType.INVALID_TOOL_PARAMS, + ), + ); + return; + } + } + } + } + } + // Policy & Security - const { decision, rule } = await checkPolicy(toolCall, this.config); + const { decision: policyDecision, rule } = await checkPolicy( + toolCall, + this.config, + ); + let decision = policyDecision; + if (hookDecision === 'ask') { + decision = PolicyDecision.ASK_USER; + } if (decision === PolicyDecision.DENY) { const { errorMessage, errorType } = getPolicyDenialError( @@ -593,6 +682,8 @@ export class Scheduler { getPreferredEditor: this.getPreferredEditor, schedulerId: this.schedulerId, onWaitingForConfirmation: this.onWaitingForConfirmation, + systemMessage: hookSystemMessage, + forcedDecision: hookDecision === 'ask' ? 'ask_user' : undefined, }); outcome = result.outcome; lastDetails = result.lastDetails; diff --git a/packages/core/src/scheduler/scheduler_parallel.test.ts b/packages/core/src/scheduler/scheduler_parallel.test.ts index 56e6e26243..60b185e81e 100644 --- a/packages/core/src/scheduler/scheduler_parallel.test.ts +++ b/packages/core/src/scheduler/scheduler_parallel.test.ts @@ -212,6 +212,7 @@ describe('Scheduler Parallel Execution', () => { mockConfig = { getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + getHookSystem: vi.fn().mockReturnValue(undefined), isInteractive: vi.fn().mockReturnValue(true), getEnableHooks: vi.fn().mockReturnValue(true), setApprovalMode: vi.fn(), diff --git a/packages/core/src/scheduler/tool-executor.ts b/packages/core/src/scheduler/tool-executor.ts index 8269f1fc41..d16992c806 100644 --- a/packages/core/src/scheduler/tool-executor.ts +++ b/packages/core/src/scheduler/tool-executor.ts @@ -129,6 +129,20 @@ export class ToolExecutor { const toolResult: ToolResult = await promise; + if (call.request.inputModifiedByHook) { + const modificationMsg = `\n\n[System] Tool input parameters were modified by a hook before execution.`; + if (typeof toolResult.llmContent === 'string') { + toolResult.llmContent += modificationMsg; + } else if (Array.isArray(toolResult.llmContent)) { + toolResult.llmContent.push({ text: modificationMsg }); + } else if (toolResult.llmContent) { + toolResult.llmContent = [ + toolResult.llmContent, + { text: modificationMsg }, + ]; + } + } + if (signal.aborted) { completedToolCall = await this.createCancelledResult( call, diff --git a/packages/core/src/scheduler/types.ts b/packages/core/src/scheduler/types.ts index 9fedd48f41..a9cde87d27 100644 --- a/packages/core/src/scheduler/types.ts +++ b/packages/core/src/scheduler/types.ts @@ -47,6 +47,8 @@ export interface ToolCallRequestInfo { traceId?: string; parentCallId?: string; schedulerId?: string; + inputModifiedByHook?: boolean; + forcedAsk?: boolean; } export interface ToolCallResponseInfo { diff --git a/packages/core/src/telemetry/conseca-logger.test.ts b/packages/core/src/telemetry/conseca-logger.test.ts index e3ce85432e..0eac29276f 100644 --- a/packages/core/src/telemetry/conseca-logger.test.ts +++ b/packages/core/src/telemetry/conseca-logger.test.ts @@ -112,7 +112,7 @@ describe('conseca-logger', () => { 'user prompt', 'policy', 'tool call', - 'ALLOW', + 'allow', 'rationale', ); @@ -122,7 +122,7 @@ describe('conseca-logger', () => { expect(logs.getLogger).toHaveBeenCalled(); expect(mockLogger.emit).toHaveBeenCalledWith( expect.objectContaining({ - body: 'Conseca Verdict: ALLOW.', + body: 'Conseca Verdict: allow.', attributes: expect.objectContaining({ 'event.name': EVENT_CONSECA_VERDICT, }), diff --git a/packages/core/src/test-utils/mock-tool.ts b/packages/core/src/test-utils/mock-tool.ts index 5f89a506cd..0d49dfad29 100644 --- a/packages/core/src/test-utils/mock-tool.ts +++ b/packages/core/src/test-utils/mock-tool.ts @@ -12,12 +12,15 @@ import { BaseDeclarativeTool, BaseToolInvocation, Kind, + type ForcedToolDecision, type ToolCallConfirmationDetails, type ToolInvocation, + type ToolLiveOutput, type ToolResult, } from '../tools/tools.js'; import { createMockMessageBus } from './mock-message-bus.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import type { ShellExecutionConfig } from 'src/services/shellExecutionService.js'; interface MockToolOptions { name: string; @@ -28,11 +31,13 @@ interface MockToolOptions { shouldConfirmExecute?: ( params: { [key: string]: unknown }, signal: AbortSignal, + forcedDecision?: ForcedToolDecision, ) => Promise; execute?: ( params: { [key: string]: unknown }, signal?: AbortSignal, updateOutput?: (output: string) => void, + shellExecutionConfig?: ShellExecutionConfig, ) => Promise; params?: object; messageBus?: MessageBus; @@ -52,19 +57,26 @@ class MockToolInvocation extends BaseToolInvocation< execute( signal: AbortSignal, - updateOutput?: (output: string) => void, + updateOutput?: (output: ToolLiveOutput) => void, + shellExecutionConfig?: ShellExecutionConfig, ): Promise { - if (updateOutput) { - return this.tool.execute(this.params, signal, updateOutput); - } else { - return this.tool.execute(this.params); - } + return this.tool.execute( + this.params, + signal, + updateOutput as ((output: string) => void) | undefined, + shellExecutionConfig, + ); } override shouldConfirmExecute( abortSignal: AbortSignal, + forcedDecision?: ForcedToolDecision, ): Promise { - return this.tool.shouldConfirmExecute(this.params, abortSignal); + return this.tool.shouldConfirmExecute( + this.params, + abortSignal, + forcedDecision, + ); } getDescription(): string { @@ -79,14 +91,17 @@ export class MockTool extends BaseDeclarativeTool< { [key: string]: unknown }, ToolResult > { - shouldConfirmExecute: ( + readonly shouldConfirmExecute: ( params: { [key: string]: unknown }, signal: AbortSignal, + forcedDecision?: ForcedToolDecision, ) => Promise; - execute: ( + + readonly execute: ( params: { [key: string]: unknown }, signal?: AbortSignal, updateOutput?: (output: string) => void, + shellExecutionConfig?: ShellExecutionConfig, ) => Promise; constructor(options: MockToolOptions) { @@ -162,6 +177,7 @@ export class MockModifiableToolInvocation extends BaseToolInvocation< override async shouldConfirmExecute( _abortSignal: AbortSignal, + _forcedDecision?: ForcedToolDecision, ): Promise { if (this.tool.shouldConfirm) { return { diff --git a/packages/core/src/tools/ask-user.ts b/packages/core/src/tools/ask-user.ts index 621d4c10d1..cba0a4f6c8 100644 --- a/packages/core/src/tools/ask-user.ts +++ b/packages/core/src/tools/ask-user.ts @@ -7,6 +7,7 @@ import { BaseDeclarativeTool, BaseToolInvocation, + type ForcedToolDecision, type ToolResult, Kind, type ToolAskUserConfirmationDetails, @@ -126,6 +127,7 @@ export class AskUserInvocation extends BaseToolInvocation< override async shouldConfirmExecute( _abortSignal: AbortSignal, + _forcedDecision?: ForcedToolDecision, ): Promise { const normalizedQuestions = this.params.questions.map((q) => ({ ...q, diff --git a/packages/core/src/tools/confirmation-policy.test.ts b/packages/core/src/tools/confirmation-policy.test.ts index a20bb611e3..361ab3d689 100644 --- a/packages/core/src/tools/confirmation-policy.test.ts +++ b/packages/core/src/tools/confirmation-policy.test.ts @@ -163,7 +163,7 @@ describe('Tool Confirmation Policy Updates', () => { // Mock getMessageBusDecision to trigger ASK_USER flow vi.spyOn(invocation as any, 'getMessageBusDecision').mockResolvedValue( - 'ASK_USER', + 'ask_user', ); const confirmation = await invocation.shouldConfirmExecute( diff --git a/packages/core/src/tools/edit.ts b/packages/core/src/tools/edit.ts index 214875c574..8810a07330 100644 --- a/packages/core/src/tools/edit.ts +++ b/packages/core/src/tools/edit.ts @@ -13,6 +13,7 @@ import { BaseDeclarativeTool, BaseToolInvocation, Kind, + type ForcedToolDecision, type ToolCallConfirmationDetails, type ToolConfirmationOutcome, type ToolEditConfirmationDetails, @@ -705,8 +706,12 @@ class EditToolInvocation */ protected override async getConfirmationDetails( abortSignal: AbortSignal, + forcedDecision?: ForcedToolDecision, ): Promise { - if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) { + if ( + this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT && + forcedDecision !== 'ask_user' + ) { return false; } diff --git a/packages/core/src/tools/enter-plan-mode.test.ts b/packages/core/src/tools/enter-plan-mode.test.ts index 48bc5b494e..d14e1bfcdc 100644 --- a/packages/core/src/tools/enter-plan-mode.test.ts +++ b/packages/core/src/tools/enter-plan-mode.test.ts @@ -47,7 +47,7 @@ describe('EnterPlanModeTool', () => { getMessageBusDecision: () => Promise; }, 'getMessageBusDecision', - ).mockResolvedValue('ASK_USER'); + ).mockResolvedValue('ask_user'); const result = await invocation.shouldConfirmExecute( new AbortController().signal, @@ -74,7 +74,7 @@ describe('EnterPlanModeTool', () => { getMessageBusDecision: () => Promise; }, 'getMessageBusDecision', - ).mockResolvedValue('ALLOW'); + ).mockResolvedValue('allow'); const result = await invocation.shouldConfirmExecute( new AbortController().signal, @@ -92,7 +92,7 @@ describe('EnterPlanModeTool', () => { getMessageBusDecision: () => Promise; }, 'getMessageBusDecision', - ).mockResolvedValue('DENY'); + ).mockResolvedValue('deny'); await expect( invocation.shouldConfirmExecute(new AbortController().signal), @@ -136,7 +136,7 @@ describe('EnterPlanModeTool', () => { getMessageBusDecision: () => Promise; }, 'getMessageBusDecision', - ).mockResolvedValue('ASK_USER'); + ).mockResolvedValue('ask_user'); const details = await invocation.shouldConfirmExecute( new AbortController().signal, diff --git a/packages/core/src/tools/enter-plan-mode.ts b/packages/core/src/tools/enter-plan-mode.ts index d52c721aae..6a6b03f5e6 100644 --- a/packages/core/src/tools/enter-plan-mode.ts +++ b/packages/core/src/tools/enter-plan-mode.ts @@ -7,6 +7,7 @@ import { BaseDeclarativeTool, BaseToolInvocation, + type ForcedToolDecision, type ToolResult, Kind, type ToolInfoConfirmationDetails, @@ -85,13 +86,15 @@ export class EnterPlanModeInvocation extends BaseToolInvocation< override async shouldConfirmExecute( abortSignal: AbortSignal, + forcedDecision?: ForcedToolDecision, ): Promise { - const decision = await this.getMessageBusDecision(abortSignal); - if (decision === 'ALLOW') { + const decision = + forcedDecision ?? (await this.getMessageBusDecision(abortSignal)); + if (decision === 'allow') { return false; } - if (decision === 'DENY') { + if (decision === 'deny') { throw new Error( `Tool execution for "${ this._toolDisplayName || this._toolName @@ -99,7 +102,7 @@ export class EnterPlanModeInvocation extends BaseToolInvocation< ); } - // ASK_USER + // ask_user return { type: 'info', title: 'Enter Plan Mode', diff --git a/packages/core/src/tools/exit-plan-mode.test.ts b/packages/core/src/tools/exit-plan-mode.test.ts index 22de81fc7f..4889872767 100644 --- a/packages/core/src/tools/exit-plan-mode.test.ts +++ b/packages/core/src/tools/exit-plan-mode.test.ts @@ -58,7 +58,7 @@ describe('ExitPlanModeTool', () => { getMessageBusDecision: () => Promise; }, 'getMessageBusDecision', - ).mockResolvedValue('ASK_USER'); + ).mockResolvedValue('ask_user'); }); afterEach(() => { @@ -126,7 +126,7 @@ describe('ExitPlanModeTool', () => { getMessageBusDecision: () => Promise; }, 'getMessageBusDecision', - ).mockResolvedValue('ALLOW'); + ).mockResolvedValue('allow'); const result = await invocation.shouldConfirmExecute( new AbortController().signal, @@ -149,7 +149,7 @@ describe('ExitPlanModeTool', () => { getMessageBusDecision: () => Promise; }, 'getMessageBusDecision', - ).mockResolvedValue('DENY'); + ).mockResolvedValue('deny'); await expect( invocation.shouldConfirmExecute(new AbortController().signal), diff --git a/packages/core/src/tools/exit-plan-mode.ts b/packages/core/src/tools/exit-plan-mode.ts index 442b00e5cb..a1d5e81472 100644 --- a/packages/core/src/tools/exit-plan-mode.ts +++ b/packages/core/src/tools/exit-plan-mode.ts @@ -7,6 +7,7 @@ import { BaseDeclarativeTool, BaseToolInvocation, + type ForcedToolDecision, type ToolResult, Kind, type ToolExitPlanModeConfirmationDetails, @@ -118,6 +119,7 @@ export class ExitPlanModeInvocation extends BaseToolInvocation< override async shouldConfirmExecute( abortSignal: AbortSignal, + forcedDecision?: ForcedToolDecision, ): Promise { const resolvedPlanPath = this.getResolvedPlanPath(); @@ -137,8 +139,9 @@ export class ExitPlanModeInvocation extends BaseToolInvocation< return false; } - const decision = await this.getMessageBusDecision(abortSignal); - if (decision === 'DENY') { + const decision = + forcedDecision ?? (await this.getMessageBusDecision(abortSignal)); + if (decision === 'deny') { throw new Error( `Tool execution for "${ this._toolDisplayName || this._toolName @@ -146,7 +149,7 @@ export class ExitPlanModeInvocation extends BaseToolInvocation< ); } - if (decision === 'ALLOW') { + if (decision === 'allow') { // If policy is allow, auto-approve with default settings and execute. this.confirmationOutcome = ToolConfirmationOutcome.ProceedOnce; this.approvalPayload = { @@ -156,7 +159,7 @@ export class ExitPlanModeInvocation extends BaseToolInvocation< return false; } - // decision is 'ASK_USER' + // decision is 'ask_user' return { type: 'exit_plan_mode', title: 'Plan Approval', diff --git a/packages/core/src/tools/get-internal-docs.ts b/packages/core/src/tools/get-internal-docs.ts index 23bda8f4dd..b185b24ae6 100644 --- a/packages/core/src/tools/get-internal-docs.ts +++ b/packages/core/src/tools/get-internal-docs.ts @@ -10,6 +10,7 @@ import { Kind, type ToolInvocation, type ToolResult, + type ForcedToolDecision, type ToolCallConfirmationDetails, } from './tools.js'; import { GET_INTERNAL_DOCS_TOOL_NAME } from './tool-names.js'; @@ -85,6 +86,7 @@ class GetInternalDocsInvocation extends BaseToolInvocation< override async shouldConfirmExecute( _abortSignal: AbortSignal, + _forcedDecision?: ForcedToolDecision, ): Promise { return false; } diff --git a/packages/core/src/tools/mcp-tool.ts b/packages/core/src/tools/mcp-tool.ts index 2c52c72573..baf4ae302d 100644 --- a/packages/core/src/tools/mcp-tool.ts +++ b/packages/core/src/tools/mcp-tool.ts @@ -10,6 +10,7 @@ import { BaseToolInvocation, Kind, ToolConfirmationOutcome, + type ForcedToolDecision, type ToolCallConfirmationDetails, type ToolInvocation, type ToolMcpConfirmationDetails, @@ -117,6 +118,7 @@ export class DiscoveredMCPToolInvocation extends BaseToolInvocation< protected override async getConfirmationDetails( _abortSignal: AbortSignal, + _forcedDecision?: ForcedToolDecision, ): Promise { const serverAllowListKey = this.serverName; const toolAllowListKey = `${this.serverName}.${this.serverToolName}`; diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts index 68a0942a53..41b1572623 100644 --- a/packages/core/src/tools/memoryTool.ts +++ b/packages/core/src/tools/memoryTool.ts @@ -9,6 +9,7 @@ import { BaseToolInvocation, Kind, ToolConfirmationOutcome, + type ForcedToolDecision, type ToolEditConfirmationDetails, type ToolResult, } from './tools.js'; @@ -163,6 +164,7 @@ class MemoryToolInvocation extends BaseToolInvocation< protected override async getConfirmationDetails( _abortSignal: AbortSignal, + _forcedDecision?: ForcedToolDecision, ): Promise { const memoryFilePath = getGlobalMemoryFilePath(); const allowlistKey = memoryFilePath; diff --git a/packages/core/src/tools/message-bus-integration.test.ts b/packages/core/src/tools/message-bus-integration.test.ts index bfc369b58b..91a2e30d94 100644 --- a/packages/core/src/tools/message-bus-integration.test.ts +++ b/packages/core/src/tools/message-bus-integration.test.ts @@ -57,10 +57,10 @@ class TestToolInvocation extends BaseToolInvocation { abortSignal: AbortSignal, ): Promise { const decision = await this.getMessageBusDecision(abortSignal); - if (decision === 'ALLOW') { + if (decision === 'allow') { return false; } - if (decision === 'DENY') { + if (decision === 'deny') { throw new Error('Tool execution denied by policy'); } return false; diff --git a/packages/core/src/tools/shell.ts b/packages/core/src/tools/shell.ts index 4ea83b0af4..778f4b8f0f 100644 --- a/packages/core/src/tools/shell.ts +++ b/packages/core/src/tools/shell.ts @@ -16,6 +16,7 @@ import { BaseToolInvocation, ToolConfirmationOutcome, Kind, + type ForcedToolDecision, type ToolInvocation, type ToolResult, type ToolCallConfirmationDetails, @@ -109,6 +110,7 @@ export class ShellToolInvocation extends BaseToolInvocation< protected override async getConfirmationDetails( _abortSignal: AbortSignal, + _forcedDecision?: ForcedToolDecision, ): Promise { const command = stripShellWrapper(this.params.command); diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 0a82cc1510..b68579d6d7 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -21,6 +21,11 @@ import { import { type ApprovalMode } from '../policy/types.js'; import type { SubagentProgress } from '../agents/types.js'; +/** + * Supported decisions for forcing tool execution behavior. + */ +export type ForcedToolDecision = 'allow' | 'deny' | 'ask_user'; + /** * Represents a validated and ready-to-execute tool call. * An instance of this is created by a `ToolBuilder`. @@ -53,9 +58,10 @@ export interface ToolInvocation< * @param abortSignal An AbortSignal that can be used to cancel the confirmation request. * @returns A ToolCallConfirmationDetails object if confirmation is required, or false if not. */ - shouldConfirmExecute( + shouldConfirmExecute: ( abortSignal: AbortSignal, - ): Promise; + forcedDecision?: ForcedToolDecision, + ) => Promise; /** * Executes the tool with the validated parameters. @@ -103,13 +109,15 @@ export abstract class BaseToolInvocation< async shouldConfirmExecute( abortSignal: AbortSignal, + forcedDecision?: ForcedToolDecision, ): Promise { - const decision = await this.getMessageBusDecision(abortSignal); - if (decision === 'ALLOW') { + const decision = + forcedDecision ?? (await this.getMessageBusDecision(abortSignal)); + if (decision === 'allow') { return false; } - if (decision === 'DENY') { + if (decision === 'deny') { throw new Error( `Tool execution for "${ this._toolDisplayName || this._toolName @@ -117,12 +125,12 @@ export abstract class BaseToolInvocation< ); } - if (decision === 'ASK_USER') { - return this.getConfirmationDetails(abortSignal); + if (decision === 'ask_user') { + return this.getConfirmationDetails(abortSignal, forcedDecision); } // Default to confirmation details if decision is unknown (should not happen with exhaustive policy) - return this.getConfirmationDetails(abortSignal); + return this.getConfirmationDetails(abortSignal, forcedDecision); } /** @@ -161,11 +169,12 @@ export abstract class BaseToolInvocation< /** * Subclasses should override this method to provide custom confirmation UI - * when the policy engine's decision is 'ASK_USER'. + * when the policy engine's decision is 'ask_user'. * The base implementation provides a generic confirmation prompt. */ protected async getConfirmationDetails( _abortSignal: AbortSignal, + _forcedDecision?: ForcedToolDecision, ): Promise { if (!this.messageBus) { return false; @@ -184,11 +193,11 @@ export abstract class BaseToolInvocation< protected getMessageBusDecision( abortSignal: AbortSignal, - ): Promise<'ALLOW' | 'DENY' | 'ASK_USER'> { + ): Promise { if (!this.messageBus || !this._toolName) { // If there's no message bus, we can't make a decision, so we allow. // The legacy confirmation flow will still apply if the tool needs it. - return Promise.resolve('ALLOW'); + return Promise.resolve('allow'); } const correlationId = randomUUID(); @@ -204,9 +213,9 @@ export abstract class BaseToolInvocation< toolAnnotations: this._toolAnnotations, }; - return new Promise<'ALLOW' | 'DENY' | 'ASK_USER'>((resolve) => { + return new Promise((resolve) => { if (!this.messageBus) { - resolve('ALLOW'); + resolve('allow'); return; } @@ -227,11 +236,11 @@ export abstract class BaseToolInvocation< const abortHandler = () => { cleanup(); - resolve('DENY'); + resolve('deny'); }; if (abortSignal.aborted) { - resolve('DENY'); + resolve('deny'); return; } @@ -239,11 +248,11 @@ export abstract class BaseToolInvocation< if (response.correlationId === correlationId) { cleanup(); if (response.requiresUserConfirmation) { - resolve('ASK_USER'); + resolve('ask_user'); } else if (response.confirmed) { - resolve('ALLOW'); + resolve('allow'); } else { - resolve('DENY'); + resolve('deny'); } } }; @@ -252,7 +261,7 @@ export abstract class BaseToolInvocation< timeoutId = setTimeout(() => { cleanup(); - resolve('ASK_USER'); // Default to ASK_USER on timeout + resolve('ask_user'); // Default to ask_user on timeout }, 30000); this.messageBus.subscribe( @@ -270,7 +279,7 @@ export abstract class BaseToolInvocation< void this.messageBus.publish(request); } catch (_error) { cleanup(); - resolve('ALLOW'); + resolve('allow'); } }); } @@ -729,6 +738,7 @@ export interface DiffStat { export interface ToolEditConfirmationDetails { type: 'edit'; title: string; + systemMessage?: string; onConfirm: ( outcome: ToolConfirmationOutcome, payload?: ToolConfirmationPayload, @@ -767,6 +777,7 @@ export type ToolConfirmationPayload = export interface ToolExecuteConfirmationDetails { type: 'exec'; title: string; + systemMessage?: string; onConfirm: (outcome: ToolConfirmationOutcome) => Promise; command: string; rootCommand: string; @@ -777,6 +788,7 @@ export interface ToolExecuteConfirmationDetails { export interface ToolMcpConfirmationDetails { type: 'mcp'; title: string; + systemMessage?: string; serverName: string; toolName: string; toolDisplayName: string; @@ -789,6 +801,7 @@ export interface ToolMcpConfirmationDetails { export interface ToolInfoConfirmationDetails { type: 'info'; title: string; + systemMessage?: string; onConfirm: (outcome: ToolConfirmationOutcome) => Promise; prompt: string; urls?: string[]; @@ -797,6 +810,7 @@ export interface ToolInfoConfirmationDetails { export interface ToolAskUserConfirmationDetails { type: 'ask_user'; title: string; + systemMessage?: string; questions: Question[]; onConfirm: ( outcome: ToolConfirmationOutcome, @@ -807,6 +821,7 @@ export interface ToolAskUserConfirmationDetails { export interface ToolExitPlanModeConfirmationDetails { type: 'exit_plan_mode'; title: string; + systemMessage?: string; planPath: string; onConfirm: ( outcome: ToolConfirmationOutcome, diff --git a/packages/core/src/tools/web-fetch.ts b/packages/core/src/tools/web-fetch.ts index 3170227188..fb196794c4 100644 --- a/packages/core/src/tools/web-fetch.ts +++ b/packages/core/src/tools/web-fetch.ts @@ -8,6 +8,7 @@ import { BaseDeclarativeTool, BaseToolInvocation, Kind, + type ForcedToolDecision, type ToolCallConfirmationDetails, type ToolInvocation, type ToolResult, @@ -293,6 +294,7 @@ ${textContent} protected override async getConfirmationDetails( _abortSignal: AbortSignal, + _forcedDecision?: ForcedToolDecision, ): Promise { // Check for AUTO_EDIT approval mode. This tool has a specific behavior // where ProceedAlways switches the entire session to AUTO_EDIT. diff --git a/packages/core/src/tools/write-file.ts b/packages/core/src/tools/write-file.ts index 8ec660b661..4aa1cf48a0 100644 --- a/packages/core/src/tools/write-file.ts +++ b/packages/core/src/tools/write-file.ts @@ -17,6 +17,7 @@ import { BaseDeclarativeTool, BaseToolInvocation, Kind, + type ForcedToolDecision, type FileDiff, type ToolCallConfirmationDetails, type ToolEditConfirmationDetails, @@ -174,8 +175,12 @@ class WriteFileToolInvocation extends BaseToolInvocation< protected override async getConfirmationDetails( abortSignal: AbortSignal, + forcedDecision?: ForcedToolDecision, ): Promise { - if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) { + if ( + this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT && + forcedDecision !== 'ask_user' + ) { return false; }