mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-19 10:31:16 -07:00
fix(hooks): support 'ask' decision for BeforeTool hooks
This commit is contained in:
@@ -8,6 +8,7 @@ import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import { TestRig, poll, normalizePath } from './test-helper.js';
|
||||
import { join } from 'node:path';
|
||||
import { writeFileSync } from 'node:fs';
|
||||
import os from 'node:os';
|
||||
|
||||
describe('Hooks System Integration', () => {
|
||||
let rig: TestRig;
|
||||
@@ -2230,7 +2231,7 @@ console.log(JSON.stringify({
|
||||
|
||||
// The hook should have stopped execution message (returned from tool)
|
||||
expect(result).toContain(
|
||||
'Agent execution stopped: Emergency Stop triggered by hook',
|
||||
'Agent execution stopped by hook: Emergency Stop triggered by hook',
|
||||
);
|
||||
|
||||
// Tool should NOT be called successfully (it was blocked/stopped)
|
||||
@@ -2242,4 +2243,166 @@ console.log(JSON.stringify({
|
||||
expect(writeFileCalls).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Hooks "ask" Decision Integration', () => {
|
||||
it(
|
||||
'should force confirmation prompt when hook returns "ask" decision even in YOLO mode',
|
||||
{ timeout: 20000 },
|
||||
async () => {
|
||||
const testName =
|
||||
'should force confirmation prompt when hook returns "ask" decision';
|
||||
|
||||
// 1. Setup hook script that returns 'ask' decision
|
||||
const hookOutput = {
|
||||
decision: 'ask',
|
||||
systemMessage: 'Confirmation forced by security hook',
|
||||
hookSpecificOutput: {
|
||||
hookEventName: 'BeforeTool',
|
||||
},
|
||||
};
|
||||
|
||||
const hookScript = `console.log(JSON.stringify(${JSON.stringify(
|
||||
hookOutput,
|
||||
)}));`;
|
||||
|
||||
// Create script path predictably
|
||||
const scriptPath = join(os.tmpdir(), 'gemini-cli-tests-ask-hook.js');
|
||||
writeFileSync(scriptPath, hookScript);
|
||||
|
||||
// 2. Setup rig with YOLO mode enabled but with the 'ask' hook
|
||||
rig.setup(testName, {
|
||||
fakeResponsesPath: join(
|
||||
import.meta.dirname,
|
||||
'hooks-system.allow-tool.responses',
|
||||
),
|
||||
settings: {
|
||||
debugMode: true,
|
||||
tools: {
|
||||
approval: 'yolo',
|
||||
},
|
||||
hooksConfig: {
|
||||
enabled: true,
|
||||
},
|
||||
hooks: {
|
||||
BeforeTool: [
|
||||
{
|
||||
matcher: 'write_file',
|
||||
hooks: [
|
||||
{
|
||||
type: 'command',
|
||||
command: `node "${scriptPath}"`,
|
||||
timeout: 5000,
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// 3. Run interactive and verify prompt appears despite YOLO mode
|
||||
const run = await rig.runInteractive();
|
||||
|
||||
// Send prompt that will trigger write_file
|
||||
await run.type('Create a file called ask-test.txt with content "test"');
|
||||
await run.type('\r');
|
||||
|
||||
// Wait for the FORCED confirmation prompt to appear
|
||||
// It should contain the system message from the hook
|
||||
await run.expectText('Confirmation forced by security hook', 15000);
|
||||
await run.expectText('Allow', 5000);
|
||||
|
||||
// 4. Approve the permission
|
||||
await run.type('y');
|
||||
await run.type('\r');
|
||||
|
||||
// Wait for command to execute
|
||||
await run.expectText('approved.txt', 15000);
|
||||
|
||||
// Should find the tool call
|
||||
const foundWriteFile = await rig.waitForToolCall('write_file');
|
||||
expect(foundWriteFile).toBeTruthy();
|
||||
|
||||
// File should be created
|
||||
const fileContent = rig.readFile('approved.txt');
|
||||
expect(fileContent).toBe('Approved content');
|
||||
},
|
||||
);
|
||||
|
||||
it('should allow cancelling when hook forces "ask" decision', async () => {
|
||||
const testName =
|
||||
'should allow cancelling when hook forces "ask" decision';
|
||||
const hookOutput = {
|
||||
decision: 'ask',
|
||||
systemMessage: 'Confirmation forced for cancellation test',
|
||||
hookSpecificOutput: {
|
||||
hookEventName: 'BeforeTool',
|
||||
},
|
||||
};
|
||||
|
||||
const hookScript = `console.log(JSON.stringify(${JSON.stringify(
|
||||
hookOutput,
|
||||
)}));`;
|
||||
|
||||
const scriptPath = join(
|
||||
os.tmpdir(),
|
||||
'gemini-cli-tests-ask-cancel-hook.js',
|
||||
);
|
||||
writeFileSync(scriptPath, hookScript);
|
||||
|
||||
rig.setup(testName, {
|
||||
fakeResponsesPath: join(
|
||||
import.meta.dirname,
|
||||
'hooks-system.allow-tool.responses',
|
||||
),
|
||||
settings: {
|
||||
debugMode: true,
|
||||
tools: {
|
||||
approval: 'yolo',
|
||||
},
|
||||
hooksConfig: {
|
||||
enabled: true,
|
||||
},
|
||||
hooks: {
|
||||
BeforeTool: [
|
||||
{
|
||||
matcher: 'write_file',
|
||||
hooks: [
|
||||
{
|
||||
type: 'command',
|
||||
command: `node "${scriptPath}"`,
|
||||
timeout: 5000,
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const run = await rig.runInteractive();
|
||||
|
||||
await run.type(
|
||||
'Create a file called cancel-test.txt with content "test"',
|
||||
);
|
||||
await run.type('\r');
|
||||
|
||||
await run.expectText('Confirmation forced for cancellation test', 15000);
|
||||
|
||||
// 4. Deny the permission using option 4
|
||||
await run.type('4');
|
||||
await run.type('\r');
|
||||
|
||||
// Wait for cancellation message
|
||||
await run.expectText('Cancelled', 10000);
|
||||
|
||||
// Tool should NOT be called successfully
|
||||
const toolLogs = rig.readToolLogs();
|
||||
const writeFileCalls = toolLogs.filter(
|
||||
(t) =>
|
||||
t.toolRequest.name === 'write_file' && t.toolRequest.success === true,
|
||||
);
|
||||
expect(writeFileCalls).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -451,7 +451,11 @@ export class Task {
|
||||
'Auto-approving all tool calls.',
|
||||
);
|
||||
toolCalls.forEach((tc: ToolCall) => {
|
||||
if (tc.status === 'awaiting_approval' && tc.confirmationDetails) {
|
||||
if (
|
||||
tc.status === 'awaiting_approval' &&
|
||||
tc.confirmationDetails &&
|
||||
!tc.request.forcedAsk
|
||||
) {
|
||||
const details = tc.confirmationDetails;
|
||||
if (isToolCallConfirmationDetails(details)) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
|
||||
@@ -679,6 +679,15 @@ export const ToolConfirmationMessage: React.FC<
|
||||
paddingTop={0}
|
||||
paddingBottom={handlesOwnUI ? 0 : 1}
|
||||
>
|
||||
{/* System message from hook */}
|
||||
{confirmationDetails.systemMessage && (
|
||||
<Box marginBottom={1}>
|
||||
<Text color={theme.status.warning}>
|
||||
{confirmationDetails.systemMessage}
|
||||
</Text>
|
||||
</Box>
|
||||
)}
|
||||
|
||||
{handlesOwnUI ? (
|
||||
bodyContent
|
||||
) : (
|
||||
|
||||
@@ -1577,7 +1577,7 @@ export const useGeminiStream = (
|
||||
) {
|
||||
let awaitingApprovalCalls = toolCalls.filter(
|
||||
(call): call is TrackedWaitingToolCall =>
|
||||
call.status === 'awaiting_approval',
|
||||
call.status === 'awaiting_approval' && !call.request.forcedAsk,
|
||||
);
|
||||
|
||||
// For AUTO_EDIT mode, only approve edit tools (replace, write_file)
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import {
|
||||
BaseToolInvocation,
|
||||
type ForcedToolDecision,
|
||||
type ToolConfirmationOutcome,
|
||||
type ToolResult,
|
||||
type ToolCallConfirmationDetails,
|
||||
@@ -134,6 +135,7 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
|
||||
|
||||
protected override async getConfirmationDetails(
|
||||
_abortSignal: AbortSignal,
|
||||
_forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
// For now, always require confirmation for remote agents until we have a policy system for them.
|
||||
return {
|
||||
|
||||
@@ -112,6 +112,7 @@ describe('SubAgentInvocation', () => {
|
||||
expect(result).toBe(false);
|
||||
expect(mockInnerInvocation.shouldConfirmExecute).toHaveBeenCalledWith(
|
||||
abortSignal,
|
||||
undefined,
|
||||
);
|
||||
expect(MockSubagentToolWrapper).toHaveBeenCalledWith(
|
||||
testDefinition,
|
||||
@@ -156,6 +157,7 @@ describe('SubAgentInvocation', () => {
|
||||
expect(result).toBe(confirmationDetails);
|
||||
expect(mockInnerInvocation.shouldConfirmExecute).toHaveBeenCalledWith(
|
||||
abortSignal,
|
||||
undefined,
|
||||
);
|
||||
expect(MockSubagentToolWrapper).toHaveBeenCalledWith(
|
||||
testRemoteDefinition,
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
Kind,
|
||||
type ToolInvocation,
|
||||
type ToolResult,
|
||||
type ForcedToolDecision,
|
||||
BaseToolInvocation,
|
||||
type ToolCallConfirmationDetails,
|
||||
isTool,
|
||||
@@ -145,12 +146,13 @@ class SubAgentInvocation extends BaseToolInvocation<AgentInputs, ToolResult> {
|
||||
|
||||
override async shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
const invocation = this.buildSubInvocation(
|
||||
this.definition,
|
||||
this.withUserHints(this.params),
|
||||
);
|
||||
return invocation.shouldConfirmExecute(abortSignal);
|
||||
return invocation.shouldConfirmExecute(abortSignal, forcedDecision);
|
||||
}
|
||||
|
||||
async execute(
|
||||
|
||||
@@ -76,12 +76,14 @@ export type SerializableConfirmationDetails =
|
||||
| {
|
||||
type: 'info';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
prompt: string;
|
||||
urls?: string[];
|
||||
}
|
||||
| {
|
||||
type: 'edit';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
fileName: string;
|
||||
filePath: string;
|
||||
fileDiff: string;
|
||||
@@ -92,6 +94,7 @@ export type SerializableConfirmationDetails =
|
||||
| {
|
||||
type: 'exec';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
command: string;
|
||||
rootCommand: string;
|
||||
rootCommands: string[];
|
||||
@@ -100,6 +103,7 @@ export type SerializableConfirmationDetails =
|
||||
| {
|
||||
type: 'mcp';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
serverName: string;
|
||||
toolName: string;
|
||||
toolDisplayName: string;
|
||||
@@ -110,11 +114,13 @@ export type SerializableConfirmationDetails =
|
||||
| {
|
||||
type: 'ask_user';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
questions: Question[];
|
||||
}
|
||||
| {
|
||||
type: 'exit_plan_mode';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
planPath: string;
|
||||
};
|
||||
|
||||
|
||||
@@ -15,10 +15,7 @@ import {
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import type { HookSystem } from '../hooks/hookSystem.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import {
|
||||
type DefaultHookOutput,
|
||||
BeforeToolHookOutput,
|
||||
} from '../hooks/types.js';
|
||||
import { type DefaultHookOutput } from '../hooks/types.js';
|
||||
|
||||
class MockInvocation extends BaseToolInvocation<{ key?: string }, ToolResult> {
|
||||
constructor(params: { key?: string }, messageBus: MessageBus) {
|
||||
@@ -66,70 +63,11 @@ describe('executeToolWithHooks', () => {
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
});
|
||||
|
||||
it('should prioritize continue: false over decision: block in BeforeTool', async () => {
|
||||
const invocation = new MockInvocation({}, messageBus);
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({
|
||||
shouldStopExecution: () => true,
|
||||
getEffectiveReason: () => 'Stop immediately',
|
||||
getBlockingError: () => ({
|
||||
blocked: false,
|
||||
reason: 'Should be ignored because continue is false',
|
||||
}),
|
||||
} as unknown as DefaultHookOutput);
|
||||
|
||||
const result = await executeToolWithHooks(
|
||||
invocation,
|
||||
'test_tool',
|
||||
abortSignal,
|
||||
mockTool,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(result.error?.type).toBe(ToolErrorType.STOP_EXECUTION);
|
||||
expect(result.error?.message).toBe('Stop immediately');
|
||||
});
|
||||
|
||||
it('should block execution in BeforeTool if decision is block', async () => {
|
||||
const invocation = new MockInvocation({}, messageBus);
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({
|
||||
shouldStopExecution: () => false,
|
||||
getEffectiveReason: () => '',
|
||||
getBlockingError: () => ({ blocked: true, reason: 'Execution blocked' }),
|
||||
} as unknown as DefaultHookOutput);
|
||||
|
||||
const result = await executeToolWithHooks(
|
||||
invocation,
|
||||
'test_tool',
|
||||
abortSignal,
|
||||
mockTool,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED);
|
||||
expect(result.error?.message).toBe('Execution blocked');
|
||||
});
|
||||
|
||||
it('should handle continue: false in AfterTool', async () => {
|
||||
const invocation = new MockInvocation({}, messageBus);
|
||||
const abortSignal = new AbortController().signal;
|
||||
const spy = vi.spyOn(invocation, 'execute');
|
||||
|
||||
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({
|
||||
shouldStopExecution: () => false,
|
||||
getEffectiveReason: () => '',
|
||||
getBlockingError: () => ({ blocked: false, reason: '' }),
|
||||
} as unknown as DefaultHookOutput);
|
||||
|
||||
vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue({
|
||||
shouldStopExecution: () => true,
|
||||
getEffectiveReason: () => 'Stop after execution',
|
||||
@@ -156,12 +94,6 @@ describe('executeToolWithHooks', () => {
|
||||
const invocation = new MockInvocation({}, messageBus);
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({
|
||||
shouldStopExecution: () => false,
|
||||
getEffectiveReason: () => '',
|
||||
getBlockingError: () => ({ blocked: false, reason: '' }),
|
||||
} as unknown as DefaultHookOutput);
|
||||
|
||||
vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue({
|
||||
shouldStopExecution: () => false,
|
||||
getEffectiveReason: () => '',
|
||||
@@ -182,80 +114,4 @@ describe('executeToolWithHooks', () => {
|
||||
expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED);
|
||||
expect(result.error?.message).toBe('Result denied');
|
||||
});
|
||||
|
||||
it('should apply modified tool input from BeforeTool hook', async () => {
|
||||
const params = { key: 'original' };
|
||||
const invocation = new MockInvocation(params, messageBus);
|
||||
const toolName = 'test-tool';
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
const mockBeforeOutput = new BeforeToolHookOutput({
|
||||
continue: true,
|
||||
hookSpecificOutput: {
|
||||
hookEventName: 'BeforeTool',
|
||||
tool_input: { key: 'modified' },
|
||||
},
|
||||
});
|
||||
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue(
|
||||
mockBeforeOutput,
|
||||
);
|
||||
|
||||
vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue(undefined);
|
||||
|
||||
const result = await executeToolWithHooks(
|
||||
invocation,
|
||||
toolName,
|
||||
abortSignal,
|
||||
mockTool,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
// Verify result reflects modified input
|
||||
expect(result.llmContent).toBe(
|
||||
'key: modified\n\n[System] Tool input parameters (key) were modified by a hook before execution.',
|
||||
);
|
||||
// Verify params object was modified in place
|
||||
expect(invocation.params.key).toBe('modified');
|
||||
|
||||
expect(mockHookSystem.fireBeforeToolEvent).toHaveBeenCalled();
|
||||
expect(mockTool.build).toHaveBeenCalledWith({ key: 'modified' });
|
||||
});
|
||||
|
||||
it('should not modify input if hook does not provide tool_input', async () => {
|
||||
const params = { key: 'original' };
|
||||
const invocation = new MockInvocation(params, messageBus);
|
||||
const toolName = 'test-tool';
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
const mockBeforeOutput = new BeforeToolHookOutput({
|
||||
continue: true,
|
||||
hookSpecificOutput: {
|
||||
hookEventName: 'BeforeTool',
|
||||
// No tool input
|
||||
},
|
||||
});
|
||||
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue(
|
||||
mockBeforeOutput,
|
||||
);
|
||||
|
||||
vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue(undefined);
|
||||
|
||||
const result = await executeToolWithHooks(
|
||||
invocation,
|
||||
toolName,
|
||||
abortSignal,
|
||||
mockTool,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(result.llmContent).toBe('key: original');
|
||||
expect(invocation.params.key).toBe('original');
|
||||
expect(mockTool.build).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type McpToolContext, BeforeToolHookOutput } from '../hooks/types.js';
|
||||
import { type McpToolContext } from '../hooks/types.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type {
|
||||
ToolResult,
|
||||
@@ -13,7 +13,6 @@ import type {
|
||||
ToolLiveOutput,
|
||||
} from '../tools/tools.js';
|
||||
import { ToolErrorType } from '../tools/tool-error.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { ShellExecutionConfig } from '../index.js';
|
||||
import { ShellToolInvocation } from '../tools/shell.js';
|
||||
import { DiscoveredMCPToolInvocation } from '../tools/mcp-tool.js';
|
||||
@@ -25,7 +24,7 @@ import { DiscoveredMCPToolInvocation } from '../tools/mcp-tool.js';
|
||||
* @param config Config to look up server details
|
||||
* @returns MCP context if this is an MCP tool, undefined otherwise
|
||||
*/
|
||||
function extractMcpContext(
|
||||
export function extractMcpContext(
|
||||
invocation: ShellToolInvocation | AnyToolInvocation,
|
||||
config: Config,
|
||||
): McpToolContext | undefined {
|
||||
@@ -78,81 +77,12 @@ export async function executeToolWithHooks(
|
||||
config?: Config,
|
||||
originalRequestName?: string,
|
||||
): Promise<ToolResult> {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const toolInput = (invocation.params || {}) as Record<string, unknown>;
|
||||
let inputWasModified = false;
|
||||
let modifiedKeys: string[] = [];
|
||||
|
||||
// Extract MCP context if this is an MCP tool (only if config is provided)
|
||||
const mcpContext = config ? extractMcpContext(invocation, config) : undefined;
|
||||
|
||||
const hookSystem = config?.getHookSystem();
|
||||
if (hookSystem) {
|
||||
const beforeOutput = await hookSystem.fireBeforeToolEvent(
|
||||
toolName,
|
||||
toolInput,
|
||||
mcpContext,
|
||||
originalRequestName,
|
||||
);
|
||||
|
||||
// Check if hook requested to stop entire agent execution
|
||||
if (beforeOutput?.shouldStopExecution()) {
|
||||
const reason = beforeOutput.getEffectiveReason();
|
||||
return {
|
||||
llmContent: `Agent execution stopped by hook: ${reason}`,
|
||||
returnDisplay: `Agent execution stopped by hook: ${reason}`,
|
||||
error: {
|
||||
type: ToolErrorType.STOP_EXECUTION,
|
||||
message: reason,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Check if hook blocked the tool execution
|
||||
const blockingError = beforeOutput?.getBlockingError();
|
||||
if (blockingError?.blocked) {
|
||||
return {
|
||||
llmContent: `Tool execution blocked: ${blockingError.reason}`,
|
||||
returnDisplay: `Tool execution blocked: ${blockingError.reason}`,
|
||||
error: {
|
||||
type: ToolErrorType.EXECUTION_FAILED,
|
||||
message: blockingError.reason,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Check if hook requested to update tool input
|
||||
if (beforeOutput instanceof BeforeToolHookOutput) {
|
||||
const modifiedInput = beforeOutput.getModifiedToolInput();
|
||||
if (modifiedInput) {
|
||||
// We modify the toolInput object in-place, which should be the same reference as invocation.params
|
||||
// We use Object.assign to update properties
|
||||
Object.assign(invocation.params, modifiedInput);
|
||||
debugLogger.debug(`Tool input modified by hook for ${toolName}`);
|
||||
inputWasModified = true;
|
||||
modifiedKeys = Object.keys(modifiedInput);
|
||||
|
||||
// Recreate the invocation with the new parameters
|
||||
// to ensure any derived state (like resolvedPath in ReadFileTool) is updated.
|
||||
try {
|
||||
// We use the tool's build method to validate and create the invocation
|
||||
// This ensures consistent behavior with the initial creation
|
||||
invocation = tool.build(invocation.params);
|
||||
} catch (error) {
|
||||
return {
|
||||
llmContent: `Tool parameter modification by hook failed validation: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
returnDisplay: `Tool parameter modification by hook failed validation.`,
|
||||
error: {
|
||||
type: ToolErrorType.INVALID_TOOL_PARAMS,
|
||||
message: String(error),
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const toolInput = (invocation.params || {}) as Record<string, unknown>;
|
||||
|
||||
// Execute the actual tool
|
||||
let toolResult: ToolResult;
|
||||
@@ -171,24 +101,6 @@ export async function executeToolWithHooks(
|
||||
);
|
||||
}
|
||||
|
||||
// Append notification if parameters were modified
|
||||
if (inputWasModified) {
|
||||
const modificationMsg = `\n\n[System] Tool input parameters (${modifiedKeys.join(
|
||||
', ',
|
||||
)}) were modified by a hook before execution.`;
|
||||
if (typeof toolResult.llmContent === 'string') {
|
||||
toolResult.llmContent += modificationMsg;
|
||||
} else if (Array.isArray(toolResult.llmContent)) {
|
||||
toolResult.llmContent.push({ text: modificationMsg });
|
||||
} else if (toolResult.llmContent) {
|
||||
// Handle single Part case by converting to an array
|
||||
toolResult.llmContent = [
|
||||
toolResult.llmContent,
|
||||
{ text: modificationMsg },
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
if (hookSystem) {
|
||||
const afterOutput = await hookSystem.fireAfterToolEvent(
|
||||
toolName,
|
||||
|
||||
@@ -285,6 +285,7 @@ function createMockConfig(overrides: Partial<Config> = {}): Config {
|
||||
getGeminiClient: () => null,
|
||||
getMessageBus: () => createMockMessageBus(),
|
||||
getEnableHooks: () => false,
|
||||
getHookSystem: () => undefined,
|
||||
getExperiments: () => {},
|
||||
} as unknown as Config;
|
||||
|
||||
@@ -1016,7 +1017,12 @@ describe('CoreToolScheduler YOLO mode', () => {
|
||||
|
||||
// Assert
|
||||
// 1. The tool's execute method was called directly.
|
||||
expect(executeFn).toHaveBeenCalledWith({ param: 'value' });
|
||||
expect(executeFn).toHaveBeenCalledWith(
|
||||
{ param: 'value' },
|
||||
expect.anything(),
|
||||
undefined,
|
||||
expect.anything(),
|
||||
);
|
||||
|
||||
// 2. The tool call status never entered CoreToolCallStatus.AwaitingApproval.
|
||||
const statusUpdates = onToolCallsUpdate.mock.calls
|
||||
@@ -1119,7 +1125,12 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
);
|
||||
|
||||
// Ensure the second tool call hasn't been executed yet.
|
||||
expect(executeFn).toHaveBeenCalledWith({ a: 1 });
|
||||
expect(executeFn).toHaveBeenCalledWith(
|
||||
{ a: 1 },
|
||||
expect.anything(),
|
||||
undefined,
|
||||
expect.anything(),
|
||||
);
|
||||
|
||||
// Complete the first tool call.
|
||||
resolveFirstCall!({
|
||||
@@ -1143,7 +1154,12 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
// Now the second tool call should have been executed.
|
||||
expect(executeFn).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
expect(executeFn).toHaveBeenCalledWith({ b: 2 });
|
||||
expect(executeFn).toHaveBeenCalledWith(
|
||||
{ b: 2 },
|
||||
expect.anything(),
|
||||
undefined,
|
||||
expect.anything(),
|
||||
);
|
||||
|
||||
// Wait for the second completion.
|
||||
await vi.waitFor(() => {
|
||||
@@ -1237,7 +1253,12 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
|
||||
// Assert
|
||||
// 1. The tool's execute method was called directly.
|
||||
expect(executeFn).toHaveBeenCalledWith({ param: 'value' });
|
||||
expect(executeFn).toHaveBeenCalledWith(
|
||||
{ param: 'value' },
|
||||
expect.anything(),
|
||||
undefined,
|
||||
expect.anything(),
|
||||
);
|
||||
|
||||
// 2. The tool call status never entered CoreToolCallStatus.AwaitingApproval.
|
||||
const statusUpdates = onToolCallsUpdate.mock.calls
|
||||
@@ -1418,8 +1439,18 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
|
||||
// Ensure the tool was called twice with the correct arguments.
|
||||
expect(executeFn).toHaveBeenCalledTimes(2);
|
||||
expect(executeFn).toHaveBeenCalledWith({ a: 1 });
|
||||
expect(executeFn).toHaveBeenCalledWith({ b: 2 });
|
||||
expect(executeFn).toHaveBeenCalledWith(
|
||||
{ a: 1 },
|
||||
expect.anything(),
|
||||
undefined,
|
||||
expect.anything(),
|
||||
);
|
||||
expect(executeFn).toHaveBeenCalledWith(
|
||||
{ b: 2 },
|
||||
expect.anything(),
|
||||
undefined,
|
||||
expect.anything(),
|
||||
);
|
||||
|
||||
// Ensure completion callbacks were called twice.
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2);
|
||||
@@ -1776,8 +1807,18 @@ describe('CoreToolScheduler Sequential Execution', () => {
|
||||
|
||||
// Check that execute was called for the first two tools only
|
||||
expect(executeFn).toHaveBeenCalledTimes(2);
|
||||
expect(executeFn).toHaveBeenCalledWith({ call: 1 });
|
||||
expect(executeFn).toHaveBeenCalledWith({ call: 2 });
|
||||
expect(executeFn).toHaveBeenCalledWith(
|
||||
{ call: 1 },
|
||||
expect.anything(),
|
||||
undefined,
|
||||
expect.anything(),
|
||||
);
|
||||
expect(executeFn).toHaveBeenCalledWith(
|
||||
{ call: 2 },
|
||||
expect.anything(),
|
||||
undefined,
|
||||
expect.anything(),
|
||||
);
|
||||
|
||||
const completedCalls = onAllToolCallsComplete.mock
|
||||
.calls[0][0] as ToolCall[];
|
||||
|
||||
@@ -50,6 +50,8 @@ import { ToolExecutor } from '../scheduler/tool-executor.js';
|
||||
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
||||
import { getPolicyDenialError } from '../scheduler/policy.js';
|
||||
import { GeminiCliOperation } from '../telemetry/constants.js';
|
||||
import { extractMcpContext } from './coreToolHookTriggers.js';
|
||||
import { BeforeToolHookOutput } from '../hooks/types.js';
|
||||
|
||||
export type {
|
||||
ToolCall,
|
||||
@@ -604,7 +606,7 @@ export class CoreToolScheduler {
|
||||
return;
|
||||
}
|
||||
|
||||
const toolCall = this.toolCallQueue.shift()!;
|
||||
let toolCall = this.toolCallQueue.shift()!;
|
||||
|
||||
// This is now the single active tool call.
|
||||
this.toolCalls = [toolCall];
|
||||
@@ -620,7 +622,8 @@ export class CoreToolScheduler {
|
||||
|
||||
// This logic is moved from the old `for` loop in `_schedule`.
|
||||
if (toolCall.status === CoreToolCallStatus.Validating) {
|
||||
const { request: reqInfo, invocation } = toolCall;
|
||||
const { request: reqInfo } = toolCall;
|
||||
let { invocation } = toolCall;
|
||||
|
||||
try {
|
||||
if (signal.aborted) {
|
||||
@@ -635,6 +638,90 @@ export class CoreToolScheduler {
|
||||
return;
|
||||
}
|
||||
|
||||
// 1. Hook Check (BeforeTool)
|
||||
let hookDecision: 'ask' | 'block' | undefined;
|
||||
let hookSystemMessage: string | undefined;
|
||||
|
||||
const hookSystem = this.config.getHookSystem();
|
||||
if (hookSystem) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const toolInput = (invocation.params || {}) as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
const mcpContext = extractMcpContext(invocation, this.config);
|
||||
|
||||
const beforeOutput = await hookSystem.fireBeforeToolEvent(
|
||||
toolCall.request.name,
|
||||
toolInput,
|
||||
mcpContext,
|
||||
toolCall.request.originalRequestName,
|
||||
);
|
||||
|
||||
if (beforeOutput) {
|
||||
// Check if hook requested to stop entire agent execution
|
||||
if (beforeOutput.shouldStopExecution()) {
|
||||
const reason = beforeOutput.getEffectiveReason();
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
CoreToolCallStatus.Error,
|
||||
signal,
|
||||
createErrorResponse(
|
||||
reqInfo,
|
||||
new Error(`Agent execution stopped by hook: ${reason}`),
|
||||
ToolErrorType.STOP_EXECUTION,
|
||||
),
|
||||
);
|
||||
await this.checkAndNotifyCompletion(signal);
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if hook blocked the tool execution
|
||||
const blockingError = beforeOutput.getBlockingError();
|
||||
if (blockingError?.blocked) {
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
CoreToolCallStatus.Error,
|
||||
signal,
|
||||
createErrorResponse(
|
||||
reqInfo,
|
||||
new Error(`Tool execution blocked: ${blockingError.reason}`),
|
||||
ToolErrorType.POLICY_VIOLATION,
|
||||
),
|
||||
);
|
||||
await this.checkAndNotifyCompletion(signal);
|
||||
return;
|
||||
}
|
||||
|
||||
if (beforeOutput.isAskDecision()) {
|
||||
hookDecision = 'ask';
|
||||
hookSystemMessage = beforeOutput.systemMessage;
|
||||
// Mark the request so UI knows not to auto-approve it
|
||||
toolCall.request.forcedAsk = true;
|
||||
}
|
||||
|
||||
// Check if hook requested to update tool input
|
||||
if (beforeOutput instanceof BeforeToolHookOutput) {
|
||||
const modifiedInput = beforeOutput.getModifiedToolInput();
|
||||
if (modifiedInput) {
|
||||
this.setArgsInternal(reqInfo.callId, modifiedInput);
|
||||
|
||||
// IMPORTANT: toolCall and invocation must be updated because setArgsInternal created a new one
|
||||
const updatedCall = this.toolCalls.find(
|
||||
(c) => c.request.callId === reqInfo.callId,
|
||||
);
|
||||
if (updatedCall) {
|
||||
toolCall = updatedCall;
|
||||
toolCall.request.inputModifiedByHook = true;
|
||||
if ('invocation' in updatedCall) {
|
||||
invocation = updatedCall.invocation;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Policy Check using PolicyEngine
|
||||
// We must reconstruct the FunctionCall format expected by PolicyEngine
|
||||
const toolCallForPolicy = {
|
||||
@@ -645,13 +732,18 @@ export class CoreToolScheduler {
|
||||
toolCall.tool instanceof DiscoveredMCPTool
|
||||
? toolCall.tool.serverName
|
||||
: undefined;
|
||||
const toolAnnotations = toolCall.tool.toolAnnotations;
|
||||
const toolAnnotations = toolCall.tool?.toolAnnotations;
|
||||
|
||||
const { decision, rule } = await this.config
|
||||
const { decision: policyDecision, rule } = await this.config
|
||||
.getPolicyEngine()
|
||||
.check(toolCallForPolicy, serverName, toolAnnotations);
|
||||
|
||||
if (decision === PolicyDecision.DENY) {
|
||||
let finalDecision = policyDecision;
|
||||
if (hookDecision === 'ask') {
|
||||
finalDecision = PolicyDecision.ASK_USER;
|
||||
}
|
||||
|
||||
if (finalDecision === PolicyDecision.DENY) {
|
||||
const { errorMessage, errorType } = getPolicyDenialError(
|
||||
this.config,
|
||||
rule,
|
||||
@@ -666,7 +758,7 @@ export class CoreToolScheduler {
|
||||
return;
|
||||
}
|
||||
|
||||
if (decision === PolicyDecision.ALLOW) {
|
||||
if (finalDecision === PolicyDecision.ALLOW) {
|
||||
this.setToolCallOutcome(
|
||||
reqInfo.callId,
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
@@ -677,11 +769,13 @@ export class CoreToolScheduler {
|
||||
signal,
|
||||
);
|
||||
} else {
|
||||
// PolicyDecision.ASK_USER
|
||||
// PolicyDecision.ASK_USER or forced 'ask' by hook
|
||||
|
||||
// We need confirmation details to show to the user
|
||||
const confirmationDetails =
|
||||
await invocation.shouldConfirmExecute(signal);
|
||||
const confirmationDetails = await invocation.shouldConfirmExecute(
|
||||
signal,
|
||||
hookDecision === 'ask' ? 'ask_user' : undefined,
|
||||
);
|
||||
|
||||
if (!confirmationDetails) {
|
||||
this.setToolCallOutcome(
|
||||
@@ -697,11 +791,17 @@ export class CoreToolScheduler {
|
||||
if (!this.config.isInteractive()) {
|
||||
throw new Error(
|
||||
`Tool execution for "${
|
||||
toolCall.tool.displayName || toolCall.tool.name
|
||||
toolCall.tool?.displayName ||
|
||||
toolCall.tool?.name ||
|
||||
toolCall.request.name
|
||||
}" requires user confirmation, which is not supported in non-interactive mode.`,
|
||||
);
|
||||
}
|
||||
|
||||
if (hookSystemMessage) {
|
||||
confirmationDetails.systemMessage = hookSystemMessage;
|
||||
}
|
||||
|
||||
// Fire Notification hook before showing confirmation to user
|
||||
const hookSystem = this.config.getHookSystem();
|
||||
if (hookSystem) {
|
||||
|
||||
300
packages/core/src/core/coreToolSchedulerHooks.test.ts
Normal file
300
packages/core/src/core/coreToolSchedulerHooks.test.ts
Normal file
@@ -0,0 +1,300 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { CoreToolScheduler } from './coreToolScheduler.js';
|
||||
import type { ToolCall, ErroredToolCall } from '../scheduler/types.js';
|
||||
import type { Config, ToolRegistry } from '../index.js';
|
||||
import {
|
||||
ApprovalMode,
|
||||
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
|
||||
} from '../index.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
import { MockTool } from '../test-utils/mock-tool.js';
|
||||
import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
|
||||
import type { PolicyEngine } from '../policy/policy-engine.js';
|
||||
import type { HookSystem } from '../hooks/hookSystem.js';
|
||||
import { BeforeToolHookOutput } from '../hooks/types.js';
|
||||
|
||||
function createMockConfig(overrides: Partial<Config> = {}): Config {
|
||||
const defaultToolRegistry = {
|
||||
getTool: () => undefined,
|
||||
getToolByName: () => undefined,
|
||||
getFunctionDeclarations: () => [],
|
||||
tools: new Map(),
|
||||
discovery: {},
|
||||
registerTool: () => {},
|
||||
getToolByDisplayName: () => undefined,
|
||||
getTools: () => [],
|
||||
discoverTools: async () => {},
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
getExperiments: () => {},
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const baseConfig = {
|
||||
getSessionId: () => 'test-session-id',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
isInteractive: () => true,
|
||||
getApprovalMode: () => ApprovalMode.DEFAULT,
|
||||
setApprovalMode: () => {},
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
terminalHeight: 30,
|
||||
sanitizationConfig: {
|
||||
enableEnvironmentVariableRedaction: true,
|
||||
allowedEnvironmentVariables: [],
|
||||
blockedEnvironmentVariables: [],
|
||||
},
|
||||
}),
|
||||
storage: {
|
||||
getProjectTempDir: () => '/tmp',
|
||||
},
|
||||
getTruncateToolOutputThreshold: () =>
|
||||
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
|
||||
getTruncateToolOutputLines: () => 1000,
|
||||
getToolRegistry: () => defaultToolRegistry,
|
||||
getActiveModel: () => DEFAULT_GEMINI_MODEL,
|
||||
getGeminiClient: () => null,
|
||||
getMessageBus: () => createMockMessageBus(),
|
||||
getEnableHooks: () => true, // Enabled for these tests
|
||||
getExperiments: () => {},
|
||||
getPolicyEngine: () =>
|
||||
({
|
||||
check: async () => ({ decision: 'allow' }), // Default allow for hook tests
|
||||
}) as unknown as PolicyEngine,
|
||||
} as unknown as Config;
|
||||
|
||||
return { ...baseConfig, ...overrides } as Config;
|
||||
}
|
||||
|
||||
describe('CoreToolScheduler Hooks', () => {
|
||||
it('should stop execution if BeforeTool hook requests stop', async () => {
|
||||
const executeFn = vi.fn().mockResolvedValue({
|
||||
llmContent: 'Tool executed',
|
||||
returnDisplay: 'Tool executed',
|
||||
});
|
||||
const mockTool = new MockTool({ name: 'mockTool', execute: executeFn });
|
||||
|
||||
const toolRegistry = {
|
||||
getTool: () => mockTool,
|
||||
getToolByName: () => mockTool,
|
||||
getFunctionDeclarations: () => [],
|
||||
tools: new Map(),
|
||||
discovery: {},
|
||||
registerTool: () => {},
|
||||
getToolByDisplayName: () => mockTool,
|
||||
getTools: () => [],
|
||||
discoverTools: async () => {},
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const mockMessageBus = createMockMessageBus();
|
||||
const mockHookSystem = {
|
||||
fireBeforeToolEvent: vi.fn().mockResolvedValue({
|
||||
shouldStopExecution: () => true,
|
||||
getEffectiveReason: () => 'Hook stopped execution',
|
||||
getBlockingError: () => ({ blocked: false }),
|
||||
isAskDecision: () => false,
|
||||
}),
|
||||
} as unknown as HookSystem;
|
||||
|
||||
const mockConfig = createMockConfig({
|
||||
getToolRegistry: () => toolRegistry,
|
||||
getMessageBus: () => mockMessageBus,
|
||||
getHookSystem: () => mockHookSystem,
|
||||
getApprovalMode: () => ApprovalMode.YOLO,
|
||||
});
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config: mockConfig,
|
||||
onAllToolCallsComplete,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
});
|
||||
|
||||
const request = {
|
||||
callId: '1',
|
||||
name: 'mockTool',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-1',
|
||||
};
|
||||
|
||||
await scheduler.schedule([request], new AbortController().signal);
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
const completedCalls = onAllToolCallsComplete.mock
|
||||
.calls[0][0] as ToolCall[];
|
||||
expect(completedCalls[0].status).toBe('error');
|
||||
const erroredCall = completedCalls[0] as ErroredToolCall;
|
||||
|
||||
// Check error type/message
|
||||
expect(erroredCall.response.error?.message).toContain(
|
||||
'Hook stopped execution',
|
||||
);
|
||||
expect(executeFn).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should block tool execution if BeforeTool hook requests block', async () => {
|
||||
const executeFn = vi.fn();
|
||||
const mockTool = new MockTool({ name: 'mockTool', execute: executeFn });
|
||||
|
||||
const toolRegistry = {
|
||||
getTool: () => mockTool,
|
||||
getToolByName: () => mockTool,
|
||||
getFunctionDeclarations: () => [],
|
||||
tools: new Map(),
|
||||
discovery: {},
|
||||
registerTool: () => {},
|
||||
getToolByDisplayName: () => mockTool,
|
||||
getTools: () => [],
|
||||
discoverTools: async () => {},
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const mockMessageBus = createMockMessageBus();
|
||||
const mockHookSystem = {
|
||||
fireBeforeToolEvent: vi.fn().mockResolvedValue({
|
||||
shouldStopExecution: () => false,
|
||||
getBlockingError: () => ({
|
||||
blocked: true,
|
||||
reason: 'Hook blocked execution',
|
||||
}),
|
||||
isAskDecision: () => false,
|
||||
}),
|
||||
} as unknown as HookSystem;
|
||||
|
||||
const mockConfig = createMockConfig({
|
||||
getToolRegistry: () => toolRegistry,
|
||||
getMessageBus: () => mockMessageBus,
|
||||
getHookSystem: () => mockHookSystem,
|
||||
getApprovalMode: () => ApprovalMode.YOLO,
|
||||
});
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config: mockConfig,
|
||||
onAllToolCallsComplete,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
});
|
||||
|
||||
const request = {
|
||||
callId: '1',
|
||||
name: 'mockTool',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-1',
|
||||
};
|
||||
|
||||
await scheduler.schedule([request], new AbortController().signal);
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
const completedCalls = onAllToolCallsComplete.mock
|
||||
.calls[0][0] as ToolCall[];
|
||||
expect(completedCalls[0].status).toBe('error');
|
||||
const erroredCall = completedCalls[0] as ErroredToolCall;
|
||||
expect(erroredCall.response.error?.message).toContain(
|
||||
'Hook blocked execution',
|
||||
);
|
||||
expect(executeFn).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should update tool input if BeforeTool hook provides modified input', async () => {
|
||||
const executeFn = vi.fn().mockResolvedValue({
|
||||
llmContent: 'Tool executed',
|
||||
returnDisplay: 'Tool executed',
|
||||
});
|
||||
const mockTool = new MockTool({ name: 'mockTool', execute: executeFn });
|
||||
|
||||
const toolRegistry = {
|
||||
getTool: () => mockTool,
|
||||
getToolByName: () => mockTool,
|
||||
getFunctionDeclarations: () => [],
|
||||
tools: new Map(),
|
||||
discovery: {},
|
||||
registerTool: () => {},
|
||||
getToolByDisplayName: () => mockTool,
|
||||
getTools: () => [],
|
||||
discoverTools: async () => {},
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const mockMessageBus = createMockMessageBus();
|
||||
const mockBeforeOutput = new BeforeToolHookOutput({
|
||||
continue: true,
|
||||
hookSpecificOutput: {
|
||||
hookEventName: 'BeforeTool',
|
||||
tool_input: { newParam: 'modifiedValue' },
|
||||
},
|
||||
});
|
||||
|
||||
const mockHookSystem = {
|
||||
fireBeforeToolEvent: vi.fn().mockResolvedValue(mockBeforeOutput),
|
||||
fireAfterToolEvent: vi.fn(),
|
||||
} as unknown as HookSystem;
|
||||
|
||||
const mockConfig = createMockConfig({
|
||||
getToolRegistry: () => toolRegistry,
|
||||
getMessageBus: () => mockMessageBus,
|
||||
getHookSystem: () => mockHookSystem,
|
||||
getApprovalMode: () => ApprovalMode.YOLO,
|
||||
});
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config: mockConfig,
|
||||
onAllToolCallsComplete,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
});
|
||||
|
||||
const request = {
|
||||
callId: '1',
|
||||
name: 'mockTool',
|
||||
args: { originalParam: 'originalValue' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-1',
|
||||
};
|
||||
|
||||
await scheduler.schedule([request], new AbortController().signal);
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
const completedCalls = onAllToolCallsComplete.mock
|
||||
.calls[0][0] as ToolCall[];
|
||||
expect(completedCalls[0].status).toBe('success');
|
||||
|
||||
// Verify execute was called with modified args
|
||||
expect(executeFn).toHaveBeenCalledWith(
|
||||
{ newParam: 'modifiedValue' },
|
||||
expect.anything(),
|
||||
undefined,
|
||||
expect.anything(),
|
||||
);
|
||||
|
||||
// Verify call request args were updated in the completion report
|
||||
expect(completedCalls[0].request.args).toEqual({
|
||||
newParam: 'modifiedValue',
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
import type {
|
||||
ToolCallConfirmationDetails,
|
||||
ToolResult,
|
||||
ForcedToolDecision,
|
||||
} from '../tools/tools.js';
|
||||
import { getResponseText } from '../utils/partUtils.js';
|
||||
import { reportError } from '../utils/errorReporting.js';
|
||||
@@ -46,6 +47,7 @@ export interface ServerTool {
|
||||
shouldConfirmExecute(
|
||||
params: Record<string, unknown>,
|
||||
abortSignal: AbortSignal,
|
||||
forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolCallConfirmationDetails | false>;
|
||||
}
|
||||
|
||||
|
||||
@@ -125,6 +125,7 @@ export class HookAggregator {
|
||||
const additionalContexts: string[] = [];
|
||||
|
||||
let hasBlockDecision = false;
|
||||
let hasAskDecision = false;
|
||||
let hasContinueFalse = false;
|
||||
|
||||
for (const output of outputs) {
|
||||
@@ -142,6 +143,12 @@ export class HookAggregator {
|
||||
if (tempOutput.isBlockingDecision()) {
|
||||
hasBlockDecision = true;
|
||||
merged.decision = output.decision;
|
||||
} else if (tempOutput.isAskDecision()) {
|
||||
hasAskDecision = true;
|
||||
// Ask decision is only set if no blocking decision was found so far
|
||||
if (!hasBlockDecision) {
|
||||
merged.decision = output.decision;
|
||||
}
|
||||
}
|
||||
|
||||
// Collect messages
|
||||
@@ -180,8 +187,8 @@ export class HookAggregator {
|
||||
this.extractAdditionalContext(output, additionalContexts);
|
||||
}
|
||||
|
||||
// Set final decision if no blocking decision was found
|
||||
if (!hasBlockDecision && !hasContinueFalse) {
|
||||
// Set final decision if no blocking or ask decision was found
|
||||
if (!hasBlockDecision && !hasAskDecision && !hasContinueFalse) {
|
||||
merged.decision = 'allow';
|
||||
}
|
||||
|
||||
|
||||
@@ -197,12 +197,19 @@ export class DefaultHookOutput implements HookOutput {
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if this output represents a blocking decision
|
||||
* Check if this output represents a blocking decision (block or deny)
|
||||
*/
|
||||
isBlockingDecision(): boolean {
|
||||
return this.decision === 'block' || this.decision === 'deny';
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if this output represents an 'ask' decision
|
||||
*/
|
||||
isAskDecision(): boolean {
|
||||
return this.decision === 'ask';
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if this output requests to stop execution
|
||||
*/
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
ToolConfirmationOutcome,
|
||||
type ToolConfirmationPayload,
|
||||
type ToolCallConfirmationDetails,
|
||||
type ForcedToolDecision,
|
||||
} from '../tools/tools.js';
|
||||
import {
|
||||
type ValidatingToolCall,
|
||||
@@ -116,6 +117,8 @@ export async function resolveConfirmation(
|
||||
getPreferredEditor: () => EditorType | undefined;
|
||||
schedulerId: string;
|
||||
onWaitingForConfirmation?: (waiting: boolean) => void;
|
||||
systemMessage?: string;
|
||||
forcedDecision?: ForcedToolDecision;
|
||||
},
|
||||
): Promise<ResolutionResult> {
|
||||
const { state, onWaitingForConfirmation } = deps;
|
||||
@@ -126,7 +129,7 @@ export async function resolveConfirmation(
|
||||
// Loop exists to allow the user to modify the parameters and see the new
|
||||
// diff.
|
||||
while (outcome === ToolConfirmationOutcome.ModifyWithEditor) {
|
||||
if (signal.aborted) throw new Error('Operation cancelled');
|
||||
if (signal.aborted) throw new Error('Operation cancelled by user');
|
||||
|
||||
const currentCall = state.getToolCall(callId);
|
||||
if (!currentCall || !('invocation' in currentCall)) {
|
||||
@@ -134,12 +137,19 @@ export async function resolveConfirmation(
|
||||
}
|
||||
const currentInvocation = currentCall.invocation;
|
||||
|
||||
const details = await currentInvocation.shouldConfirmExecute(signal);
|
||||
const details = await currentInvocation.shouldConfirmExecute(
|
||||
signal,
|
||||
deps.forcedDecision,
|
||||
);
|
||||
if (!details) {
|
||||
outcome = ToolConfirmationOutcome.ProceedOnce;
|
||||
break;
|
||||
}
|
||||
|
||||
if (deps.systemMessage) {
|
||||
details.systemMessage = deps.systemMessage;
|
||||
}
|
||||
|
||||
await notifyHooks(deps, details);
|
||||
|
||||
const correlationId = randomUUID();
|
||||
|
||||
@@ -572,6 +572,7 @@ describe('Plan Mode Denial Consistency', () => {
|
||||
toolRegistry: mockToolRegistry,
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
|
||||
getHookSystem: vi.fn().mockReturnValue(undefined),
|
||||
isInteractive: vi.fn().mockReturnValue(true),
|
||||
getEnableHooks: vi.fn().mockReturnValue(false),
|
||||
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.PLAN), // Key: Plan Mode
|
||||
|
||||
@@ -170,6 +170,8 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
mockConfig = {
|
||||
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
|
||||
toolRegistry: mockToolRegistry,
|
||||
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||
getHookSystem: vi.fn().mockReturnValue(undefined),
|
||||
isInteractive: vi.fn().mockReturnValue(true),
|
||||
getEnableHooks: vi.fn().mockReturnValue(true),
|
||||
setApprovalMode: vi.fn(),
|
||||
@@ -1316,6 +1318,7 @@ describe('Scheduler MCP Progress', () => {
|
||||
mockConfig = {
|
||||
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
|
||||
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||
getHookSystem: vi.fn().mockReturnValue(undefined),
|
||||
isInteractive: vi.fn().mockReturnValue(true),
|
||||
getEnableHooks: vi.fn().mockReturnValue(true),
|
||||
setApprovalMode: vi.fn(),
|
||||
|
||||
@@ -25,6 +25,8 @@ import {
|
||||
type ScheduledToolCall,
|
||||
} from './types.js';
|
||||
import { ToolErrorType } from '../tools/tool-error.js';
|
||||
import { extractMcpContext } from '../core/coreToolHookTriggers.js';
|
||||
import { BeforeToolHookOutput } from '../hooks/types.js';
|
||||
import { PolicyDecision, type ApprovalMode } from '../policy/types.js';
|
||||
import {
|
||||
ToolConfirmationOutcome,
|
||||
@@ -562,8 +564,95 @@ export class Scheduler {
|
||||
): Promise<void> {
|
||||
const callId = toolCall.request.callId;
|
||||
|
||||
let hookDecision: 'ask' | 'block' | undefined;
|
||||
let hookSystemMessage: string | undefined;
|
||||
|
||||
const hookSystem = this.config.getHookSystem();
|
||||
if (hookSystem) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const toolInput = (toolCall.invocation.params || {}) as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
const mcpContext = extractMcpContext(toolCall.invocation, this.config);
|
||||
|
||||
const beforeOutput = await hookSystem.fireBeforeToolEvent(
|
||||
toolCall.request.name,
|
||||
toolInput,
|
||||
mcpContext,
|
||||
toolCall.request.originalRequestName,
|
||||
);
|
||||
|
||||
if (beforeOutput) {
|
||||
if (beforeOutput.shouldStopExecution()) {
|
||||
this.state.updateStatus(
|
||||
callId,
|
||||
CoreToolCallStatus.Error,
|
||||
createErrorResponse(
|
||||
toolCall.request,
|
||||
new Error(
|
||||
`Agent execution stopped by hook: ${beforeOutput.getEffectiveReason()}`,
|
||||
),
|
||||
ToolErrorType.STOP_EXECUTION,
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const blockingError = beforeOutput.getBlockingError();
|
||||
if (blockingError?.blocked) {
|
||||
this.state.updateStatus(
|
||||
callId,
|
||||
CoreToolCallStatus.Error,
|
||||
createErrorResponse(
|
||||
toolCall.request,
|
||||
new Error(`Tool execution blocked: ${blockingError.reason}`),
|
||||
ToolErrorType.POLICY_VIOLATION,
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (beforeOutput.isAskDecision()) {
|
||||
hookDecision = 'ask';
|
||||
hookSystemMessage = beforeOutput.systemMessage;
|
||||
}
|
||||
|
||||
if (beforeOutput instanceof BeforeToolHookOutput) {
|
||||
const modifiedInput = beforeOutput.getModifiedToolInput();
|
||||
if (modifiedInput) {
|
||||
toolCall.request.args = modifiedInput;
|
||||
toolCall.request.inputModifiedByHook = true;
|
||||
try {
|
||||
toolCall.invocation = toolCall.tool.build(modifiedInput);
|
||||
} catch (error) {
|
||||
this.state.updateStatus(
|
||||
callId,
|
||||
CoreToolCallStatus.Error,
|
||||
createErrorResponse(
|
||||
toolCall.request,
|
||||
new Error(
|
||||
`Tool parameter modification by hook failed validation: ${error instanceof Error ? error.message : String(error)}`,
|
||||
),
|
||||
ToolErrorType.INVALID_TOOL_PARAMS,
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Policy & Security
|
||||
const { decision, rule } = await checkPolicy(toolCall, this.config);
|
||||
const { decision: policyDecision, rule } = await checkPolicy(
|
||||
toolCall,
|
||||
this.config,
|
||||
);
|
||||
let decision = policyDecision;
|
||||
if (hookDecision === 'ask') {
|
||||
decision = PolicyDecision.ASK_USER;
|
||||
}
|
||||
|
||||
if (decision === PolicyDecision.DENY) {
|
||||
const { errorMessage, errorType } = getPolicyDenialError(
|
||||
@@ -596,6 +685,8 @@ export class Scheduler {
|
||||
getPreferredEditor: this.getPreferredEditor,
|
||||
schedulerId: this.schedulerId,
|
||||
onWaitingForConfirmation: this.onWaitingForConfirmation,
|
||||
systemMessage: hookSystemMessage,
|
||||
forcedDecision: hookDecision === 'ask' ? 'ask_user' : undefined,
|
||||
});
|
||||
outcome = result.outcome;
|
||||
lastDetails = result.lastDetails;
|
||||
|
||||
@@ -212,6 +212,8 @@ describe('Scheduler Parallel Execution', () => {
|
||||
mockConfig = {
|
||||
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
|
||||
toolRegistry: mockToolRegistry,
|
||||
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||
getHookSystem: vi.fn().mockReturnValue(undefined),
|
||||
isInteractive: vi.fn().mockReturnValue(true),
|
||||
getEnableHooks: vi.fn().mockReturnValue(true),
|
||||
setApprovalMode: vi.fn(),
|
||||
|
||||
@@ -133,6 +133,20 @@ export class ToolExecutor {
|
||||
|
||||
const toolResult: ToolResult = await promise;
|
||||
|
||||
if (call.request.inputModifiedByHook) {
|
||||
const modificationMsg = `\n\n[System] Tool input parameters were modified by a hook before execution.`;
|
||||
if (typeof toolResult.llmContent === 'string') {
|
||||
toolResult.llmContent += modificationMsg;
|
||||
} else if (Array.isArray(toolResult.llmContent)) {
|
||||
toolResult.llmContent.push({ text: modificationMsg });
|
||||
} else if (toolResult.llmContent) {
|
||||
toolResult.llmContent = [
|
||||
toolResult.llmContent,
|
||||
{ text: modificationMsg },
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
if (signal.aborted) {
|
||||
completedToolCall = await this.createCancelledResult(
|
||||
call,
|
||||
|
||||
@@ -47,6 +47,8 @@ export interface ToolCallRequestInfo {
|
||||
traceId?: string;
|
||||
parentCallId?: string;
|
||||
schedulerId?: string;
|
||||
inputModifiedByHook?: boolean;
|
||||
forcedAsk?: boolean;
|
||||
}
|
||||
|
||||
export interface ToolCallResponseInfo {
|
||||
|
||||
@@ -112,7 +112,7 @@ describe('conseca-logger', () => {
|
||||
'user prompt',
|
||||
'policy',
|
||||
'tool call',
|
||||
'ALLOW',
|
||||
'allow',
|
||||
'rationale',
|
||||
);
|
||||
|
||||
@@ -122,7 +122,7 @@ describe('conseca-logger', () => {
|
||||
expect(logs.getLogger).toHaveBeenCalled();
|
||||
expect(mockLogger.emit).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
body: 'Conseca Verdict: ALLOW.',
|
||||
body: 'Conseca Verdict: allow.',
|
||||
attributes: expect.objectContaining({
|
||||
'event.name': EVENT_CONSECA_VERDICT,
|
||||
}),
|
||||
|
||||
@@ -12,12 +12,15 @@ import {
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
Kind,
|
||||
type ForcedToolDecision,
|
||||
type ToolCallConfirmationDetails,
|
||||
type ToolInvocation,
|
||||
type ToolLiveOutput,
|
||||
type ToolResult,
|
||||
} from '../tools/tools.js';
|
||||
import { createMockMessageBus } from './mock-message-bus.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import type { ShellExecutionConfig } from 'src/services/shellExecutionService.js';
|
||||
|
||||
interface MockToolOptions {
|
||||
name: string;
|
||||
@@ -28,11 +31,13 @@ interface MockToolOptions {
|
||||
shouldConfirmExecute?: (
|
||||
params: { [key: string]: unknown },
|
||||
signal: AbortSignal,
|
||||
forcedDecision?: ForcedToolDecision,
|
||||
) => Promise<ToolCallConfirmationDetails | false>;
|
||||
execute?: (
|
||||
params: { [key: string]: unknown },
|
||||
signal?: AbortSignal,
|
||||
updateOutput?: (output: string) => void,
|
||||
shellExecutionConfig?: ShellExecutionConfig,
|
||||
) => Promise<ToolResult>;
|
||||
params?: object;
|
||||
messageBus?: MessageBus;
|
||||
@@ -52,19 +57,26 @@ class MockToolInvocation extends BaseToolInvocation<
|
||||
|
||||
execute(
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string) => void,
|
||||
updateOutput?: (output: ToolLiveOutput) => void,
|
||||
shellExecutionConfig?: ShellExecutionConfig,
|
||||
): Promise<ToolResult> {
|
||||
if (updateOutput) {
|
||||
return this.tool.execute(this.params, signal, updateOutput);
|
||||
} else {
|
||||
return this.tool.execute(this.params);
|
||||
}
|
||||
return this.tool.execute(
|
||||
this.params,
|
||||
signal,
|
||||
updateOutput as ((output: string) => void) | undefined,
|
||||
shellExecutionConfig,
|
||||
);
|
||||
}
|
||||
|
||||
override shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
return this.tool.shouldConfirmExecute(this.params, abortSignal);
|
||||
return this.tool.shouldConfirmExecute(
|
||||
this.params,
|
||||
abortSignal,
|
||||
forcedDecision,
|
||||
);
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
@@ -79,14 +91,17 @@ export class MockTool extends BaseDeclarativeTool<
|
||||
{ [key: string]: unknown },
|
||||
ToolResult
|
||||
> {
|
||||
shouldConfirmExecute: (
|
||||
readonly shouldConfirmExecute: (
|
||||
params: { [key: string]: unknown },
|
||||
signal: AbortSignal,
|
||||
forcedDecision?: ForcedToolDecision,
|
||||
) => Promise<ToolCallConfirmationDetails | false>;
|
||||
execute: (
|
||||
|
||||
readonly execute: (
|
||||
params: { [key: string]: unknown },
|
||||
signal?: AbortSignal,
|
||||
updateOutput?: (output: string) => void,
|
||||
shellExecutionConfig?: ShellExecutionConfig,
|
||||
) => Promise<ToolResult>;
|
||||
|
||||
constructor(options: MockToolOptions) {
|
||||
@@ -162,6 +177,7 @@ export class MockModifiableToolInvocation extends BaseToolInvocation<
|
||||
|
||||
override async shouldConfirmExecute(
|
||||
_abortSignal: AbortSignal,
|
||||
_forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.tool.shouldConfirm) {
|
||||
return {
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
import {
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
type ForcedToolDecision,
|
||||
type ToolResult,
|
||||
Kind,
|
||||
type ToolAskUserConfirmationDetails,
|
||||
@@ -126,6 +127,7 @@ export class AskUserInvocation extends BaseToolInvocation<
|
||||
|
||||
override async shouldConfirmExecute(
|
||||
_abortSignal: AbortSignal,
|
||||
_forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolAskUserConfirmationDetails | false> {
|
||||
const normalizedQuestions = this.params.questions.map((q) => ({
|
||||
...q,
|
||||
|
||||
@@ -163,7 +163,7 @@ describe('Tool Confirmation Policy Updates', () => {
|
||||
|
||||
// Mock getMessageBusDecision to trigger ASK_USER flow
|
||||
vi.spyOn(invocation as any, 'getMessageBusDecision').mockResolvedValue(
|
||||
'ASK_USER',
|
||||
'ask_user',
|
||||
);
|
||||
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
Kind,
|
||||
type ForcedToolDecision,
|
||||
type ToolCallConfirmationDetails,
|
||||
type ToolConfirmationOutcome,
|
||||
type ToolEditConfirmationDetails,
|
||||
@@ -705,8 +706,12 @@ class EditToolInvocation
|
||||
*/
|
||||
protected override async getConfirmationDetails(
|
||||
abortSignal: AbortSignal,
|
||||
forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
if (
|
||||
this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT &&
|
||||
forcedDecision !== 'ask_user'
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ describe('EnterPlanModeTool', () => {
|
||||
getMessageBusDecision: () => Promise<string>;
|
||||
},
|
||||
'getMessageBusDecision',
|
||||
).mockResolvedValue('ASK_USER');
|
||||
).mockResolvedValue('ask_user');
|
||||
|
||||
const result = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
@@ -74,7 +74,7 @@ describe('EnterPlanModeTool', () => {
|
||||
getMessageBusDecision: () => Promise<string>;
|
||||
},
|
||||
'getMessageBusDecision',
|
||||
).mockResolvedValue('ALLOW');
|
||||
).mockResolvedValue('allow');
|
||||
|
||||
const result = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
@@ -92,7 +92,7 @@ describe('EnterPlanModeTool', () => {
|
||||
getMessageBusDecision: () => Promise<string>;
|
||||
},
|
||||
'getMessageBusDecision',
|
||||
).mockResolvedValue('DENY');
|
||||
).mockResolvedValue('deny');
|
||||
|
||||
await expect(
|
||||
invocation.shouldConfirmExecute(new AbortController().signal),
|
||||
@@ -136,7 +136,7 @@ describe('EnterPlanModeTool', () => {
|
||||
getMessageBusDecision: () => Promise<string>;
|
||||
},
|
||||
'getMessageBusDecision',
|
||||
).mockResolvedValue('ASK_USER');
|
||||
).mockResolvedValue('ask_user');
|
||||
|
||||
const details = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
import {
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
type ForcedToolDecision,
|
||||
type ToolResult,
|
||||
Kind,
|
||||
type ToolInfoConfirmationDetails,
|
||||
@@ -85,13 +86,15 @@ export class EnterPlanModeInvocation extends BaseToolInvocation<
|
||||
|
||||
override async shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolInfoConfirmationDetails | false> {
|
||||
const decision = await this.getMessageBusDecision(abortSignal);
|
||||
if (decision === 'ALLOW') {
|
||||
const decision =
|
||||
forcedDecision ?? (await this.getMessageBusDecision(abortSignal));
|
||||
if (decision === 'allow') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decision === 'DENY') {
|
||||
if (decision === 'deny') {
|
||||
throw new Error(
|
||||
`Tool execution for "${
|
||||
this._toolDisplayName || this._toolName
|
||||
@@ -99,7 +102,7 @@ export class EnterPlanModeInvocation extends BaseToolInvocation<
|
||||
);
|
||||
}
|
||||
|
||||
// ASK_USER
|
||||
// ask_user
|
||||
return {
|
||||
type: 'info',
|
||||
title: 'Enter Plan Mode',
|
||||
|
||||
@@ -58,7 +58,7 @@ describe('ExitPlanModeTool', () => {
|
||||
getMessageBusDecision: () => Promise<string>;
|
||||
},
|
||||
'getMessageBusDecision',
|
||||
).mockResolvedValue('ASK_USER');
|
||||
).mockResolvedValue('ask_user');
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -126,7 +126,7 @@ describe('ExitPlanModeTool', () => {
|
||||
getMessageBusDecision: () => Promise<string>;
|
||||
},
|
||||
'getMessageBusDecision',
|
||||
).mockResolvedValue('ALLOW');
|
||||
).mockResolvedValue('allow');
|
||||
|
||||
const result = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
@@ -149,7 +149,7 @@ describe('ExitPlanModeTool', () => {
|
||||
getMessageBusDecision: () => Promise<string>;
|
||||
},
|
||||
'getMessageBusDecision',
|
||||
).mockResolvedValue('DENY');
|
||||
).mockResolvedValue('deny');
|
||||
|
||||
await expect(
|
||||
invocation.shouldConfirmExecute(new AbortController().signal),
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
import {
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
type ForcedToolDecision,
|
||||
type ToolResult,
|
||||
Kind,
|
||||
type ToolExitPlanModeConfirmationDetails,
|
||||
@@ -118,6 +119,7 @@ export class ExitPlanModeInvocation extends BaseToolInvocation<
|
||||
|
||||
override async shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolExitPlanModeConfirmationDetails | false> {
|
||||
const resolvedPlanPath = this.getResolvedPlanPath();
|
||||
|
||||
@@ -137,8 +139,9 @@ export class ExitPlanModeInvocation extends BaseToolInvocation<
|
||||
return false;
|
||||
}
|
||||
|
||||
const decision = await this.getMessageBusDecision(abortSignal);
|
||||
if (decision === 'DENY') {
|
||||
const decision =
|
||||
forcedDecision ?? (await this.getMessageBusDecision(abortSignal));
|
||||
if (decision === 'deny') {
|
||||
throw new Error(
|
||||
`Tool execution for "${
|
||||
this._toolDisplayName || this._toolName
|
||||
@@ -146,7 +149,7 @@ export class ExitPlanModeInvocation extends BaseToolInvocation<
|
||||
);
|
||||
}
|
||||
|
||||
if (decision === 'ALLOW') {
|
||||
if (decision === 'allow') {
|
||||
// If policy is allow, auto-approve with default settings and execute.
|
||||
this.confirmationOutcome = ToolConfirmationOutcome.ProceedOnce;
|
||||
this.approvalPayload = {
|
||||
@@ -156,7 +159,7 @@ export class ExitPlanModeInvocation extends BaseToolInvocation<
|
||||
return false;
|
||||
}
|
||||
|
||||
// decision is 'ASK_USER'
|
||||
// decision is 'ask_user'
|
||||
return {
|
||||
type: 'exit_plan_mode',
|
||||
title: 'Plan Approval',
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
Kind,
|
||||
type ToolInvocation,
|
||||
type ToolResult,
|
||||
type ForcedToolDecision,
|
||||
type ToolCallConfirmationDetails,
|
||||
} from './tools.js';
|
||||
import { GET_INTERNAL_DOCS_TOOL_NAME } from './tool-names.js';
|
||||
@@ -85,6 +86,7 @@ class GetInternalDocsInvocation extends BaseToolInvocation<
|
||||
|
||||
override async shouldConfirmExecute(
|
||||
_abortSignal: AbortSignal,
|
||||
_forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
BaseToolInvocation,
|
||||
Kind,
|
||||
ToolConfirmationOutcome,
|
||||
type ForcedToolDecision,
|
||||
type ToolCallConfirmationDetails,
|
||||
type ToolInvocation,
|
||||
type ToolMcpConfirmationDetails,
|
||||
@@ -192,6 +193,7 @@ export class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
||||
|
||||
protected override async getConfirmationDetails(
|
||||
_abortSignal: AbortSignal,
|
||||
_forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
const serverAllowListKey = this.serverName;
|
||||
const toolAllowListKey = `${this.serverName}.${this.serverToolName}`;
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
BaseToolInvocation,
|
||||
Kind,
|
||||
ToolConfirmationOutcome,
|
||||
type ForcedToolDecision,
|
||||
type ToolEditConfirmationDetails,
|
||||
type ToolResult,
|
||||
} from './tools.js';
|
||||
@@ -163,6 +164,7 @@ class MemoryToolInvocation extends BaseToolInvocation<
|
||||
|
||||
protected override async getConfirmationDetails(
|
||||
_abortSignal: AbortSignal,
|
||||
_forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolEditConfirmationDetails | false> {
|
||||
const memoryFilePath = getGlobalMemoryFilePath();
|
||||
const allowlistKey = memoryFilePath;
|
||||
|
||||
@@ -57,10 +57,10 @@ class TestToolInvocation extends BaseToolInvocation<TestParams, TestResult> {
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<false> {
|
||||
const decision = await this.getMessageBusDecision(abortSignal);
|
||||
if (decision === 'ALLOW') {
|
||||
if (decision === 'allow') {
|
||||
return false;
|
||||
}
|
||||
if (decision === 'DENY') {
|
||||
if (decision === 'deny') {
|
||||
throw new Error('Tool execution denied by policy');
|
||||
}
|
||||
return false;
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
BaseToolInvocation,
|
||||
ToolConfirmationOutcome,
|
||||
Kind,
|
||||
type ForcedToolDecision,
|
||||
type ToolInvocation,
|
||||
type ToolResult,
|
||||
type ToolCallConfirmationDetails,
|
||||
@@ -109,6 +110,7 @@ export class ShellToolInvocation extends BaseToolInvocation<
|
||||
|
||||
protected override async getConfirmationDetails(
|
||||
_abortSignal: AbortSignal,
|
||||
_forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
const command = stripShellWrapper(this.params.command);
|
||||
|
||||
|
||||
@@ -21,6 +21,11 @@ import {
|
||||
import { type ApprovalMode } from '../policy/types.js';
|
||||
import type { SubagentProgress } from '../agents/types.js';
|
||||
|
||||
/**
|
||||
* Supported decisions for forcing tool execution behavior.
|
||||
*/
|
||||
export type ForcedToolDecision = 'allow' | 'deny' | 'ask_user';
|
||||
|
||||
/**
|
||||
* Represents a validated and ready-to-execute tool call.
|
||||
* An instance of this is created by a `ToolBuilder`.
|
||||
@@ -53,9 +58,10 @@ export interface ToolInvocation<
|
||||
* @param abortSignal An AbortSignal that can be used to cancel the confirmation request.
|
||||
* @returns A ToolCallConfirmationDetails object if confirmation is required, or false if not.
|
||||
*/
|
||||
shouldConfirmExecute(
|
||||
shouldConfirmExecute: (
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false>;
|
||||
forcedDecision?: ForcedToolDecision,
|
||||
) => Promise<ToolCallConfirmationDetails | false>;
|
||||
|
||||
/**
|
||||
* Executes the tool with the validated parameters.
|
||||
@@ -103,13 +109,15 @@ export abstract class BaseToolInvocation<
|
||||
|
||||
async shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
const decision = await this.getMessageBusDecision(abortSignal);
|
||||
if (decision === 'ALLOW') {
|
||||
const decision =
|
||||
forcedDecision ?? (await this.getMessageBusDecision(abortSignal));
|
||||
if (decision === 'allow') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decision === 'DENY') {
|
||||
if (decision === 'deny') {
|
||||
throw new Error(
|
||||
`Tool execution for "${
|
||||
this._toolDisplayName || this._toolName
|
||||
@@ -117,12 +125,12 @@ export abstract class BaseToolInvocation<
|
||||
);
|
||||
}
|
||||
|
||||
if (decision === 'ASK_USER') {
|
||||
return this.getConfirmationDetails(abortSignal);
|
||||
if (decision === 'ask_user') {
|
||||
return this.getConfirmationDetails(abortSignal, forcedDecision);
|
||||
}
|
||||
|
||||
// Default to confirmation details if decision is unknown (should not happen with exhaustive policy)
|
||||
return this.getConfirmationDetails(abortSignal);
|
||||
return this.getConfirmationDetails(abortSignal, forcedDecision);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -161,11 +169,12 @@ export abstract class BaseToolInvocation<
|
||||
|
||||
/**
|
||||
* Subclasses should override this method to provide custom confirmation UI
|
||||
* when the policy engine's decision is 'ASK_USER'.
|
||||
* when the policy engine's decision is 'ask_user'.
|
||||
* The base implementation provides a generic confirmation prompt.
|
||||
*/
|
||||
protected async getConfirmationDetails(
|
||||
_abortSignal: AbortSignal,
|
||||
_forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (!this.messageBus) {
|
||||
return false;
|
||||
@@ -184,11 +193,11 @@ export abstract class BaseToolInvocation<
|
||||
|
||||
protected getMessageBusDecision(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<'ALLOW' | 'DENY' | 'ASK_USER'> {
|
||||
): Promise<ForcedToolDecision> {
|
||||
if (!this.messageBus || !this._toolName) {
|
||||
// If there's no message bus, we can't make a decision, so we allow.
|
||||
// The legacy confirmation flow will still apply if the tool needs it.
|
||||
return Promise.resolve('ALLOW');
|
||||
return Promise.resolve('allow');
|
||||
}
|
||||
|
||||
const correlationId = randomUUID();
|
||||
@@ -204,9 +213,9 @@ export abstract class BaseToolInvocation<
|
||||
toolAnnotations: this._toolAnnotations,
|
||||
};
|
||||
|
||||
return new Promise<'ALLOW' | 'DENY' | 'ASK_USER'>((resolve) => {
|
||||
return new Promise<ForcedToolDecision>((resolve) => {
|
||||
if (!this.messageBus) {
|
||||
resolve('ALLOW');
|
||||
resolve('allow');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -227,11 +236,11 @@ export abstract class BaseToolInvocation<
|
||||
|
||||
const abortHandler = () => {
|
||||
cleanup();
|
||||
resolve('DENY');
|
||||
resolve('deny');
|
||||
};
|
||||
|
||||
if (abortSignal.aborted) {
|
||||
resolve('DENY');
|
||||
resolve('deny');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -239,11 +248,11 @@ export abstract class BaseToolInvocation<
|
||||
if (response.correlationId === correlationId) {
|
||||
cleanup();
|
||||
if (response.requiresUserConfirmation) {
|
||||
resolve('ASK_USER');
|
||||
resolve('ask_user');
|
||||
} else if (response.confirmed) {
|
||||
resolve('ALLOW');
|
||||
resolve('allow');
|
||||
} else {
|
||||
resolve('DENY');
|
||||
resolve('deny');
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -252,7 +261,7 @@ export abstract class BaseToolInvocation<
|
||||
|
||||
timeoutId = setTimeout(() => {
|
||||
cleanup();
|
||||
resolve('ASK_USER'); // Default to ASK_USER on timeout
|
||||
resolve('ask_user'); // Default to ask_user on timeout
|
||||
}, 30000);
|
||||
|
||||
this.messageBus.subscribe(
|
||||
@@ -270,7 +279,7 @@ export abstract class BaseToolInvocation<
|
||||
void this.messageBus.publish(request);
|
||||
} catch (_error) {
|
||||
cleanup();
|
||||
resolve('ALLOW');
|
||||
resolve('allow');
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -729,6 +738,7 @@ export interface DiffStat {
|
||||
export interface ToolEditConfirmationDetails {
|
||||
type: 'edit';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
onConfirm: (
|
||||
outcome: ToolConfirmationOutcome,
|
||||
payload?: ToolConfirmationPayload,
|
||||
@@ -767,6 +777,7 @@ export type ToolConfirmationPayload =
|
||||
export interface ToolExecuteConfirmationDetails {
|
||||
type: 'exec';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
||||
command: string;
|
||||
rootCommand: string;
|
||||
@@ -777,6 +788,7 @@ export interface ToolExecuteConfirmationDetails {
|
||||
export interface ToolMcpConfirmationDetails {
|
||||
type: 'mcp';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
serverName: string;
|
||||
toolName: string;
|
||||
toolDisplayName: string;
|
||||
@@ -789,6 +801,7 @@ export interface ToolMcpConfirmationDetails {
|
||||
export interface ToolInfoConfirmationDetails {
|
||||
type: 'info';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
||||
prompt: string;
|
||||
urls?: string[];
|
||||
@@ -797,6 +810,7 @@ export interface ToolInfoConfirmationDetails {
|
||||
export interface ToolAskUserConfirmationDetails {
|
||||
type: 'ask_user';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
questions: Question[];
|
||||
onConfirm: (
|
||||
outcome: ToolConfirmationOutcome,
|
||||
@@ -807,6 +821,7 @@ export interface ToolAskUserConfirmationDetails {
|
||||
export interface ToolExitPlanModeConfirmationDetails {
|
||||
type: 'exit_plan_mode';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
planPath: string;
|
||||
onConfirm: (
|
||||
outcome: ToolConfirmationOutcome,
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
Kind,
|
||||
type ForcedToolDecision,
|
||||
type ToolCallConfirmationDetails,
|
||||
type ToolInvocation,
|
||||
type ToolResult,
|
||||
@@ -293,6 +294,7 @@ ${textContent}
|
||||
|
||||
protected override async getConfirmationDetails(
|
||||
_abortSignal: AbortSignal,
|
||||
_forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
// Check for AUTO_EDIT approval mode. This tool has a specific behavior
|
||||
// where ProceedAlways switches the entire session to AUTO_EDIT.
|
||||
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
Kind,
|
||||
type ForcedToolDecision,
|
||||
type FileDiff,
|
||||
type ToolCallConfirmationDetails,
|
||||
type ToolEditConfirmationDetails,
|
||||
@@ -174,8 +175,12 @@ class WriteFileToolInvocation extends BaseToolInvocation<
|
||||
|
||||
protected override async getConfirmationDetails(
|
||||
abortSignal: AbortSignal,
|
||||
forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
if (
|
||||
this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT &&
|
||||
forcedDecision !== 'ask_user'
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user