mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-20 10:10:56 -07:00
feat(core,cli): enforce mandatory MessageBus injection (Phase 3 Hard Migration) (#15776)
This commit is contained in:
@@ -8,6 +8,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { ActivateSkillTool } from './activate-skill.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
|
||||
vi.mock('../utils/getFolderStructure.js', () => ({
|
||||
getFolderStructure: vi.fn().mockResolvedValue('Mock folder structure'),
|
||||
@@ -16,13 +17,10 @@ vi.mock('../utils/getFolderStructure.js', () => ({
|
||||
describe('ActivateSkillTool', () => {
|
||||
let mockConfig: Config;
|
||||
let tool: ActivateSkillTool;
|
||||
const mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
subscribe: vi.fn(),
|
||||
unsubscribe: vi.fn(),
|
||||
} as unknown as MessageBus;
|
||||
let mockMessageBus: MessageBus;
|
||||
|
||||
beforeEach(() => {
|
||||
mockMessageBus = createMockMessageBus();
|
||||
const skills = [
|
||||
{
|
||||
name: 'test-skill',
|
||||
|
||||
@@ -38,7 +38,7 @@ class ActivateSkillToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private config: Config,
|
||||
params: ActivateSkillToolParams,
|
||||
messageBus: MessageBus | undefined,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
@@ -145,7 +145,7 @@ export class ActivateSkillTool extends BaseDeclarativeTool<
|
||||
|
||||
constructor(
|
||||
private config: Config,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
const skills = config.getSkillManager().getSkills();
|
||||
const skillNames = skills.map((s) => s.name);
|
||||
@@ -169,15 +169,15 @@ export class ActivateSkillTool extends BaseDeclarativeTool<
|
||||
"Activates a specialized agent skill by name. Returns the skill's instructions wrapped in `<ACTIVATED_SKILL>` tags. These provide specialized guidance for the current task. Use this when you identify a task that matches a skill's description.",
|
||||
Kind.Other,
|
||||
zodToJsonSchema(schema),
|
||||
messageBus,
|
||||
true,
|
||||
false,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
protected createInvocation(
|
||||
params: ActivateSkillToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<ActivateSkillToolParams, ToolResult> {
|
||||
|
||||
@@ -118,7 +118,7 @@ class EditToolInvocation
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: EditToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
toolName?: string,
|
||||
displayName?: string,
|
||||
) {
|
||||
@@ -492,7 +492,7 @@ export class EditTool
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
super(
|
||||
EditTool.Name,
|
||||
@@ -535,9 +535,9 @@ Expectation for required parameters:
|
||||
required: ['file_path', 'old_string', 'new_string'],
|
||||
type: 'object',
|
||||
},
|
||||
messageBus,
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -568,14 +568,14 @@ Expectation for required parameters:
|
||||
|
||||
protected createInvocation(
|
||||
params: EditToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
toolName?: string,
|
||||
displayName?: string,
|
||||
): ToolInvocation<EditToolParams, ToolResult> {
|
||||
return new EditToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus ?? this.messageBus,
|
||||
messageBus,
|
||||
toolName ?? this.name,
|
||||
displayName ?? this.displayName,
|
||||
);
|
||||
|
||||
@@ -9,13 +9,14 @@ import { GetInternalDocsTool } from './get-internal-docs.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import fs from 'node:fs/promises';
|
||||
import path from 'node:path';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
|
||||
describe('GetInternalDocsTool (Integration)', () => {
|
||||
let tool: GetInternalDocsTool;
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
beforeEach(() => {
|
||||
tool = new GetInternalDocsTool();
|
||||
tool = new GetInternalDocsTool(createMockMessageBus());
|
||||
});
|
||||
|
||||
it('should find the documentation root and list files', async () => {
|
||||
|
||||
@@ -82,7 +82,7 @@ class GetInternalDocsInvocation extends BaseToolInvocation<
|
||||
> {
|
||||
constructor(
|
||||
params: GetInternalDocsParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
@@ -161,7 +161,7 @@ export class GetInternalDocsTool extends BaseDeclarativeTool<
|
||||
> {
|
||||
static readonly Name = GET_INTERNAL_DOCS_TOOL_NAME;
|
||||
|
||||
constructor(messageBus?: MessageBus) {
|
||||
constructor(messageBus: MessageBus) {
|
||||
super(
|
||||
GetInternalDocsTool.Name,
|
||||
'GetInternalDocs',
|
||||
@@ -177,21 +177,21 @@ export class GetInternalDocsTool extends BaseDeclarativeTool<
|
||||
},
|
||||
},
|
||||
},
|
||||
messageBus,
|
||||
/* isOutputMarkdown */ true,
|
||||
/* canUpdateOutput */ false,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
protected createInvocation(
|
||||
params: GetInternalDocsParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<GetInternalDocsParams, ToolResult> {
|
||||
return new GetInternalDocsInvocation(
|
||||
params,
|
||||
messageBus ?? this.messageBus,
|
||||
messageBus,
|
||||
_toolName ?? GetInternalDocsTool.Name,
|
||||
_toolDisplayName,
|
||||
);
|
||||
|
||||
@@ -91,7 +91,7 @@ class GlobToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private config: Config,
|
||||
params: GlobToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
@@ -262,7 +262,7 @@ export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
|
||||
static readonly Name = GLOB_TOOL_NAME;
|
||||
constructor(
|
||||
private config: Config,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
super(
|
||||
GlobTool.Name,
|
||||
@@ -300,9 +300,9 @@ export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
|
||||
required: ['pattern'],
|
||||
type: 'object',
|
||||
},
|
||||
messageBus,
|
||||
true,
|
||||
false,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -348,14 +348,14 @@ export class GlobTool extends BaseDeclarativeTool<GlobToolParams, ToolResult> {
|
||||
|
||||
protected createInvocation(
|
||||
params: GlobToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<GlobToolParams, ToolResult> {
|
||||
return new GlobToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus ?? this.messageBus,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_toolDisplayName,
|
||||
);
|
||||
|
||||
@@ -62,7 +62,7 @@ class GrepToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: GrepToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
@@ -571,7 +571,7 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
|
||||
static readonly Name = GREP_TOOL_NAME;
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
super(
|
||||
GrepTool.Name,
|
||||
@@ -599,9 +599,9 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
|
||||
required: ['pattern'],
|
||||
type: 'object',
|
||||
},
|
||||
messageBus,
|
||||
true,
|
||||
false,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -674,14 +674,14 @@ export class GrepTool extends BaseDeclarativeTool<GrepToolParams, ToolResult> {
|
||||
|
||||
protected createInvocation(
|
||||
params: GrepToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<GrepToolParams, ToolResult> {
|
||||
return new GrepToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus ?? this.messageBus,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_toolDisplayName,
|
||||
);
|
||||
|
||||
@@ -73,7 +73,7 @@ class LSToolInvocation extends BaseToolInvocation<LSToolParams, ToolResult> {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: LSToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
@@ -259,7 +259,7 @@ export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
|
||||
|
||||
constructor(
|
||||
private config: Config,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
super(
|
||||
LSTool.Name,
|
||||
@@ -300,9 +300,9 @@ export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
|
||||
required: ['dir_path'],
|
||||
type: 'object',
|
||||
},
|
||||
messageBus,
|
||||
true,
|
||||
false,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -330,7 +330,7 @@ export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
|
||||
|
||||
protected createInvocation(
|
||||
params: LSToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<LSToolParams, ToolResult> {
|
||||
|
||||
@@ -57,7 +57,7 @@ import type {
|
||||
} from '../utils/workspaceContext.js';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { type MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { coreEvents } from '../utils/events.js';
|
||||
import type { ResourceRegistry } from '../resources/resource-registry.js';
|
||||
import {
|
||||
@@ -895,7 +895,7 @@ export async function discoverTools(
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
mcpClient: Client,
|
||||
cliConfig: Config,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
options?: { timeout?: number; signal?: AbortSignal },
|
||||
): Promise<DiscoveredMCPTool[]> {
|
||||
try {
|
||||
@@ -922,12 +922,12 @@ export async function discoverTools(
|
||||
toolDef.name,
|
||||
toolDef.description ?? '',
|
||||
toolDef.inputSchema ?? { type: 'object', properties: {} },
|
||||
messageBus,
|
||||
mcpServerConfig.trust,
|
||||
undefined,
|
||||
cliConfig,
|
||||
mcpServerConfig.extension?.name,
|
||||
mcpServerConfig.extension?.id,
|
||||
messageBus,
|
||||
);
|
||||
|
||||
discoveredTools.push(tool);
|
||||
|
||||
@@ -13,6 +13,10 @@ import type { ToolResult } from './tools.js';
|
||||
import { ToolConfirmationOutcome } from './tools.js'; // Added ToolConfirmationOutcome
|
||||
import type { CallableTool, Part } from '@google/genai';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import {
|
||||
createMockMessageBus,
|
||||
getMockMessageBusInstance,
|
||||
} from '../test-utils/mock-message-bus.js';
|
||||
|
||||
// Mock @google/genai mcpToTool and CallableTool
|
||||
// We only need to mock the parts of CallableTool that DiscoveredMCPTool uses.
|
||||
@@ -85,12 +89,15 @@ describe('DiscoveredMCPTool', () => {
|
||||
beforeEach(() => {
|
||||
mockCallTool.mockClear();
|
||||
mockToolMethod.mockClear();
|
||||
const bus = createMockMessageBus();
|
||||
getMockMessageBusInstance(bus).defaultToolDecision = 'ask_user';
|
||||
tool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
bus,
|
||||
);
|
||||
// Clear allowlist before each relevant test, especially for shouldConfirmExecute
|
||||
const invocation = tool.build({ param: 'mock' }) as any;
|
||||
@@ -190,6 +197,12 @@ describe('DiscoveredMCPTool', () => {
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
createMockMessageBus(),
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
const params = { param: 'isErrorTrueCase' };
|
||||
const functionCall = {
|
||||
@@ -230,6 +243,12 @@ describe('DiscoveredMCPTool', () => {
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
createMockMessageBus(),
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
const params = { param: 'isErrorTopLevelCase' };
|
||||
const functionCall = {
|
||||
@@ -273,6 +292,12 @@ describe('DiscoveredMCPTool', () => {
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
createMockMessageBus(),
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
const params = { param: 'isErrorFalseCase' };
|
||||
const mockToolSuccessResultObject = {
|
||||
@@ -728,9 +753,12 @@ describe('DiscoveredMCPTool', () => {
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
createMockMessageBus(),
|
||||
true,
|
||||
undefined,
|
||||
{ isTrustedFolder: () => true } as any,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
const invocation = trustedTool.build({ param: 'mock' });
|
||||
expect(
|
||||
@@ -862,15 +890,20 @@ describe('DiscoveredMCPTool', () => {
|
||||
'return confirmation details if trust is false, even if folder is trusted',
|
||||
},
|
||||
])('should $description', async ({ trust, isTrusted, shouldConfirm }) => {
|
||||
const bus = createMockMessageBus();
|
||||
getMockMessageBusInstance(bus).defaultToolDecision = 'ask_user';
|
||||
const testTool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
bus,
|
||||
trust,
|
||||
undefined,
|
||||
mockConfig(isTrusted) as any,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
const invocation = testTool.build({ param: 'mock' });
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
|
||||
@@ -70,10 +70,10 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
||||
readonly serverName: string,
|
||||
readonly serverToolName: string,
|
||||
readonly displayName: string,
|
||||
messageBus: MessageBus,
|
||||
readonly trust?: boolean,
|
||||
params: ToolParams = {},
|
||||
private readonly cliConfig?: Config,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
// Use composite format for policy checks: serverName__toolName
|
||||
// This enables server wildcards (e.g., "google-workspace__*")
|
||||
@@ -239,12 +239,12 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
readonly serverToolName: string,
|
||||
description: string,
|
||||
override readonly parameterSchema: unknown,
|
||||
messageBus: MessageBus,
|
||||
readonly trust?: boolean,
|
||||
nameOverride?: string,
|
||||
private readonly cliConfig?: Config,
|
||||
override readonly extensionName?: string,
|
||||
override readonly extensionId?: string,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
super(
|
||||
nameOverride ?? generateValidName(serverToolName),
|
||||
@@ -252,9 +252,9 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
description,
|
||||
Kind.Other,
|
||||
parameterSchema,
|
||||
messageBus,
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput,
|
||||
messageBus,
|
||||
extensionName,
|
||||
extensionId,
|
||||
);
|
||||
@@ -271,18 +271,18 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
this.serverToolName,
|
||||
this.description,
|
||||
this.parameterSchema,
|
||||
this.messageBus,
|
||||
this.trust,
|
||||
`${this.getFullyQualifiedPrefix()}${this.serverToolName}`,
|
||||
this.cliConfig,
|
||||
this.extensionName,
|
||||
this.extensionId,
|
||||
this.messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
protected createInvocation(
|
||||
params: ToolParams,
|
||||
_messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_displayName?: string,
|
||||
): ToolInvocation<ToolParams, ToolResult> {
|
||||
@@ -291,10 +291,10 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
this.serverName,
|
||||
this.serverToolName,
|
||||
_displayName ?? this.displayName,
|
||||
messageBus,
|
||||
this.trust,
|
||||
params,
|
||||
this.cliConfig,
|
||||
_messageBus ?? this.messageBus,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,10 @@ import * as os from 'node:os';
|
||||
import { ToolConfirmationOutcome } from './tools.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { GEMINI_DIR } from '../utils/paths.js';
|
||||
import {
|
||||
createMockMessageBus,
|
||||
getMockMessageBusInstance,
|
||||
} from '../test-utils/mock-message-bus.js';
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock(import('node:fs/promises'), async (importOriginal) => {
|
||||
@@ -200,7 +204,7 @@ describe('MemoryTool', () => {
|
||||
let performAddMemoryEntrySpy: Mock<typeof MemoryTool.performAddMemoryEntry>;
|
||||
|
||||
beforeEach(() => {
|
||||
memoryTool = new MemoryTool();
|
||||
memoryTool = new MemoryTool(createMockMessageBus());
|
||||
// Spy on the static method for these tests
|
||||
performAddMemoryEntrySpy = vi
|
||||
.spyOn(MemoryTool, 'performAddMemoryEntry')
|
||||
@@ -300,7 +304,9 @@ describe('MemoryTool', () => {
|
||||
let memoryTool: MemoryTool;
|
||||
|
||||
beforeEach(() => {
|
||||
memoryTool = new MemoryTool();
|
||||
const bus = createMockMessageBus();
|
||||
getMockMessageBusInstance(bus).defaultToolDecision = 'ask_user';
|
||||
memoryTool = new MemoryTool(bus);
|
||||
// Clear the allowlist before each test
|
||||
const invocation = memoryTool.build({ fact: 'mock-fact' });
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
|
||||
@@ -179,7 +179,7 @@ class MemoryToolInvocation extends BaseToolInvocation<
|
||||
|
||||
constructor(
|
||||
params: SaveMemoryParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
toolName?: string,
|
||||
displayName?: string,
|
||||
) {
|
||||
@@ -298,16 +298,16 @@ export class MemoryTool
|
||||
{
|
||||
static readonly Name = MEMORY_TOOL_NAME;
|
||||
|
||||
constructor(messageBus?: MessageBus) {
|
||||
constructor(messageBus: MessageBus) {
|
||||
super(
|
||||
MemoryTool.Name,
|
||||
'SaveMemory',
|
||||
memoryToolDescription,
|
||||
Kind.Think,
|
||||
memoryToolSchemaData.parametersJsonSchema as Record<string, unknown>,
|
||||
messageBus,
|
||||
true,
|
||||
false,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -323,13 +323,13 @@ export class MemoryTool
|
||||
|
||||
protected createInvocation(
|
||||
params: SaveMemoryParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
toolName?: string,
|
||||
displayName?: string,
|
||||
) {
|
||||
return new MemoryToolInvocation(
|
||||
params,
|
||||
messageBus ?? this.messageBus,
|
||||
messageBus,
|
||||
toolName ?? this.name,
|
||||
displayName ?? this.displayName,
|
||||
);
|
||||
|
||||
@@ -81,21 +81,21 @@ class TestTool extends BaseDeclarativeTool<TestParams, TestResult> {
|
||||
},
|
||||
required: ['testParam'],
|
||||
},
|
||||
messageBus,
|
||||
true,
|
||||
false,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
protected createInvocation(
|
||||
params: TestParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
return new TestToolInvocation(
|
||||
params,
|
||||
messageBus ?? this.messageBus,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_toolDisplayName,
|
||||
);
|
||||
|
||||
@@ -50,7 +50,7 @@ class ReadFileToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private config: Config,
|
||||
params: ReadFileToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
@@ -149,7 +149,7 @@ export class ReadFileTool extends BaseDeclarativeTool<
|
||||
|
||||
constructor(
|
||||
private config: Config,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
super(
|
||||
ReadFileTool.Name,
|
||||
@@ -176,9 +176,9 @@ export class ReadFileTool extends BaseDeclarativeTool<
|
||||
required: ['file_path'],
|
||||
type: 'object',
|
||||
},
|
||||
messageBus,
|
||||
true,
|
||||
false,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -225,14 +225,14 @@ export class ReadFileTool extends BaseDeclarativeTool<
|
||||
|
||||
protected createInvocation(
|
||||
params: ReadFileToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<ReadFileToolParams, ToolResult> {
|
||||
return new ReadFileToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus ?? this.messageBus,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_toolDisplayName,
|
||||
);
|
||||
|
||||
@@ -107,7 +107,7 @@ class ReadManyFilesToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: ReadManyFilesParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
@@ -447,7 +447,7 @@ export class ReadManyFilesTool extends BaseDeclarativeTool<
|
||||
|
||||
constructor(
|
||||
private config: Config,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
const parameterSchema = {
|
||||
type: 'object',
|
||||
@@ -520,22 +520,22 @@ This tool is useful when you need to understand or analyze a collection of files
|
||||
Use this tool when the user's query implies needing the content of several files simultaneously for context, analysis, or summarization. For text files, it uses default UTF-8 encoding and a '--- {filePath} ---' separator between file contents. The tool inserts a '--- End of content ---' after the last file. Ensure glob patterns are relative to the target directory. Glob patterns like 'src/**/*.js' are supported. Avoid using for single files if a more specific single-file reading tool is available, unless the user specifically requests to process a list containing just one file via this tool. Other binary files (not explicitly requested as image/audio/PDF) are generally skipped. Default excludes apply to common non-text files (except for explicitly requested images/audio/PDFs) and large dependency directories unless 'useDefaultExcludes' is false.`,
|
||||
Kind.Read,
|
||||
parameterSchema,
|
||||
messageBus,
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
protected createInvocation(
|
||||
params: ReadManyFilesParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<ReadManyFilesParams, ToolResult> {
|
||||
return new ReadManyFilesToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus ?? this.messageBus,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_toolDisplayName,
|
||||
);
|
||||
|
||||
@@ -24,6 +24,7 @@ import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.j
|
||||
import type { ChildProcess } from 'node:child_process';
|
||||
import { spawn } from 'node:child_process';
|
||||
import { downloadRipGrep } from '@joshua.litt/get-ripgrep';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
// Mock dependencies for canUseRipgrep
|
||||
vi.mock('@joshua.litt/get-ripgrep', () => ({
|
||||
downloadRipGrep: vi.fn(),
|
||||
@@ -267,7 +268,7 @@ describe('RipGrepTool', () => {
|
||||
await fs.writeFile(ripgrepBinaryPath, '');
|
||||
storageSpy.mockImplementation(() => binDir);
|
||||
tempRootDir = await fs.mkdtemp(path.join(os.tmpdir(), 'grep-tool-root-'));
|
||||
grepTool = new RipGrepTool(mockConfig);
|
||||
grepTool = new RipGrepTool(mockConfig, createMockMessageBus());
|
||||
|
||||
// Create some test files and directories
|
||||
await fs.writeFile(
|
||||
@@ -833,7 +834,10 @@ describe('RipGrepTool', () => {
|
||||
return mockProcess as unknown as ChildProcess;
|
||||
});
|
||||
|
||||
const multiDirGrepTool = new RipGrepTool(multiDirConfig);
|
||||
const multiDirGrepTool = new RipGrepTool(
|
||||
multiDirConfig,
|
||||
createMockMessageBus(),
|
||||
);
|
||||
const params: RipGrepToolParams = { pattern: 'world' };
|
||||
const invocation = multiDirGrepTool.build(params);
|
||||
const result = await invocation.execute(abortSignal);
|
||||
@@ -927,7 +931,10 @@ describe('RipGrepTool', () => {
|
||||
return mockProcess as unknown as ChildProcess;
|
||||
});
|
||||
|
||||
const multiDirGrepTool = new RipGrepTool(multiDirConfig);
|
||||
const multiDirGrepTool = new RipGrepTool(
|
||||
multiDirConfig,
|
||||
createMockMessageBus(),
|
||||
);
|
||||
|
||||
// Search only in the 'sub' directory of the first workspace
|
||||
const params: RipGrepToolParams = { pattern: 'world', dir_path: 'sub' };
|
||||
@@ -1656,7 +1663,10 @@ describe('RipGrepTool', () => {
|
||||
getDebugMode: () => false,
|
||||
getFileFilteringRespectGeminiIgnore: () => true,
|
||||
} as unknown as Config;
|
||||
const geminiIgnoreTool = new RipGrepTool(configWithGeminiIgnore);
|
||||
const geminiIgnoreTool = new RipGrepTool(
|
||||
configWithGeminiIgnore,
|
||||
createMockMessageBus(),
|
||||
);
|
||||
|
||||
mockSpawn.mockImplementationOnce(
|
||||
createMockSpawn({
|
||||
@@ -1693,7 +1703,10 @@ describe('RipGrepTool', () => {
|
||||
getDebugMode: () => false,
|
||||
getFileFilteringRespectGeminiIgnore: () => false,
|
||||
} as unknown as Config;
|
||||
const geminiIgnoreTool = new RipGrepTool(configWithoutGeminiIgnore);
|
||||
const geminiIgnoreTool = new RipGrepTool(
|
||||
configWithoutGeminiIgnore,
|
||||
createMockMessageBus(),
|
||||
);
|
||||
|
||||
mockSpawn.mockImplementationOnce(
|
||||
createMockSpawn({
|
||||
@@ -1816,7 +1829,10 @@ describe('RipGrepTool', () => {
|
||||
getDebugMode: () => false,
|
||||
} as unknown as Config;
|
||||
|
||||
const multiDirGrepTool = new RipGrepTool(multiDirConfig);
|
||||
const multiDirGrepTool = new RipGrepTool(
|
||||
multiDirConfig,
|
||||
createMockMessageBus(),
|
||||
);
|
||||
const params: RipGrepToolParams = { pattern: 'testPattern' };
|
||||
const invocation = multiDirGrepTool.build(params);
|
||||
expect(invocation.getDescription()).toBe("'testPattern' within ./");
|
||||
|
||||
@@ -192,7 +192,7 @@ class GrepToolInvocation extends BaseToolInvocation<
|
||||
private readonly config: Config,
|
||||
private readonly geminiIgnoreParser: GeminiIgnoreParser,
|
||||
params: RipGrepToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
@@ -493,7 +493,7 @@ export class RipGrepTool extends BaseDeclarativeTool<
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
super(
|
||||
RipGrepTool.Name,
|
||||
@@ -551,9 +551,9 @@ export class RipGrepTool extends BaseDeclarativeTool<
|
||||
required: ['pattern'],
|
||||
type: 'object',
|
||||
},
|
||||
messageBus,
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
messageBus,
|
||||
);
|
||||
this.geminiIgnoreParser = new GeminiIgnoreParser(config.getTargetDir());
|
||||
}
|
||||
@@ -586,7 +586,7 @@ export class RipGrepTool extends BaseDeclarativeTool<
|
||||
|
||||
protected createInvocation(
|
||||
params: RipGrepToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<RipGrepToolParams, ToolResult> {
|
||||
|
||||
@@ -57,7 +57,7 @@ export class ShellToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: ShellToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
@@ -420,7 +420,7 @@ export class ShellTool extends BaseDeclarativeTool<
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
void initializeShellParsers().catch(() => {
|
||||
// Errors are surfaced when parsing commands.
|
||||
@@ -450,9 +450,9 @@ export class ShellTool extends BaseDeclarativeTool<
|
||||
},
|
||||
required: ['command'],
|
||||
},
|
||||
messageBus,
|
||||
false, // output is not markdown
|
||||
true, // output can be updated
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -478,14 +478,14 @@ export class ShellTool extends BaseDeclarativeTool<
|
||||
|
||||
protected createInvocation(
|
||||
params: ShellToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<ShellToolParams, ToolResult> {
|
||||
return new ShellToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus ?? this.messageBus,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_toolDisplayName,
|
||||
);
|
||||
|
||||
@@ -49,6 +49,10 @@ import {
|
||||
} from './smart-edit.js';
|
||||
import { type FileDiff, ToolConfirmationOutcome } from './tools.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import {
|
||||
createMockMessageBus,
|
||||
getMockMessageBusInstance,
|
||||
} from '../test-utils/mock-message-bus.js';
|
||||
import path from 'node:path';
|
||||
import fs from 'node:fs';
|
||||
import os from 'node:os';
|
||||
@@ -165,7 +169,9 @@ describe('SmartEditTool', () => {
|
||||
},
|
||||
);
|
||||
|
||||
tool = new SmartEditTool(mockConfig);
|
||||
const bus = createMockMessageBus();
|
||||
getMockMessageBusInstance(bus).defaultToolDecision = 'ask_user';
|
||||
tool = new SmartEditTool(mockConfig, bus);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
|
||||
@@ -386,7 +386,7 @@ class EditToolInvocation
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: EditToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
toolName?: string,
|
||||
displayName?: string,
|
||||
) {
|
||||
@@ -853,7 +853,7 @@ export class SmartEditTool
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
super(
|
||||
SmartEditTool.Name,
|
||||
@@ -915,9 +915,9 @@ A good instruction should concisely answer:
|
||||
required: ['file_path', 'instruction', 'old_string', 'new_string'],
|
||||
type: 'object',
|
||||
},
|
||||
messageBus,
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -955,12 +955,12 @@ A good instruction should concisely answer:
|
||||
|
||||
protected createInvocation(
|
||||
params: EditToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
): ToolInvocation<EditToolParams, ToolResult> {
|
||||
return new EditToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus ?? this.messageBus,
|
||||
messageBus,
|
||||
this.name,
|
||||
this.displayName,
|
||||
);
|
||||
|
||||
@@ -90,12 +90,26 @@ const createMockCallableTool = (
|
||||
});
|
||||
|
||||
// Helper to create a DiscoveredMCPTool
|
||||
const mockMessageBusForHelper = {
|
||||
publish: vi.fn(),
|
||||
subscribe: vi.fn(),
|
||||
unsubscribe: vi.fn(),
|
||||
} as unknown as MessageBus;
|
||||
|
||||
const createMCPTool = (
|
||||
serverName: string,
|
||||
toolName: string,
|
||||
description: string,
|
||||
mockCallable: CallableTool = {} as CallableTool,
|
||||
) => new DiscoveredMCPTool(mockCallable, serverName, toolName, description, {});
|
||||
) =>
|
||||
new DiscoveredMCPTool(
|
||||
mockCallable,
|
||||
serverName,
|
||||
toolName,
|
||||
description,
|
||||
{},
|
||||
mockMessageBusForHelper,
|
||||
);
|
||||
|
||||
// Helper to create a mock spawn process for tool discovery
|
||||
const createDiscoveryProcess = (toolDeclarations: FunctionDeclaration[]) => {
|
||||
@@ -171,6 +185,11 @@ const baseConfigParams: ConfigParameters = {
|
||||
describe('ToolRegistry', () => {
|
||||
let config: Config;
|
||||
let toolRegistry: ToolRegistry;
|
||||
const mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
subscribe: vi.fn(),
|
||||
unsubscribe: vi.fn(),
|
||||
} as unknown as MessageBus;
|
||||
let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>;
|
||||
let mockConfigGetExcludedTools: MockInstance<
|
||||
typeof Config.prototype.getExcludeTools
|
||||
@@ -182,7 +201,7 @@ describe('ToolRegistry', () => {
|
||||
isDirectory: () => true,
|
||||
} as fs.Stats);
|
||||
config = new Config(baseConfigParams);
|
||||
toolRegistry = new ToolRegistry(config);
|
||||
toolRegistry = new ToolRegistry(config, mockMessageBus);
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'debug').mockImplementation(() => {});
|
||||
@@ -372,6 +391,7 @@ describe('ToolRegistry', () => {
|
||||
DISCOVERED_TOOL_PREFIX + 'discovered-1',
|
||||
'desc',
|
||||
{},
|
||||
mockMessageBus,
|
||||
);
|
||||
const mcpZebra = createMCPTool('zebra-server', 'mcp-zebra', 'desc');
|
||||
const mcpApple = createMCPTool('apple-server', 'mcp-apple', 'desc');
|
||||
@@ -482,13 +502,6 @@ describe('ToolRegistry', () => {
|
||||
const discoveryCommand = 'my-discovery-command';
|
||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand);
|
||||
|
||||
const mockMessageBus = {
|
||||
publish: vi.fn(),
|
||||
subscribe: vi.fn(),
|
||||
unsubscribe: vi.fn(),
|
||||
} as unknown as MessageBus;
|
||||
toolRegistry.setMessageBus(mockMessageBus);
|
||||
|
||||
const toolDeclaration: FunctionDeclaration = {
|
||||
name: 'policy-test-tool',
|
||||
description: 'tests policy',
|
||||
@@ -520,6 +533,7 @@ describe('ToolRegistry', () => {
|
||||
DISCOVERED_TOOL_PREFIX + 'test-tool',
|
||||
'A test tool',
|
||||
{},
|
||||
mockMessageBus,
|
||||
);
|
||||
const params = { param: 'testValue' };
|
||||
const invocation = tool.build(params);
|
||||
|
||||
@@ -34,7 +34,7 @@ class DiscoveredToolInvocation extends BaseToolInvocation<
|
||||
private readonly originalToolName: string,
|
||||
prefixedToolName: string,
|
||||
params: ToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
super(params, messageBus, prefixedToolName);
|
||||
}
|
||||
@@ -135,7 +135,7 @@ export class DiscoveredTool extends BaseDeclarativeTool<
|
||||
prefixedName: string,
|
||||
description: string,
|
||||
override readonly parameterSchema: Record<string, unknown>,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
const discoveryCmd = config.getToolDiscoveryCommand()!;
|
||||
const callCommand = config.getToolCallCommand()!;
|
||||
@@ -163,16 +163,16 @@ Signal: Signal number or \`(none)\` if no signal was received.
|
||||
fullDescription,
|
||||
Kind.Other,
|
||||
parameterSchema,
|
||||
messageBus,
|
||||
false, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
messageBus,
|
||||
);
|
||||
this.originalName = originalName;
|
||||
}
|
||||
|
||||
protected createInvocation(
|
||||
params: ToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_displayName?: string,
|
||||
): ToolInvocation<ToolParams, ToolResult> {
|
||||
@@ -181,7 +181,7 @@ Signal: Signal number or \`(none)\` if no signal was received.
|
||||
this.originalName,
|
||||
_toolName ?? this.name,
|
||||
params,
|
||||
messageBus ?? this.messageBus,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -192,26 +192,17 @@ export class ToolRegistry {
|
||||
// and `isActive` to get only the active tools.
|
||||
private allKnownTools: Map<string, AnyDeclarativeTool> = new Map();
|
||||
private config: Config;
|
||||
private messageBus?: MessageBus;
|
||||
private messageBus: MessageBus;
|
||||
|
||||
constructor(config: Config, messageBus?: MessageBus) {
|
||||
constructor(config: Config, messageBus: MessageBus) {
|
||||
this.config = config;
|
||||
this.messageBus = messageBus;
|
||||
}
|
||||
|
||||
getMessageBus(): MessageBus | undefined {
|
||||
getMessageBus(): MessageBus {
|
||||
return this.messageBus;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated migration only - will be removed in PR 3 (Enforcement)
|
||||
* TODO: DELETE ME in PR 3. This is a temporary shim to allow for soft migration
|
||||
* of tools while the core infrastructure is updated to require a MessageBus at birth.
|
||||
*/
|
||||
setMessageBus(messageBus: MessageBus): void {
|
||||
this.messageBus = messageBus;
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers a tool definition.
|
||||
*
|
||||
|
||||
@@ -8,6 +8,7 @@ import { describe, it, expect, vi } from 'vitest';
|
||||
import type { ToolInvocation, ToolResult } from './tools.js';
|
||||
import { DeclarativeTool, hasCycleInSchema, Kind } from './tools.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
|
||||
class TestToolInvocation implements ToolInvocation<object, ToolResult> {
|
||||
constructor(
|
||||
@@ -36,7 +37,16 @@ class TestTool extends DeclarativeTool<object, ToolResult> {
|
||||
private readonly buildFn: (params: object) => TestToolInvocation;
|
||||
|
||||
constructor(buildFn: (params: object) => TestToolInvocation) {
|
||||
super('test-tool', 'Test Tool', 'A tool for testing', Kind.Other, {});
|
||||
super(
|
||||
'test-tool',
|
||||
'Test Tool',
|
||||
'A tool for testing',
|
||||
Kind.Other,
|
||||
{},
|
||||
createMockMessageBus(),
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
);
|
||||
this.buildFn = buildFn;
|
||||
}
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ export abstract class BaseToolInvocation<
|
||||
{
|
||||
constructor(
|
||||
readonly params: TParams,
|
||||
protected readonly messageBus?: MessageBus,
|
||||
protected readonly messageBus: MessageBus,
|
||||
readonly _toolName?: string,
|
||||
readonly _toolDisplayName?: string,
|
||||
readonly _serverName?: string,
|
||||
@@ -98,25 +98,24 @@ export abstract class BaseToolInvocation<
|
||||
async shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.messageBus) {
|
||||
const decision = await this.getMessageBusDecision(abortSignal);
|
||||
if (decision === 'ALLOW') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decision === 'DENY') {
|
||||
throw new Error(
|
||||
`Tool execution for "${
|
||||
this._toolDisplayName || this._toolName
|
||||
}" denied by policy.`,
|
||||
);
|
||||
}
|
||||
|
||||
if (decision === 'ASK_USER') {
|
||||
return this.getConfirmationDetails(abortSignal);
|
||||
}
|
||||
const decision = await this.getMessageBusDecision(abortSignal);
|
||||
if (decision === 'ALLOW') {
|
||||
return false;
|
||||
}
|
||||
// When no message bus, use default confirmation flow
|
||||
|
||||
if (decision === 'DENY') {
|
||||
throw new Error(
|
||||
`Tool execution for "${
|
||||
this._toolDisplayName || this._toolName
|
||||
}" denied by policy.`,
|
||||
);
|
||||
}
|
||||
|
||||
if (decision === 'ASK_USER') {
|
||||
return this.getConfirmationDetails(abortSignal);
|
||||
}
|
||||
|
||||
// Default to confirmation details if decision is unknown (should not happen with exhaustive policy)
|
||||
return this.getConfirmationDetails(abortSignal);
|
||||
}
|
||||
|
||||
@@ -142,7 +141,7 @@ export abstract class BaseToolInvocation<
|
||||
outcome === ToolConfirmationOutcome.ProceedAlways ||
|
||||
outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave
|
||||
) {
|
||||
if (this.messageBus && this._toolName) {
|
||||
if (this._toolName) {
|
||||
const options = this.getPolicyUpdateOptions(outcome);
|
||||
await this.messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
@@ -206,7 +205,7 @@ export abstract class BaseToolInvocation<
|
||||
timeoutId = undefined;
|
||||
}
|
||||
abortSignal.removeEventListener('abort', abortHandler);
|
||||
this.messageBus?.unsubscribe(
|
||||
this.messageBus.unsubscribe(
|
||||
MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
responseHandler,
|
||||
);
|
||||
@@ -341,9 +340,9 @@ export abstract class DeclarativeTool<
|
||||
readonly description: string,
|
||||
readonly kind: Kind,
|
||||
readonly parameterSchema: unknown,
|
||||
readonly messageBus: MessageBus,
|
||||
readonly isOutputMarkdown: boolean = true,
|
||||
readonly canUpdateOutput: boolean = false,
|
||||
readonly messageBus?: MessageBus,
|
||||
readonly extensionName?: string,
|
||||
readonly extensionId?: string,
|
||||
) {}
|
||||
@@ -496,7 +495,7 @@ export abstract class BaseDeclarativeTool<
|
||||
|
||||
protected abstract createInvocation(
|
||||
params: TParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<TParams, TResult>;
|
||||
|
||||
@@ -10,6 +10,10 @@ import type { Config } from '../config/config.js';
|
||||
import { ApprovalMode } from '../policy/types.js';
|
||||
import { ToolConfirmationOutcome } from './tools.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import {
|
||||
createMockMessageBus,
|
||||
getMockMessageBusInstance,
|
||||
} from '../test-utils/mock-message-bus.js';
|
||||
import * as fetchUtils from '../utils/fetch.js';
|
||||
import { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { PolicyEngine } from '../policy/policy-engine.js';
|
||||
@@ -126,9 +130,12 @@ describe('parsePrompt', () => {
|
||||
|
||||
describe('WebFetchTool', () => {
|
||||
let mockConfig: Config;
|
||||
let bus: MessageBus;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
bus = createMockMessageBus();
|
||||
getMockMessageBusInstance(bus).defaultToolDecision = 'ask_user';
|
||||
mockConfig = {
|
||||
getApprovalMode: vi.fn(),
|
||||
setApprovalMode: vi.fn(),
|
||||
@@ -163,12 +170,12 @@ describe('WebFetchTool', () => {
|
||||
expectedError: 'Error(s) in prompt URLs:',
|
||||
},
|
||||
])('should throw if $name', ({ prompt, expectedError }) => {
|
||||
const tool = new WebFetchTool(mockConfig);
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
expect(() => tool.build({ prompt })).toThrow(expectedError);
|
||||
});
|
||||
|
||||
it('should pass if prompt contains at least one valid URL', () => {
|
||||
const tool = new WebFetchTool(mockConfig);
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
expect(() =>
|
||||
tool.build({ prompt: 'fetch https://example.com' }),
|
||||
).not.toThrow();
|
||||
@@ -181,7 +188,7 @@ describe('WebFetchTool', () => {
|
||||
vi.spyOn(fetchUtils, 'fetchWithTimeout').mockRejectedValue(
|
||||
new Error('fetch failed'),
|
||||
);
|
||||
const tool = new WebFetchTool(mockConfig);
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = { prompt: 'fetch https://private.ip' };
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
@@ -191,7 +198,7 @@ describe('WebFetchTool', () => {
|
||||
it('should return WEB_FETCH_PROCESSING_ERROR on general processing failure', async () => {
|
||||
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false);
|
||||
mockGenerateContent.mockRejectedValue(new Error('API error'));
|
||||
const tool = new WebFetchTool(mockConfig);
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = { prompt: 'fetch https://public.ip' };
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
@@ -209,7 +216,7 @@ describe('WebFetchTool', () => {
|
||||
candidates: [{ content: { parts: [{ text: 'fallback response' }] } }],
|
||||
});
|
||||
|
||||
const tool = new WebFetchTool(mockConfig);
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = { prompt: 'fetch https://private.ip' };
|
||||
const invocation = tool.build(params);
|
||||
await invocation.execute(new AbortController().signal);
|
||||
@@ -237,7 +244,7 @@ describe('WebFetchTool', () => {
|
||||
candidates: [{ content: { parts: [{ text: 'fallback response' }] } }],
|
||||
});
|
||||
|
||||
const tool = new WebFetchTool(mockConfig);
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = { prompt: 'fetch https://public.ip' };
|
||||
const invocation = tool.build(params);
|
||||
await invocation.execute(new AbortController().signal);
|
||||
@@ -306,7 +313,7 @@ describe('WebFetchTool', () => {
|
||||
],
|
||||
}));
|
||||
|
||||
const tool = new WebFetchTool(mockConfig);
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = { prompt: 'fetch https://example.com' };
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
@@ -330,7 +337,7 @@ describe('WebFetchTool', () => {
|
||||
|
||||
describe('shouldConfirmExecute', () => {
|
||||
it('should return confirmation details with the correct prompt and parsed urls', async () => {
|
||||
const tool = new WebFetchTool(mockConfig);
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = { prompt: 'fetch https://example.com' };
|
||||
const invocation = tool.build(params);
|
||||
const confirmationDetails = await invocation.shouldConfirmExecute(
|
||||
@@ -347,7 +354,7 @@ describe('WebFetchTool', () => {
|
||||
});
|
||||
|
||||
it('should convert github urls to raw format', async () => {
|
||||
const tool = new WebFetchTool(mockConfig);
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = {
|
||||
prompt:
|
||||
'fetch https://github.com/google/gemini-react/blob/main/README.md',
|
||||
@@ -373,7 +380,7 @@ describe('WebFetchTool', () => {
|
||||
vi.spyOn(mockConfig, 'getApprovalMode').mockReturnValue(
|
||||
ApprovalMode.AUTO_EDIT,
|
||||
);
|
||||
const tool = new WebFetchTool(mockConfig);
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = { prompt: 'fetch https://example.com' };
|
||||
const invocation = tool.build(params);
|
||||
const confirmationDetails = await invocation.shouldConfirmExecute(
|
||||
@@ -384,7 +391,7 @@ describe('WebFetchTool', () => {
|
||||
});
|
||||
|
||||
it('should call setApprovalMode when onConfirm is called with ProceedAlways', async () => {
|
||||
const tool = new WebFetchTool(mockConfig);
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = { prompt: 'fetch https://example.com' };
|
||||
const invocation = tool.build(params);
|
||||
const confirmationDetails = await invocation.shouldConfirmExecute(
|
||||
@@ -412,8 +419,8 @@ describe('WebFetchTool', () => {
|
||||
let messageBus: MessageBus;
|
||||
let mockUUID: Mock;
|
||||
|
||||
const createToolWithMessageBus = (bus?: MessageBus) => {
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const createToolWithMessageBus = (customBus?: MessageBus) => {
|
||||
const tool = new WebFetchTool(mockConfig, customBus ?? bus);
|
||||
const params = { prompt: 'fetch https://example.com' };
|
||||
return { tool, invocation: tool.build(params) };
|
||||
};
|
||||
@@ -516,16 +523,6 @@ describe('WebFetchTool', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should fall back to legacy confirmation when no message bus', async () => {
|
||||
const { invocation } = createToolWithMessageBus(); // No message bus
|
||||
const result = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
expect(result).not.toBe(false);
|
||||
expect(result).toHaveProperty('type', 'info');
|
||||
});
|
||||
|
||||
it('should ignore responses with wrong correlation ID', async () => {
|
||||
vi.useFakeTimers();
|
||||
const { invocation } = createToolWithMessageBus(messageBus);
|
||||
|
||||
@@ -114,7 +114,7 @@ class WebFetchToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: WebFetchToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
@@ -218,7 +218,8 @@ ${textContent}
|
||||
protected override async getConfirmationDetails(
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
// Legacy confirmation flow (no message bus OR policy decision was ASK_USER)
|
||||
// Check for AUTO_EDIT approval mode. This tool has a specific behavior
|
||||
// where ProceedAlways switches the entire session to AUTO_EDIT.
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
return false;
|
||||
}
|
||||
@@ -406,7 +407,7 @@ export class WebFetchTool extends BaseDeclarativeTool<
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
super(
|
||||
WebFetchTool.Name,
|
||||
@@ -424,9 +425,9 @@ export class WebFetchTool extends BaseDeclarativeTool<
|
||||
required: ['prompt'],
|
||||
type: 'object',
|
||||
},
|
||||
messageBus,
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -452,14 +453,14 @@ export class WebFetchTool extends BaseDeclarativeTool<
|
||||
|
||||
protected createInvocation(
|
||||
params: WebFetchToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<WebFetchToolParams, ToolResult> {
|
||||
return new WebFetchToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus ?? this.messageBus,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_toolDisplayName,
|
||||
);
|
||||
|
||||
@@ -11,6 +11,7 @@ import { WebSearchTool } from './web-search.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
|
||||
// Mock GeminiClient and Config constructor
|
||||
vi.mock('../core/client.js');
|
||||
@@ -33,7 +34,7 @@ describe('WebSearchTool', () => {
|
||||
},
|
||||
} as unknown as Config;
|
||||
mockGeminiClient = new GeminiClient(mockConfigInstance);
|
||||
tool = new WebSearchTool(mockConfigInstance);
|
||||
tool = new WebSearchTool(mockConfigInstance, createMockMessageBus());
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
|
||||
@@ -65,7 +65,7 @@ class WebSearchToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: WebSearchToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
@@ -192,7 +192,7 @@ export class WebSearchTool extends BaseDeclarativeTool<
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
super(
|
||||
WebSearchTool.Name,
|
||||
@@ -209,9 +209,9 @@ export class WebSearchTool extends BaseDeclarativeTool<
|
||||
},
|
||||
required: ['query'],
|
||||
},
|
||||
messageBus,
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -231,7 +231,7 @@ export class WebSearchTool extends BaseDeclarativeTool<
|
||||
|
||||
protected createInvocation(
|
||||
params: WebSearchToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
): ToolInvocation<WebSearchToolParams, WebSearchToolResult> {
|
||||
|
||||
@@ -149,7 +149,7 @@ class WriteFileToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: WriteFileToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
toolName?: string,
|
||||
displayName?: string,
|
||||
) {
|
||||
@@ -409,7 +409,7 @@ export class WriteFileTool
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
) {
|
||||
super(
|
||||
WriteFileTool.Name,
|
||||
@@ -432,9 +432,9 @@ export class WriteFileTool
|
||||
required: ['file_path', 'content'],
|
||||
type: 'object',
|
||||
},
|
||||
messageBus,
|
||||
true,
|
||||
false,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -475,7 +475,7 @@ export class WriteFileTool
|
||||
|
||||
protected createInvocation(
|
||||
params: WriteFileToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
): ToolInvocation<WriteFileToolParams, ToolResult> {
|
||||
return new WriteFileToolInvocation(
|
||||
this.config,
|
||||
|
||||
@@ -6,9 +6,10 @@
|
||||
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { WriteTodosTool, type WriteTodosToolParams } from './write-todos.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
|
||||
describe('WriteTodosTool', () => {
|
||||
const tool = new WriteTodosTool();
|
||||
const tool = new WriteTodosTool(createMockMessageBus());
|
||||
const signal = new AbortController().signal;
|
||||
|
||||
describe('validation', () => {
|
||||
|
||||
@@ -101,7 +101,7 @@ class WriteTodosToolInvocation extends BaseToolInvocation<
|
||||
> {
|
||||
constructor(
|
||||
params: WriteTodosToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
@@ -145,7 +145,7 @@ export class WriteTodosTool extends BaseDeclarativeTool<
|
||||
> {
|
||||
static readonly Name = WRITE_TODOS_TOOL_NAME;
|
||||
|
||||
constructor(messageBus?: MessageBus) {
|
||||
constructor(messageBus: MessageBus) {
|
||||
super(
|
||||
WriteTodosTool.Name,
|
||||
'WriteTodos',
|
||||
@@ -180,9 +180,9 @@ export class WriteTodosTool extends BaseDeclarativeTool<
|
||||
required: ['todos'],
|
||||
additionalProperties: false,
|
||||
},
|
||||
messageBus,
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -251,13 +251,13 @@ export class WriteTodosTool extends BaseDeclarativeTool<
|
||||
|
||||
protected createInvocation(
|
||||
params: WriteTodosToolParams,
|
||||
messageBus?: MessageBus,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
_displayName?: string,
|
||||
): ToolInvocation<WriteTodosToolParams, ToolResult> {
|
||||
return new WriteTodosToolInvocation(
|
||||
params,
|
||||
messageBus ?? this.messageBus,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_displayName,
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user