feat: Implement message bus and policy engine (#11523)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Allen Hutchison
2025-10-21 11:45:33 -07:00
committed by GitHub
parent 0658b4aa31
commit bf80263bd6
19 changed files with 339 additions and 94 deletions
+1 -1
View File
@@ -78,7 +78,7 @@ describe('replace', () => {
rig.createFile(fileName, originalContent); rig.createFile(fileName, originalContent);
await rig.run( await rig.run(
`In ${fileName}, delete the entire block from "## DELETE THIS ##" to "## END DELETE ##" including the markers.`, `In ${fileName}, delete the entire block from "## DELETE THIS ##" to "## END DELETE ##" including the markers and the newline that follows it.`,
); );
const foundToolCall = await rig.waitForToolCall('replace'); const foundToolCall = await rig.waitForToolCall('replace');
+60 -41
View File
@@ -14,16 +14,70 @@ import {
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
describe('createPolicyEngineConfig', () => { describe('createPolicyEngineConfig', () => {
it('should return ASK_USER for all tools by default', () => { it('should return ASK_USER for write tools and ALLOW for read-only tools by default', () => {
const settings: Settings = {}; const settings: Settings = {};
const config = createPolicyEngineConfig(settings, ApprovalMode.DEFAULT); const config = createPolicyEngineConfig(settings, ApprovalMode.DEFAULT);
expect(config.defaultDecision).toBe(PolicyDecision.ASK_USER); expect(config.defaultDecision).toBe(PolicyDecision.ASK_USER);
// The order of the rules is not guaranteed, so we sort them by tool name.
config.rules?.sort((a, b) =>
(a.toolName ?? '').localeCompare(b.toolName ?? ''),
);
expect(config.rules).toEqual([ expect(config.rules).toEqual([
{ toolName: 'replace', decision: 'ask_user', priority: 10 }, {
{ toolName: 'save_memory', decision: 'ask_user', priority: 10 }, toolName: 'glob',
{ toolName: 'run_shell_command', decision: 'ask_user', priority: 10 }, decision: PolicyDecision.ALLOW,
{ toolName: 'write_file', decision: 'ask_user', priority: 10 }, priority: 50,
{ toolName: WEB_FETCH_TOOL_NAME, decision: 'ask_user', priority: 10 }, },
{
toolName: 'google_web_search',
decision: PolicyDecision.ALLOW,
priority: 50,
},
{
toolName: 'list_directory',
decision: PolicyDecision.ALLOW,
priority: 50,
},
{
toolName: 'read_file',
decision: PolicyDecision.ALLOW,
priority: 50,
},
{
toolName: 'read_many_files',
decision: PolicyDecision.ALLOW,
priority: 50,
},
{
toolName: 'replace',
decision: PolicyDecision.ASK_USER,
priority: 10,
},
{
toolName: 'run_shell_command',
decision: PolicyDecision.ASK_USER,
priority: 10,
},
{
toolName: 'save_memory',
decision: PolicyDecision.ASK_USER,
priority: 10,
},
{
toolName: 'search_file_content',
decision: PolicyDecision.ALLOW,
priority: 50,
},
{
toolName: 'web_fetch',
decision: PolicyDecision.ASK_USER,
priority: 10,
},
{
toolName: 'write_file',
decision: PolicyDecision.ASK_USER,
priority: 10,
},
]); ]);
}); });
@@ -159,18 +213,6 @@ describe('createPolicyEngineConfig', () => {
expect(excludedRule?.priority).toBe(195); expect(excludedRule?.priority).toBe(195);
}); });
it('should allow read-only tools if autoAccept is true', () => {
const settings: Settings = {
tools: { autoAccept: true },
};
const config = createPolicyEngineConfig(settings, ApprovalMode.DEFAULT);
const rule = config.rules?.find(
(r) => r.toolName === 'glob' && r.decision === PolicyDecision.ALLOW,
);
expect(rule).toBeDefined();
expect(rule?.priority).toBe(50);
});
it('should allow all tools in YOLO mode', () => { it('should allow all tools in YOLO mode', () => {
const settings: Settings = {}; const settings: Settings = {};
const config = createPolicyEngineConfig(settings, ApprovalMode.YOLO); const config = createPolicyEngineConfig(settings, ApprovalMode.YOLO);
@@ -419,29 +461,6 @@ describe('createPolicyEngineConfig', () => {
// Exclude (195) should win over trust (90) when evaluated // Exclude (195) should win over trust (90) when evaluated
}); });
it('should create all read-only tool rules when autoAccept is enabled', () => {
const settings: Settings = {
tools: { autoAccept: true },
};
const config = createPolicyEngineConfig(settings, ApprovalMode.DEFAULT);
// All read-only tools should have allow rules
const readOnlyTools = [
'glob',
'search_file_content',
'list_directory',
'read_file',
'read_many_files',
];
for (const tool of readOnlyTools) {
const rule = config.rules?.find(
(r) => r.toolName === tool && r.decision === PolicyDecision.ALLOW,
);
expect(rule).toBeDefined();
expect(rule?.priority).toBe(50);
}
});
it('should handle all approval modes correctly', () => { it('should handle all approval modes correctly', () => {
const settings: Settings = {}; const settings: Settings = {};
+25 -4
View File
@@ -22,8 +22,12 @@ import {
EDIT_TOOL_NAME, EDIT_TOOL_NAME,
MEMORY_TOOL_NAME, MEMORY_TOOL_NAME,
WEB_SEARCH_TOOL_NAME, WEB_SEARCH_TOOL_NAME,
type PolicyEngine,
type MessageBus,
MessageBusType,
type UpdatePolicy,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import type { Settings } from './settings.js'; import { type Settings } from './settings.js';
// READ_ONLY_TOOLS is a list of built-in tools that do not modify the user's // READ_ONLY_TOOLS is a list of built-in tools that do not modify the user's
// files or system state. // files or system state.
@@ -69,6 +73,7 @@ export function createPolicyEngineConfig(
// 90: MCP servers with trust=true // 90: MCP servers with trust=true
// 100: Explicitly allowed individual tools // 100: Explicitly allowed individual tools
// 195: Explicitly excluded MCP servers // 195: Explicitly excluded MCP servers
// 199: Tools that the user has selected as "Always Allow" in the interactive UI.
// 200: Explicitly excluded individual tools (highest priority) // 200: Explicitly excluded individual tools (highest priority)
// MCP servers that are explicitly allowed in settings.mcp.allowed // MCP servers that are explicitly allowed in settings.mcp.allowed
@@ -137,9 +142,8 @@ export function createPolicyEngineConfig(
} }
} }
// If auto-accept is enabled, allow all read-only tools. // Allow all read-only tools.
// Priority: 50 // Priority: 50
if (settings.tools?.autoAccept) {
for (const tool of READ_ONLY_TOOLS) { for (const tool of READ_ONLY_TOOLS) {
rules.push({ rules.push({
toolName: tool, toolName: tool,
@@ -147,7 +151,6 @@ export function createPolicyEngineConfig(
priority: 50, priority: 50,
}); });
} }
}
// Only add write tool rules if not in YOLO mode // Only add write tool rules if not in YOLO mode
// In YOLO mode, the wildcard ALLOW rule handles everything // In YOLO mode, the wildcard ALLOW rule handles everything
@@ -179,3 +182,21 @@ export function createPolicyEngineConfig(
defaultDecision: PolicyDecision.ASK_USER, defaultDecision: PolicyDecision.ASK_USER,
}; };
} }
export function createPolicyUpdater(
policyEngine: PolicyEngine,
messageBus: MessageBus,
) {
messageBus.subscribe(
MessageBusType.UPDATE_POLICY,
(message: UpdatePolicy) => {
const toolName = message.toolName;
policyEngine.addRule({
toolName,
decision: PolicyDecision.ALLOW,
priority: 199, // High priority, but lower than explicit DENY (200)
});
},
);
}
+8
View File
@@ -170,6 +170,10 @@ describe('gemini.tsx main function', () => {
getScreenReader: () => false, getScreenReader: () => false,
getGeminiMdFileCount: () => 0, getGeminiMdFileCount: () => 0,
getProjectRoot: () => '/', getProjectRoot: () => '/',
getPolicyEngine: vi.fn(),
getMessageBus: () => ({
subscribe: vi.fn(),
}),
} as unknown as Config; } as unknown as Config;
}); });
vi.mocked(loadSettings).mockReturnValue({ vi.mocked(loadSettings).mockReturnValue({
@@ -301,6 +305,10 @@ describe('gemini.tsx main function kitty protocol', () => {
getExperimentalZedIntegration: () => false, getExperimentalZedIntegration: () => false,
getScreenReader: () => false, getScreenReader: () => false,
getGeminiMdFileCount: () => 0, getGeminiMdFileCount: () => 0,
getPolicyEngine: vi.fn(),
getMessageBus: () => ({
subscribe: vi.fn(),
}),
} as unknown as Config); } as unknown as Config);
vi.mocked(loadSettings).mockReturnValue({ vi.mocked(loadSettings).mockReturnValue({
errors: [], errors: [],
+5
View File
@@ -67,6 +67,7 @@ import {
relaunchOnExitCode, relaunchOnExitCode,
} from './utils/relaunch.js'; } from './utils/relaunch.js';
import { loadSandboxConfig } from './config/sandboxConfig.js'; import { loadSandboxConfig } from './config/sandboxConfig.js';
import { createPolicyUpdater } from './config/policy.js';
import { ExtensionEnablementManager } from './config/extensions/extensionEnablement.js'; import { ExtensionEnablementManager } from './config/extensions/extensionEnablement.js';
export function validateDnsResolutionOrder( export function validateDnsResolutionOrder(
@@ -370,6 +371,10 @@ export async function main() {
argv, argv,
); );
const policyEngine = config.getPolicyEngine();
const messageBus = config.getMessageBus();
createPolicyUpdater(policyEngine, messageBus);
// Cleanup sessions after config initialization // Cleanup sessions after config initialization
await cleanupExpiredSessions(config, settings.merged); await cleanupExpiredSessions(config, settings.merged);
+1 -1
View File
@@ -491,7 +491,7 @@ export class Config {
this.fileExclusions = new FileExclusions(this); this.fileExclusions = new FileExclusions(this);
this.eventEmitter = params.eventEmitter; this.eventEmitter = params.eventEmitter;
this.policyEngine = new PolicyEngine(params.policyEngineConfig); this.policyEngine = new PolicyEngine(params.policyEngineConfig);
this.messageBus = new MessageBus(this.policyEngine); this.messageBus = new MessageBus(this.policyEngine, this.debugMode);
this.outputSettings = { this.outputSettings = {
format: params.output?.format ?? OutputFormat.TEXT, format: params.output?.format ?? OutputFormat.TEXT,
}; };
@@ -11,8 +11,12 @@ import { MessageBusType, type Message } from './types.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js'; import { safeJsonStringify } from '../utils/safeJsonStringify.js';
export class MessageBus extends EventEmitter { export class MessageBus extends EventEmitter {
constructor(private readonly policyEngine: PolicyEngine) { constructor(
private readonly policyEngine: PolicyEngine,
private readonly debug = false,
) {
super(); super();
this.debug = debug;
} }
private isValidMessage(message: Message): boolean { private isValidMessage(message: Message): boolean {
@@ -35,6 +39,9 @@ export class MessageBus extends EventEmitter {
} }
publish(message: Message): void { publish(message: Message): void {
if (this.debug) {
console.debug(`[MESSAGE_BUS] publish: ${safeJsonStringify(message)}`);
}
try { try {
if (!this.isValidMessage(message)) { if (!this.isValidMessage(message)) {
throw new Error( throw new Error(
+8 -1
View File
@@ -12,6 +12,7 @@ export enum MessageBusType {
TOOL_POLICY_REJECTION = 'tool-policy-rejection', TOOL_POLICY_REJECTION = 'tool-policy-rejection',
TOOL_EXECUTION_SUCCESS = 'tool-execution-success', TOOL_EXECUTION_SUCCESS = 'tool-execution-success',
TOOL_EXECUTION_FAILURE = 'tool-execution-failure', TOOL_EXECUTION_FAILURE = 'tool-execution-failure',
UPDATE_POLICY = 'update-policy',
} }
export interface ToolConfirmationRequest { export interface ToolConfirmationRequest {
@@ -31,6 +32,11 @@ export interface ToolConfirmationResponse {
requiresUserConfirmation?: boolean; requiresUserConfirmation?: boolean;
} }
export interface UpdatePolicy {
type: MessageBusType.UPDATE_POLICY;
toolName: string;
}
export interface ToolPolicyRejection { export interface ToolPolicyRejection {
type: MessageBusType.TOOL_POLICY_REJECTION; type: MessageBusType.TOOL_POLICY_REJECTION;
toolCall: FunctionCall; toolCall: FunctionCall;
@@ -53,4 +59,5 @@ export type Message =
| ToolConfirmationResponse | ToolConfirmationResponse
| ToolPolicyRejection | ToolPolicyRejection
| ToolExecutionSuccess | ToolExecutionSuccess
| ToolExecutionFailure; | ToolExecutionFailure
| UpdatePolicy;
+2
View File
@@ -11,6 +11,8 @@ export * from './output/json-formatter.js';
export * from './output/stream-json-formatter.js'; export * from './output/stream-json-formatter.js';
export * from './policy/types.js'; export * from './policy/types.js';
export * from './policy/policy-engine.js'; export * from './policy/policy-engine.js';
export * from './confirmation-bus/types.js';
export * from './confirmation-bus/message-bus.js';
// Export Core Logic // Export Core Logic
export * from './core/client.js'; export * from './core/client.js';
+22 -4
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import fs from 'node:fs'; import fs from 'node:fs';
import path from 'node:path'; import path from 'node:path';
import { glob, escape } from 'glob'; import { glob, escape } from 'glob';
@@ -88,8 +89,11 @@ class GlobToolInvocation extends BaseToolInvocation<
constructor( constructor(
private config: Config, private config: Config,
params: GlobToolParams, params: GlobToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) { ) {
super(params); super(params, messageBus, _toolName, _toolDisplayName);
} }
getDescription(): string { getDescription(): string {
@@ -261,8 +265,10 @@ class GlobToolInvocation extends BaseToolInvocation<
*/ */
export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> { export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
static readonly Name = GLOB_TOOL_NAME; static readonly Name = GLOB_TOOL_NAME;
constructor(
constructor(private config: Config) { private config: Config,
messageBus?: MessageBus,
) {
super( super(
GlobTool.Name, GlobTool.Name,
'FindFiles', 'FindFiles',
@@ -299,6 +305,9 @@ export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
required: ['pattern'], required: ['pattern'],
type: 'object', type: 'object',
}, },
true,
false,
messageBus,
); );
} }
@@ -344,7 +353,16 @@ export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
protected createInvocation( protected createInvocation(
params: GlobToolParams, params: GlobToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<GlobToolParams, ToolResult> { ): ToolInvocation<GlobToolParams, ToolResult> {
return new GlobToolInvocation(this.config, params); return new GlobToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
} }
} }
+22 -4
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import fs from 'node:fs'; import fs from 'node:fs';
import fsPromises from 'node:fs/promises'; import fsPromises from 'node:fs/promises';
import path from 'node:path'; import path from 'node:path';
@@ -61,8 +62,11 @@ class GrepToolInvocation extends BaseToolInvocation<
constructor( constructor(
private readonly config: Config, private readonly config: Config,
params: GrepToolParams, params: GrepToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) { ) {
super(params); super(params, messageBus, _toolName, _toolDisplayName);
this.fileExclusions = config.getFileExclusions(); this.fileExclusions = config.getFileExclusions();
} }
@@ -565,8 +569,10 @@ class GrepToolInvocation extends BaseToolInvocation<
*/ */
export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> { export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
static readonly Name = GREP_TOOL_NAME; static readonly Name = GREP_TOOL_NAME;
constructor(
constructor(private readonly config: Config) { private readonly config: Config,
messageBus?: MessageBus,
) {
super( super(
GrepTool.Name, GrepTool.Name,
'SearchText', 'SearchText',
@@ -593,6 +599,9 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
required: ['pattern'], required: ['pattern'],
type: 'object', type: 'object',
}, },
true,
false,
messageBus,
); );
} }
@@ -665,7 +674,16 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
protected createInvocation( protected createInvocation(
params: GrepToolParams, params: GrepToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<GrepToolParams, ToolResult> { ): ToolInvocation<GrepToolParams, ToolResult> {
return new GrepToolInvocation(this.config, params); return new GrepToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
} }
} }
+22 -3
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import fs from 'node:fs/promises'; import fs from 'node:fs/promises';
import path from 'node:path'; import path from 'node:path';
import type { ToolInvocation, ToolResult } from './tools.js'; import type { ToolInvocation, ToolResult } from './tools.js';
@@ -71,8 +72,11 @@ class LSToolInvocation extends BaseToolInvocation<LSToolParams, ToolResult> {
constructor( constructor(
private readonly config: Config, private readonly config: Config,
params: LSToolParams, params: LSToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) { ) {
super(params); super(params, messageBus, _toolName, _toolDisplayName);
} }
/** /**
@@ -255,7 +259,10 @@ class LSToolInvocation extends BaseToolInvocation<LSToolParams, ToolResult> {
export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> { export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
static readonly Name = LS_TOOL_NAME; static readonly Name = LS_TOOL_NAME;
constructor(private config: Config) { constructor(
private config: Config,
messageBus?: MessageBus,
) {
super( super(
LSTool.Name, LSTool.Name,
'ReadFolder', 'ReadFolder',
@@ -296,6 +303,9 @@ export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
required: ['path'], required: ['path'],
type: 'object', type: 'object',
}, },
true,
false,
messageBus,
); );
} }
@@ -323,7 +333,16 @@ export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
protected createInvocation( protected createInvocation(
params: LSToolParams, params: LSToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<LSToolParams, ToolResult> { ): ToolInvocation<LSToolParams, ToolResult> {
return new LSToolInvocation(this.config, params); return new LSToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
} }
} }
+22 -3
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import path from 'node:path'; import path from 'node:path';
import { makeRelative, shortenPath } from '../utils/paths.js'; import { makeRelative, shortenPath } from '../utils/paths.js';
import type { ToolInvocation, ToolLocation, ToolResult } from './tools.js'; import type { ToolInvocation, ToolLocation, ToolResult } from './tools.js';
@@ -48,8 +49,11 @@ class ReadFileToolInvocation extends BaseToolInvocation<
constructor( constructor(
private config: Config, private config: Config,
params: ReadFileToolParams, params: ReadFileToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) { ) {
super(params); super(params, messageBus, _toolName, _toolDisplayName);
} }
getDescription(): string { getDescription(): string {
@@ -138,7 +142,10 @@ export class ReadFileTool extends BaseDeclarativeTool<
> { > {
static readonly Name = READ_FILE_TOOL_NAME; static readonly Name = READ_FILE_TOOL_NAME;
constructor(private config: Config) { constructor(
private config: Config,
messageBus?: MessageBus,
) {
super( super(
ReadFileTool.Name, ReadFileTool.Name,
'ReadFile', 'ReadFile',
@@ -165,6 +172,9 @@ export class ReadFileTool extends BaseDeclarativeTool<
required: ['absolute_path'], required: ['absolute_path'],
type: 'object', type: 'object',
}, },
true,
false,
messageBus,
); );
} }
@@ -209,7 +219,16 @@ export class ReadFileTool extends BaseDeclarativeTool<
protected createInvocation( protected createInvocation(
params: ReadFileToolParams, params: ReadFileToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<ReadFileToolParams, ToolResult> { ): ToolInvocation<ReadFileToolParams, ToolResult> {
return new ReadFileToolInvocation(this.config, params); return new ReadFileToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
} }
} }
+22 -3
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import type { ToolInvocation, ToolResult } from './tools.js'; import type { ToolInvocation, ToolResult } from './tools.js';
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js'; import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
import { getErrorMessage } from '../utils/errors.js'; import { getErrorMessage } from '../utils/errors.js';
@@ -114,8 +115,11 @@ class ReadManyFilesToolInvocation extends BaseToolInvocation<
constructor( constructor(
private readonly config: Config, private readonly config: Config,
params: ReadManyFilesParams, params: ReadManyFilesParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) { ) {
super(params); super(params, messageBus, _toolName, _toolDisplayName);
} }
getDescription(): string { getDescription(): string {
@@ -477,7 +481,10 @@ export class ReadManyFilesTool extends BaseDeclarativeTool<
> { > {
static readonly Name = READ_MANY_FILES_TOOL_NAME; static readonly Name = READ_MANY_FILES_TOOL_NAME;
constructor(private config: Config) { constructor(
private config: Config,
messageBus?: MessageBus,
) {
const parameterSchema = { const parameterSchema = {
type: 'object', type: 'object',
properties: { properties: {
@@ -559,12 +566,24 @@ This tool is useful when you need to understand or analyze a collection of files
Use this tool when the user's query implies needing the content of several files simultaneously for context, analysis, or summarization. For text files, it uses default UTF-8 encoding and a '--- {filePath} ---' separator between file contents. The tool inserts a '--- End of content ---' after the last file. Ensure paths are relative to the target directory. Glob patterns like 'src/**/*.js' are supported. Avoid using for single files if a more specific single-file reading tool is available, unless the user specifically requests to process a list containing just one file via this tool. Other binary files (not explicitly requested as image/PDF) are generally skipped. Default excludes apply to common non-text files (except for explicitly requested images/PDFs) and large dependency directories unless 'useDefaultExcludes' is false.`, Use this tool when the user's query implies needing the content of several files simultaneously for context, analysis, or summarization. For text files, it uses default UTF-8 encoding and a '--- {filePath} ---' separator between file contents. The tool inserts a '--- End of content ---' after the last file. Ensure paths are relative to the target directory. Glob patterns like 'src/**/*.js' are supported. Avoid using for single files if a more specific single-file reading tool is available, unless the user specifically requests to process a list containing just one file via this tool. Other binary files (not explicitly requested as image/PDF) are generally skipped. Default excludes apply to common non-text files (except for explicitly requested images/PDFs) and large dependency directories unless 'useDefaultExcludes' is false.`,
Kind.Read, Kind.Read,
parameterSchema, parameterSchema,
true, // isOutputMarkdown
false, // canUpdateOutput
messageBus,
); );
} }
protected createInvocation( protected createInvocation(
params: ReadManyFilesParams, params: ReadManyFilesParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<ReadManyFilesParams, ToolResult> { ): ToolInvocation<ReadManyFilesParams, ToolResult> {
return new ReadManyFilesToolInvocation(this.config, params); return new ReadManyFilesToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
} }
} }
+22 -3
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import fs from 'node:fs'; import fs from 'node:fs';
import path from 'node:path'; import path from 'node:path';
import { EOL } from 'node:os'; import { EOL } from 'node:os';
@@ -110,8 +111,11 @@ class GrepToolInvocation extends BaseToolInvocation<
constructor( constructor(
private readonly config: Config, private readonly config: Config,
params: RipGrepToolParams, params: RipGrepToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) { ) {
super(params); super(params, messageBus, _toolName, _toolDisplayName);
} }
/** /**
@@ -449,7 +453,10 @@ export class RipGrepTool extends BaseDeclarativeTool<
> { > {
static readonly Name = GREP_TOOL_NAME; static readonly Name = GREP_TOOL_NAME;
constructor(private readonly config: Config) { constructor(
private readonly config: Config,
messageBus?: MessageBus,
) {
super( super(
RipGrepTool.Name, RipGrepTool.Name,
'SearchText', 'SearchText',
@@ -476,6 +483,9 @@ export class RipGrepTool extends BaseDeclarativeTool<
required: ['pattern'], required: ['pattern'],
type: 'object', type: 'object',
}, },
true, // isOutputMarkdown
false, // canUpdateOutput
messageBus,
); );
} }
@@ -548,7 +558,16 @@ export class RipGrepTool extends BaseDeclarativeTool<
protected createInvocation( protected createInvocation(
params: RipGrepToolParams, params: RipGrepToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<RipGrepToolParams, ToolResult> { ): ToolInvocation<RipGrepToolParams, ToolResult> {
return new GrepToolInvocation(this.config, params); return new GrepToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
} }
} }
+48 -13
View File
@@ -76,13 +76,9 @@ export abstract class BaseToolInvocation<
constructor( constructor(
readonly params: TParams, readonly params: TParams,
protected readonly messageBus?: MessageBus, protected readonly messageBus?: MessageBus,
) { readonly _toolName?: string,
if (this.messageBus) { readonly _toolDisplayName?: string,
console.debug( ) {}
`[DEBUG] Tool ${this.constructor.name} created with messageBus: YES`,
);
}
}
abstract getDescription(): string; abstract getDescription(): string;
@@ -90,11 +86,43 @@ export abstract class BaseToolInvocation<
return []; return [];
} }
shouldConfirmExecute( async shouldConfirmExecute(
_abortSignal: AbortSignal, abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> { ): Promise<ToolCallConfirmationDetails | false> {
// Default implementation for tools that don't override it. if (this.messageBus) {
return Promise.resolve(false); const decision = await this.getMessageBusDecision(abortSignal);
if (decision === 'ALLOW') {
return false;
}
if (decision === 'DENY') {
throw new Error(
`Tool execution for "${
this._toolDisplayName || this._toolName
}" denied by policy.`,
);
}
if (decision === 'ASK_USER') {
const confirmationDetails: ToolCallConfirmationDetails = {
type: 'info',
title: `Confirm: ${this._toolDisplayName || this._toolName}`,
prompt: this.getDescription(),
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
if (this.messageBus && this._toolName) {
this.messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName: this._toolName,
});
}
}
},
};
return confirmationDetails;
}
}
return false;
} }
protected getMessageBusDecision( protected getMessageBusDecision(
@@ -108,7 +136,7 @@ export abstract class BaseToolInvocation<
const correlationId = randomUUID(); const correlationId = randomUUID();
const toolCall = { const toolCall = {
name: this.constructor.name, name: this._toolName || this.constructor.name,
args: this.params as Record<string, unknown>, args: this.params as Record<string, unknown>,
}; };
@@ -385,7 +413,12 @@ export abstract class BaseDeclarativeTool<
if (validationError) { if (validationError) {
throw new Error(validationError); throw new Error(validationError);
} }
return this.createInvocation(params, this.messageBus); return this.createInvocation(
params,
this.messageBus,
this.name,
this.displayName,
);
} }
override validateToolParams(params: TParams): string | null { override validateToolParams(params: TParams): string | null {
@@ -408,6 +441,8 @@ export abstract class BaseDeclarativeTool<
protected abstract createInvocation( protected abstract createInvocation(
params: TParams, params: TParams,
messageBus?: MessageBus, messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<TParams, TResult>; ): ToolInvocation<TParams, TResult>;
} }
+1 -1
View File
@@ -471,7 +471,7 @@ describe('WebFetchTool', () => {
expect(publishSpy).toHaveBeenCalledWith({ expect(publishSpy).toHaveBeenCalledWith({
type: MessageBusType.TOOL_CONFIRMATION_REQUEST, type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
toolCall: { toolCall: {
name: 'WebFetchToolInvocation', name: 'web_fetch',
args: { prompt: 'fetch https://example.com' }, args: { prompt: 'fetch https://example.com' },
}, },
correlationId: 'test-correlation-id', correlationId: 'test-correlation-id',
+12 -2
View File
@@ -110,8 +110,10 @@ class WebFetchToolInvocation extends BaseToolInvocation<
private readonly config: Config, private readonly config: Config,
params: WebFetchToolParams, params: WebFetchToolParams,
messageBus?: MessageBus, messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) { ) {
super(params, messageBus); super(params, messageBus, _toolName, _toolDisplayName);
} }
private async executeFallback(signal: AbortSignal): Promise<ToolResult> { private async executeFallback(signal: AbortSignal): Promise<ToolResult> {
@@ -450,7 +452,15 @@ export class WebFetchTool extends BaseDeclarativeTool<
protected createInvocation( protected createInvocation(
params: WebFetchToolParams, params: WebFetchToolParams,
messageBus?: MessageBus, messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<WebFetchToolParams, ToolResult> { ): ToolInvocation<WebFetchToolParams, ToolResult> {
return new WebFetchToolInvocation(this.config, params, messageBus); return new WebFetchToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
} }
} }
+22 -3
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { WEB_SEARCH_TOOL_NAME } from './tool-names.js'; import { WEB_SEARCH_TOOL_NAME } from './tool-names.js';
import type { GroundingMetadata } from '@google/genai'; import type { GroundingMetadata } from '@google/genai';
import type { ToolInvocation, ToolResult } from './tools.js'; import type { ToolInvocation, ToolResult } from './tools.js';
@@ -64,8 +65,11 @@ class WebSearchToolInvocation extends BaseToolInvocation<
constructor( constructor(
private readonly config: Config, private readonly config: Config,
params: WebSearchToolParams, params: WebSearchToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) { ) {
super(params); super(params, messageBus, _toolName, _toolDisplayName);
} }
override getDescription(): string { override getDescription(): string {
@@ -187,7 +191,10 @@ export class WebSearchTool extends BaseDeclarativeTool<
> { > {
static readonly Name = WEB_SEARCH_TOOL_NAME; static readonly Name = WEB_SEARCH_TOOL_NAME;
constructor(private readonly config: Config) { constructor(
private readonly config: Config,
messageBus?: MessageBus,
) {
super( super(
WebSearchTool.Name, WebSearchTool.Name,
'GoogleSearch', 'GoogleSearch',
@@ -203,6 +210,9 @@ export class WebSearchTool extends BaseDeclarativeTool<
}, },
required: ['query'], required: ['query'],
}, },
true, // isOutputMarkdown
false, // canUpdateOutput
messageBus,
); );
} }
@@ -222,7 +232,16 @@ export class WebSearchTool extends BaseDeclarativeTool<
protected createInvocation( protected createInvocation(
params: WebSearchToolParams, params: WebSearchToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<WebSearchToolParams, WebSearchToolResult> { ): ToolInvocation<WebSearchToolParams, WebSearchToolResult> {
return new WebSearchToolInvocation(this.config, params); return new WebSearchToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
} }
} }