mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-12 21:03:05 -07:00
Migrate beforeTool and afterTool hooks to hookSystem (#17204)
Co-authored-by: Ishaan Gupta <ishaankone@gmail.com> Co-authored-by: Tommaso Sciortino <sciortino@gmail.com>
This commit is contained in:
@@ -13,10 +13,12 @@ import {
|
|||||||
type AnyDeclarativeTool,
|
type AnyDeclarativeTool,
|
||||||
} from '../tools/tools.js';
|
} from '../tools/tools.js';
|
||||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||||
|
import type { HookSystem } from '../hooks/hookSystem.js';
|
||||||
|
import type { Config } from '../config/config.js';
|
||||||
import {
|
import {
|
||||||
MessageBusType,
|
type DefaultHookOutput,
|
||||||
type HookExecutionResponse,
|
BeforeToolHookOutput,
|
||||||
} from '../confirmation-bus/types.js';
|
} from '../hooks/types.js';
|
||||||
|
|
||||||
class MockInvocation extends BaseToolInvocation<{ key?: string }, ToolResult> {
|
class MockInvocation extends BaseToolInvocation<{ key?: string }, ToolResult> {
|
||||||
constructor(params: { key?: string }, messageBus: MessageBus) {
|
constructor(params: { key?: string }, messageBus: MessageBus) {
|
||||||
@@ -38,6 +40,8 @@ class MockInvocation extends BaseToolInvocation<{ key?: string }, ToolResult> {
|
|||||||
describe('executeToolWithHooks', () => {
|
describe('executeToolWithHooks', () => {
|
||||||
let messageBus: MessageBus;
|
let messageBus: MessageBus;
|
||||||
let mockTool: AnyDeclarativeTool;
|
let mockTool: AnyDeclarativeTool;
|
||||||
|
let mockHookSystem: HookSystem;
|
||||||
|
let mockConfig: Config;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
messageBus = {
|
messageBus = {
|
||||||
@@ -46,6 +50,15 @@ describe('executeToolWithHooks', () => {
|
|||||||
subscribe: vi.fn(),
|
subscribe: vi.fn(),
|
||||||
unsubscribe: vi.fn(),
|
unsubscribe: vi.fn(),
|
||||||
} as unknown as MessageBus;
|
} as unknown as MessageBus;
|
||||||
|
mockHookSystem = {
|
||||||
|
fireBeforeToolEvent: vi.fn(),
|
||||||
|
fireAfterToolEvent: vi.fn(),
|
||||||
|
} as unknown as HookSystem;
|
||||||
|
mockConfig = {
|
||||||
|
getHookSystem: vi.fn().mockReturnValue(mockHookSystem),
|
||||||
|
getMcpClientManager: vi.fn().mockReturnValue(undefined),
|
||||||
|
getMcpServers: vi.fn().mockReturnValue({}),
|
||||||
|
} as unknown as Config;
|
||||||
mockTool = {
|
mockTool = {
|
||||||
build: vi
|
build: vi
|
||||||
.fn()
|
.fn()
|
||||||
@@ -57,25 +70,24 @@ describe('executeToolWithHooks', () => {
|
|||||||
const invocation = new MockInvocation({}, messageBus);
|
const invocation = new MockInvocation({}, messageBus);
|
||||||
const abortSignal = new AbortController().signal;
|
const abortSignal = new AbortController().signal;
|
||||||
|
|
||||||
vi.mocked(messageBus.request).mockResolvedValue({
|
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({
|
||||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
shouldStopExecution: () => true,
|
||||||
correlationId: 'test-id',
|
getEffectiveReason: () => 'Stop immediately',
|
||||||
success: true,
|
getBlockingError: () => ({
|
||||||
output: {
|
blocked: false,
|
||||||
continue: false,
|
|
||||||
stopReason: 'Stop immediately',
|
|
||||||
decision: 'block',
|
|
||||||
reason: 'Should be ignored because continue is false',
|
reason: 'Should be ignored because continue is false',
|
||||||
},
|
}),
|
||||||
} as HookExecutionResponse);
|
} as unknown as DefaultHookOutput);
|
||||||
|
|
||||||
const result = await executeToolWithHooks(
|
const result = await executeToolWithHooks(
|
||||||
invocation,
|
invocation,
|
||||||
'test_tool',
|
'test_tool',
|
||||||
abortSignal,
|
abortSignal,
|
||||||
messageBus,
|
|
||||||
true,
|
|
||||||
mockTool,
|
mockTool,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
mockConfig,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(result.error?.type).toBe(ToolErrorType.STOP_EXECUTION);
|
expect(result.error?.type).toBe(ToolErrorType.STOP_EXECUTION);
|
||||||
@@ -86,23 +98,21 @@ describe('executeToolWithHooks', () => {
|
|||||||
const invocation = new MockInvocation({}, messageBus);
|
const invocation = new MockInvocation({}, messageBus);
|
||||||
const abortSignal = new AbortController().signal;
|
const abortSignal = new AbortController().signal;
|
||||||
|
|
||||||
vi.mocked(messageBus.request).mockResolvedValue({
|
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({
|
||||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
shouldStopExecution: () => false,
|
||||||
correlationId: 'test-id',
|
getEffectiveReason: () => '',
|
||||||
success: true,
|
getBlockingError: () => ({ blocked: true, reason: 'Execution blocked' }),
|
||||||
output: {
|
} as unknown as DefaultHookOutput);
|
||||||
decision: 'block',
|
|
||||||
reason: 'Execution blocked',
|
|
||||||
},
|
|
||||||
} as HookExecutionResponse);
|
|
||||||
|
|
||||||
const result = await executeToolWithHooks(
|
const result = await executeToolWithHooks(
|
||||||
invocation,
|
invocation,
|
||||||
'test_tool',
|
'test_tool',
|
||||||
abortSignal,
|
abortSignal,
|
||||||
messageBus,
|
|
||||||
true,
|
|
||||||
mockTool,
|
mockTool,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
mockConfig,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED);
|
expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED);
|
||||||
@@ -114,32 +124,27 @@ describe('executeToolWithHooks', () => {
|
|||||||
const abortSignal = new AbortController().signal;
|
const abortSignal = new AbortController().signal;
|
||||||
const spy = vi.spyOn(invocation, 'execute');
|
const spy = vi.spyOn(invocation, 'execute');
|
||||||
|
|
||||||
// BeforeTool allow
|
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({
|
||||||
vi.mocked(messageBus.request)
|
shouldStopExecution: () => false,
|
||||||
.mockResolvedValueOnce({
|
getEffectiveReason: () => '',
|
||||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
getBlockingError: () => ({ blocked: false, reason: '' }),
|
||||||
correlationId: 'test-id',
|
} as unknown as DefaultHookOutput);
|
||||||
success: true,
|
|
||||||
output: { decision: 'allow' },
|
vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue({
|
||||||
} as HookExecutionResponse)
|
shouldStopExecution: () => true,
|
||||||
// AfterTool stop
|
getEffectiveReason: () => 'Stop after execution',
|
||||||
.mockResolvedValueOnce({
|
getBlockingError: () => ({ blocked: false, reason: '' }),
|
||||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
} as unknown as DefaultHookOutput);
|
||||||
correlationId: 'test-id',
|
|
||||||
success: true,
|
|
||||||
output: {
|
|
||||||
continue: false,
|
|
||||||
stopReason: 'Stop after execution',
|
|
||||||
},
|
|
||||||
} as HookExecutionResponse);
|
|
||||||
|
|
||||||
const result = await executeToolWithHooks(
|
const result = await executeToolWithHooks(
|
||||||
invocation,
|
invocation,
|
||||||
'test_tool',
|
'test_tool',
|
||||||
abortSignal,
|
abortSignal,
|
||||||
messageBus,
|
|
||||||
true,
|
|
||||||
mockTool,
|
mockTool,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
mockConfig,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(result.error?.type).toBe(ToolErrorType.STOP_EXECUTION);
|
expect(result.error?.type).toBe(ToolErrorType.STOP_EXECUTION);
|
||||||
@@ -151,32 +156,27 @@ describe('executeToolWithHooks', () => {
|
|||||||
const invocation = new MockInvocation({}, messageBus);
|
const invocation = new MockInvocation({}, messageBus);
|
||||||
const abortSignal = new AbortController().signal;
|
const abortSignal = new AbortController().signal;
|
||||||
|
|
||||||
// BeforeTool allow
|
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({
|
||||||
vi.mocked(messageBus.request)
|
shouldStopExecution: () => false,
|
||||||
.mockResolvedValueOnce({
|
getEffectiveReason: () => '',
|
||||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
getBlockingError: () => ({ blocked: false, reason: '' }),
|
||||||
correlationId: 'test-id',
|
} as unknown as DefaultHookOutput);
|
||||||
success: true,
|
|
||||||
output: { decision: 'allow' },
|
vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue({
|
||||||
} as HookExecutionResponse)
|
shouldStopExecution: () => false,
|
||||||
// AfterTool deny
|
getEffectiveReason: () => '',
|
||||||
.mockResolvedValueOnce({
|
getBlockingError: () => ({ blocked: true, reason: 'Result denied' }),
|
||||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
} as unknown as DefaultHookOutput);
|
||||||
correlationId: 'test-id',
|
|
||||||
success: true,
|
|
||||||
output: {
|
|
||||||
decision: 'deny',
|
|
||||||
reason: 'Result denied',
|
|
||||||
},
|
|
||||||
} as HookExecutionResponse);
|
|
||||||
|
|
||||||
const result = await executeToolWithHooks(
|
const result = await executeToolWithHooks(
|
||||||
invocation,
|
invocation,
|
||||||
'test_tool',
|
'test_tool',
|
||||||
abortSignal,
|
abortSignal,
|
||||||
messageBus,
|
|
||||||
true,
|
|
||||||
mockTool,
|
mockTool,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
mockConfig,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED);
|
expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED);
|
||||||
@@ -189,39 +189,28 @@ describe('executeToolWithHooks', () => {
|
|||||||
const toolName = 'test-tool';
|
const toolName = 'test-tool';
|
||||||
const abortSignal = new AbortController().signal;
|
const abortSignal = new AbortController().signal;
|
||||||
|
|
||||||
// Capture arguments to verify what was passed before modification
|
const mockBeforeOutput = new BeforeToolHookOutput({
|
||||||
const requestSpy = vi.fn().mockImplementation(async (request) => {
|
continue: true,
|
||||||
if (request.eventName === 'BeforeTool') {
|
|
||||||
// Verify input is original before we return modification instruction
|
|
||||||
expect(request.input.tool_input.key).toBe('original');
|
|
||||||
return {
|
|
||||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
|
||||||
correlationId: 'test-id',
|
|
||||||
success: true,
|
|
||||||
output: {
|
|
||||||
hookSpecificOutput: {
|
hookSpecificOutput: {
|
||||||
hookEventName: 'BeforeTool',
|
hookEventName: 'BeforeTool',
|
||||||
tool_input: { key: 'modified' },
|
tool_input: { key: 'modified' },
|
||||||
},
|
},
|
||||||
},
|
|
||||||
} as HookExecutionResponse;
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
|
||||||
correlationId: 'test-id',
|
|
||||||
success: true,
|
|
||||||
output: {},
|
|
||||||
} as HookExecutionResponse;
|
|
||||||
});
|
});
|
||||||
messageBus.request = requestSpy;
|
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue(
|
||||||
|
mockBeforeOutput,
|
||||||
|
);
|
||||||
|
|
||||||
|
vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue(undefined);
|
||||||
|
|
||||||
const result = await executeToolWithHooks(
|
const result = await executeToolWithHooks(
|
||||||
invocation,
|
invocation,
|
||||||
toolName,
|
toolName,
|
||||||
abortSignal,
|
abortSignal,
|
||||||
messageBus,
|
|
||||||
true, // hooksEnabled
|
|
||||||
mockTool,
|
mockTool,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
mockConfig,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Verify result reflects modified input
|
// Verify result reflects modified input
|
||||||
@@ -231,7 +220,7 @@ describe('executeToolWithHooks', () => {
|
|||||||
// Verify params object was modified in place
|
// Verify params object was modified in place
|
||||||
expect(invocation.params.key).toBe('modified');
|
expect(invocation.params.key).toBe('modified');
|
||||||
|
|
||||||
expect(requestSpy).toHaveBeenCalled();
|
expect(mockHookSystem.fireBeforeToolEvent).toHaveBeenCalled();
|
||||||
expect(mockTool.build).toHaveBeenCalledWith({ key: 'modified' });
|
expect(mockTool.build).toHaveBeenCalledWith({ key: 'modified' });
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -241,25 +230,28 @@ describe('executeToolWithHooks', () => {
|
|||||||
const toolName = 'test-tool';
|
const toolName = 'test-tool';
|
||||||
const abortSignal = new AbortController().signal;
|
const abortSignal = new AbortController().signal;
|
||||||
|
|
||||||
vi.mocked(messageBus.request).mockResolvedValue({
|
const mockBeforeOutput = new BeforeToolHookOutput({
|
||||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
continue: true,
|
||||||
correlationId: 'test-id',
|
|
||||||
success: true,
|
|
||||||
output: {
|
|
||||||
hookSpecificOutput: {
|
hookSpecificOutput: {
|
||||||
hookEventName: 'BeforeTool',
|
hookEventName: 'BeforeTool',
|
||||||
// No tool_input
|
// No tool input
|
||||||
},
|
},
|
||||||
},
|
});
|
||||||
} as HookExecutionResponse);
|
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue(
|
||||||
|
mockBeforeOutput,
|
||||||
|
);
|
||||||
|
|
||||||
|
vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue(undefined);
|
||||||
|
|
||||||
const result = await executeToolWithHooks(
|
const result = await executeToolWithHooks(
|
||||||
invocation,
|
invocation,
|
||||||
toolName,
|
toolName,
|
||||||
abortSignal,
|
abortSignal,
|
||||||
messageBus,
|
|
||||||
true, // hooksEnabled
|
|
||||||
mockTool,
|
mockTool,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
undefined,
|
||||||
|
mockConfig,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(result.llmContent).toBe('key: original');
|
expect(result.llmContent).toBe('key: original');
|
||||||
|
|||||||
@@ -11,9 +11,7 @@ import {
|
|||||||
type HookExecutionResponse,
|
type HookExecutionResponse,
|
||||||
} from '../confirmation-bus/types.js';
|
} from '../confirmation-bus/types.js';
|
||||||
import {
|
import {
|
||||||
createHookOutput,
|
|
||||||
NotificationType,
|
NotificationType,
|
||||||
type DefaultHookOutput,
|
|
||||||
type McpToolContext,
|
type McpToolContext,
|
||||||
BeforeToolHookOutput,
|
BeforeToolHookOutput,
|
||||||
} from '../hooks/types.js';
|
} from '../hooks/types.js';
|
||||||
@@ -194,103 +192,12 @@ function extractMcpContext(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Fires the BeforeTool hook and returns the hook output.
|
|
||||||
*
|
|
||||||
* @param messageBus The message bus to use for hook communication
|
|
||||||
* @param toolName The name of the tool being executed
|
|
||||||
* @param toolInput The input parameters for the tool
|
|
||||||
* @param mcpContext Optional MCP context for MCP tools
|
|
||||||
* @returns The hook output, or undefined if no hook was executed or on error
|
|
||||||
*/
|
|
||||||
export async function fireBeforeToolHook(
|
|
||||||
messageBus: MessageBus,
|
|
||||||
toolName: string,
|
|
||||||
toolInput: Record<string, unknown>,
|
|
||||||
mcpContext?: McpToolContext,
|
|
||||||
): Promise<DefaultHookOutput | undefined> {
|
|
||||||
try {
|
|
||||||
const response = await messageBus.request<
|
|
||||||
HookExecutionRequest,
|
|
||||||
HookExecutionResponse
|
|
||||||
>(
|
|
||||||
{
|
|
||||||
type: MessageBusType.HOOK_EXECUTION_REQUEST,
|
|
||||||
eventName: 'BeforeTool',
|
|
||||||
input: {
|
|
||||||
tool_name: toolName,
|
|
||||||
tool_input: toolInput,
|
|
||||||
...(mcpContext && { mcp_context: mcpContext }),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
MessageBusType.HOOK_EXECUTION_RESPONSE,
|
|
||||||
);
|
|
||||||
|
|
||||||
return response.output
|
|
||||||
? createHookOutput('BeforeTool', response.output)
|
|
||||||
: undefined;
|
|
||||||
} catch (error) {
|
|
||||||
debugLogger.debug(`BeforeTool hook failed for ${toolName}:`, error);
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Fires the AfterTool hook and returns the hook output.
|
|
||||||
*
|
|
||||||
* @param messageBus The message bus to use for hook communication
|
|
||||||
* @param toolName The name of the tool that was executed
|
|
||||||
* @param toolInput The input parameters for the tool
|
|
||||||
* @param toolResponse The result from the tool execution
|
|
||||||
* @param mcpContext Optional MCP context for MCP tools
|
|
||||||
* @returns The hook output, or undefined if no hook was executed or on error
|
|
||||||
*/
|
|
||||||
export async function fireAfterToolHook(
|
|
||||||
messageBus: MessageBus,
|
|
||||||
toolName: string,
|
|
||||||
toolInput: Record<string, unknown>,
|
|
||||||
toolResponse: {
|
|
||||||
llmContent: ToolResult['llmContent'];
|
|
||||||
returnDisplay: ToolResult['returnDisplay'];
|
|
||||||
error: ToolResult['error'];
|
|
||||||
},
|
|
||||||
mcpContext?: McpToolContext,
|
|
||||||
): Promise<DefaultHookOutput | undefined> {
|
|
||||||
try {
|
|
||||||
const response = await messageBus.request<
|
|
||||||
HookExecutionRequest,
|
|
||||||
HookExecutionResponse
|
|
||||||
>(
|
|
||||||
{
|
|
||||||
type: MessageBusType.HOOK_EXECUTION_REQUEST,
|
|
||||||
eventName: 'AfterTool',
|
|
||||||
input: {
|
|
||||||
tool_name: toolName,
|
|
||||||
tool_input: toolInput,
|
|
||||||
tool_response: toolResponse,
|
|
||||||
...(mcpContext && { mcp_context: mcpContext }),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
MessageBusType.HOOK_EXECUTION_RESPONSE,
|
|
||||||
);
|
|
||||||
|
|
||||||
return response.output
|
|
||||||
? createHookOutput('AfterTool', response.output)
|
|
||||||
: undefined;
|
|
||||||
} catch (error) {
|
|
||||||
debugLogger.debug(`AfterTool hook failed for ${toolName}:`, error);
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Execute a tool with BeforeTool and AfterTool hooks.
|
* Execute a tool with BeforeTool and AfterTool hooks.
|
||||||
*
|
*
|
||||||
* @param invocation The tool invocation to execute
|
* @param invocation The tool invocation to execute
|
||||||
* @param toolName The name of the tool
|
* @param toolName The name of the tool
|
||||||
* @param signal Abort signal for cancellation
|
* @param signal Abort signal for cancellation
|
||||||
* @param messageBus Optional message bus for hook communication
|
|
||||||
* @param hooksEnabled Whether hooks are enabled
|
|
||||||
* @param liveOutputCallback Optional callback for live output updates
|
* @param liveOutputCallback Optional callback for live output updates
|
||||||
* @param shellExecutionConfig Optional shell execution config
|
* @param shellExecutionConfig Optional shell execution config
|
||||||
* @param setPidCallback Optional callback to set the PID for shell invocations
|
* @param setPidCallback Optional callback to set the PID for shell invocations
|
||||||
@@ -301,8 +208,6 @@ export async function executeToolWithHooks(
|
|||||||
invocation: ShellToolInvocation | AnyToolInvocation,
|
invocation: ShellToolInvocation | AnyToolInvocation,
|
||||||
toolName: string,
|
toolName: string,
|
||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
messageBus: MessageBus | undefined,
|
|
||||||
hooksEnabled: boolean,
|
|
||||||
tool: AnyDeclarativeTool,
|
tool: AnyDeclarativeTool,
|
||||||
liveOutputCallback?: (outputChunk: string | AnsiOutput) => void,
|
liveOutputCallback?: (outputChunk: string | AnsiOutput) => void,
|
||||||
shellExecutionConfig?: ShellExecutionConfig,
|
shellExecutionConfig?: ShellExecutionConfig,
|
||||||
@@ -316,10 +221,9 @@ export async function executeToolWithHooks(
|
|||||||
// Extract MCP context if this is an MCP tool (only if config is provided)
|
// Extract MCP context if this is an MCP tool (only if config is provided)
|
||||||
const mcpContext = config ? extractMcpContext(invocation, config) : undefined;
|
const mcpContext = config ? extractMcpContext(invocation, config) : undefined;
|
||||||
|
|
||||||
// Fire BeforeTool hook through MessageBus (only if hooks are enabled)
|
const hookSystem = config?.getHookSystem();
|
||||||
if (hooksEnabled && messageBus) {
|
if (hookSystem) {
|
||||||
const beforeOutput = await fireBeforeToolHook(
|
const beforeOutput = await hookSystem.fireBeforeToolEvent(
|
||||||
messageBus,
|
|
||||||
toolName,
|
toolName,
|
||||||
toolInput,
|
toolInput,
|
||||||
mcpContext,
|
mcpContext,
|
||||||
@@ -419,10 +323,8 @@ export async function executeToolWithHooks(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fire AfterTool hook through MessageBus (only if hooks are enabled)
|
if (hookSystem) {
|
||||||
if (hooksEnabled && messageBus) {
|
const afterOutput = await hookSystem.fireAfterToolEvent(
|
||||||
const afterOutput = await fireAfterToolHook(
|
|
||||||
messageBus,
|
|
||||||
toolName,
|
toolName,
|
||||||
toolInput,
|
toolInput,
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1889,6 +1889,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
|
|||||||
}) as unknown as PolicyEngine,
|
}) as unknown as PolicyEngine,
|
||||||
isInteractive: () => false,
|
isInteractive: () => false,
|
||||||
});
|
});
|
||||||
|
mockConfig.getHookSystem = vi.fn().mockReturnValue(undefined);
|
||||||
|
|
||||||
const scheduler = new CoreToolScheduler({
|
const scheduler = new CoreToolScheduler({
|
||||||
config: mockConfig,
|
config: mockConfig,
|
||||||
@@ -2018,6 +2019,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
|
|||||||
getApprovalMode: () => ApprovalMode.YOLO,
|
getApprovalMode: () => ApprovalMode.YOLO,
|
||||||
isInteractive: () => false,
|
isInteractive: () => false,
|
||||||
});
|
});
|
||||||
|
mockConfig.getHookSystem = vi.fn().mockReturnValue(undefined);
|
||||||
|
|
||||||
const scheduler = new CoreToolScheduler({
|
const scheduler = new CoreToolScheduler({
|
||||||
config: mockConfig,
|
config: mockConfig,
|
||||||
@@ -2105,6 +2107,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
|
|||||||
check: async () => ({ decision: PolicyDecision.DENY }),
|
check: async () => ({ decision: PolicyDecision.DENY }),
|
||||||
}) as unknown as PolicyEngine,
|
}) as unknown as PolicyEngine,
|
||||||
});
|
});
|
||||||
|
mockConfig.getHookSystem = vi.fn().mockReturnValue(undefined);
|
||||||
|
|
||||||
const scheduler = new CoreToolScheduler({
|
const scheduler = new CoreToolScheduler({
|
||||||
config: mockConfig,
|
config: mockConfig,
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import type {
|
|||||||
BeforeModelHookOutput,
|
BeforeModelHookOutput,
|
||||||
AfterModelHookOutput,
|
AfterModelHookOutput,
|
||||||
BeforeToolSelectionHookOutput,
|
BeforeToolSelectionHookOutput,
|
||||||
|
McpToolContext,
|
||||||
} from './types.js';
|
} from './types.js';
|
||||||
import type { AggregatedHookResult } from './hookAggregator.js';
|
import type { AggregatedHookResult } from './hookAggregator.js';
|
||||||
import type {
|
import type {
|
||||||
@@ -297,4 +298,46 @@ export class HookSystem {
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fireBeforeToolEvent(
|
||||||
|
toolName: string,
|
||||||
|
toolInput: Record<string, unknown>,
|
||||||
|
mcpContext?: McpToolContext,
|
||||||
|
): Promise<DefaultHookOutput | undefined> {
|
||||||
|
try {
|
||||||
|
const result = await this.hookEventHandler.fireBeforeToolEvent(
|
||||||
|
toolName,
|
||||||
|
toolInput,
|
||||||
|
mcpContext,
|
||||||
|
);
|
||||||
|
return result.finalOutput;
|
||||||
|
} catch (error) {
|
||||||
|
debugLogger.debug(`BeforeTool hook failed for ${toolName}:`, error);
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fireAfterToolEvent(
|
||||||
|
toolName: string,
|
||||||
|
toolInput: Record<string, unknown>,
|
||||||
|
toolResponse: {
|
||||||
|
llmContent: unknown;
|
||||||
|
returnDisplay: unknown;
|
||||||
|
error: unknown;
|
||||||
|
},
|
||||||
|
mcpContext?: McpToolContext,
|
||||||
|
): Promise<DefaultHookOutput | undefined> {
|
||||||
|
try {
|
||||||
|
const result = await this.hookEventHandler.fireAfterToolEvent(
|
||||||
|
toolName,
|
||||||
|
toolInput,
|
||||||
|
toolResponse as Record<string, unknown>,
|
||||||
|
mcpContext,
|
||||||
|
);
|
||||||
|
return result.finalOutput;
|
||||||
|
} catch (error) {
|
||||||
|
debugLogger.debug(`AfterTool hook failed for ${toolName}:`, error);
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -253,17 +253,7 @@ describe('ToolExecutor', () => {
|
|||||||
// 2. Mock executeToolWithHooks to trigger the PID callback
|
// 2. Mock executeToolWithHooks to trigger the PID callback
|
||||||
const testPid = 12345;
|
const testPid = 12345;
|
||||||
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockImplementation(
|
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockImplementation(
|
||||||
async (
|
async (_inv, _name, _sig, _tool, _liveCb, _shellCfg, setPidCallback) => {
|
||||||
_inv,
|
|
||||||
_name,
|
|
||||||
_sig,
|
|
||||||
_bus,
|
|
||||||
_hooks,
|
|
||||||
_tool,
|
|
||||||
_liveCb,
|
|
||||||
_shellCfg,
|
|
||||||
setPidCallback,
|
|
||||||
) => {
|
|
||||||
// Simulate the shell tool reporting a PID
|
// Simulate the shell tool reporting a PID
|
||||||
if (setPidCallback) {
|
if (setPidCallback) {
|
||||||
setPidCallback(testPid);
|
setPidCallback(testPid);
|
||||||
|
|||||||
@@ -66,8 +66,6 @@ export class ToolExecutor {
|
|||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
const shellExecutionConfig = this.config.getShellExecutionConfig();
|
const shellExecutionConfig = this.config.getShellExecutionConfig();
|
||||||
const hooksEnabled = this.config.getEnableHooks();
|
|
||||||
const messageBus = this.config.getMessageBus();
|
|
||||||
|
|
||||||
return runInDevTraceSpan(
|
return runInDevTraceSpan(
|
||||||
{
|
{
|
||||||
@@ -95,8 +93,6 @@ export class ToolExecutor {
|
|||||||
invocation,
|
invocation,
|
||||||
toolName,
|
toolName,
|
||||||
signal,
|
signal,
|
||||||
messageBus,
|
|
||||||
hooksEnabled,
|
|
||||||
tool,
|
tool,
|
||||||
liveOutputCallback,
|
liveOutputCallback,
|
||||||
shellExecutionConfig,
|
shellExecutionConfig,
|
||||||
@@ -108,8 +104,6 @@ export class ToolExecutor {
|
|||||||
invocation,
|
invocation,
|
||||||
toolName,
|
toolName,
|
||||||
signal,
|
signal,
|
||||||
messageBus,
|
|
||||||
hooksEnabled,
|
|
||||||
tool,
|
tool,
|
||||||
liveOutputCallback,
|
liveOutputCallback,
|
||||||
shellExecutionConfig,
|
shellExecutionConfig,
|
||||||
|
|||||||
Reference in New Issue
Block a user