feat(safety): Introduce safety checker framework (#12504)

This commit is contained in:
Allen Hutchison
2025-11-12 13:18:34 -08:00
committed by GitHub
parent aa9922bc98
commit 1ed163a666
21 changed files with 2636 additions and 328 deletions
@@ -26,12 +26,12 @@ describe('MessageBus', () => {
});
describe('publish', () => {
it('should emit error for invalid message', () => {
it('should emit error for invalid message', async () => {
const errorHandler = vi.fn();
messageBus.on('error', errorHandler);
// @ts-expect-error - Testing invalid message
messageBus.publish({ invalid: 'message' });
await messageBus.publish({ invalid: 'message' });
expect(errorHandler).toHaveBeenCalledWith(
expect.objectContaining({
@@ -40,12 +40,12 @@ describe('MessageBus', () => {
);
});
it('should validate tool confirmation requests have correlationId', () => {
it('should validate tool confirmation requests have correlationId', async () => {
const errorHandler = vi.fn();
messageBus.on('error', errorHandler);
// @ts-expect-error - Testing missing correlationId
messageBus.publish({
await messageBus.publish({
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
toolCall: { name: 'test' },
});
@@ -53,8 +53,10 @@ describe('MessageBus', () => {
expect(errorHandler).toHaveBeenCalled();
});
it('should emit confirmation response when policy allows', () => {
vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.ALLOW);
it('should emit confirmation response when policy allows', async () => {
vi.spyOn(policyEngine, 'check').mockResolvedValue({
decision: PolicyDecision.ALLOW,
});
const responseHandler = vi.fn();
messageBus.subscribe(
@@ -68,7 +70,7 @@ describe('MessageBus', () => {
correlationId: '123',
};
messageBus.publish(request);
await messageBus.publish(request);
const expectedResponse: ToolConfirmationResponse = {
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
@@ -78,8 +80,10 @@ describe('MessageBus', () => {
expect(responseHandler).toHaveBeenCalledWith(expectedResponse);
});
it('should emit rejection and response when policy denies', () => {
vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.DENY);
it('should emit rejection and response when policy denies', async () => {
vi.spyOn(policyEngine, 'check').mockResolvedValue({
decision: PolicyDecision.DENY,
});
const responseHandler = vi.fn();
const rejectionHandler = vi.fn();
@@ -98,7 +102,7 @@ describe('MessageBus', () => {
correlationId: '123',
};
messageBus.publish(request);
await messageBus.publish(request);
const expectedRejection: ToolPolicyRejection = {
type: MessageBusType.TOOL_POLICY_REJECTION,
@@ -114,8 +118,10 @@ describe('MessageBus', () => {
expect(responseHandler).toHaveBeenCalledWith(expectedResponse);
});
it('should pass through to UI when policy says ASK_USER', () => {
vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.ASK_USER);
it('should pass through to UI when policy says ASK_USER', async () => {
vi.spyOn(policyEngine, 'check').mockResolvedValue({
decision: PolicyDecision.ASK_USER,
});
const requestHandler = vi.fn();
messageBus.subscribe(
@@ -129,12 +135,12 @@ describe('MessageBus', () => {
correlationId: '123',
};
messageBus.publish(request);
await messageBus.publish(request);
expect(requestHandler).toHaveBeenCalledWith(request);
});
it('should emit other message types directly', () => {
it('should emit other message types directly', async () => {
const successHandler = vi.fn();
messageBus.subscribe(
MessageBusType.TOOL_EXECUTION_SUCCESS,
@@ -147,14 +153,14 @@ describe('MessageBus', () => {
result: 'success',
};
messageBus.publish(message);
await messageBus.publish(message);
expect(successHandler).toHaveBeenCalledWith(message);
});
});
describe('subscribe/unsubscribe', () => {
it('should allow subscribing to specific message types', () => {
it('should allow subscribing to specific message types', async () => {
const handler = vi.fn();
messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler);
@@ -164,12 +170,12 @@ describe('MessageBus', () => {
result: 'test',
};
messageBus.publish(message);
await messageBus.publish(message);
expect(handler).toHaveBeenCalledWith(message);
});
it('should allow unsubscribing from message types', () => {
it('should allow unsubscribing from message types', async () => {
const handler = vi.fn();
messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler);
messageBus.unsubscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler);
@@ -180,12 +186,12 @@ describe('MessageBus', () => {
result: 'test',
};
messageBus.publish(message);
await messageBus.publish(message);
expect(handler).not.toHaveBeenCalled();
});
it('should support multiple subscribers for the same message type', () => {
it('should support multiple subscribers for the same message type', async () => {
const handler1 = vi.fn();
const handler2 = vi.fn();
@@ -198,7 +204,7 @@ describe('MessageBus', () => {
result: 'test',
};
messageBus.publish(message);
await messageBus.publish(message);
expect(handler1).toHaveBeenCalledWith(message);
expect(handler2).toHaveBeenCalledWith(message);
@@ -206,12 +212,12 @@ describe('MessageBus', () => {
});
describe('error handling', () => {
it('should not crash on errors during message processing', () => {
it('should not crash on errors during message processing', async () => {
const errorHandler = vi.fn();
messageBus.on('error', errorHandler);
// Mock policyEngine to throw an error
vi.spyOn(policyEngine, 'check').mockImplementation(() => {
vi.spyOn(policyEngine, 'check').mockImplementation(async () => {
throw new Error('Policy check failed');
});
@@ -222,7 +228,7 @@ describe('MessageBus', () => {
};
// Should not throw
expect(() => messageBus.publish(request)).not.toThrow();
await expect(messageBus.publish(request)).resolves.not.toThrow();
// Should emit error
expect(errorHandler).toHaveBeenCalledWith(
@@ -38,7 +38,7 @@ export class MessageBus extends EventEmitter {
this.emit(message.type, message);
}
publish(message: Message): void {
async publish(message: Message): Promise<void> {
if (this.debug) {
console.debug(`[MESSAGE_BUS] publish: ${safeJsonStringify(message)}`);
}
@@ -50,7 +50,7 @@ export class MessageBus extends EventEmitter {
}
if (message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) {
const decision = this.policyEngine.check(
const { decision } = await this.policyEngine.check(
message.toolCall,
message.serverName,
);
+189 -1
View File
@@ -9,7 +9,7 @@ import { describe, it, expect, vi, afterEach, beforeEach } from 'vitest';
import nodePath from 'node:path';
import type { PolicySettings } from './types.js';
import { ApprovalMode, PolicyDecision } from './types.js';
import { ApprovalMode, PolicyDecision, InProcessCheckerType } from './types.js';
import { Storage } from '../config/storage.js';
@@ -642,6 +642,194 @@ priority = 150
vi.doUnmock('node:fs/promises');
});
it('should load safety_checker configuration from TOML', async () => {
const actualFs =
await vi.importActual<typeof import('node:fs/promises')>(
'node:fs/promises',
);
const mockReaddir = vi.fn(
async (
path: string | Buffer | URL,
options?: Parameters<typeof actualFs.readdir>[1],
) => {
if (
typeof path === 'string' &&
nodePath
.normalize(path)
.includes(nodePath.normalize('.gemini/policies'))
) {
return [
{
name: 'safety.toml',
isFile: () => true,
isDirectory: () => false,
},
] as unknown as Awaited<ReturnType<typeof actualFs.readdir>>;
}
return actualFs.readdir(
path,
options as Parameters<typeof actualFs.readdir>[1],
);
},
);
const mockReadFile = vi.fn(
async (
path: Parameters<typeof actualFs.readFile>[0],
options: Parameters<typeof actualFs.readFile>[1],
) => {
if (
typeof path === 'string' &&
nodePath
.normalize(path)
.includes(nodePath.normalize('.gemini/policies/safety.toml'))
) {
return `
[[rule]]
toolName = "write_file"
decision = "allow"
priority = 10
[[rule]]
toolName = "write_file"
decision = "allow"
priority = 10
[[safety_checker]]
toolName = "write_file"
priority = 10
[safety_checker.checker]
type = "in-process"
name = "allowed-path"
required_context = ["environment"]
[safety_checker.checker.config]
`;
}
return actualFs.readFile(path, options);
},
);
vi.doMock('node:fs/promises', () => ({
...actualFs,
default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir },
readFile: mockReadFile,
readdir: mockReaddir,
}));
vi.resetModules();
const { createPolicyEngineConfig } = await import('./config.js');
const settings: PolicySettings = {};
const config = await createPolicyEngineConfig(
settings,
ApprovalMode.DEFAULT,
'/tmp/mock/default/policies',
);
const rule = config.rules?.find(
(r) => r.toolName === 'write_file' && r.decision === PolicyDecision.ALLOW,
);
expect(rule).toBeDefined();
const checker = config.checkers?.find(
(c) => c.toolName === 'write_file' && c.checker.type === 'in-process',
);
expect(checker).toBeDefined();
expect(checker?.checker.type).toBe('in-process');
expect(checker?.checker.name).toBe(InProcessCheckerType.ALLOWED_PATH);
expect(checker?.checker.required_context).toEqual(['environment']);
vi.doUnmock('node:fs/promises');
});
it('should reject invalid in-process checker names', async () => {
const actualFs =
await vi.importActual<typeof import('node:fs/promises')>(
'node:fs/promises',
);
const mockReaddir = vi.fn(
async (
path: string | Buffer | URL,
options?: Parameters<typeof actualFs.readdir>[1],
) => {
if (
typeof path === 'string' &&
nodePath
.normalize(path)
.includes(nodePath.normalize('.gemini/policies'))
) {
return [
{
name: 'invalid_safety.toml',
isFile: () => true,
isDirectory: () => false,
},
] as unknown as Awaited<ReturnType<typeof actualFs.readdir>>;
}
return actualFs.readdir(
path,
options as Parameters<typeof actualFs.readdir>[1],
);
},
);
const mockReadFile = vi.fn(
async (
path: Parameters<typeof actualFs.readFile>[0],
options: Parameters<typeof actualFs.readFile>[1],
) => {
if (
typeof path === 'string' &&
nodePath
.normalize(path)
.includes(
nodePath.normalize('.gemini/policies/invalid_safety.toml'),
)
) {
return `
[[rule]]
toolName = "write_file"
decision = "allow"
priority = 10
[[safety_checker]]
toolName = "write_file"
priority = 10
[safety_checker.checker]
type = "in-process"
name = "invalid-name"
`;
}
return actualFs.readFile(path, options);
},
);
vi.doMock('node:fs/promises', () => ({
...actualFs,
default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir },
readFile: mockReadFile,
readdir: mockReaddir,
}));
vi.resetModules();
const { createPolicyEngineConfig } = await import('./config.js');
const settings: PolicySettings = {};
const config = await createPolicyEngineConfig(
settings,
ApprovalMode.DEFAULT,
'/tmp/mock/default/policies',
);
// The rule should be rejected because 'invalid-name' is not in the enum
const rule = config.rules?.find((r) => r.toolName === 'write_file');
expect(rule).toBeUndefined();
vi.doUnmock('node:fs/promises');
});
it('should have default ASK_USER rule for discovered tools', async () => {
vi.resetModules();
vi.doUnmock('node:fs/promises');
+8 -4
View File
@@ -114,10 +114,12 @@ export async function createPolicyEngineConfig(
const policyDirs = getPolicyDirectories(defaultPoliciesDir);
// Load policies from TOML files
const { rules: tomlRules, errors } = await loadPoliciesFromToml(
approvalMode,
policyDirs,
(dir) => getPolicyTier(dir, defaultPoliciesDir),
const {
rules: tomlRules,
checkers: tomlCheckers,
errors,
} = await loadPoliciesFromToml(approvalMode, policyDirs, (dir) =>
getPolicyTier(dir, defaultPoliciesDir),
);
// Emit any errors encountered during TOML loading to the UI
@@ -129,6 +131,7 @@ export async function createPolicyEngineConfig(
}
const rules: PolicyRule[] = [...tomlRules];
const checkers = [...tomlCheckers];
// Priority system for policy rules:
// - Higher priority numbers win over lower priority numbers
@@ -225,6 +228,7 @@ export async function createPolicyEngineConfig(
return {
rules,
checkers,
defaultDecision: PolicyDecision.ASK_USER,
};
}
@@ -36,6 +36,11 @@ decision = "allow"
priority = 15
modes = ["autoEdit"]
[rule.safety_checker]
type = "in-process"
name = "allowed-path"
required_context = ["environment"]
[[rule]]
toolName = "save_memory"
decision = "ask_user"
@@ -57,6 +62,11 @@ decision = "allow"
priority = 15
modes = ["autoEdit"]
[rule.safety_checker]
type = "in-process"
name = "allowed-path"
required_context = ["environment"]
[[rule]]
toolName = "web_fetch"
decision = "ask_user"
File diff suppressed because it is too large Load Diff
+89 -11
View File
@@ -9,12 +9,15 @@ import {
PolicyDecision,
type PolicyEngineConfig,
type PolicyRule,
type SafetyCheckerRule,
} from './types.js';
import { stableStringify } from './stable-stringify.js';
import { debugLogger } from '../utils/debugLogger.js';
import type { CheckerRunner } from '../safety/checker-runner.js';
import { SafetyCheckDecision } from '../safety/protocol.js';
function ruleMatches(
rule: PolicyRule,
rule: PolicyRule | SafetyCheckerRule,
toolCall: FunctionCall,
stringifiedArgs: string | undefined,
serverName: string | undefined,
@@ -60,27 +63,41 @@ function ruleMatches(
export class PolicyEngine {
private rules: PolicyRule[];
private checkers: SafetyCheckerRule[];
private readonly defaultDecision: PolicyDecision;
private readonly nonInteractive: boolean;
private readonly checkerRunner?: CheckerRunner;
constructor(config: PolicyEngineConfig = {}) {
constructor(config: PolicyEngineConfig = {}, checkerRunner?: CheckerRunner) {
this.rules = (config.rules ?? []).sort(
(a, b) => (b.priority ?? 0) - (a.priority ?? 0),
);
this.checkers = (config.checkers ?? []).sort(
(a, b) => (b.priority ?? 0) - (a.priority ?? 0),
);
this.defaultDecision = config.defaultDecision ?? PolicyDecision.ASK_USER;
this.nonInteractive = config.nonInteractive ?? false;
this.checkerRunner = checkerRunner;
}
/**
* Check if a tool call is allowed based on the configured policies.
* Returns the decision and the matching rule (if any).
*/
check(
async check(
toolCall: FunctionCall,
serverName: string | undefined,
): PolicyDecision {
): Promise<{
decision: PolicyDecision;
rule?: PolicyRule;
}> {
let stringifiedArgs: string | undefined;
// Compute stringified args once before the loop
if (toolCall.args && this.rules.some((rule) => rule.argsPattern)) {
if (
toolCall.args &&
(this.rules.some((rule) => rule.argsPattern) ||
this.checkers.some((checker) => checker.argsPattern))
) {
stringifiedArgs = stableStringify(toolCall.args);
}
@@ -89,20 +106,72 @@ export class PolicyEngine {
);
// Find the first matching rule (already sorted by priority)
let matchedRule: PolicyRule | undefined;
let decision: PolicyDecision | undefined;
for (const rule of this.rules) {
if (ruleMatches(rule, toolCall, stringifiedArgs, serverName)) {
debugLogger.debug(
`[PolicyEngine.check] MATCHED rule: toolName=${rule.toolName}, decision=${rule.decision}, priority=${rule.priority}, argsPattern=${rule.argsPattern?.source || 'none'}`,
);
return this.applyNonInteractiveMode(rule.decision);
matchedRule = rule;
decision = this.applyNonInteractiveMode(rule.decision);
break;
}
}
// No matching rule found, use default decision
debugLogger.debug(
`[PolicyEngine.check] NO MATCH - using default decision: ${this.defaultDecision}`,
);
return this.applyNonInteractiveMode(this.defaultDecision);
if (!decision) {
// No matching rule found, use default decision
debugLogger.debug(
`[PolicyEngine.check] NO MATCH - using default decision: ${this.defaultDecision}`,
);
decision = this.applyNonInteractiveMode(this.defaultDecision);
}
// If decision is not DENY, run safety checkers
if (decision !== PolicyDecision.DENY && this.checkerRunner) {
for (const checkerRule of this.checkers) {
if (ruleMatches(checkerRule, toolCall, stringifiedArgs, serverName)) {
debugLogger.debug(
`[PolicyEngine.check] Running safety checker: ${checkerRule.checker.name}`,
);
try {
const result = await this.checkerRunner.runChecker(
toolCall,
checkerRule.checker,
);
if (result.decision === SafetyCheckDecision.DENY) {
debugLogger.debug(
`[PolicyEngine.check] Safety checker denied: ${result.reason}`,
);
return {
decision: PolicyDecision.DENY,
rule: matchedRule,
};
} else if (result.decision === SafetyCheckDecision.ASK_USER) {
debugLogger.debug(
`[PolicyEngine.check] Safety checker requested ASK_USER: ${result.reason}`,
);
decision = PolicyDecision.ASK_USER;
}
} catch (error) {
debugLogger.debug(
`[PolicyEngine.check] Safety checker failed: ${error}`,
);
return {
decision: PolicyDecision.DENY,
rule: matchedRule,
};
}
}
}
}
return {
decision: this.applyNonInteractiveMode(decision),
rule: matchedRule,
};
}
/**
@@ -114,6 +183,11 @@ export class PolicyEngine {
this.rules.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0));
}
addChecker(checker: SafetyCheckerRule): void {
this.checkers.push(checker);
this.checkers.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0));
}
/**
* Remove rules for a specific tool.
*/
@@ -128,6 +202,10 @@ export class PolicyEngine {
return this.rules;
}
getCheckers(): readonly SafetyCheckerRule[] {
return this.checkers;
}
private applyNonInteractiveMode(decision: PolicyDecision): PolicyDecision {
// In non-interactive mode, ASK_USER becomes DENY
if (this.nonInteractive && decision === PolicyDecision.ASK_USER) {
@@ -85,6 +85,7 @@ priority = 100
decision: PolicyDecision.ALLOW,
priority: 1.1, // tier 1 + 100/1000
});
expect(result.checkers).toHaveLength(0);
expect(result.errors).toHaveLength(0);
});
+122 -6
View File
@@ -4,7 +4,14 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { type PolicyRule, PolicyDecision, type ApprovalMode } from './types.js';
import {
type PolicyRule,
PolicyDecision,
type ApprovalMode,
type SafetyCheckerConfig,
type SafetyCheckerRule,
InProcessCheckerType,
} from './types.js';
import fs from 'node:fs/promises';
import path from 'node:path';
import toml from '@iarna/toml';
@@ -39,11 +46,39 @@ const PolicyRuleSchema = z.object({
modes: z.array(z.string()).optional(),
});
/**
* Schema for a single safety checker rule in the TOML file.
*/
const SafetyCheckerRuleSchema = z.object({
toolName: z.union([z.string(), z.array(z.string())]).optional(),
mcpName: z.string().optional(),
argsPattern: z.string().optional(),
commandPrefix: z.union([z.string(), z.array(z.string())]).optional(),
commandRegex: z.string().optional(),
priority: z.number().int().default(0),
modes: z.array(z.string()).optional(),
checker: z.discriminatedUnion('type', [
z.object({
type: z.literal('in-process'),
name: z.nativeEnum(InProcessCheckerType),
required_context: z.array(z.string()).optional(),
config: z.record(z.unknown()).optional(),
}),
z.object({
type: z.literal('external'),
name: z.string(),
required_context: z.array(z.string()).optional(),
config: z.record(z.unknown()).optional(),
}),
]),
});
/**
* Schema for the entire policy TOML file.
*/
const PolicyFileSchema = z.object({
rule: z.array(PolicyRuleSchema),
rule: z.array(PolicyRuleSchema).optional(),
safety_checker: z.array(SafetyCheckerRuleSchema).optional(),
});
/**
@@ -80,6 +115,7 @@ export interface PolicyFileError {
*/
export interface PolicyLoadResult {
rules: PolicyRule[];
checkers: SafetyCheckerRule[];
errors: PolicyFileError[];
}
@@ -194,6 +230,7 @@ export async function loadPoliciesFromToml(
getPolicyTier: (dir: string) => number,
): Promise<PolicyLoadResult> {
const rules: PolicyRule[] = [];
const checkers: SafetyCheckerRule[] = [];
const errors: PolicyFileError[] = [];
for (const dir of policyDirs) {
@@ -267,8 +304,9 @@ export async function loadPoliciesFromToml(
}
// Validate shell command convenience syntax
for (let i = 0; i < validationResult.data.rule.length; i++) {
const rule = validationResult.data.rule[i];
const tomlRules = validationResult.data.rule ?? [];
for (let i = 0; i < tomlRules.length; i++) {
const rule = tomlRules[i];
const validationError = validateShellCommandSyntax(rule, i);
if (validationError) {
errors.push({
@@ -285,7 +323,7 @@ export async function loadPoliciesFromToml(
}
// Transform rules
const parsedRules: PolicyRule[] = validationResult.data.rule
const parsedRules: PolicyRule[] = (validationResult.data.rule ?? [])
.filter((rule) => {
// Filter by mode
if (!rule.modes || rule.modes.length === 0) {
@@ -369,6 +407,84 @@ export async function loadPoliciesFromToml(
.filter((rule): rule is PolicyRule => rule !== null);
rules.push(...parsedRules);
// Transform checkers
const parsedCheckers: SafetyCheckerRule[] = (
validationResult.data.safety_checker ?? []
)
.filter((checker) => {
if (!checker.modes || checker.modes.length === 0) {
return true;
}
return checker.modes.includes(approvalMode);
})
.flatMap((checker) => {
let effectiveArgsPattern = checker.argsPattern;
const commandPrefixes: string[] = [];
if (checker.commandPrefix) {
const prefixes = Array.isArray(checker.commandPrefix)
? checker.commandPrefix
: [checker.commandPrefix];
commandPrefixes.push(...prefixes);
} else if (checker.commandRegex) {
effectiveArgsPattern = `"command":"${checker.commandRegex}`;
}
const argsPatterns: Array<string | undefined> =
commandPrefixes.length > 0
? commandPrefixes.map(
(prefix) => `"command":"${escapeRegex(prefix)}`,
)
: [effectiveArgsPattern];
return argsPatterns.flatMap((argsPattern) => {
const toolNames: Array<string | undefined> = checker.toolName
? Array.isArray(checker.toolName)
? checker.toolName
: [checker.toolName]
: [undefined];
return toolNames.map((toolName) => {
let effectiveToolName: string | undefined;
if (checker.mcpName && toolName) {
effectiveToolName = `${checker.mcpName}__${toolName}`;
} else if (checker.mcpName) {
effectiveToolName = `${checker.mcpName}__*`;
} else {
effectiveToolName = toolName;
}
const safetyCheckerRule: SafetyCheckerRule = {
toolName: effectiveToolName,
priority: checker.priority,
checker: checker.checker as SafetyCheckerConfig,
};
if (argsPattern) {
try {
safetyCheckerRule.argsPattern = new RegExp(argsPattern);
} catch (e) {
const error = e as Error;
errors.push({
filePath,
fileName: file,
tier: tierName,
errorType: 'regex_compilation',
message: 'Invalid regex pattern in safety checker',
details: `Pattern: ${argsPattern}\nError: ${error.message}`,
});
return null;
}
}
return safetyCheckerRule;
});
});
})
.filter((checker): checker is SafetyCheckerRule => checker !== null);
checkers.push(...parsedCheckers);
} catch (e) {
const error = e as NodeJS.ErrnoException;
// Catch-all for unexpected errors
@@ -386,5 +502,5 @@ export async function loadPoliciesFromToml(
}
}
return { rules, errors };
return { rules, checkers, errors };
}
+79
View File
@@ -4,6 +4,8 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { SafetyCheckInput } from '../safety/protocol.js';
export enum PolicyDecision {
ALLOW = 'allow',
DENY = 'deny',
@@ -16,6 +18,52 @@ export enum ApprovalMode {
YOLO = 'yolo',
}
/**
* Configuration for the built-in allowed-path checker.
*/
export interface AllowedPathConfig {
/**
* Explicitly include argument keys to be checked as paths.
*/
included_args?: string[];
/**
* Explicitly exclude argument keys from being checked as paths.
*/
excluded_args?: string[];
}
/**
* Base interface for external checkers.
*/
export interface ExternalCheckerConfig {
type: 'external';
name: string;
config?: unknown;
required_context?: Array<keyof SafetyCheckInput['context']>;
}
export enum InProcessCheckerType {
ALLOWED_PATH = 'allowed-path',
}
/**
* Base interface for in-process checkers.
*/
export interface InProcessCheckerConfig {
type: 'in-process';
name: InProcessCheckerType;
config?: AllowedPathConfig;
required_context?: Array<keyof SafetyCheckInput['context']>;
}
/**
* A discriminated union for all safety checker configurations.
*/
export type SafetyCheckerConfig =
| ExternalCheckerConfig
| InProcessCheckerConfig;
export interface PolicyRule {
/**
* The name of the tool this rule applies to.
@@ -41,12 +89,43 @@ export interface PolicyRule {
priority?: number;
}
export interface SafetyCheckerRule {
/**
* The name of the tool this rule applies to.
* If undefined, the rule applies to all tools.
*/
toolName?: string;
/**
* Pattern to match against tool arguments.
* Can be used for more fine-grained control.
*/
argsPattern?: RegExp;
/**
* Priority of this checker. Higher numbers run first.
* Default is 0.
*/
priority?: number;
/**
* Specifies an external or built-in safety checker to execute for
* additional validation of a tool call.
*/
checker: SafetyCheckerConfig;
}
export interface PolicyEngineConfig {
/**
* List of policy rules to apply.
*/
rules?: PolicyRule[];
/**
* List of safety checkers to apply.
*/
checkers?: SafetyCheckerRule[];
/**
* Default decision when no rules match.
* Defaults to ASK_USER.
+244
View File
@@ -0,0 +1,244 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import * as fs from 'node:fs/promises';
import * as os from 'node:os';
import * as path from 'node:path';
import { AllowedPathChecker } from './built-in.js';
import type { SafetyCheckInput } from './protocol.js';
import { SafetyCheckDecision } from './protocol.js';
import type { FunctionCall } from '@google/genai';
describe('AllowedPathChecker', () => {
let checker: AllowedPathChecker;
let testRootDir: string;
let mockCwd: string;
let mockWorkspaces: string[];
beforeEach(async () => {
checker = new AllowedPathChecker();
testRootDir = await fs.mkdtemp(path.join(os.tmpdir(), 'safety-test-'));
mockCwd = path.join(testRootDir, 'home', 'user', 'project');
await fs.mkdir(mockCwd, { recursive: true });
mockWorkspaces = [
mockCwd,
path.join(testRootDir, 'home', 'user', 'other-project'),
];
await fs.mkdir(mockWorkspaces[1], { recursive: true });
});
afterEach(async () => {
await fs.rm(testRootDir, { recursive: true, force: true });
});
const createInput = (
toolArgs: Record<string, unknown>,
config?: Record<string, unknown>,
): SafetyCheckInput => ({
protocolVersion: '1.0.0',
toolCall: {
name: 'test_tool',
args: toolArgs,
} as unknown as FunctionCall,
context: {
environment: {
cwd: mockCwd,
workspaces: mockWorkspaces,
},
},
config,
});
it('should allow paths within CWD', async () => {
const filePath = path.join(mockCwd, 'file.txt');
await fs.writeFile(filePath, 'test content');
const input = createInput({
path: filePath,
});
const result = await checker.check(input);
expect(result.decision).toBe(SafetyCheckDecision.ALLOW);
});
it('should allow paths within workspace roots', async () => {
const filePath = path.join(mockWorkspaces[1], 'data.json');
await fs.writeFile(filePath, 'test content');
const input = createInput({
path: filePath,
});
const result = await checker.check(input);
expect(result.decision).toBe(SafetyCheckDecision.ALLOW);
});
it('should deny paths outside allowed areas', async () => {
const outsidePath = path.join(testRootDir, 'etc', 'passwd');
await fs.mkdir(path.dirname(outsidePath), { recursive: true });
await fs.writeFile(outsidePath, 'secret');
const input = createInput({ path: outsidePath });
const result = await checker.check(input);
expect(result.decision).toBe(SafetyCheckDecision.DENY);
expect(result.reason).toContain('outside of the allowed workspace');
});
it('should deny paths using ../ to escape', async () => {
const secretPath = path.join(testRootDir, 'home', 'user', 'secret.txt');
await fs.writeFile(secretPath, 'secret');
const input = createInput({
path: path.join(mockCwd, '..', 'secret.txt'),
});
const result = await checker.check(input);
expect(result.decision).toBe(SafetyCheckDecision.DENY);
});
it('should check multiple path arguments', async () => {
const passwdPath = path.join(testRootDir, 'etc', 'passwd');
await fs.mkdir(path.dirname(passwdPath), { recursive: true });
await fs.writeFile(passwdPath, 'secret');
const srcPath = path.join(mockCwd, 'src.txt');
await fs.writeFile(srcPath, 'source content');
const input = createInput({
source: srcPath,
destination: passwdPath,
});
const result = await checker.check(input);
expect(result.decision).toBe(SafetyCheckDecision.DENY);
expect(result.reason).toContain(passwdPath);
});
it('should handle non-existent paths gracefully if they are inside allowed dir', async () => {
const input = createInput({
path: path.join(mockCwd, 'new-file.txt'),
});
const result = await checker.check(input);
expect(result.decision).toBe(SafetyCheckDecision.ALLOW);
});
it('should deny access if path contains a symlink pointing outside allowed directories', async () => {
const symlinkPath = path.join(mockCwd, 'symlink');
const targetPath = path.join(testRootDir, 'etc', 'passwd');
await fs.mkdir(path.dirname(targetPath), { recursive: true });
await fs.writeFile(targetPath, 'secret');
// Create symlink: mockCwd/symlink -> targetPath
await fs.symlink(targetPath, symlinkPath);
const input = createInput({ path: symlinkPath });
const result = await checker.check(input);
expect(result.decision).toBe(SafetyCheckDecision.DENY);
expect(result.reason).toContain(
'outside of the allowed workspace directories',
);
});
it('should allow access if path contains a symlink pointing INSIDE allowed directories', async () => {
const symlinkPath = path.join(mockCwd, 'symlink-inside');
const realFilePath = path.join(mockCwd, 'real-file');
await fs.writeFile(realFilePath, 'real content');
// Create symlink: mockCwd/symlink-inside -> mockCwd/real-file
await fs.symlink(realFilePath, symlinkPath);
const input = createInput({ path: symlinkPath });
const result = await checker.check(input);
expect(result.decision).toBe(SafetyCheckDecision.ALLOW);
});
it('should check explicitly included arguments', async () => {
const outsidePath = path.join(testRootDir, 'etc', 'passwd');
await fs.mkdir(path.dirname(outsidePath), { recursive: true });
await fs.writeFile(outsidePath, 'secret');
const input = createInput(
{ custom_arg: outsidePath },
{ included_args: ['custom_arg'] },
);
const result = await checker.check(input);
expect(result.decision).toBe(SafetyCheckDecision.DENY);
expect(result.reason).toContain('outside of the allowed workspace');
});
it('should skip explicitly excluded arguments', async () => {
const outsidePath = path.join(testRootDir, 'etc', 'passwd');
await fs.mkdir(path.dirname(outsidePath), { recursive: true });
await fs.writeFile(outsidePath, 'secret');
// Normally 'path' would be checked, but we exclude it
const input = createInput(
{ path: outsidePath },
{ excluded_args: ['path'] },
);
const result = await checker.check(input);
expect(result.decision).toBe(SafetyCheckDecision.ALLOW);
});
it('should handle both included and excluded arguments', async () => {
const outsidePath = path.join(testRootDir, 'etc', 'passwd');
await fs.mkdir(path.dirname(outsidePath), { recursive: true });
await fs.writeFile(outsidePath, 'secret');
const input = createInput(
{
path: outsidePath, // Excluded
custom_arg: outsidePath, // Included
},
{
excluded_args: ['path'],
included_args: ['custom_arg'],
},
);
const result = await checker.check(input);
expect(result.decision).toBe(SafetyCheckDecision.DENY);
// Should be denied because of custom_arg, not path
expect(result.reason).toContain(outsidePath);
});
it('should check nested path arguments', async () => {
const outsidePath = path.join(testRootDir, 'etc', 'passwd');
await fs.mkdir(path.dirname(outsidePath), { recursive: true });
await fs.writeFile(outsidePath, 'secret');
const input = createInput({
nested: {
path: outsidePath,
},
});
const result = await checker.check(input);
expect(result.decision).toBe(SafetyCheckDecision.DENY);
expect(result.reason).toContain(outsidePath);
expect(result.reason).toContain('nested.path');
});
it('should support dot notation for included_args', async () => {
const outsidePath = path.join(testRootDir, 'etc', 'passwd');
await fs.mkdir(path.dirname(outsidePath), { recursive: true });
await fs.writeFile(outsidePath, 'secret');
const input = createInput(
{
nested: {
custom: outsidePath,
},
},
{ included_args: ['nested.custom'] },
);
const result = await checker.check(input);
expect(result.decision).toBe(SafetyCheckDecision.DENY);
expect(result.reason).toContain(outsidePath);
expect(result.reason).toContain('nested.custom');
});
it('should support dot notation for excluded_args', async () => {
const outsidePath = path.join(testRootDir, 'etc', 'passwd');
await fs.mkdir(path.dirname(outsidePath), { recursive: true });
await fs.writeFile(outsidePath, 'secret');
const input = createInput(
{
nested: {
path: outsidePath,
},
},
{ excluded_args: ['nested.path'] },
);
const result = await checker.check(input);
expect(result.decision).toBe(SafetyCheckDecision.ALLOW);
});
});
+154
View File
@@ -0,0 +1,154 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as path from 'node:path';
import * as fs from 'node:fs';
import type { SafetyCheckInput, SafetyCheckResult } from './protocol.js';
import { SafetyCheckDecision } from './protocol.js';
import type { AllowedPathConfig } from '../policy/types.js';
/**
* Interface for all in-process safety checkers.
*/
export interface InProcessChecker {
check(input: SafetyCheckInput): Promise<SafetyCheckResult>;
}
/**
* An in-process checker to validate file paths.
*/
export class AllowedPathChecker implements InProcessChecker {
async check(input: SafetyCheckInput): Promise<SafetyCheckResult> {
const { toolCall, context } = input;
const config = input.config as AllowedPathConfig | undefined;
// Build list of allowed directories
const allowedDirs = [
context.environment.cwd,
...context.environment.workspaces,
];
// Find all arguments that look like paths
const includedArgs = config?.included_args ?? [];
const excludedArgs = config?.excluded_args ?? [];
const pathsToCheck = this.collectPathsToCheck(
toolCall.args,
includedArgs,
excludedArgs,
);
// Check each path
for (const { path: p, argName } of pathsToCheck) {
const resolvedPath = this.safelyResolvePath(p, context.environment.cwd);
if (!resolvedPath) {
// If path cannot be resolved, deny it
return {
decision: SafetyCheckDecision.DENY,
reason: `Cannot resolve path "${p}" in argument "${argName}"`,
};
}
const isAllowed = allowedDirs.some((dir) => {
// Also resolve allowed directories to handle symlinks
const resolvedDir = this.safelyResolvePath(
dir,
context.environment.cwd,
);
if (!resolvedDir) return false;
return this.isPathAllowed(resolvedPath, resolvedDir);
});
if (!isAllowed) {
return {
decision: SafetyCheckDecision.DENY,
reason: `Path "${p}" in argument "${argName}" is outside of the allowed workspace directories.`,
};
}
}
return { decision: SafetyCheckDecision.ALLOW };
}
private safelyResolvePath(inputPath: string, cwd: string): string | null {
try {
const resolved = path.resolve(cwd, inputPath);
// Walk up the directory tree until we find a path that exists
let current = resolved;
// Stop at root (dirname(root) === root on many systems, or it becomes empty/'.' depending on implementation)
while (current && current !== path.dirname(current)) {
if (fs.existsSync(current)) {
const canonical = fs.realpathSync(current);
// Re-construct the full path from this canonical base
const relative = path.relative(current, resolved);
// path.join handles empty relative paths correctly (returns canonical)
return path.join(canonical, relative);
}
current = path.dirname(current);
}
// Fallback if nothing exists (unlikely if root exists)
return resolved;
} catch (_error) {
return null;
}
}
private isPathAllowed(targetPath: string, allowedDir: string): boolean {
const relative = path.relative(allowedDir, targetPath);
return (
relative === '' ||
(!relative.startsWith('..') && !path.isAbsolute(relative))
);
}
private collectPathsToCheck(
args: unknown,
includedArgs: string[],
excludedArgs: string[],
prefix = '',
): Array<{ path: string; argName: string }> {
const paths: Array<{ path: string; argName: string }> = [];
if (typeof args !== 'object' || args === null) {
return paths;
}
for (const [key, value] of Object.entries(args)) {
const fullKey = prefix ? `${prefix}.${key}` : key;
if (excludedArgs.includes(fullKey)) {
continue;
}
if (typeof value === 'string') {
if (
includedArgs.includes(fullKey) ||
key.includes('path') ||
key.includes('directory') ||
key.includes('file') ||
key === 'source' ||
key === 'destination'
) {
paths.push({ path: value, argName: fullKey });
}
} else if (typeof value === 'object') {
paths.push(
...this.collectPathsToCheck(
value,
includedArgs,
excludedArgs,
fullKey,
),
);
}
}
return paths;
}
}
@@ -0,0 +1,303 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { spawn } from 'node:child_process';
import { CheckerRunner } from './checker-runner.js';
import { ContextBuilder } from './context-builder.js';
import { CheckerRegistry } from './registry.js';
import {
type InProcessCheckerConfig,
InProcessCheckerType,
} from '../policy/types.js';
import type { SafetyCheckResult } from './protocol.js';
import { SafetyCheckDecision } from './protocol.js';
import type { Config } from '../config/config.js';
// Mock dependencies
vi.mock('./registry.js');
vi.mock('./context-builder.js');
vi.mock('node:child_process');
describe('CheckerRunner', () => {
let runner: CheckerRunner;
let mockContextBuilder: ContextBuilder;
let mockRegistry: CheckerRegistry;
const mockToolCall = { name: 'test_tool', args: {} };
const mockInProcessConfig: InProcessCheckerConfig = {
type: 'in-process',
name: InProcessCheckerType.ALLOWED_PATH,
};
beforeEach(() => {
mockContextBuilder = new ContextBuilder({} as Config);
mockRegistry = new CheckerRegistry('/mock/dist');
CheckerRegistry.prototype.resolveInProcess = vi.fn();
runner = new CheckerRunner(mockContextBuilder, mockRegistry, {
checkersPath: '/mock/dist',
});
});
afterEach(() => {
vi.restoreAllMocks();
});
it('should run in-process checker successfully', async () => {
const mockResult: SafetyCheckResult = {
decision: SafetyCheckDecision.ALLOW,
};
const mockChecker = {
check: vi.fn().mockResolvedValue(mockResult),
};
vi.mocked(mockRegistry.resolveInProcess).mockReturnValue(mockChecker);
vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({
environment: { cwd: '/tmp', workspaces: [] },
});
const result = await runner.runChecker(mockToolCall, mockInProcessConfig);
expect(result).toEqual(mockResult);
expect(mockRegistry.resolveInProcess).toHaveBeenCalledWith(
InProcessCheckerType.ALLOWED_PATH,
);
expect(mockChecker.check).toHaveBeenCalled();
});
it('should handle in-process checker errors', async () => {
const mockChecker = {
check: vi.fn().mockRejectedValue(new Error('Checker failed')),
};
vi.mocked(mockRegistry.resolveInProcess).mockReturnValue(mockChecker);
vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({
environment: { cwd: '/tmp', workspaces: [] },
});
const result = await runner.runChecker(mockToolCall, mockInProcessConfig);
expect(result.decision).toBe(SafetyCheckDecision.DENY);
expect(result.reason).toContain('Failed to run in-process checker');
expect(result.reason).toContain('Checker failed');
});
it('should respect timeout for in-process checkers', async () => {
vi.useFakeTimers();
const mockChecker = {
check: vi.fn().mockImplementation(async () => {
await new Promise((resolve) => setTimeout(resolve, 6000)); // Longer than default 5s timeout
return { decision: SafetyCheckDecision.ALLOW };
}),
};
vi.mocked(mockRegistry.resolveInProcess).mockReturnValue(mockChecker);
vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({
environment: { cwd: '/tmp', workspaces: [] },
});
const runPromise = runner.runChecker(mockToolCall, mockInProcessConfig);
vi.advanceTimersByTime(5001);
const result = await runPromise;
expect(result.decision).toBe(SafetyCheckDecision.DENY);
expect(result.reason).toContain('timed out');
vi.useRealTimers();
});
it('should use minimal context when requested', async () => {
const configWithContext: InProcessCheckerConfig = {
...mockInProcessConfig,
required_context: ['environment'],
};
const mockChecker = {
check: vi.fn().mockResolvedValue({ decision: SafetyCheckDecision.ALLOW }),
};
vi.mocked(mockRegistry.resolveInProcess).mockReturnValue(mockChecker);
vi.mocked(mockContextBuilder.buildMinimalContext).mockReturnValue({
environment: { cwd: '/tmp', workspaces: [] },
});
await runner.runChecker(mockToolCall, configWithContext);
expect(mockContextBuilder.buildMinimalContext).toHaveBeenCalledWith([
'environment',
]);
expect(mockContextBuilder.buildFullContext).not.toHaveBeenCalled();
});
it('should pass config to in-process checker via toolCall', async () => {
const mockConfig = { included_args: ['foo'] };
const configWithConfig: InProcessCheckerConfig = {
...mockInProcessConfig,
config: mockConfig,
};
const mockResult: SafetyCheckResult = {
decision: SafetyCheckDecision.ALLOW,
};
const mockChecker = {
check: vi.fn().mockResolvedValue(mockResult),
};
vi.mocked(mockRegistry.resolveInProcess).mockReturnValue(mockChecker);
vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({
environment: { cwd: '/tmp', workspaces: [] },
});
await runner.runChecker(mockToolCall, configWithConfig);
expect(mockChecker.check).toHaveBeenCalledWith(
expect.objectContaining({
toolCall: mockToolCall,
config: mockConfig,
}),
);
});
describe('External Checkers', () => {
const mockExternalConfig = {
type: 'external' as const,
name: 'python-checker',
};
it('should spawn external checker directly', async () => {
const mockCheckerPath = '/mock/dist/python-checker';
vi.mocked(mockRegistry.resolveExternal).mockReturnValue(mockCheckerPath);
vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({
environment: { cwd: '/tmp', workspaces: [] },
});
const mockStdout = {
on: vi.fn().mockImplementation((event, callback) => {
if (event === 'data') {
callback(
Buffer.from(
JSON.stringify({ decision: SafetyCheckDecision.ALLOW }),
),
);
}
}),
};
const mockChildProcess = {
stdin: { write: vi.fn(), end: vi.fn() },
stdout: mockStdout,
stderr: { on: vi.fn() },
on: vi.fn().mockImplementation((event, callback) => {
if (event === 'close') {
// Defer the close callback slightly to allow stdout 'data' to be registered
setTimeout(() => callback(0), 0);
}
}),
kill: vi.fn(),
};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
vi.mocked(spawn).mockReturnValue(mockChildProcess as any);
const result = await runner.runChecker(mockToolCall, mockExternalConfig);
expect(result.decision).toBe(SafetyCheckDecision.ALLOW);
expect(spawn).toHaveBeenCalledWith(
mockCheckerPath,
[],
expect.anything(),
);
});
it('should include checker name in timeout error message', async () => {
vi.useFakeTimers();
const mockCheckerPath = '/mock/dist/python-checker';
vi.mocked(mockRegistry.resolveExternal).mockReturnValue(mockCheckerPath);
vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({
environment: { cwd: '/tmp', workspaces: [] },
});
const mockChildProcess = {
stdin: { write: vi.fn(), end: vi.fn() },
stdout: { on: vi.fn() },
stderr: { on: vi.fn() },
on: vi.fn(), // Never calls 'close'
kill: vi.fn(),
};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
vi.mocked(spawn).mockReturnValue(mockChildProcess as any);
const runPromise = runner.runChecker(mockToolCall, mockExternalConfig);
vi.advanceTimersByTime(5001);
const result = await runPromise;
expect(result.decision).toBe(SafetyCheckDecision.DENY);
expect(result.reason).toContain(
'Safety checker "python-checker" timed out',
);
vi.useRealTimers();
});
it('should send SIGKILL if process ignores SIGTERM', async () => {
vi.useFakeTimers();
const mockCheckerPath = '/mock/dist/python-checker';
vi.mocked(mockRegistry.resolveExternal).mockReturnValue(mockCheckerPath);
vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({
environment: { cwd: '/tmp', workspaces: [] },
});
const mockChildProcess = {
stdin: { write: vi.fn(), end: vi.fn() },
stdout: { on: vi.fn() },
stderr: { on: vi.fn() },
on: vi.fn(), // Never calls 'close' automatically
kill: vi.fn(),
};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
vi.mocked(spawn).mockReturnValue(mockChildProcess as any);
const runPromise = runner.runChecker(mockToolCall, mockExternalConfig);
// Trigger main timeout
vi.advanceTimersByTime(5001);
// Should have sent SIGTERM
expect(mockChildProcess.kill).toHaveBeenCalledWith('SIGTERM');
// Advance past cleanup timeout (5000ms)
vi.advanceTimersByTime(5000);
// Should have sent SIGKILL
expect(mockChildProcess.kill).toHaveBeenCalledWith('SIGKILL');
// Clean up promise
await runPromise;
vi.useRealTimers();
});
it('should include checker name in non-zero exit code error message', async () => {
const mockCheckerPath = '/mock/dist/python-checker';
vi.mocked(mockRegistry.resolveExternal).mockReturnValue(mockCheckerPath);
vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({
environment: { cwd: '/tmp', workspaces: [] },
});
const mockChildProcess = {
stdin: { write: vi.fn(), end: vi.fn() },
stdout: { on: vi.fn() },
stderr: { on: vi.fn() },
on: vi.fn().mockImplementation((event, callback) => {
if (event === 'close') {
callback(1); // Exit code 1
}
}),
kill: vi.fn(),
};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
vi.mocked(spawn).mockReturnValue(mockChildProcess as any);
const result = await runner.runChecker(mockToolCall, mockExternalConfig);
expect(result.decision).toBe(SafetyCheckDecision.DENY);
expect(result.reason).toContain(
'Safety checker "python-checker" exited with code 1',
);
});
});
});
+297
View File
@@ -0,0 +1,297 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { spawn } from 'node:child_process';
import type { FunctionCall } from '@google/genai';
import type {
SafetyCheckerConfig,
InProcessCheckerConfig,
ExternalCheckerConfig,
} from '../policy/types.js';
import type { SafetyCheckInput, SafetyCheckResult } from './protocol.js';
import { SafetyCheckDecision } from './protocol.js';
import type { CheckerRegistry } from './registry.js';
import type { ContextBuilder } from './context-builder.js';
/**
* Configuration for the checker runner.
*/
export interface CheckerRunnerConfig {
/**
* Maximum time (in milliseconds) to wait for a checker to complete.
* Default: 5000 (5 seconds)
*/
timeout?: number;
/**
* Path to the directory containing external checkers.
*/
checkersPath: string;
}
/**
* Service for executing safety checker processes.
*/
export class CheckerRunner {
private static readonly DEFAULT_TIMEOUT = 5000; // 5 seconds
private readonly registry: CheckerRegistry;
private readonly contextBuilder: ContextBuilder;
private readonly timeout: number;
constructor(
contextBuilder: ContextBuilder,
registry: CheckerRegistry,
config: CheckerRunnerConfig,
) {
this.contextBuilder = contextBuilder;
this.registry = registry;
this.timeout = config.timeout ?? CheckerRunner.DEFAULT_TIMEOUT;
}
/**
* Runs a safety checker and returns the result.
*/
async runChecker(
toolCall: FunctionCall,
checkerConfig: SafetyCheckerConfig,
): Promise<SafetyCheckResult> {
if (checkerConfig.type === 'in-process') {
return this.runInProcessChecker(toolCall, checkerConfig);
}
return this.runExternalChecker(toolCall, checkerConfig);
}
private async runInProcessChecker(
toolCall: FunctionCall,
checkerConfig: InProcessCheckerConfig,
): Promise<SafetyCheckResult> {
try {
const checker = this.registry.resolveInProcess(checkerConfig.name);
const context = checkerConfig.required_context
? this.contextBuilder.buildMinimalContext(
checkerConfig.required_context,
)
: this.contextBuilder.buildFullContext();
const input: SafetyCheckInput = {
protocolVersion: '1.0.0',
toolCall,
context,
config: checkerConfig.config,
};
// In-process checkers can be async, but we'll also apply a timeout
// for safety, in case of infinite loops or unexpected delays.
return await this.executeWithTimeout(checker.check(input));
} catch (error) {
return {
decision: SafetyCheckDecision.DENY,
reason: `Failed to run in-process checker "${checkerConfig.name}": ${
error instanceof Error ? error.message : String(error)
}`,
};
}
}
private async runExternalChecker(
toolCall: FunctionCall,
checkerConfig: ExternalCheckerConfig,
): Promise<SafetyCheckResult> {
try {
// Resolve the checker executable path
const checkerPath = this.registry.resolveExternal(checkerConfig.name);
// Build the appropriate context
const context = checkerConfig.required_context
? this.contextBuilder.buildMinimalContext(
checkerConfig.required_context,
)
: this.contextBuilder.buildFullContext();
// Create the input payload
const input: SafetyCheckInput = {
protocolVersion: '1.0.0',
toolCall,
context,
config: checkerConfig.config,
};
// Run the checker process
return await this.executeCheckerProcess(
checkerPath,
input,
checkerConfig.name,
);
} catch (error) {
// If anything goes wrong, deny the operation
return {
decision: SafetyCheckDecision.DENY,
reason: `Failed to run safety checker "${checkerConfig.name}": ${
error instanceof Error ? error.message : String(error)
}`,
};
}
}
/**
* Executes an external checker process and handles its lifecycle.
*/
private executeCheckerProcess(
checkerPath: string,
input: SafetyCheckInput,
checkerName: string,
): Promise<SafetyCheckResult> {
return new Promise((resolve) => {
const child = spawn(checkerPath, [], {
stdio: ['pipe', 'pipe', 'pipe'],
});
let stdout = '';
let stderr = '';
let timeoutHandle: NodeJS.Timeout | null = null;
let killed = false;
let exited = false;
// Set up timeout
timeoutHandle = setTimeout(() => {
killed = true;
child.kill('SIGTERM');
resolve({
decision: SafetyCheckDecision.DENY,
reason: `Safety checker "${checkerName}" timed out after ${this.timeout}ms`,
});
// Fallback: if process doesn't exit after 5s, force kill
setTimeout(() => {
if (!exited) {
child.kill('SIGKILL');
}
}, 5000).unref();
}, this.timeout);
// Collect output
if (child.stdout) {
child.stdout.on('data', (data: Buffer) => {
stdout += data.toString();
});
}
if (child.stderr) {
child.stderr.on('data', (data: Buffer) => {
stderr += data.toString();
});
}
// Handle process completion
child.on('close', (code: number | null) => {
exited = true;
if (timeoutHandle) {
clearTimeout(timeoutHandle);
}
// If we already killed it due to timeout, don't process the result
if (killed) {
return;
}
// Non-zero exit code is a failure
if (code !== 0) {
resolve({
decision: SafetyCheckDecision.DENY,
reason: `Safety checker "${checkerName}" exited with code ${code}${
stderr ? `: ${stderr}` : ''
}`,
});
return;
}
// Try to parse the output
try {
const result: SafetyCheckResult = JSON.parse(stdout);
// Validate the result structure
if (
!result.decision ||
!Object.values(SafetyCheckDecision).includes(result.decision)
) {
throw new Error(
'Invalid result: missing or invalid "decision" field',
);
}
resolve(result);
} catch (parseError) {
resolve({
decision: SafetyCheckDecision.DENY,
reason: `Failed to parse output from safety checker "${checkerName}": ${
parseError instanceof Error
? parseError.message
: String(parseError)
}`,
});
}
});
// Handle process errors
child.on('error', (error: Error) => {
if (timeoutHandle) {
clearTimeout(timeoutHandle);
}
if (!killed) {
resolve({
decision: SafetyCheckDecision.DENY,
reason: `Failed to spawn safety checker "${checkerName}": ${error.message}`,
});
}
});
// Send input to the checker
try {
if (child.stdin) {
child.stdin.write(JSON.stringify(input));
child.stdin.end();
} else {
throw new Error('Failed to open stdin for checker process');
}
} catch (writeError) {
if (timeoutHandle) {
clearTimeout(timeoutHandle);
}
child.kill();
resolve({
decision: SafetyCheckDecision.DENY,
reason: `Failed to write to stdin of safety checker "${checkerName}": ${
writeError instanceof Error
? writeError.message
: String(writeError)
}`,
});
}
});
}
/**
* Executes a promise with a timeout.
*/
private executeWithTimeout<T>(promise: Promise<T>): Promise<T> {
return new Promise((resolve, reject) => {
const timeoutHandle = setTimeout(() => {
reject(new Error(`Checker timed out after ${this.timeout}ms`));
}, this.timeout);
promise
.then(resolve)
.catch(reject)
.finally(() => {
clearTimeout(timeoutHandle);
});
});
}
}
@@ -0,0 +1,56 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { ContextBuilder } from './context-builder.js';
import type { Config } from '../config/config.js';
import type { ConversationTurn } from './protocol.js';
describe('ContextBuilder', () => {
let contextBuilder: ContextBuilder;
let mockConfig: Config;
const mockHistory: ConversationTurn[] = [
{ user: { text: 'hello' }, model: { text: 'hi' } },
];
const mockCwd = '/home/user/project';
const mockWorkspaces = ['/home/user/project'];
beforeEach(() => {
vi.spyOn(process, 'cwd').mockReturnValue(mockCwd);
mockConfig = {
getWorkspaceContext: vi.fn().mockReturnValue({
getDirectories: vi.fn().mockReturnValue(mockWorkspaces),
}),
apiKey: 'secret-api-key',
somePublicConfig: 'public-value',
nested: {
secretToken: 'hidden',
public: 'visible',
},
} as unknown as Config;
contextBuilder = new ContextBuilder(mockConfig, mockHistory);
});
it('should build full context with all fields', () => {
const context = contextBuilder.buildFullContext();
expect(context.environment.cwd).toBe(mockCwd);
expect(context.environment.workspaces).toEqual(mockWorkspaces);
expect(context.history?.turns).toEqual(mockHistory);
});
it('should build minimal context with only required keys', () => {
const context = contextBuilder.buildMinimalContext(['environment']);
expect(context).toHaveProperty('environment');
expect(context).not.toHaveProperty('config');
expect(context).not.toHaveProperty('history');
});
it('should handle missing history', () => {
contextBuilder = new ContextBuilder(mockConfig);
const context = contextBuilder.buildFullContext();
expect(context.history?.turns).toEqual([]);
});
});
@@ -0,0 +1,54 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { SafetyCheckInput, ConversationTurn } from './protocol.js';
import type { Config } from '../config/config.js';
/**
* Builds context objects for safety checkers, ensuring sensitive data is filtered.
*/
export class ContextBuilder {
constructor(
private readonly config: Config,
private readonly conversationHistory: ConversationTurn[] = [],
) {}
/**
* Builds the full context object with all available data.
*/
buildFullContext(): SafetyCheckInput['context'] {
return {
environment: {
cwd: process.cwd(),
workspaces: this.config
.getWorkspaceContext()
.getDirectories() as string[],
},
history: {
turns: this.conversationHistory,
},
};
}
/**
* Builds a minimal context with only the specified keys.
*/
buildMinimalContext(
requiredKeys: Array<keyof SafetyCheckInput['context']>,
): SafetyCheckInput['context'] {
const fullContext = this.buildFullContext();
const minimalContext: Partial<SafetyCheckInput['context']> = {};
for (const key of requiredKeys) {
if (key in fullContext) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(minimalContext as any)[key] = fullContext[key];
}
}
return minimalContext as SafetyCheckInput['context'];
}
}
+100
View File
@@ -0,0 +1,100 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { FunctionCall } from '@google/genai';
/**
* Represents a single turn in the conversation between the user and the model.
* This provides semantic context for why a tool call might be happening.
*/
export interface ConversationTurn {
user: {
text: string;
};
model: {
text?: string;
toolCalls?: FunctionCall[];
};
}
/**
* The data structure passed from the CLI to a safety checker process via stdin.
*/
export interface SafetyCheckInput {
/**
* The semantic version of the protocol (e.g., "1.0.0"). This allows
* for introducing breaking changes in the future while maintaining
* support for older checkers.
*/
protocolVersion: '1.0.0';
/**
* The specific tool call that is being validated.
*/
toolCall: FunctionCall;
/**
* A container for all contextual information from the CLI's internal state.
* By grouping data into categories, we can easily add new context in the
* future without creating a flat, unmanageable object.
*/
context: {
/**
* Information about the user's file system and execution environment.
*/
environment: {
cwd: string;
workspaces: string[]; // A list of user-configured workspace roots
};
/**
* The recent history of the conversation. This can be used by checkers
* that need to understand the intent behind a tool call.
*/
history?: {
turns: ConversationTurn[];
};
};
/**
* Configuration for the safety checker.
* This allows checkers to be parameterized (e.g. allowed paths).
*/
config?: unknown;
}
/**
* The possible decisions a safety checker can make.
*/
export enum SafetyCheckDecision {
ALLOW = 'allow',
DENY = 'deny',
ASK_USER = 'ask_user',
}
/**
* The data structure returned by a safety checker process via stdout.
*/
export type SafetyCheckResult =
| {
/**
* The decision made by the safety checker.
*/
decision: SafetyCheckDecision.ALLOW;
/**
* If not allowed, a message explaining why the tool call was blocked.
* This will be shown to the user.
*/
reason?: string;
}
| {
decision: SafetyCheckDecision.DENY;
reason: string;
}
| {
decision: SafetyCheckDecision.ASK_USER;
reason: string;
};
+47
View File
@@ -0,0 +1,47 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach } from 'vitest';
import { CheckerRegistry } from './registry.js';
import { InProcessCheckerType } from '../policy/types.js';
import { AllowedPathChecker } from './built-in.js';
describe('CheckerRegistry', () => {
let registry: CheckerRegistry;
const mockCheckersPath = '/mock/checkers/path';
beforeEach(() => {
registry = new CheckerRegistry(mockCheckersPath);
});
it('should resolve built-in in-process checkers', () => {
const checker = registry.resolveInProcess(
InProcessCheckerType.ALLOWED_PATH,
);
expect(checker).toBeInstanceOf(AllowedPathChecker);
});
it('should throw for unknown in-process checkers', () => {
expect(() => registry.resolveInProcess('unknown-checker')).toThrow(
'Unknown in-process checker "unknown-checker"',
);
});
it('should validate checker names', () => {
expect(() => registry.resolveInProcess('invalid name!')).toThrow(
'Invalid checker name',
);
expect(() => registry.resolveInProcess('../escape')).toThrow(
'Invalid checker name',
);
});
it('should throw for unknown external checkers (for now)', () => {
expect(() => registry.resolveExternal('some-external')).toThrow(
'Unknown external checker "some-external"',
);
});
});
+83
View File
@@ -0,0 +1,83 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as path from 'node:path';
import * as fs from 'node:fs';
import { type InProcessChecker, AllowedPathChecker } from './built-in.js';
import { InProcessCheckerType } from '../policy/types.js';
/**
* Registry for managing safety checker resolution.
*/
export class CheckerRegistry {
private static readonly BUILT_IN_EXTERNAL_CHECKERS = new Map<string, string>([
// No external built-ins for now
]);
private static readonly BUILT_IN_IN_PROCESS_CHECKERS = new Map<
string,
InProcessChecker
>([[InProcessCheckerType.ALLOWED_PATH, new AllowedPathChecker()]]);
// Regex to validate checker names (alphanumeric and hyphens only)
private static readonly VALID_NAME_PATTERN = /^[a-z0-9-]+$/;
constructor(private readonly checkersPath: string) {}
/**
* Resolves an external checker name to an absolute executable path.
*/
resolveExternal(name: string): string {
if (!CheckerRegistry.isValidCheckerName(name)) {
throw new Error(
`Invalid checker name "${name}". Checker names must contain only lowercase letters, numbers, and hyphens.`,
);
}
const builtInPath = CheckerRegistry.BUILT_IN_EXTERNAL_CHECKERS.get(name);
if (builtInPath) {
const fullPath = path.join(this.checkersPath, builtInPath);
if (!fs.existsSync(fullPath)) {
throw new Error(`Built-in checker "${name}" not found at ${fullPath}`);
}
return fullPath;
}
// TODO: Phase 5 - Add support for custom external checkers
throw new Error(`Unknown external checker "${name}".`);
}
/**
* Resolves an in-process checker name to a checker instance.
*/
resolveInProcess(name: string): InProcessChecker {
if (!CheckerRegistry.isValidCheckerName(name)) {
throw new Error(`Invalid checker name "${name}".`);
}
const checker = CheckerRegistry.BUILT_IN_IN_PROCESS_CHECKERS.get(name);
if (checker) {
return checker;
}
throw new Error(
`Unknown in-process checker "${name}". Available: ${Array.from(
CheckerRegistry.BUILT_IN_IN_PROCESS_CHECKERS.keys(),
).join(', ')}`,
);
}
private static isValidCheckerName(name: string): boolean {
return this.VALID_NAME_PATTERN.test(name) && !name.includes('..');
}
static getBuiltInCheckers(): string[] {
return [
...Array.from(this.BUILT_IN_EXTERNAL_CHECKERS.keys()),
...Array.from(this.BUILT_IN_IN_PROCESS_CHECKERS.keys()),
];
}
}
@@ -47,11 +47,13 @@ describe('BaseToolInvocation', () => {
);
let capturedRequest: ToolConfirmationRequest | undefined;
vi.mocked(messageBus.publish).mockImplementation((request: Message) => {
if (request.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) {
capturedRequest = request;
}
});
vi.mocked(messageBus.publish).mockImplementation(
async (request: Message) => {
if (request.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) {
capturedRequest = request;
}
},
);
let responseHandler:
| ((response: ToolConfirmationResponse) => void)
@@ -102,11 +104,13 @@ describe('BaseToolInvocation', () => {
);
let capturedRequest: ToolConfirmationRequest | undefined;
vi.mocked(messageBus.publish).mockImplementation((request: Message) => {
if (request.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) {
capturedRequest = request;
}
});
vi.mocked(messageBus.publish).mockImplementation(
async (request: Message) => {
if (request.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) {
capturedRequest = request;
}
},
);
// We need to mock subscribe to avoid hanging if we want to await the promise,
// but for this test we just need to check publish.