From d1dc4902fd5b9f3bc87d122490eaf83d15cf4046 Mon Sep 17 00:00:00 2001 From: Christian Gunderman Date: Sat, 21 Mar 2026 03:52:39 +0000 Subject: [PATCH] fix(hooks): support 'ask' decision for BeforeTool hooks (#21146) --- integration-tests/hooks-system.test.ts | 217 +++++++++++- .../messages/ToolConfirmationMessage.tsx | 9 + packages/cli/src/ui/hooks/useGeminiStream.ts | 10 +- .../core/src/confirmation-bus/message-bus.ts | 4 +- packages/core/src/confirmation-bus/types.ts | 10 + .../src/core/coreToolHookTriggers.test.ts | 19 +- .../core/src/core/coreToolHookTriggers.ts | 9 +- .../core/src/core/coreToolScheduler.test.ts | 57 +++- packages/core/src/core/coreToolScheduler.ts | 69 +++- .../src/core/coreToolSchedulerHooks.test.ts | 312 ++++++++++++++++++ packages/core/src/hooks/hookAggregator.ts | 11 +- packages/core/src/hooks/types.ts | 9 +- packages/core/src/scheduler/confirmation.ts | 14 +- packages/core/src/scheduler/hook-utils.ts | 109 ++++++ packages/core/src/scheduler/policy.test.ts | 1 + packages/core/src/scheduler/scheduler.test.ts | 3 + packages/core/src/scheduler/scheduler.ts | 41 ++- .../src/scheduler/scheduler_parallel.test.ts | 2 + packages/core/src/scheduler/tool-executor.ts | 15 + packages/core/src/scheduler/types.ts | 2 + .../core/src/telemetry/conseca-logger.test.ts | 4 +- packages/core/src/test-utils/mock-tool.ts | 29 +- .../src/tools/confirmation-policy.test.ts | 36 +- packages/core/src/tools/edit.ts | 16 +- .../core/src/tools/enter-plan-mode.test.ts | 8 +- packages/core/src/tools/enter-plan-mode.ts | 6 +- .../core/src/tools/exit-plan-mode.test.ts | 6 +- packages/core/src/tools/exit-plan-mode.ts | 6 +- .../src/tools/message-bus-integration.test.ts | 4 +- packages/core/src/tools/tools.ts | 61 +++- packages/core/src/tools/web-fetch.ts | 18 +- packages/core/src/tools/write-file.ts | 16 +- 32 files changed, 1016 insertions(+), 117 deletions(-) create mode 100644 packages/core/src/core/coreToolSchedulerHooks.test.ts create mode 100644 packages/core/src/scheduler/hook-utils.ts diff --git a/integration-tests/hooks-system.test.ts b/integration-tests/hooks-system.test.ts index 479851957b..4fe63a3ab6 100644 --- a/integration-tests/hooks-system.test.ts +++ b/integration-tests/hooks-system.test.ts @@ -7,9 +7,10 @@ import { describe, it, expect, beforeEach, afterEach } from 'vitest'; import { TestRig, poll, normalizePath } from './test-helper.js'; import { join } from 'node:path'; -import { writeFileSync } from 'node:fs'; +import { writeFileSync, existsSync, mkdirSync } from 'node:fs'; +import os from 'node:os'; -describe('Hooks System Integration', () => { +describe('Hooks System Integration', { timeout: 120000 }, () => { let rig: TestRig; beforeEach(() => { @@ -2016,6 +2017,10 @@ console.log(JSON.stringify({ // 3. Final setup with full settings rig.setup('Hook Disabling Multiple Ops', { + fakeResponsesPath: join( + import.meta.dirname, + 'hooks-system.disabled-via-command.responses', + ), settings: { hooksConfig: { enabled: true, @@ -2230,7 +2235,7 @@ console.log(JSON.stringify({ // The hook should have stopped execution message (returned from tool) expect(result).toContain( - 'Agent execution stopped: Emergency Stop triggered by hook', + 'Agent execution stopped by hook: Emergency Stop triggered by hook', ); // Tool should NOT be called successfully (it was blocked/stopped) @@ -2242,4 +2247,210 @@ console.log(JSON.stringify({ expect(writeFileCalls).toHaveLength(0); }); }); + + describe('Hooks "ask" Decision Integration', () => { + it( + 'should force confirmation prompt when hook returns "ask" decision even in YOLO mode', + { timeout: 60000 }, + async () => { + const testName = + 'should force confirmation prompt when hook returns "ask" decision even in YOLO mode'; + + // 1. Setup hook script that returns 'ask' decision + const hookOutput = { + decision: 'ask', + systemMessage: 'Confirmation forced by security hook', + hookSpecificOutput: { + hookEventName: 'BeforeTool', + }, + }; + + const hookScript = `console.log(JSON.stringify(${JSON.stringify( + hookOutput, + )}));`; + + // Create script path predictably + const scriptPath = join(os.tmpdir(), 'gemini-cli-tests-ask-hook.js'); + writeFileSync(scriptPath, hookScript); + + // 2. Setup rig with YOLO mode enabled but with the 'ask' hook + rig.setup(testName, { + fakeResponsesPath: join( + import.meta.dirname, + 'hooks-system.allow-tool.responses', + ), + settings: { + debugMode: true, + tools: { + approval: 'yolo', + }, + general: { + enableAutoUpdateNotification: false, + }, + hooksConfig: { + enabled: true, + }, + hooks: { + BeforeTool: [ + { + matcher: 'write_file', + hooks: [ + { + type: 'command', + command: `node "${scriptPath}"`, + timeout: 5000, + }, + ], + }, + ], + }, + }, + }); + + // Bypass terminal setup prompt and other startup banners + const stateDir = join(rig.homeDir!, '.gemini'); + if (!existsSync(stateDir)) mkdirSync(stateDir, { recursive: true }); + writeFileSync( + join(stateDir, 'state.json'), + JSON.stringify({ + terminalSetupPromptShown: true, + hasSeenScreenReaderNudge: true, + tipsShown: 100, + }), + ); + + // 3. Run interactive and verify prompt appears despite YOLO mode + const run = await rig.runInteractive(); + + // Wait for prompt to appear + await run.expectText('Type your message', 30000); + + // Send prompt that will trigger write_file + await run.type('Create a file called ask-test.txt with content "test"'); + await run.type('\r'); + + // Wait for the FORCED confirmation prompt to appear + // It should contain the system message from the hook + await run.expectText('Confirmation forced by security hook', 30000); + await run.expectText('Allow', 5000); + + // 4. Approve the permission + await run.type('y'); + await run.type('\r'); + + // Wait for command to execute + await run.expectText('approved.txt', 30000); + + // Should find the tool call + const foundWriteFile = await rig.waitForToolCall('write_file'); + expect(foundWriteFile).toBeTruthy(); + + // File should be created + const fileContent = rig.readFile('approved.txt'); + expect(fileContent).toBe('Approved content'); + }, + ); + + it( + 'should allow cancelling when hook forces "ask" decision', + { timeout: 60000 }, + async () => { + const testName = + 'should allow cancelling when hook forces "ask" decision'; + const hookOutput = { + decision: 'ask', + systemMessage: 'Confirmation forced for cancellation test', + hookSpecificOutput: { + hookEventName: 'BeforeTool', + }, + }; + + const hookScript = `console.log(JSON.stringify(${JSON.stringify( + hookOutput, + )}));`; + + const scriptPath = join( + os.tmpdir(), + 'gemini-cli-tests-ask-cancel-hook.js', + ); + writeFileSync(scriptPath, hookScript); + + rig.setup(testName, { + fakeResponsesPath: join( + import.meta.dirname, + 'hooks-system.allow-tool.responses', + ), + settings: { + debugMode: true, + tools: { + approval: 'yolo', + }, + general: { + enableAutoUpdateNotification: false, + }, + hooksConfig: { + enabled: true, + }, + hooks: { + BeforeTool: [ + { + matcher: 'write_file', + hooks: [ + { + type: 'command', + command: `node "${scriptPath}"`, + timeout: 5000, + }, + ], + }, + ], + }, + }, + }); + + // Bypass terminal setup prompt and other startup banners + const stateDir = join(rig.homeDir!, '.gemini'); + if (!existsSync(stateDir)) mkdirSync(stateDir, { recursive: true }); + writeFileSync( + join(stateDir, 'state.json'), + JSON.stringify({ + terminalSetupPromptShown: true, + hasSeenScreenReaderNudge: true, + tipsShown: 100, + }), + ); + + const run = await rig.runInteractive(); + + // Wait for prompt to appear + await run.expectText('Type your message', 30000); + + await run.type( + 'Create a file called cancel-test.txt with content "test"', + ); + await run.type('\r'); + + await run.expectText( + 'Confirmation forced for cancellation test', + 30000, + ); + + // 4. Deny the permission using option 4 + await run.type('4'); + await run.type('\r'); + + // Wait for cancellation message + await run.expectText('Cancelled', 15000); + + // Tool should NOT be called successfully + const toolLogs = rig.readToolLogs(); + const writeFileCalls = toolLogs.filter( + (t) => + t.toolRequest.name === 'write_file' && + t.toolRequest.success === true, + ); + expect(writeFileCalls).toHaveLength(0); + }, + ); + }); }); diff --git a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx index 45584a9d46..6d6d85780c 100644 --- a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx +++ b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx @@ -735,6 +735,15 @@ export const ToolConfirmationMessage: React.FC< paddingTop={0} paddingBottom={handlesOwnUI ? 0 : 1} > + {/* System message from hook */} + {confirmationDetails.systemMessage && ( + + + {confirmationDetails.systemMessage} + + + )} + {handlesOwnUI ? ( bodyContent ) : ( diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 2034e14b87..f82e32a6c1 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -548,11 +548,9 @@ export const useGeminiStream = ( if (tc.request.name === ASK_USER_TOOL_NAME && isInProgress) { return false; } - return ( - tc.status !== 'scheduled' && - tc.status !== 'validating' && - tc.status !== 'awaiting_approval' - ); + // ToolGroupMessage now shows all non-canceled tools, so they are visible + // in pending and we need to draw the closing border for them. + return true; }); if ( @@ -1658,7 +1656,7 @@ export const useGeminiStream = ( ) { let awaitingApprovalCalls = toolCalls.filter( (call): call is TrackedWaitingToolCall => - call.status === 'awaiting_approval', + call.status === 'awaiting_approval' && !call.request.forcedAsk, ); // For AUTO_EDIT mode, only approve edit tools (replace, write_file) diff --git a/packages/core/src/confirmation-bus/message-bus.ts b/packages/core/src/confirmation-bus/message-bus.ts index 5495996d25..72f1c1c15a 100644 --- a/packages/core/src/confirmation-bus/message-bus.ts +++ b/packages/core/src/confirmation-bus/message-bus.ts @@ -83,13 +83,15 @@ export class MessageBus extends EventEmitter { } if (message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) { - const { decision } = await this.policyEngine.check( + const { decision: policyDecision } = await this.policyEngine.check( message.toolCall, message.serverName, message.toolAnnotations, message.subagent, ); + const decision = message.forcedDecision ?? policyDecision; + switch (decision) { case PolicyDecision.ALLOW: // Directly emit the response instead of recursive publish diff --git a/packages/core/src/confirmation-bus/types.ts b/packages/core/src/confirmation-bus/types.ts index 91aeab8308..ceb1c96296 100644 --- a/packages/core/src/confirmation-bus/types.ts +++ b/packages/core/src/confirmation-bus/types.ts @@ -46,6 +46,10 @@ export interface ToolConfirmationRequest { * Optional rich details for the confirmation UI (diffs, counts, etc.) */ details?: SerializableConfirmationDetails; + /** + * Optional decision to force for this tool call, bypassing the policy engine. + */ + forcedDecision?: 'allow' | 'deny' | 'ask_user'; } export interface ToolConfirmationResponse { @@ -76,12 +80,14 @@ export type SerializableConfirmationDetails = | { type: 'info'; title: string; + systemMessage?: string; prompt: string; urls?: string[]; } | { type: 'edit'; title: string; + systemMessage?: string; fileName: string; filePath: string; fileDiff: string; @@ -92,6 +98,7 @@ export type SerializableConfirmationDetails = | { type: 'exec'; title: string; + systemMessage?: string; command: string; rootCommand: string; rootCommands: string[]; @@ -100,6 +107,7 @@ export type SerializableConfirmationDetails = | { type: 'mcp'; title: string; + systemMessage?: string; serverName: string; toolName: string; toolDisplayName: string; @@ -110,11 +118,13 @@ export type SerializableConfirmationDetails = | { type: 'ask_user'; title: string; + systemMessage?: string; questions: Question[]; } | { type: 'exit_plan_mode'; title: string; + systemMessage?: string; planPath: string; }; diff --git a/packages/core/src/core/coreToolHookTriggers.test.ts b/packages/core/src/core/coreToolHookTriggers.test.ts index 414064ff85..60c6836452 100644 --- a/packages/core/src/core/coreToolHookTriggers.test.ts +++ b/packages/core/src/core/coreToolHookTriggers.test.ts @@ -16,10 +16,8 @@ import { import type { MessageBus } from '../confirmation-bus/message-bus.js'; import type { HookSystem } from '../hooks/hookSystem.js'; import type { Config } from '../config/config.js'; -import { - type DefaultHookOutput, - BeforeToolHookOutput, -} from '../hooks/types.js'; +import type { DefaultHookOutput } from '../hooks/types.js'; +import { BeforeToolHookOutput } from '../hooks/types.js'; class MockInvocation extends BaseToolInvocation<{ key?: string }, ToolResult> { constructor(params: { key?: string }, messageBus: MessageBus) { @@ -140,18 +138,11 @@ describe('executeToolWithHooks', () => { expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED); expect(result.error?.message).toBe('Execution blocked'); }); - it('should handle continue: false in AfterTool', async () => { const invocation = new MockInvocation({}, messageBus); const abortSignal = new AbortController().signal; const spy = vi.spyOn(invocation, 'execute'); - vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({ - shouldStopExecution: () => false, - getEffectiveReason: () => '', - getBlockingError: () => ({ blocked: false, reason: '' }), - } as unknown as DefaultHookOutput); - vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue({ shouldStopExecution: () => true, getEffectiveReason: () => 'Stop after execution', @@ -177,12 +168,6 @@ describe('executeToolWithHooks', () => { const invocation = new MockInvocation({}, messageBus); const abortSignal = new AbortController().signal; - vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({ - shouldStopExecution: () => false, - getEffectiveReason: () => '', - getBlockingError: () => ({ blocked: false, reason: '' }), - } as unknown as DefaultHookOutput); - vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue({ shouldStopExecution: () => false, getEffectiveReason: () => '', diff --git a/packages/core/src/core/coreToolHookTriggers.ts b/packages/core/src/core/coreToolHookTriggers.ts index 6bff4cfdd5..c2748cbd0a 100644 --- a/packages/core/src/core/coreToolHookTriggers.ts +++ b/packages/core/src/core/coreToolHookTriggers.ts @@ -14,8 +14,8 @@ import type { ExecuteOptions, } from '../tools/tools.js'; import { ToolErrorType } from '../tools/tool-error.js'; -import { debugLogger } from '../utils/debugLogger.js'; import { DiscoveredMCPToolInvocation } from '../tools/mcp-tool.js'; +import { debugLogger } from '../utils/debugLogger.js'; /** * Extracts MCP context from a tool invocation if it's an MCP tool. @@ -24,7 +24,7 @@ import { DiscoveredMCPToolInvocation } from '../tools/mcp-tool.js'; * @param config Config to look up server details * @returns MCP context if this is an MCP tool, undefined otherwise */ -function extractMcpContext( +export function extractMcpContext( invocation: AnyToolInvocation, config: Config, ): McpToolContext | undefined { @@ -74,6 +74,7 @@ export async function executeToolWithHooks( options?: ExecuteOptions, config?: Config, originalRequestName?: string, + skipBeforeHook?: boolean, ): Promise { // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion const toolInput = (invocation.params || {}) as Record; @@ -82,9 +83,9 @@ export async function executeToolWithHooks( // Extract MCP context if this is an MCP tool (only if config is provided) const mcpContext = config ? extractMcpContext(invocation, config) : undefined; - const hookSystem = config?.getHookSystem(); - if (hookSystem) { + + if (hookSystem && !skipBeforeHook) { const beforeOutput = await hookSystem.fireBeforeToolEvent( toolName, toolInput, diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index 3a9d0e2e92..c897e4ed30 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -287,6 +287,7 @@ function createMockConfig(overrides: Partial = {}): Config { getGeminiClient: () => null, getMessageBus: () => createMockMessageBus(), getEnableHooks: () => false, + getHookSystem: () => undefined, getExperiments: () => {}, } as unknown as Config; @@ -1028,7 +1029,12 @@ describe('CoreToolScheduler YOLO mode', () => { // Assert // 1. The tool's execute method was called directly. - expect(executeFn).toHaveBeenCalledWith({ param: 'value' }); + expect(executeFn).toHaveBeenCalledWith( + { param: 'value' }, + expect.anything(), + undefined, + expect.anything(), + ); // 2. The tool call status never entered CoreToolCallStatus.AwaitingApproval. const statusUpdates = onToolCallsUpdate.mock.calls @@ -1131,7 +1137,12 @@ describe('CoreToolScheduler request queueing', () => { ); // Ensure the second tool call hasn't been executed yet. - expect(executeFn).toHaveBeenCalledWith({ a: 1 }); + expect(executeFn).toHaveBeenCalledWith( + { a: 1 }, + expect.anything(), + undefined, + expect.anything(), + ); // Complete the first tool call. resolveFirstCall!({ @@ -1155,7 +1166,12 @@ describe('CoreToolScheduler request queueing', () => { // Now the second tool call should have been executed. expect(executeFn).toHaveBeenCalledTimes(2); }); - expect(executeFn).toHaveBeenCalledWith({ b: 2 }); + expect(executeFn).toHaveBeenCalledWith( + { b: 2 }, + expect.anything(), + undefined, + expect.anything(), + ); // Wait for the second completion. await vi.waitFor(() => { @@ -1250,7 +1266,12 @@ describe('CoreToolScheduler request queueing', () => { // Assert // 1. The tool's execute method was called directly. - expect(executeFn).toHaveBeenCalledWith({ param: 'value' }); + expect(executeFn).toHaveBeenCalledWith( + { param: 'value' }, + expect.anything(), + undefined, + expect.anything(), + ); // 2. The tool call status never entered CoreToolCallStatus.AwaitingApproval. const statusUpdates = onToolCallsUpdate.mock.calls @@ -1432,8 +1453,18 @@ describe('CoreToolScheduler request queueing', () => { // Ensure the tool was called twice with the correct arguments. expect(executeFn).toHaveBeenCalledTimes(2); - expect(executeFn).toHaveBeenCalledWith({ a: 1 }); - expect(executeFn).toHaveBeenCalledWith({ b: 2 }); + expect(executeFn).toHaveBeenCalledWith( + { a: 1 }, + expect.anything(), + undefined, + expect.anything(), + ); + expect(executeFn).toHaveBeenCalledWith( + { b: 2 }, + expect.anything(), + undefined, + expect.anything(), + ); // Ensure completion callbacks were called twice. expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2); @@ -1790,8 +1821,18 @@ describe('CoreToolScheduler Sequential Execution', () => { // Check that execute was called for the first two tools only expect(executeFn).toHaveBeenCalledTimes(2); - expect(executeFn).toHaveBeenCalledWith({ call: 1 }); - expect(executeFn).toHaveBeenCalledWith({ call: 2 }); + expect(executeFn).toHaveBeenCalledWith( + { call: 1 }, + expect.anything(), + undefined, + expect.anything(), + ); + expect(executeFn).toHaveBeenCalledWith( + { call: 2 }, + expect.anything(), + undefined, + expect.anything(), + ); const completedCalls = onAllToolCallsComplete.mock .calls[0][0] as ToolCall[]; diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 1ecae4ef33..8aabd709c2 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -49,6 +49,7 @@ import { ToolExecutor } from '../scheduler/tool-executor.js'; import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; import { getPolicyDenialError } from '../scheduler/policy.js'; import { GeminiCliOperation } from '../telemetry/constants.js'; +import { evaluateBeforeToolHook } from '../scheduler/hook-utils.js'; import type { AgentLoopContext } from '../config/agent-loop-context.js'; export type { @@ -602,7 +603,7 @@ export class CoreToolScheduler { return; } - const toolCall = this.toolCallQueue.shift()!; + let toolCall = this.toolCallQueue.shift()!; // This is now the single active tool call. this.toolCalls = [toolCall]; @@ -618,7 +619,7 @@ export class CoreToolScheduler { // This logic is moved from the old `for` loop in `_schedule`. if (toolCall.status === CoreToolCallStatus.Validating) { - const { request: reqInfo, invocation } = toolCall; + let { request: reqInfo } = toolCall; try { if (signal.aborted) { @@ -633,7 +634,49 @@ export class CoreToolScheduler { return; } - // Policy Check using PolicyEngine + // 1. Hook Check (BeforeTool) + const hookResult = await evaluateBeforeToolHook( + this.context.config, + toolCall.tool, + toolCall.request, + toolCall.invocation, + ); + + if (hookResult.status === 'error') { + this.setStatusInternal( + reqInfo.callId, + CoreToolCallStatus.Error, + signal, + createErrorResponse( + toolCall.request, + hookResult.error, + hookResult.errorType, + ), + ); + await this.checkAndNotifyCompletion(signal); + return; + } + + const { hookDecision, hookSystemMessage, modifiedArgs, newInvocation } = + hookResult; + + if (modifiedArgs && newInvocation) { + this.setArgsInternal(reqInfo.callId, modifiedArgs); + // Re-retrieve toolCall as it was updated in the array by setArgsInternal + const updatedCall = this.toolCalls.find( + (c) => c.request.callId === reqInfo.callId, + ); + if ( + updatedCall && + updatedCall.status === CoreToolCallStatus.Validating + ) { + toolCall = updatedCall; + } + toolCall.request.inputModifiedByHook = true; + reqInfo = toolCall.request; + } + + // 2. Policy Check using PolicyEngine // We must reconstruct the FunctionCall format expected by PolicyEngine const toolCallForPolicy = { name: toolCall.request.name, @@ -645,11 +688,16 @@ export class CoreToolScheduler { : undefined; const toolAnnotations = toolCall.tool.toolAnnotations; - const { decision, rule } = await this.context.config + const { decision: policyDecision, rule } = await this.context.config .getPolicyEngine() .check(toolCallForPolicy, serverName, toolAnnotations); - if (decision === PolicyDecision.DENY) { + let finalDecision = policyDecision; + if (hookDecision === 'ask') { + finalDecision = PolicyDecision.ASK_USER; + } + + if (finalDecision === PolicyDecision.DENY) { const { errorMessage, errorType } = getPolicyDenialError( this.context.config, rule, @@ -664,7 +712,7 @@ export class CoreToolScheduler { return; } - if (decision === PolicyDecision.ALLOW) { + if (finalDecision === PolicyDecision.ALLOW) { this.setToolCallOutcome( reqInfo.callId, ToolConfirmationOutcome.ProceedAlways, @@ -679,7 +727,10 @@ export class CoreToolScheduler { // We need confirmation details to show to the user const confirmationDetails = - await invocation.shouldConfirmExecute(signal); + await toolCall.invocation.shouldConfirmExecute( + signal, + hookDecision === 'ask' ? 'ask_user' : undefined, + ); if (!confirmationDetails) { this.setToolCallOutcome( @@ -700,6 +751,10 @@ export class CoreToolScheduler { ); } + if (hookSystemMessage) { + confirmationDetails.systemMessage = hookSystemMessage; + } + // Fire Notification hook before showing confirmation to user const hookSystem = this.context.config.getHookSystem(); if (hookSystem) { diff --git a/packages/core/src/core/coreToolSchedulerHooks.test.ts b/packages/core/src/core/coreToolSchedulerHooks.test.ts new file mode 100644 index 0000000000..63c22e0b11 --- /dev/null +++ b/packages/core/src/core/coreToolSchedulerHooks.test.ts @@ -0,0 +1,312 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi } from 'vitest'; +import { CoreToolScheduler } from './coreToolScheduler.js'; +import type { ToolCall, ErroredToolCall } from '../scheduler/types.js'; +import type { Config, ToolRegistry, AgentLoopContext } from '../index.js'; +import { + ApprovalMode, + DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, +} from '../index.js'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; +import { MockTool } from '../test-utils/mock-tool.js'; +import { DEFAULT_GEMINI_MODEL } from '../config/models.js'; +import type { PolicyEngine } from '../policy/policy-engine.js'; +import type { HookSystem } from '../hooks/hookSystem.js'; +import { BeforeToolHookOutput } from '../hooks/types.js'; + +function createMockConfig(overrides: Partial = {}): Config { + const defaultToolRegistry = { + getTool: () => undefined, + getToolByName: () => undefined, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByDisplayName: () => undefined, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + getExperiments: () => {}, + } as unknown as ToolRegistry; + + const baseConfig = { + getSessionId: () => 'test-session-id', + getUsageStatisticsEnabled: () => true, + getDebugMode: () => false, + isInteractive: () => true, + getApprovalMode: () => ApprovalMode.DEFAULT, + setApprovalMode: () => {}, + getAllowedTools: () => [], + getContentGeneratorConfig: () => ({ + model: 'test-model', + authType: 'oauth-personal', + }), + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + sanitizationConfig: { + enableEnvironmentVariableRedaction: true, + allowedEnvironmentVariables: [], + blockedEnvironmentVariables: [], + }, + }), + storage: { + getProjectTempDir: () => '/tmp', + }, + getTruncateToolOutputThreshold: () => + DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, + getTruncateToolOutputLines: () => 1000, + getToolRegistry: () => defaultToolRegistry, + getActiveModel: () => DEFAULT_GEMINI_MODEL, + getGeminiClient: () => null, + getMessageBus: () => createMockMessageBus(), + getEnableHooks: () => true, // Enabled for these tests + getExperiments: () => {}, + getPolicyEngine: () => + ({ + check: async () => ({ decision: 'allow' }), // Default allow for hook tests + }) as unknown as PolicyEngine, + } as unknown as Config; + + return { ...baseConfig, ...overrides } as Config; +} + +describe('CoreToolScheduler Hooks', () => { + it('should stop execution if BeforeTool hook requests stop', async () => { + const executeFn = vi.fn().mockResolvedValue({ + llmContent: 'Tool executed', + returnDisplay: 'Tool executed', + }); + const mockTool = new MockTool({ name: 'mockTool', execute: executeFn }); + + const toolRegistry = { + getTool: () => mockTool, + getToolByName: () => mockTool, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByDisplayName: () => mockTool, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + } as unknown as ToolRegistry; + + const mockMessageBus = createMockMessageBus(); + const mockHookSystem = { + fireBeforeToolEvent: vi.fn().mockResolvedValue({ + shouldStopExecution: () => true, + getEffectiveReason: () => 'Hook stopped execution', + getBlockingError: () => ({ blocked: false }), + isAskDecision: () => false, + }), + } as unknown as HookSystem; + + const mockConfig = createMockConfig({ + getToolRegistry: () => toolRegistry, + getMessageBus: () => mockMessageBus, + getHookSystem: () => mockHookSystem, + getApprovalMode: () => ApprovalMode.YOLO, + }); + + const onAllToolCallsComplete = vi.fn(); + const scheduler = new CoreToolScheduler({ + context: { + config: mockConfig, + messageBus: mockMessageBus, + toolRegistry, + } as unknown as AgentLoopContext, + onAllToolCallsComplete, + getPreferredEditor: () => 'vscode', + }); + + const request = { + callId: '1', + name: 'mockTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-1', + }; + + await scheduler.schedule([request], new AbortController().signal); + + await vi.waitFor(() => { + expect(onAllToolCallsComplete).toHaveBeenCalled(); + }); + + const completedCalls = onAllToolCallsComplete.mock + .calls[0][0] as ToolCall[]; + expect(completedCalls[0].status).toBe('error'); + const erroredCall = completedCalls[0] as ErroredToolCall; + + // Check error type/message + expect(erroredCall.response.error?.message).toContain( + 'Hook stopped execution', + ); + expect(executeFn).not.toHaveBeenCalled(); + }); + + it('should block tool execution if BeforeTool hook requests block', async () => { + const executeFn = vi.fn(); + const mockTool = new MockTool({ name: 'mockTool', execute: executeFn }); + + const toolRegistry = { + getTool: () => mockTool, + getToolByName: () => mockTool, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByDisplayName: () => mockTool, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + } as unknown as ToolRegistry; + + const mockMessageBus = createMockMessageBus(); + const mockHookSystem = { + fireBeforeToolEvent: vi.fn().mockResolvedValue({ + shouldStopExecution: () => false, + getBlockingError: () => ({ + blocked: true, + reason: 'Hook blocked execution', + }), + isAskDecision: () => false, + }), + } as unknown as HookSystem; + + const mockConfig = createMockConfig({ + getToolRegistry: () => toolRegistry, + getMessageBus: () => mockMessageBus, + getHookSystem: () => mockHookSystem, + getApprovalMode: () => ApprovalMode.YOLO, + }); + + const onAllToolCallsComplete = vi.fn(); + const scheduler = new CoreToolScheduler({ + context: { + config: mockConfig, + messageBus: mockMessageBus, + toolRegistry, + } as unknown as AgentLoopContext, + onAllToolCallsComplete, + getPreferredEditor: () => 'vscode', + }); + + const request = { + callId: '1', + name: 'mockTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-1', + }; + + await scheduler.schedule([request], new AbortController().signal); + + await vi.waitFor(() => { + expect(onAllToolCallsComplete).toHaveBeenCalled(); + }); + + const completedCalls = onAllToolCallsComplete.mock + .calls[0][0] as ToolCall[]; + expect(completedCalls[0].status).toBe('error'); + const erroredCall = completedCalls[0] as ErroredToolCall; + expect(erroredCall.response.error?.message).toContain( + 'Hook blocked execution', + ); + expect(executeFn).not.toHaveBeenCalled(); + }); + + it('should update tool input if BeforeTool hook provides modified input', async () => { + const executeFn = vi.fn().mockResolvedValue({ + llmContent: 'Tool executed', + returnDisplay: 'Tool executed', + }); + const mockTool = new MockTool({ name: 'mockTool', execute: executeFn }); + + const toolRegistry = { + getTool: () => mockTool, + getToolByName: () => mockTool, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByDisplayName: () => mockTool, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + } as unknown as ToolRegistry; + + const mockMessageBus = createMockMessageBus(); + const mockBeforeOutput = new BeforeToolHookOutput({ + continue: true, + hookSpecificOutput: { + hookEventName: 'BeforeTool', + tool_input: { newParam: 'modifiedValue' }, + }, + }); + + const mockHookSystem = { + fireBeforeToolEvent: vi.fn().mockResolvedValue(mockBeforeOutput), + fireAfterToolEvent: vi.fn(), + } as unknown as HookSystem; + + const mockConfig = createMockConfig({ + getToolRegistry: () => toolRegistry, + getMessageBus: () => mockMessageBus, + getHookSystem: () => mockHookSystem, + getApprovalMode: () => ApprovalMode.YOLO, + }); + + const onAllToolCallsComplete = vi.fn(); + const scheduler = new CoreToolScheduler({ + context: { + config: mockConfig, + messageBus: mockMessageBus, + toolRegistry, + } as unknown as AgentLoopContext, + onAllToolCallsComplete, + getPreferredEditor: () => 'vscode', + }); + + const request = { + callId: '1', + name: 'mockTool', + args: { originalParam: 'originalValue' }, + isClientInitiated: false, + prompt_id: 'prompt-1', + }; + + await scheduler.schedule([request], new AbortController().signal); + + await vi.waitFor(() => { + expect(onAllToolCallsComplete).toHaveBeenCalled(); + }); + + const completedCalls = onAllToolCallsComplete.mock + .calls[0][0] as ToolCall[]; + expect(completedCalls[0].status).toBe('success'); + + // Verify execute was called with modified args + expect(executeFn).toHaveBeenCalledWith( + { newParam: 'modifiedValue' }, + expect.anything(), + undefined, + expect.anything(), + ); + + // Verify call request args were updated in the completion report + expect(completedCalls[0].request.args).toEqual({ + newParam: 'modifiedValue', + }); + }); +}); diff --git a/packages/core/src/hooks/hookAggregator.ts b/packages/core/src/hooks/hookAggregator.ts index 73e814702e..b67266edf5 100644 --- a/packages/core/src/hooks/hookAggregator.ts +++ b/packages/core/src/hooks/hookAggregator.ts @@ -125,6 +125,7 @@ export class HookAggregator { const additionalContexts: string[] = []; let hasBlockDecision = false; + let hasAskDecision = false; let hasContinueFalse = false; for (const output of outputs) { @@ -142,6 +143,12 @@ export class HookAggregator { if (tempOutput.isBlockingDecision()) { hasBlockDecision = true; merged.decision = output.decision; + } else if (tempOutput.isAskDecision()) { + hasAskDecision = true; + // Ask decision is only set if no blocking decision was found so far + if (!hasBlockDecision) { + merged.decision = output.decision; + } } // Collect messages @@ -180,8 +187,8 @@ export class HookAggregator { this.extractAdditionalContext(output, additionalContexts); } - // Set final decision if no blocking decision was found - if (!hasBlockDecision && !hasContinueFalse) { + // Set final decision if no blocking or ask decision was found + if (!hasBlockDecision && !hasAskDecision && !hasContinueFalse) { merged.decision = 'allow'; } diff --git a/packages/core/src/hooks/types.ts b/packages/core/src/hooks/types.ts index 9c6217ffa4..c1a35384ae 100644 --- a/packages/core/src/hooks/types.ts +++ b/packages/core/src/hooks/types.ts @@ -197,12 +197,19 @@ export class DefaultHookOutput implements HookOutput { } /** - * Check if this output represents a blocking decision + * Check if this output represents a blocking decision (block or deny) */ isBlockingDecision(): boolean { return this.decision === 'block' || this.decision === 'deny'; } + /** + * Check if this output represents an 'ask' decision + */ + isAskDecision(): boolean { + return this.decision === 'ask'; + } + /** * Check if this output requests to stop execution */ diff --git a/packages/core/src/scheduler/confirmation.ts b/packages/core/src/scheduler/confirmation.ts index 67ae26d2eb..7db7a0b48f 100644 --- a/packages/core/src/scheduler/confirmation.ts +++ b/packages/core/src/scheduler/confirmation.ts @@ -16,6 +16,7 @@ import { ToolConfirmationOutcome, type ToolConfirmationPayload, type ToolCallConfirmationDetails, + type ForcedToolDecision, } from '../tools/tools.js'; import { type ValidatingToolCall, @@ -116,6 +117,8 @@ export async function resolveConfirmation( getPreferredEditor: () => EditorType | undefined; schedulerId: string; onWaitingForConfirmation?: (waiting: boolean) => void; + systemMessage?: string; + forcedDecision?: ForcedToolDecision; }, ): Promise { const { state, onWaitingForConfirmation } = deps; @@ -126,7 +129,7 @@ export async function resolveConfirmation( // Loop exists to allow the user to modify the parameters and see the new // diff. while (outcome === ToolConfirmationOutcome.ModifyWithEditor) { - if (signal.aborted) throw new Error('Operation cancelled'); + if (signal.aborted) throw new Error('Operation cancelled by user'); const currentCall = state.getToolCall(callId); if (!currentCall || !('invocation' in currentCall)) { @@ -134,12 +137,19 @@ export async function resolveConfirmation( } const currentInvocation = currentCall.invocation; - const details = await currentInvocation.shouldConfirmExecute(signal); + const details = await currentInvocation.shouldConfirmExecute( + signal, + deps.forcedDecision, + ); if (!details) { outcome = ToolConfirmationOutcome.ProceedOnce; break; } + if (deps.systemMessage) { + details.systemMessage = deps.systemMessage; + } + await notifyHooks(deps, details); const correlationId = randomUUID(); diff --git a/packages/core/src/scheduler/hook-utils.ts b/packages/core/src/scheduler/hook-utils.ts new file mode 100644 index 0000000000..78d5aeaa53 --- /dev/null +++ b/packages/core/src/scheduler/hook-utils.ts @@ -0,0 +1,109 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../config/config.js'; +import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js'; +import type { ToolCallRequestInfo } from './types.js'; +import { extractMcpContext } from '../core/coreToolHookTriggers.js'; +import { BeforeToolHookOutput } from '../hooks/types.js'; +import { ToolErrorType } from '../tools/tool-error.js'; + +export type HookEvaluationResult = + | { + status: 'continue'; + hookDecision?: 'ask' | 'block'; + hookSystemMessage?: string; + modifiedArgs?: Record; + newInvocation?: AnyToolInvocation; + } + | { + status: 'error'; + error: Error; + errorType: ToolErrorType; + }; + +export async function evaluateBeforeToolHook( + config: Config, + tool: AnyDeclarativeTool, + request: ToolCallRequestInfo, + invocation: AnyToolInvocation, +): Promise { + const hookSystem = config.getHookSystem(); + if (!hookSystem) { + return { status: 'continue' }; + } + + const params = invocation.params || {}; + const toolInput: Record = { ...params }; + const mcpContext = extractMcpContext(invocation, config); + + const beforeOutput = await hookSystem.fireBeforeToolEvent( + request.name, + toolInput, + mcpContext, + request.originalRequestName, + ); + + if (!beforeOutput) { + return { status: 'continue' }; + } + + if (beforeOutput.shouldStopExecution()) { + return { + status: 'error', + error: new Error( + `Agent execution stopped by hook: ${beforeOutput.getEffectiveReason()}`, + ), + errorType: ToolErrorType.STOP_EXECUTION, + }; + } + + const blockingError = beforeOutput.getBlockingError(); + if (blockingError?.blocked) { + return { + status: 'error', + error: new Error(`Tool execution blocked: ${blockingError.reason}`), + errorType: ToolErrorType.POLICY_VIOLATION, + }; + } + + let hookDecision: 'ask' | 'block' | undefined; + let hookSystemMessage: string | undefined; + + if (beforeOutput.isAskDecision()) { + hookDecision = 'ask'; + hookSystemMessage = beforeOutput.systemMessage; + } + + let modifiedArgs: Record | undefined; + let newInvocation: AnyToolInvocation | undefined; + + if (beforeOutput instanceof BeforeToolHookOutput) { + const modifiedInput = beforeOutput.getModifiedToolInput(); + if (modifiedInput) { + modifiedArgs = modifiedInput; + try { + newInvocation = tool.build(modifiedInput); + } catch (error) { + return { + status: 'error', + error: new Error( + `Tool parameter modification by hook failed validation: ${error instanceof Error ? error.message : String(error)}`, + ), + errorType: ToolErrorType.INVALID_TOOL_PARAMS, + }; + } + } + } + + return { + status: 'continue', + hookDecision, + hookSystemMessage, + modifiedArgs, + newInvocation, + }; +} diff --git a/packages/core/src/scheduler/policy.test.ts b/packages/core/src/scheduler/policy.test.ts index 32a92309e0..435fe6524d 100644 --- a/packages/core/src/scheduler/policy.test.ts +++ b/packages/core/src/scheduler/policy.test.ts @@ -824,6 +824,7 @@ describe('Plan Mode Denial Consistency', () => { toolRegistry: mockToolRegistry, getToolRegistry: () => mockToolRegistry, getMessageBus: vi.fn().mockReturnValue(mockMessageBus), + getHookSystem: vi.fn().mockReturnValue(undefined), isInteractive: vi.fn().mockReturnValue(true), getEnableHooks: vi.fn().mockReturnValue(false), getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.PLAN), // Key: Plan Mode diff --git a/packages/core/src/scheduler/scheduler.test.ts b/packages/core/src/scheduler/scheduler.test.ts index 35cfdc3af7..3ad99c397b 100644 --- a/packages/core/src/scheduler/scheduler.test.ts +++ b/packages/core/src/scheduler/scheduler.test.ts @@ -170,6 +170,8 @@ describe('Scheduler (Orchestrator)', () => { mockConfig = { getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), toolRegistry: mockToolRegistry, + getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + getHookSystem: vi.fn().mockReturnValue(undefined), isInteractive: vi.fn().mockReturnValue(true), getEnableHooks: vi.fn().mockReturnValue(true), setApprovalMode: vi.fn(), @@ -1346,6 +1348,7 @@ describe('Scheduler MCP Progress', () => { mockConfig = { getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + getHookSystem: vi.fn().mockReturnValue(undefined), isInteractive: vi.fn().mockReturnValue(true), getEnableHooks: vi.fn().mockReturnValue(true), setApprovalMode: vi.fn(), diff --git a/packages/core/src/scheduler/scheduler.ts b/packages/core/src/scheduler/scheduler.ts index cc14e3d875..db272213fa 100644 --- a/packages/core/src/scheduler/scheduler.ts +++ b/packages/core/src/scheduler/scheduler.ts @@ -10,6 +10,7 @@ import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { SchedulerStateManager } from './state-manager.js'; import { resolveConfirmation } from './confirmation.js'; import { checkPolicy, updatePolicy, getPolicyDenialError } from './policy.js'; +import { evaluateBeforeToolHook } from './hook-utils.js'; import { ToolExecutor } from './tool-executor.js'; import { ToolModificationHandler } from './tool-modifier.js'; import { @@ -572,12 +573,46 @@ export class Scheduler { ): Promise { const callId = toolCall.request.callId; - // Policy & Security - const { decision, rule } = await checkPolicy( + // 1. Hook Check (BeforeTool) + const hookResult = await evaluateBeforeToolHook( + this.config, + toolCall.tool, + toolCall.request, + toolCall.invocation, + ); + + if (hookResult.status === 'error') { + this.state.updateStatus( + callId, + CoreToolCallStatus.Error, + createErrorResponse( + toolCall.request, + hookResult.error, + hookResult.errorType, + ), + ); + return; + } + + const { hookDecision, hookSystemMessage, modifiedArgs, newInvocation } = + hookResult; + + if (modifiedArgs && newInvocation) { + toolCall.request.args = modifiedArgs; + toolCall.request.inputModifiedByHook = true; + toolCall.invocation = newInvocation; + } + + // 2. Policy & Security + const { decision: policyDecision, rule } = await checkPolicy( toolCall, this.config, this.subagent, ); + let decision = policyDecision; + if (hookDecision === 'ask') { + decision = PolicyDecision.ASK_USER; + } if (decision === PolicyDecision.DENY) { const { errorMessage, errorType } = getPolicyDenialError( @@ -610,6 +645,8 @@ export class Scheduler { getPreferredEditor: this.getPreferredEditor, schedulerId: this.schedulerId, onWaitingForConfirmation: this.onWaitingForConfirmation, + systemMessage: hookSystemMessage, + forcedDecision: hookDecision === 'ask' ? 'ask_user' : undefined, }); outcome = result.outcome; lastDetails = result.lastDetails; diff --git a/packages/core/src/scheduler/scheduler_parallel.test.ts b/packages/core/src/scheduler/scheduler_parallel.test.ts index 06b5e169df..1a9d3fe172 100644 --- a/packages/core/src/scheduler/scheduler_parallel.test.ts +++ b/packages/core/src/scheduler/scheduler_parallel.test.ts @@ -212,6 +212,8 @@ describe('Scheduler Parallel Execution', () => { mockConfig = { getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), toolRegistry: mockToolRegistry, + getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + getHookSystem: vi.fn().mockReturnValue(undefined), isInteractive: vi.fn().mockReturnValue(true), getEnableHooks: vi.fn().mockReturnValue(true), setApprovalMode: vi.fn(), diff --git a/packages/core/src/scheduler/tool-executor.ts b/packages/core/src/scheduler/tool-executor.ts index 81232d39d9..91e4e49073 100644 --- a/packages/core/src/scheduler/tool-executor.ts +++ b/packages/core/src/scheduler/tool-executor.ts @@ -115,10 +115,25 @@ export class ToolExecutor { { shellExecutionConfig, setExecutionIdCallback }, this.config, request.originalRequestName, + true, // skipBeforeHook ); const toolResult: ToolResult = await promise; + if (call.request.inputModifiedByHook) { + const modificationMsg = `\n\n[System] Tool input parameters were modified by a hook before execution.`; + if (typeof toolResult.llmContent === 'string') { + toolResult.llmContent += modificationMsg; + } else if (Array.isArray(toolResult.llmContent)) { + toolResult.llmContent.push({ text: modificationMsg }); + } else if (toolResult.llmContent) { + toolResult.llmContent = [ + toolResult.llmContent, + { text: modificationMsg }, + ]; + } + } + if (signal.aborted) { completedToolCall = await this.createCancelledResult( call, diff --git a/packages/core/src/scheduler/types.ts b/packages/core/src/scheduler/types.ts index 9fedd48f41..a9cde87d27 100644 --- a/packages/core/src/scheduler/types.ts +++ b/packages/core/src/scheduler/types.ts @@ -47,6 +47,8 @@ export interface ToolCallRequestInfo { traceId?: string; parentCallId?: string; schedulerId?: string; + inputModifiedByHook?: boolean; + forcedAsk?: boolean; } export interface ToolCallResponseInfo { diff --git a/packages/core/src/telemetry/conseca-logger.test.ts b/packages/core/src/telemetry/conseca-logger.test.ts index e3ce85432e..0eac29276f 100644 --- a/packages/core/src/telemetry/conseca-logger.test.ts +++ b/packages/core/src/telemetry/conseca-logger.test.ts @@ -112,7 +112,7 @@ describe('conseca-logger', () => { 'user prompt', 'policy', 'tool call', - 'ALLOW', + 'allow', 'rationale', ); @@ -122,7 +122,7 @@ describe('conseca-logger', () => { expect(logs.getLogger).toHaveBeenCalled(); expect(mockLogger.emit).toHaveBeenCalledWith( expect.objectContaining({ - body: 'Conseca Verdict: ALLOW.', + body: 'Conseca Verdict: allow.', attributes: expect.objectContaining({ 'event.name': EVENT_CONSECA_VERDICT, }), diff --git a/packages/core/src/test-utils/mock-tool.ts b/packages/core/src/test-utils/mock-tool.ts index 5f89a506cd..a16f42093b 100644 --- a/packages/core/src/test-utils/mock-tool.ts +++ b/packages/core/src/test-utils/mock-tool.ts @@ -14,7 +14,9 @@ import { Kind, type ToolCallConfirmationDetails, type ToolInvocation, + type ToolLiveOutput, type ToolResult, + type ExecuteOptions, } from '../tools/tools.js'; import { createMockMessageBus } from './mock-message-bus.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; @@ -33,6 +35,7 @@ interface MockToolOptions { params: { [key: string]: unknown }, signal?: AbortSignal, updateOutput?: (output: string) => void, + options?: ExecuteOptions, ) => Promise; params?: object; messageBus?: MessageBus; @@ -52,13 +55,15 @@ class MockToolInvocation extends BaseToolInvocation< execute( signal: AbortSignal, - updateOutput?: (output: string) => void, + updateOutput?: (output: ToolLiveOutput) => void, + options?: ExecuteOptions, ): Promise { - if (updateOutput) { - return this.tool.execute(this.params, signal, updateOutput); - } else { - return this.tool.execute(this.params); - } + return this.tool.execute( + this.params, + signal, + updateOutput as ((output: string) => void) | undefined, + options, + ); } override shouldConfirmExecute( @@ -79,14 +84,16 @@ export class MockTool extends BaseDeclarativeTool< { [key: string]: unknown }, ToolResult > { - shouldConfirmExecute: ( + readonly shouldConfirmExecute: ( params: { [key: string]: unknown }, signal: AbortSignal, ) => Promise; - execute: ( + + readonly execute: ( params: { [key: string]: unknown }, signal?: AbortSignal, updateOutput?: (output: string) => void, + options?: ExecuteOptions, ) => Promise; constructor(options: MockToolOptions) { @@ -150,7 +157,11 @@ export class MockModifiableToolInvocation extends BaseToolInvocation< super(params, messageBus, tool.name, tool.displayName); } - async execute(_abortSignal: AbortSignal): Promise { + async execute( + _signal: AbortSignal, + _updateOutput?: (output: ToolLiveOutput) => void, + _options?: ExecuteOptions, + ): Promise { const result = this.tool.executeFn(this.params); return ( result ?? { diff --git a/packages/core/src/tools/confirmation-policy.test.ts b/packages/core/src/tools/confirmation-policy.test.ts index b18b1dd77e..af9f178b8b 100644 --- a/packages/core/src/tools/confirmation-policy.test.ts +++ b/packages/core/src/tools/confirmation-policy.test.ts @@ -166,7 +166,7 @@ describe('Tool Confirmation Policy Updates', () => { // Mock getMessageBusDecision to trigger ASK_USER flow vi.spyOn(invocation as any, 'getMessageBusDecision').mockResolvedValue( - 'ASK_USER', + 'ask_user', ); const confirmation = await invocation.shouldConfirmExecute( @@ -194,5 +194,39 @@ describe('Tool Confirmation Policy Updates', () => { } }, ); + + it('should skip confirmation in AUTO_EDIT mode', async () => { + vi.spyOn(mockConfig, 'getApprovalMode').mockReturnValue( + ApprovalMode.AUTO_EDIT, + ); + const tool = create(mockConfig, mockMessageBus); + const invocation = tool.build(params as any); + + const confirmation = await invocation.shouldConfirmExecute( + new AbortController().signal, + ); + + expect(confirmation).toBe(false); + }); + + it('should NOT skip confirmation in AUTO_EDIT mode if forcedDecision is ask_user', async () => { + vi.spyOn(mockConfig, 'getApprovalMode').mockReturnValue( + ApprovalMode.AUTO_EDIT, + ); + const tool = create(mockConfig, mockMessageBus); + const invocation = tool.build(params as any); + + // Mock getMessageBusDecision to return ask_user + vi.spyOn(invocation as any, 'getMessageBusDecision').mockResolvedValue( + 'ask_user', + ); + + const confirmation = await invocation.shouldConfirmExecute( + new AbortController().signal, + 'ask_user', + ); + + expect(confirmation).not.toBe(false); + }); }); }); diff --git a/packages/core/src/tools/edit.ts b/packages/core/src/tools/edit.ts index bfa70565be..cbf36936a9 100644 --- a/packages/core/src/tools/edit.ts +++ b/packages/core/src/tools/edit.ts @@ -29,7 +29,6 @@ import { makeRelative, shortenPath } from '../utils/paths.js'; import { isNodeError } from '../utils/errors.js'; import { correctPath } from '../utils/pathCorrector.js'; import type { Config } from '../config/config.js'; -import { ApprovalMode } from '../policy/types.js'; import { CoreToolCallStatus } from '../scheduler/types.js'; import { DEFAULT_DIFF_OPTIONS, getDiffStat } from './diffOptions.js'; @@ -454,7 +453,16 @@ class EditToolInvocation toolName?: string, displayName?: string, ) { - super(params, messageBus, toolName, displayName); + super( + params, + messageBus, + toolName, + displayName, + undefined, + undefined, + true, + () => this.config.getApprovalMode(), + ); if (!path.isAbsolute(this.params.file_path)) { const result = correctPath(this.params.file_path, this.config); if (result.success) { @@ -732,10 +740,6 @@ class EditToolInvocation protected override async getConfirmationDetails( abortSignal: AbortSignal, ): Promise { - if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) { - return false; - } - let editData: CalculatedEdit; try { editData = await this.calculateEdit(this.params, abortSignal); diff --git a/packages/core/src/tools/enter-plan-mode.test.ts b/packages/core/src/tools/enter-plan-mode.test.ts index 48bc5b494e..d14e1bfcdc 100644 --- a/packages/core/src/tools/enter-plan-mode.test.ts +++ b/packages/core/src/tools/enter-plan-mode.test.ts @@ -47,7 +47,7 @@ describe('EnterPlanModeTool', () => { getMessageBusDecision: () => Promise; }, 'getMessageBusDecision', - ).mockResolvedValue('ASK_USER'); + ).mockResolvedValue('ask_user'); const result = await invocation.shouldConfirmExecute( new AbortController().signal, @@ -74,7 +74,7 @@ describe('EnterPlanModeTool', () => { getMessageBusDecision: () => Promise; }, 'getMessageBusDecision', - ).mockResolvedValue('ALLOW'); + ).mockResolvedValue('allow'); const result = await invocation.shouldConfirmExecute( new AbortController().signal, @@ -92,7 +92,7 @@ describe('EnterPlanModeTool', () => { getMessageBusDecision: () => Promise; }, 'getMessageBusDecision', - ).mockResolvedValue('DENY'); + ).mockResolvedValue('deny'); await expect( invocation.shouldConfirmExecute(new AbortController().signal), @@ -136,7 +136,7 @@ describe('EnterPlanModeTool', () => { getMessageBusDecision: () => Promise; }, 'getMessageBusDecision', - ).mockResolvedValue('ASK_USER'); + ).mockResolvedValue('ask_user'); const details = await invocation.shouldConfirmExecute( new AbortController().signal, diff --git a/packages/core/src/tools/enter-plan-mode.ts b/packages/core/src/tools/enter-plan-mode.ts index d52c721aae..dee8569669 100644 --- a/packages/core/src/tools/enter-plan-mode.ts +++ b/packages/core/src/tools/enter-plan-mode.ts @@ -87,11 +87,11 @@ export class EnterPlanModeInvocation extends BaseToolInvocation< abortSignal: AbortSignal, ): Promise { const decision = await this.getMessageBusDecision(abortSignal); - if (decision === 'ALLOW') { + if (decision === 'allow') { return false; } - if (decision === 'DENY') { + if (decision === 'deny') { throw new Error( `Tool execution for "${ this._toolDisplayName || this._toolName @@ -99,7 +99,7 @@ export class EnterPlanModeInvocation extends BaseToolInvocation< ); } - // ASK_USER + // ask_user return { type: 'info', title: 'Enter Plan Mode', diff --git a/packages/core/src/tools/exit-plan-mode.test.ts b/packages/core/src/tools/exit-plan-mode.test.ts index 88e327ab34..855c5d2aba 100644 --- a/packages/core/src/tools/exit-plan-mode.test.ts +++ b/packages/core/src/tools/exit-plan-mode.test.ts @@ -59,7 +59,7 @@ describe('ExitPlanModeTool', () => { getMessageBusDecision: () => Promise; }, 'getMessageBusDecision', - ).mockResolvedValue('ASK_USER'); + ).mockResolvedValue('ask_user'); }); afterEach(() => { @@ -127,7 +127,7 @@ describe('ExitPlanModeTool', () => { getMessageBusDecision: () => Promise; }, 'getMessageBusDecision', - ).mockResolvedValue('ALLOW'); + ).mockResolvedValue('allow'); const result = await invocation.shouldConfirmExecute( new AbortController().signal, @@ -150,7 +150,7 @@ describe('ExitPlanModeTool', () => { getMessageBusDecision: () => Promise; }, 'getMessageBusDecision', - ).mockResolvedValue('DENY'); + ).mockResolvedValue('deny'); await expect( invocation.shouldConfirmExecute(new AbortController().signal), diff --git a/packages/core/src/tools/exit-plan-mode.ts b/packages/core/src/tools/exit-plan-mode.ts index aad95492c2..892e8926e0 100644 --- a/packages/core/src/tools/exit-plan-mode.ts +++ b/packages/core/src/tools/exit-plan-mode.ts @@ -138,7 +138,7 @@ export class ExitPlanModeInvocation extends BaseToolInvocation< } const decision = await this.getMessageBusDecision(abortSignal); - if (decision === 'DENY') { + if (decision === 'deny') { throw new Error( `Tool execution for "${ this._toolDisplayName || this._toolName @@ -146,7 +146,7 @@ export class ExitPlanModeInvocation extends BaseToolInvocation< ); } - if (decision === 'ALLOW') { + if (decision === 'allow') { // If policy is allow, auto-approve with default settings and execute. this.confirmationOutcome = ToolConfirmationOutcome.ProceedOnce; this.approvalPayload = { @@ -156,7 +156,7 @@ export class ExitPlanModeInvocation extends BaseToolInvocation< return false; } - // decision is 'ASK_USER' + // decision is 'ask_user' return { type: 'exit_plan_mode', title: 'Plan Approval', diff --git a/packages/core/src/tools/message-bus-integration.test.ts b/packages/core/src/tools/message-bus-integration.test.ts index bfc369b58b..91a2e30d94 100644 --- a/packages/core/src/tools/message-bus-integration.test.ts +++ b/packages/core/src/tools/message-bus-integration.test.ts @@ -57,10 +57,10 @@ class TestToolInvocation extends BaseToolInvocation { abortSignal: AbortSignal, ): Promise { const decision = await this.getMessageBusDecision(abortSignal); - if (decision === 'ALLOW') { + if (decision === 'allow') { return false; } - if (decision === 'DENY') { + if (decision === 'deny') { throw new Error('Tool execution denied by policy'); } return false; diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 3865aaf357..8b7d7223bd 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -19,9 +19,15 @@ import { type ToolConfirmationResponse, type Question, } from '../confirmation-bus/types.js'; -import { type ApprovalMode } from '../policy/types.js'; +import { ApprovalMode } from '../policy/types.js'; import type { SubagentProgress } from '../agents/types.js'; +/** +/** + * Supported decisions for forcing tool execution behavior. + */ +export type ForcedToolDecision = 'allow' | 'deny' | 'ask_user'; + /** * Options bag for tool execution, replacing positional parameters that are * only relevant to specific tool types. @@ -65,6 +71,7 @@ export interface ToolInvocation< */ shouldConfirmExecute( abortSignal: AbortSignal, + forcedDecision?: ForcedToolDecision, ): Promise; /** @@ -148,6 +155,8 @@ export abstract class BaseToolInvocation< readonly _toolDisplayName?: string, readonly _serverName?: string, readonly _toolAnnotations?: Record, + readonly respectsAutoEdit: boolean = false, + readonly getApprovalMode: () => ApprovalMode = () => ApprovalMode.DEFAULT, ) {} abstract getDescription(): string; @@ -158,13 +167,23 @@ export abstract class BaseToolInvocation< async shouldConfirmExecute( abortSignal: AbortSignal, + forcedDecision?: ForcedToolDecision, ): Promise { - const decision = await this.getMessageBusDecision(abortSignal); - if (decision === 'ALLOW') { + if ( + this.respectsAutoEdit && + this.getApprovalMode() === ApprovalMode.AUTO_EDIT && + forcedDecision !== 'ask_user' + ) { return false; } - if (decision === 'DENY') { + const decision = + forcedDecision ?? (await this.getMessageBusDecision(abortSignal)); + if (decision === 'allow') { + return false; + } + + if (decision === 'deny') { throw new Error( `Tool execution for "${ this._toolDisplayName || this._toolName @@ -172,7 +191,7 @@ export abstract class BaseToolInvocation< ); } - if (decision === 'ASK_USER') { + if (decision === 'ask_user') { return this.getConfirmationDetails(abortSignal); } @@ -216,7 +235,7 @@ export abstract class BaseToolInvocation< /** * Subclasses should override this method to provide custom confirmation UI - * when the policy engine's decision is 'ASK_USER'. + * when the policy engine's decision is 'ask_user'. * The base implementation provides a generic confirmation prompt. */ protected async getConfirmationDetails( @@ -239,11 +258,12 @@ export abstract class BaseToolInvocation< protected getMessageBusDecision( abortSignal: AbortSignal, - ): Promise<'ALLOW' | 'DENY' | 'ASK_USER'> { + forcedDecision?: ForcedToolDecision, + ): Promise { if (!this.messageBus || !this._toolName) { // If there's no message bus, we can't make a decision, so we allow. // The legacy confirmation flow will still apply if the tool needs it. - return Promise.resolve('ALLOW'); + return Promise.resolve('allow'); } const correlationId = randomUUID(); @@ -257,11 +277,12 @@ export abstract class BaseToolInvocation< }, serverName: this._serverName, toolAnnotations: this._toolAnnotations, + forcedDecision, }; - return new Promise<'ALLOW' | 'DENY' | 'ASK_USER'>((resolve) => { + return new Promise((resolve) => { if (!this.messageBus) { - resolve('ALLOW'); + resolve('allow'); return; } @@ -282,11 +303,11 @@ export abstract class BaseToolInvocation< const abortHandler = () => { cleanup(); - resolve('DENY'); + resolve('deny'); }; if (abortSignal.aborted) { - resolve('DENY'); + resolve('deny'); return; } @@ -294,11 +315,11 @@ export abstract class BaseToolInvocation< if (response.correlationId === correlationId) { cleanup(); if (response.requiresUserConfirmation) { - resolve('ASK_USER'); + resolve('ask_user'); } else if (response.confirmed) { - resolve('ALLOW'); + resolve('allow'); } else { - resolve('DENY'); + resolve('deny'); } } }; @@ -307,7 +328,7 @@ export abstract class BaseToolInvocation< timeoutId = setTimeout(() => { cleanup(); - resolve('ASK_USER'); // Default to ASK_USER on timeout + resolve('ask_user'); // Default to ask_user on timeout }, 30000); this.messageBus.subscribe( @@ -325,7 +346,7 @@ export abstract class BaseToolInvocation< void this.messageBus.publish(request); } catch (_error) { cleanup(); - resolve('ALLOW'); + resolve('allow'); } }); } @@ -859,6 +880,7 @@ export interface DiffStat { export interface ToolEditConfirmationDetails { type: 'edit'; title: string; + systemMessage?: string; onConfirm: ( outcome: ToolConfirmationOutcome, payload?: ToolConfirmationPayload, @@ -897,6 +919,7 @@ export type ToolConfirmationPayload = export interface ToolExecuteConfirmationDetails { type: 'exec'; title: string; + systemMessage?: string; onConfirm: (outcome: ToolConfirmationOutcome) => Promise; command: string; rootCommand: string; @@ -907,6 +930,7 @@ export interface ToolExecuteConfirmationDetails { export interface ToolMcpConfirmationDetails { type: 'mcp'; title: string; + systemMessage?: string; serverName: string; toolName: string; toolDisplayName: string; @@ -919,6 +943,7 @@ export interface ToolMcpConfirmationDetails { export interface ToolInfoConfirmationDetails { type: 'info'; title: string; + systemMessage?: string; onConfirm: (outcome: ToolConfirmationOutcome) => Promise; prompt: string; urls?: string[]; @@ -927,6 +952,7 @@ export interface ToolInfoConfirmationDetails { export interface ToolAskUserConfirmationDetails { type: 'ask_user'; title: string; + systemMessage?: string; questions: Question[]; onConfirm: ( outcome: ToolConfirmationOutcome, @@ -937,6 +963,7 @@ export interface ToolAskUserConfirmationDetails { export interface ToolExitPlanModeConfirmationDetails { type: 'exit_plan_mode'; title: string; + systemMessage?: string; planPath: string; onConfirm: ( outcome: ToolConfirmationOutcome, diff --git a/packages/core/src/tools/web-fetch.ts b/packages/core/src/tools/web-fetch.ts index 27a60c4259..5240da9451 100644 --- a/packages/core/src/tools/web-fetch.ts +++ b/packages/core/src/tools/web-fetch.ts @@ -18,7 +18,6 @@ import { buildParamArgsPattern } from '../policy/utils.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { ToolErrorType } from './tool-error.js'; import { getErrorMessage } from '../utils/errors.js'; -import { ApprovalMode } from '../policy/types.js'; import { getResponseText } from '../utils/partUtils.js'; import { fetchWithTimeout, isPrivateIp } from '../utils/fetch.js'; import { truncateString } from '../utils/textUtils.js'; @@ -231,7 +230,16 @@ class WebFetchToolInvocation extends BaseToolInvocation< _toolName?: string, _toolDisplayName?: string, ) { - super(params, messageBus, _toolName, _toolDisplayName); + super( + params, + messageBus, + _toolName, + _toolDisplayName, + undefined, + undefined, + true, + () => this.context.config.getApprovalMode(), + ); } private handleRetry(attempt: number, error: unknown, delayMs: number): void { @@ -516,12 +524,6 @@ ${aggregatedContent} protected override async getConfirmationDetails( _abortSignal: AbortSignal, ): Promise { - // Check for AUTO_EDIT approval mode. This tool has a specific behavior - // where ProceedAlways switches the entire session to AUTO_EDIT. - if (this.context.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) { - return false; - } - let urls: string[] = []; let prompt = this.params.prompt || ''; diff --git a/packages/core/src/tools/write-file.ts b/packages/core/src/tools/write-file.ts index f725a21c43..8ba967114c 100644 --- a/packages/core/src/tools/write-file.ts +++ b/packages/core/src/tools/write-file.ts @@ -11,7 +11,6 @@ import os from 'node:os'; import * as Diff from 'diff'; import { WRITE_FILE_TOOL_NAME, WRITE_FILE_DISPLAY_NAME } from './tool-names.js'; import type { Config } from '../config/config.js'; -import { ApprovalMode } from '../policy/types.js'; import { BaseDeclarativeTool, @@ -156,7 +155,16 @@ class WriteFileToolInvocation extends BaseToolInvocation< toolName?: string, displayName?: string, ) { - super(params, messageBus, toolName, displayName); + super( + params, + messageBus, + toolName, + displayName, + undefined, + undefined, + true, + () => this.config.getApprovalMode(), + ); this.resolvedPath = path.resolve( this.config.getTargetDir(), this.params.file_path, @@ -186,10 +194,6 @@ class WriteFileToolInvocation extends BaseToolInvocation< protected override async getConfirmationDetails( abortSignal: AbortSignal, ): Promise { - if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) { - return false; - } - const correctedContentResult = await getCorrectedFileContent( this.config, this.resolvedPath,