diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index 0df9fd58eb..3b582abe89 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -7,10 +7,8 @@ import { describe, it, expect, vi } from 'vitest'; import type { Mock } from 'vitest'; import type { CallableTool } from '@google/genai'; -import { - CoreToolScheduler, - PLAN_MODE_DENIAL_MESSAGE, -} from './coreToolScheduler.js'; +import { CoreToolScheduler } from './coreToolScheduler.js'; +import { PLAN_MODE_DENIAL_MESSAGE } from '../scheduler/policy.js'; import type { ToolCall, WaitingToolCall, diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 30093f289e..96cb05d970 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -14,7 +14,7 @@ import { } from '../tools/tools.js'; import type { EditorType } from '../utils/editor.js'; import type { Config } from '../config/config.js'; -import { PolicyDecision, ApprovalMode } from '../policy/types.js'; +import { PolicyDecision } from '../policy/types.js'; import { logToolCall } from '../telemetry/loggers.js'; import { ToolErrorType } from '../tools/tool-error.js'; import { ToolCallEvent } from '../telemetry/types.js'; @@ -44,6 +44,7 @@ import { } from '../scheduler/types.js'; import { ToolExecutor } from '../scheduler/tool-executor.js'; import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; +import { getPolicyDenialError } from '../scheduler/policy.js'; export type { ToolCall, @@ -64,9 +65,6 @@ export type { ToolCallResponseInfo, }; -export const PLAN_MODE_DENIAL_MESSAGE = - 'You are in Plan Mode - adjust your prompt to only use read and search tools.'; - const createErrorResponse = ( request: ToolCallRequestInfo, error: Error, @@ -599,18 +597,15 @@ export class CoreToolScheduler { ? toolCall.tool.serverName : undefined; - const { decision } = await this.config + const { decision, rule } = await this.config .getPolicyEngine() .check(toolCallForPolicy, serverName); if (decision === PolicyDecision.DENY) { - let errorMessage = `Tool execution denied by policy.`; - let errorType = ToolErrorType.POLICY_VIOLATION; - - if (this.config.getApprovalMode() === ApprovalMode.PLAN) { - errorMessage = PLAN_MODE_DENIAL_MESSAGE; - errorType = ToolErrorType.STOP_EXECUTION; - } + const { errorMessage, errorType } = getPolicyDenialError( + this.config, + rule, + ); this.setStatusInternal( reqInfo.callId, 'error', diff --git a/packages/core/src/scheduler/policy.test.ts b/packages/core/src/scheduler/policy.test.ts index 57703abe3c..ad32b93f93 100644 --- a/packages/core/src/scheduler/policy.test.ts +++ b/packages/core/src/scheduler/policy.test.ts @@ -4,8 +4,20 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, vi, type Mocked } from 'vitest'; -import { checkPolicy, updatePolicy } from './policy.js'; +import { + describe, + it, + expect, + vi, + type Mocked, + beforeEach, + afterEach, +} from 'vitest'; +import { + checkPolicy, + updatePolicy, + PLAN_MODE_DENIAL_MESSAGE, +} from './policy.js'; import type { Config } from '../config/config.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { MessageBusType } from '../confirmation-bus/types.js'; @@ -15,10 +27,20 @@ import { type AnyDeclarativeTool, type ToolMcpConfirmationDetails, type ToolExecuteConfirmationDetails, + type AnyToolInvocation, } from '../tools/tools.js'; -import type { ValidatingToolCall } from './types.js'; +import type { + ValidatingToolCall, + ToolCallRequestInfo, + CompletedToolCall, +} from './types.js'; import type { PolicyEngine } from '../policy/policy-engine.js'; import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; +import { CoreToolScheduler } from '../core/coreToolScheduler.js'; +import { Scheduler } from './scheduler.js'; +import { ROOT_SCHEDULER_ID } from './types.js'; +import { ToolErrorType } from '../tools/tool-error.js'; +import type { ToolRegistry } from '../tools/tool-registry.js'; describe('policy.ts', () => { describe('checkPolicy', () => { @@ -420,3 +442,113 @@ describe('policy.ts', () => { }); }); }); + +describe('Plan Mode Denial Consistency', () => { + let mockConfig: Mocked; + let mockMessageBus: Mocked; + let mockPolicyEngine: Mocked; + let mockToolRegistry: Mocked; + let mockTool: AnyDeclarativeTool; + let mockInvocation: AnyToolInvocation; + + const req: ToolCallRequestInfo = { + callId: 'call-1', + name: 'test-tool', + args: { foo: 'bar' }, + isClientInitiated: false, + prompt_id: 'prompt-1', + schedulerId: ROOT_SCHEDULER_ID, + }; + + beforeEach(() => { + mockTool = { + name: 'test-tool', + build: vi.fn(), + } as unknown as AnyDeclarativeTool; + + mockInvocation = { + shouldConfirmExecute: vi.fn(), + } as unknown as AnyToolInvocation; + vi.mocked(mockTool.build).mockReturnValue(mockInvocation); + + mockPolicyEngine = { + check: vi.fn().mockResolvedValue({ decision: PolicyDecision.DENY }), // Default to DENY for this test + } as unknown as Mocked; + + mockToolRegistry = { + getTool: vi.fn().mockReturnValue(mockTool), + getAllToolNames: vi.fn().mockReturnValue(['test-tool']), + } as unknown as Mocked; + + mockMessageBus = { + publish: vi.fn(), + subscribe: vi.fn(), + } as unknown as Mocked; + + mockConfig = { + getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), + getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + getMessageBus: vi.fn().mockReturnValue(mockMessageBus), + isInteractive: vi.fn().mockReturnValue(true), + getEnableHooks: vi.fn().mockReturnValue(false), + getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.PLAN), // Key: Plan Mode + setApprovalMode: vi.fn(), + getUsageStatisticsEnabled: vi.fn().mockReturnValue(false), + } as unknown as Mocked; + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + describe.each([ + { enableEventDrivenScheduler: false, name: 'Legacy CoreToolScheduler' }, + { enableEventDrivenScheduler: true, name: 'Event-Driven Scheduler' }, + ])('$name', ({ enableEventDrivenScheduler }) => { + it('should return the correct Plan Mode denial message when policy denies execution', async () => { + let resultMessage: string | undefined; + let resultErrorType: ToolErrorType | undefined; + + const signal = new AbortController().signal; + + if (enableEventDrivenScheduler) { + const scheduler = new Scheduler({ + config: mockConfig, + messageBus: mockMessageBus, + getPreferredEditor: () => undefined, + schedulerId: ROOT_SCHEDULER_ID, + }); + + const results = await scheduler.schedule(req, signal); + const result = results[0]; + + expect(result.status).toBe('error'); + if (result.status === 'error') { + resultMessage = result.response.error?.message; + resultErrorType = result.response.errorType; + } + } else { + let capturedCalls: CompletedToolCall[] = []; + const scheduler = new CoreToolScheduler({ + config: mockConfig, + getPreferredEditor: () => undefined, + onAllToolCallsComplete: async (calls) => { + capturedCalls = calls; + }, + }); + + await scheduler.schedule(req, signal); + + expect(capturedCalls.length).toBeGreaterThan(0); + const call = capturedCalls[0]; + if (call.status === 'error') { + resultMessage = call.response.error?.message; + resultErrorType = call.response.errorType; + } + } + + expect(resultMessage).toBe(PLAN_MODE_DENIAL_MESSAGE); + expect(resultErrorType).toBe(ToolErrorType.STOP_EXECUTION); + }); + }); +}); diff --git a/packages/core/src/scheduler/policy.ts b/packages/core/src/scheduler/policy.ts index d28ca6dad6..279dea85c7 100644 --- a/packages/core/src/scheduler/policy.ts +++ b/packages/core/src/scheduler/policy.ts @@ -4,10 +4,12 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { ToolErrorType } from '../tools/tool-error.js'; import { ApprovalMode, PolicyDecision, type CheckResult, + type PolicyRule, } from '../policy/types.js'; import type { Config } from '../config/config.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; @@ -24,6 +26,30 @@ import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; import { EDIT_TOOL_NAMES } from '../tools/tool-names.js'; import type { ValidatingToolCall } from './types.js'; +export const PLAN_MODE_DENIAL_MESSAGE = + 'You are in Plan Mode - adjust your prompt to only use read and search tools.'; + +/** + * Helper to determine the error message and type for a policy denial. + */ +export function getPolicyDenialError( + config: Config, + rule?: PolicyRule, +): { errorMessage: string; errorType: ToolErrorType } { + if (config.getApprovalMode() === ApprovalMode.PLAN) { + return { + errorMessage: PLAN_MODE_DENIAL_MESSAGE, + errorType: ToolErrorType.STOP_EXECUTION, + }; + } + + const denyMessage = rule?.denyMessage ? ` ${rule.denyMessage}` : ''; + return { + errorMessage: `Tool execution denied by policy.${denyMessage}`, + errorType: ToolErrorType.POLICY_VIOLATION, + }; +} + /** * Queries the system PolicyEngine to determine tool allowance. * @returns The PolicyDecision. diff --git a/packages/core/src/scheduler/scheduler.test.ts b/packages/core/src/scheduler/scheduler.test.ts index 4ae3e84c8c..7fd815a597 100644 --- a/packages/core/src/scheduler/scheduler.test.ts +++ b/packages/core/src/scheduler/scheduler.test.ts @@ -46,7 +46,14 @@ import { ToolModificationHandler } from './tool-modifier.js'; vi.mock('./state-manager.js'); vi.mock('./confirmation.js'); -vi.mock('./policy.js'); +vi.mock('./policy.js', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + checkPolicy: vi.fn(), + updatePolicy: vi.fn(), + }; +}); vi.mock('./tool-executor.js'); vi.mock('./tool-modifier.js'); @@ -55,7 +62,7 @@ import type { Config } from '../config/config.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; import type { PolicyEngine } from '../policy/policy-engine.js'; import type { ToolRegistry } from '../tools/tool-registry.js'; -import { PolicyDecision } from '../policy/types.js'; +import { PolicyDecision, ApprovalMode } from '../policy/types.js'; import { ToolConfirmationOutcome, type AnyDeclarativeTool, @@ -149,6 +156,7 @@ describe('Scheduler (Orchestrator)', () => { isInteractive: vi.fn().mockReturnValue(true), getEnableHooks: vi.fn().mockReturnValue(true), setApprovalMode: vi.fn(), + getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT), } as unknown as Mocked; mockMessageBus = { diff --git a/packages/core/src/scheduler/scheduler.ts b/packages/core/src/scheduler/scheduler.ts index 0589c50a72..71729923d0 100644 --- a/packages/core/src/scheduler/scheduler.ts +++ b/packages/core/src/scheduler/scheduler.ts @@ -8,7 +8,7 @@ import type { Config } from '../config/config.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { SchedulerStateManager } from './state-manager.js'; import { resolveConfirmation } from './confirmation.js'; -import { checkPolicy, updatePolicy } from './policy.js'; +import { checkPolicy, updatePolicy, getPolicyDenialError } from './policy.js'; import { ToolExecutor } from './tool-executor.js'; import { ToolModificationHandler } from './tool-modifier.js'; import { @@ -407,14 +407,18 @@ export class Scheduler { const { decision, rule } = await checkPolicy(toolCall, this.config); if (decision === PolicyDecision.DENY) { - const denyMessage = rule?.denyMessage ? ` ${rule.denyMessage}` : ''; + const { errorMessage, errorType } = getPolicyDenialError( + this.config, + rule, + ); + this.state.updateStatus( callId, 'error', createErrorResponse( toolCall.request, - new Error(`Tool execution denied by policy.${denyMessage}`), - ToolErrorType.POLICY_VIOLATION, + new Error(errorMessage), + errorType, ), ); this.state.finalizeCall(callId);