feat: Implement message bus and policy engine (#11523)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Allen Hutchison
2025-10-21 11:45:33 -07:00
committed by GitHub
parent 0658b4aa31
commit bf80263bd6
19 changed files with 339 additions and 94 deletions
+60 -41
View File
@@ -14,16 +14,70 @@ import {
} from '@google/gemini-cli-core';
describe('createPolicyEngineConfig', () => {
it('should return ASK_USER for all tools by default', () => {
it('should return ASK_USER for write tools and ALLOW for read-only tools by default', () => {
const settings: Settings = {};
const config = createPolicyEngineConfig(settings, ApprovalMode.DEFAULT);
expect(config.defaultDecision).toBe(PolicyDecision.ASK_USER);
// The order of the rules is not guaranteed, so we sort them by tool name.
config.rules?.sort((a, b) =>
(a.toolName ?? '').localeCompare(b.toolName ?? ''),
);
expect(config.rules).toEqual([
{ toolName: 'replace', decision: 'ask_user', priority: 10 },
{ toolName: 'save_memory', decision: 'ask_user', priority: 10 },
{ toolName: 'run_shell_command', decision: 'ask_user', priority: 10 },
{ toolName: 'write_file', decision: 'ask_user', priority: 10 },
{ toolName: WEB_FETCH_TOOL_NAME, decision: 'ask_user', priority: 10 },
{
toolName: 'glob',
decision: PolicyDecision.ALLOW,
priority: 50,
},
{
toolName: 'google_web_search',
decision: PolicyDecision.ALLOW,
priority: 50,
},
{
toolName: 'list_directory',
decision: PolicyDecision.ALLOW,
priority: 50,
},
{
toolName: 'read_file',
decision: PolicyDecision.ALLOW,
priority: 50,
},
{
toolName: 'read_many_files',
decision: PolicyDecision.ALLOW,
priority: 50,
},
{
toolName: 'replace',
decision: PolicyDecision.ASK_USER,
priority: 10,
},
{
toolName: 'run_shell_command',
decision: PolicyDecision.ASK_USER,
priority: 10,
},
{
toolName: 'save_memory',
decision: PolicyDecision.ASK_USER,
priority: 10,
},
{
toolName: 'search_file_content',
decision: PolicyDecision.ALLOW,
priority: 50,
},
{
toolName: 'web_fetch',
decision: PolicyDecision.ASK_USER,
priority: 10,
},
{
toolName: 'write_file',
decision: PolicyDecision.ASK_USER,
priority: 10,
},
]);
});
@@ -159,18 +213,6 @@ describe('createPolicyEngineConfig', () => {
expect(excludedRule?.priority).toBe(195);
});
it('should allow read-only tools if autoAccept is true', () => {
const settings: Settings = {
tools: { autoAccept: true },
};
const config = createPolicyEngineConfig(settings, ApprovalMode.DEFAULT);
const rule = config.rules?.find(
(r) => r.toolName === 'glob' && r.decision === PolicyDecision.ALLOW,
);
expect(rule).toBeDefined();
expect(rule?.priority).toBe(50);
});
it('should allow all tools in YOLO mode', () => {
const settings: Settings = {};
const config = createPolicyEngineConfig(settings, ApprovalMode.YOLO);
@@ -419,29 +461,6 @@ describe('createPolicyEngineConfig', () => {
// Exclude (195) should win over trust (90) when evaluated
});
it('should create all read-only tool rules when autoAccept is enabled', () => {
const settings: Settings = {
tools: { autoAccept: true },
};
const config = createPolicyEngineConfig(settings, ApprovalMode.DEFAULT);
// All read-only tools should have allow rules
const readOnlyTools = [
'glob',
'search_file_content',
'list_directory',
'read_file',
'read_many_files',
];
for (const tool of readOnlyTools) {
const rule = config.rules?.find(
(r) => r.toolName === tool && r.decision === PolicyDecision.ALLOW,
);
expect(rule).toBeDefined();
expect(rule?.priority).toBe(50);
}
});
it('should handle all approval modes correctly', () => {
const settings: Settings = {};
+31 -10
View File
@@ -22,8 +22,12 @@ import {
EDIT_TOOL_NAME,
MEMORY_TOOL_NAME,
WEB_SEARCH_TOOL_NAME,
type PolicyEngine,
type MessageBus,
MessageBusType,
type UpdatePolicy,
} from '@google/gemini-cli-core';
import type { Settings } from './settings.js';
import { type Settings } from './settings.js';
// READ_ONLY_TOOLS is a list of built-in tools that do not modify the user's
// files or system state.
@@ -69,6 +73,7 @@ export function createPolicyEngineConfig(
// 90: MCP servers with trust=true
// 100: Explicitly allowed individual tools
// 195: Explicitly excluded MCP servers
// 199: Tools that the user has selected as "Always Allow" in the interactive UI.
// 200: Explicitly excluded individual tools (highest priority)
// MCP servers that are explicitly allowed in settings.mcp.allowed
@@ -137,16 +142,14 @@ export function createPolicyEngineConfig(
}
}
// If auto-accept is enabled, allow all read-only tools.
// Allow all read-only tools.
// Priority: 50
if (settings.tools?.autoAccept) {
for (const tool of READ_ONLY_TOOLS) {
rules.push({
toolName: tool,
decision: PolicyDecision.ALLOW,
priority: 50,
});
}
for (const tool of READ_ONLY_TOOLS) {
rules.push({
toolName: tool,
decision: PolicyDecision.ALLOW,
priority: 50,
});
}
// Only add write tool rules if not in YOLO mode
@@ -179,3 +182,21 @@ export function createPolicyEngineConfig(
defaultDecision: PolicyDecision.ASK_USER,
};
}
export function createPolicyUpdater(
policyEngine: PolicyEngine,
messageBus: MessageBus,
) {
messageBus.subscribe(
MessageBusType.UPDATE_POLICY,
(message: UpdatePolicy) => {
const toolName = message.toolName;
policyEngine.addRule({
toolName,
decision: PolicyDecision.ALLOW,
priority: 199, // High priority, but lower than explicit DENY (200)
});
},
);
}
+8
View File
@@ -170,6 +170,10 @@ describe('gemini.tsx main function', () => {
getScreenReader: () => false,
getGeminiMdFileCount: () => 0,
getProjectRoot: () => '/',
getPolicyEngine: vi.fn(),
getMessageBus: () => ({
subscribe: vi.fn(),
}),
} as unknown as Config;
});
vi.mocked(loadSettings).mockReturnValue({
@@ -301,6 +305,10 @@ describe('gemini.tsx main function kitty protocol', () => {
getExperimentalZedIntegration: () => false,
getScreenReader: () => false,
getGeminiMdFileCount: () => 0,
getPolicyEngine: vi.fn(),
getMessageBus: () => ({
subscribe: vi.fn(),
}),
} as unknown as Config);
vi.mocked(loadSettings).mockReturnValue({
errors: [],
+5
View File
@@ -67,6 +67,7 @@ import {
relaunchOnExitCode,
} from './utils/relaunch.js';
import { loadSandboxConfig } from './config/sandboxConfig.js';
import { createPolicyUpdater } from './config/policy.js';
import { ExtensionEnablementManager } from './config/extensions/extensionEnablement.js';
export function validateDnsResolutionOrder(
@@ -370,6 +371,10 @@ export async function main() {
argv,
);
const policyEngine = config.getPolicyEngine();
const messageBus = config.getMessageBus();
createPolicyUpdater(policyEngine, messageBus);
// Cleanup sessions after config initialization
await cleanupExpiredSessions(config, settings.merged);
+1 -1
View File
@@ -491,7 +491,7 @@ export class Config {
this.fileExclusions = new FileExclusions(this);
this.eventEmitter = params.eventEmitter;
this.policyEngine = new PolicyEngine(params.policyEngineConfig);
this.messageBus = new MessageBus(this.policyEngine);
this.messageBus = new MessageBus(this.policyEngine, this.debugMode);
this.outputSettings = {
format: params.output?.format ?? OutputFormat.TEXT,
};
@@ -11,8 +11,12 @@ import { MessageBusType, type Message } from './types.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
export class MessageBus extends EventEmitter {
constructor(private readonly policyEngine: PolicyEngine) {
constructor(
private readonly policyEngine: PolicyEngine,
private readonly debug = false,
) {
super();
this.debug = debug;
}
private isValidMessage(message: Message): boolean {
@@ -35,6 +39,9 @@ export class MessageBus extends EventEmitter {
}
publish(message: Message): void {
if (this.debug) {
console.debug(`[MESSAGE_BUS] publish: ${safeJsonStringify(message)}`);
}
try {
if (!this.isValidMessage(message)) {
throw new Error(
+8 -1
View File
@@ -12,6 +12,7 @@ export enum MessageBusType {
TOOL_POLICY_REJECTION = 'tool-policy-rejection',
TOOL_EXECUTION_SUCCESS = 'tool-execution-success',
TOOL_EXECUTION_FAILURE = 'tool-execution-failure',
UPDATE_POLICY = 'update-policy',
}
export interface ToolConfirmationRequest {
@@ -31,6 +32,11 @@ export interface ToolConfirmationResponse {
requiresUserConfirmation?: boolean;
}
export interface UpdatePolicy {
type: MessageBusType.UPDATE_POLICY;
toolName: string;
}
export interface ToolPolicyRejection {
type: MessageBusType.TOOL_POLICY_REJECTION;
toolCall: FunctionCall;
@@ -53,4 +59,5 @@ export type Message =
| ToolConfirmationResponse
| ToolPolicyRejection
| ToolExecutionSuccess
| ToolExecutionFailure;
| ToolExecutionFailure
| UpdatePolicy;
+2
View File
@@ -11,6 +11,8 @@ export * from './output/json-formatter.js';
export * from './output/stream-json-formatter.js';
export * from './policy/types.js';
export * from './policy/policy-engine.js';
export * from './confirmation-bus/types.js';
export * from './confirmation-bus/message-bus.js';
// Export Core Logic
export * from './core/client.js';
+22 -4
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import fs from 'node:fs';
import path from 'node:path';
import { glob, escape } from 'glob';
@@ -88,8 +89,11 @@ class GlobToolInvocation extends BaseToolInvocation<
constructor(
private config: Config,
params: GlobToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) {
super(params);
super(params, messageBus, _toolName, _toolDisplayName);
}
getDescription(): string {
@@ -261,8 +265,10 @@ class GlobToolInvocation extends BaseToolInvocation<
*/
export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
static readonly Name = GLOB_TOOL_NAME;
constructor(private config: Config) {
constructor(
private config: Config,
messageBus?: MessageBus,
) {
super(
GlobTool.Name,
'FindFiles',
@@ -299,6 +305,9 @@ export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
required: ['pattern'],
type: 'object',
},
true,
false,
messageBus,
);
}
@@ -344,7 +353,16 @@ export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
protected createInvocation(
params: GlobToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<GlobToolParams, ToolResult> {
return new GlobToolInvocation(this.config, params);
return new GlobToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
}
}
+22 -4
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import fs from 'node:fs';
import fsPromises from 'node:fs/promises';
import path from 'node:path';
@@ -61,8 +62,11 @@ class GrepToolInvocation extends BaseToolInvocation<
constructor(
private readonly config: Config,
params: GrepToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) {
super(params);
super(params, messageBus, _toolName, _toolDisplayName);
this.fileExclusions = config.getFileExclusions();
}
@@ -565,8 +569,10 @@ class GrepToolInvocation extends BaseToolInvocation<
*/
export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
static readonly Name = GREP_TOOL_NAME;
constructor(private readonly config: Config) {
constructor(
private readonly config: Config,
messageBus?: MessageBus,
) {
super(
GrepTool.Name,
'SearchText',
@@ -593,6 +599,9 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
required: ['pattern'],
type: 'object',
},
true,
false,
messageBus,
);
}
@@ -665,7 +674,16 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
protected createInvocation(
params: GrepToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<GrepToolParams, ToolResult> {
return new GrepToolInvocation(this.config, params);
return new GrepToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
}
}
+22 -3
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import fs from 'node:fs/promises';
import path from 'node:path';
import type { ToolInvocation, ToolResult } from './tools.js';
@@ -71,8 +72,11 @@ class LSToolInvocation extends BaseToolInvocation<LSToolParams, ToolResult> {
constructor(
private readonly config: Config,
params: LSToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) {
super(params);
super(params, messageBus, _toolName, _toolDisplayName);
}
/**
@@ -255,7 +259,10 @@ class LSToolInvocation extends BaseToolInvocation<LSToolParams, ToolResult> {
export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
static readonly Name = LS_TOOL_NAME;
constructor(private config: Config) {
constructor(
private config: Config,
messageBus?: MessageBus,
) {
super(
LSTool.Name,
'ReadFolder',
@@ -296,6 +303,9 @@ export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
required: ['path'],
type: 'object',
},
true,
false,
messageBus,
);
}
@@ -323,7 +333,16 @@ export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
protected createInvocation(
params: LSToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<LSToolParams, ToolResult> {
return new LSToolInvocation(this.config, params);
return new LSToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
}
}
+22 -3
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import path from 'node:path';
import { makeRelative, shortenPath } from '../utils/paths.js';
import type { ToolInvocation, ToolLocation, ToolResult } from './tools.js';
@@ -48,8 +49,11 @@ class ReadFileToolInvocation extends BaseToolInvocation<
constructor(
private config: Config,
params: ReadFileToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) {
super(params);
super(params, messageBus, _toolName, _toolDisplayName);
}
getDescription(): string {
@@ -138,7 +142,10 @@ export class ReadFileTool extends BaseDeclarativeTool<
> {
static readonly Name = READ_FILE_TOOL_NAME;
constructor(private config: Config) {
constructor(
private config: Config,
messageBus?: MessageBus,
) {
super(
ReadFileTool.Name,
'ReadFile',
@@ -165,6 +172,9 @@ export class ReadFileTool extends BaseDeclarativeTool<
required: ['absolute_path'],
type: 'object',
},
true,
false,
messageBus,
);
}
@@ -209,7 +219,16 @@ export class ReadFileTool extends BaseDeclarativeTool<
protected createInvocation(
params: ReadFileToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<ReadFileToolParams, ToolResult> {
return new ReadFileToolInvocation(this.config, params);
return new ReadFileToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
}
}
+22 -3
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import type { ToolInvocation, ToolResult } from './tools.js';
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
import { getErrorMessage } from '../utils/errors.js';
@@ -114,8 +115,11 @@ class ReadManyFilesToolInvocation extends BaseToolInvocation<
constructor(
private readonly config: Config,
params: ReadManyFilesParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) {
super(params);
super(params, messageBus, _toolName, _toolDisplayName);
}
getDescription(): string {
@@ -477,7 +481,10 @@ export class ReadManyFilesTool extends BaseDeclarativeTool<
> {
static readonly Name = READ_MANY_FILES_TOOL_NAME;
constructor(private config: Config) {
constructor(
private config: Config,
messageBus?: MessageBus,
) {
const parameterSchema = {
type: 'object',
properties: {
@@ -559,12 +566,24 @@ This tool is useful when you need to understand or analyze a collection of files
Use this tool when the user's query implies needing the content of several files simultaneously for context, analysis, or summarization. For text files, it uses default UTF-8 encoding and a '--- {filePath} ---' separator between file contents. The tool inserts a '--- End of content ---' after the last file. Ensure paths are relative to the target directory. Glob patterns like 'src/**/*.js' are supported. Avoid using for single files if a more specific single-file reading tool is available, unless the user specifically requests to process a list containing just one file via this tool. Other binary files (not explicitly requested as image/PDF) are generally skipped. Default excludes apply to common non-text files (except for explicitly requested images/PDFs) and large dependency directories unless 'useDefaultExcludes' is false.`,
Kind.Read,
parameterSchema,
true, // isOutputMarkdown
false, // canUpdateOutput
messageBus,
);
}
protected createInvocation(
params: ReadManyFilesParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<ReadManyFilesParams, ToolResult> {
return new ReadManyFilesToolInvocation(this.config, params);
return new ReadManyFilesToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
}
}
+22 -3
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import fs from 'node:fs';
import path from 'node:path';
import { EOL } from 'node:os';
@@ -110,8 +111,11 @@ class GrepToolInvocation extends BaseToolInvocation<
constructor(
private readonly config: Config,
params: RipGrepToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) {
super(params);
super(params, messageBus, _toolName, _toolDisplayName);
}
/**
@@ -449,7 +453,10 @@ export class RipGrepTool extends BaseDeclarativeTool<
> {
static readonly Name = GREP_TOOL_NAME;
constructor(private readonly config: Config) {
constructor(
private readonly config: Config,
messageBus?: MessageBus,
) {
super(
RipGrepTool.Name,
'SearchText',
@@ -476,6 +483,9 @@ export class RipGrepTool extends BaseDeclarativeTool<
required: ['pattern'],
type: 'object',
},
true, // isOutputMarkdown
false, // canUpdateOutput
messageBus,
);
}
@@ -548,7 +558,16 @@ export class RipGrepTool extends BaseDeclarativeTool<
protected createInvocation(
params: RipGrepToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<RipGrepToolParams, ToolResult> {
return new GrepToolInvocation(this.config, params);
return new GrepToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
}
}
+48 -13
View File
@@ -76,13 +76,9 @@ export abstract class BaseToolInvocation<
constructor(
readonly params: TParams,
protected readonly messageBus?: MessageBus,
) {
if (this.messageBus) {
console.debug(
`[DEBUG] Tool ${this.constructor.name} created with messageBus: YES`,
);
}
}
readonly _toolName?: string,
readonly _toolDisplayName?: string,
) {}
abstract getDescription(): string;
@@ -90,11 +86,43 @@ export abstract class BaseToolInvocation<
return [];
}
shouldConfirmExecute(
_abortSignal: AbortSignal,
async shouldConfirmExecute(
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
// Default implementation for tools that don't override it.
return Promise.resolve(false);
if (this.messageBus) {
const decision = await this.getMessageBusDecision(abortSignal);
if (decision === 'ALLOW') {
return false;
}
if (decision === 'DENY') {
throw new Error(
`Tool execution for "${
this._toolDisplayName || this._toolName
}" denied by policy.`,
);
}
if (decision === 'ASK_USER') {
const confirmationDetails: ToolCallConfirmationDetails = {
type: 'info',
title: `Confirm: ${this._toolDisplayName || this._toolName}`,
prompt: this.getDescription(),
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
if (this.messageBus && this._toolName) {
this.messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName: this._toolName,
});
}
}
},
};
return confirmationDetails;
}
}
return false;
}
protected getMessageBusDecision(
@@ -108,7 +136,7 @@ export abstract class BaseToolInvocation<
const correlationId = randomUUID();
const toolCall = {
name: this.constructor.name,
name: this._toolName || this.constructor.name,
args: this.params as Record<string, unknown>,
};
@@ -385,7 +413,12 @@ export abstract class BaseDeclarativeTool<
if (validationError) {
throw new Error(validationError);
}
return this.createInvocation(params, this.messageBus);
return this.createInvocation(
params,
this.messageBus,
this.name,
this.displayName,
);
}
override validateToolParams(params: TParams): string | null {
@@ -408,6 +441,8 @@ export abstract class BaseDeclarativeTool<
protected abstract createInvocation(
params: TParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<TParams, TResult>;
}
+1 -1
View File
@@ -471,7 +471,7 @@ describe('WebFetchTool', () => {
expect(publishSpy).toHaveBeenCalledWith({
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
toolCall: {
name: 'WebFetchToolInvocation',
name: 'web_fetch',
args: { prompt: 'fetch https://example.com' },
},
correlationId: 'test-correlation-id',
+12 -2
View File
@@ -110,8 +110,10 @@ class WebFetchToolInvocation extends BaseToolInvocation<
private readonly config: Config,
params: WebFetchToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) {
super(params, messageBus);
super(params, messageBus, _toolName, _toolDisplayName);
}
private async executeFallback(signal: AbortSignal): Promise<ToolResult> {
@@ -450,7 +452,15 @@ export class WebFetchTool extends BaseDeclarativeTool<
protected createInvocation(
params: WebFetchToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<WebFetchToolParams, ToolResult> {
return new WebFetchToolInvocation(this.config, params, messageBus);
return new WebFetchToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
}
}
+22 -3
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { WEB_SEARCH_TOOL_NAME } from './tool-names.js';
import type { GroundingMetadata } from '@google/genai';
import type { ToolInvocation, ToolResult } from './tools.js';
@@ -64,8 +65,11 @@ class WebSearchToolInvocation extends BaseToolInvocation<
constructor(
private readonly config: Config,
params: WebSearchToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
) {
super(params);
super(params, messageBus, _toolName, _toolDisplayName);
}
override getDescription(): string {
@@ -187,7 +191,10 @@ export class WebSearchTool extends BaseDeclarativeTool<
> {
static readonly Name = WEB_SEARCH_TOOL_NAME;
constructor(private readonly config: Config) {
constructor(
private readonly config: Config,
messageBus?: MessageBus,
) {
super(
WebSearchTool.Name,
'GoogleSearch',
@@ -203,6 +210,9 @@ export class WebSearchTool extends BaseDeclarativeTool<
},
required: ['query'],
},
true, // isOutputMarkdown
false, // canUpdateOutput
messageBus,
);
}
@@ -222,7 +232,16 @@ export class WebSearchTool extends BaseDeclarativeTool<
protected createInvocation(
params: WebSearchToolParams,
messageBus?: MessageBus,
_toolName?: string,
_toolDisplayName?: string,
): ToolInvocation<WebSearchToolParams, WebSearchToolResult> {
return new WebSearchToolInvocation(this.config, params);
return new WebSearchToolInvocation(
this.config,
params,
messageBus,
_toolName,
_toolDisplayName,
);
}
}