feat(plan): handle inconsistency in schedulers (#17813)

This commit is contained in:
Adib234
2026-02-02 22:10:04 -05:00
committed by GitHub
parent 18cce6a9ab
commit 01e33465bd
6 changed files with 188 additions and 25 deletions

View File

@@ -7,10 +7,8 @@
import { describe, it, expect, vi } from 'vitest';
import type { Mock } from 'vitest';
import type { CallableTool } from '@google/genai';
import {
CoreToolScheduler,
PLAN_MODE_DENIAL_MESSAGE,
} from './coreToolScheduler.js';
import { CoreToolScheduler } from './coreToolScheduler.js';
import { PLAN_MODE_DENIAL_MESSAGE } from '../scheduler/policy.js';
import type {
ToolCall,
WaitingToolCall,

View File

@@ -14,7 +14,7 @@ import {
} from '../tools/tools.js';
import type { EditorType } from '../utils/editor.js';
import type { Config } from '../config/config.js';
import { PolicyDecision, ApprovalMode } from '../policy/types.js';
import { PolicyDecision } from '../policy/types.js';
import { logToolCall } from '../telemetry/loggers.js';
import { ToolErrorType } from '../tools/tool-error.js';
import { ToolCallEvent } from '../telemetry/types.js';
@@ -44,6 +44,7 @@ import {
} from '../scheduler/types.js';
import { ToolExecutor } from '../scheduler/tool-executor.js';
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
import { getPolicyDenialError } from '../scheduler/policy.js';
export type {
ToolCall,
@@ -64,9 +65,6 @@ export type {
ToolCallResponseInfo,
};
export const PLAN_MODE_DENIAL_MESSAGE =
'You are in Plan Mode - adjust your prompt to only use read and search tools.';
const createErrorResponse = (
request: ToolCallRequestInfo,
error: Error,
@@ -599,18 +597,15 @@ export class CoreToolScheduler {
? toolCall.tool.serverName
: undefined;
const { decision } = await this.config
const { decision, rule } = await this.config
.getPolicyEngine()
.check(toolCallForPolicy, serverName);
if (decision === PolicyDecision.DENY) {
let errorMessage = `Tool execution denied by policy.`;
let errorType = ToolErrorType.POLICY_VIOLATION;
if (this.config.getApprovalMode() === ApprovalMode.PLAN) {
errorMessage = PLAN_MODE_DENIAL_MESSAGE;
errorType = ToolErrorType.STOP_EXECUTION;
}
const { errorMessage, errorType } = getPolicyDenialError(
this.config,
rule,
);
this.setStatusInternal(
reqInfo.callId,
'error',

View File

@@ -4,8 +4,20 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, type Mocked } from 'vitest';
import { checkPolicy, updatePolicy } from './policy.js';
import {
describe,
it,
expect,
vi,
type Mocked,
beforeEach,
afterEach,
} from 'vitest';
import {
checkPolicy,
updatePolicy,
PLAN_MODE_DENIAL_MESSAGE,
} from './policy.js';
import type { Config } from '../config/config.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { MessageBusType } from '../confirmation-bus/types.js';
@@ -15,10 +27,20 @@ import {
type AnyDeclarativeTool,
type ToolMcpConfirmationDetails,
type ToolExecuteConfirmationDetails,
type AnyToolInvocation,
} from '../tools/tools.js';
import type { ValidatingToolCall } from './types.js';
import type {
ValidatingToolCall,
ToolCallRequestInfo,
CompletedToolCall,
} from './types.js';
import type { PolicyEngine } from '../policy/policy-engine.js';
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
import { CoreToolScheduler } from '../core/coreToolScheduler.js';
import { Scheduler } from './scheduler.js';
import { ROOT_SCHEDULER_ID } from './types.js';
import { ToolErrorType } from '../tools/tool-error.js';
import type { ToolRegistry } from '../tools/tool-registry.js';
describe('policy.ts', () => {
describe('checkPolicy', () => {
@@ -420,3 +442,113 @@ describe('policy.ts', () => {
});
});
});
describe('Plan Mode Denial Consistency', () => {
let mockConfig: Mocked<Config>;
let mockMessageBus: Mocked<MessageBus>;
let mockPolicyEngine: Mocked<PolicyEngine>;
let mockToolRegistry: Mocked<ToolRegistry>;
let mockTool: AnyDeclarativeTool;
let mockInvocation: AnyToolInvocation;
const req: ToolCallRequestInfo = {
callId: 'call-1',
name: 'test-tool',
args: { foo: 'bar' },
isClientInitiated: false,
prompt_id: 'prompt-1',
schedulerId: ROOT_SCHEDULER_ID,
};
beforeEach(() => {
mockTool = {
name: 'test-tool',
build: vi.fn(),
} as unknown as AnyDeclarativeTool;
mockInvocation = {
shouldConfirmExecute: vi.fn(),
} as unknown as AnyToolInvocation;
vi.mocked(mockTool.build).mockReturnValue(mockInvocation);
mockPolicyEngine = {
check: vi.fn().mockResolvedValue({ decision: PolicyDecision.DENY }), // Default to DENY for this test
} as unknown as Mocked<PolicyEngine>;
mockToolRegistry = {
getTool: vi.fn().mockReturnValue(mockTool),
getAllToolNames: vi.fn().mockReturnValue(['test-tool']),
} as unknown as Mocked<ToolRegistry>;
mockMessageBus = {
publish: vi.fn(),
subscribe: vi.fn(),
} as unknown as Mocked<MessageBus>;
mockConfig = {
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
isInteractive: vi.fn().mockReturnValue(true),
getEnableHooks: vi.fn().mockReturnValue(false),
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.PLAN), // Key: Plan Mode
setApprovalMode: vi.fn(),
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
} as unknown as Mocked<Config>;
});
afterEach(() => {
vi.clearAllMocks();
});
describe.each([
{ enableEventDrivenScheduler: false, name: 'Legacy CoreToolScheduler' },
{ enableEventDrivenScheduler: true, name: 'Event-Driven Scheduler' },
])('$name', ({ enableEventDrivenScheduler }) => {
it('should return the correct Plan Mode denial message when policy denies execution', async () => {
let resultMessage: string | undefined;
let resultErrorType: ToolErrorType | undefined;
const signal = new AbortController().signal;
if (enableEventDrivenScheduler) {
const scheduler = new Scheduler({
config: mockConfig,
messageBus: mockMessageBus,
getPreferredEditor: () => undefined,
schedulerId: ROOT_SCHEDULER_ID,
});
const results = await scheduler.schedule(req, signal);
const result = results[0];
expect(result.status).toBe('error');
if (result.status === 'error') {
resultMessage = result.response.error?.message;
resultErrorType = result.response.errorType;
}
} else {
let capturedCalls: CompletedToolCall[] = [];
const scheduler = new CoreToolScheduler({
config: mockConfig,
getPreferredEditor: () => undefined,
onAllToolCallsComplete: async (calls) => {
capturedCalls = calls;
},
});
await scheduler.schedule(req, signal);
expect(capturedCalls.length).toBeGreaterThan(0);
const call = capturedCalls[0];
if (call.status === 'error') {
resultMessage = call.response.error?.message;
resultErrorType = call.response.errorType;
}
}
expect(resultMessage).toBe(PLAN_MODE_DENIAL_MESSAGE);
expect(resultErrorType).toBe(ToolErrorType.STOP_EXECUTION);
});
});
});

View File

@@ -4,10 +4,12 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { ToolErrorType } from '../tools/tool-error.js';
import {
ApprovalMode,
PolicyDecision,
type CheckResult,
type PolicyRule,
} from '../policy/types.js';
import type { Config } from '../config/config.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
@@ -24,6 +26,30 @@ import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
import { EDIT_TOOL_NAMES } from '../tools/tool-names.js';
import type { ValidatingToolCall } from './types.js';
export const PLAN_MODE_DENIAL_MESSAGE =
'You are in Plan Mode - adjust your prompt to only use read and search tools.';
/**
* Helper to determine the error message and type for a policy denial.
*/
export function getPolicyDenialError(
config: Config,
rule?: PolicyRule,
): { errorMessage: string; errorType: ToolErrorType } {
if (config.getApprovalMode() === ApprovalMode.PLAN) {
return {
errorMessage: PLAN_MODE_DENIAL_MESSAGE,
errorType: ToolErrorType.STOP_EXECUTION,
};
}
const denyMessage = rule?.denyMessage ? ` ${rule.denyMessage}` : '';
return {
errorMessage: `Tool execution denied by policy.${denyMessage}`,
errorType: ToolErrorType.POLICY_VIOLATION,
};
}
/**
* Queries the system PolicyEngine to determine tool allowance.
* @returns The PolicyDecision.

View File

@@ -46,7 +46,14 @@ import { ToolModificationHandler } from './tool-modifier.js';
vi.mock('./state-manager.js');
vi.mock('./confirmation.js');
vi.mock('./policy.js');
vi.mock('./policy.js', async (importOriginal) => {
const actual = await importOriginal<typeof import('./policy.js')>();
return {
...actual,
checkPolicy: vi.fn(),
updatePolicy: vi.fn(),
};
});
vi.mock('./tool-executor.js');
vi.mock('./tool-modifier.js');
@@ -55,7 +62,7 @@ import type { Config } from '../config/config.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import type { PolicyEngine } from '../policy/policy-engine.js';
import type { ToolRegistry } from '../tools/tool-registry.js';
import { PolicyDecision } from '../policy/types.js';
import { PolicyDecision, ApprovalMode } from '../policy/types.js';
import {
ToolConfirmationOutcome,
type AnyDeclarativeTool,
@@ -149,6 +156,7 @@ describe('Scheduler (Orchestrator)', () => {
isInteractive: vi.fn().mockReturnValue(true),
getEnableHooks: vi.fn().mockReturnValue(true),
setApprovalMode: vi.fn(),
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
} as unknown as Mocked<Config>;
mockMessageBus = {

View File

@@ -8,7 +8,7 @@ import type { Config } from '../config/config.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { SchedulerStateManager } from './state-manager.js';
import { resolveConfirmation } from './confirmation.js';
import { checkPolicy, updatePolicy } from './policy.js';
import { checkPolicy, updatePolicy, getPolicyDenialError } from './policy.js';
import { ToolExecutor } from './tool-executor.js';
import { ToolModificationHandler } from './tool-modifier.js';
import {
@@ -407,14 +407,18 @@ export class Scheduler {
const { decision, rule } = await checkPolicy(toolCall, this.config);
if (decision === PolicyDecision.DENY) {
const denyMessage = rule?.denyMessage ? ` ${rule.denyMessage}` : '';
const { errorMessage, errorType } = getPolicyDenialError(
this.config,
rule,
);
this.state.updateStatus(
callId,
'error',
createErrorResponse(
toolCall.request,
new Error(`Tool execution denied by policy.${denyMessage}`),
ToolErrorType.POLICY_VIOLATION,
new Error(errorMessage),
errorType,
),
);
this.state.finalizeCall(callId);