This commit is contained in:
A.K.M. Adib
2026-01-28 15:33:09 -05:00
parent 065e69a12b
commit 41d08b1196
5 changed files with 53 additions and 22 deletions
@@ -7,10 +7,8 @@
import { describe, it, expect, vi } from 'vitest'; import { describe, it, expect, vi } from 'vitest';
import type { Mock } from 'vitest'; import type { Mock } from 'vitest';
import type { CallableTool } from '@google/genai'; import type { CallableTool } from '@google/genai';
import { import { CoreToolScheduler } from './coreToolScheduler.js';
CoreToolScheduler, import { PLAN_MODE_DENIAL_MESSAGE } from '../scheduler/policy.js';
PLAN_MODE_DENIAL_MESSAGE,
} from './coreToolScheduler.js';
import type { import type {
ToolCall, ToolCall,
WaitingToolCall, WaitingToolCall,
+7 -12
View File
@@ -14,7 +14,7 @@ import {
} from '../tools/tools.js'; } from '../tools/tools.js';
import type { EditorType } from '../utils/editor.js'; import type { EditorType } from '../utils/editor.js';
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
import { PolicyDecision, ApprovalMode } from '../policy/types.js'; import { PolicyDecision } from '../policy/types.js';
import { logToolCall } from '../telemetry/loggers.js'; import { logToolCall } from '../telemetry/loggers.js';
import { ToolErrorType } from '../tools/tool-error.js'; import { ToolErrorType } from '../tools/tool-error.js';
import { ToolCallEvent } from '../telemetry/types.js'; import { ToolCallEvent } from '../telemetry/types.js';
@@ -44,6 +44,7 @@ import {
} from '../scheduler/types.js'; } from '../scheduler/types.js';
import { ToolExecutor } from '../scheduler/tool-executor.js'; import { ToolExecutor } from '../scheduler/tool-executor.js';
import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
import { getPolicyDenialError } from '../scheduler/policy.js';
export type { export type {
ToolCall, ToolCall,
@@ -64,9 +65,6 @@ export type {
ToolCallResponseInfo, ToolCallResponseInfo,
}; };
export const PLAN_MODE_DENIAL_MESSAGE =
'You are in Plan Mode - adjust your prompt to only use read and search tools.';
const createErrorResponse = ( const createErrorResponse = (
request: ToolCallRequestInfo, request: ToolCallRequestInfo,
error: Error, error: Error,
@@ -599,18 +597,15 @@ export class CoreToolScheduler {
? toolCall.tool.serverName ? toolCall.tool.serverName
: undefined; : undefined;
const { decision } = await this.config const { decision, rule } = await this.config
.getPolicyEngine() .getPolicyEngine()
.check(toolCallForPolicy, serverName); .check(toolCallForPolicy, serverName);
if (decision === PolicyDecision.DENY) { if (decision === PolicyDecision.DENY) {
let errorMessage = `Tool execution denied by policy.`; const { errorMessage, errorType } = getPolicyDenialError(
let errorType = ToolErrorType.POLICY_VIOLATION; this.config,
rule,
if (this.config.getApprovalMode() === ApprovalMode.PLAN) { );
errorMessage = PLAN_MODE_DENIAL_MESSAGE;
errorType = ToolErrorType.STOP_EXECUTION;
}
this.setStatusInternal( this.setStatusInternal(
reqInfo.callId, reqInfo.callId,
'error', 'error',
+26
View File
@@ -4,10 +4,12 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import { ToolErrorType } from '../tools/tool-error.js';
import { import {
ApprovalMode, ApprovalMode,
PolicyDecision, PolicyDecision,
type CheckResult, type CheckResult,
type PolicyRule,
} from '../policy/types.js'; } from '../policy/types.js';
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js';
@@ -24,6 +26,30 @@ import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
import { EDIT_TOOL_NAMES } from '../tools/tool-names.js'; import { EDIT_TOOL_NAMES } from '../tools/tool-names.js';
import type { ValidatingToolCall } from './types.js'; import type { ValidatingToolCall } from './types.js';
export const PLAN_MODE_DENIAL_MESSAGE =
'You are in Plan Mode - adjust your prompt to only use read and search tools.';
/**
* Helper to determine the error message and type for a policy denial.
*/
export function getPolicyDenialError(
config: Config,
rule?: PolicyRule,
): { errorMessage: string; errorType: ToolErrorType } {
if (config.getApprovalMode() === ApprovalMode.PLAN) {
return {
errorMessage: PLAN_MODE_DENIAL_MESSAGE,
errorType: ToolErrorType.STOP_EXECUTION,
};
}
const denyMessage = rule?.denyMessage ? ` ${rule.denyMessage}` : '';
return {
errorMessage: `Tool execution denied by policy.${denyMessage}`,
errorType: ToolErrorType.POLICY_VIOLATION,
};
}
/** /**
* Queries the system PolicyEngine to determine tool allowance. * Queries the system PolicyEngine to determine tool allowance.
* @returns The PolicyDecision. * @returns The PolicyDecision.
+10 -2
View File
@@ -46,7 +46,14 @@ import { ToolModificationHandler } from './tool-modifier.js';
vi.mock('./state-manager.js'); vi.mock('./state-manager.js');
vi.mock('./confirmation.js'); vi.mock('./confirmation.js');
vi.mock('./policy.js'); vi.mock('./policy.js', async (importOriginal) => {
const actual = await importOriginal<typeof import('./policy.js')>();
return {
...actual,
checkPolicy: vi.fn(),
updatePolicy: vi.fn(),
};
});
vi.mock('./tool-executor.js'); vi.mock('./tool-executor.js');
vi.mock('./tool-modifier.js'); vi.mock('./tool-modifier.js');
@@ -55,7 +62,7 @@ import type { Config } from '../config/config.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js';
import type { PolicyEngine } from '../policy/policy-engine.js'; import type { PolicyEngine } from '../policy/policy-engine.js';
import type { ToolRegistry } from '../tools/tool-registry.js'; import type { ToolRegistry } from '../tools/tool-registry.js';
import { PolicyDecision } from '../policy/types.js'; import { PolicyDecision, ApprovalMode } from '../policy/types.js';
import { import {
ToolConfirmationOutcome, ToolConfirmationOutcome,
type AnyDeclarativeTool, type AnyDeclarativeTool,
@@ -149,6 +156,7 @@ describe('Scheduler (Orchestrator)', () => {
isInteractive: vi.fn().mockReturnValue(true), isInteractive: vi.fn().mockReturnValue(true),
getEnableHooks: vi.fn().mockReturnValue(true), getEnableHooks: vi.fn().mockReturnValue(true),
setApprovalMode: vi.fn(), setApprovalMode: vi.fn(),
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
} as unknown as Mocked<Config>; } as unknown as Mocked<Config>;
mockMessageBus = { mockMessageBus = {
+8 -4
View File
@@ -8,7 +8,7 @@ import type { Config } from '../config/config.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { SchedulerStateManager } from './state-manager.js'; import { SchedulerStateManager } from './state-manager.js';
import { resolveConfirmation } from './confirmation.js'; import { resolveConfirmation } from './confirmation.js';
import { checkPolicy, updatePolicy } from './policy.js'; import { checkPolicy, updatePolicy, getPolicyDenialError } from './policy.js';
import { ToolExecutor } from './tool-executor.js'; import { ToolExecutor } from './tool-executor.js';
import { ToolModificationHandler } from './tool-modifier.js'; import { ToolModificationHandler } from './tool-modifier.js';
import { import {
@@ -407,14 +407,18 @@ export class Scheduler {
const { decision, rule } = await checkPolicy(toolCall, this.config); const { decision, rule } = await checkPolicy(toolCall, this.config);
if (decision === PolicyDecision.DENY) { if (decision === PolicyDecision.DENY) {
const denyMessage = rule?.denyMessage ? ` ${rule.denyMessage}` : ''; const { errorMessage, errorType } = getPolicyDenialError(
this.config,
rule,
);
this.state.updateStatus( this.state.updateStatus(
callId, callId,
'error', 'error',
createErrorResponse( createErrorResponse(
toolCall.request, toolCall.request,
new Error(`Tool execution denied by policy.${denyMessage}`), new Error(errorMessage),
ToolErrorType.POLICY_VIOLATION, errorType,
), ),
); );
this.state.finalizeCall(callId); this.state.finalizeCall(callId);