mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 13:22:35 -07:00
Implement support for tool input modification (#15492)
This commit is contained in:
committed by
GitHub
parent
15c9f88da6
commit
90eb1e0281
@@ -0,0 +1,131 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { executeToolWithHooks } from './coreToolHookTriggers.js';
|
||||
import {
|
||||
BaseToolInvocation,
|
||||
type ToolResult,
|
||||
type AnyDeclarativeTool,
|
||||
} from '../tools/tools.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import {
|
||||
MessageBusType,
|
||||
type HookExecutionResponse,
|
||||
} from '../confirmation-bus/types.js';
|
||||
|
||||
class MockInvocation extends BaseToolInvocation<{ key: string }, ToolResult> {
|
||||
constructor(params: { key: string }) {
|
||||
super(params);
|
||||
}
|
||||
getDescription() {
|
||||
return 'mock';
|
||||
}
|
||||
async execute() {
|
||||
return {
|
||||
llmContent: `key: ${this.params.key}`,
|
||||
returnDisplay: `key: ${this.params.key}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
describe('executeToolWithHooks', () => {
|
||||
let messageBus: MessageBus;
|
||||
let mockTool: AnyDeclarativeTool;
|
||||
|
||||
beforeEach(() => {
|
||||
messageBus = {
|
||||
request: vi.fn(),
|
||||
} as unknown as MessageBus;
|
||||
mockTool = {
|
||||
build: vi.fn().mockImplementation((params) => new MockInvocation(params)),
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
});
|
||||
|
||||
it('should apply modified tool input from BeforeTool hook', async () => {
|
||||
const params = { key: 'original' };
|
||||
const invocation = new MockInvocation(params);
|
||||
const toolName = 'test-tool';
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
// Capture arguments to verify what was passed before modification
|
||||
const requestSpy = vi.fn().mockImplementation(async (request) => {
|
||||
if (request.eventName === 'BeforeTool') {
|
||||
// Verify input is original before we return modification instruction
|
||||
expect(request.input.tool_input.key).toBe('original');
|
||||
return {
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: 'test-id',
|
||||
success: true,
|
||||
output: {
|
||||
hookSpecificOutput: {
|
||||
hookEventName: 'BeforeTool',
|
||||
tool_input: { key: 'modified' },
|
||||
},
|
||||
},
|
||||
} as HookExecutionResponse;
|
||||
}
|
||||
return {
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: 'test-id',
|
||||
success: true,
|
||||
output: {},
|
||||
} as HookExecutionResponse;
|
||||
});
|
||||
messageBus.request = requestSpy;
|
||||
|
||||
const result = await executeToolWithHooks(
|
||||
invocation,
|
||||
toolName,
|
||||
abortSignal,
|
||||
messageBus,
|
||||
true, // hooksEnabled
|
||||
mockTool,
|
||||
);
|
||||
|
||||
// Verify result reflects modified input
|
||||
expect(result.llmContent).toBe(
|
||||
'key: modified\n\n[System] Tool input parameters (key) were modified by a hook before execution.',
|
||||
);
|
||||
// Verify params object was modified in place
|
||||
expect(invocation.params.key).toBe('modified');
|
||||
|
||||
expect(requestSpy).toHaveBeenCalled();
|
||||
expect(mockTool.build).toHaveBeenCalledWith({ key: 'modified' });
|
||||
});
|
||||
|
||||
it('should not modify input if hook does not provide tool_input', async () => {
|
||||
const params = { key: 'original' };
|
||||
const invocation = new MockInvocation(params);
|
||||
const toolName = 'test-tool';
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
vi.mocked(messageBus.request).mockResolvedValue({
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: 'test-id',
|
||||
success: true,
|
||||
output: {
|
||||
hookSpecificOutput: {
|
||||
hookEventName: 'BeforeTool',
|
||||
// No tool_input
|
||||
},
|
||||
},
|
||||
} as HookExecutionResponse);
|
||||
|
||||
const result = await executeToolWithHooks(
|
||||
invocation,
|
||||
toolName,
|
||||
abortSignal,
|
||||
messageBus,
|
||||
true, // hooksEnabled
|
||||
mockTool,
|
||||
);
|
||||
|
||||
expect(result.llmContent).toBe('key: original');
|
||||
expect(invocation.params.key).toBe('original');
|
||||
expect(mockTool.build).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -14,10 +14,12 @@ import {
|
||||
createHookOutput,
|
||||
NotificationType,
|
||||
type DefaultHookOutput,
|
||||
BeforeToolHookOutput,
|
||||
} from '../hooks/types.js';
|
||||
import type {
|
||||
ToolCallConfirmationDetails,
|
||||
ToolResult,
|
||||
AnyDeclarativeTool,
|
||||
} from '../tools/tools.js';
|
||||
import { ToolErrorType } from '../tools/tool-error.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
@@ -254,11 +256,14 @@ export async function executeToolWithHooks(
|
||||
signal: AbortSignal,
|
||||
messageBus: MessageBus | undefined,
|
||||
hooksEnabled: boolean,
|
||||
tool: AnyDeclarativeTool,
|
||||
liveOutputCallback?: (outputChunk: string | AnsiOutput) => void,
|
||||
shellExecutionConfig?: ShellExecutionConfig,
|
||||
setPidCallback?: (pid: number) => void,
|
||||
): Promise<ToolResult> {
|
||||
const toolInput = (invocation.params || {}) as Record<string, unknown>;
|
||||
let inputWasModified = false;
|
||||
let modifiedKeys: string[] = [];
|
||||
|
||||
// Fire BeforeTool hook through MessageBus (only if hooks are enabled)
|
||||
if (hooksEnabled && messageBus) {
|
||||
@@ -293,6 +298,38 @@ export async function executeToolWithHooks(
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Check if hook requested to update tool input
|
||||
if (beforeOutput instanceof BeforeToolHookOutput) {
|
||||
const modifiedInput = beforeOutput.getModifiedToolInput();
|
||||
if (modifiedInput) {
|
||||
// We modify the toolInput object in-place, which should be the same reference as invocation.params
|
||||
// We use Object.assign to update properties
|
||||
Object.assign(invocation.params, modifiedInput);
|
||||
debugLogger.debug(`Tool input modified by hook for ${toolName}`);
|
||||
inputWasModified = true;
|
||||
modifiedKeys = Object.keys(modifiedInput);
|
||||
|
||||
// Recreate the invocation with the new parameters
|
||||
// to ensure any derived state (like resolvedPath in ReadFileTool) is updated.
|
||||
try {
|
||||
// We use the tool's build method to validate and create the invocation
|
||||
// This ensures consistent behavior with the initial creation
|
||||
invocation = tool.build(invocation.params);
|
||||
} catch (error) {
|
||||
return {
|
||||
llmContent: `Tool parameter modification by hook failed validation: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
returnDisplay: `Tool parameter modification by hook failed validation.`,
|
||||
error: {
|
||||
type: ToolErrorType.INVALID_TOOL_PARAMS,
|
||||
message: String(error),
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the actual tool
|
||||
@@ -312,6 +349,24 @@ export async function executeToolWithHooks(
|
||||
);
|
||||
}
|
||||
|
||||
// Append notification if parameters were modified
|
||||
if (inputWasModified) {
|
||||
const modificationMsg = `\n\n[System] Tool input parameters (${modifiedKeys.join(
|
||||
', ',
|
||||
)}) were modified by a hook before execution.`;
|
||||
if (typeof toolResult.llmContent === 'string') {
|
||||
toolResult.llmContent += modificationMsg;
|
||||
} else if (Array.isArray(toolResult.llmContent)) {
|
||||
toolResult.llmContent.push({ text: modificationMsg });
|
||||
} else if (toolResult.llmContent) {
|
||||
// Handle single Part case by converting to an array
|
||||
toolResult.llmContent = [
|
||||
toolResult.llmContent,
|
||||
{ text: modificationMsg },
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
// Fire AfterTool hook through MessageBus (only if hooks are enabled)
|
||||
if (hooksEnabled && messageBus) {
|
||||
const afterOutput = await fireAfterToolHook(
|
||||
|
||||
@@ -900,6 +900,7 @@ export class CoreToolScheduler {
|
||||
signal,
|
||||
messageBus,
|
||||
hooksEnabled,
|
||||
toolCall.tool,
|
||||
liveOutputCallback,
|
||||
shellExecutionConfig,
|
||||
setPidCallback,
|
||||
@@ -911,6 +912,7 @@ export class CoreToolScheduler {
|
||||
signal,
|
||||
messageBus,
|
||||
hooksEnabled,
|
||||
toolCall.tool,
|
||||
liveOutputCallback,
|
||||
shellExecutionConfig,
|
||||
);
|
||||
|
||||
@@ -158,6 +158,14 @@ export class HookAggregator {
|
||||
merged.suppressOutput = true;
|
||||
}
|
||||
|
||||
// Merge hookSpecificOutput
|
||||
if (output.hookSpecificOutput) {
|
||||
merged.hookSpecificOutput = {
|
||||
...(merged.hookSpecificOutput || {}),
|
||||
...output.hookSpecificOutput,
|
||||
};
|
||||
}
|
||||
|
||||
// Collect additional context from hook-specific outputs
|
||||
this.extractAdditionalContext(output, additionalContexts);
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import type {
|
||||
BeforeAgentInput,
|
||||
BeforeModelInput,
|
||||
BeforeModelOutput,
|
||||
BeforeToolInput,
|
||||
} from './types.js';
|
||||
import type { LLMRequest } from './hookTranslator.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
@@ -190,6 +191,20 @@ export class HookRunner {
|
||||
}
|
||||
break;
|
||||
|
||||
case HookEventName.BeforeTool:
|
||||
if ('tool_input' in hookOutput.hookSpecificOutput) {
|
||||
const newToolInput = hookOutput.hookSpecificOutput[
|
||||
'tool_input'
|
||||
] as Record<string, unknown>;
|
||||
if (newToolInput && 'tool_input' in modifiedInput) {
|
||||
(modifiedInput as BeforeToolInput).tool_input = {
|
||||
...(modifiedInput as BeforeToolInput).tool_input,
|
||||
...newToolInput,
|
||||
};
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
// For other events, no special input modification is needed
|
||||
break;
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
AfterModelHookOutput,
|
||||
HookEventName,
|
||||
HookType,
|
||||
BeforeToolHookOutput,
|
||||
} from './types.js';
|
||||
import { defaultHookTranslator } from './hookTranslator.js';
|
||||
import type {
|
||||
@@ -92,6 +93,11 @@ describe('Hook Output Classes', () => {
|
||||
const output = createHookOutput(HookEventName.BeforeToolSelection, {});
|
||||
expect(output).toBeInstanceOf(BeforeToolSelectionHookOutput);
|
||||
});
|
||||
|
||||
it('should return BeforeToolHookOutput for BeforeTool event', () => {
|
||||
const output = createHookOutput(HookEventName.BeforeTool, {});
|
||||
expect(output).toBeInstanceOf(BeforeToolHookOutput);
|
||||
});
|
||||
});
|
||||
|
||||
describe('DefaultHookOutput', () => {
|
||||
|
||||
@@ -133,6 +133,8 @@ export function createHookOutput(
|
||||
return new AfterModelHookOutput(data);
|
||||
case 'BeforeToolSelection':
|
||||
return new BeforeToolSelectionHookOutput(data);
|
||||
case 'BeforeTool':
|
||||
return new BeforeToolHookOutput(data);
|
||||
default:
|
||||
return new DefaultHookOutput(data);
|
||||
}
|
||||
@@ -236,7 +238,24 @@ export class DefaultHookOutput implements HookOutput {
|
||||
/**
|
||||
* Specific hook output class for BeforeTool events.
|
||||
*/
|
||||
export class BeforeToolHookOutput extends DefaultHookOutput {}
|
||||
export class BeforeToolHookOutput extends DefaultHookOutput {
|
||||
/**
|
||||
* Get modified tool input if provided by hook
|
||||
*/
|
||||
getModifiedToolInput(): Record<string, unknown> | undefined {
|
||||
if (this.hookSpecificOutput && 'tool_input' in this.hookSpecificOutput) {
|
||||
const input = this.hookSpecificOutput['tool_input'];
|
||||
if (
|
||||
typeof input === 'object' &&
|
||||
input !== null &&
|
||||
!Array.isArray(input)
|
||||
) {
|
||||
return input as Record<string, unknown>;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Specific hook output class for BeforeModel events
|
||||
@@ -368,6 +387,7 @@ export interface BeforeToolInput extends HookInput {
|
||||
export interface BeforeToolOutput extends HookOutput {
|
||||
hookSpecificOutput?: {
|
||||
hookEventName: 'BeforeTool';
|
||||
tool_input?: Record<string, unknown>;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user