mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-30 06:54:15 -07:00
feat(safety): Introduce safety checker framework (#12504)
This commit is contained in:
@@ -26,12 +26,12 @@ describe('MessageBus', () => {
|
||||
});
|
||||
|
||||
describe('publish', () => {
|
||||
it('should emit error for invalid message', () => {
|
||||
it('should emit error for invalid message', async () => {
|
||||
const errorHandler = vi.fn();
|
||||
messageBus.on('error', errorHandler);
|
||||
|
||||
// @ts-expect-error - Testing invalid message
|
||||
messageBus.publish({ invalid: 'message' });
|
||||
await messageBus.publish({ invalid: 'message' });
|
||||
|
||||
expect(errorHandler).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
@@ -40,12 +40,12 @@ describe('MessageBus', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should validate tool confirmation requests have correlationId', () => {
|
||||
it('should validate tool confirmation requests have correlationId', async () => {
|
||||
const errorHandler = vi.fn();
|
||||
messageBus.on('error', errorHandler);
|
||||
|
||||
// @ts-expect-error - Testing missing correlationId
|
||||
messageBus.publish({
|
||||
await messageBus.publish({
|
||||
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
|
||||
toolCall: { name: 'test' },
|
||||
});
|
||||
@@ -53,8 +53,10 @@ describe('MessageBus', () => {
|
||||
expect(errorHandler).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should emit confirmation response when policy allows', () => {
|
||||
vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.ALLOW);
|
||||
it('should emit confirmation response when policy allows', async () => {
|
||||
vi.spyOn(policyEngine, 'check').mockResolvedValue({
|
||||
decision: PolicyDecision.ALLOW,
|
||||
});
|
||||
|
||||
const responseHandler = vi.fn();
|
||||
messageBus.subscribe(
|
||||
@@ -68,7 +70,7 @@ describe('MessageBus', () => {
|
||||
correlationId: '123',
|
||||
};
|
||||
|
||||
messageBus.publish(request);
|
||||
await messageBus.publish(request);
|
||||
|
||||
const expectedResponse: ToolConfirmationResponse = {
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
@@ -78,8 +80,10 @@ describe('MessageBus', () => {
|
||||
expect(responseHandler).toHaveBeenCalledWith(expectedResponse);
|
||||
});
|
||||
|
||||
it('should emit rejection and response when policy denies', () => {
|
||||
vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.DENY);
|
||||
it('should emit rejection and response when policy denies', async () => {
|
||||
vi.spyOn(policyEngine, 'check').mockResolvedValue({
|
||||
decision: PolicyDecision.DENY,
|
||||
});
|
||||
|
||||
const responseHandler = vi.fn();
|
||||
const rejectionHandler = vi.fn();
|
||||
@@ -98,7 +102,7 @@ describe('MessageBus', () => {
|
||||
correlationId: '123',
|
||||
};
|
||||
|
||||
messageBus.publish(request);
|
||||
await messageBus.publish(request);
|
||||
|
||||
const expectedRejection: ToolPolicyRejection = {
|
||||
type: MessageBusType.TOOL_POLICY_REJECTION,
|
||||
@@ -114,8 +118,10 @@ describe('MessageBus', () => {
|
||||
expect(responseHandler).toHaveBeenCalledWith(expectedResponse);
|
||||
});
|
||||
|
||||
it('should pass through to UI when policy says ASK_USER', () => {
|
||||
vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.ASK_USER);
|
||||
it('should pass through to UI when policy says ASK_USER', async () => {
|
||||
vi.spyOn(policyEngine, 'check').mockResolvedValue({
|
||||
decision: PolicyDecision.ASK_USER,
|
||||
});
|
||||
|
||||
const requestHandler = vi.fn();
|
||||
messageBus.subscribe(
|
||||
@@ -129,12 +135,12 @@ describe('MessageBus', () => {
|
||||
correlationId: '123',
|
||||
};
|
||||
|
||||
messageBus.publish(request);
|
||||
await messageBus.publish(request);
|
||||
|
||||
expect(requestHandler).toHaveBeenCalledWith(request);
|
||||
});
|
||||
|
||||
it('should emit other message types directly', () => {
|
||||
it('should emit other message types directly', async () => {
|
||||
const successHandler = vi.fn();
|
||||
messageBus.subscribe(
|
||||
MessageBusType.TOOL_EXECUTION_SUCCESS,
|
||||
@@ -147,14 +153,14 @@ describe('MessageBus', () => {
|
||||
result: 'success',
|
||||
};
|
||||
|
||||
messageBus.publish(message);
|
||||
await messageBus.publish(message);
|
||||
|
||||
expect(successHandler).toHaveBeenCalledWith(message);
|
||||
});
|
||||
});
|
||||
|
||||
describe('subscribe/unsubscribe', () => {
|
||||
it('should allow subscribing to specific message types', () => {
|
||||
it('should allow subscribing to specific message types', async () => {
|
||||
const handler = vi.fn();
|
||||
messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler);
|
||||
|
||||
@@ -164,12 +170,12 @@ describe('MessageBus', () => {
|
||||
result: 'test',
|
||||
};
|
||||
|
||||
messageBus.publish(message);
|
||||
await messageBus.publish(message);
|
||||
|
||||
expect(handler).toHaveBeenCalledWith(message);
|
||||
});
|
||||
|
||||
it('should allow unsubscribing from message types', () => {
|
||||
it('should allow unsubscribing from message types', async () => {
|
||||
const handler = vi.fn();
|
||||
messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler);
|
||||
messageBus.unsubscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler);
|
||||
@@ -180,12 +186,12 @@ describe('MessageBus', () => {
|
||||
result: 'test',
|
||||
};
|
||||
|
||||
messageBus.publish(message);
|
||||
await messageBus.publish(message);
|
||||
|
||||
expect(handler).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should support multiple subscribers for the same message type', () => {
|
||||
it('should support multiple subscribers for the same message type', async () => {
|
||||
const handler1 = vi.fn();
|
||||
const handler2 = vi.fn();
|
||||
|
||||
@@ -198,7 +204,7 @@ describe('MessageBus', () => {
|
||||
result: 'test',
|
||||
};
|
||||
|
||||
messageBus.publish(message);
|
||||
await messageBus.publish(message);
|
||||
|
||||
expect(handler1).toHaveBeenCalledWith(message);
|
||||
expect(handler2).toHaveBeenCalledWith(message);
|
||||
@@ -206,12 +212,12 @@ describe('MessageBus', () => {
|
||||
});
|
||||
|
||||
describe('error handling', () => {
|
||||
it('should not crash on errors during message processing', () => {
|
||||
it('should not crash on errors during message processing', async () => {
|
||||
const errorHandler = vi.fn();
|
||||
messageBus.on('error', errorHandler);
|
||||
|
||||
// Mock policyEngine to throw an error
|
||||
vi.spyOn(policyEngine, 'check').mockImplementation(() => {
|
||||
vi.spyOn(policyEngine, 'check').mockImplementation(async () => {
|
||||
throw new Error('Policy check failed');
|
||||
});
|
||||
|
||||
@@ -222,7 +228,7 @@ describe('MessageBus', () => {
|
||||
};
|
||||
|
||||
// Should not throw
|
||||
expect(() => messageBus.publish(request)).not.toThrow();
|
||||
await expect(messageBus.publish(request)).resolves.not.toThrow();
|
||||
|
||||
// Should emit error
|
||||
expect(errorHandler).toHaveBeenCalledWith(
|
||||
|
||||
@@ -38,7 +38,7 @@ export class MessageBus extends EventEmitter {
|
||||
this.emit(message.type, message);
|
||||
}
|
||||
|
||||
publish(message: Message): void {
|
||||
async publish(message: Message): Promise<void> {
|
||||
if (this.debug) {
|
||||
console.debug(`[MESSAGE_BUS] publish: ${safeJsonStringify(message)}`);
|
||||
}
|
||||
@@ -50,7 +50,7 @@ export class MessageBus extends EventEmitter {
|
||||
}
|
||||
|
||||
if (message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) {
|
||||
const decision = this.policyEngine.check(
|
||||
const { decision } = await this.policyEngine.check(
|
||||
message.toolCall,
|
||||
message.serverName,
|
||||
);
|
||||
|
||||
@@ -9,7 +9,7 @@ import { describe, it, expect, vi, afterEach, beforeEach } from 'vitest';
|
||||
import nodePath from 'node:path';
|
||||
|
||||
import type { PolicySettings } from './types.js';
|
||||
import { ApprovalMode, PolicyDecision } from './types.js';
|
||||
import { ApprovalMode, PolicyDecision, InProcessCheckerType } from './types.js';
|
||||
|
||||
import { Storage } from '../config/storage.js';
|
||||
|
||||
@@ -642,6 +642,194 @@ priority = 150
|
||||
vi.doUnmock('node:fs/promises');
|
||||
});
|
||||
|
||||
it('should load safety_checker configuration from TOML', async () => {
|
||||
const actualFs =
|
||||
await vi.importActual<typeof import('node:fs/promises')>(
|
||||
'node:fs/promises',
|
||||
);
|
||||
|
||||
const mockReaddir = vi.fn(
|
||||
async (
|
||||
path: string | Buffer | URL,
|
||||
options?: Parameters<typeof actualFs.readdir>[1],
|
||||
) => {
|
||||
if (
|
||||
typeof path === 'string' &&
|
||||
nodePath
|
||||
.normalize(path)
|
||||
.includes(nodePath.normalize('.gemini/policies'))
|
||||
) {
|
||||
return [
|
||||
{
|
||||
name: 'safety.toml',
|
||||
isFile: () => true,
|
||||
isDirectory: () => false,
|
||||
},
|
||||
] as unknown as Awaited<ReturnType<typeof actualFs.readdir>>;
|
||||
}
|
||||
return actualFs.readdir(
|
||||
path,
|
||||
options as Parameters<typeof actualFs.readdir>[1],
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
const mockReadFile = vi.fn(
|
||||
async (
|
||||
path: Parameters<typeof actualFs.readFile>[0],
|
||||
options: Parameters<typeof actualFs.readFile>[1],
|
||||
) => {
|
||||
if (
|
||||
typeof path === 'string' &&
|
||||
nodePath
|
||||
.normalize(path)
|
||||
.includes(nodePath.normalize('.gemini/policies/safety.toml'))
|
||||
) {
|
||||
return `
|
||||
[[rule]]
|
||||
toolName = "write_file"
|
||||
decision = "allow"
|
||||
priority = 10
|
||||
|
||||
[[rule]]
|
||||
toolName = "write_file"
|
||||
decision = "allow"
|
||||
priority = 10
|
||||
|
||||
[[safety_checker]]
|
||||
toolName = "write_file"
|
||||
priority = 10
|
||||
[safety_checker.checker]
|
||||
type = "in-process"
|
||||
name = "allowed-path"
|
||||
required_context = ["environment"]
|
||||
[safety_checker.checker.config]
|
||||
`;
|
||||
}
|
||||
return actualFs.readFile(path, options);
|
||||
},
|
||||
);
|
||||
|
||||
vi.doMock('node:fs/promises', () => ({
|
||||
...actualFs,
|
||||
default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir },
|
||||
readFile: mockReadFile,
|
||||
readdir: mockReaddir,
|
||||
}));
|
||||
|
||||
vi.resetModules();
|
||||
const { createPolicyEngineConfig } = await import('./config.js');
|
||||
|
||||
const settings: PolicySettings = {};
|
||||
const config = await createPolicyEngineConfig(
|
||||
settings,
|
||||
ApprovalMode.DEFAULT,
|
||||
'/tmp/mock/default/policies',
|
||||
);
|
||||
|
||||
const rule = config.rules?.find(
|
||||
(r) => r.toolName === 'write_file' && r.decision === PolicyDecision.ALLOW,
|
||||
);
|
||||
expect(rule).toBeDefined();
|
||||
|
||||
const checker = config.checkers?.find(
|
||||
(c) => c.toolName === 'write_file' && c.checker.type === 'in-process',
|
||||
);
|
||||
expect(checker).toBeDefined();
|
||||
expect(checker?.checker.type).toBe('in-process');
|
||||
expect(checker?.checker.name).toBe(InProcessCheckerType.ALLOWED_PATH);
|
||||
expect(checker?.checker.required_context).toEqual(['environment']);
|
||||
|
||||
vi.doUnmock('node:fs/promises');
|
||||
});
|
||||
|
||||
it('should reject invalid in-process checker names', async () => {
|
||||
const actualFs =
|
||||
await vi.importActual<typeof import('node:fs/promises')>(
|
||||
'node:fs/promises',
|
||||
);
|
||||
|
||||
const mockReaddir = vi.fn(
|
||||
async (
|
||||
path: string | Buffer | URL,
|
||||
options?: Parameters<typeof actualFs.readdir>[1],
|
||||
) => {
|
||||
if (
|
||||
typeof path === 'string' &&
|
||||
nodePath
|
||||
.normalize(path)
|
||||
.includes(nodePath.normalize('.gemini/policies'))
|
||||
) {
|
||||
return [
|
||||
{
|
||||
name: 'invalid_safety.toml',
|
||||
isFile: () => true,
|
||||
isDirectory: () => false,
|
||||
},
|
||||
] as unknown as Awaited<ReturnType<typeof actualFs.readdir>>;
|
||||
}
|
||||
return actualFs.readdir(
|
||||
path,
|
||||
options as Parameters<typeof actualFs.readdir>[1],
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
const mockReadFile = vi.fn(
|
||||
async (
|
||||
path: Parameters<typeof actualFs.readFile>[0],
|
||||
options: Parameters<typeof actualFs.readFile>[1],
|
||||
) => {
|
||||
if (
|
||||
typeof path === 'string' &&
|
||||
nodePath
|
||||
.normalize(path)
|
||||
.includes(
|
||||
nodePath.normalize('.gemini/policies/invalid_safety.toml'),
|
||||
)
|
||||
) {
|
||||
return `
|
||||
[[rule]]
|
||||
toolName = "write_file"
|
||||
decision = "allow"
|
||||
priority = 10
|
||||
|
||||
[[safety_checker]]
|
||||
toolName = "write_file"
|
||||
priority = 10
|
||||
[safety_checker.checker]
|
||||
type = "in-process"
|
||||
name = "invalid-name"
|
||||
`;
|
||||
}
|
||||
return actualFs.readFile(path, options);
|
||||
},
|
||||
);
|
||||
|
||||
vi.doMock('node:fs/promises', () => ({
|
||||
...actualFs,
|
||||
default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir },
|
||||
readFile: mockReadFile,
|
||||
readdir: mockReaddir,
|
||||
}));
|
||||
|
||||
vi.resetModules();
|
||||
const { createPolicyEngineConfig } = await import('./config.js');
|
||||
|
||||
const settings: PolicySettings = {};
|
||||
const config = await createPolicyEngineConfig(
|
||||
settings,
|
||||
ApprovalMode.DEFAULT,
|
||||
'/tmp/mock/default/policies',
|
||||
);
|
||||
|
||||
// The rule should be rejected because 'invalid-name' is not in the enum
|
||||
const rule = config.rules?.find((r) => r.toolName === 'write_file');
|
||||
expect(rule).toBeUndefined();
|
||||
|
||||
vi.doUnmock('node:fs/promises');
|
||||
});
|
||||
|
||||
it('should have default ASK_USER rule for discovered tools', async () => {
|
||||
vi.resetModules();
|
||||
vi.doUnmock('node:fs/promises');
|
||||
|
||||
@@ -114,10 +114,12 @@ export async function createPolicyEngineConfig(
|
||||
const policyDirs = getPolicyDirectories(defaultPoliciesDir);
|
||||
|
||||
// Load policies from TOML files
|
||||
const { rules: tomlRules, errors } = await loadPoliciesFromToml(
|
||||
approvalMode,
|
||||
policyDirs,
|
||||
(dir) => getPolicyTier(dir, defaultPoliciesDir),
|
||||
const {
|
||||
rules: tomlRules,
|
||||
checkers: tomlCheckers,
|
||||
errors,
|
||||
} = await loadPoliciesFromToml(approvalMode, policyDirs, (dir) =>
|
||||
getPolicyTier(dir, defaultPoliciesDir),
|
||||
);
|
||||
|
||||
// Emit any errors encountered during TOML loading to the UI
|
||||
@@ -129,6 +131,7 @@ export async function createPolicyEngineConfig(
|
||||
}
|
||||
|
||||
const rules: PolicyRule[] = [...tomlRules];
|
||||
const checkers = [...tomlCheckers];
|
||||
|
||||
// Priority system for policy rules:
|
||||
// - Higher priority numbers win over lower priority numbers
|
||||
@@ -225,6 +228,7 @@ export async function createPolicyEngineConfig(
|
||||
|
||||
return {
|
||||
rules,
|
||||
checkers,
|
||||
defaultDecision: PolicyDecision.ASK_USER,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -36,6 +36,11 @@ decision = "allow"
|
||||
priority = 15
|
||||
modes = ["autoEdit"]
|
||||
|
||||
[rule.safety_checker]
|
||||
type = "in-process"
|
||||
name = "allowed-path"
|
||||
required_context = ["environment"]
|
||||
|
||||
[[rule]]
|
||||
toolName = "save_memory"
|
||||
decision = "ask_user"
|
||||
@@ -57,6 +62,11 @@ decision = "allow"
|
||||
priority = 15
|
||||
modes = ["autoEdit"]
|
||||
|
||||
[rule.safety_checker]
|
||||
type = "in-process"
|
||||
name = "allowed-path"
|
||||
required_context = ["environment"]
|
||||
|
||||
[[rule]]
|
||||
toolName = "web_fetch"
|
||||
decision = "ask_user"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,12 +9,15 @@ import {
|
||||
PolicyDecision,
|
||||
type PolicyEngineConfig,
|
||||
type PolicyRule,
|
||||
type SafetyCheckerRule,
|
||||
} from './types.js';
|
||||
import { stableStringify } from './stable-stringify.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { CheckerRunner } from '../safety/checker-runner.js';
|
||||
import { SafetyCheckDecision } from '../safety/protocol.js';
|
||||
|
||||
function ruleMatches(
|
||||
rule: PolicyRule,
|
||||
rule: PolicyRule | SafetyCheckerRule,
|
||||
toolCall: FunctionCall,
|
||||
stringifiedArgs: string | undefined,
|
||||
serverName: string | undefined,
|
||||
@@ -60,27 +63,41 @@ function ruleMatches(
|
||||
|
||||
export class PolicyEngine {
|
||||
private rules: PolicyRule[];
|
||||
private checkers: SafetyCheckerRule[];
|
||||
private readonly defaultDecision: PolicyDecision;
|
||||
private readonly nonInteractive: boolean;
|
||||
private readonly checkerRunner?: CheckerRunner;
|
||||
|
||||
constructor(config: PolicyEngineConfig = {}) {
|
||||
constructor(config: PolicyEngineConfig = {}, checkerRunner?: CheckerRunner) {
|
||||
this.rules = (config.rules ?? []).sort(
|
||||
(a, b) => (b.priority ?? 0) - (a.priority ?? 0),
|
||||
);
|
||||
this.checkers = (config.checkers ?? []).sort(
|
||||
(a, b) => (b.priority ?? 0) - (a.priority ?? 0),
|
||||
);
|
||||
this.defaultDecision = config.defaultDecision ?? PolicyDecision.ASK_USER;
|
||||
this.nonInteractive = config.nonInteractive ?? false;
|
||||
this.checkerRunner = checkerRunner;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a tool call is allowed based on the configured policies.
|
||||
* Returns the decision and the matching rule (if any).
|
||||
*/
|
||||
check(
|
||||
async check(
|
||||
toolCall: FunctionCall,
|
||||
serverName: string | undefined,
|
||||
): PolicyDecision {
|
||||
): Promise<{
|
||||
decision: PolicyDecision;
|
||||
rule?: PolicyRule;
|
||||
}> {
|
||||
let stringifiedArgs: string | undefined;
|
||||
// Compute stringified args once before the loop
|
||||
if (toolCall.args && this.rules.some((rule) => rule.argsPattern)) {
|
||||
if (
|
||||
toolCall.args &&
|
||||
(this.rules.some((rule) => rule.argsPattern) ||
|
||||
this.checkers.some((checker) => checker.argsPattern))
|
||||
) {
|
||||
stringifiedArgs = stableStringify(toolCall.args);
|
||||
}
|
||||
|
||||
@@ -89,20 +106,72 @@ export class PolicyEngine {
|
||||
);
|
||||
|
||||
// Find the first matching rule (already sorted by priority)
|
||||
let matchedRule: PolicyRule | undefined;
|
||||
let decision: PolicyDecision | undefined;
|
||||
|
||||
for (const rule of this.rules) {
|
||||
if (ruleMatches(rule, toolCall, stringifiedArgs, serverName)) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.check] MATCHED rule: toolName=${rule.toolName}, decision=${rule.decision}, priority=${rule.priority}, argsPattern=${rule.argsPattern?.source || 'none'}`,
|
||||
);
|
||||
return this.applyNonInteractiveMode(rule.decision);
|
||||
matchedRule = rule;
|
||||
decision = this.applyNonInteractiveMode(rule.decision);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// No matching rule found, use default decision
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.check] NO MATCH - using default decision: ${this.defaultDecision}`,
|
||||
);
|
||||
return this.applyNonInteractiveMode(this.defaultDecision);
|
||||
if (!decision) {
|
||||
// No matching rule found, use default decision
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.check] NO MATCH - using default decision: ${this.defaultDecision}`,
|
||||
);
|
||||
decision = this.applyNonInteractiveMode(this.defaultDecision);
|
||||
}
|
||||
|
||||
// If decision is not DENY, run safety checkers
|
||||
if (decision !== PolicyDecision.DENY && this.checkerRunner) {
|
||||
for (const checkerRule of this.checkers) {
|
||||
if (ruleMatches(checkerRule, toolCall, stringifiedArgs, serverName)) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.check] Running safety checker: ${checkerRule.checker.name}`,
|
||||
);
|
||||
try {
|
||||
const result = await this.checkerRunner.runChecker(
|
||||
toolCall,
|
||||
checkerRule.checker,
|
||||
);
|
||||
|
||||
if (result.decision === SafetyCheckDecision.DENY) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.check] Safety checker denied: ${result.reason}`,
|
||||
);
|
||||
return {
|
||||
decision: PolicyDecision.DENY,
|
||||
rule: matchedRule,
|
||||
};
|
||||
} else if (result.decision === SafetyCheckDecision.ASK_USER) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.check] Safety checker requested ASK_USER: ${result.reason}`,
|
||||
);
|
||||
decision = PolicyDecision.ASK_USER;
|
||||
}
|
||||
} catch (error) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.check] Safety checker failed: ${error}`,
|
||||
);
|
||||
return {
|
||||
decision: PolicyDecision.DENY,
|
||||
rule: matchedRule,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
decision: this.applyNonInteractiveMode(decision),
|
||||
rule: matchedRule,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -114,6 +183,11 @@ export class PolicyEngine {
|
||||
this.rules.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0));
|
||||
}
|
||||
|
||||
addChecker(checker: SafetyCheckerRule): void {
|
||||
this.checkers.push(checker);
|
||||
this.checkers.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0));
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove rules for a specific tool.
|
||||
*/
|
||||
@@ -128,6 +202,10 @@ export class PolicyEngine {
|
||||
return this.rules;
|
||||
}
|
||||
|
||||
getCheckers(): readonly SafetyCheckerRule[] {
|
||||
return this.checkers;
|
||||
}
|
||||
|
||||
private applyNonInteractiveMode(decision: PolicyDecision): PolicyDecision {
|
||||
// In non-interactive mode, ASK_USER becomes DENY
|
||||
if (this.nonInteractive && decision === PolicyDecision.ASK_USER) {
|
||||
|
||||
@@ -85,6 +85,7 @@ priority = 100
|
||||
decision: PolicyDecision.ALLOW,
|
||||
priority: 1.1, // tier 1 + 100/1000
|
||||
});
|
||||
expect(result.checkers).toHaveLength(0);
|
||||
expect(result.errors).toHaveLength(0);
|
||||
});
|
||||
|
||||
|
||||
@@ -4,7 +4,14 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type PolicyRule, PolicyDecision, type ApprovalMode } from './types.js';
|
||||
import {
|
||||
type PolicyRule,
|
||||
PolicyDecision,
|
||||
type ApprovalMode,
|
||||
type SafetyCheckerConfig,
|
||||
type SafetyCheckerRule,
|
||||
InProcessCheckerType,
|
||||
} from './types.js';
|
||||
import fs from 'node:fs/promises';
|
||||
import path from 'node:path';
|
||||
import toml from '@iarna/toml';
|
||||
@@ -39,11 +46,39 @@ const PolicyRuleSchema = z.object({
|
||||
modes: z.array(z.string()).optional(),
|
||||
});
|
||||
|
||||
/**
|
||||
* Schema for a single safety checker rule in the TOML file.
|
||||
*/
|
||||
const SafetyCheckerRuleSchema = z.object({
|
||||
toolName: z.union([z.string(), z.array(z.string())]).optional(),
|
||||
mcpName: z.string().optional(),
|
||||
argsPattern: z.string().optional(),
|
||||
commandPrefix: z.union([z.string(), z.array(z.string())]).optional(),
|
||||
commandRegex: z.string().optional(),
|
||||
priority: z.number().int().default(0),
|
||||
modes: z.array(z.string()).optional(),
|
||||
checker: z.discriminatedUnion('type', [
|
||||
z.object({
|
||||
type: z.literal('in-process'),
|
||||
name: z.nativeEnum(InProcessCheckerType),
|
||||
required_context: z.array(z.string()).optional(),
|
||||
config: z.record(z.unknown()).optional(),
|
||||
}),
|
||||
z.object({
|
||||
type: z.literal('external'),
|
||||
name: z.string(),
|
||||
required_context: z.array(z.string()).optional(),
|
||||
config: z.record(z.unknown()).optional(),
|
||||
}),
|
||||
]),
|
||||
});
|
||||
|
||||
/**
|
||||
* Schema for the entire policy TOML file.
|
||||
*/
|
||||
const PolicyFileSchema = z.object({
|
||||
rule: z.array(PolicyRuleSchema),
|
||||
rule: z.array(PolicyRuleSchema).optional(),
|
||||
safety_checker: z.array(SafetyCheckerRuleSchema).optional(),
|
||||
});
|
||||
|
||||
/**
|
||||
@@ -80,6 +115,7 @@ export interface PolicyFileError {
|
||||
*/
|
||||
export interface PolicyLoadResult {
|
||||
rules: PolicyRule[];
|
||||
checkers: SafetyCheckerRule[];
|
||||
errors: PolicyFileError[];
|
||||
}
|
||||
|
||||
@@ -194,6 +230,7 @@ export async function loadPoliciesFromToml(
|
||||
getPolicyTier: (dir: string) => number,
|
||||
): Promise<PolicyLoadResult> {
|
||||
const rules: PolicyRule[] = [];
|
||||
const checkers: SafetyCheckerRule[] = [];
|
||||
const errors: PolicyFileError[] = [];
|
||||
|
||||
for (const dir of policyDirs) {
|
||||
@@ -267,8 +304,9 @@ export async function loadPoliciesFromToml(
|
||||
}
|
||||
|
||||
// Validate shell command convenience syntax
|
||||
for (let i = 0; i < validationResult.data.rule.length; i++) {
|
||||
const rule = validationResult.data.rule[i];
|
||||
const tomlRules = validationResult.data.rule ?? [];
|
||||
for (let i = 0; i < tomlRules.length; i++) {
|
||||
const rule = tomlRules[i];
|
||||
const validationError = validateShellCommandSyntax(rule, i);
|
||||
if (validationError) {
|
||||
errors.push({
|
||||
@@ -285,7 +323,7 @@ export async function loadPoliciesFromToml(
|
||||
}
|
||||
|
||||
// Transform rules
|
||||
const parsedRules: PolicyRule[] = validationResult.data.rule
|
||||
const parsedRules: PolicyRule[] = (validationResult.data.rule ?? [])
|
||||
.filter((rule) => {
|
||||
// Filter by mode
|
||||
if (!rule.modes || rule.modes.length === 0) {
|
||||
@@ -369,6 +407,84 @@ export async function loadPoliciesFromToml(
|
||||
.filter((rule): rule is PolicyRule => rule !== null);
|
||||
|
||||
rules.push(...parsedRules);
|
||||
|
||||
// Transform checkers
|
||||
const parsedCheckers: SafetyCheckerRule[] = (
|
||||
validationResult.data.safety_checker ?? []
|
||||
)
|
||||
.filter((checker) => {
|
||||
if (!checker.modes || checker.modes.length === 0) {
|
||||
return true;
|
||||
}
|
||||
return checker.modes.includes(approvalMode);
|
||||
})
|
||||
.flatMap((checker) => {
|
||||
let effectiveArgsPattern = checker.argsPattern;
|
||||
const commandPrefixes: string[] = [];
|
||||
|
||||
if (checker.commandPrefix) {
|
||||
const prefixes = Array.isArray(checker.commandPrefix)
|
||||
? checker.commandPrefix
|
||||
: [checker.commandPrefix];
|
||||
commandPrefixes.push(...prefixes);
|
||||
} else if (checker.commandRegex) {
|
||||
effectiveArgsPattern = `"command":"${checker.commandRegex}`;
|
||||
}
|
||||
|
||||
const argsPatterns: Array<string | undefined> =
|
||||
commandPrefixes.length > 0
|
||||
? commandPrefixes.map(
|
||||
(prefix) => `"command":"${escapeRegex(prefix)}`,
|
||||
)
|
||||
: [effectiveArgsPattern];
|
||||
|
||||
return argsPatterns.flatMap((argsPattern) => {
|
||||
const toolNames: Array<string | undefined> = checker.toolName
|
||||
? Array.isArray(checker.toolName)
|
||||
? checker.toolName
|
||||
: [checker.toolName]
|
||||
: [undefined];
|
||||
|
||||
return toolNames.map((toolName) => {
|
||||
let effectiveToolName: string | undefined;
|
||||
if (checker.mcpName && toolName) {
|
||||
effectiveToolName = `${checker.mcpName}__${toolName}`;
|
||||
} else if (checker.mcpName) {
|
||||
effectiveToolName = `${checker.mcpName}__*`;
|
||||
} else {
|
||||
effectiveToolName = toolName;
|
||||
}
|
||||
|
||||
const safetyCheckerRule: SafetyCheckerRule = {
|
||||
toolName: effectiveToolName,
|
||||
priority: checker.priority,
|
||||
checker: checker.checker as SafetyCheckerConfig,
|
||||
};
|
||||
|
||||
if (argsPattern) {
|
||||
try {
|
||||
safetyCheckerRule.argsPattern = new RegExp(argsPattern);
|
||||
} catch (e) {
|
||||
const error = e as Error;
|
||||
errors.push({
|
||||
filePath,
|
||||
fileName: file,
|
||||
tier: tierName,
|
||||
errorType: 'regex_compilation',
|
||||
message: 'Invalid regex pattern in safety checker',
|
||||
details: `Pattern: ${argsPattern}\nError: ${error.message}`,
|
||||
});
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return safetyCheckerRule;
|
||||
});
|
||||
});
|
||||
})
|
||||
.filter((checker): checker is SafetyCheckerRule => checker !== null);
|
||||
|
||||
checkers.push(...parsedCheckers);
|
||||
} catch (e) {
|
||||
const error = e as NodeJS.ErrnoException;
|
||||
// Catch-all for unexpected errors
|
||||
@@ -386,5 +502,5 @@ export async function loadPoliciesFromToml(
|
||||
}
|
||||
}
|
||||
|
||||
return { rules, errors };
|
||||
return { rules, checkers, errors };
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { SafetyCheckInput } from '../safety/protocol.js';
|
||||
|
||||
export enum PolicyDecision {
|
||||
ALLOW = 'allow',
|
||||
DENY = 'deny',
|
||||
@@ -16,6 +18,52 @@ export enum ApprovalMode {
|
||||
YOLO = 'yolo',
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for the built-in allowed-path checker.
|
||||
*/
|
||||
export interface AllowedPathConfig {
|
||||
/**
|
||||
* Explicitly include argument keys to be checked as paths.
|
||||
*/
|
||||
included_args?: string[];
|
||||
|
||||
/**
|
||||
* Explicitly exclude argument keys from being checked as paths.
|
||||
*/
|
||||
excluded_args?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Base interface for external checkers.
|
||||
*/
|
||||
export interface ExternalCheckerConfig {
|
||||
type: 'external';
|
||||
name: string;
|
||||
config?: unknown;
|
||||
required_context?: Array<keyof SafetyCheckInput['context']>;
|
||||
}
|
||||
|
||||
export enum InProcessCheckerType {
|
||||
ALLOWED_PATH = 'allowed-path',
|
||||
}
|
||||
|
||||
/**
|
||||
* Base interface for in-process checkers.
|
||||
*/
|
||||
export interface InProcessCheckerConfig {
|
||||
type: 'in-process';
|
||||
name: InProcessCheckerType;
|
||||
config?: AllowedPathConfig;
|
||||
required_context?: Array<keyof SafetyCheckInput['context']>;
|
||||
}
|
||||
|
||||
/**
|
||||
* A discriminated union for all safety checker configurations.
|
||||
*/
|
||||
export type SafetyCheckerConfig =
|
||||
| ExternalCheckerConfig
|
||||
| InProcessCheckerConfig;
|
||||
|
||||
export interface PolicyRule {
|
||||
/**
|
||||
* The name of the tool this rule applies to.
|
||||
@@ -41,12 +89,43 @@ export interface PolicyRule {
|
||||
priority?: number;
|
||||
}
|
||||
|
||||
export interface SafetyCheckerRule {
|
||||
/**
|
||||
* The name of the tool this rule applies to.
|
||||
* If undefined, the rule applies to all tools.
|
||||
*/
|
||||
toolName?: string;
|
||||
|
||||
/**
|
||||
* Pattern to match against tool arguments.
|
||||
* Can be used for more fine-grained control.
|
||||
*/
|
||||
argsPattern?: RegExp;
|
||||
|
||||
/**
|
||||
* Priority of this checker. Higher numbers run first.
|
||||
* Default is 0.
|
||||
*/
|
||||
priority?: number;
|
||||
|
||||
/**
|
||||
* Specifies an external or built-in safety checker to execute for
|
||||
* additional validation of a tool call.
|
||||
*/
|
||||
checker: SafetyCheckerConfig;
|
||||
}
|
||||
|
||||
export interface PolicyEngineConfig {
|
||||
/**
|
||||
* List of policy rules to apply.
|
||||
*/
|
||||
rules?: PolicyRule[];
|
||||
|
||||
/**
|
||||
* List of safety checkers to apply.
|
||||
*/
|
||||
checkers?: SafetyCheckerRule[];
|
||||
|
||||
/**
|
||||
* Default decision when no rules match.
|
||||
* Defaults to ASK_USER.
|
||||
|
||||
@@ -0,0 +1,244 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import * as os from 'node:os';
|
||||
import * as path from 'node:path';
|
||||
import { AllowedPathChecker } from './built-in.js';
|
||||
import type { SafetyCheckInput } from './protocol.js';
|
||||
import { SafetyCheckDecision } from './protocol.js';
|
||||
import type { FunctionCall } from '@google/genai';
|
||||
|
||||
describe('AllowedPathChecker', () => {
|
||||
let checker: AllowedPathChecker;
|
||||
let testRootDir: string;
|
||||
let mockCwd: string;
|
||||
let mockWorkspaces: string[];
|
||||
|
||||
beforeEach(async () => {
|
||||
checker = new AllowedPathChecker();
|
||||
testRootDir = await fs.mkdtemp(path.join(os.tmpdir(), 'safety-test-'));
|
||||
mockCwd = path.join(testRootDir, 'home', 'user', 'project');
|
||||
await fs.mkdir(mockCwd, { recursive: true });
|
||||
mockWorkspaces = [
|
||||
mockCwd,
|
||||
path.join(testRootDir, 'home', 'user', 'other-project'),
|
||||
];
|
||||
await fs.mkdir(mockWorkspaces[1], { recursive: true });
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await fs.rm(testRootDir, { recursive: true, force: true });
|
||||
});
|
||||
|
||||
const createInput = (
|
||||
toolArgs: Record<string, unknown>,
|
||||
config?: Record<string, unknown>,
|
||||
): SafetyCheckInput => ({
|
||||
protocolVersion: '1.0.0',
|
||||
toolCall: {
|
||||
name: 'test_tool',
|
||||
args: toolArgs,
|
||||
} as unknown as FunctionCall,
|
||||
context: {
|
||||
environment: {
|
||||
cwd: mockCwd,
|
||||
workspaces: mockWorkspaces,
|
||||
},
|
||||
},
|
||||
config,
|
||||
});
|
||||
|
||||
it('should allow paths within CWD', async () => {
|
||||
const filePath = path.join(mockCwd, 'file.txt');
|
||||
await fs.writeFile(filePath, 'test content');
|
||||
const input = createInput({
|
||||
path: filePath,
|
||||
});
|
||||
const result = await checker.check(input);
|
||||
expect(result.decision).toBe(SafetyCheckDecision.ALLOW);
|
||||
});
|
||||
|
||||
it('should allow paths within workspace roots', async () => {
|
||||
const filePath = path.join(mockWorkspaces[1], 'data.json');
|
||||
await fs.writeFile(filePath, 'test content');
|
||||
const input = createInput({
|
||||
path: filePath,
|
||||
});
|
||||
const result = await checker.check(input);
|
||||
expect(result.decision).toBe(SafetyCheckDecision.ALLOW);
|
||||
});
|
||||
|
||||
it('should deny paths outside allowed areas', async () => {
|
||||
const outsidePath = path.join(testRootDir, 'etc', 'passwd');
|
||||
await fs.mkdir(path.dirname(outsidePath), { recursive: true });
|
||||
await fs.writeFile(outsidePath, 'secret');
|
||||
const input = createInput({ path: outsidePath });
|
||||
const result = await checker.check(input);
|
||||
expect(result.decision).toBe(SafetyCheckDecision.DENY);
|
||||
expect(result.reason).toContain('outside of the allowed workspace');
|
||||
});
|
||||
|
||||
it('should deny paths using ../ to escape', async () => {
|
||||
const secretPath = path.join(testRootDir, 'home', 'user', 'secret.txt');
|
||||
await fs.writeFile(secretPath, 'secret');
|
||||
const input = createInput({
|
||||
path: path.join(mockCwd, '..', 'secret.txt'),
|
||||
});
|
||||
const result = await checker.check(input);
|
||||
expect(result.decision).toBe(SafetyCheckDecision.DENY);
|
||||
});
|
||||
|
||||
it('should check multiple path arguments', async () => {
|
||||
const passwdPath = path.join(testRootDir, 'etc', 'passwd');
|
||||
await fs.mkdir(path.dirname(passwdPath), { recursive: true });
|
||||
await fs.writeFile(passwdPath, 'secret');
|
||||
const srcPath = path.join(mockCwd, 'src.txt');
|
||||
await fs.writeFile(srcPath, 'source content');
|
||||
|
||||
const input = createInput({
|
||||
source: srcPath,
|
||||
destination: passwdPath,
|
||||
});
|
||||
const result = await checker.check(input);
|
||||
expect(result.decision).toBe(SafetyCheckDecision.DENY);
|
||||
expect(result.reason).toContain(passwdPath);
|
||||
});
|
||||
|
||||
it('should handle non-existent paths gracefully if they are inside allowed dir', async () => {
|
||||
const input = createInput({
|
||||
path: path.join(mockCwd, 'new-file.txt'),
|
||||
});
|
||||
const result = await checker.check(input);
|
||||
expect(result.decision).toBe(SafetyCheckDecision.ALLOW);
|
||||
});
|
||||
|
||||
it('should deny access if path contains a symlink pointing outside allowed directories', async () => {
|
||||
const symlinkPath = path.join(mockCwd, 'symlink');
|
||||
const targetPath = path.join(testRootDir, 'etc', 'passwd');
|
||||
await fs.mkdir(path.dirname(targetPath), { recursive: true });
|
||||
await fs.writeFile(targetPath, 'secret');
|
||||
|
||||
// Create symlink: mockCwd/symlink -> targetPath
|
||||
await fs.symlink(targetPath, symlinkPath);
|
||||
|
||||
const input = createInput({ path: symlinkPath });
|
||||
const result = await checker.check(input);
|
||||
expect(result.decision).toBe(SafetyCheckDecision.DENY);
|
||||
expect(result.reason).toContain(
|
||||
'outside of the allowed workspace directories',
|
||||
);
|
||||
});
|
||||
|
||||
it('should allow access if path contains a symlink pointing INSIDE allowed directories', async () => {
|
||||
const symlinkPath = path.join(mockCwd, 'symlink-inside');
|
||||
const realFilePath = path.join(mockCwd, 'real-file');
|
||||
await fs.writeFile(realFilePath, 'real content');
|
||||
|
||||
// Create symlink: mockCwd/symlink-inside -> mockCwd/real-file
|
||||
await fs.symlink(realFilePath, symlinkPath);
|
||||
|
||||
const input = createInput({ path: symlinkPath });
|
||||
const result = await checker.check(input);
|
||||
expect(result.decision).toBe(SafetyCheckDecision.ALLOW);
|
||||
});
|
||||
|
||||
it('should check explicitly included arguments', async () => {
|
||||
const outsidePath = path.join(testRootDir, 'etc', 'passwd');
|
||||
await fs.mkdir(path.dirname(outsidePath), { recursive: true });
|
||||
await fs.writeFile(outsidePath, 'secret');
|
||||
const input = createInput(
|
||||
{ custom_arg: outsidePath },
|
||||
{ included_args: ['custom_arg'] },
|
||||
);
|
||||
const result = await checker.check(input);
|
||||
expect(result.decision).toBe(SafetyCheckDecision.DENY);
|
||||
expect(result.reason).toContain('outside of the allowed workspace');
|
||||
});
|
||||
|
||||
it('should skip explicitly excluded arguments', async () => {
|
||||
const outsidePath = path.join(testRootDir, 'etc', 'passwd');
|
||||
await fs.mkdir(path.dirname(outsidePath), { recursive: true });
|
||||
await fs.writeFile(outsidePath, 'secret');
|
||||
// Normally 'path' would be checked, but we exclude it
|
||||
const input = createInput(
|
||||
{ path: outsidePath },
|
||||
{ excluded_args: ['path'] },
|
||||
);
|
||||
const result = await checker.check(input);
|
||||
expect(result.decision).toBe(SafetyCheckDecision.ALLOW);
|
||||
});
|
||||
|
||||
it('should handle both included and excluded arguments', async () => {
|
||||
const outsidePath = path.join(testRootDir, 'etc', 'passwd');
|
||||
await fs.mkdir(path.dirname(outsidePath), { recursive: true });
|
||||
await fs.writeFile(outsidePath, 'secret');
|
||||
const input = createInput(
|
||||
{
|
||||
path: outsidePath, // Excluded
|
||||
custom_arg: outsidePath, // Included
|
||||
},
|
||||
{
|
||||
excluded_args: ['path'],
|
||||
included_args: ['custom_arg'],
|
||||
},
|
||||
);
|
||||
const result = await checker.check(input);
|
||||
expect(result.decision).toBe(SafetyCheckDecision.DENY);
|
||||
// Should be denied because of custom_arg, not path
|
||||
expect(result.reason).toContain(outsidePath);
|
||||
});
|
||||
|
||||
it('should check nested path arguments', async () => {
|
||||
const outsidePath = path.join(testRootDir, 'etc', 'passwd');
|
||||
await fs.mkdir(path.dirname(outsidePath), { recursive: true });
|
||||
await fs.writeFile(outsidePath, 'secret');
|
||||
const input = createInput({
|
||||
nested: {
|
||||
path: outsidePath,
|
||||
},
|
||||
});
|
||||
const result = await checker.check(input);
|
||||
expect(result.decision).toBe(SafetyCheckDecision.DENY);
|
||||
expect(result.reason).toContain(outsidePath);
|
||||
expect(result.reason).toContain('nested.path');
|
||||
});
|
||||
|
||||
it('should support dot notation for included_args', async () => {
|
||||
const outsidePath = path.join(testRootDir, 'etc', 'passwd');
|
||||
await fs.mkdir(path.dirname(outsidePath), { recursive: true });
|
||||
await fs.writeFile(outsidePath, 'secret');
|
||||
const input = createInput(
|
||||
{
|
||||
nested: {
|
||||
custom: outsidePath,
|
||||
},
|
||||
},
|
||||
{ included_args: ['nested.custom'] },
|
||||
);
|
||||
const result = await checker.check(input);
|
||||
expect(result.decision).toBe(SafetyCheckDecision.DENY);
|
||||
expect(result.reason).toContain(outsidePath);
|
||||
expect(result.reason).toContain('nested.custom');
|
||||
});
|
||||
|
||||
it('should support dot notation for excluded_args', async () => {
|
||||
const outsidePath = path.join(testRootDir, 'etc', 'passwd');
|
||||
await fs.mkdir(path.dirname(outsidePath), { recursive: true });
|
||||
await fs.writeFile(outsidePath, 'secret');
|
||||
const input = createInput(
|
||||
{
|
||||
nested: {
|
||||
path: outsidePath,
|
||||
},
|
||||
},
|
||||
{ excluded_args: ['nested.path'] },
|
||||
);
|
||||
const result = await checker.check(input);
|
||||
expect(result.decision).toBe(SafetyCheckDecision.ALLOW);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,154 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as path from 'node:path';
|
||||
import * as fs from 'node:fs';
|
||||
import type { SafetyCheckInput, SafetyCheckResult } from './protocol.js';
|
||||
import { SafetyCheckDecision } from './protocol.js';
|
||||
import type { AllowedPathConfig } from '../policy/types.js';
|
||||
|
||||
/**
|
||||
* Interface for all in-process safety checkers.
|
||||
*/
|
||||
export interface InProcessChecker {
|
||||
check(input: SafetyCheckInput): Promise<SafetyCheckResult>;
|
||||
}
|
||||
|
||||
/**
|
||||
* An in-process checker to validate file paths.
|
||||
*/
|
||||
export class AllowedPathChecker implements InProcessChecker {
|
||||
async check(input: SafetyCheckInput): Promise<SafetyCheckResult> {
|
||||
const { toolCall, context } = input;
|
||||
const config = input.config as AllowedPathConfig | undefined;
|
||||
|
||||
// Build list of allowed directories
|
||||
const allowedDirs = [
|
||||
context.environment.cwd,
|
||||
...context.environment.workspaces,
|
||||
];
|
||||
|
||||
// Find all arguments that look like paths
|
||||
const includedArgs = config?.included_args ?? [];
|
||||
const excludedArgs = config?.excluded_args ?? [];
|
||||
|
||||
const pathsToCheck = this.collectPathsToCheck(
|
||||
toolCall.args,
|
||||
includedArgs,
|
||||
excludedArgs,
|
||||
);
|
||||
|
||||
// Check each path
|
||||
for (const { path: p, argName } of pathsToCheck) {
|
||||
const resolvedPath = this.safelyResolvePath(p, context.environment.cwd);
|
||||
|
||||
if (!resolvedPath) {
|
||||
// If path cannot be resolved, deny it
|
||||
return {
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Cannot resolve path "${p}" in argument "${argName}"`,
|
||||
};
|
||||
}
|
||||
|
||||
const isAllowed = allowedDirs.some((dir) => {
|
||||
// Also resolve allowed directories to handle symlinks
|
||||
const resolvedDir = this.safelyResolvePath(
|
||||
dir,
|
||||
context.environment.cwd,
|
||||
);
|
||||
if (!resolvedDir) return false;
|
||||
return this.isPathAllowed(resolvedPath, resolvedDir);
|
||||
});
|
||||
|
||||
if (!isAllowed) {
|
||||
return {
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Path "${p}" in argument "${argName}" is outside of the allowed workspace directories.`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return { decision: SafetyCheckDecision.ALLOW };
|
||||
}
|
||||
|
||||
private safelyResolvePath(inputPath: string, cwd: string): string | null {
|
||||
try {
|
||||
const resolved = path.resolve(cwd, inputPath);
|
||||
|
||||
// Walk up the directory tree until we find a path that exists
|
||||
let current = resolved;
|
||||
// Stop at root (dirname(root) === root on many systems, or it becomes empty/'.' depending on implementation)
|
||||
while (current && current !== path.dirname(current)) {
|
||||
if (fs.existsSync(current)) {
|
||||
const canonical = fs.realpathSync(current);
|
||||
// Re-construct the full path from this canonical base
|
||||
const relative = path.relative(current, resolved);
|
||||
// path.join handles empty relative paths correctly (returns canonical)
|
||||
return path.join(canonical, relative);
|
||||
}
|
||||
current = path.dirname(current);
|
||||
}
|
||||
|
||||
// Fallback if nothing exists (unlikely if root exists)
|
||||
return resolved;
|
||||
} catch (_error) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private isPathAllowed(targetPath: string, allowedDir: string): boolean {
|
||||
const relative = path.relative(allowedDir, targetPath);
|
||||
return (
|
||||
relative === '' ||
|
||||
(!relative.startsWith('..') && !path.isAbsolute(relative))
|
||||
);
|
||||
}
|
||||
|
||||
private collectPathsToCheck(
|
||||
args: unknown,
|
||||
includedArgs: string[],
|
||||
excludedArgs: string[],
|
||||
prefix = '',
|
||||
): Array<{ path: string; argName: string }> {
|
||||
const paths: Array<{ path: string; argName: string }> = [];
|
||||
|
||||
if (typeof args !== 'object' || args === null) {
|
||||
return paths;
|
||||
}
|
||||
|
||||
for (const [key, value] of Object.entries(args)) {
|
||||
const fullKey = prefix ? `${prefix}.${key}` : key;
|
||||
|
||||
if (excludedArgs.includes(fullKey)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (typeof value === 'string') {
|
||||
if (
|
||||
includedArgs.includes(fullKey) ||
|
||||
key.includes('path') ||
|
||||
key.includes('directory') ||
|
||||
key.includes('file') ||
|
||||
key === 'source' ||
|
||||
key === 'destination'
|
||||
) {
|
||||
paths.push({ path: value, argName: fullKey });
|
||||
}
|
||||
} else if (typeof value === 'object') {
|
||||
paths.push(
|
||||
...this.collectPathsToCheck(
|
||||
value,
|
||||
includedArgs,
|
||||
excludedArgs,
|
||||
fullKey,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return paths;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,303 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { spawn } from 'node:child_process';
|
||||
import { CheckerRunner } from './checker-runner.js';
|
||||
import { ContextBuilder } from './context-builder.js';
|
||||
import { CheckerRegistry } from './registry.js';
|
||||
import {
|
||||
type InProcessCheckerConfig,
|
||||
InProcessCheckerType,
|
||||
} from '../policy/types.js';
|
||||
import type { SafetyCheckResult } from './protocol.js';
|
||||
import { SafetyCheckDecision } from './protocol.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('./registry.js');
|
||||
vi.mock('./context-builder.js');
|
||||
vi.mock('node:child_process');
|
||||
|
||||
describe('CheckerRunner', () => {
|
||||
let runner: CheckerRunner;
|
||||
let mockContextBuilder: ContextBuilder;
|
||||
let mockRegistry: CheckerRegistry;
|
||||
|
||||
const mockToolCall = { name: 'test_tool', args: {} };
|
||||
const mockInProcessConfig: InProcessCheckerConfig = {
|
||||
type: 'in-process',
|
||||
name: InProcessCheckerType.ALLOWED_PATH,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
mockContextBuilder = new ContextBuilder({} as Config);
|
||||
mockRegistry = new CheckerRegistry('/mock/dist');
|
||||
CheckerRegistry.prototype.resolveInProcess = vi.fn();
|
||||
|
||||
runner = new CheckerRunner(mockContextBuilder, mockRegistry, {
|
||||
checkersPath: '/mock/dist',
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should run in-process checker successfully', async () => {
|
||||
const mockResult: SafetyCheckResult = {
|
||||
decision: SafetyCheckDecision.ALLOW,
|
||||
};
|
||||
const mockChecker = {
|
||||
check: vi.fn().mockResolvedValue(mockResult),
|
||||
};
|
||||
vi.mocked(mockRegistry.resolveInProcess).mockReturnValue(mockChecker);
|
||||
vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({
|
||||
environment: { cwd: '/tmp', workspaces: [] },
|
||||
});
|
||||
|
||||
const result = await runner.runChecker(mockToolCall, mockInProcessConfig);
|
||||
|
||||
expect(result).toEqual(mockResult);
|
||||
expect(mockRegistry.resolveInProcess).toHaveBeenCalledWith(
|
||||
InProcessCheckerType.ALLOWED_PATH,
|
||||
);
|
||||
expect(mockChecker.check).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle in-process checker errors', async () => {
|
||||
const mockChecker = {
|
||||
check: vi.fn().mockRejectedValue(new Error('Checker failed')),
|
||||
};
|
||||
vi.mocked(mockRegistry.resolveInProcess).mockReturnValue(mockChecker);
|
||||
vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({
|
||||
environment: { cwd: '/tmp', workspaces: [] },
|
||||
});
|
||||
|
||||
const result = await runner.runChecker(mockToolCall, mockInProcessConfig);
|
||||
|
||||
expect(result.decision).toBe(SafetyCheckDecision.DENY);
|
||||
expect(result.reason).toContain('Failed to run in-process checker');
|
||||
expect(result.reason).toContain('Checker failed');
|
||||
});
|
||||
|
||||
it('should respect timeout for in-process checkers', async () => {
|
||||
vi.useFakeTimers();
|
||||
const mockChecker = {
|
||||
check: vi.fn().mockImplementation(async () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 6000)); // Longer than default 5s timeout
|
||||
return { decision: SafetyCheckDecision.ALLOW };
|
||||
}),
|
||||
};
|
||||
vi.mocked(mockRegistry.resolveInProcess).mockReturnValue(mockChecker);
|
||||
vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({
|
||||
environment: { cwd: '/tmp', workspaces: [] },
|
||||
});
|
||||
|
||||
const runPromise = runner.runChecker(mockToolCall, mockInProcessConfig);
|
||||
vi.advanceTimersByTime(5001);
|
||||
|
||||
const result = await runPromise;
|
||||
expect(result.decision).toBe(SafetyCheckDecision.DENY);
|
||||
expect(result.reason).toContain('timed out');
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should use minimal context when requested', async () => {
|
||||
const configWithContext: InProcessCheckerConfig = {
|
||||
...mockInProcessConfig,
|
||||
required_context: ['environment'],
|
||||
};
|
||||
const mockChecker = {
|
||||
check: vi.fn().mockResolvedValue({ decision: SafetyCheckDecision.ALLOW }),
|
||||
};
|
||||
vi.mocked(mockRegistry.resolveInProcess).mockReturnValue(mockChecker);
|
||||
vi.mocked(mockContextBuilder.buildMinimalContext).mockReturnValue({
|
||||
environment: { cwd: '/tmp', workspaces: [] },
|
||||
});
|
||||
|
||||
await runner.runChecker(mockToolCall, configWithContext);
|
||||
|
||||
expect(mockContextBuilder.buildMinimalContext).toHaveBeenCalledWith([
|
||||
'environment',
|
||||
]);
|
||||
expect(mockContextBuilder.buildFullContext).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should pass config to in-process checker via toolCall', async () => {
|
||||
const mockConfig = { included_args: ['foo'] };
|
||||
const configWithConfig: InProcessCheckerConfig = {
|
||||
...mockInProcessConfig,
|
||||
config: mockConfig,
|
||||
};
|
||||
const mockResult: SafetyCheckResult = {
|
||||
decision: SafetyCheckDecision.ALLOW,
|
||||
};
|
||||
const mockChecker = {
|
||||
check: vi.fn().mockResolvedValue(mockResult),
|
||||
};
|
||||
vi.mocked(mockRegistry.resolveInProcess).mockReturnValue(mockChecker);
|
||||
vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({
|
||||
environment: { cwd: '/tmp', workspaces: [] },
|
||||
});
|
||||
|
||||
await runner.runChecker(mockToolCall, configWithConfig);
|
||||
|
||||
expect(mockChecker.check).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
toolCall: mockToolCall,
|
||||
config: mockConfig,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
describe('External Checkers', () => {
|
||||
const mockExternalConfig = {
|
||||
type: 'external' as const,
|
||||
name: 'python-checker',
|
||||
};
|
||||
|
||||
it('should spawn external checker directly', async () => {
|
||||
const mockCheckerPath = '/mock/dist/python-checker';
|
||||
vi.mocked(mockRegistry.resolveExternal).mockReturnValue(mockCheckerPath);
|
||||
vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({
|
||||
environment: { cwd: '/tmp', workspaces: [] },
|
||||
});
|
||||
|
||||
const mockStdout = {
|
||||
on: vi.fn().mockImplementation((event, callback) => {
|
||||
if (event === 'data') {
|
||||
callback(
|
||||
Buffer.from(
|
||||
JSON.stringify({ decision: SafetyCheckDecision.ALLOW }),
|
||||
),
|
||||
);
|
||||
}
|
||||
}),
|
||||
};
|
||||
const mockChildProcess = {
|
||||
stdin: { write: vi.fn(), end: vi.fn() },
|
||||
stdout: mockStdout,
|
||||
stderr: { on: vi.fn() },
|
||||
on: vi.fn().mockImplementation((event, callback) => {
|
||||
if (event === 'close') {
|
||||
// Defer the close callback slightly to allow stdout 'data' to be registered
|
||||
setTimeout(() => callback(0), 0);
|
||||
}
|
||||
}),
|
||||
kill: vi.fn(),
|
||||
};
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
vi.mocked(spawn).mockReturnValue(mockChildProcess as any);
|
||||
|
||||
const result = await runner.runChecker(mockToolCall, mockExternalConfig);
|
||||
|
||||
expect(result.decision).toBe(SafetyCheckDecision.ALLOW);
|
||||
expect(spawn).toHaveBeenCalledWith(
|
||||
mockCheckerPath,
|
||||
[],
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should include checker name in timeout error message', async () => {
|
||||
vi.useFakeTimers();
|
||||
const mockCheckerPath = '/mock/dist/python-checker';
|
||||
vi.mocked(mockRegistry.resolveExternal).mockReturnValue(mockCheckerPath);
|
||||
vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({
|
||||
environment: { cwd: '/tmp', workspaces: [] },
|
||||
});
|
||||
|
||||
const mockChildProcess = {
|
||||
stdin: { write: vi.fn(), end: vi.fn() },
|
||||
stdout: { on: vi.fn() },
|
||||
stderr: { on: vi.fn() },
|
||||
on: vi.fn(), // Never calls 'close'
|
||||
kill: vi.fn(),
|
||||
};
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
vi.mocked(spawn).mockReturnValue(mockChildProcess as any);
|
||||
|
||||
const runPromise = runner.runChecker(mockToolCall, mockExternalConfig);
|
||||
vi.advanceTimersByTime(5001);
|
||||
|
||||
const result = await runPromise;
|
||||
expect(result.decision).toBe(SafetyCheckDecision.DENY);
|
||||
expect(result.reason).toContain(
|
||||
'Safety checker "python-checker" timed out',
|
||||
);
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should send SIGKILL if process ignores SIGTERM', async () => {
|
||||
vi.useFakeTimers();
|
||||
const mockCheckerPath = '/mock/dist/python-checker';
|
||||
vi.mocked(mockRegistry.resolveExternal).mockReturnValue(mockCheckerPath);
|
||||
vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({
|
||||
environment: { cwd: '/tmp', workspaces: [] },
|
||||
});
|
||||
|
||||
const mockChildProcess = {
|
||||
stdin: { write: vi.fn(), end: vi.fn() },
|
||||
stdout: { on: vi.fn() },
|
||||
stderr: { on: vi.fn() },
|
||||
on: vi.fn(), // Never calls 'close' automatically
|
||||
kill: vi.fn(),
|
||||
};
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
vi.mocked(spawn).mockReturnValue(mockChildProcess as any);
|
||||
|
||||
const runPromise = runner.runChecker(mockToolCall, mockExternalConfig);
|
||||
|
||||
// Trigger main timeout
|
||||
vi.advanceTimersByTime(5001);
|
||||
|
||||
// Should have sent SIGTERM
|
||||
expect(mockChildProcess.kill).toHaveBeenCalledWith('SIGTERM');
|
||||
|
||||
// Advance past cleanup timeout (5000ms)
|
||||
vi.advanceTimersByTime(5000);
|
||||
|
||||
// Should have sent SIGKILL
|
||||
expect(mockChildProcess.kill).toHaveBeenCalledWith('SIGKILL');
|
||||
|
||||
// Clean up promise
|
||||
await runPromise;
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should include checker name in non-zero exit code error message', async () => {
|
||||
const mockCheckerPath = '/mock/dist/python-checker';
|
||||
vi.mocked(mockRegistry.resolveExternal).mockReturnValue(mockCheckerPath);
|
||||
vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({
|
||||
environment: { cwd: '/tmp', workspaces: [] },
|
||||
});
|
||||
|
||||
const mockChildProcess = {
|
||||
stdin: { write: vi.fn(), end: vi.fn() },
|
||||
stdout: { on: vi.fn() },
|
||||
stderr: { on: vi.fn() },
|
||||
on: vi.fn().mockImplementation((event, callback) => {
|
||||
if (event === 'close') {
|
||||
callback(1); // Exit code 1
|
||||
}
|
||||
}),
|
||||
kill: vi.fn(),
|
||||
};
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
vi.mocked(spawn).mockReturnValue(mockChildProcess as any);
|
||||
|
||||
const result = await runner.runChecker(mockToolCall, mockExternalConfig);
|
||||
|
||||
expect(result.decision).toBe(SafetyCheckDecision.DENY);
|
||||
expect(result.reason).toContain(
|
||||
'Safety checker "python-checker" exited with code 1',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,297 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { spawn } from 'node:child_process';
|
||||
import type { FunctionCall } from '@google/genai';
|
||||
import type {
|
||||
SafetyCheckerConfig,
|
||||
InProcessCheckerConfig,
|
||||
ExternalCheckerConfig,
|
||||
} from '../policy/types.js';
|
||||
import type { SafetyCheckInput, SafetyCheckResult } from './protocol.js';
|
||||
import { SafetyCheckDecision } from './protocol.js';
|
||||
import type { CheckerRegistry } from './registry.js';
|
||||
import type { ContextBuilder } from './context-builder.js';
|
||||
|
||||
/**
|
||||
* Configuration for the checker runner.
|
||||
*/
|
||||
export interface CheckerRunnerConfig {
|
||||
/**
|
||||
* Maximum time (in milliseconds) to wait for a checker to complete.
|
||||
* Default: 5000 (5 seconds)
|
||||
*/
|
||||
timeout?: number;
|
||||
|
||||
/**
|
||||
* Path to the directory containing external checkers.
|
||||
*/
|
||||
checkersPath: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Service for executing safety checker processes.
|
||||
*/
|
||||
export class CheckerRunner {
|
||||
private static readonly DEFAULT_TIMEOUT = 5000; // 5 seconds
|
||||
|
||||
private readonly registry: CheckerRegistry;
|
||||
private readonly contextBuilder: ContextBuilder;
|
||||
private readonly timeout: number;
|
||||
|
||||
constructor(
|
||||
contextBuilder: ContextBuilder,
|
||||
registry: CheckerRegistry,
|
||||
config: CheckerRunnerConfig,
|
||||
) {
|
||||
this.contextBuilder = contextBuilder;
|
||||
this.registry = registry;
|
||||
this.timeout = config.timeout ?? CheckerRunner.DEFAULT_TIMEOUT;
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs a safety checker and returns the result.
|
||||
*/
|
||||
async runChecker(
|
||||
toolCall: FunctionCall,
|
||||
checkerConfig: SafetyCheckerConfig,
|
||||
): Promise<SafetyCheckResult> {
|
||||
if (checkerConfig.type === 'in-process') {
|
||||
return this.runInProcessChecker(toolCall, checkerConfig);
|
||||
}
|
||||
return this.runExternalChecker(toolCall, checkerConfig);
|
||||
}
|
||||
|
||||
private async runInProcessChecker(
|
||||
toolCall: FunctionCall,
|
||||
checkerConfig: InProcessCheckerConfig,
|
||||
): Promise<SafetyCheckResult> {
|
||||
try {
|
||||
const checker = this.registry.resolveInProcess(checkerConfig.name);
|
||||
const context = checkerConfig.required_context
|
||||
? this.contextBuilder.buildMinimalContext(
|
||||
checkerConfig.required_context,
|
||||
)
|
||||
: this.contextBuilder.buildFullContext();
|
||||
|
||||
const input: SafetyCheckInput = {
|
||||
protocolVersion: '1.0.0',
|
||||
toolCall,
|
||||
context,
|
||||
config: checkerConfig.config,
|
||||
};
|
||||
|
||||
// In-process checkers can be async, but we'll also apply a timeout
|
||||
// for safety, in case of infinite loops or unexpected delays.
|
||||
return await this.executeWithTimeout(checker.check(input));
|
||||
} catch (error) {
|
||||
return {
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Failed to run in-process checker "${checkerConfig.name}": ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private async runExternalChecker(
|
||||
toolCall: FunctionCall,
|
||||
checkerConfig: ExternalCheckerConfig,
|
||||
): Promise<SafetyCheckResult> {
|
||||
try {
|
||||
// Resolve the checker executable path
|
||||
const checkerPath = this.registry.resolveExternal(checkerConfig.name);
|
||||
|
||||
// Build the appropriate context
|
||||
const context = checkerConfig.required_context
|
||||
? this.contextBuilder.buildMinimalContext(
|
||||
checkerConfig.required_context,
|
||||
)
|
||||
: this.contextBuilder.buildFullContext();
|
||||
|
||||
// Create the input payload
|
||||
const input: SafetyCheckInput = {
|
||||
protocolVersion: '1.0.0',
|
||||
toolCall,
|
||||
context,
|
||||
config: checkerConfig.config,
|
||||
};
|
||||
|
||||
// Run the checker process
|
||||
return await this.executeCheckerProcess(
|
||||
checkerPath,
|
||||
input,
|
||||
checkerConfig.name,
|
||||
);
|
||||
} catch (error) {
|
||||
// If anything goes wrong, deny the operation
|
||||
return {
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Failed to run safety checker "${checkerConfig.name}": ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes an external checker process and handles its lifecycle.
|
||||
*/
|
||||
private executeCheckerProcess(
|
||||
checkerPath: string,
|
||||
input: SafetyCheckInput,
|
||||
checkerName: string,
|
||||
): Promise<SafetyCheckResult> {
|
||||
return new Promise((resolve) => {
|
||||
const child = spawn(checkerPath, [], {
|
||||
stdio: ['pipe', 'pipe', 'pipe'],
|
||||
});
|
||||
|
||||
let stdout = '';
|
||||
let stderr = '';
|
||||
let timeoutHandle: NodeJS.Timeout | null = null;
|
||||
let killed = false;
|
||||
|
||||
let exited = false;
|
||||
|
||||
// Set up timeout
|
||||
timeoutHandle = setTimeout(() => {
|
||||
killed = true;
|
||||
child.kill('SIGTERM');
|
||||
resolve({
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Safety checker "${checkerName}" timed out after ${this.timeout}ms`,
|
||||
});
|
||||
|
||||
// Fallback: if process doesn't exit after 5s, force kill
|
||||
setTimeout(() => {
|
||||
if (!exited) {
|
||||
child.kill('SIGKILL');
|
||||
}
|
||||
}, 5000).unref();
|
||||
}, this.timeout);
|
||||
|
||||
// Collect output
|
||||
if (child.stdout) {
|
||||
child.stdout.on('data', (data: Buffer) => {
|
||||
stdout += data.toString();
|
||||
});
|
||||
}
|
||||
|
||||
if (child.stderr) {
|
||||
child.stderr.on('data', (data: Buffer) => {
|
||||
stderr += data.toString();
|
||||
});
|
||||
}
|
||||
|
||||
// Handle process completion
|
||||
child.on('close', (code: number | null) => {
|
||||
exited = true;
|
||||
if (timeoutHandle) {
|
||||
clearTimeout(timeoutHandle);
|
||||
}
|
||||
|
||||
// If we already killed it due to timeout, don't process the result
|
||||
if (killed) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Non-zero exit code is a failure
|
||||
if (code !== 0) {
|
||||
resolve({
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Safety checker "${checkerName}" exited with code ${code}${
|
||||
stderr ? `: ${stderr}` : ''
|
||||
}`,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to parse the output
|
||||
try {
|
||||
const result: SafetyCheckResult = JSON.parse(stdout);
|
||||
|
||||
// Validate the result structure
|
||||
if (
|
||||
!result.decision ||
|
||||
!Object.values(SafetyCheckDecision).includes(result.decision)
|
||||
) {
|
||||
throw new Error(
|
||||
'Invalid result: missing or invalid "decision" field',
|
||||
);
|
||||
}
|
||||
|
||||
resolve(result);
|
||||
} catch (parseError) {
|
||||
resolve({
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Failed to parse output from safety checker "${checkerName}": ${
|
||||
parseError instanceof Error
|
||||
? parseError.message
|
||||
: String(parseError)
|
||||
}`,
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// Handle process errors
|
||||
child.on('error', (error: Error) => {
|
||||
if (timeoutHandle) {
|
||||
clearTimeout(timeoutHandle);
|
||||
}
|
||||
|
||||
if (!killed) {
|
||||
resolve({
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Failed to spawn safety checker "${checkerName}": ${error.message}`,
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// Send input to the checker
|
||||
try {
|
||||
if (child.stdin) {
|
||||
child.stdin.write(JSON.stringify(input));
|
||||
child.stdin.end();
|
||||
} else {
|
||||
throw new Error('Failed to open stdin for checker process');
|
||||
}
|
||||
} catch (writeError) {
|
||||
if (timeoutHandle) {
|
||||
clearTimeout(timeoutHandle);
|
||||
}
|
||||
|
||||
child.kill();
|
||||
resolve({
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: `Failed to write to stdin of safety checker "${checkerName}": ${
|
||||
writeError instanceof Error
|
||||
? writeError.message
|
||||
: String(writeError)
|
||||
}`,
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes a promise with a timeout.
|
||||
*/
|
||||
private executeWithTimeout<T>(promise: Promise<T>): Promise<T> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const timeoutHandle = setTimeout(() => {
|
||||
reject(new Error(`Checker timed out after ${this.timeout}ms`));
|
||||
}, this.timeout);
|
||||
|
||||
promise
|
||||
.then(resolve)
|
||||
.catch(reject)
|
||||
.finally(() => {
|
||||
clearTimeout(timeoutHandle);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { ContextBuilder } from './context-builder.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { ConversationTurn } from './protocol.js';
|
||||
|
||||
describe('ContextBuilder', () => {
|
||||
let contextBuilder: ContextBuilder;
|
||||
let mockConfig: Config;
|
||||
const mockHistory: ConversationTurn[] = [
|
||||
{ user: { text: 'hello' }, model: { text: 'hi' } },
|
||||
];
|
||||
const mockCwd = '/home/user/project';
|
||||
const mockWorkspaces = ['/home/user/project'];
|
||||
|
||||
beforeEach(() => {
|
||||
vi.spyOn(process, 'cwd').mockReturnValue(mockCwd);
|
||||
mockConfig = {
|
||||
getWorkspaceContext: vi.fn().mockReturnValue({
|
||||
getDirectories: vi.fn().mockReturnValue(mockWorkspaces),
|
||||
}),
|
||||
apiKey: 'secret-api-key',
|
||||
somePublicConfig: 'public-value',
|
||||
nested: {
|
||||
secretToken: 'hidden',
|
||||
public: 'visible',
|
||||
},
|
||||
} as unknown as Config;
|
||||
contextBuilder = new ContextBuilder(mockConfig, mockHistory);
|
||||
});
|
||||
|
||||
it('should build full context with all fields', () => {
|
||||
const context = contextBuilder.buildFullContext();
|
||||
expect(context.environment.cwd).toBe(mockCwd);
|
||||
expect(context.environment.workspaces).toEqual(mockWorkspaces);
|
||||
expect(context.history?.turns).toEqual(mockHistory);
|
||||
});
|
||||
|
||||
it('should build minimal context with only required keys', () => {
|
||||
const context = contextBuilder.buildMinimalContext(['environment']);
|
||||
expect(context).toHaveProperty('environment');
|
||||
expect(context).not.toHaveProperty('config');
|
||||
expect(context).not.toHaveProperty('history');
|
||||
});
|
||||
|
||||
it('should handle missing history', () => {
|
||||
contextBuilder = new ContextBuilder(mockConfig);
|
||||
const context = contextBuilder.buildFullContext();
|
||||
expect(context.history?.turns).toEqual([]);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,54 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { SafetyCheckInput, ConversationTurn } from './protocol.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
|
||||
/**
|
||||
* Builds context objects for safety checkers, ensuring sensitive data is filtered.
|
||||
*/
|
||||
export class ContextBuilder {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly conversationHistory: ConversationTurn[] = [],
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Builds the full context object with all available data.
|
||||
*/
|
||||
buildFullContext(): SafetyCheckInput['context'] {
|
||||
return {
|
||||
environment: {
|
||||
cwd: process.cwd(),
|
||||
workspaces: this.config
|
||||
.getWorkspaceContext()
|
||||
.getDirectories() as string[],
|
||||
},
|
||||
history: {
|
||||
turns: this.conversationHistory,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a minimal context with only the specified keys.
|
||||
*/
|
||||
buildMinimalContext(
|
||||
requiredKeys: Array<keyof SafetyCheckInput['context']>,
|
||||
): SafetyCheckInput['context'] {
|
||||
const fullContext = this.buildFullContext();
|
||||
const minimalContext: Partial<SafetyCheckInput['context']> = {};
|
||||
|
||||
for (const key of requiredKeys) {
|
||||
if (key in fullContext) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(minimalContext as any)[key] = fullContext[key];
|
||||
}
|
||||
}
|
||||
|
||||
return minimalContext as SafetyCheckInput['context'];
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { FunctionCall } from '@google/genai';
|
||||
|
||||
/**
|
||||
* Represents a single turn in the conversation between the user and the model.
|
||||
* This provides semantic context for why a tool call might be happening.
|
||||
*/
|
||||
export interface ConversationTurn {
|
||||
user: {
|
||||
text: string;
|
||||
};
|
||||
model: {
|
||||
text?: string;
|
||||
toolCalls?: FunctionCall[];
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* The data structure passed from the CLI to a safety checker process via stdin.
|
||||
*/
|
||||
export interface SafetyCheckInput {
|
||||
/**
|
||||
* The semantic version of the protocol (e.g., "1.0.0"). This allows
|
||||
* for introducing breaking changes in the future while maintaining
|
||||
* support for older checkers.
|
||||
*/
|
||||
protocolVersion: '1.0.0';
|
||||
|
||||
/**
|
||||
* The specific tool call that is being validated.
|
||||
*/
|
||||
toolCall: FunctionCall;
|
||||
|
||||
/**
|
||||
* A container for all contextual information from the CLI's internal state.
|
||||
* By grouping data into categories, we can easily add new context in the
|
||||
* future without creating a flat, unmanageable object.
|
||||
*/
|
||||
context: {
|
||||
/**
|
||||
* Information about the user's file system and execution environment.
|
||||
*/
|
||||
environment: {
|
||||
cwd: string;
|
||||
workspaces: string[]; // A list of user-configured workspace roots
|
||||
};
|
||||
|
||||
/**
|
||||
* The recent history of the conversation. This can be used by checkers
|
||||
* that need to understand the intent behind a tool call.
|
||||
*/
|
||||
history?: {
|
||||
turns: ConversationTurn[];
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Configuration for the safety checker.
|
||||
* This allows checkers to be parameterized (e.g. allowed paths).
|
||||
*/
|
||||
config?: unknown;
|
||||
}
|
||||
|
||||
/**
|
||||
* The possible decisions a safety checker can make.
|
||||
*/
|
||||
export enum SafetyCheckDecision {
|
||||
ALLOW = 'allow',
|
||||
DENY = 'deny',
|
||||
ASK_USER = 'ask_user',
|
||||
}
|
||||
|
||||
/**
|
||||
* The data structure returned by a safety checker process via stdout.
|
||||
*/
|
||||
export type SafetyCheckResult =
|
||||
| {
|
||||
/**
|
||||
* The decision made by the safety checker.
|
||||
*/
|
||||
decision: SafetyCheckDecision.ALLOW;
|
||||
/**
|
||||
* If not allowed, a message explaining why the tool call was blocked.
|
||||
* This will be shown to the user.
|
||||
*/
|
||||
reason?: string;
|
||||
}
|
||||
| {
|
||||
decision: SafetyCheckDecision.DENY;
|
||||
reason: string;
|
||||
}
|
||||
| {
|
||||
decision: SafetyCheckDecision.ASK_USER;
|
||||
reason: string;
|
||||
};
|
||||
@@ -0,0 +1,47 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach } from 'vitest';
|
||||
import { CheckerRegistry } from './registry.js';
|
||||
import { InProcessCheckerType } from '../policy/types.js';
|
||||
import { AllowedPathChecker } from './built-in.js';
|
||||
|
||||
describe('CheckerRegistry', () => {
|
||||
let registry: CheckerRegistry;
|
||||
const mockCheckersPath = '/mock/checkers/path';
|
||||
|
||||
beforeEach(() => {
|
||||
registry = new CheckerRegistry(mockCheckersPath);
|
||||
});
|
||||
|
||||
it('should resolve built-in in-process checkers', () => {
|
||||
const checker = registry.resolveInProcess(
|
||||
InProcessCheckerType.ALLOWED_PATH,
|
||||
);
|
||||
expect(checker).toBeInstanceOf(AllowedPathChecker);
|
||||
});
|
||||
|
||||
it('should throw for unknown in-process checkers', () => {
|
||||
expect(() => registry.resolveInProcess('unknown-checker')).toThrow(
|
||||
'Unknown in-process checker "unknown-checker"',
|
||||
);
|
||||
});
|
||||
|
||||
it('should validate checker names', () => {
|
||||
expect(() => registry.resolveInProcess('invalid name!')).toThrow(
|
||||
'Invalid checker name',
|
||||
);
|
||||
expect(() => registry.resolveInProcess('../escape')).toThrow(
|
||||
'Invalid checker name',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw for unknown external checkers (for now)', () => {
|
||||
expect(() => registry.resolveExternal('some-external')).toThrow(
|
||||
'Unknown external checker "some-external"',
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,83 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as path from 'node:path';
|
||||
import * as fs from 'node:fs';
|
||||
import { type InProcessChecker, AllowedPathChecker } from './built-in.js';
|
||||
import { InProcessCheckerType } from '../policy/types.js';
|
||||
|
||||
/**
|
||||
* Registry for managing safety checker resolution.
|
||||
*/
|
||||
export class CheckerRegistry {
|
||||
private static readonly BUILT_IN_EXTERNAL_CHECKERS = new Map<string, string>([
|
||||
// No external built-ins for now
|
||||
]);
|
||||
|
||||
private static readonly BUILT_IN_IN_PROCESS_CHECKERS = new Map<
|
||||
string,
|
||||
InProcessChecker
|
||||
>([[InProcessCheckerType.ALLOWED_PATH, new AllowedPathChecker()]]);
|
||||
|
||||
// Regex to validate checker names (alphanumeric and hyphens only)
|
||||
private static readonly VALID_NAME_PATTERN = /^[a-z0-9-]+$/;
|
||||
|
||||
constructor(private readonly checkersPath: string) {}
|
||||
|
||||
/**
|
||||
* Resolves an external checker name to an absolute executable path.
|
||||
*/
|
||||
resolveExternal(name: string): string {
|
||||
if (!CheckerRegistry.isValidCheckerName(name)) {
|
||||
throw new Error(
|
||||
`Invalid checker name "${name}". Checker names must contain only lowercase letters, numbers, and hyphens.`,
|
||||
);
|
||||
}
|
||||
|
||||
const builtInPath = CheckerRegistry.BUILT_IN_EXTERNAL_CHECKERS.get(name);
|
||||
if (builtInPath) {
|
||||
const fullPath = path.join(this.checkersPath, builtInPath);
|
||||
if (!fs.existsSync(fullPath)) {
|
||||
throw new Error(`Built-in checker "${name}" not found at ${fullPath}`);
|
||||
}
|
||||
return fullPath;
|
||||
}
|
||||
|
||||
// TODO: Phase 5 - Add support for custom external checkers
|
||||
throw new Error(`Unknown external checker "${name}".`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves an in-process checker name to a checker instance.
|
||||
*/
|
||||
resolveInProcess(name: string): InProcessChecker {
|
||||
if (!CheckerRegistry.isValidCheckerName(name)) {
|
||||
throw new Error(`Invalid checker name "${name}".`);
|
||||
}
|
||||
|
||||
const checker = CheckerRegistry.BUILT_IN_IN_PROCESS_CHECKERS.get(name);
|
||||
if (checker) {
|
||||
return checker;
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
`Unknown in-process checker "${name}". Available: ${Array.from(
|
||||
CheckerRegistry.BUILT_IN_IN_PROCESS_CHECKERS.keys(),
|
||||
).join(', ')}`,
|
||||
);
|
||||
}
|
||||
|
||||
private static isValidCheckerName(name: string): boolean {
|
||||
return this.VALID_NAME_PATTERN.test(name) && !name.includes('..');
|
||||
}
|
||||
|
||||
static getBuiltInCheckers(): string[] {
|
||||
return [
|
||||
...Array.from(this.BUILT_IN_EXTERNAL_CHECKERS.keys()),
|
||||
...Array.from(this.BUILT_IN_IN_PROCESS_CHECKERS.keys()),
|
||||
];
|
||||
}
|
||||
}
|
||||
@@ -47,11 +47,13 @@ describe('BaseToolInvocation', () => {
|
||||
);
|
||||
|
||||
let capturedRequest: ToolConfirmationRequest | undefined;
|
||||
vi.mocked(messageBus.publish).mockImplementation((request: Message) => {
|
||||
if (request.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) {
|
||||
capturedRequest = request;
|
||||
}
|
||||
});
|
||||
vi.mocked(messageBus.publish).mockImplementation(
|
||||
async (request: Message) => {
|
||||
if (request.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) {
|
||||
capturedRequest = request;
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let responseHandler:
|
||||
| ((response: ToolConfirmationResponse) => void)
|
||||
@@ -102,11 +104,13 @@ describe('BaseToolInvocation', () => {
|
||||
);
|
||||
|
||||
let capturedRequest: ToolConfirmationRequest | undefined;
|
||||
vi.mocked(messageBus.publish).mockImplementation((request: Message) => {
|
||||
if (request.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) {
|
||||
capturedRequest = request;
|
||||
}
|
||||
});
|
||||
vi.mocked(messageBus.publish).mockImplementation(
|
||||
async (request: Message) => {
|
||||
if (request.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) {
|
||||
capturedRequest = request;
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
// We need to mock subscribe to avoid hanging if we want to await the promise,
|
||||
// but for this test we just need to check publish.
|
||||
|
||||
Reference in New Issue
Block a user