mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-12 21:03:05 -07:00
feat(plan): handle inconsistency in schedulers (#17813)
This commit is contained in:
@@ -7,10 +7,8 @@
|
|||||||
import { describe, it, expect, vi } from 'vitest';
|
import { describe, it, expect, vi } from 'vitest';
|
||||||
import type { Mock } from 'vitest';
|
import type { Mock } from 'vitest';
|
||||||
import type { CallableTool } from '@google/genai';
|
import type { CallableTool } from '@google/genai';
|
||||||
import {
|
import { CoreToolScheduler } from './coreToolScheduler.js';
|
||||||
CoreToolScheduler,
|
import { PLAN_MODE_DENIAL_MESSAGE } from '../scheduler/policy.js';
|
||||||
PLAN_MODE_DENIAL_MESSAGE,
|
|
||||||
} from './coreToolScheduler.js';
|
|
||||||
import type {
|
import type {
|
||||||
ToolCall,
|
ToolCall,
|
||||||
WaitingToolCall,
|
WaitingToolCall,
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import {
|
|||||||
} from '../tools/tools.js';
|
} from '../tools/tools.js';
|
||||||
import type { EditorType } from '../utils/editor.js';
|
import type { EditorType } from '../utils/editor.js';
|
||||||
import type { Config } from '../config/config.js';
|
import type { Config } from '../config/config.js';
|
||||||
import { PolicyDecision, ApprovalMode } from '../policy/types.js';
|
import { PolicyDecision } from '../policy/types.js';
|
||||||
import { logToolCall } from '../telemetry/loggers.js';
|
import { logToolCall } from '../telemetry/loggers.js';
|
||||||
import { ToolErrorType } from '../tools/tool-error.js';
|
import { ToolErrorType } from '../tools/tool-error.js';
|
||||||
import { ToolCallEvent } from '../telemetry/types.js';
|
import { ToolCallEvent } from '../telemetry/types.js';
|
||||||
@@ -44,6 +44,7 @@ import {
|
|||||||
} from '../scheduler/types.js';
|
} from '../scheduler/types.js';
|
||||||
import { ToolExecutor } from '../scheduler/tool-executor.js';
|
import { ToolExecutor } from '../scheduler/tool-executor.js';
|
||||||
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
||||||
|
import { getPolicyDenialError } from '../scheduler/policy.js';
|
||||||
|
|
||||||
export type {
|
export type {
|
||||||
ToolCall,
|
ToolCall,
|
||||||
@@ -64,9 +65,6 @@ export type {
|
|||||||
ToolCallResponseInfo,
|
ToolCallResponseInfo,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const PLAN_MODE_DENIAL_MESSAGE =
|
|
||||||
'You are in Plan Mode - adjust your prompt to only use read and search tools.';
|
|
||||||
|
|
||||||
const createErrorResponse = (
|
const createErrorResponse = (
|
||||||
request: ToolCallRequestInfo,
|
request: ToolCallRequestInfo,
|
||||||
error: Error,
|
error: Error,
|
||||||
@@ -599,18 +597,15 @@ export class CoreToolScheduler {
|
|||||||
? toolCall.tool.serverName
|
? toolCall.tool.serverName
|
||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
const { decision } = await this.config
|
const { decision, rule } = await this.config
|
||||||
.getPolicyEngine()
|
.getPolicyEngine()
|
||||||
.check(toolCallForPolicy, serverName);
|
.check(toolCallForPolicy, serverName);
|
||||||
|
|
||||||
if (decision === PolicyDecision.DENY) {
|
if (decision === PolicyDecision.DENY) {
|
||||||
let errorMessage = `Tool execution denied by policy.`;
|
const { errorMessage, errorType } = getPolicyDenialError(
|
||||||
let errorType = ToolErrorType.POLICY_VIOLATION;
|
this.config,
|
||||||
|
rule,
|
||||||
if (this.config.getApprovalMode() === ApprovalMode.PLAN) {
|
);
|
||||||
errorMessage = PLAN_MODE_DENIAL_MESSAGE;
|
|
||||||
errorType = ToolErrorType.STOP_EXECUTION;
|
|
||||||
}
|
|
||||||
this.setStatusInternal(
|
this.setStatusInternal(
|
||||||
reqInfo.callId,
|
reqInfo.callId,
|
||||||
'error',
|
'error',
|
||||||
|
|||||||
@@ -4,8 +4,20 @@
|
|||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { describe, it, expect, vi, type Mocked } from 'vitest';
|
import {
|
||||||
import { checkPolicy, updatePolicy } from './policy.js';
|
describe,
|
||||||
|
it,
|
||||||
|
expect,
|
||||||
|
vi,
|
||||||
|
type Mocked,
|
||||||
|
beforeEach,
|
||||||
|
afterEach,
|
||||||
|
} from 'vitest';
|
||||||
|
import {
|
||||||
|
checkPolicy,
|
||||||
|
updatePolicy,
|
||||||
|
PLAN_MODE_DENIAL_MESSAGE,
|
||||||
|
} from './policy.js';
|
||||||
import type { Config } from '../config/config.js';
|
import type { Config } from '../config/config.js';
|
||||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||||
import { MessageBusType } from '../confirmation-bus/types.js';
|
import { MessageBusType } from '../confirmation-bus/types.js';
|
||||||
@@ -15,10 +27,20 @@ import {
|
|||||||
type AnyDeclarativeTool,
|
type AnyDeclarativeTool,
|
||||||
type ToolMcpConfirmationDetails,
|
type ToolMcpConfirmationDetails,
|
||||||
type ToolExecuteConfirmationDetails,
|
type ToolExecuteConfirmationDetails,
|
||||||
|
type AnyToolInvocation,
|
||||||
} from '../tools/tools.js';
|
} from '../tools/tools.js';
|
||||||
import type { ValidatingToolCall } from './types.js';
|
import type {
|
||||||
|
ValidatingToolCall,
|
||||||
|
ToolCallRequestInfo,
|
||||||
|
CompletedToolCall,
|
||||||
|
} from './types.js';
|
||||||
import type { PolicyEngine } from '../policy/policy-engine.js';
|
import type { PolicyEngine } from '../policy/policy-engine.js';
|
||||||
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
||||||
|
import { CoreToolScheduler } from '../core/coreToolScheduler.js';
|
||||||
|
import { Scheduler } from './scheduler.js';
|
||||||
|
import { ROOT_SCHEDULER_ID } from './types.js';
|
||||||
|
import { ToolErrorType } from '../tools/tool-error.js';
|
||||||
|
import type { ToolRegistry } from '../tools/tool-registry.js';
|
||||||
|
|
||||||
describe('policy.ts', () => {
|
describe('policy.ts', () => {
|
||||||
describe('checkPolicy', () => {
|
describe('checkPolicy', () => {
|
||||||
@@ -420,3 +442,113 @@ describe('policy.ts', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('Plan Mode Denial Consistency', () => {
|
||||||
|
let mockConfig: Mocked<Config>;
|
||||||
|
let mockMessageBus: Mocked<MessageBus>;
|
||||||
|
let mockPolicyEngine: Mocked<PolicyEngine>;
|
||||||
|
let mockToolRegistry: Mocked<ToolRegistry>;
|
||||||
|
let mockTool: AnyDeclarativeTool;
|
||||||
|
let mockInvocation: AnyToolInvocation;
|
||||||
|
|
||||||
|
const req: ToolCallRequestInfo = {
|
||||||
|
callId: 'call-1',
|
||||||
|
name: 'test-tool',
|
||||||
|
args: { foo: 'bar' },
|
||||||
|
isClientInitiated: false,
|
||||||
|
prompt_id: 'prompt-1',
|
||||||
|
schedulerId: ROOT_SCHEDULER_ID,
|
||||||
|
};
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockTool = {
|
||||||
|
name: 'test-tool',
|
||||||
|
build: vi.fn(),
|
||||||
|
} as unknown as AnyDeclarativeTool;
|
||||||
|
|
||||||
|
mockInvocation = {
|
||||||
|
shouldConfirmExecute: vi.fn(),
|
||||||
|
} as unknown as AnyToolInvocation;
|
||||||
|
vi.mocked(mockTool.build).mockReturnValue(mockInvocation);
|
||||||
|
|
||||||
|
mockPolicyEngine = {
|
||||||
|
check: vi.fn().mockResolvedValue({ decision: PolicyDecision.DENY }), // Default to DENY for this test
|
||||||
|
} as unknown as Mocked<PolicyEngine>;
|
||||||
|
|
||||||
|
mockToolRegistry = {
|
||||||
|
getTool: vi.fn().mockReturnValue(mockTool),
|
||||||
|
getAllToolNames: vi.fn().mockReturnValue(['test-tool']),
|
||||||
|
} as unknown as Mocked<ToolRegistry>;
|
||||||
|
|
||||||
|
mockMessageBus = {
|
||||||
|
publish: vi.fn(),
|
||||||
|
subscribe: vi.fn(),
|
||||||
|
} as unknown as Mocked<MessageBus>;
|
||||||
|
|
||||||
|
mockConfig = {
|
||||||
|
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
|
||||||
|
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||||
|
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
|
||||||
|
isInteractive: vi.fn().mockReturnValue(true),
|
||||||
|
getEnableHooks: vi.fn().mockReturnValue(false),
|
||||||
|
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.PLAN), // Key: Plan Mode
|
||||||
|
setApprovalMode: vi.fn(),
|
||||||
|
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
|
||||||
|
} as unknown as Mocked<Config>;
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe.each([
|
||||||
|
{ enableEventDrivenScheduler: false, name: 'Legacy CoreToolScheduler' },
|
||||||
|
{ enableEventDrivenScheduler: true, name: 'Event-Driven Scheduler' },
|
||||||
|
])('$name', ({ enableEventDrivenScheduler }) => {
|
||||||
|
it('should return the correct Plan Mode denial message when policy denies execution', async () => {
|
||||||
|
let resultMessage: string | undefined;
|
||||||
|
let resultErrorType: ToolErrorType | undefined;
|
||||||
|
|
||||||
|
const signal = new AbortController().signal;
|
||||||
|
|
||||||
|
if (enableEventDrivenScheduler) {
|
||||||
|
const scheduler = new Scheduler({
|
||||||
|
config: mockConfig,
|
||||||
|
messageBus: mockMessageBus,
|
||||||
|
getPreferredEditor: () => undefined,
|
||||||
|
schedulerId: ROOT_SCHEDULER_ID,
|
||||||
|
});
|
||||||
|
|
||||||
|
const results = await scheduler.schedule(req, signal);
|
||||||
|
const result = results[0];
|
||||||
|
|
||||||
|
expect(result.status).toBe('error');
|
||||||
|
if (result.status === 'error') {
|
||||||
|
resultMessage = result.response.error?.message;
|
||||||
|
resultErrorType = result.response.errorType;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let capturedCalls: CompletedToolCall[] = [];
|
||||||
|
const scheduler = new CoreToolScheduler({
|
||||||
|
config: mockConfig,
|
||||||
|
getPreferredEditor: () => undefined,
|
||||||
|
onAllToolCallsComplete: async (calls) => {
|
||||||
|
capturedCalls = calls;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await scheduler.schedule(req, signal);
|
||||||
|
|
||||||
|
expect(capturedCalls.length).toBeGreaterThan(0);
|
||||||
|
const call = capturedCalls[0];
|
||||||
|
if (call.status === 'error') {
|
||||||
|
resultMessage = call.response.error?.message;
|
||||||
|
resultErrorType = call.response.errorType;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(resultMessage).toBe(PLAN_MODE_DENIAL_MESSAGE);
|
||||||
|
expect(resultErrorType).toBe(ToolErrorType.STOP_EXECUTION);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|||||||
@@ -4,10 +4,12 @@
|
|||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import { ToolErrorType } from '../tools/tool-error.js';
|
||||||
import {
|
import {
|
||||||
ApprovalMode,
|
ApprovalMode,
|
||||||
PolicyDecision,
|
PolicyDecision,
|
||||||
type CheckResult,
|
type CheckResult,
|
||||||
|
type PolicyRule,
|
||||||
} from '../policy/types.js';
|
} from '../policy/types.js';
|
||||||
import type { Config } from '../config/config.js';
|
import type { Config } from '../config/config.js';
|
||||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||||
@@ -24,6 +26,30 @@ import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
|||||||
import { EDIT_TOOL_NAMES } from '../tools/tool-names.js';
|
import { EDIT_TOOL_NAMES } from '../tools/tool-names.js';
|
||||||
import type { ValidatingToolCall } from './types.js';
|
import type { ValidatingToolCall } from './types.js';
|
||||||
|
|
||||||
|
export const PLAN_MODE_DENIAL_MESSAGE =
|
||||||
|
'You are in Plan Mode - adjust your prompt to only use read and search tools.';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper to determine the error message and type for a policy denial.
|
||||||
|
*/
|
||||||
|
export function getPolicyDenialError(
|
||||||
|
config: Config,
|
||||||
|
rule?: PolicyRule,
|
||||||
|
): { errorMessage: string; errorType: ToolErrorType } {
|
||||||
|
if (config.getApprovalMode() === ApprovalMode.PLAN) {
|
||||||
|
return {
|
||||||
|
errorMessage: PLAN_MODE_DENIAL_MESSAGE,
|
||||||
|
errorType: ToolErrorType.STOP_EXECUTION,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const denyMessage = rule?.denyMessage ? ` ${rule.denyMessage}` : '';
|
||||||
|
return {
|
||||||
|
errorMessage: `Tool execution denied by policy.${denyMessage}`,
|
||||||
|
errorType: ToolErrorType.POLICY_VIOLATION,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Queries the system PolicyEngine to determine tool allowance.
|
* Queries the system PolicyEngine to determine tool allowance.
|
||||||
* @returns The PolicyDecision.
|
* @returns The PolicyDecision.
|
||||||
|
|||||||
@@ -46,7 +46,14 @@ import { ToolModificationHandler } from './tool-modifier.js';
|
|||||||
|
|
||||||
vi.mock('./state-manager.js');
|
vi.mock('./state-manager.js');
|
||||||
vi.mock('./confirmation.js');
|
vi.mock('./confirmation.js');
|
||||||
vi.mock('./policy.js');
|
vi.mock('./policy.js', async (importOriginal) => {
|
||||||
|
const actual = await importOriginal<typeof import('./policy.js')>();
|
||||||
|
return {
|
||||||
|
...actual,
|
||||||
|
checkPolicy: vi.fn(),
|
||||||
|
updatePolicy: vi.fn(),
|
||||||
|
};
|
||||||
|
});
|
||||||
vi.mock('./tool-executor.js');
|
vi.mock('./tool-executor.js');
|
||||||
vi.mock('./tool-modifier.js');
|
vi.mock('./tool-modifier.js');
|
||||||
|
|
||||||
@@ -55,7 +62,7 @@ import type { Config } from '../config/config.js';
|
|||||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||||
import type { PolicyEngine } from '../policy/policy-engine.js';
|
import type { PolicyEngine } from '../policy/policy-engine.js';
|
||||||
import type { ToolRegistry } from '../tools/tool-registry.js';
|
import type { ToolRegistry } from '../tools/tool-registry.js';
|
||||||
import { PolicyDecision } from '../policy/types.js';
|
import { PolicyDecision, ApprovalMode } from '../policy/types.js';
|
||||||
import {
|
import {
|
||||||
ToolConfirmationOutcome,
|
ToolConfirmationOutcome,
|
||||||
type AnyDeclarativeTool,
|
type AnyDeclarativeTool,
|
||||||
@@ -149,6 +156,7 @@ describe('Scheduler (Orchestrator)', () => {
|
|||||||
isInteractive: vi.fn().mockReturnValue(true),
|
isInteractive: vi.fn().mockReturnValue(true),
|
||||||
getEnableHooks: vi.fn().mockReturnValue(true),
|
getEnableHooks: vi.fn().mockReturnValue(true),
|
||||||
setApprovalMode: vi.fn(),
|
setApprovalMode: vi.fn(),
|
||||||
|
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
|
||||||
} as unknown as Mocked<Config>;
|
} as unknown as Mocked<Config>;
|
||||||
|
|
||||||
mockMessageBus = {
|
mockMessageBus = {
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import type { Config } from '../config/config.js';
|
|||||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||||
import { SchedulerStateManager } from './state-manager.js';
|
import { SchedulerStateManager } from './state-manager.js';
|
||||||
import { resolveConfirmation } from './confirmation.js';
|
import { resolveConfirmation } from './confirmation.js';
|
||||||
import { checkPolicy, updatePolicy } from './policy.js';
|
import { checkPolicy, updatePolicy, getPolicyDenialError } from './policy.js';
|
||||||
import { ToolExecutor } from './tool-executor.js';
|
import { ToolExecutor } from './tool-executor.js';
|
||||||
import { ToolModificationHandler } from './tool-modifier.js';
|
import { ToolModificationHandler } from './tool-modifier.js';
|
||||||
import {
|
import {
|
||||||
@@ -407,14 +407,18 @@ export class Scheduler {
|
|||||||
const { decision, rule } = await checkPolicy(toolCall, this.config);
|
const { decision, rule } = await checkPolicy(toolCall, this.config);
|
||||||
|
|
||||||
if (decision === PolicyDecision.DENY) {
|
if (decision === PolicyDecision.DENY) {
|
||||||
const denyMessage = rule?.denyMessage ? ` ${rule.denyMessage}` : '';
|
const { errorMessage, errorType } = getPolicyDenialError(
|
||||||
|
this.config,
|
||||||
|
rule,
|
||||||
|
);
|
||||||
|
|
||||||
this.state.updateStatus(
|
this.state.updateStatus(
|
||||||
callId,
|
callId,
|
||||||
'error',
|
'error',
|
||||||
createErrorResponse(
|
createErrorResponse(
|
||||||
toolCall.request,
|
toolCall.request,
|
||||||
new Error(`Tool execution denied by policy.${denyMessage}`),
|
new Error(errorMessage),
|
||||||
ToolErrorType.POLICY_VIOLATION,
|
errorType,
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
this.state.finalizeCall(callId);
|
this.state.finalizeCall(callId);
|
||||||
|
|||||||
Reference in New Issue
Block a user