feat: Implement message bus and policy engine (#11523)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Allen Hutchison
2025-10-21 11:45:33 -07:00
committed by GitHub
parent 0658b4aa31
commit bf80263bd6
19 changed files with 339 additions and 94 deletions
+1 -1
View File
@@ -491,7 +491,7 @@ export class Config {
this.fileExclusions = new FileExclusions(this);
this.eventEmitter = params.eventEmitter;
this.policyEngine = new PolicyEngine(params.policyEngineConfig);
this.messageBus = new MessageBus(this.policyEngine);
this.messageBus = new MessageBus(this.policyEngine, this.debugMode);
this.outputSettings = {
format: params.output?.format ?? OutputFormat.TEXT,
};
@@ -11,8 +11,12 @@ import { MessageBusType, type Message } from './types.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
export class MessageBus extends EventEmitter {
constructor(private readonly policyEngine: PolicyEngine) {
constructor(
private readonly policyEngine: PolicyEngine,
private readonly debug = false,
) {
super();
this.debug = debug;
}
private isValidMessage(message: Message): boolean {
@@ -35,6 +39,9 @@ export class MessageBus extends EventEmitter {
}
publish(message: Message): void {
if (this.debug) {
console.debug(`[MESSAGE_BUS] publish: ${safeJsonStringify(message)}`);
}
try {
if (!this.isValidMessage(message)) {
throw new Error(
+8 -1
View File
@@ -12,6 +12,7 @@ export enum MessageBusType {
TOOL_POLICY_REJECTION = 'tool-policy-rejection',
TOOL_EXECUTION_SUCCESS = 'tool-execution-success',
TOOL_EXECUTION_FAILURE = 'tool-execution-failure',
UPDATE_POLICY = 'update-policy',
}
export interface ToolConfirmationRequest {
@@ -31,6 +32,11 @@ export interface ToolConfirmationResponse {
requiresUserConfirmation?: boolean;
}
export interface UpdatePolicy {
type: MessageBusType.UPDATE_POLICY;
toolName: string;
}
export interface ToolPolicyRejection {
type: MessageBusType.TOOL_POLICY_REJECTION;
toolCall: FunctionCall;
@@ -53,4 +59,5 @@ export type Message =
| ToolConfirmationResponse
| ToolPolicyRejection
| ToolExecutionSuccess
| ToolExecutionFailure;
| ToolExecutionFailure
| UpdatePolicy;
+2
View File
@@ -11,6 +11,8 @@ export * from './output/json-formatter.js';
export * from './output/stream-json-formatter.js';
export * from './policy/types.js';
export * from './policy/policy-engine.js';
export * from './confirmation-bus/types.js';
export * from './confirmation-bus/message-bus.js';
// Export Core Logic
export * from './core/client.js';
+22 -4
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import fs from 'node:fs';
import path from 'node:path';
import { glob, escape } from 'glob';
@@ -88,8 +89,11 @@ class GlobToolInvocation extends BaseToolInvocation<
constructor(
private config: Config,
params: GlobToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) {
super(params);
super(params, messageBus, _toolName, _toolDisplayName);
}
getDescription(): string {
@@ -261,8 +265,10 @@ class GlobToolInvocation extends BaseToolInvocation<
*/
export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
static readonly Name = GLOB_TOOL_NAME;
constructor(private config: Config) {
constructor(
private config: Config,
messageBus?: MessageBus,
) {
super(
GlobTool.Name,
'FindFiles',
@@ -299,6 +305,9 @@ export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
required: ['pattern'],
type: 'object',
},
true,
false,
messageBus,
);
}
@@ -344,7 +353,16 @@ export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
protected createInvocation(
params: GlobToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<GlobToolParams, ToolResult> {
return new GlobToolInvocation(this.config, params);
return new GlobToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
}
}
+22 -4
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import fs from 'node:fs';
import fsPromises from 'node:fs/promises';
import path from 'node:path';
@@ -61,8 +62,11 @@ class GrepToolInvocation extends BaseToolInvocation<
constructor(
private readonly config: Config,
params: GrepToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) {
super(params);
super(params, messageBus, _toolName, _toolDisplayName);
this.fileExclusions = config.getFileExclusions();
}
@@ -565,8 +569,10 @@ class GrepToolInvocation extends BaseToolInvocation<
*/
export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
static readonly Name = GREP_TOOL_NAME;
constructor(private readonly config: Config) {
constructor(
private readonly config: Config,
messageBus?: MessageBus,
) {
super(
GrepTool.Name,
'SearchText',
@@ -593,6 +599,9 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
required: ['pattern'],
type: 'object',
},
true,
false,
messageBus,
);
}
@@ -665,7 +674,16 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
protected createInvocation(
params: GrepToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<GrepToolParams, ToolResult> {
return new GrepToolInvocation(this.config, params);
return new GrepToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
}
}
+22 -3
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import fs from 'node:fs/promises';
import path from 'node:path';
import type { ToolInvocation, ToolResult } from './tools.js';
@@ -71,8 +72,11 @@ class LSToolInvocation extends BaseToolInvocation<LSToolParams, ToolResult> {
constructor(
private readonly config: Config,
params: LSToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) {
super(params);
super(params, messageBus, _toolName, _toolDisplayName);
}
/**
@@ -255,7 +259,10 @@ class LSToolInvocation extends BaseToolInvocation<LSToolParams, ToolResult> {
export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
static readonly Name = LS_TOOL_NAME;
constructor(private config: Config) {
constructor(
private config: Config,
messageBus?: MessageBus,
) {
super(
LSTool.Name,
'ReadFolder',
@@ -296,6 +303,9 @@ export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
required: ['path'],
type: 'object',
},
true,
false,
messageBus,
);
}
@@ -323,7 +333,16 @@ export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
protected createInvocation(
params: LSToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<LSToolParams, ToolResult> {
return new LSToolInvocation(this.config, params);
return new LSToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
}
}
+22 -3
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import path from 'node:path';
import { makeRelative, shortenPath } from '../utils/paths.js';
import type { ToolInvocation, ToolLocation, ToolResult } from './tools.js';
@@ -48,8 +49,11 @@ class ReadFileToolInvocation extends BaseToolInvocation<
constructor(
private config: Config,
params: ReadFileToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) {
super(params);
super(params, messageBus, _toolName, _toolDisplayName);
}
getDescription(): string {
@@ -138,7 +142,10 @@ export class ReadFileTool extends BaseDeclarativeTool<
> {
static readonly Name = READ_FILE_TOOL_NAME;
constructor(private config: Config) {
constructor(
private config: Config,
messageBus?: MessageBus,
) {
super(
ReadFileTool.Name,
'ReadFile',
@@ -165,6 +172,9 @@ export class ReadFileTool extends BaseDeclarativeTool<
required: ['absolute_path'],
type: 'object',
},
true,
false,
messageBus,
);
}
@@ -209,7 +219,16 @@ export class ReadFileTool extends BaseDeclarativeTool<
protected createInvocation(
params: ReadFileToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<ReadFileToolParams, ToolResult> {
return new ReadFileToolInvocation(this.config, params);
return new ReadFileToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
}
}
+22 -3
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import type { ToolInvocation, ToolResult } from './tools.js';
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
import { getErrorMessage } from '../utils/errors.js';
@@ -114,8 +115,11 @@ class ReadManyFilesToolInvocation extends BaseToolInvocation<
constructor(
private readonly config: Config,
params: ReadManyFilesParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) {
super(params);
super(params, messageBus, _toolName, _toolDisplayName);
}
getDescription(): string {
@@ -477,7 +481,10 @@ export class ReadManyFilesTool extends BaseDeclarativeTool<
> {
static readonly Name = READ_MANY_FILES_TOOL_NAME;
constructor(private config: Config) {
constructor(
private config: Config,
messageBus?: MessageBus,
) {
const parameterSchema = {
type: 'object',
properties: {
@@ -559,12 +566,24 @@ This tool is useful when you need to understand or analyze a collection of files
Use this tool when the user's query implies needing the content of several files simultaneously for context, analysis, or summarization. For text files, it uses default UTF-8 encoding and a '--- {filePath} ---' separator between file contents. The tool inserts a '--- End of content ---' after the last file. Ensure paths are relative to the target directory. Glob patterns like 'src/**/*.js' are supported. Avoid using for single files if a more specific single-file reading tool is available, unless the user specifically requests to process a list containing just one file via this tool. Other binary files (not explicitly requested as image/PDF) are generally skipped. Default excludes apply to common non-text files (except for explicitly requested images/PDFs) and large dependency directories unless 'useDefaultExcludes' is false.`,
Kind.Read,
parameterSchema,
true, // isOutputMarkdown
false, // canUpdateOutput
messageBus,
);
}
protected createInvocation(
params: ReadManyFilesParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<ReadManyFilesParams, ToolResult> {
return new ReadManyFilesToolInvocation(this.config, params);
return new ReadManyFilesToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
}
}
+22 -3
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import fs from 'node:fs';
import path from 'node:path';
import { EOL } from 'node:os';
@@ -110,8 +111,11 @@ class GrepToolInvocation extends BaseToolInvocation<
constructor(
private readonly config: Config,
params: RipGrepToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) {
super(params);
super(params, messageBus, _toolName, _toolDisplayName);
}
/**
@@ -449,7 +453,10 @@ export class RipGrepTool extends BaseDeclarativeTool<
> {
static readonly Name = GREP_TOOL_NAME;
constructor(private readonly config: Config) {
constructor(
private readonly config: Config,
messageBus?: MessageBus,
) {
super(
RipGrepTool.Name,
'SearchText',
@@ -476,6 +483,9 @@ export class RipGrepTool extends BaseDeclarativeTool<
required: ['pattern'],
type: 'object',
},
true, // isOutputMarkdown
false, // canUpdateOutput
messageBus,
);
}
@@ -548,7 +558,16 @@ export class RipGrepTool extends BaseDeclarativeTool<
protected createInvocation(
params: RipGrepToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<RipGrepToolParams, ToolResult> {
return new GrepToolInvocation(this.config, params);
return new GrepToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
}
}
+48 -13
View File
@@ -76,13 +76,9 @@ export abstract class BaseToolInvocation<
constructor(
readonly params: TParams,
protected readonly messageBus?: MessageBus,
) {
if (this.messageBus) {
console.debug(
`[DEBUG] Tool ${this.constructor.name} created with messageBus: YES`,
);
}
}
readonly _toolName?: string,
readonly _toolDisplayName?: string,
) {}
abstract getDescription(): string;
@@ -90,11 +86,43 @@ export abstract class BaseToolInvocation<
return [];
}
shouldConfirmExecute(
_abortSignal: AbortSignal,
async shouldConfirmExecute(
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
// Default implementation for tools that don't override it.
return Promise.resolve(false);
if (this.messageBus) {
const decision = await this.getMessageBusDecision(abortSignal);
if (decision === 'ALLOW') {
return false;
}
if (decision === 'DENY') {
throw new Error(
`Tool execution for "${
this._toolDisplayName || this._toolName
}" denied by policy.`,
);
}
if (decision === 'ASK_USER') {
const confirmationDetails: ToolCallConfirmationDetails = {
type: 'info',
title: `Confirm: ${this._toolDisplayName || this._toolName}`,
prompt: this.getDescription(),
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
if (this.messageBus && this._toolName) {
this.messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName: this._toolName,
});
}
}
},
};
return confirmationDetails;
}
}
return false;
}
protected getMessageBusDecision(
@@ -108,7 +136,7 @@ export abstract class BaseToolInvocation<
const correlationId = randomUUID();
const toolCall = {
name: this.constructor.name,
name: this._toolName || this.constructor.name,
args: this.params as Record<string, unknown>,
};
@@ -385,7 +413,12 @@ export abstract class BaseDeclarativeTool<
if (validationError) {
throw new Error(validationError);
}
return this.createInvocation(params, this.messageBus);
return this.createInvocation(
params,
this.messageBus,
this.name,
this.displayName,
);
}
override validateToolParams(params: TParams): string | null {
@@ -408,6 +441,8 @@ export abstract class BaseDeclarativeTool<
protected abstract createInvocation(
params: TParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<TParams, TResult>;
}
+1 -1
View File
@@ -471,7 +471,7 @@ describe('WebFetchTool', () => {
expect(publishSpy).toHaveBeenCalledWith({
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
toolCall: {
name: 'WebFetchToolInvocation',
name: 'web_fetch',
args: { prompt: 'fetch https://example.com' },
},
correlationId: 'test-correlation-id',
+12 -2
View File
@@ -110,8 +110,10 @@ class WebFetchToolInvocation extends BaseToolInvocation<
private readonly config: Config,
params: WebFetchToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) {
super(params, messageBus);
super(params, messageBus, _toolName, _toolDisplayName);
}
private async executeFallback(signal: AbortSignal): Promise<ToolResult> {
@@ -450,7 +452,15 @@ export class WebFetchTool extends BaseDeclarativeTool<
protected createInvocation(
params: WebFetchToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<WebFetchToolParams, ToolResult> {
return new WebFetchToolInvocation(this.config, params, messageBus);
return new WebFetchToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
}
}
+22 -3
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { WEB_SEARCH_TOOL_NAME } from './tool-names.js';
import type { GroundingMetadata } from '@google/genai';
import type { ToolInvocation, ToolResult } from './tools.js';
@@ -64,8 +65,11 @@ class WebSearchToolInvocation extends BaseToolInvocation<
constructor(
private readonly config: Config,
params: WebSearchToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) {
super(params);
super(params, messageBus, _toolName, _toolDisplayName);
}
override getDescription(): string {
@@ -187,7 +191,10 @@ export class WebSearchTool extends BaseDeclarativeTool<
> {
static readonly Name = WEB_SEARCH_TOOL_NAME;
constructor(private readonly config: Config) {
constructor(
private readonly config: Config,
messageBus?: MessageBus,
) {
super(
WebSearchTool.Name,
'GoogleSearch',
@@ -203,6 +210,9 @@ export class WebSearchTool extends BaseDeclarativeTool<
},
required: ['query'],
},
true, // isOutputMarkdown
false, // canUpdateOutput
messageBus,
);
}
@@ -222,7 +232,16 @@ export class WebSearchTool extends BaseDeclarativeTool<
protected createInvocation(
params: WebSearchToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<WebSearchToolParams, WebSearchToolResult> {
return new WebSearchToolInvocation(this.config, params);
return new WebSearchToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
}
}