diff --git a/packages/cli/src/ui/commands/mcpCommand.test.ts b/packages/cli/src/ui/commands/mcpCommand.test.ts index 22b4f58e60..85ee967143 100644 --- a/packages/cli/src/ui/commands/mcpCommand.test.ts +++ b/packages/cli/src/ui/commands/mcpCommand.test.ts @@ -13,10 +13,10 @@ import { getMCPServerStatus, getMCPDiscoveryState, DiscoveredMCPTool, + type MessageBus, } from '@google/gemini-cli-core'; import type { CallableTool } from '@google/genai'; -import { Type } from '@google/genai'; import { MessageType } from '../types.js'; vi.mock('@google/gemini-cli-core', async (importOriginal) => { @@ -37,6 +37,12 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => { }; }); +const mockMessageBus = { + publish: vi.fn(), + subscribe: vi.fn(), + unsubscribe: vi.fn(), +} as unknown as MessageBus; + // Helper function to create a mock DiscoveredMCPTool const createMockMCPTool = ( name: string, @@ -50,8 +56,14 @@ const createMockMCPTool = ( } as unknown as CallableTool, serverName, name, - description || `Description for ${name}`, - { type: Type.OBJECT, properties: {} }, + description || 'Mock tool description', + { type: 'object', properties: {} }, + mockMessageBus, + undefined, // trust + undefined, // nameOverride + undefined, // cliConfig + undefined, // extensionName + undefined, // extensionId ); describe('mcpCommand', () => { diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.test.ts b/packages/cli/src/ui/hooks/atCommandProcessor.test.ts index ae093ee56c..5e86c9b27a 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.test.ts @@ -54,6 +54,12 @@ describe('handleAtCommand', () => { const getToolRegistry = vi.fn(); + const mockMessageBus = { + publish: vi.fn(), + subscribe: vi.fn(), + unsubscribe: vi.fn(), + } as unknown as core.MessageBus; + mockConfig = { getToolRegistry, getTargetDir: () => testRootDir, @@ -94,11 +100,12 @@ describe('handleAtCommand', () => { getMcpClientManager: () => ({ getClient: () => undefined, }), + getMessageBus: () => mockMessageBus, } as unknown as Config; - const registry = new ToolRegistry(mockConfig); - registry.registerTool(new ReadManyFilesTool(mockConfig)); - registry.registerTool(new GlobTool(mockConfig)); + const registry = new ToolRegistry(mockConfig, mockMessageBus); + registry.registerTool(new ReadManyFilesTool(mockConfig, mockMessageBus)); + registry.registerTool(new GlobTool(mockConfig, mockMessageBus)); getToolRegistry.mockReturnValue(registry); }); diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.ts b/packages/cli/src/ui/hooks/atCommandProcessor.ts index 40fb42ea44..10196a3545 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.ts @@ -164,7 +164,10 @@ export async function handleAtCommand({ }; const toolRegistry = config.getToolRegistry(); - const readManyFilesTool = new ReadManyFilesTool(config); + const readManyFilesTool = new ReadManyFilesTool( + config, + config.getMessageBus(), + ); const globTool = toolRegistry.getTool('glob'); if (!readManyFilesTool) { diff --git a/packages/cli/src/zed-integration/zedIntegration.test.ts b/packages/cli/src/zed-integration/zedIntegration.test.ts index 3cc0933170..f0ceec4e22 100644 --- a/packages/cli/src/zed-integration/zedIntegration.test.ts +++ b/packages/cli/src/zed-integration/zedIntegration.test.ts @@ -24,6 +24,7 @@ import { ReadManyFilesTool, type GeminiChat, type Config, + type MessageBus, } from '@google/gemini-cli-core'; import { SettingScope, type LoadedSettings } from '../config/settings.js'; import { loadCliConfig, type CliArgs } from '../config/config.js'; @@ -97,6 +98,11 @@ describe('GeminiAgent', () => { getGeminiClient: vi.fn().mockReturnValue({ startChat: vi.fn().mockResolvedValue({}), }), + getMessageBus: vi.fn().mockReturnValue({ + publish: vi.fn(), + subscribe: vi.fn(), + unsubscribe: vi.fn(), + }), } as unknown as Mocked>>; mockSettings = { merged: { @@ -261,6 +267,7 @@ describe('Session', () => { let session: Session; let mockToolRegistry: { getTool: Mock }; let mockTool: { kind: string; build: Mock }; + let mockMessageBus: Mocked; beforeEach(() => { mockChat = { @@ -279,6 +286,11 @@ describe('Session', () => { mockToolRegistry = { getTool: vi.fn().mockReturnValue(mockTool), }; + mockMessageBus = { + publish: vi.fn(), + subscribe: vi.fn(), + unsubscribe: vi.fn(), + } as unknown as Mocked; mockConfig = { getModel: vi.fn().mockReturnValue('gemini-pro'), getPreviewFeatures: vi.fn().mockReturnValue({}), @@ -290,6 +302,7 @@ describe('Session', () => { getTargetDir: vi.fn().mockReturnValue('/tmp'), getEnableRecursiveFileSearch: vi.fn().mockReturnValue(false), getDebugMode: vi.fn().mockReturnValue(false), + getMessageBus: vi.fn().mockReturnValue(mockMessageBus), } as unknown as Mocked; mockConnection = { sessionUpdate: vi.fn(), diff --git a/packages/cli/src/zed-integration/zedIntegration.ts b/packages/cli/src/zed-integration/zedIntegration.ts index a957f32a14..d4381efc0e 100644 --- a/packages/cli/src/zed-integration/zedIntegration.ts +++ b/packages/cli/src/zed-integration/zedIntegration.ts @@ -609,7 +609,10 @@ export class Session { const ignoredPaths: string[] = []; const toolRegistry = this.config.getToolRegistry(); - const readManyFilesTool = new ReadManyFilesTool(this.config); + const readManyFilesTool = new ReadManyFilesTool( + this.config, + this.config.getMessageBus(), + ); const globTool = toolRegistry.getTool('glob'); if (!readManyFilesTool) { diff --git a/packages/core/src/agents/delegate-to-agent-tool.test.ts b/packages/core/src/agents/delegate-to-agent-tool.test.ts index 9afaee7b87..5c8601f217 100644 --- a/packages/core/src/agents/delegate-to-agent-tool.test.ts +++ b/packages/core/src/agents/delegate-to-agent-tool.test.ts @@ -13,6 +13,7 @@ import { LocalSubagentInvocation } from './local-invocation.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { MessageBusType } from '../confirmation-bus/types.js'; import { DELEGATE_TO_AGENT_TOOL_NAME } from '../tools/tool-names.js'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; vi.mock('./local-invocation.js', () => ({ LocalSubagentInvocation: vi.fn().mockImplementation(() => ({ @@ -58,11 +59,7 @@ describe('DelegateToAgentTool', () => { // eslint-disable-next-line @typescript-eslint/no-explicit-any (registry as any).agents.set(mockAgentDef.name, mockAgentDef); - messageBus = { - publish: vi.fn(), - subscribe: vi.fn(), - unsubscribe: vi.fn(), - } as unknown as MessageBus; + messageBus = createMockMessageBus(); tool = new DelegateToAgentTool(registry, config, messageBus); }); @@ -155,7 +152,7 @@ describe('DelegateToAgentTool', () => { // eslint-disable-next-line @typescript-eslint/no-explicit-any (registry as any).agents.set(invalidAgentDef.name, invalidAgentDef); - expect(() => new DelegateToAgentTool(registry, config)).toThrow( + expect(() => new DelegateToAgentTool(registry, config, messageBus)).toThrow( "Agent 'invalid_agent' cannot have an input parameter named 'agent_name' as it is a reserved parameter for delegation.", ); }); diff --git a/packages/core/src/agents/delegate-to-agent-tool.ts b/packages/core/src/agents/delegate-to-agent-tool.ts index 435f697c72..6ac716b7f4 100644 --- a/packages/core/src/agents/delegate-to-agent-tool.ts +++ b/packages/core/src/agents/delegate-to-agent-tool.ts @@ -30,7 +30,7 @@ export class DelegateToAgentTool extends BaseDeclarativeTool< constructor( private readonly registry: AgentRegistry, private readonly config: Config, - messageBus?: MessageBus, + messageBus: MessageBus, ) { const definitions = registry.getAllDefinitions(); @@ -119,15 +119,15 @@ export class DelegateToAgentTool extends BaseDeclarativeTool< registry.getToolDescription(), Kind.Think, zodToJsonSchema(schema), + messageBus, /* isOutputMarkdown */ true, /* canUpdateOutput */ true, - messageBus, ); } protected createInvocation( params: DelegateParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation { @@ -135,7 +135,7 @@ export class DelegateToAgentTool extends BaseDeclarativeTool< params, this.registry, this.config, - messageBus ?? this.messageBus, + messageBus, _toolName, _toolDisplayName, ); @@ -150,7 +150,7 @@ class DelegateInvocation extends BaseToolInvocation< params: DelegateParams, private readonly registry: AgentRegistry, private readonly config: Config, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ) { diff --git a/packages/core/src/agents/introspection-agent.test.ts b/packages/core/src/agents/introspection-agent.test.ts index 3e28659390..5feac834f5 100644 --- a/packages/core/src/agents/introspection-agent.test.ts +++ b/packages/core/src/agents/introspection-agent.test.ts @@ -6,7 +6,7 @@ import { describe, it, expect } from 'vitest'; import { IntrospectionAgent } from './introspection-agent.js'; -import { GetInternalDocsTool } from '../tools/get-internal-docs.js'; +import { GET_INTERNAL_DOCS_TOOL_NAME } from '../tools/tool-names.js'; import { GEMINI_MODEL_ALIAS_FLASH } from '../config/models.js'; import type { LocalAgentDefinition } from './types.js'; @@ -32,9 +32,7 @@ describe('IntrospectionAgent', () => { expect(localAgent.modelConfig?.model).toBe(GEMINI_MODEL_ALIAS_FLASH); const tools = localAgent.toolConfig?.tools || []; - const hasInternalDocsTool = tools.some( - (t) => t instanceof GetInternalDocsTool, - ); + const hasInternalDocsTool = tools.includes(GET_INTERNAL_DOCS_TOOL_NAME); expect(hasInternalDocsTool).toBe(true); }); diff --git a/packages/core/src/agents/introspection-agent.ts b/packages/core/src/agents/introspection-agent.ts index 413caa28a6..8801af6d50 100644 --- a/packages/core/src/agents/introspection-agent.ts +++ b/packages/core/src/agents/introspection-agent.ts @@ -5,7 +5,7 @@ */ import type { AgentDefinition } from './types.js'; -import { GetInternalDocsTool } from '../tools/get-internal-docs.js'; +import { GET_INTERNAL_DOCS_TOOL_NAME } from '../tools/tool-names.js'; import { GEMINI_MODEL_ALIAS_FLASH } from '../config/models.js'; import { z } from 'zod'; @@ -60,7 +60,7 @@ export const IntrospectionAgent: AgentDefinition< }, toolConfig: { - tools: [new GetInternalDocsTool()], + tools: [GET_INTERNAL_DOCS_TOOL_NAME], }, promptConfig: { diff --git a/packages/core/src/agents/local-executor.test.ts b/packages/core/src/agents/local-executor.test.ts index 1a25d8bd7a..98d017c864 100644 --- a/packages/core/src/agents/local-executor.test.ts +++ b/packages/core/src/agents/local-executor.test.ts @@ -269,8 +269,13 @@ describe('LocalAgentExecutor', () => { vi.useFakeTimers(); mockConfig = makeFakeConfig(); - parentToolRegistry = new ToolRegistry(mockConfig); - parentToolRegistry.registerTool(new LSTool(mockConfig)); + parentToolRegistry = new ToolRegistry( + mockConfig, + mockConfig.getMessageBus(), + ); + parentToolRegistry.registerTool( + new LSTool(mockConfig, mockConfig.getMessageBus()), + ); parentToolRegistry.registerTool( new MockTool({ name: READ_FILE_TOOL_NAME }), ); diff --git a/packages/core/src/agents/local-executor.ts b/packages/core/src/agents/local-executor.ts index 3a713c0167..994c616594 100644 --- a/packages/core/src/agents/local-executor.ts +++ b/packages/core/src/agents/local-executor.ts @@ -99,7 +99,10 @@ export class LocalAgentExecutor { onActivity?: ActivityCallback, ): Promise> { // Create an isolated tool registry for this agent instance. - const agentToolRegistry = new ToolRegistry(runtimeContext); + const agentToolRegistry = new ToolRegistry( + runtimeContext, + runtimeContext.getMessageBus(), + ); const parentToolRegistry = runtimeContext.getToolRegistry(); if (definition.toolConfig) { diff --git a/packages/core/src/agents/local-invocation.test.ts b/packages/core/src/agents/local-invocation.test.ts index 3aa5a39628..91614cea04 100644 --- a/packages/core/src/agents/local-invocation.test.ts +++ b/packages/core/src/agents/local-invocation.test.ts @@ -15,6 +15,7 @@ import { ToolErrorType } from '../tools/tool-error.js'; import type { Config } from '../config/config.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { type z } from 'zod'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; vi.mock('./local-executor.js'); @@ -39,10 +40,12 @@ const testDefinition: LocalAgentDefinition = { describe('LocalSubagentInvocation', () => { let mockExecutorInstance: Mocked>; + let mockMessageBus: MessageBus; beforeEach(() => { vi.clearAllMocks(); mockConfig = makeFakeConfig(); + mockMessageBus = createMockMessageBus(); mockExecutorInstance = { run: vi.fn(), @@ -55,7 +58,6 @@ describe('LocalSubagentInvocation', () => { }); it('should pass the messageBus to the parent constructor', () => { - const mockMessageBus = {} as MessageBus; const params = { task: 'Analyze data' }; const invocation = new LocalSubagentInvocation( testDefinition, @@ -76,6 +78,7 @@ describe('LocalSubagentInvocation', () => { testDefinition, mockConfig, params, + mockMessageBus, ); const description = invocation.getDescription(); expect(description).toBe( @@ -90,6 +93,7 @@ describe('LocalSubagentInvocation', () => { testDefinition, mockConfig, params, + mockMessageBus, ); const description = invocation.getDescription(); // Default INPUT_PREVIEW_MAX_LENGTH is 50 @@ -112,6 +116,7 @@ describe('LocalSubagentInvocation', () => { longNameDef, mockConfig, params, + mockMessageBus, ); const description = invocation.getDescription(); // Default DESCRIPTION_MAX_LENGTH is 200 @@ -137,6 +142,7 @@ describe('LocalSubagentInvocation', () => { testDefinition, mockConfig, params, + mockMessageBus, ); }); diff --git a/packages/core/src/agents/local-invocation.ts b/packages/core/src/agents/local-invocation.ts index 1ca315e1ca..a75fa8a11a 100644 --- a/packages/core/src/agents/local-invocation.ts +++ b/packages/core/src/agents/local-invocation.ts @@ -37,13 +37,13 @@ export class LocalSubagentInvocation extends BaseToolInvocation< * @param definition The definition object that configures the agent. * @param config The global runtime configuration. * @param params The validated input parameters for the agent. - * @param messageBus Optional message bus for policy enforcement. + * @param messageBus Message bus for policy enforcement. */ constructor( private readonly definition: LocalAgentDefinition, private readonly config: Config, params: AgentInputs, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ) { diff --git a/packages/core/src/agents/remote-invocation.test.ts b/packages/core/src/agents/remote-invocation.test.ts index bbe6d15f31..610961e440 100644 --- a/packages/core/src/agents/remote-invocation.test.ts +++ b/packages/core/src/agents/remote-invocation.test.ts @@ -8,6 +8,7 @@ import { describe, it, expect } from 'vitest'; import type { ToolCallConfirmationDetails } from '../tools/tools.js'; import { RemoteAgentInvocation } from './remote-invocation.js'; import type { RemoteAgentDefinition } from './types.js'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; class TestableRemoteAgentInvocation extends RemoteAgentInvocation { override async getConfirmationDetails( @@ -29,8 +30,14 @@ describe('RemoteAgentInvocation', () => { }, }; + const mockMessageBus = createMockMessageBus(); + it('should be instantiated with correct params', () => { - const invocation = new RemoteAgentInvocation(mockDefinition, {}); + const invocation = new RemoteAgentInvocation( + mockDefinition, + {}, + mockMessageBus, + ); expect(invocation).toBeDefined(); expect(invocation.getDescription()).toBe( 'Calling remote agent Test Remote Agent', @@ -38,7 +45,11 @@ describe('RemoteAgentInvocation', () => { }); it('should return false for confirmation details (not yet implemented)', async () => { - const invocation = new TestableRemoteAgentInvocation(mockDefinition, {}); + const invocation = new TestableRemoteAgentInvocation( + mockDefinition, + {}, + mockMessageBus, + ); const details = await invocation.getConfirmationDetails( new AbortController().signal, ); diff --git a/packages/core/src/agents/remote-invocation.ts b/packages/core/src/agents/remote-invocation.ts index c74af79e95..28ee8de6bb 100644 --- a/packages/core/src/agents/remote-invocation.ts +++ b/packages/core/src/agents/remote-invocation.ts @@ -25,7 +25,7 @@ export class RemoteAgentInvocation extends BaseToolInvocation< constructor( private readonly definition: RemoteAgentDefinition, params: AgentInputs, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ) { diff --git a/packages/core/src/agents/subagent-tool-wrapper.test.ts b/packages/core/src/agents/subagent-tool-wrapper.test.ts index 1b8f6269ec..29a241f32e 100644 --- a/packages/core/src/agents/subagent-tool-wrapper.test.ts +++ b/packages/core/src/agents/subagent-tool-wrapper.test.ts @@ -13,6 +13,7 @@ import type { LocalAgentDefinition, AgentInputs } from './types.js'; import type { Config } from '../config/config.js'; import { Kind } from '../tools/tools.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; // Mock dependencies to isolate the SubagentToolWrapper class vi.mock('./local-invocation.js'); @@ -25,6 +26,7 @@ const mockConvertInputConfigToJsonSchema = vi.mocked( // Define reusable test data let mockConfig: Config; +let mockMessageBus: MessageBus; const mockDefinition: LocalAgentDefinition = { kind: 'local', @@ -59,6 +61,7 @@ describe('SubagentToolWrapper', () => { beforeEach(() => { vi.clearAllMocks(); mockConfig = makeFakeConfig(); + mockMessageBus = createMockMessageBus(); // Provide a mock implementation for the schema conversion utility // eslint-disable-next-line @typescript-eslint/no-explicit-any mockConvertInputConfigToJsonSchema.mockReturnValue(mockSchema as any); @@ -66,7 +69,7 @@ describe('SubagentToolWrapper', () => { describe('constructor', () => { it('should call convertInputConfigToJsonSchema with the correct agent inputConfig', () => { - new SubagentToolWrapper(mockDefinition, mockConfig); + new SubagentToolWrapper(mockDefinition, mockConfig, mockMessageBus); expect(convertInputConfigToJsonSchema).toHaveBeenCalledExactlyOnceWith( mockDefinition.inputConfig, @@ -74,7 +77,11 @@ describe('SubagentToolWrapper', () => { }); it('should correctly configure the tool properties from the agent definition', () => { - const wrapper = new SubagentToolWrapper(mockDefinition, mockConfig); + const wrapper = new SubagentToolWrapper( + mockDefinition, + mockConfig, + mockMessageBus, + ); expect(wrapper.name).toBe(mockDefinition.name); expect(wrapper.displayName).toBe(mockDefinition.displayName); @@ -92,12 +99,17 @@ describe('SubagentToolWrapper', () => { const wrapper = new SubagentToolWrapper( definitionWithoutDisplayName, mockConfig, + mockMessageBus, ); expect(wrapper.displayName).toBe(definitionWithoutDisplayName.name); }); it('should generate a valid tool schema using the definition and converted schema', () => { - const wrapper = new SubagentToolWrapper(mockDefinition, mockConfig); + const wrapper = new SubagentToolWrapper( + mockDefinition, + mockConfig, + mockMessageBus, + ); const schema = wrapper.schema; expect(schema.name).toBe(mockDefinition.name); @@ -108,7 +120,11 @@ describe('SubagentToolWrapper', () => { describe('createInvocation', () => { it('should create a LocalSubagentInvocation with the correct parameters', () => { - const wrapper = new SubagentToolWrapper(mockDefinition, mockConfig); + const wrapper = new SubagentToolWrapper( + mockDefinition, + mockConfig, + mockMessageBus, + ); const params: AgentInputs = { goal: 'Test the invocation', priority: 1 }; // The public `build` method calls the protected `createInvocation` after validation @@ -119,18 +135,22 @@ describe('SubagentToolWrapper', () => { mockDefinition, mockConfig, params, - undefined, + mockMessageBus, mockDefinition.name, mockDefinition.displayName, ); }); it('should pass the messageBus to the LocalSubagentInvocation constructor', () => { - const mockMessageBus = {} as MessageBus; + const specificMessageBus = { + publish: vi.fn(), + subscribe: vi.fn(), + unsubscribe: vi.fn(), + } as unknown as MessageBus; const wrapper = new SubagentToolWrapper( mockDefinition, mockConfig, - mockMessageBus, + specificMessageBus, ); const params: AgentInputs = { goal: 'Test the invocation', priority: 1 }; @@ -140,14 +160,18 @@ describe('SubagentToolWrapper', () => { mockDefinition, mockConfig, params, - mockMessageBus, + specificMessageBus, mockDefinition.name, mockDefinition.displayName, ); }); it('should throw a validation error for invalid parameters before creating an invocation', () => { - const wrapper = new SubagentToolWrapper(mockDefinition, mockConfig); + const wrapper = new SubagentToolWrapper( + mockDefinition, + mockConfig, + mockMessageBus, + ); // Missing the required 'goal' parameter const invalidParams = { priority: 1 }; diff --git a/packages/core/src/agents/subagent-tool-wrapper.ts b/packages/core/src/agents/subagent-tool-wrapper.ts index 5acc1799a9..ccb0627b0b 100644 --- a/packages/core/src/agents/subagent-tool-wrapper.ts +++ b/packages/core/src/agents/subagent-tool-wrapper.ts @@ -38,7 +38,7 @@ export class SubagentToolWrapper extends BaseDeclarativeTool< constructor( private readonly definition: AgentDefinition, private readonly config: Config, - messageBus?: MessageBus, + messageBus: MessageBus, ) { const parameterSchema = convertInputConfigToJsonSchema( definition.inputConfig, @@ -50,9 +50,9 @@ export class SubagentToolWrapper extends BaseDeclarativeTool< definition.description, Kind.Think, parameterSchema, + messageBus, /* isOutputMarkdown */ true, /* canUpdateOutput */ true, - messageBus, ); } @@ -67,12 +67,12 @@ export class SubagentToolWrapper extends BaseDeclarativeTool< */ protected createInvocation( params: AgentInputs, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation { const definition = this.definition; - const effectiveMessageBus = messageBus ?? this.messageBus; + const effectiveMessageBus = messageBus; if (definition.kind === 'remote') { return new RemoteAgentInvocation( diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 28714db617..1f389cfaa0 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -61,7 +61,6 @@ vi.mock('../tools/tool-registry', () => { ToolRegistryMock.prototype.sortTools = vi.fn(); ToolRegistryMock.prototype.getAllTools = vi.fn(() => []); // Mock methods if needed ToolRegistryMock.prototype.getTool = vi.fn(); - ToolRegistryMock.prototype.setMessageBus = vi.fn(); ToolRegistryMock.prototype.getFunctionDeclarations = vi.fn(() => []); return { ToolRegistry: ToolRegistryMock }; }); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 348e86b61e..2871e51f80 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -1633,10 +1633,7 @@ export class Config { } async createToolRegistry(): Promise { - const registry = new ToolRegistry(this); - - // Set message bus on tool registry before discovery so MCP tools can access it - registry.setMessageBus(this.messageBus); + const registry = new ToolRegistry(this, this.messageBus); // helper to create & register core tools that are enabled // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -1659,9 +1656,7 @@ export class Config { } if (isEnabled) { - // Pass message bus to tools when feature flag is enabled - // This first implementation is only focused on the general case of - // the tool registry. + // Pass message bus to tools (required for policy engine integration) const toolArgs = [...args, this.getMessageBus()]; registry.registerTool(new ToolClass(...toolArgs)); diff --git a/packages/core/src/core/coreToolHookTriggers.test.ts b/packages/core/src/core/coreToolHookTriggers.test.ts index 68a4357523..403f11339a 100644 --- a/packages/core/src/core/coreToolHookTriggers.test.ts +++ b/packages/core/src/core/coreToolHookTriggers.test.ts @@ -19,8 +19,8 @@ import { } from '../confirmation-bus/types.js'; class MockInvocation extends BaseToolInvocation<{ key?: string }, ToolResult> { - constructor(params: { key?: string }) { - super(params); + constructor(params: { key?: string }, messageBus: MessageBus) { + super(params, messageBus); } getDescription() { return 'mock'; @@ -47,12 +47,14 @@ describe('executeToolWithHooks', () => { unsubscribe: vi.fn(), } as unknown as MessageBus; mockTool = { - build: vi.fn().mockImplementation((params) => new MockInvocation(params)), + build: vi + .fn() + .mockImplementation((params) => new MockInvocation(params, messageBus)), } as unknown as AnyDeclarativeTool; }); it('should prioritize continue: false over decision: block in BeforeTool', async () => { - const invocation = new MockInvocation({}); + const invocation = new MockInvocation({}, messageBus); const abortSignal = new AbortController().signal; vi.mocked(messageBus.request).mockResolvedValue({ @@ -81,7 +83,7 @@ describe('executeToolWithHooks', () => { }); it('should block execution in BeforeTool if decision is block', async () => { - const invocation = new MockInvocation({}); + const invocation = new MockInvocation({}, messageBus); const abortSignal = new AbortController().signal; vi.mocked(messageBus.request).mockResolvedValue({ @@ -108,7 +110,7 @@ describe('executeToolWithHooks', () => { }); it('should handle continue: false in AfterTool', async () => { - const invocation = new MockInvocation({}); + const invocation = new MockInvocation({}, messageBus); const abortSignal = new AbortController().signal; const spy = vi.spyOn(invocation, 'execute'); @@ -146,7 +148,7 @@ describe('executeToolWithHooks', () => { }); it('should block result in AfterTool if decision is deny', async () => { - const invocation = new MockInvocation({}); + const invocation = new MockInvocation({}, messageBus); const abortSignal = new AbortController().signal; // BeforeTool allow @@ -183,7 +185,7 @@ describe('executeToolWithHooks', () => { it('should apply modified tool input from BeforeTool hook', async () => { const params = { key: 'original' }; - const invocation = new MockInvocation(params); + const invocation = new MockInvocation(params, messageBus); const toolName = 'test-tool'; const abortSignal = new AbortController().signal; @@ -235,7 +237,7 @@ describe('executeToolWithHooks', () => { it('should not modify input if hook does not provide tool_input', async () => { const params = { key: 'original' }; - const invocation = new MockInvocation(params); + const invocation = new MockInvocation(params, messageBus); const toolName = 'test-tool'; const abortSignal = new AbortController().signal; diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index 3a8fad3aa3..ba4e22506b 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -49,7 +49,10 @@ vi.mock('fs/promises', () => ({ class TestApprovalTool extends BaseDeclarativeTool<{ id: string }, ToolResult> { static readonly Name = 'testApprovalTool'; - constructor(private config: Config) { + constructor( + private config: Config, + messageBus: MessageBus, + ) { super( TestApprovalTool.Name, 'TestApprovalTool', @@ -60,20 +63,17 @@ class TestApprovalTool extends BaseDeclarativeTool<{ id: string }, ToolResult> { required: ['id'], type: 'object', }, + messageBus, ); } protected createInvocation( params: { id: string }, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation<{ id: string }, ToolResult> { - return new TestApprovalInvocation( - this.config, - params, - messageBus ?? this.messageBus, - ); + return new TestApprovalInvocation(this.config, params, messageBus); } } @@ -84,7 +84,7 @@ class TestApprovalInvocation extends BaseToolInvocation< constructor( private config: Config, params: { id: string }, - messageBus?: MessageBus, + messageBus: MessageBus, ) { super(params, messageBus); } @@ -133,7 +133,7 @@ class AbortDuringConfirmationInvocation extends BaseToolInvocation< private readonly abortController: AbortController, private readonly abortError: Error, params: Record, - messageBus?: MessageBus, + messageBus: MessageBus, ) { super(params, messageBus); } @@ -161,6 +161,7 @@ class AbortDuringConfirmationTool extends BaseDeclarativeTool< constructor( private readonly abortController: AbortController, private readonly abortError: Error, + messageBus: MessageBus, ) { super( 'abortDuringConfirmationTool', @@ -171,12 +172,13 @@ class AbortDuringConfirmationTool extends BaseDeclarativeTool< type: 'object', properties: {}, }, + messageBus, ); } protected createInvocation( params: Record, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation, ToolResult> { @@ -184,7 +186,7 @@ class AbortDuringConfirmationTool extends BaseDeclarativeTool< this.abortController, this.abortError, params, - messageBus ?? this.messageBus, + messageBus, ); } } @@ -546,6 +548,7 @@ describe('CoreToolScheduler', () => { const declarativeTool = new AbortDuringConfirmationTool( abortController, abortError, + createMockMessageBus(), ); const mockToolRegistry = { @@ -741,7 +744,7 @@ class MockEditToolInvocation extends BaseToolInvocation< Record, ToolResult > { - constructor(params: Record, messageBus?: MessageBus) { + constructor(params: Record, messageBus: MessageBus) { super(params, messageBus); } @@ -777,23 +780,30 @@ class MockEditTool extends BaseDeclarativeTool< Record, ToolResult > { - constructor() { - super('mockEditTool', 'mockEditTool', 'A mock edit tool', Kind.Edit, {}); + constructor(messageBus: MessageBus) { + super( + 'mockEditTool', + 'mockEditTool', + 'A mock edit tool', + Kind.Edit, + {}, + messageBus, + ); } protected createInvocation( params: Record, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation, ToolResult> { - return new MockEditToolInvocation(params, messageBus ?? this.messageBus); + return new MockEditToolInvocation(params, messageBus); } } describe('CoreToolScheduler edit cancellation', () => { it('should preserve diff when an edit is cancelled', async () => { - const mockEditTool = new MockEditTool(); + const mockEditTool = new MockEditTool(createMockMessageBus()); const mockToolRegistry = { getTool: () => mockEditTool, getFunctionDeclarations: () => [], @@ -1362,7 +1372,7 @@ describe('CoreToolScheduler request queueing', () => { .fn() .mockReturnValue(new HookSystem(mockConfig)); - const testTool = new TestApprovalTool(mockConfig); + const testTool = new TestApprovalTool(mockConfig, mockMessageBus); const toolRegistry = { getTool: () => testTool, getFunctionDeclarations: () => [], diff --git a/packages/core/src/hooks/hookEventHandler.test.ts b/packages/core/src/hooks/hookEventHandler.test.ts index e30a57ba87..c6032298f6 100644 --- a/packages/core/src/hooks/hookEventHandler.test.ts +++ b/packages/core/src/hooks/hookEventHandler.test.ts @@ -18,6 +18,7 @@ import { SessionStartSource, type HookExecutionResult, } from './types.js'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; // Mock debugLogger const mockDebugLogger = vi.hoisted(() => ({ @@ -92,6 +93,7 @@ describe('HookEventHandler', () => { mockHookPlanner, mockHookRunner, mockHookAggregator, + createMockMessageBus(), ); }); diff --git a/packages/core/src/hooks/hookEventHandler.ts b/packages/core/src/hooks/hookEventHandler.ts index e36bd3719a..92268b7f51 100644 --- a/packages/core/src/hooks/hookEventHandler.ts +++ b/packages/core/src/hooks/hookEventHandler.ts @@ -280,7 +280,7 @@ export class HookEventHandler { private readonly hookPlanner: HookPlanner; private readonly hookRunner: HookRunner; private readonly hookAggregator: HookAggregator; - private readonly messageBus?: MessageBus; + private readonly messageBus: MessageBus; constructor( config: Config, @@ -288,7 +288,7 @@ export class HookEventHandler { hookPlanner: HookPlanner, hookRunner: HookRunner, hookAggregator: HookAggregator, - messageBus?: MessageBus, + messageBus: MessageBus, ) { this.config = config; this.hookPlanner = hookPlanner; diff --git a/packages/core/src/telemetry/loggers.test.ts b/packages/core/src/telemetry/loggers.test.ts index 614694097f..3dabc4a89d 100644 --- a/packages/core/src/telemetry/loggers.test.ts +++ b/packages/core/src/telemetry/loggers.test.ts @@ -17,6 +17,7 @@ import { ToolConfirmationOutcome, ToolErrorType, ToolRegistry, + type MessageBus, } from '../index.js'; import { OutputFormat } from '../output/types.js'; import { logs } from '@opentelemetry/api-logs'; @@ -94,6 +95,7 @@ import { import * as metrics from './metrics.js'; import { FileOperation } from './metrics.js'; import * as sdk from './sdk.js'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; import { vi, describe, beforeEach, it, expect, afterEach } from 'vitest'; import { type GeminiCLIExtension } from '../config/config.js'; import { @@ -999,7 +1001,8 @@ describe('loggers', () => { }, }), getQuestion: () => 'test-question', - getToolRegistry: () => new ToolRegistry(cfg1), + getToolRegistry: () => + new ToolRegistry(cfg1, {} as unknown as MessageBus), getUserMemory: () => 'user-memory', } as unknown as Config; @@ -1031,7 +1034,7 @@ describe('loggers', () => { }); it('should log a tool call with all fields', () => { - const tool = new EditTool(mockConfig); + const tool = new EditTool(mockConfig, createMockMessageBus()); const call: CompletedToolCall = { status: 'success', request: { @@ -1247,7 +1250,7 @@ describe('loggers', () => { contentLength: 13, }, outcome: ToolConfirmationOutcome.ModifyWithEditor, - tool: new EditTool(mockConfig), + tool: new EditTool(mockConfig, createMockMessageBus()), invocation: {} as AnyToolInvocation, durationMs: 100, }; @@ -1326,7 +1329,7 @@ describe('loggers', () => { errorType: undefined, contentLength: 13, }, - tool: new EditTool(mockConfig), + tool: new EditTool(mockConfig, createMockMessageBus()), invocation: {} as AnyToolInvocation, durationMs: 100, }; @@ -1478,6 +1481,7 @@ describe('loggers', () => { }, required: ['arg1', 'arg2'], }, + createMockMessageBus(), false, undefined, undefined, diff --git a/packages/core/src/test-utils/mock-tool.ts b/packages/core/src/test-utils/mock-tool.ts index 4a0eeccc2e..2c12aa0962 100644 --- a/packages/core/src/test-utils/mock-tool.ts +++ b/packages/core/src/test-utils/mock-tool.ts @@ -47,7 +47,7 @@ class MockToolInvocation extends BaseToolInvocation< constructor( private readonly tool: MockTool, params: { [key: string]: unknown }, - messageBus?: MessageBus, + messageBus: MessageBus, ) { super(params, messageBus, tool.name, tool.displayName); } @@ -98,9 +98,9 @@ export class MockTool extends BaseDeclarativeTool< options.description ?? options.name, Kind.Other, options.params, + options.messageBus ?? createMockMessageBus(), options.isOutputMarkdown ?? false, options.canUpdateOutput ?? false, - options.messageBus ?? createMockMessageBus(), ); if (options.shouldConfirmExecute) { @@ -122,11 +122,11 @@ export class MockTool extends BaseDeclarativeTool< protected createInvocation( params: { [key: string]: unknown }, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation<{ [key: string]: unknown }, ToolResult> { - return new MockToolInvocation(this, params, messageBus ?? this.messageBus); + return new MockToolInvocation(this, params, messageBus); } } @@ -146,7 +146,7 @@ export class MockModifiableToolInvocation extends BaseToolInvocation< constructor( private readonly tool: MockModifiableTool, params: Record, - messageBus?: MessageBus, + messageBus: MessageBus, ) { super(params, messageBus, tool.name, tool.displayName); } @@ -207,9 +207,9 @@ export class MockModifiableTool type: 'object', properties: { param: { type: 'string' } }, }, + createMockMessageBus(), true, false, - createMockMessageBus(), ); } @@ -230,14 +230,10 @@ export class MockModifiableTool protected createInvocation( params: Record, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation, ToolResult> { - return new MockModifiableToolInvocation( - this, - params, - messageBus ?? this.messageBus, - ); + return new MockModifiableToolInvocation(this, params, messageBus); } } diff --git a/packages/core/src/tools/activate-skill.test.ts b/packages/core/src/tools/activate-skill.test.ts index 80f4dc6885..3e7fe4a6e8 100644 --- a/packages/core/src/tools/activate-skill.test.ts +++ b/packages/core/src/tools/activate-skill.test.ts @@ -8,6 +8,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'; import { ActivateSkillTool } from './activate-skill.js'; import type { Config } from '../config/config.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; vi.mock('../utils/getFolderStructure.js', () => ({ getFolderStructure: vi.fn().mockResolvedValue('Mock folder structure'), @@ -16,13 +17,10 @@ vi.mock('../utils/getFolderStructure.js', () => ({ describe('ActivateSkillTool', () => { let mockConfig: Config; let tool: ActivateSkillTool; - const mockMessageBus = { - publish: vi.fn(), - subscribe: vi.fn(), - unsubscribe: vi.fn(), - } as unknown as MessageBus; + let mockMessageBus: MessageBus; beforeEach(() => { + mockMessageBus = createMockMessageBus(); const skills = [ { name: 'test-skill', diff --git a/packages/core/src/tools/activate-skill.ts b/packages/core/src/tools/activate-skill.ts index afea50316c..31ee4d0c24 100644 --- a/packages/core/src/tools/activate-skill.ts +++ b/packages/core/src/tools/activate-skill.ts @@ -38,7 +38,7 @@ class ActivateSkillToolInvocation extends BaseToolInvocation< constructor( private config: Config, params: ActivateSkillToolParams, - messageBus: MessageBus | undefined, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ) { @@ -145,7 +145,7 @@ export class ActivateSkillTool extends BaseDeclarativeTool< constructor( private config: Config, - messageBus?: MessageBus, + messageBus: MessageBus, ) { const skills = config.getSkillManager().getSkills(); const skillNames = skills.map((s) => s.name); @@ -169,15 +169,15 @@ export class ActivateSkillTool extends BaseDeclarativeTool< "Activates a specialized agent skill by name. Returns the skill's instructions wrapped in `` tags. These provide specialized guidance for the current task. Use this when you identify a task that matches a skill's description.", Kind.Other, zodToJsonSchema(schema), + messageBus, true, false, - messageBus, ); } protected createInvocation( params: ActivateSkillToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation { diff --git a/packages/core/src/tools/edit.ts b/packages/core/src/tools/edit.ts index 475e8f2745..838bbc6c6e 100644 --- a/packages/core/src/tools/edit.ts +++ b/packages/core/src/tools/edit.ts @@ -118,7 +118,7 @@ class EditToolInvocation constructor( private readonly config: Config, params: EditToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, toolName?: string, displayName?: string, ) { @@ -492,7 +492,7 @@ export class EditTool constructor( private readonly config: Config, - messageBus?: MessageBus, + messageBus: MessageBus, ) { super( EditTool.Name, @@ -535,9 +535,9 @@ Expectation for required parameters: required: ['file_path', 'old_string', 'new_string'], type: 'object', }, + messageBus, true, // isOutputMarkdown false, // canUpdateOutput - messageBus, ); } @@ -568,14 +568,14 @@ Expectation for required parameters: protected createInvocation( params: EditToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, toolName?: string, displayName?: string, ): ToolInvocation { return new EditToolInvocation( this.config, params, - messageBus ?? this.messageBus, + messageBus, toolName ?? this.name, displayName ?? this.displayName, ); diff --git a/packages/core/src/tools/get-internal-docs.test.ts b/packages/core/src/tools/get-internal-docs.test.ts index 40a47b6477..bee9265e70 100644 --- a/packages/core/src/tools/get-internal-docs.test.ts +++ b/packages/core/src/tools/get-internal-docs.test.ts @@ -9,13 +9,14 @@ import { GetInternalDocsTool } from './get-internal-docs.js'; import { ToolErrorType } from './tool-error.js'; import fs from 'node:fs/promises'; import path from 'node:path'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; describe('GetInternalDocsTool (Integration)', () => { let tool: GetInternalDocsTool; const abortSignal = new AbortController().signal; beforeEach(() => { - tool = new GetInternalDocsTool(); + tool = new GetInternalDocsTool(createMockMessageBus()); }); it('should find the documentation root and list files', async () => { diff --git a/packages/core/src/tools/get-internal-docs.ts b/packages/core/src/tools/get-internal-docs.ts index 90637ffced..c18c155404 100644 --- a/packages/core/src/tools/get-internal-docs.ts +++ b/packages/core/src/tools/get-internal-docs.ts @@ -82,7 +82,7 @@ class GetInternalDocsInvocation extends BaseToolInvocation< > { constructor( params: GetInternalDocsParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ) { @@ -161,7 +161,7 @@ export class GetInternalDocsTool extends BaseDeclarativeTool< > { static readonly Name = GET_INTERNAL_DOCS_TOOL_NAME; - constructor(messageBus?: MessageBus) { + constructor(messageBus: MessageBus) { super( GetInternalDocsTool.Name, 'GetInternalDocs', @@ -177,21 +177,21 @@ export class GetInternalDocsTool extends BaseDeclarativeTool< }, }, }, + messageBus, /* isOutputMarkdown */ true, /* canUpdateOutput */ false, - messageBus, ); } protected createInvocation( params: GetInternalDocsParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation { return new GetInternalDocsInvocation( params, - messageBus ?? this.messageBus, + messageBus, _toolName ?? GetInternalDocsTool.Name, _toolDisplayName, ); diff --git a/packages/core/src/tools/glob.ts b/packages/core/src/tools/glob.ts index 40a911d2f0..7a98d8e3e2 100644 --- a/packages/core/src/tools/glob.ts +++ b/packages/core/src/tools/glob.ts @@ -91,7 +91,7 @@ class GlobToolInvocation extends BaseToolInvocation< constructor( private config: Config, params: GlobToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ) { @@ -262,7 +262,7 @@ export class GlobTool extends BaseDeclarativeTool { static readonly Name = GLOB_TOOL_NAME; constructor( private config: Config, - messageBus?: MessageBus, + messageBus: MessageBus, ) { super( GlobTool.Name, @@ -300,9 +300,9 @@ export class GlobTool extends BaseDeclarativeTool { required: ['pattern'], type: 'object', }, + messageBus, true, false, - messageBus, ); } @@ -348,14 +348,14 @@ export class GlobTool extends BaseDeclarativeTool { protected createInvocation( params: GlobToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation { return new GlobToolInvocation( this.config, params, - messageBus ?? this.messageBus, + messageBus, _toolName, _toolDisplayName, ); diff --git a/packages/core/src/tools/grep.ts b/packages/core/src/tools/grep.ts index a4f35e07b4..3fbbb141d6 100644 --- a/packages/core/src/tools/grep.ts +++ b/packages/core/src/tools/grep.ts @@ -62,7 +62,7 @@ class GrepToolInvocation extends BaseToolInvocation< constructor( private readonly config: Config, params: GrepToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ) { @@ -571,7 +571,7 @@ export class GrepTool extends BaseDeclarativeTool { static readonly Name = GREP_TOOL_NAME; constructor( private readonly config: Config, - messageBus?: MessageBus, + messageBus: MessageBus, ) { super( GrepTool.Name, @@ -599,9 +599,9 @@ export class GrepTool extends BaseDeclarativeTool { required: ['pattern'], type: 'object', }, + messageBus, true, false, - messageBus, ); } @@ -674,14 +674,14 @@ export class GrepTool extends BaseDeclarativeTool { protected createInvocation( params: GrepToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation { return new GrepToolInvocation( this.config, params, - messageBus ?? this.messageBus, + messageBus, _toolName, _toolDisplayName, ); diff --git a/packages/core/src/tools/ls.ts b/packages/core/src/tools/ls.ts index c7e02acaf0..80a5ecbc0d 100644 --- a/packages/core/src/tools/ls.ts +++ b/packages/core/src/tools/ls.ts @@ -73,7 +73,7 @@ class LSToolInvocation extends BaseToolInvocation { constructor( private readonly config: Config, params: LSToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ) { @@ -259,7 +259,7 @@ export class LSTool extends BaseDeclarativeTool { constructor( private config: Config, - messageBus?: MessageBus, + messageBus: MessageBus, ) { super( LSTool.Name, @@ -300,9 +300,9 @@ export class LSTool extends BaseDeclarativeTool { required: ['dir_path'], type: 'object', }, + messageBus, true, false, - messageBus, ); } @@ -330,7 +330,7 @@ export class LSTool extends BaseDeclarativeTool { protected createInvocation( params: LSToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation { diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index db7e102c89..1f96d34169 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -57,7 +57,7 @@ import type { } from '../utils/workspaceContext.js'; import type { ToolRegistry } from './tool-registry.js'; import { debugLogger } from '../utils/debugLogger.js'; -import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { type MessageBus } from '../confirmation-bus/message-bus.js'; import { coreEvents } from '../utils/events.js'; import type { ResourceRegistry } from '../resources/resource-registry.js'; import { @@ -895,7 +895,7 @@ export async function discoverTools( mcpServerConfig: MCPServerConfig, mcpClient: Client, cliConfig: Config, - messageBus?: MessageBus, + messageBus: MessageBus, options?: { timeout?: number; signal?: AbortSignal }, ): Promise { try { @@ -922,12 +922,12 @@ export async function discoverTools( toolDef.name, toolDef.description ?? '', toolDef.inputSchema ?? { type: 'object', properties: {} }, + messageBus, mcpServerConfig.trust, undefined, cliConfig, mcpServerConfig.extension?.name, mcpServerConfig.extension?.id, - messageBus, ); discoveredTools.push(tool); diff --git a/packages/core/src/tools/mcp-tool.test.ts b/packages/core/src/tools/mcp-tool.test.ts index fbd0d25dc1..5abc5779e9 100644 --- a/packages/core/src/tools/mcp-tool.test.ts +++ b/packages/core/src/tools/mcp-tool.test.ts @@ -13,6 +13,10 @@ import type { ToolResult } from './tools.js'; import { ToolConfirmationOutcome } from './tools.js'; // Added ToolConfirmationOutcome import type { CallableTool, Part } from '@google/genai'; import { ToolErrorType } from './tool-error.js'; +import { + createMockMessageBus, + getMockMessageBusInstance, +} from '../test-utils/mock-message-bus.js'; // Mock @google/genai mcpToTool and CallableTool // We only need to mock the parts of CallableTool that DiscoveredMCPTool uses. @@ -85,12 +89,15 @@ describe('DiscoveredMCPTool', () => { beforeEach(() => { mockCallTool.mockClear(); mockToolMethod.mockClear(); + const bus = createMockMessageBus(); + getMockMessageBusInstance(bus).defaultToolDecision = 'ask_user'; tool = new DiscoveredMCPTool( mockCallableToolInstance, serverName, serverToolName, baseDescription, inputSchema, + bus, ); // Clear allowlist before each relevant test, especially for shouldConfirmExecute const invocation = tool.build({ param: 'mock' }) as any; @@ -190,6 +197,12 @@ describe('DiscoveredMCPTool', () => { serverToolName, baseDescription, inputSchema, + createMockMessageBus(), + undefined, + undefined, + undefined, + undefined, + undefined, ); const params = { param: 'isErrorTrueCase' }; const functionCall = { @@ -230,6 +243,12 @@ describe('DiscoveredMCPTool', () => { serverToolName, baseDescription, inputSchema, + createMockMessageBus(), + undefined, + undefined, + undefined, + undefined, + undefined, ); const params = { param: 'isErrorTopLevelCase' }; const functionCall = { @@ -273,6 +292,12 @@ describe('DiscoveredMCPTool', () => { serverToolName, baseDescription, inputSchema, + createMockMessageBus(), + undefined, + undefined, + undefined, + undefined, + undefined, ); const params = { param: 'isErrorFalseCase' }; const mockToolSuccessResultObject = { @@ -728,9 +753,12 @@ describe('DiscoveredMCPTool', () => { serverToolName, baseDescription, inputSchema, + createMockMessageBus(), true, undefined, { isTrustedFolder: () => true } as any, + undefined, + undefined, ); const invocation = trustedTool.build({ param: 'mock' }); expect( @@ -862,15 +890,20 @@ describe('DiscoveredMCPTool', () => { 'return confirmation details if trust is false, even if folder is trusted', }, ])('should $description', async ({ trust, isTrusted, shouldConfirm }) => { + const bus = createMockMessageBus(); + getMockMessageBusInstance(bus).defaultToolDecision = 'ask_user'; const testTool = new DiscoveredMCPTool( mockCallableToolInstance, serverName, serverToolName, baseDescription, inputSchema, + bus, trust, undefined, mockConfig(isTrusted) as any, + undefined, + undefined, ); const invocation = testTool.build({ param: 'mock' }); const confirmation = await invocation.shouldConfirmExecute( diff --git a/packages/core/src/tools/mcp-tool.ts b/packages/core/src/tools/mcp-tool.ts index 844a05ab09..44a07d99e8 100644 --- a/packages/core/src/tools/mcp-tool.ts +++ b/packages/core/src/tools/mcp-tool.ts @@ -70,10 +70,10 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation< readonly serverName: string, readonly serverToolName: string, readonly displayName: string, + messageBus: MessageBus, readonly trust?: boolean, params: ToolParams = {}, private readonly cliConfig?: Config, - messageBus?: MessageBus, ) { // Use composite format for policy checks: serverName__toolName // This enables server wildcards (e.g., "google-workspace__*") @@ -239,12 +239,12 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool< readonly serverToolName: string, description: string, override readonly parameterSchema: unknown, + messageBus: MessageBus, readonly trust?: boolean, nameOverride?: string, private readonly cliConfig?: Config, override readonly extensionName?: string, override readonly extensionId?: string, - messageBus?: MessageBus, ) { super( nameOverride ?? generateValidName(serverToolName), @@ -252,9 +252,9 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool< description, Kind.Other, parameterSchema, + messageBus, true, // isOutputMarkdown false, // canUpdateOutput, - messageBus, extensionName, extensionId, ); @@ -271,18 +271,18 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool< this.serverToolName, this.description, this.parameterSchema, + this.messageBus, this.trust, `${this.getFullyQualifiedPrefix()}${this.serverToolName}`, this.cliConfig, this.extensionName, this.extensionId, - this.messageBus, ); } protected createInvocation( params: ToolParams, - _messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _displayName?: string, ): ToolInvocation { @@ -291,10 +291,10 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool< this.serverName, this.serverToolName, _displayName ?? this.displayName, + messageBus, this.trust, params, this.cliConfig, - _messageBus ?? this.messageBus, ); } } diff --git a/packages/core/src/tools/memoryTool.test.ts b/packages/core/src/tools/memoryTool.test.ts index fb251d37ec..4581b19232 100644 --- a/packages/core/src/tools/memoryTool.test.ts +++ b/packages/core/src/tools/memoryTool.test.ts @@ -19,6 +19,10 @@ import * as os from 'node:os'; import { ToolConfirmationOutcome } from './tools.js'; import { ToolErrorType } from './tool-error.js'; import { GEMINI_DIR } from '../utils/paths.js'; +import { + createMockMessageBus, + getMockMessageBusInstance, +} from '../test-utils/mock-message-bus.js'; // Mock dependencies vi.mock(import('node:fs/promises'), async (importOriginal) => { @@ -200,7 +204,7 @@ describe('MemoryTool', () => { let performAddMemoryEntrySpy: Mock; beforeEach(() => { - memoryTool = new MemoryTool(); + memoryTool = new MemoryTool(createMockMessageBus()); // Spy on the static method for these tests performAddMemoryEntrySpy = vi .spyOn(MemoryTool, 'performAddMemoryEntry') @@ -300,7 +304,9 @@ describe('MemoryTool', () => { let memoryTool: MemoryTool; beforeEach(() => { - memoryTool = new MemoryTool(); + const bus = createMockMessageBus(); + getMockMessageBusInstance(bus).defaultToolDecision = 'ask_user'; + memoryTool = new MemoryTool(bus); // Clear the allowlist before each test const invocation = memoryTool.build({ fact: 'mock-fact' }); // eslint-disable-next-line @typescript-eslint/no-explicit-any diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts index 3e38d6d294..56de14eae7 100644 --- a/packages/core/src/tools/memoryTool.ts +++ b/packages/core/src/tools/memoryTool.ts @@ -179,7 +179,7 @@ class MemoryToolInvocation extends BaseToolInvocation< constructor( params: SaveMemoryParams, - messageBus?: MessageBus, + messageBus: MessageBus, toolName?: string, displayName?: string, ) { @@ -298,16 +298,16 @@ export class MemoryTool { static readonly Name = MEMORY_TOOL_NAME; - constructor(messageBus?: MessageBus) { + constructor(messageBus: MessageBus) { super( MemoryTool.Name, 'SaveMemory', memoryToolDescription, Kind.Think, memoryToolSchemaData.parametersJsonSchema as Record, + messageBus, true, false, - messageBus, ); } @@ -323,13 +323,13 @@ export class MemoryTool protected createInvocation( params: SaveMemoryParams, - messageBus?: MessageBus, + messageBus: MessageBus, toolName?: string, displayName?: string, ) { return new MemoryToolInvocation( params, - messageBus ?? this.messageBus, + messageBus, toolName ?? this.name, displayName ?? this.displayName, ); diff --git a/packages/core/src/tools/message-bus-integration.test.ts b/packages/core/src/tools/message-bus-integration.test.ts index bafa140f5d..bfc369b58b 100644 --- a/packages/core/src/tools/message-bus-integration.test.ts +++ b/packages/core/src/tools/message-bus-integration.test.ts @@ -81,21 +81,21 @@ class TestTool extends BaseDeclarativeTool { }, required: ['testParam'], }, + messageBus, true, false, - messageBus, ); } protected createInvocation( params: TestParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ) { return new TestToolInvocation( params, - messageBus ?? this.messageBus, + messageBus, _toolName, _toolDisplayName, ); diff --git a/packages/core/src/tools/read-file.ts b/packages/core/src/tools/read-file.ts index 6b19878385..f748bf8b45 100644 --- a/packages/core/src/tools/read-file.ts +++ b/packages/core/src/tools/read-file.ts @@ -50,7 +50,7 @@ class ReadFileToolInvocation extends BaseToolInvocation< constructor( private config: Config, params: ReadFileToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ) { @@ -149,7 +149,7 @@ export class ReadFileTool extends BaseDeclarativeTool< constructor( private config: Config, - messageBus?: MessageBus, + messageBus: MessageBus, ) { super( ReadFileTool.Name, @@ -176,9 +176,9 @@ export class ReadFileTool extends BaseDeclarativeTool< required: ['file_path'], type: 'object', }, + messageBus, true, false, - messageBus, ); } @@ -225,14 +225,14 @@ export class ReadFileTool extends BaseDeclarativeTool< protected createInvocation( params: ReadFileToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation { return new ReadFileToolInvocation( this.config, params, - messageBus ?? this.messageBus, + messageBus, _toolName, _toolDisplayName, ); diff --git a/packages/core/src/tools/read-many-files.ts b/packages/core/src/tools/read-many-files.ts index de96588337..c1d8c18cd7 100644 --- a/packages/core/src/tools/read-many-files.ts +++ b/packages/core/src/tools/read-many-files.ts @@ -107,7 +107,7 @@ class ReadManyFilesToolInvocation extends BaseToolInvocation< constructor( private readonly config: Config, params: ReadManyFilesParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ) { @@ -447,7 +447,7 @@ export class ReadManyFilesTool extends BaseDeclarativeTool< constructor( private config: Config, - messageBus?: MessageBus, + messageBus: MessageBus, ) { const parameterSchema = { type: 'object', @@ -520,22 +520,22 @@ This tool is useful when you need to understand or analyze a collection of files Use this tool when the user's query implies needing the content of several files simultaneously for context, analysis, or summarization. For text files, it uses default UTF-8 encoding and a '--- {filePath} ---' separator between file contents. The tool inserts a '--- End of content ---' after the last file. Ensure glob patterns are relative to the target directory. Glob patterns like 'src/**/*.js' are supported. Avoid using for single files if a more specific single-file reading tool is available, unless the user specifically requests to process a list containing just one file via this tool. Other binary files (not explicitly requested as image/audio/PDF) are generally skipped. Default excludes apply to common non-text files (except for explicitly requested images/audio/PDFs) and large dependency directories unless 'useDefaultExcludes' is false.`, Kind.Read, parameterSchema, + messageBus, true, // isOutputMarkdown false, // canUpdateOutput - messageBus, ); } protected createInvocation( params: ReadManyFilesParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation { return new ReadManyFilesToolInvocation( this.config, params, - messageBus ?? this.messageBus, + messageBus, _toolName, _toolDisplayName, ); diff --git a/packages/core/src/tools/ripGrep.test.ts b/packages/core/src/tools/ripGrep.test.ts index 0f978313ed..e8eafc9b23 100644 --- a/packages/core/src/tools/ripGrep.test.ts +++ b/packages/core/src/tools/ripGrep.test.ts @@ -24,6 +24,7 @@ import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.j import type { ChildProcess } from 'node:child_process'; import { spawn } from 'node:child_process'; import { downloadRipGrep } from '@joshua.litt/get-ripgrep'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; // Mock dependencies for canUseRipgrep vi.mock('@joshua.litt/get-ripgrep', () => ({ downloadRipGrep: vi.fn(), @@ -267,7 +268,7 @@ describe('RipGrepTool', () => { await fs.writeFile(ripgrepBinaryPath, ''); storageSpy.mockImplementation(() => binDir); tempRootDir = await fs.mkdtemp(path.join(os.tmpdir(), 'grep-tool-root-')); - grepTool = new RipGrepTool(mockConfig); + grepTool = new RipGrepTool(mockConfig, createMockMessageBus()); // Create some test files and directories await fs.writeFile( @@ -833,7 +834,10 @@ describe('RipGrepTool', () => { return mockProcess as unknown as ChildProcess; }); - const multiDirGrepTool = new RipGrepTool(multiDirConfig); + const multiDirGrepTool = new RipGrepTool( + multiDirConfig, + createMockMessageBus(), + ); const params: RipGrepToolParams = { pattern: 'world' }; const invocation = multiDirGrepTool.build(params); const result = await invocation.execute(abortSignal); @@ -927,7 +931,10 @@ describe('RipGrepTool', () => { return mockProcess as unknown as ChildProcess; }); - const multiDirGrepTool = new RipGrepTool(multiDirConfig); + const multiDirGrepTool = new RipGrepTool( + multiDirConfig, + createMockMessageBus(), + ); // Search only in the 'sub' directory of the first workspace const params: RipGrepToolParams = { pattern: 'world', dir_path: 'sub' }; @@ -1656,7 +1663,10 @@ describe('RipGrepTool', () => { getDebugMode: () => false, getFileFilteringRespectGeminiIgnore: () => true, } as unknown as Config; - const geminiIgnoreTool = new RipGrepTool(configWithGeminiIgnore); + const geminiIgnoreTool = new RipGrepTool( + configWithGeminiIgnore, + createMockMessageBus(), + ); mockSpawn.mockImplementationOnce( createMockSpawn({ @@ -1693,7 +1703,10 @@ describe('RipGrepTool', () => { getDebugMode: () => false, getFileFilteringRespectGeminiIgnore: () => false, } as unknown as Config; - const geminiIgnoreTool = new RipGrepTool(configWithoutGeminiIgnore); + const geminiIgnoreTool = new RipGrepTool( + configWithoutGeminiIgnore, + createMockMessageBus(), + ); mockSpawn.mockImplementationOnce( createMockSpawn({ @@ -1816,7 +1829,10 @@ describe('RipGrepTool', () => { getDebugMode: () => false, } as unknown as Config; - const multiDirGrepTool = new RipGrepTool(multiDirConfig); + const multiDirGrepTool = new RipGrepTool( + multiDirConfig, + createMockMessageBus(), + ); const params: RipGrepToolParams = { pattern: 'testPattern' }; const invocation = multiDirGrepTool.build(params); expect(invocation.getDescription()).toBe("'testPattern' within ./"); diff --git a/packages/core/src/tools/ripGrep.ts b/packages/core/src/tools/ripGrep.ts index 973e0a5fa3..0e52884b14 100644 --- a/packages/core/src/tools/ripGrep.ts +++ b/packages/core/src/tools/ripGrep.ts @@ -192,7 +192,7 @@ class GrepToolInvocation extends BaseToolInvocation< private readonly config: Config, private readonly geminiIgnoreParser: GeminiIgnoreParser, params: RipGrepToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ) { @@ -493,7 +493,7 @@ export class RipGrepTool extends BaseDeclarativeTool< constructor( private readonly config: Config, - messageBus?: MessageBus, + messageBus: MessageBus, ) { super( RipGrepTool.Name, @@ -551,9 +551,9 @@ export class RipGrepTool extends BaseDeclarativeTool< required: ['pattern'], type: 'object', }, + messageBus, true, // isOutputMarkdown false, // canUpdateOutput - messageBus, ); this.geminiIgnoreParser = new GeminiIgnoreParser(config.getTargetDir()); } @@ -586,7 +586,7 @@ export class RipGrepTool extends BaseDeclarativeTool< protected createInvocation( params: RipGrepToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation { diff --git a/packages/core/src/tools/shell.ts b/packages/core/src/tools/shell.ts index 8554b2e081..a2d3b611c5 100644 --- a/packages/core/src/tools/shell.ts +++ b/packages/core/src/tools/shell.ts @@ -57,7 +57,7 @@ export class ShellToolInvocation extends BaseToolInvocation< constructor( private readonly config: Config, params: ShellToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ) { @@ -420,7 +420,7 @@ export class ShellTool extends BaseDeclarativeTool< constructor( private readonly config: Config, - messageBus?: MessageBus, + messageBus: MessageBus, ) { void initializeShellParsers().catch(() => { // Errors are surfaced when parsing commands. @@ -450,9 +450,9 @@ export class ShellTool extends BaseDeclarativeTool< }, required: ['command'], }, + messageBus, false, // output is not markdown true, // output can be updated - messageBus, ); } @@ -478,14 +478,14 @@ export class ShellTool extends BaseDeclarativeTool< protected createInvocation( params: ShellToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation { return new ShellToolInvocation( this.config, params, - messageBus ?? this.messageBus, + messageBus, _toolName, _toolDisplayName, ); diff --git a/packages/core/src/tools/smart-edit.test.ts b/packages/core/src/tools/smart-edit.test.ts index 448662fc6d..9c76d77ee4 100644 --- a/packages/core/src/tools/smart-edit.test.ts +++ b/packages/core/src/tools/smart-edit.test.ts @@ -49,6 +49,10 @@ import { } from './smart-edit.js'; import { type FileDiff, ToolConfirmationOutcome } from './tools.js'; import { ToolErrorType } from './tool-error.js'; +import { + createMockMessageBus, + getMockMessageBusInstance, +} from '../test-utils/mock-message-bus.js'; import path from 'node:path'; import fs from 'node:fs'; import os from 'node:os'; @@ -165,7 +169,9 @@ describe('SmartEditTool', () => { }, ); - tool = new SmartEditTool(mockConfig); + const bus = createMockMessageBus(); + getMockMessageBusInstance(bus).defaultToolDecision = 'ask_user'; + tool = new SmartEditTool(mockConfig, bus); }); afterEach(() => { diff --git a/packages/core/src/tools/smart-edit.ts b/packages/core/src/tools/smart-edit.ts index b9a79b7218..aee3a115f8 100644 --- a/packages/core/src/tools/smart-edit.ts +++ b/packages/core/src/tools/smart-edit.ts @@ -386,7 +386,7 @@ class EditToolInvocation constructor( private readonly config: Config, params: EditToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, toolName?: string, displayName?: string, ) { @@ -853,7 +853,7 @@ export class SmartEditTool constructor( private readonly config: Config, - messageBus?: MessageBus, + messageBus: MessageBus, ) { super( SmartEditTool.Name, @@ -915,9 +915,9 @@ A good instruction should concisely answer: required: ['file_path', 'instruction', 'old_string', 'new_string'], type: 'object', }, + messageBus, true, // isOutputMarkdown false, // canUpdateOutput - messageBus, ); } @@ -955,12 +955,12 @@ A good instruction should concisely answer: protected createInvocation( params: EditToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, ): ToolInvocation { return new EditToolInvocation( this.config, params, - messageBus ?? this.messageBus, + messageBus, this.name, this.displayName, ); diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index 55eb89150a..f665827e35 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -90,12 +90,26 @@ const createMockCallableTool = ( }); // Helper to create a DiscoveredMCPTool +const mockMessageBusForHelper = { + publish: vi.fn(), + subscribe: vi.fn(), + unsubscribe: vi.fn(), +} as unknown as MessageBus; + const createMCPTool = ( serverName: string, toolName: string, description: string, mockCallable: CallableTool = {} as CallableTool, -) => new DiscoveredMCPTool(mockCallable, serverName, toolName, description, {}); +) => + new DiscoveredMCPTool( + mockCallable, + serverName, + toolName, + description, + {}, + mockMessageBusForHelper, + ); // Helper to create a mock spawn process for tool discovery const createDiscoveryProcess = (toolDeclarations: FunctionDeclaration[]) => { @@ -171,6 +185,11 @@ const baseConfigParams: ConfigParameters = { describe('ToolRegistry', () => { let config: Config; let toolRegistry: ToolRegistry; + const mockMessageBus = { + publish: vi.fn(), + subscribe: vi.fn(), + unsubscribe: vi.fn(), + } as unknown as MessageBus; let mockConfigGetToolDiscoveryCommand: ReturnType; let mockConfigGetExcludedTools: MockInstance< typeof Config.prototype.getExcludeTools @@ -182,7 +201,7 @@ describe('ToolRegistry', () => { isDirectory: () => true, } as fs.Stats); config = new Config(baseConfigParams); - toolRegistry = new ToolRegistry(config); + toolRegistry = new ToolRegistry(config, mockMessageBus); vi.spyOn(console, 'warn').mockImplementation(() => {}); vi.spyOn(console, 'error').mockImplementation(() => {}); vi.spyOn(console, 'debug').mockImplementation(() => {}); @@ -372,6 +391,7 @@ describe('ToolRegistry', () => { DISCOVERED_TOOL_PREFIX + 'discovered-1', 'desc', {}, + mockMessageBus, ); const mcpZebra = createMCPTool('zebra-server', 'mcp-zebra', 'desc'); const mcpApple = createMCPTool('apple-server', 'mcp-apple', 'desc'); @@ -482,13 +502,6 @@ describe('ToolRegistry', () => { const discoveryCommand = 'my-discovery-command'; mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand); - const mockMessageBus = { - publish: vi.fn(), - subscribe: vi.fn(), - unsubscribe: vi.fn(), - } as unknown as MessageBus; - toolRegistry.setMessageBus(mockMessageBus); - const toolDeclaration: FunctionDeclaration = { name: 'policy-test-tool', description: 'tests policy', @@ -520,6 +533,7 @@ describe('ToolRegistry', () => { DISCOVERED_TOOL_PREFIX + 'test-tool', 'A test tool', {}, + mockMessageBus, ); const params = { param: 'testValue' }; const invocation = tool.build(params); diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index ef481c42af..18c30c5f76 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -34,7 +34,7 @@ class DiscoveredToolInvocation extends BaseToolInvocation< private readonly originalToolName: string, prefixedToolName: string, params: ToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, ) { super(params, messageBus, prefixedToolName); } @@ -135,7 +135,7 @@ export class DiscoveredTool extends BaseDeclarativeTool< prefixedName: string, description: string, override readonly parameterSchema: Record, - messageBus?: MessageBus, + messageBus: MessageBus, ) { const discoveryCmd = config.getToolDiscoveryCommand()!; const callCommand = config.getToolCallCommand()!; @@ -163,16 +163,16 @@ Signal: Signal number or \`(none)\` if no signal was received. fullDescription, Kind.Other, parameterSchema, + messageBus, false, // isOutputMarkdown false, // canUpdateOutput - messageBus, ); this.originalName = originalName; } protected createInvocation( params: ToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _displayName?: string, ): ToolInvocation { @@ -181,7 +181,7 @@ Signal: Signal number or \`(none)\` if no signal was received. this.originalName, _toolName ?? this.name, params, - messageBus ?? this.messageBus, + messageBus, ); } } @@ -192,26 +192,17 @@ export class ToolRegistry { // and `isActive` to get only the active tools. private allKnownTools: Map = new Map(); private config: Config; - private messageBus?: MessageBus; + private messageBus: MessageBus; - constructor(config: Config, messageBus?: MessageBus) { + constructor(config: Config, messageBus: MessageBus) { this.config = config; this.messageBus = messageBus; } - getMessageBus(): MessageBus | undefined { + getMessageBus(): MessageBus { return this.messageBus; } - /** - * @deprecated migration only - will be removed in PR 3 (Enforcement) - * TODO: DELETE ME in PR 3. This is a temporary shim to allow for soft migration - * of tools while the core infrastructure is updated to require a MessageBus at birth. - */ - setMessageBus(messageBus: MessageBus): void { - this.messageBus = messageBus; - } - /** * Registers a tool definition. * diff --git a/packages/core/src/tools/tools.test.ts b/packages/core/src/tools/tools.test.ts index 38827268c1..514f4f3455 100644 --- a/packages/core/src/tools/tools.test.ts +++ b/packages/core/src/tools/tools.test.ts @@ -8,6 +8,7 @@ import { describe, it, expect, vi } from 'vitest'; import type { ToolInvocation, ToolResult } from './tools.js'; import { DeclarativeTool, hasCycleInSchema, Kind } from './tools.js'; import { ToolErrorType } from './tool-error.js'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; class TestToolInvocation implements ToolInvocation { constructor( @@ -36,7 +37,16 @@ class TestTool extends DeclarativeTool { private readonly buildFn: (params: object) => TestToolInvocation; constructor(buildFn: (params: object) => TestToolInvocation) { - super('test-tool', 'Test Tool', 'A tool for testing', Kind.Other, {}); + super( + 'test-tool', + 'Test Tool', + 'A tool for testing', + Kind.Other, + {}, + createMockMessageBus(), + true, // isOutputMarkdown + false, // canUpdateOutput + ); this.buildFn = buildFn; } diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index d4b7fc3094..1b6f6f92ee 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -83,7 +83,7 @@ export abstract class BaseToolInvocation< { constructor( readonly params: TParams, - protected readonly messageBus?: MessageBus, + protected readonly messageBus: MessageBus, readonly _toolName?: string, readonly _toolDisplayName?: string, readonly _serverName?: string, @@ -98,25 +98,24 @@ export abstract class BaseToolInvocation< async shouldConfirmExecute( abortSignal: AbortSignal, ): Promise { - if (this.messageBus) { - const decision = await this.getMessageBusDecision(abortSignal); - if (decision === 'ALLOW') { - return false; - } - - if (decision === 'DENY') { - throw new Error( - `Tool execution for "${ - this._toolDisplayName || this._toolName - }" denied by policy.`, - ); - } - - if (decision === 'ASK_USER') { - return this.getConfirmationDetails(abortSignal); - } + const decision = await this.getMessageBusDecision(abortSignal); + if (decision === 'ALLOW') { + return false; } - // When no message bus, use default confirmation flow + + if (decision === 'DENY') { + throw new Error( + `Tool execution for "${ + this._toolDisplayName || this._toolName + }" denied by policy.`, + ); + } + + if (decision === 'ASK_USER') { + return this.getConfirmationDetails(abortSignal); + } + + // Default to confirmation details if decision is unknown (should not happen with exhaustive policy) return this.getConfirmationDetails(abortSignal); } @@ -142,7 +141,7 @@ export abstract class BaseToolInvocation< outcome === ToolConfirmationOutcome.ProceedAlways || outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave ) { - if (this.messageBus && this._toolName) { + if (this._toolName) { const options = this.getPolicyUpdateOptions(outcome); await this.messageBus.publish({ type: MessageBusType.UPDATE_POLICY, @@ -206,7 +205,7 @@ export abstract class BaseToolInvocation< timeoutId = undefined; } abortSignal.removeEventListener('abort', abortHandler); - this.messageBus?.unsubscribe( + this.messageBus.unsubscribe( MessageBusType.TOOL_CONFIRMATION_RESPONSE, responseHandler, ); @@ -341,9 +340,9 @@ export abstract class DeclarativeTool< readonly description: string, readonly kind: Kind, readonly parameterSchema: unknown, + readonly messageBus: MessageBus, readonly isOutputMarkdown: boolean = true, readonly canUpdateOutput: boolean = false, - readonly messageBus?: MessageBus, readonly extensionName?: string, readonly extensionId?: string, ) {} @@ -496,7 +495,7 @@ export abstract class BaseDeclarativeTool< protected abstract createInvocation( params: TParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation; diff --git a/packages/core/src/tools/web-fetch.test.ts b/packages/core/src/tools/web-fetch.test.ts index f37db3d558..ac483fccd9 100644 --- a/packages/core/src/tools/web-fetch.test.ts +++ b/packages/core/src/tools/web-fetch.test.ts @@ -10,6 +10,10 @@ import type { Config } from '../config/config.js'; import { ApprovalMode } from '../policy/types.js'; import { ToolConfirmationOutcome } from './tools.js'; import { ToolErrorType } from './tool-error.js'; +import { + createMockMessageBus, + getMockMessageBusInstance, +} from '../test-utils/mock-message-bus.js'; import * as fetchUtils from '../utils/fetch.js'; import { MessageBus } from '../confirmation-bus/message-bus.js'; import { PolicyEngine } from '../policy/policy-engine.js'; @@ -126,9 +130,12 @@ describe('parsePrompt', () => { describe('WebFetchTool', () => { let mockConfig: Config; + let bus: MessageBus; beforeEach(() => { vi.resetAllMocks(); + bus = createMockMessageBus(); + getMockMessageBusInstance(bus).defaultToolDecision = 'ask_user'; mockConfig = { getApprovalMode: vi.fn(), setApprovalMode: vi.fn(), @@ -163,12 +170,12 @@ describe('WebFetchTool', () => { expectedError: 'Error(s) in prompt URLs:', }, ])('should throw if $name', ({ prompt, expectedError }) => { - const tool = new WebFetchTool(mockConfig); + const tool = new WebFetchTool(mockConfig, bus); expect(() => tool.build({ prompt })).toThrow(expectedError); }); it('should pass if prompt contains at least one valid URL', () => { - const tool = new WebFetchTool(mockConfig); + const tool = new WebFetchTool(mockConfig, bus); expect(() => tool.build({ prompt: 'fetch https://example.com' }), ).not.toThrow(); @@ -181,7 +188,7 @@ describe('WebFetchTool', () => { vi.spyOn(fetchUtils, 'fetchWithTimeout').mockRejectedValue( new Error('fetch failed'), ); - const tool = new WebFetchTool(mockConfig); + const tool = new WebFetchTool(mockConfig, bus); const params = { prompt: 'fetch https://private.ip' }; const invocation = tool.build(params); const result = await invocation.execute(new AbortController().signal); @@ -191,7 +198,7 @@ describe('WebFetchTool', () => { it('should return WEB_FETCH_PROCESSING_ERROR on general processing failure', async () => { vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false); mockGenerateContent.mockRejectedValue(new Error('API error')); - const tool = new WebFetchTool(mockConfig); + const tool = new WebFetchTool(mockConfig, bus); const params = { prompt: 'fetch https://public.ip' }; const invocation = tool.build(params); const result = await invocation.execute(new AbortController().signal); @@ -209,7 +216,7 @@ describe('WebFetchTool', () => { candidates: [{ content: { parts: [{ text: 'fallback response' }] } }], }); - const tool = new WebFetchTool(mockConfig); + const tool = new WebFetchTool(mockConfig, bus); const params = { prompt: 'fetch https://private.ip' }; const invocation = tool.build(params); await invocation.execute(new AbortController().signal); @@ -237,7 +244,7 @@ describe('WebFetchTool', () => { candidates: [{ content: { parts: [{ text: 'fallback response' }] } }], }); - const tool = new WebFetchTool(mockConfig); + const tool = new WebFetchTool(mockConfig, bus); const params = { prompt: 'fetch https://public.ip' }; const invocation = tool.build(params); await invocation.execute(new AbortController().signal); @@ -306,7 +313,7 @@ describe('WebFetchTool', () => { ], })); - const tool = new WebFetchTool(mockConfig); + const tool = new WebFetchTool(mockConfig, bus); const params = { prompt: 'fetch https://example.com' }; const invocation = tool.build(params); const result = await invocation.execute(new AbortController().signal); @@ -330,7 +337,7 @@ describe('WebFetchTool', () => { describe('shouldConfirmExecute', () => { it('should return confirmation details with the correct prompt and parsed urls', async () => { - const tool = new WebFetchTool(mockConfig); + const tool = new WebFetchTool(mockConfig, bus); const params = { prompt: 'fetch https://example.com' }; const invocation = tool.build(params); const confirmationDetails = await invocation.shouldConfirmExecute( @@ -347,7 +354,7 @@ describe('WebFetchTool', () => { }); it('should convert github urls to raw format', async () => { - const tool = new WebFetchTool(mockConfig); + const tool = new WebFetchTool(mockConfig, bus); const params = { prompt: 'fetch https://github.com/google/gemini-react/blob/main/README.md', @@ -373,7 +380,7 @@ describe('WebFetchTool', () => { vi.spyOn(mockConfig, 'getApprovalMode').mockReturnValue( ApprovalMode.AUTO_EDIT, ); - const tool = new WebFetchTool(mockConfig); + const tool = new WebFetchTool(mockConfig, bus); const params = { prompt: 'fetch https://example.com' }; const invocation = tool.build(params); const confirmationDetails = await invocation.shouldConfirmExecute( @@ -384,7 +391,7 @@ describe('WebFetchTool', () => { }); it('should call setApprovalMode when onConfirm is called with ProceedAlways', async () => { - const tool = new WebFetchTool(mockConfig); + const tool = new WebFetchTool(mockConfig, bus); const params = { prompt: 'fetch https://example.com' }; const invocation = tool.build(params); const confirmationDetails = await invocation.shouldConfirmExecute( @@ -412,8 +419,8 @@ describe('WebFetchTool', () => { let messageBus: MessageBus; let mockUUID: Mock; - const createToolWithMessageBus = (bus?: MessageBus) => { - const tool = new WebFetchTool(mockConfig, bus); + const createToolWithMessageBus = (customBus?: MessageBus) => { + const tool = new WebFetchTool(mockConfig, customBus ?? bus); const params = { prompt: 'fetch https://example.com' }; return { tool, invocation: tool.build(params) }; }; @@ -516,16 +523,6 @@ describe('WebFetchTool', () => { ); }); - it('should fall back to legacy confirmation when no message bus', async () => { - const { invocation } = createToolWithMessageBus(); // No message bus - const result = await invocation.shouldConfirmExecute( - new AbortController().signal, - ); - - expect(result).not.toBe(false); - expect(result).toHaveProperty('type', 'info'); - }); - it('should ignore responses with wrong correlation ID', async () => { vi.useFakeTimers(); const { invocation } = createToolWithMessageBus(messageBus); diff --git a/packages/core/src/tools/web-fetch.ts b/packages/core/src/tools/web-fetch.ts index b31f30ae53..3f8df7fa14 100644 --- a/packages/core/src/tools/web-fetch.ts +++ b/packages/core/src/tools/web-fetch.ts @@ -114,7 +114,7 @@ class WebFetchToolInvocation extends BaseToolInvocation< constructor( private readonly config: Config, params: WebFetchToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ) { @@ -218,7 +218,8 @@ ${textContent} protected override async getConfirmationDetails( _abortSignal: AbortSignal, ): Promise { - // Legacy confirmation flow (no message bus OR policy decision was ASK_USER) + // Check for AUTO_EDIT approval mode. This tool has a specific behavior + // where ProceedAlways switches the entire session to AUTO_EDIT. if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) { return false; } @@ -406,7 +407,7 @@ export class WebFetchTool extends BaseDeclarativeTool< constructor( private readonly config: Config, - messageBus?: MessageBus, + messageBus: MessageBus, ) { super( WebFetchTool.Name, @@ -424,9 +425,9 @@ export class WebFetchTool extends BaseDeclarativeTool< required: ['prompt'], type: 'object', }, + messageBus, true, // isOutputMarkdown false, // canUpdateOutput - messageBus, ); } @@ -452,14 +453,14 @@ export class WebFetchTool extends BaseDeclarativeTool< protected createInvocation( params: WebFetchToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation { return new WebFetchToolInvocation( this.config, params, - messageBus ?? this.messageBus, + messageBus, _toolName, _toolDisplayName, ); diff --git a/packages/core/src/tools/web-search.test.ts b/packages/core/src/tools/web-search.test.ts index 560e17e4ce..3812a54879 100644 --- a/packages/core/src/tools/web-search.test.ts +++ b/packages/core/src/tools/web-search.test.ts @@ -11,6 +11,7 @@ import { WebSearchTool } from './web-search.js'; import type { Config } from '../config/config.js'; import { GeminiClient } from '../core/client.js'; import { ToolErrorType } from './tool-error.js'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; // Mock GeminiClient and Config constructor vi.mock('../core/client.js'); @@ -33,7 +34,7 @@ describe('WebSearchTool', () => { }, } as unknown as Config; mockGeminiClient = new GeminiClient(mockConfigInstance); - tool = new WebSearchTool(mockConfigInstance); + tool = new WebSearchTool(mockConfigInstance, createMockMessageBus()); }); afterEach(() => { diff --git a/packages/core/src/tools/web-search.ts b/packages/core/src/tools/web-search.ts index 6a9e7c9751..5a1eeffb6d 100644 --- a/packages/core/src/tools/web-search.ts +++ b/packages/core/src/tools/web-search.ts @@ -65,7 +65,7 @@ class WebSearchToolInvocation extends BaseToolInvocation< constructor( private readonly config: Config, params: WebSearchToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ) { @@ -192,7 +192,7 @@ export class WebSearchTool extends BaseDeclarativeTool< constructor( private readonly config: Config, - messageBus?: MessageBus, + messageBus: MessageBus, ) { super( WebSearchTool.Name, @@ -209,9 +209,9 @@ export class WebSearchTool extends BaseDeclarativeTool< }, required: ['query'], }, + messageBus, true, // isOutputMarkdown false, // canUpdateOutput - messageBus, ); } @@ -231,7 +231,7 @@ export class WebSearchTool extends BaseDeclarativeTool< protected createInvocation( params: WebSearchToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ): ToolInvocation { diff --git a/packages/core/src/tools/write-file.ts b/packages/core/src/tools/write-file.ts index c3d7e39409..339a60b4b6 100644 --- a/packages/core/src/tools/write-file.ts +++ b/packages/core/src/tools/write-file.ts @@ -149,7 +149,7 @@ class WriteFileToolInvocation extends BaseToolInvocation< constructor( private readonly config: Config, params: WriteFileToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, toolName?: string, displayName?: string, ) { @@ -409,7 +409,7 @@ export class WriteFileTool constructor( private readonly config: Config, - messageBus?: MessageBus, + messageBus: MessageBus, ) { super( WriteFileTool.Name, @@ -432,9 +432,9 @@ export class WriteFileTool required: ['file_path', 'content'], type: 'object', }, + messageBus, true, false, - messageBus, ); } @@ -475,7 +475,7 @@ export class WriteFileTool protected createInvocation( params: WriteFileToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, ): ToolInvocation { return new WriteFileToolInvocation( this.config, diff --git a/packages/core/src/tools/write-todos.test.ts b/packages/core/src/tools/write-todos.test.ts index 9c2bc36fa5..117a3d2681 100644 --- a/packages/core/src/tools/write-todos.test.ts +++ b/packages/core/src/tools/write-todos.test.ts @@ -6,9 +6,10 @@ import { describe, expect, it } from 'vitest'; import { WriteTodosTool, type WriteTodosToolParams } from './write-todos.js'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; describe('WriteTodosTool', () => { - const tool = new WriteTodosTool(); + const tool = new WriteTodosTool(createMockMessageBus()); const signal = new AbortController().signal; describe('validation', () => { diff --git a/packages/core/src/tools/write-todos.ts b/packages/core/src/tools/write-todos.ts index 57c1ad5048..6f12574107 100644 --- a/packages/core/src/tools/write-todos.ts +++ b/packages/core/src/tools/write-todos.ts @@ -101,7 +101,7 @@ class WriteTodosToolInvocation extends BaseToolInvocation< > { constructor( params: WriteTodosToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _toolDisplayName?: string, ) { @@ -145,7 +145,7 @@ export class WriteTodosTool extends BaseDeclarativeTool< > { static readonly Name = WRITE_TODOS_TOOL_NAME; - constructor(messageBus?: MessageBus) { + constructor(messageBus: MessageBus) { super( WriteTodosTool.Name, 'WriteTodos', @@ -180,9 +180,9 @@ export class WriteTodosTool extends BaseDeclarativeTool< required: ['todos'], additionalProperties: false, }, + messageBus, true, // isOutputMarkdown false, // canUpdateOutput - messageBus, ); } @@ -251,13 +251,13 @@ export class WriteTodosTool extends BaseDeclarativeTool< protected createInvocation( params: WriteTodosToolParams, - messageBus?: MessageBus, + messageBus: MessageBus, _toolName?: string, _displayName?: string, ): ToolInvocation { return new WriteTodosToolInvocation( params, - messageBus ?? this.messageBus, + messageBus, _toolName, _displayName, );