From bf80263bd64bfada12df13e3f8855758eb3ae866 Mon Sep 17 00:00:00 2001 From: Allen Hutchison Date: Tue, 21 Oct 2025 11:45:33 -0700 Subject: [PATCH] feat: Implement message bus and policy engine (#11523) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- integration-tests/replace.test.ts | 2 +- packages/cli/src/config/policy.test.ts | 101 +++++++++++------- packages/cli/src/config/policy.ts | 41 +++++-- packages/cli/src/gemini.test.tsx | 8 ++ packages/cli/src/gemini.tsx | 5 + packages/core/src/config/config.ts | 2 +- .../core/src/confirmation-bus/message-bus.ts | 9 +- packages/core/src/confirmation-bus/types.ts | 9 +- packages/core/src/index.ts | 2 + packages/core/src/tools/glob.ts | 26 ++++- packages/core/src/tools/grep.ts | 26 ++++- packages/core/src/tools/ls.ts | 25 ++++- packages/core/src/tools/read-file.ts | 25 ++++- packages/core/src/tools/read-many-files.ts | 25 ++++- packages/core/src/tools/ripGrep.ts | 25 ++++- packages/core/src/tools/tools.ts | 61 ++++++++--- packages/core/src/tools/web-fetch.test.ts | 2 +- packages/core/src/tools/web-fetch.ts | 14 ++- packages/core/src/tools/web-search.ts | 25 ++++- 19 files changed, 339 insertions(+), 94 deletions(-) diff --git a/integration-tests/replace.test.ts b/integration-tests/replace.test.ts index b9452dba8d..4b0eaeddff 100644 --- a/integration-tests/replace.test.ts +++ b/integration-tests/replace.test.ts @@ -78,7 +78,7 @@ describe('replace', () => { rig.createFile(fileName, originalContent); await rig.run( - `In ${fileName}, delete the entire block from "## DELETE THIS ##" to "## END DELETE ##" including the markers.`, + `In ${fileName}, delete the entire block from "## DELETE THIS ##" to "## END DELETE ##" including the markers and the newline that follows it.`, ); const foundToolCall = await rig.waitForToolCall('replace'); diff --git a/packages/cli/src/config/policy.test.ts b/packages/cli/src/config/policy.test.ts index 25c20cd761..f6c442a9e6 100644 --- a/packages/cli/src/config/policy.test.ts +++ b/packages/cli/src/config/policy.test.ts @@ -14,16 +14,70 @@ import { } from '@google/gemini-cli-core'; describe('createPolicyEngineConfig', () => { - it('should return ASK_USER for all tools by default', () => { + it('should return ASK_USER for write tools and ALLOW for read-only tools by default', () => { const settings: Settings = {}; const config = createPolicyEngineConfig(settings, ApprovalMode.DEFAULT); expect(config.defaultDecision).toBe(PolicyDecision.ASK_USER); + // The order of the rules is not guaranteed, so we sort them by tool name. + config.rules?.sort((a, b) => + (a.toolName ?? '').localeCompare(b.toolName ?? ''), + ); expect(config.rules).toEqual([ - { toolName: 'replace', decision: 'ask_user', priority: 10 }, - { toolName: 'save_memory', decision: 'ask_user', priority: 10 }, - { toolName: 'run_shell_command', decision: 'ask_user', priority: 10 }, - { toolName: 'write_file', decision: 'ask_user', priority: 10 }, - { toolName: WEB_FETCH_TOOL_NAME, decision: 'ask_user', priority: 10 }, + { + toolName: 'glob', + decision: PolicyDecision.ALLOW, + priority: 50, + }, + { + toolName: 'google_web_search', + decision: PolicyDecision.ALLOW, + priority: 50, + }, + { + toolName: 'list_directory', + decision: PolicyDecision.ALLOW, + priority: 50, + }, + { + toolName: 'read_file', + decision: PolicyDecision.ALLOW, + priority: 50, + }, + { + toolName: 'read_many_files', + decision: PolicyDecision.ALLOW, + priority: 50, + }, + { + toolName: 'replace', + decision: PolicyDecision.ASK_USER, + priority: 10, + }, + { + toolName: 'run_shell_command', + decision: PolicyDecision.ASK_USER, + priority: 10, + }, + { + toolName: 'save_memory', + decision: PolicyDecision.ASK_USER, + priority: 10, + }, + { + toolName: 'search_file_content', + decision: PolicyDecision.ALLOW, + priority: 50, + }, + { + toolName: 'web_fetch', + decision: PolicyDecision.ASK_USER, + priority: 10, + }, + { + toolName: 'write_file', + decision: PolicyDecision.ASK_USER, + priority: 10, + }, ]); }); @@ -159,18 +213,6 @@ describe('createPolicyEngineConfig', () => { expect(excludedRule?.priority).toBe(195); }); - it('should allow read-only tools if autoAccept is true', () => { - const settings: Settings = { - tools: { autoAccept: true }, - }; - const config = createPolicyEngineConfig(settings, ApprovalMode.DEFAULT); - const rule = config.rules?.find( - (r) => r.toolName === 'glob' && r.decision === PolicyDecision.ALLOW, - ); - expect(rule).toBeDefined(); - expect(rule?.priority).toBe(50); - }); - it('should allow all tools in YOLO mode', () => { const settings: Settings = {}; const config = createPolicyEngineConfig(settings, ApprovalMode.YOLO); @@ -419,29 +461,6 @@ describe('createPolicyEngineConfig', () => { // Exclude (195) should win over trust (90) when evaluated }); - it('should create all read-only tool rules when autoAccept is enabled', () => { - const settings: Settings = { - tools: { autoAccept: true }, - }; - const config = createPolicyEngineConfig(settings, ApprovalMode.DEFAULT); - - // All read-only tools should have allow rules - const readOnlyTools = [ - 'glob', - 'search_file_content', - 'list_directory', - 'read_file', - 'read_many_files', - ]; - for (const tool of readOnlyTools) { - const rule = config.rules?.find( - (r) => r.toolName === tool && r.decision === PolicyDecision.ALLOW, - ); - expect(rule).toBeDefined(); - expect(rule?.priority).toBe(50); - } - }); - it('should handle all approval modes correctly', () => { const settings: Settings = {}; diff --git a/packages/cli/src/config/policy.ts b/packages/cli/src/config/policy.ts index 75e9b4315a..0ebe8f06e0 100644 --- a/packages/cli/src/config/policy.ts +++ b/packages/cli/src/config/policy.ts @@ -22,8 +22,12 @@ import { EDIT_TOOL_NAME, MEMORY_TOOL_NAME, WEB_SEARCH_TOOL_NAME, + type PolicyEngine, + type MessageBus, + MessageBusType, + type UpdatePolicy, } from '@google/gemini-cli-core'; -import type { Settings } from './settings.js'; +import { type Settings } from './settings.js'; // READ_ONLY_TOOLS is a list of built-in tools that do not modify the user's // files or system state. @@ -69,6 +73,7 @@ export function createPolicyEngineConfig( // 90: MCP servers with trust=true // 100: Explicitly allowed individual tools // 195: Explicitly excluded MCP servers + // 199: Tools that the user has selected as "Always Allow" in the interactive UI. // 200: Explicitly excluded individual tools (highest priority) // MCP servers that are explicitly allowed in settings.mcp.allowed @@ -137,16 +142,14 @@ export function createPolicyEngineConfig( } } - // If auto-accept is enabled, allow all read-only tools. + // Allow all read-only tools. // Priority: 50 - if (settings.tools?.autoAccept) { - for (const tool of READ_ONLY_TOOLS) { - rules.push({ - toolName: tool, - decision: PolicyDecision.ALLOW, - priority: 50, - }); - } + for (const tool of READ_ONLY_TOOLS) { + rules.push({ + toolName: tool, + decision: PolicyDecision.ALLOW, + priority: 50, + }); } // Only add write tool rules if not in YOLO mode @@ -179,3 +182,21 @@ export function createPolicyEngineConfig( defaultDecision: PolicyDecision.ASK_USER, }; } + +export function createPolicyUpdater( + policyEngine: PolicyEngine, + messageBus: MessageBus, +) { + messageBus.subscribe( + MessageBusType.UPDATE_POLICY, + (message: UpdatePolicy) => { + const toolName = message.toolName; + + policyEngine.addRule({ + toolName, + decision: PolicyDecision.ALLOW, + priority: 199, // High priority, but lower than explicit DENY (200) + }); + }, + ); +} diff --git a/packages/cli/src/gemini.test.tsx b/packages/cli/src/gemini.test.tsx index 1ccb148375..931e35a3b5 100644 --- a/packages/cli/src/gemini.test.tsx +++ b/packages/cli/src/gemini.test.tsx @@ -170,6 +170,10 @@ describe('gemini.tsx main function', () => { getScreenReader: () => false, getGeminiMdFileCount: () => 0, getProjectRoot: () => '/', + getPolicyEngine: vi.fn(), + getMessageBus: () => ({ + subscribe: vi.fn(), + }), } as unknown as Config; }); vi.mocked(loadSettings).mockReturnValue({ @@ -301,6 +305,10 @@ describe('gemini.tsx main function kitty protocol', () => { getExperimentalZedIntegration: () => false, getScreenReader: () => false, getGeminiMdFileCount: () => 0, + getPolicyEngine: vi.fn(), + getMessageBus: () => ({ + subscribe: vi.fn(), + }), } as unknown as Config); vi.mocked(loadSettings).mockReturnValue({ errors: [], diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 1f43ab7694..f22a9dda08 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -67,6 +67,7 @@ import { relaunchOnExitCode, } from './utils/relaunch.js'; import { loadSandboxConfig } from './config/sandboxConfig.js'; +import { createPolicyUpdater } from './config/policy.js'; import { ExtensionEnablementManager } from './config/extensions/extensionEnablement.js'; export function validateDnsResolutionOrder( @@ -370,6 +371,10 @@ export async function main() { argv, ); + const policyEngine = config.getPolicyEngine(); + const messageBus = config.getMessageBus(); + createPolicyUpdater(policyEngine, messageBus); + // Cleanup sessions after config initialization await cleanupExpiredSessions(config, settings.merged); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 584e928cb0..42c5008fb2 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -491,7 +491,7 @@ export class Config { this.fileExclusions = new FileExclusions(this); this.eventEmitter = params.eventEmitter; this.policyEngine = new PolicyEngine(params.policyEngineConfig); - this.messageBus = new MessageBus(this.policyEngine); + this.messageBus = new MessageBus(this.policyEngine, this.debugMode); this.outputSettings = { format: params.output?.format ?? OutputFormat.TEXT, }; diff --git a/packages/core/src/confirmation-bus/message-bus.ts b/packages/core/src/confirmation-bus/message-bus.ts index b9d66eff6a..b48129b412 100644 --- a/packages/core/src/confirmation-bus/message-bus.ts +++ b/packages/core/src/confirmation-bus/message-bus.ts @@ -11,8 +11,12 @@ import { MessageBusType, type Message } from './types.js'; import { safeJsonStringify } from '../utils/safeJsonStringify.js'; export class MessageBus extends EventEmitter { - constructor(private readonly policyEngine: PolicyEngine) { + constructor( + private readonly policyEngine: PolicyEngine, + private readonly debug = false, + ) { super(); + this.debug = debug; } private isValidMessage(message: Message): boolean { @@ -35,6 +39,9 @@ export class MessageBus extends EventEmitter { } publish(message: Message): void { + if (this.debug) { + console.debug(`[MESSAGE_BUS] publish: ${safeJsonStringify(message)}`); + } try { if (!this.isValidMessage(message)) { throw new Error( diff --git a/packages/core/src/confirmation-bus/types.ts b/packages/core/src/confirmation-bus/types.ts index f1c58a7c31..2b4bcf5685 100644 --- a/packages/core/src/confirmation-bus/types.ts +++ b/packages/core/src/confirmation-bus/types.ts @@ -12,6 +12,7 @@ export enum MessageBusType { TOOL_POLICY_REJECTION = 'tool-policy-rejection', TOOL_EXECUTION_SUCCESS = 'tool-execution-success', TOOL_EXECUTION_FAILURE = 'tool-execution-failure', + UPDATE_POLICY = 'update-policy', } export interface ToolConfirmationRequest { @@ -31,6 +32,11 @@ export interface ToolConfirmationResponse { requiresUserConfirmation?: boolean; } +export interface UpdatePolicy { + type: MessageBusType.UPDATE_POLICY; + toolName: string; +} + export interface ToolPolicyRejection { type: MessageBusType.TOOL_POLICY_REJECTION; toolCall: FunctionCall; @@ -53,4 +59,5 @@ export type Message = | ToolConfirmationResponse | ToolPolicyRejection | ToolExecutionSuccess - | ToolExecutionFailure; + | ToolExecutionFailure + | UpdatePolicy; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index bc8fb83308..9cb1662714 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -11,6 +11,8 @@ export * from './output/json-formatter.js'; export * from './output/stream-json-formatter.js'; export * from './policy/types.js'; export * from './policy/policy-engine.js'; +export * from './confirmation-bus/types.js'; +export * from './confirmation-bus/message-bus.js'; // Export Core Logic export * from './core/client.js'; diff --git a/packages/core/src/tools/glob.ts b/packages/core/src/tools/glob.ts index dadf282536..2ac2fd89c2 100644 --- a/packages/core/src/tools/glob.ts +++ b/packages/core/src/tools/glob.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import type { MessageBus } from '../confirmation-bus/message-bus.js'; import fs from 'node:fs'; import path from 'node:path'; import { glob, escape } from 'glob'; @@ -88,8 +89,11 @@ class GlobToolInvocation extends BaseToolInvocation< constructor( private config: Config, params: GlobToolParams, + messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ) { - super(params); + super(params, messageBus, _toolName, _toolDisplayName); } getDescription(): string { @@ -261,8 +265,10 @@ class GlobToolInvocation extends BaseToolInvocation< */ export class GlobTool extends BaseDeclarativeTool { static readonly Name = GLOB_TOOL_NAME; - - constructor(private config: Config) { + constructor( + private config: Config, + messageBus?: MessageBus, + ) { super( GlobTool.Name, 'FindFiles', @@ -299,6 +305,9 @@ export class GlobTool extends BaseDeclarativeTool { required: ['pattern'], type: 'object', }, + true, + false, + messageBus, ); } @@ -344,7 +353,16 @@ export class GlobTool extends BaseDeclarativeTool { protected createInvocation( params: GlobToolParams, + messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ): ToolInvocation { - return new GlobToolInvocation(this.config, params); + return new GlobToolInvocation( + this.config, + params, + messageBus, + _toolName, + _toolDisplayName, + ); } } diff --git a/packages/core/src/tools/grep.ts b/packages/core/src/tools/grep.ts index 00700476b0..cee8dd24f7 100644 --- a/packages/core/src/tools/grep.ts +++ b/packages/core/src/tools/grep.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import type { MessageBus } from '../confirmation-bus/message-bus.js'; import fs from 'node:fs'; import fsPromises from 'node:fs/promises'; import path from 'node:path'; @@ -61,8 +62,11 @@ class GrepToolInvocation extends BaseToolInvocation< constructor( private readonly config: Config, params: GrepToolParams, + messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ) { - super(params); + super(params, messageBus, _toolName, _toolDisplayName); this.fileExclusions = config.getFileExclusions(); } @@ -565,8 +569,10 @@ class GrepToolInvocation extends BaseToolInvocation< */ export class GrepTool extends BaseDeclarativeTool { static readonly Name = GREP_TOOL_NAME; - - constructor(private readonly config: Config) { + constructor( + private readonly config: Config, + messageBus?: MessageBus, + ) { super( GrepTool.Name, 'SearchText', @@ -593,6 +599,9 @@ export class GrepTool extends BaseDeclarativeTool { required: ['pattern'], type: 'object', }, + true, + false, + messageBus, ); } @@ -665,7 +674,16 @@ export class GrepTool extends BaseDeclarativeTool { protected createInvocation( params: GrepToolParams, + messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ): ToolInvocation { - return new GrepToolInvocation(this.config, params); + return new GrepToolInvocation( + this.config, + params, + messageBus, + _toolName, + _toolDisplayName, + ); } } diff --git a/packages/core/src/tools/ls.ts b/packages/core/src/tools/ls.ts index d34c91b275..9699be5d60 100644 --- a/packages/core/src/tools/ls.ts +++ b/packages/core/src/tools/ls.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import type { MessageBus } from '../confirmation-bus/message-bus.js'; import fs from 'node:fs/promises'; import path from 'node:path'; import type { ToolInvocation, ToolResult } from './tools.js'; @@ -71,8 +72,11 @@ class LSToolInvocation extends BaseToolInvocation { constructor( private readonly config: Config, params: LSToolParams, + messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ) { - super(params); + super(params, messageBus, _toolName, _toolDisplayName); } /** @@ -255,7 +259,10 @@ class LSToolInvocation extends BaseToolInvocation { export class LSTool extends BaseDeclarativeTool { static readonly Name = LS_TOOL_NAME; - constructor(private config: Config) { + constructor( + private config: Config, + messageBus?: MessageBus, + ) { super( LSTool.Name, 'ReadFolder', @@ -296,6 +303,9 @@ export class LSTool extends BaseDeclarativeTool { required: ['path'], type: 'object', }, + true, + false, + messageBus, ); } @@ -323,7 +333,16 @@ export class LSTool extends BaseDeclarativeTool { protected createInvocation( params: LSToolParams, + messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ): ToolInvocation { - return new LSToolInvocation(this.config, params); + return new LSToolInvocation( + this.config, + params, + messageBus, + _toolName, + _toolDisplayName, + ); } } diff --git a/packages/core/src/tools/read-file.ts b/packages/core/src/tools/read-file.ts index ac43590f7f..9584865746 100644 --- a/packages/core/src/tools/read-file.ts +++ b/packages/core/src/tools/read-file.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import type { MessageBus } from '../confirmation-bus/message-bus.js'; import path from 'node:path'; import { makeRelative, shortenPath } from '../utils/paths.js'; import type { ToolInvocation, ToolLocation, ToolResult } from './tools.js'; @@ -48,8 +49,11 @@ class ReadFileToolInvocation extends BaseToolInvocation< constructor( private config: Config, params: ReadFileToolParams, + messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ) { - super(params); + super(params, messageBus, _toolName, _toolDisplayName); } getDescription(): string { @@ -138,7 +142,10 @@ export class ReadFileTool extends BaseDeclarativeTool< > { static readonly Name = READ_FILE_TOOL_NAME; - constructor(private config: Config) { + constructor( + private config: Config, + messageBus?: MessageBus, + ) { super( ReadFileTool.Name, 'ReadFile', @@ -165,6 +172,9 @@ export class ReadFileTool extends BaseDeclarativeTool< required: ['absolute_path'], type: 'object', }, + true, + false, + messageBus, ); } @@ -209,7 +219,16 @@ export class ReadFileTool extends BaseDeclarativeTool< protected createInvocation( params: ReadFileToolParams, + messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ): ToolInvocation { - return new ReadFileToolInvocation(this.config, params); + return new ReadFileToolInvocation( + this.config, + params, + messageBus, + _toolName, + _toolDisplayName, + ); } } diff --git a/packages/core/src/tools/read-many-files.ts b/packages/core/src/tools/read-many-files.ts index e6d238dbe5..88d2660c1b 100644 --- a/packages/core/src/tools/read-many-files.ts +++ b/packages/core/src/tools/read-many-files.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import type { MessageBus } from '../confirmation-bus/message-bus.js'; import type { ToolInvocation, ToolResult } from './tools.js'; import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js'; import { getErrorMessage } from '../utils/errors.js'; @@ -114,8 +115,11 @@ class ReadManyFilesToolInvocation extends BaseToolInvocation< constructor( private readonly config: Config, params: ReadManyFilesParams, + messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ) { - super(params); + super(params, messageBus, _toolName, _toolDisplayName); } getDescription(): string { @@ -477,7 +481,10 @@ export class ReadManyFilesTool extends BaseDeclarativeTool< > { static readonly Name = READ_MANY_FILES_TOOL_NAME; - constructor(private config: Config) { + constructor( + private config: Config, + messageBus?: MessageBus, + ) { const parameterSchema = { type: 'object', properties: { @@ -559,12 +566,24 @@ This tool is useful when you need to understand or analyze a collection of files Use this tool when the user's query implies needing the content of several files simultaneously for context, analysis, or summarization. For text files, it uses default UTF-8 encoding and a '--- {filePath} ---' separator between file contents. The tool inserts a '--- End of content ---' after the last file. Ensure paths are relative to the target directory. Glob patterns like 'src/**/*.js' are supported. Avoid using for single files if a more specific single-file reading tool is available, unless the user specifically requests to process a list containing just one file via this tool. Other binary files (not explicitly requested as image/PDF) are generally skipped. Default excludes apply to common non-text files (except for explicitly requested images/PDFs) and large dependency directories unless 'useDefaultExcludes' is false.`, Kind.Read, parameterSchema, + true, // isOutputMarkdown + false, // canUpdateOutput + messageBus, ); } protected createInvocation( params: ReadManyFilesParams, + messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ): ToolInvocation { - return new ReadManyFilesToolInvocation(this.config, params); + return new ReadManyFilesToolInvocation( + this.config, + params, + messageBus, + _toolName, + _toolDisplayName, + ); } } diff --git a/packages/core/src/tools/ripGrep.ts b/packages/core/src/tools/ripGrep.ts index 92aa110103..054f01b558 100644 --- a/packages/core/src/tools/ripGrep.ts +++ b/packages/core/src/tools/ripGrep.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import type { MessageBus } from '../confirmation-bus/message-bus.js'; import fs from 'node:fs'; import path from 'node:path'; import { EOL } from 'node:os'; @@ -110,8 +111,11 @@ class GrepToolInvocation extends BaseToolInvocation< constructor( private readonly config: Config, params: RipGrepToolParams, + messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ) { - super(params); + super(params, messageBus, _toolName, _toolDisplayName); } /** @@ -449,7 +453,10 @@ export class RipGrepTool extends BaseDeclarativeTool< > { static readonly Name = GREP_TOOL_NAME; - constructor(private readonly config: Config) { + constructor( + private readonly config: Config, + messageBus?: MessageBus, + ) { super( RipGrepTool.Name, 'SearchText', @@ -476,6 +483,9 @@ export class RipGrepTool extends BaseDeclarativeTool< required: ['pattern'], type: 'object', }, + true, // isOutputMarkdown + false, // canUpdateOutput + messageBus, ); } @@ -548,7 +558,16 @@ export class RipGrepTool extends BaseDeclarativeTool< protected createInvocation( params: RipGrepToolParams, + messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ): ToolInvocation { - return new GrepToolInvocation(this.config, params); + return new GrepToolInvocation( + this.config, + params, + messageBus, + _toolName, + _toolDisplayName, + ); } } diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 86690cd675..fbf58b2e4e 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -76,13 +76,9 @@ export abstract class BaseToolInvocation< constructor( readonly params: TParams, protected readonly messageBus?: MessageBus, - ) { - if (this.messageBus) { - console.debug( - `[DEBUG] Tool ${this.constructor.name} created with messageBus: YES`, - ); - } - } + readonly _toolName?: string, + readonly _toolDisplayName?: string, + ) {} abstract getDescription(): string; @@ -90,11 +86,43 @@ export abstract class BaseToolInvocation< return []; } - shouldConfirmExecute( - _abortSignal: AbortSignal, + async shouldConfirmExecute( + abortSignal: AbortSignal, ): Promise { - // Default implementation for tools that don't override it. - return Promise.resolve(false); + if (this.messageBus) { + const decision = await this.getMessageBusDecision(abortSignal); + if (decision === 'ALLOW') { + return false; + } + + if (decision === 'DENY') { + throw new Error( + `Tool execution for "${ + this._toolDisplayName || this._toolName + }" denied by policy.`, + ); + } + + if (decision === 'ASK_USER') { + const confirmationDetails: ToolCallConfirmationDetails = { + type: 'info', + title: `Confirm: ${this._toolDisplayName || this._toolName}`, + prompt: this.getDescription(), + onConfirm: async (outcome: ToolConfirmationOutcome) => { + if (outcome === ToolConfirmationOutcome.ProceedAlways) { + if (this.messageBus && this._toolName) { + this.messageBus.publish({ + type: MessageBusType.UPDATE_POLICY, + toolName: this._toolName, + }); + } + } + }, + }; + return confirmationDetails; + } + } + return false; } protected getMessageBusDecision( @@ -108,7 +136,7 @@ export abstract class BaseToolInvocation< const correlationId = randomUUID(); const toolCall = { - name: this.constructor.name, + name: this._toolName || this.constructor.name, args: this.params as Record, }; @@ -385,7 +413,12 @@ export abstract class BaseDeclarativeTool< if (validationError) { throw new Error(validationError); } - return this.createInvocation(params, this.messageBus); + return this.createInvocation( + params, + this.messageBus, + this.name, + this.displayName, + ); } override validateToolParams(params: TParams): string | null { @@ -408,6 +441,8 @@ export abstract class BaseDeclarativeTool< protected abstract createInvocation( params: TParams, messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ): ToolInvocation; } diff --git a/packages/core/src/tools/web-fetch.test.ts b/packages/core/src/tools/web-fetch.test.ts index 2ae9d62425..69adeb23ac 100644 --- a/packages/core/src/tools/web-fetch.test.ts +++ b/packages/core/src/tools/web-fetch.test.ts @@ -471,7 +471,7 @@ describe('WebFetchTool', () => { expect(publishSpy).toHaveBeenCalledWith({ type: MessageBusType.TOOL_CONFIRMATION_REQUEST, toolCall: { - name: 'WebFetchToolInvocation', + name: 'web_fetch', args: { prompt: 'fetch https://example.com' }, }, correlationId: 'test-correlation-id', diff --git a/packages/core/src/tools/web-fetch.ts b/packages/core/src/tools/web-fetch.ts index 4971aa1abe..5e1835a13c 100644 --- a/packages/core/src/tools/web-fetch.ts +++ b/packages/core/src/tools/web-fetch.ts @@ -110,8 +110,10 @@ class WebFetchToolInvocation extends BaseToolInvocation< private readonly config: Config, params: WebFetchToolParams, messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ) { - super(params, messageBus); + super(params, messageBus, _toolName, _toolDisplayName); } private async executeFallback(signal: AbortSignal): Promise { @@ -450,7 +452,15 @@ export class WebFetchTool extends BaseDeclarativeTool< protected createInvocation( params: WebFetchToolParams, messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ): ToolInvocation { - return new WebFetchToolInvocation(this.config, params, messageBus); + return new WebFetchToolInvocation( + this.config, + params, + messageBus, + _toolName, + _toolDisplayName, + ); } } diff --git a/packages/core/src/tools/web-search.ts b/packages/core/src/tools/web-search.ts index 1aaa5ea02a..c1b21b6afa 100644 --- a/packages/core/src/tools/web-search.ts +++ b/packages/core/src/tools/web-search.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { WEB_SEARCH_TOOL_NAME } from './tool-names.js'; import type { GroundingMetadata } from '@google/genai'; import type { ToolInvocation, ToolResult } from './tools.js'; @@ -64,8 +65,11 @@ class WebSearchToolInvocation extends BaseToolInvocation< constructor( private readonly config: Config, params: WebSearchToolParams, + messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ) { - super(params); + super(params, messageBus, _toolName, _toolDisplayName); } override getDescription(): string { @@ -187,7 +191,10 @@ export class WebSearchTool extends BaseDeclarativeTool< > { static readonly Name = WEB_SEARCH_TOOL_NAME; - constructor(private readonly config: Config) { + constructor( + private readonly config: Config, + messageBus?: MessageBus, + ) { super( WebSearchTool.Name, 'GoogleSearch', @@ -203,6 +210,9 @@ export class WebSearchTool extends BaseDeclarativeTool< }, required: ['query'], }, + true, // isOutputMarkdown + false, // canUpdateOutput + messageBus, ); } @@ -222,7 +232,16 @@ export class WebSearchTool extends BaseDeclarativeTool< protected createInvocation( params: WebSearchToolParams, + messageBus?: MessageBus, + _toolName?: string, + _toolDisplayName?: string, ): ToolInvocation { - return new WebSearchToolInvocation(this.config, params); + return new WebSearchToolInvocation( + this.config, + params, + messageBus, + _toolName, + _toolDisplayName, + ); } }