mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-17 17:41:24 -07:00
feat(core): wire up UI for ASK_USER policy decisions in message bus (#10630)
This commit is contained in:
@@ -50,6 +50,9 @@ export function createMockConfig(
|
||||
getEmbeddingModel: vi.fn().mockReturnValue('text-embedding-004'),
|
||||
getSessionId: vi.fn().mockReturnValue('test-session-id'),
|
||||
getUserTier: vi.fn(),
|
||||
getEnableMessageBusIntegration: vi.fn().mockReturnValue(false),
|
||||
getMessageBus: vi.fn(),
|
||||
getPolicyEngine: vi.fn(),
|
||||
...overrides,
|
||||
} as unknown as Config;
|
||||
|
||||
|
||||
@@ -70,6 +70,9 @@ const mockConfig = {
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
getShellExecutionConfig: () => ({ terminalWidth: 80, terminalHeight: 24 }),
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
getPolicyEngine: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const mockTool = new MockTool({
|
||||
|
||||
@@ -24,6 +24,11 @@ export interface ToolConfirmationResponse {
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE;
|
||||
correlationId: string;
|
||||
confirmed: boolean;
|
||||
/**
|
||||
* When true, indicates that policy decision was ASK_USER and the tool should
|
||||
* show its legacy confirmation UI instead of auto-proceeding.
|
||||
*/
|
||||
requiresUserConfirmation?: boolean;
|
||||
}
|
||||
|
||||
export interface ToolPolicyRejection {
|
||||
|
||||
@@ -255,6 +255,9 @@ describe('CoreToolScheduler', () => {
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
getPolicyEngine: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -332,6 +335,9 @@ describe('CoreToolScheduler', () => {
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null,
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
getPolicyEngine: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -365,15 +371,18 @@ describe('CoreToolScheduler', () => {
|
||||
describe('getToolSuggestion', () => {
|
||||
it('should suggest the top N closest tool names for a typo', () => {
|
||||
// Create mocked tool registry
|
||||
const mockToolRegistry = {
|
||||
getAllToolNames: () => ['list_files', 'read_file', 'write_file'],
|
||||
} as unknown as ToolRegistry;
|
||||
const mockConfig = {
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
getPolicyEngine: () => null,
|
||||
} as unknown as Config;
|
||||
const mockToolRegistry = {
|
||||
getAllToolNames: () => ['list_files', 'read_file', 'write_file'],
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
// Create scheduler
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -448,6 +457,9 @@ describe('CoreToolScheduler with payload', () => {
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
getPolicyEngine: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -768,6 +780,9 @@ describe('CoreToolScheduler edit cancellation', () => {
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
getPolicyEngine: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -874,6 +889,9 @@ describe('CoreToolScheduler YOLO mode', () => {
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
getPolicyEngine: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -981,6 +999,9 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
getPolicyEngine: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -1113,6 +1134,9 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
getPolicyEngine: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -1215,6 +1239,9 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
getPolicyEngine: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -1287,6 +1314,9 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
getPolicyEngine: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const testTool = new TestApprovalTool(mockConfig);
|
||||
@@ -1475,6 +1505,8 @@ describe('CoreToolScheduler Sequential Execution', () => {
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null,
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -1595,6 +1627,8 @@ describe('CoreToolScheduler Sequential Execution', () => {
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null,
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
|
||||
@@ -42,6 +42,8 @@ import * as path from 'node:path';
|
||||
import { doesToolInvocationMatch } from '../utils/tool-utils.js';
|
||||
import levenshtein from 'fast-levenshtein';
|
||||
import { ShellToolInvocation } from '../tools/shell.js';
|
||||
import type { ToolConfirmationRequest } from '../confirmation-bus/types.js';
|
||||
import { MessageBusType } from '../confirmation-bus/types.js';
|
||||
|
||||
export type ValidatingToolCall = {
|
||||
status: 'validating';
|
||||
@@ -352,6 +354,15 @@ export class CoreToolScheduler {
|
||||
this.onToolCallsUpdate = options.onToolCallsUpdate;
|
||||
this.getPreferredEditor = options.getPreferredEditor;
|
||||
this.onEditorClose = options.onEditorClose;
|
||||
|
||||
// Subscribe to message bus for ASK_USER policy decisions
|
||||
if (this.config.getEnableMessageBusIntegration()) {
|
||||
const messageBus = this.config.getMessageBus();
|
||||
messageBus.subscribe(
|
||||
MessageBusType.TOOL_CONFIRMATION_REQUEST,
|
||||
this.handleToolConfirmationRequest.bind(this),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private setStatusInternal(
|
||||
@@ -1160,6 +1171,26 @@ export class CoreToolScheduler {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle tool confirmation requests from the message bus when policy decision is ASK_USER.
|
||||
* This publishes a response with requiresUserConfirmation=true to signal the tool
|
||||
* that it should fall back to its legacy confirmation UI.
|
||||
*/
|
||||
private handleToolConfirmationRequest(
|
||||
request: ToolConfirmationRequest,
|
||||
): void {
|
||||
// When ASK_USER policy decision is made, the message bus emits the request here.
|
||||
// We respond with requiresUserConfirmation=true to tell the tool to use its
|
||||
// legacy confirmation flow (which will show diffs, URLs, etc in the UI).
|
||||
const messageBus = this.config.getMessageBus();
|
||||
messageBus.publish({
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
correlationId: request.correlationId,
|
||||
confirmed: false, // Not auto-approved
|
||||
requiresUserConfirmation: true, // Use legacy UI confirmation
|
||||
});
|
||||
}
|
||||
|
||||
private async autoApproveCompatiblePendingTools(
|
||||
signal: AbortSignal,
|
||||
triggeringCallId: string,
|
||||
|
||||
@@ -62,6 +62,9 @@ describe('executeToolCall', () => {
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
getEnableMessageBusIntegration: () => false,
|
||||
getMessageBus: () => null,
|
||||
getPolicyEngine: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
abortController = new AbortController();
|
||||
|
||||
@@ -52,6 +52,22 @@ class TestToolInvocation extends BaseToolInvocation<TestParams, TestResult> {
|
||||
testValue: this.params.testParam,
|
||||
};
|
||||
}
|
||||
|
||||
override async shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<false> {
|
||||
// This conditional is here to allow testing of the case where there is no message bus.
|
||||
if (this.messageBus) {
|
||||
const decision = await this.getMessageBusDecision(abortSignal);
|
||||
if (decision === 'ALLOW') {
|
||||
return false;
|
||||
}
|
||||
if (decision === 'DENY') {
|
||||
throw new Error('Tool execution denied by policy');
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
class TestTool extends BaseDeclarativeTool<TestParams, TestResult> {
|
||||
@@ -200,7 +216,7 @@ describe('Message Bus Integration', () => {
|
||||
abortController.abort();
|
||||
|
||||
await expect(confirmationPromise).rejects.toThrow(
|
||||
'Tool confirmation aborted',
|
||||
'Tool execution denied by policy',
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ export abstract class BaseToolInvocation<
|
||||
protected readonly messageBus?: MessageBus,
|
||||
) {
|
||||
if (this.messageBus) {
|
||||
console.log(
|
||||
console.debug(
|
||||
`[DEBUG] Tool ${this.constructor.name} created with messageBus: YES`,
|
||||
);
|
||||
}
|
||||
@@ -93,27 +93,17 @@ export abstract class BaseToolInvocation<
|
||||
shouldConfirmExecute(
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
// If message bus is available, use it for confirmation
|
||||
if (this.messageBus) {
|
||||
console.log(
|
||||
`[DEBUG] Using message bus for tool confirmation: ${this.constructor.name}`,
|
||||
);
|
||||
return this.handleMessageBusConfirmation(_abortSignal);
|
||||
}
|
||||
|
||||
// Fall back to existing confirmation flow
|
||||
// Default implementation for tools that don't override it.
|
||||
return Promise.resolve(false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle tool confirmation using the message bus.
|
||||
* This method publishes a confirmation request and waits for the response.
|
||||
*/
|
||||
protected async handleMessageBusConfirmation(
|
||||
protected getMessageBusDecision(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
): Promise<'ALLOW' | 'DENY' | 'ASK_USER'> {
|
||||
if (!this.messageBus) {
|
||||
return false;
|
||||
// If there's no message bus, we can't make a decision, so we allow.
|
||||
// The legacy confirmation flow will still apply if the tool needs it.
|
||||
return Promise.resolve('ALLOW');
|
||||
}
|
||||
|
||||
const correlationId = randomUUID();
|
||||
@@ -122,85 +112,74 @@ export abstract class BaseToolInvocation<
|
||||
args: this.params as Record<string, unknown>,
|
||||
};
|
||||
|
||||
return new Promise<ToolCallConfirmationDetails | false>(
|
||||
(resolve, reject) => {
|
||||
if (!this.messageBus) {
|
||||
resolve(false);
|
||||
return;
|
||||
return new Promise<'ALLOW' | 'DENY' | 'ASK_USER'>((resolve) => {
|
||||
if (!this.messageBus) {
|
||||
resolve('ALLOW');
|
||||
return;
|
||||
}
|
||||
|
||||
let timeoutId: NodeJS.Timeout | undefined;
|
||||
|
||||
const cleanup = () => {
|
||||
if (timeoutId) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = undefined;
|
||||
}
|
||||
|
||||
let timeoutId: NodeJS.Timeout | undefined;
|
||||
|
||||
// Centralized cleanup function
|
||||
const cleanup = () => {
|
||||
if (timeoutId) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = undefined;
|
||||
}
|
||||
abortSignal.removeEventListener('abort', abortHandler);
|
||||
this.messageBus?.unsubscribe(
|
||||
MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
responseHandler,
|
||||
);
|
||||
};
|
||||
|
||||
// Set up abort handler
|
||||
const abortHandler = () => {
|
||||
cleanup();
|
||||
reject(new Error('Tool confirmation aborted'));
|
||||
};
|
||||
|
||||
// Check if already aborted
|
||||
if (abortSignal.aborted) {
|
||||
reject(new Error('Tool confirmation aborted'));
|
||||
return;
|
||||
}
|
||||
|
||||
// Set up response handler
|
||||
const responseHandler = (response: ToolConfirmationResponse) => {
|
||||
if (response.correlationId === correlationId) {
|
||||
cleanup();
|
||||
|
||||
if (response.confirmed) {
|
||||
// Tool was confirmed, return false to indicate no further confirmation needed
|
||||
resolve(false);
|
||||
} else {
|
||||
// Tool was denied, reject to prevent execution
|
||||
reject(new Error('Tool execution denied by policy'));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Add event listener for abort signal
|
||||
abortSignal.addEventListener('abort', abortHandler);
|
||||
|
||||
// Set up timeout
|
||||
timeoutId = setTimeout(() => {
|
||||
cleanup();
|
||||
resolve(false);
|
||||
}, 30000); // 30 second timeout
|
||||
|
||||
// Subscribe to response
|
||||
this.messageBus.subscribe(
|
||||
abortSignal.removeEventListener('abort', abortHandler);
|
||||
this.messageBus?.unsubscribe(
|
||||
MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
responseHandler,
|
||||
);
|
||||
};
|
||||
|
||||
// Publish confirmation request
|
||||
const request: ToolConfirmationRequest = {
|
||||
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
|
||||
toolCall,
|
||||
correlationId,
|
||||
};
|
||||
const abortHandler = () => {
|
||||
cleanup();
|
||||
resolve('DENY');
|
||||
};
|
||||
|
||||
try {
|
||||
this.messageBus.publish(request);
|
||||
} catch (_error) {
|
||||
if (abortSignal.aborted) {
|
||||
resolve('DENY');
|
||||
return;
|
||||
}
|
||||
|
||||
const responseHandler = (response: ToolConfirmationResponse) => {
|
||||
if (response.correlationId === correlationId) {
|
||||
cleanup();
|
||||
resolve(false);
|
||||
if (response.requiresUserConfirmation) {
|
||||
resolve('ASK_USER');
|
||||
} else if (response.confirmed) {
|
||||
resolve('ALLOW');
|
||||
} else {
|
||||
resolve('DENY');
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
abortSignal.addEventListener('abort', abortHandler);
|
||||
|
||||
timeoutId = setTimeout(() => {
|
||||
cleanup();
|
||||
resolve('ASK_USER'); // Default to ASK_USER on timeout
|
||||
}, 30000);
|
||||
|
||||
this.messageBus.subscribe(
|
||||
MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
responseHandler,
|
||||
);
|
||||
|
||||
const request: ToolConfirmationRequest = {
|
||||
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
|
||||
toolCall,
|
||||
correlationId,
|
||||
};
|
||||
|
||||
try {
|
||||
this.messageBus.publish(request);
|
||||
} catch (_error) {
|
||||
cleanup();
|
||||
resolve('ALLOW');
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
abstract execute(
|
||||
|
||||
@@ -4,13 +4,20 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
|
||||
import { WebFetchTool, parsePrompt } from './web-fetch.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ApprovalMode } from '../config/config.js';
|
||||
import { ToolConfirmationOutcome } from './tools.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import * as fetchUtils from '../utils/fetch.js';
|
||||
import { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { PolicyEngine } from '../policy/policy-engine.js';
|
||||
import {
|
||||
MessageBusType,
|
||||
type ToolConfirmationResponse,
|
||||
} from '../confirmation-bus/types.js';
|
||||
import { randomUUID } from 'node:crypto';
|
||||
import {
|
||||
logWebFetchFallbackAttempt,
|
||||
WebFetchFallbackAttemptEvent,
|
||||
@@ -35,6 +42,10 @@ vi.mock('../utils/fetch.js', async (importOriginal) => {
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('node:crypto', () => ({
|
||||
randomUUID: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('parsePrompt', () => {
|
||||
it('should extract valid URLs separated by whitespace', () => {
|
||||
const prompt = 'Go to https://example.com and http://google.com';
|
||||
@@ -313,4 +324,229 @@ describe('WebFetchTool', () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Message Bus Integration', () => {
|
||||
let policyEngine: PolicyEngine;
|
||||
let messageBus: MessageBus;
|
||||
let mockUUID: Mock;
|
||||
|
||||
beforeEach(() => {
|
||||
policyEngine = new PolicyEngine();
|
||||
messageBus = new MessageBus(policyEngine);
|
||||
mockUUID = vi.mocked(randomUUID);
|
||||
mockUUID.mockReturnValue('test-correlation-id');
|
||||
});
|
||||
|
||||
it('should use message bus for confirmation when available', async () => {
|
||||
const tool = new WebFetchTool(mockConfig, messageBus);
|
||||
const params = { prompt: 'fetch https://example.com' };
|
||||
const invocation = tool.build(params);
|
||||
|
||||
// Mock message bus publish and subscribe
|
||||
const publishSpy = vi.spyOn(messageBus, 'publish');
|
||||
const subscribeSpy = vi.spyOn(messageBus, 'subscribe');
|
||||
const unsubscribeSpy = vi.spyOn(messageBus, 'unsubscribe');
|
||||
|
||||
// Start confirmation process
|
||||
const confirmationPromise = invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// Verify confirmation request was published
|
||||
expect(publishSpy).toHaveBeenCalledWith({
|
||||
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
|
||||
toolCall: {
|
||||
name: 'WebFetchToolInvocation',
|
||||
args: { prompt: 'fetch https://example.com' },
|
||||
},
|
||||
correlationId: 'test-correlation-id',
|
||||
});
|
||||
|
||||
// Verify subscription to response
|
||||
expect(subscribeSpy).toHaveBeenCalledWith(
|
||||
MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
expect.any(Function),
|
||||
);
|
||||
|
||||
// Simulate confirmation response
|
||||
const responseHandler = subscribeSpy.mock.calls[0][1];
|
||||
const response: ToolConfirmationResponse = {
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
correlationId: 'test-correlation-id',
|
||||
confirmed: true,
|
||||
};
|
||||
|
||||
responseHandler(response);
|
||||
|
||||
const result = await confirmationPromise;
|
||||
expect(result).toBe(false); // No further confirmation needed
|
||||
expect(unsubscribeSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should reject promise when confirmation is denied via message bus', async () => {
|
||||
const tool = new WebFetchTool(mockConfig, messageBus);
|
||||
const params = { prompt: 'fetch https://example.com' };
|
||||
const invocation = tool.build(params);
|
||||
|
||||
const subscribeSpy = vi.spyOn(messageBus, 'subscribe');
|
||||
|
||||
const confirmationPromise = invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// Simulate denial response
|
||||
const responseHandler = subscribeSpy.mock.calls[0][1];
|
||||
const response: ToolConfirmationResponse = {
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
correlationId: 'test-correlation-id',
|
||||
confirmed: false,
|
||||
};
|
||||
|
||||
responseHandler(response);
|
||||
|
||||
// Should reject with error when denied
|
||||
await expect(confirmationPromise).rejects.toThrow(
|
||||
'Tool execution denied by policy',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle timeout gracefully', async () => {
|
||||
vi.useFakeTimers();
|
||||
|
||||
const tool = new WebFetchTool(mockConfig, messageBus);
|
||||
const params = { prompt: 'fetch https://example.com' };
|
||||
const invocation = tool.build(params);
|
||||
|
||||
const confirmationPromise = invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// Fast-forward past timeout
|
||||
await vi.advanceTimersByTimeAsync(30000);
|
||||
const result = await confirmationPromise;
|
||||
expect(result).not.toBe(false);
|
||||
expect(result).toHaveProperty('type', 'info');
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should handle abort signal during confirmation', async () => {
|
||||
const tool = new WebFetchTool(mockConfig, messageBus);
|
||||
const params = { prompt: 'fetch https://example.com' };
|
||||
const invocation = tool.build(params);
|
||||
|
||||
const abortController = new AbortController();
|
||||
const confirmationPromise = invocation.shouldConfirmExecute(
|
||||
abortController.signal,
|
||||
);
|
||||
|
||||
// Abort the operation
|
||||
abortController.abort();
|
||||
|
||||
await expect(confirmationPromise).rejects.toThrow(
|
||||
'Tool execution denied by policy.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should fall back to legacy confirmation when no message bus', async () => {
|
||||
const tool = new WebFetchTool(mockConfig); // No message bus
|
||||
const params = { prompt: 'fetch https://example.com' };
|
||||
const invocation = tool.build(params);
|
||||
|
||||
const result = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// Should use legacy confirmation flow (returns confirmation details, not false)
|
||||
expect(result).not.toBe(false);
|
||||
expect(result).toHaveProperty('type', 'info');
|
||||
});
|
||||
|
||||
it('should ignore responses with wrong correlation ID', async () => {
|
||||
vi.useFakeTimers();
|
||||
|
||||
const tool = new WebFetchTool(mockConfig, messageBus);
|
||||
const params = { prompt: 'fetch https://example.com' };
|
||||
const invocation = tool.build(params);
|
||||
|
||||
const subscribeSpy = vi.spyOn(messageBus, 'subscribe');
|
||||
const confirmationPromise = invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// Send response with wrong correlation ID
|
||||
const responseHandler = subscribeSpy.mock.calls[0][1];
|
||||
const wrongResponse: ToolConfirmationResponse = {
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
correlationId: 'wrong-id',
|
||||
confirmed: true,
|
||||
};
|
||||
|
||||
responseHandler(wrongResponse);
|
||||
|
||||
// Should timeout since correct response wasn't received
|
||||
await vi.advanceTimersByTimeAsync(30000);
|
||||
const result = await confirmationPromise;
|
||||
expect(result).not.toBe(false);
|
||||
expect(result).toHaveProperty('type', 'info');
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should handle message bus publish errors gracefully', async () => {
|
||||
const tool = new WebFetchTool(mockConfig, messageBus);
|
||||
const params = { prompt: 'fetch https://example.com' };
|
||||
const invocation = tool.build(params);
|
||||
|
||||
// Mock publish to throw error
|
||||
vi.spyOn(messageBus, 'publish').mockImplementation(() => {
|
||||
throw new Error('Message bus error');
|
||||
});
|
||||
|
||||
const result = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(result).toBe(false); // Should gracefully fall back
|
||||
});
|
||||
|
||||
it('should execute normally after confirmation approval', async () => {
|
||||
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false);
|
||||
mockGenerateContent.mockResolvedValue({
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: 'Fetched content from https://example.com' }],
|
||||
role: 'model',
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const tool = new WebFetchTool(mockConfig, messageBus);
|
||||
const params = { prompt: 'fetch https://example.com' };
|
||||
const invocation = tool.build(params);
|
||||
|
||||
const subscribeSpy = vi.spyOn(messageBus, 'subscribe');
|
||||
|
||||
// Start confirmation
|
||||
const confirmationPromise = invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// Approve via message bus
|
||||
const responseHandler = subscribeSpy.mock.calls[0][1];
|
||||
responseHandler({
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
correlationId: 'test-correlation-id',
|
||||
confirmed: true,
|
||||
});
|
||||
|
||||
await confirmationPromise;
|
||||
|
||||
// Execute the tool
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
expect(result.error).toBeUndefined();
|
||||
expect(result.llmContent).toContain('Fetched content');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
Kind,
|
||||
ToolConfirmationOutcome,
|
||||
} from './tools.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
@@ -107,8 +108,9 @@ class WebFetchToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: WebFetchToolParams,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
super(params);
|
||||
super(params, messageBus);
|
||||
}
|
||||
|
||||
private async executeFallback(signal: AbortSignal): Promise<ToolResult> {
|
||||
@@ -181,9 +183,22 @@ ${textContent}
|
||||
return `Processing URLs and instructions from prompt: "${displayPrompt}"`;
|
||||
}
|
||||
|
||||
override async shouldConfirmExecute(): Promise<
|
||||
ToolCallConfirmationDetails | false
|
||||
> {
|
||||
override async shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
// Try message bus confirmation first if available
|
||||
if (this.messageBus) {
|
||||
const decision = await this.getMessageBusDecision(abortSignal);
|
||||
if (decision === 'ALLOW') {
|
||||
return false; // No confirmation needed
|
||||
}
|
||||
if (decision === 'DENY') {
|
||||
throw new Error('Tool execution denied by policy.');
|
||||
}
|
||||
// if 'ASK_USER', fall through to legacy logic
|
||||
}
|
||||
|
||||
// Legacy confirmation flow (no message bus OR policy decision was ASK_USER)
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
return false;
|
||||
}
|
||||
@@ -366,7 +381,10 @@ export class WebFetchTool extends BaseDeclarativeTool<
|
||||
> {
|
||||
static readonly Name: string = 'web_fetch';
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
super(
|
||||
WebFetchTool.Name,
|
||||
'WebFetch',
|
||||
@@ -383,6 +401,9 @@ export class WebFetchTool extends BaseDeclarativeTool<
|
||||
required: ['prompt'],
|
||||
type: 'object',
|
||||
},
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
messageBus,
|
||||
);
|
||||
const proxy = config.getProxy();
|
||||
if (proxy) {
|
||||
@@ -412,7 +433,8 @@ export class WebFetchTool extends BaseDeclarativeTool<
|
||||
|
||||
protected createInvocation(
|
||||
params: WebFetchToolParams,
|
||||
messageBus?: MessageBus,
|
||||
): ToolInvocation<WebFetchToolParams, ToolResult> {
|
||||
return new WebFetchToolInvocation(this.config, params);
|
||||
return new WebFetchToolInvocation(this.config, params, messageBus);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user