mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-25 04:24:51 -07:00
feat: Implement message bus and policy engine (#11523)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -14,16 +14,70 @@ import {
|
||||
} from '@google/gemini-cli-core';
|
||||
|
||||
describe('createPolicyEngineConfig', () => {
|
||||
it('should return ASK_USER for all tools by default', () => {
|
||||
it('should return ASK_USER for write tools and ALLOW for read-only tools by default', () => {
|
||||
const settings: Settings = {};
|
||||
const config = createPolicyEngineConfig(settings, ApprovalMode.DEFAULT);
|
||||
expect(config.defaultDecision).toBe(PolicyDecision.ASK_USER);
|
||||
// The order of the rules is not guaranteed, so we sort them by tool name.
|
||||
config.rules?.sort((a, b) =>
|
||||
(a.toolName ?? '').localeCompare(b.toolName ?? ''),
|
||||
);
|
||||
expect(config.rules).toEqual([
|
||||
{ toolName: 'replace', decision: 'ask_user', priority: 10 },
|
||||
{ toolName: 'save_memory', decision: 'ask_user', priority: 10 },
|
||||
{ toolName: 'run_shell_command', decision: 'ask_user', priority: 10 },
|
||||
{ toolName: 'write_file', decision: 'ask_user', priority: 10 },
|
||||
{ toolName: WEB_FETCH_TOOL_NAME, decision: 'ask_user', priority: 10 },
|
||||
{
|
||||
toolName: 'glob',
|
||||
decision: PolicyDecision.ALLOW,
|
||||
priority: 50,
|
||||
},
|
||||
{
|
||||
toolName: 'google_web_search',
|
||||
decision: PolicyDecision.ALLOW,
|
||||
priority: 50,
|
||||
},
|
||||
{
|
||||
toolName: 'list_directory',
|
||||
decision: PolicyDecision.ALLOW,
|
||||
priority: 50,
|
||||
},
|
||||
{
|
||||
toolName: 'read_file',
|
||||
decision: PolicyDecision.ALLOW,
|
||||
priority: 50,
|
||||
},
|
||||
{
|
||||
toolName: 'read_many_files',
|
||||
decision: PolicyDecision.ALLOW,
|
||||
priority: 50,
|
||||
},
|
||||
{
|
||||
toolName: 'replace',
|
||||
decision: PolicyDecision.ASK_USER,
|
||||
priority: 10,
|
||||
},
|
||||
{
|
||||
toolName: 'run_shell_command',
|
||||
decision: PolicyDecision.ASK_USER,
|
||||
priority: 10,
|
||||
},
|
||||
{
|
||||
toolName: 'save_memory',
|
||||
decision: PolicyDecision.ASK_USER,
|
||||
priority: 10,
|
||||
},
|
||||
{
|
||||
toolName: 'search_file_content',
|
||||
decision: PolicyDecision.ALLOW,
|
||||
priority: 50,
|
||||
},
|
||||
{
|
||||
toolName: 'web_fetch',
|
||||
decision: PolicyDecision.ASK_USER,
|
||||
priority: 10,
|
||||
},
|
||||
{
|
||||
toolName: 'write_file',
|
||||
decision: PolicyDecision.ASK_USER,
|
||||
priority: 10,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
@@ -159,18 +213,6 @@ describe('createPolicyEngineConfig', () => {
|
||||
expect(excludedRule?.priority).toBe(195);
|
||||
});
|
||||
|
||||
it('should allow read-only tools if autoAccept is true', () => {
|
||||
const settings: Settings = {
|
||||
tools: { autoAccept: true },
|
||||
};
|
||||
const config = createPolicyEngineConfig(settings, ApprovalMode.DEFAULT);
|
||||
const rule = config.rules?.find(
|
||||
(r) => r.toolName === 'glob' && r.decision === PolicyDecision.ALLOW,
|
||||
);
|
||||
expect(rule).toBeDefined();
|
||||
expect(rule?.priority).toBe(50);
|
||||
});
|
||||
|
||||
it('should allow all tools in YOLO mode', () => {
|
||||
const settings: Settings = {};
|
||||
const config = createPolicyEngineConfig(settings, ApprovalMode.YOLO);
|
||||
@@ -419,29 +461,6 @@ describe('createPolicyEngineConfig', () => {
|
||||
// Exclude (195) should win over trust (90) when evaluated
|
||||
});
|
||||
|
||||
it('should create all read-only tool rules when autoAccept is enabled', () => {
|
||||
const settings: Settings = {
|
||||
tools: { autoAccept: true },
|
||||
};
|
||||
const config = createPolicyEngineConfig(settings, ApprovalMode.DEFAULT);
|
||||
|
||||
// All read-only tools should have allow rules
|
||||
const readOnlyTools = [
|
||||
'glob',
|
||||
'search_file_content',
|
||||
'list_directory',
|
||||
'read_file',
|
||||
'read_many_files',
|
||||
];
|
||||
for (const tool of readOnlyTools) {
|
||||
const rule = config.rules?.find(
|
||||
(r) => r.toolName === tool && r.decision === PolicyDecision.ALLOW,
|
||||
);
|
||||
expect(rule).toBeDefined();
|
||||
expect(rule?.priority).toBe(50);
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle all approval modes correctly', () => {
|
||||
const settings: Settings = {};
|
||||
|
||||
|
||||
@@ -22,8 +22,12 @@ import {
|
||||
EDIT_TOOL_NAME,
|
||||
MEMORY_TOOL_NAME,
|
||||
WEB_SEARCH_TOOL_NAME,
|
||||
type PolicyEngine,
|
||||
type MessageBus,
|
||||
MessageBusType,
|
||||
type UpdatePolicy,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type { Settings } from './settings.js';
|
||||
import { type Settings } from './settings.js';
|
||||
|
||||
// READ_ONLY_TOOLS is a list of built-in tools that do not modify the user's
|
||||
// files or system state.
|
||||
@@ -69,6 +73,7 @@ export function createPolicyEngineConfig(
|
||||
// 90: MCP servers with trust=true
|
||||
// 100: Explicitly allowed individual tools
|
||||
// 195: Explicitly excluded MCP servers
|
||||
// 199: Tools that the user has selected as "Always Allow" in the interactive UI.
|
||||
// 200: Explicitly excluded individual tools (highest priority)
|
||||
|
||||
// MCP servers that are explicitly allowed in settings.mcp.allowed
|
||||
@@ -137,16 +142,14 @@ export function createPolicyEngineConfig(
|
||||
}
|
||||
}
|
||||
|
||||
// If auto-accept is enabled, allow all read-only tools.
|
||||
// Allow all read-only tools.
|
||||
// Priority: 50
|
||||
if (settings.tools?.autoAccept) {
|
||||
for (const tool of READ_ONLY_TOOLS) {
|
||||
rules.push({
|
||||
toolName: tool,
|
||||
decision: PolicyDecision.ALLOW,
|
||||
priority: 50,
|
||||
});
|
||||
}
|
||||
for (const tool of READ_ONLY_TOOLS) {
|
||||
rules.push({
|
||||
toolName: tool,
|
||||
decision: PolicyDecision.ALLOW,
|
||||
priority: 50,
|
||||
});
|
||||
}
|
||||
|
||||
// Only add write tool rules if not in YOLO mode
|
||||
@@ -179,3 +182,21 @@ export function createPolicyEngineConfig(
|
||||
defaultDecision: PolicyDecision.ASK_USER,
|
||||
};
|
||||
}
|
||||
|
||||
export function createPolicyUpdater(
|
||||
policyEngine: PolicyEngine,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
messageBus.subscribe(
|
||||
MessageBusType.UPDATE_POLICY,
|
||||
(message: UpdatePolicy) => {
|
||||
const toolName = message.toolName;
|
||||
|
||||
policyEngine.addRule({
|
||||
toolName,
|
||||
decision: PolicyDecision.ALLOW,
|
||||
priority: 199, // High priority, but lower than explicit DENY (200)
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -170,6 +170,10 @@ describe('gemini.tsx main function', () => {
|
||||
getScreenReader: () => false,
|
||||
getGeminiMdFileCount: () => 0,
|
||||
getProjectRoot: () => '/',
|
||||
getPolicyEngine: vi.fn(),
|
||||
getMessageBus: () => ({
|
||||
subscribe: vi.fn(),
|
||||
}),
|
||||
} as unknown as Config;
|
||||
});
|
||||
vi.mocked(loadSettings).mockReturnValue({
|
||||
@@ -301,6 +305,10 @@ describe('gemini.tsx main function kitty protocol', () => {
|
||||
getExperimentalZedIntegration: () => false,
|
||||
getScreenReader: () => false,
|
||||
getGeminiMdFileCount: () => 0,
|
||||
getPolicyEngine: vi.fn(),
|
||||
getMessageBus: () => ({
|
||||
subscribe: vi.fn(),
|
||||
}),
|
||||
} as unknown as Config);
|
||||
vi.mocked(loadSettings).mockReturnValue({
|
||||
errors: [],
|
||||
|
||||
@@ -67,6 +67,7 @@ import {
|
||||
relaunchOnExitCode,
|
||||
} from './utils/relaunch.js';
|
||||
import { loadSandboxConfig } from './config/sandboxConfig.js';
|
||||
import { createPolicyUpdater } from './config/policy.js';
|
||||
import { ExtensionEnablementManager } from './config/extensions/extensionEnablement.js';
|
||||
|
||||
export function validateDnsResolutionOrder(
|
||||
@@ -370,6 +371,10 @@ export async function main() {
|
||||
argv,
|
||||
);
|
||||
|
||||
const policyEngine = config.getPolicyEngine();
|
||||
const messageBus = config.getMessageBus();
|
||||
createPolicyUpdater(policyEngine, messageBus);
|
||||
|
||||
// Cleanup sessions after config initialization
|
||||
await cleanupExpiredSessions(config, settings.merged);
|
||||
|
||||
|
||||
@@ -491,7 +491,7 @@ export class Config {
|
||||
this.fileExclusions = new FileExclusions(this);
|
||||
this.eventEmitter = params.eventEmitter;
|
||||
this.policyEngine = new PolicyEngine(params.policyEngineConfig);
|
||||
this.messageBus = new MessageBus(this.policyEngine);
|
||||
this.messageBus = new MessageBus(this.policyEngine, this.debugMode);
|
||||
this.outputSettings = {
|
||||
format: params.output?.format ?? OutputFormat.TEXT,
|
||||
};
|
||||
|
||||
@@ -11,8 +11,12 @@ import { MessageBusType, type Message } from './types.js';
|
||||
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
|
||||
|
||||
export class MessageBus extends EventEmitter {
|
||||
constructor(private readonly policyEngine: PolicyEngine) {
|
||||
constructor(
|
||||
private readonly policyEngine: PolicyEngine,
|
||||
private readonly debug = false,
|
||||
) {
|
||||
super();
|
||||
this.debug = debug;
|
||||
}
|
||||
|
||||
private isValidMessage(message: Message): boolean {
|
||||
@@ -35,6 +39,9 @@ export class MessageBus extends EventEmitter {
|
||||
}
|
||||
|
||||
publish(message: Message): void {
|
||||
if (this.debug) {
|
||||
console.debug(`[MESSAGE_BUS] publish: ${safeJsonStringify(message)}`);
|
||||
}
|
||||
try {
|
||||
if (!this.isValidMessage(message)) {
|
||||
throw new Error(
|
||||
|
||||
@@ -12,6 +12,7 @@ export enum MessageBusType {
|
||||
TOOL_POLICY_REJECTION = 'tool-policy-rejection',
|
||||
TOOL_EXECUTION_SUCCESS = 'tool-execution-success',
|
||||
TOOL_EXECUTION_FAILURE = 'tool-execution-failure',
|
||||
UPDATE_POLICY = 'update-policy',
|
||||
}
|
||||
|
||||
export interface ToolConfirmationRequest {
|
||||
@@ -31,6 +32,11 @@ export interface ToolConfirmationResponse {
|
||||
requiresUserConfirmation?: boolean;
|
||||
}
|
||||
|
||||
export interface UpdatePolicy {
|
||||
type: MessageBusType.UPDATE_POLICY;
|
||||
toolName: string;
|
||||
}
|
||||
|
||||
export interface ToolPolicyRejection {
|
||||
type: MessageBusType.TOOL_POLICY_REJECTION;
|
||||
toolCall: FunctionCall;
|
||||
@@ -53,4 +59,5 @@ export type Message =
|
||||
| ToolConfirmationResponse
|
||||
| ToolPolicyRejection
|
||||
| ToolExecutionSuccess
|
||||
| ToolExecutionFailure;
|
||||
| ToolExecutionFailure
|
||||
| UpdatePolicy;
|
||||
|
||||
@@ -11,6 +11,8 @@ export * from './output/json-formatter.js';
|
||||
export * from './output/stream-json-formatter.js';
|
||||
export * from './policy/types.js';
|
||||
export * from './policy/policy-engine.js';
|
||||
export * from './confirmation-bus/types.js';
|
||||
export * from './confirmation-bus/message-bus.js';
|
||||
|
||||
// Export Core Logic
|
||||
export * from './core/client.js';
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import fs from 'node:fs';
|
||||
import path from 'node:path';
|
||||
import { glob, escape } from 'glob';
|
||||
@@ -88,8 +89,11 @@ class GlobToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private config: Config,
|
||||
params: GlobToolParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
super(params);
|
||||
super(params, messageBus, _toolName, _toolDisplayName);
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
@@ -261,8 +265,10 @@ class GlobToolInvocation extends BaseToolInvocation<
|
||||
*/
|
||||
export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
|
||||
static readonly Name = GLOB_TOOL_NAME;
|
||||
|
||||
constructor(private config: Config) {
|
||||
constructor(
|
||||
private config: Config,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
super(
|
||||
GlobTool.Name,
|
||||
'FindFiles',
|
||||
@@ -299,6 +305,9 @@ export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
|
||||
required: ['pattern'],
|
||||
type: 'object',
|
||||
},
|
||||
true,
|
||||
false,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -344,7 +353,16 @@ export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
|
||||
|
||||
protected createInvocation(
|
||||
params: GlobToolParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<GlobToolParams, ToolResult> {
|
||||
return new GlobToolInvocation(this.config, params);
|
||||
return new GlobToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_toolDisplayName,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import fs from 'node:fs';
|
||||
import fsPromises from 'node:fs/promises';
|
||||
import path from 'node:path';
|
||||
@@ -61,8 +62,11 @@ class GrepToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: GrepToolParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
super(params);
|
||||
super(params, messageBus, _toolName, _toolDisplayName);
|
||||
this.fileExclusions = config.getFileExclusions();
|
||||
}
|
||||
|
||||
@@ -565,8 +569,10 @@ class GrepToolInvocation extends BaseToolInvocation<
|
||||
*/
|
||||
export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
|
||||
static readonly Name = GREP_TOOL_NAME;
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
super(
|
||||
GrepTool.Name,
|
||||
'SearchText',
|
||||
@@ -593,6 +599,9 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
|
||||
required: ['pattern'],
|
||||
type: 'object',
|
||||
},
|
||||
true,
|
||||
false,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -665,7 +674,16 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
|
||||
|
||||
protected createInvocation(
|
||||
params: GrepToolParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<GrepToolParams, ToolResult> {
|
||||
return new GrepToolInvocation(this.config, params);
|
||||
return new GrepToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_toolDisplayName,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import fs from 'node:fs/promises';
|
||||
import path from 'node:path';
|
||||
import type { ToolInvocation, ToolResult } from './tools.js';
|
||||
@@ -71,8 +72,11 @@ class LSToolInvocation extends BaseToolInvocation<LSToolParams, ToolResult> {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: LSToolParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
super(params);
|
||||
super(params, messageBus, _toolName, _toolDisplayName);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -255,7 +259,10 @@ class LSToolInvocation extends BaseToolInvocation<LSToolParams, ToolResult> {
|
||||
export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
|
||||
static readonly Name = LS_TOOL_NAME;
|
||||
|
||||
constructor(private config: Config) {
|
||||
constructor(
|
||||
private config: Config,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
super(
|
||||
LSTool.Name,
|
||||
'ReadFolder',
|
||||
@@ -296,6 +303,9 @@ export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
|
||||
required: ['path'],
|
||||
type: 'object',
|
||||
},
|
||||
true,
|
||||
false,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -323,7 +333,16 @@ export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
|
||||
|
||||
protected createInvocation(
|
||||
params: LSToolParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<LSToolParams, ToolResult> {
|
||||
return new LSToolInvocation(this.config, params);
|
||||
return new LSToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_toolDisplayName,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import path from 'node:path';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import type { ToolInvocation, ToolLocation, ToolResult } from './tools.js';
|
||||
@@ -48,8 +49,11 @@ class ReadFileToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private config: Config,
|
||||
params: ReadFileToolParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
super(params);
|
||||
super(params, messageBus, _toolName, _toolDisplayName);
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
@@ -138,7 +142,10 @@ export class ReadFileTool extends BaseDeclarativeTool<
|
||||
> {
|
||||
static readonly Name = READ_FILE_TOOL_NAME;
|
||||
|
||||
constructor(private config: Config) {
|
||||
constructor(
|
||||
private config: Config,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
super(
|
||||
ReadFileTool.Name,
|
||||
'ReadFile',
|
||||
@@ -165,6 +172,9 @@ export class ReadFileTool extends BaseDeclarativeTool<
|
||||
required: ['absolute_path'],
|
||||
type: 'object',
|
||||
},
|
||||
true,
|
||||
false,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -209,7 +219,16 @@ export class ReadFileTool extends BaseDeclarativeTool<
|
||||
|
||||
protected createInvocation(
|
||||
params: ReadFileToolParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<ReadFileToolParams, ToolResult> {
|
||||
return new ReadFileToolInvocation(this.config, params);
|
||||
return new ReadFileToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_toolDisplayName,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import type { ToolInvocation, ToolResult } from './tools.js';
|
||||
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
@@ -114,8 +115,11 @@ class ReadManyFilesToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: ReadManyFilesParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
super(params);
|
||||
super(params, messageBus, _toolName, _toolDisplayName);
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
@@ -477,7 +481,10 @@ export class ReadManyFilesTool extends BaseDeclarativeTool<
|
||||
> {
|
||||
static readonly Name = READ_MANY_FILES_TOOL_NAME;
|
||||
|
||||
constructor(private config: Config) {
|
||||
constructor(
|
||||
private config: Config,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
const parameterSchema = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
@@ -559,12 +566,24 @@ This tool is useful when you need to understand or analyze a collection of files
|
||||
Use this tool when the user's query implies needing the content of several files simultaneously for context, analysis, or summarization. For text files, it uses default UTF-8 encoding and a '--- {filePath} ---' separator between file contents. The tool inserts a '--- End of content ---' after the last file. Ensure paths are relative to the target directory. Glob patterns like 'src/**/*.js' are supported. Avoid using for single files if a more specific single-file reading tool is available, unless the user specifically requests to process a list containing just one file via this tool. Other binary files (not explicitly requested as image/PDF) are generally skipped. Default excludes apply to common non-text files (except for explicitly requested images/PDFs) and large dependency directories unless 'useDefaultExcludes' is false.`,
|
||||
Kind.Read,
|
||||
parameterSchema,
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
protected createInvocation(
|
||||
params: ReadManyFilesParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<ReadManyFilesParams, ToolResult> {
|
||||
return new ReadManyFilesToolInvocation(this.config, params);
|
||||
return new ReadManyFilesToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_toolDisplayName,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import fs from 'node:fs';
|
||||
import path from 'node:path';
|
||||
import { EOL } from 'node:os';
|
||||
@@ -110,8 +111,11 @@ class GrepToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: RipGrepToolParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
super(params);
|
||||
super(params, messageBus, _toolName, _toolDisplayName);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -449,7 +453,10 @@ export class RipGrepTool extends BaseDeclarativeTool<
|
||||
> {
|
||||
static readonly Name = GREP_TOOL_NAME;
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
super(
|
||||
RipGrepTool.Name,
|
||||
'SearchText',
|
||||
@@ -476,6 +483,9 @@ export class RipGrepTool extends BaseDeclarativeTool<
|
||||
required: ['pattern'],
|
||||
type: 'object',
|
||||
},
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -548,7 +558,16 @@ export class RipGrepTool extends BaseDeclarativeTool<
|
||||
|
||||
protected createInvocation(
|
||||
params: RipGrepToolParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<RipGrepToolParams, ToolResult> {
|
||||
return new GrepToolInvocation(this.config, params);
|
||||
return new GrepToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_toolDisplayName,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,13 +76,9 @@ export abstract class BaseToolInvocation<
|
||||
constructor(
|
||||
readonly params: TParams,
|
||||
protected readonly messageBus?: MessageBus,
|
||||
) {
|
||||
if (this.messageBus) {
|
||||
console.debug(
|
||||
`[DEBUG] Tool ${this.constructor.name} created with messageBus: YES`,
|
||||
);
|
||||
}
|
||||
}
|
||||
readonly _toolName?: string,
|
||||
readonly _toolDisplayName?: string,
|
||||
) {}
|
||||
|
||||
abstract getDescription(): string;
|
||||
|
||||
@@ -90,11 +86,43 @@ export abstract class BaseToolInvocation<
|
||||
return [];
|
||||
}
|
||||
|
||||
shouldConfirmExecute(
|
||||
_abortSignal: AbortSignal,
|
||||
async shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
// Default implementation for tools that don't override it.
|
||||
return Promise.resolve(false);
|
||||
if (this.messageBus) {
|
||||
const decision = await this.getMessageBusDecision(abortSignal);
|
||||
if (decision === 'ALLOW') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decision === 'DENY') {
|
||||
throw new Error(
|
||||
`Tool execution for "${
|
||||
this._toolDisplayName || this._toolName
|
||||
}" denied by policy.`,
|
||||
);
|
||||
}
|
||||
|
||||
if (decision === 'ASK_USER') {
|
||||
const confirmationDetails: ToolCallConfirmationDetails = {
|
||||
type: 'info',
|
||||
title: `Confirm: ${this._toolDisplayName || this._toolName}`,
|
||||
prompt: this.getDescription(),
|
||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
if (this.messageBus && this._toolName) {
|
||||
this.messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName: this._toolName,
|
||||
});
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
return confirmationDetails;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
protected getMessageBusDecision(
|
||||
@@ -108,7 +136,7 @@ export abstract class BaseToolInvocation<
|
||||
|
||||
const correlationId = randomUUID();
|
||||
const toolCall = {
|
||||
name: this.constructor.name,
|
||||
name: this._toolName || this.constructor.name,
|
||||
args: this.params as Record<string, unknown>,
|
||||
};
|
||||
|
||||
@@ -385,7 +413,12 @@ export abstract class BaseDeclarativeTool<
|
||||
if (validationError) {
|
||||
throw new Error(validationError);
|
||||
}
|
||||
return this.createInvocation(params, this.messageBus);
|
||||
return this.createInvocation(
|
||||
params,
|
||||
this.messageBus,
|
||||
this.name,
|
||||
this.displayName,
|
||||
);
|
||||
}
|
||||
|
||||
override validateToolParams(params: TParams): string | null {
|
||||
@@ -408,6 +441,8 @@ export abstract class BaseDeclarativeTool<
|
||||
protected abstract createInvocation(
|
||||
params: TParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<TParams, TResult>;
|
||||
}
|
||||
|
||||
|
||||
@@ -471,7 +471,7 @@ describe('WebFetchTool', () => {
|
||||
expect(publishSpy).toHaveBeenCalledWith({
|
||||
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
|
||||
toolCall: {
|
||||
name: 'WebFetchToolInvocation',
|
||||
name: 'web_fetch',
|
||||
args: { prompt: 'fetch https://example.com' },
|
||||
},
|
||||
correlationId: 'test-correlation-id',
|
||||
|
||||
@@ -110,8 +110,10 @@ class WebFetchToolInvocation extends BaseToolInvocation<
|
||||
private readonly config: Config,
|
||||
params: WebFetchToolParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
super(params, messageBus);
|
||||
super(params, messageBus, _toolName, _toolDisplayName);
|
||||
}
|
||||
|
||||
private async executeFallback(signal: AbortSignal): Promise<ToolResult> {
|
||||
@@ -450,7 +452,15 @@ export class WebFetchTool extends BaseDeclarativeTool<
|
||||
protected createInvocation(
|
||||
params: WebFetchToolParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<WebFetchToolParams, ToolResult> {
|
||||
return new WebFetchToolInvocation(this.config, params, messageBus);
|
||||
return new WebFetchToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_toolDisplayName,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { WEB_SEARCH_TOOL_NAME } from './tool-names.js';
|
||||
import type { GroundingMetadata } from '@google/genai';
|
||||
import type { ToolInvocation, ToolResult } from './tools.js';
|
||||
@@ -64,8 +65,11 @@ class WebSearchToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: WebSearchToolParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
super(params);
|
||||
super(params, messageBus, _toolName, _toolDisplayName);
|
||||
}
|
||||
|
||||
override getDescription(): string {
|
||||
@@ -187,7 +191,10 @@ export class WebSearchTool extends BaseDeclarativeTool<
|
||||
> {
|
||||
static readonly Name = WEB_SEARCH_TOOL_NAME;
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
super(
|
||||
WebSearchTool.Name,
|
||||
'GoogleSearch',
|
||||
@@ -203,6 +210,9 @@ export class WebSearchTool extends BaseDeclarativeTool<
|
||||
},
|
||||
required: ['query'],
|
||||
},
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -222,7 +232,16 @@ export class WebSearchTool extends BaseDeclarativeTool<
|
||||
|
||||
protected createInvocation(
|
||||
params: WebSearchToolParams,
|
||||
messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<WebSearchToolParams, WebSearchToolResult> {
|
||||
return new WebSearchToolInvocation(this.config, params);
|
||||
return new WebSearchToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_toolDisplayName,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user