mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-12 12:54:07 -07:00
[feat] Extension Reloading - respect updates to exclude tools (#12728)
This commit is contained in:
@@ -752,11 +752,11 @@ describe('mergeMcpServers', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
describe('mergeExcludeTools', () => {
|
describe('mergeExcludeTools', () => {
|
||||||
const defaultExcludes = [
|
const defaultExcludes = new Set([
|
||||||
SHELL_TOOL_NAME,
|
SHELL_TOOL_NAME,
|
||||||
EDIT_TOOL_NAME,
|
EDIT_TOOL_NAME,
|
||||||
WRITE_FILE_TOOL_NAME,
|
WRITE_FILE_TOOL_NAME,
|
||||||
];
|
]);
|
||||||
const originalIsTTY = process.stdin.isTTY;
|
const originalIsTTY = process.stdin.isTTY;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
@@ -799,7 +799,7 @@ describe('mergeExcludeTools', () => {
|
|||||||
argv,
|
argv,
|
||||||
);
|
);
|
||||||
expect(config.getExcludeTools()).toEqual(
|
expect(config.getExcludeTools()).toEqual(
|
||||||
expect.arrayContaining(['tool1', 'tool2', 'tool3', 'tool4', 'tool5']),
|
new Set(['tool1', 'tool2', 'tool3', 'tool4', 'tool5']),
|
||||||
);
|
);
|
||||||
expect(config.getExcludeTools()).toHaveLength(5);
|
expect(config.getExcludeTools()).toHaveLength(5);
|
||||||
});
|
});
|
||||||
@@ -821,7 +821,7 @@ describe('mergeExcludeTools', () => {
|
|||||||
const argv = await parseArguments({} as Settings);
|
const argv = await parseArguments({} as Settings);
|
||||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||||
expect(config.getExcludeTools()).toEqual(
|
expect(config.getExcludeTools()).toEqual(
|
||||||
expect.arrayContaining(['tool1', 'tool2', 'tool3']),
|
new Set(['tool1', 'tool2', 'tool3']),
|
||||||
);
|
);
|
||||||
expect(config.getExcludeTools()).toHaveLength(3);
|
expect(config.getExcludeTools()).toHaveLength(3);
|
||||||
});
|
});
|
||||||
@@ -852,7 +852,7 @@ describe('mergeExcludeTools', () => {
|
|||||||
const argv = await parseArguments({} as Settings);
|
const argv = await parseArguments({} as Settings);
|
||||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||||
expect(config.getExcludeTools()).toEqual(
|
expect(config.getExcludeTools()).toEqual(
|
||||||
expect.arrayContaining(['tool1', 'tool2', 'tool3', 'tool4']),
|
new Set(['tool1', 'tool2', 'tool3', 'tool4']),
|
||||||
);
|
);
|
||||||
expect(config.getExcludeTools()).toHaveLength(4);
|
expect(config.getExcludeTools()).toHaveLength(4);
|
||||||
});
|
});
|
||||||
@@ -863,7 +863,7 @@ describe('mergeExcludeTools', () => {
|
|||||||
process.argv = ['node', 'script.js'];
|
process.argv = ['node', 'script.js'];
|
||||||
const argv = await parseArguments({} as Settings);
|
const argv = await parseArguments({} as Settings);
|
||||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||||
expect(config.getExcludeTools()).toEqual([]);
|
expect(config.getExcludeTools()).toEqual(new Set([]));
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return default excludes when no excludeTools are specified and it is not interactive', async () => {
|
it('should return default excludes when no excludeTools are specified and it is not interactive', async () => {
|
||||||
@@ -881,9 +881,7 @@ describe('mergeExcludeTools', () => {
|
|||||||
const settings: Settings = { tools: { exclude: ['tool1', 'tool2'] } };
|
const settings: Settings = { tools: { exclude: ['tool1', 'tool2'] } };
|
||||||
vi.spyOn(ExtensionManager.prototype, 'getExtensions').mockReturnValue([]);
|
vi.spyOn(ExtensionManager.prototype, 'getExtensions').mockReturnValue([]);
|
||||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||||
expect(config.getExcludeTools()).toEqual(
|
expect(config.getExcludeTools()).toEqual(new Set(['tool1', 'tool2']));
|
||||||
expect.arrayContaining(['tool1', 'tool2']),
|
|
||||||
);
|
|
||||||
expect(config.getExcludeTools()).toHaveLength(2);
|
expect(config.getExcludeTools()).toHaveLength(2);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -903,9 +901,7 @@ describe('mergeExcludeTools', () => {
|
|||||||
process.argv = ['node', 'script.js'];
|
process.argv = ['node', 'script.js'];
|
||||||
const argv = await parseArguments({} as Settings);
|
const argv = await parseArguments({} as Settings);
|
||||||
const config = await loadCliConfig(settings, 'test-session', argv);
|
const config = await loadCliConfig(settings, 'test-session', argv);
|
||||||
expect(config.getExcludeTools()).toEqual(
|
expect(config.getExcludeTools()).toEqual(new Set(['tool1', 'tool2']));
|
||||||
expect.arrayContaining(['tool1', 'tool2']),
|
|
||||||
);
|
|
||||||
expect(config.getExcludeTools()).toHaveLength(2);
|
expect(config.getExcludeTools()).toHaveLength(2);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ describe('handleAtCommand', () => {
|
|||||||
getToolRegistry,
|
getToolRegistry,
|
||||||
getTargetDir: () => testRootDir,
|
getTargetDir: () => testRootDir,
|
||||||
isSandboxed: () => false,
|
isSandboxed: () => false,
|
||||||
|
getExcludeTools: vi.fn(),
|
||||||
getFileService: () => new FileDiscoveryService(testRootDir),
|
getFileService: () => new FileDiscoveryService(testRootDir),
|
||||||
getFileFilteringRespectGitIgnore: () => true,
|
getFileFilteringRespectGitIgnore: () => true,
|
||||||
getFileFilteringRespectGeminiIgnore: () => true,
|
getFileFilteringRespectGeminiIgnore: () => true,
|
||||||
|
|||||||
@@ -872,27 +872,6 @@ describe('Server Config (config.ts)', () => {
|
|||||||
expect(wasShellToolRegistered).toBe(true);
|
expect(wasShellToolRegistered).toBe(true);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should not register a tool if excludeTools contains the non-minified class name', async () => {
|
|
||||||
const params: ConfigParameters = {
|
|
||||||
...baseParams,
|
|
||||||
coreTools: undefined, // all tools enabled by default
|
|
||||||
excludeTools: ['ShellTool'],
|
|
||||||
};
|
|
||||||
const config = new Config(params);
|
|
||||||
await config.initialize();
|
|
||||||
|
|
||||||
const registerToolMock = (
|
|
||||||
(await vi.importMock('../tools/tool-registry')) as {
|
|
||||||
ToolRegistry: { prototype: { registerTool: Mock } };
|
|
||||||
}
|
|
||||||
).ToolRegistry.prototype.registerTool;
|
|
||||||
|
|
||||||
const wasShellToolRegistered = (
|
|
||||||
registerToolMock as Mock
|
|
||||||
).mock.calls.some((call) => call[0] instanceof vi.mocked(ShellTool));
|
|
||||||
expect(wasShellToolRegistered).toBe(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should register a tool if coreTools contains an argument-specific pattern with the non-minified class name', async () => {
|
it('should register a tool if coreTools contains an argument-specific pattern with the non-minified class name', async () => {
|
||||||
const params: ConfigParameters = {
|
const params: ConfigParameters = {
|
||||||
...baseParams,
|
...baseParams,
|
||||||
|
|||||||
@@ -826,7 +826,7 @@ export class Config {
|
|||||||
*
|
*
|
||||||
* May change over time.
|
* May change over time.
|
||||||
*/
|
*/
|
||||||
getExcludeTools(): string[] | undefined {
|
getExcludeTools(): Set<string> | undefined {
|
||||||
const excludeToolsSet = new Set([...(this.excludeTools ?? [])]);
|
const excludeToolsSet = new Set([...(this.excludeTools ?? [])]);
|
||||||
for (const extension of this.getExtensionLoader().getExtensions()) {
|
for (const extension of this.getExtensionLoader().getExtensions()) {
|
||||||
if (!extension.isActive) {
|
if (!extension.isActive) {
|
||||||
@@ -836,7 +836,7 @@ export class Config {
|
|||||||
excludeToolsSet.add(tool);
|
excludeToolsSet.add(tool);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return [...excludeToolsSet];
|
return excludeToolsSet;
|
||||||
}
|
}
|
||||||
|
|
||||||
getToolDiscoveryCommand(): string | undefined {
|
getToolDiscoveryCommand(): string | undefined {
|
||||||
@@ -1282,7 +1282,6 @@ export class Config {
|
|||||||
const className = ToolClass.name;
|
const className = ToolClass.name;
|
||||||
const toolName = ToolClass.Name || className;
|
const toolName = ToolClass.Name || className;
|
||||||
const coreTools = this.getCoreTools();
|
const coreTools = this.getCoreTools();
|
||||||
const excludeTools = this.getExcludeTools() || [];
|
|
||||||
// On some platforms, the className can be minified to _ClassName.
|
// On some platforms, the className can be minified to _ClassName.
|
||||||
const normalizedClassName = className.replace(/^_+/, '');
|
const normalizedClassName = className.replace(/^_+/, '');
|
||||||
|
|
||||||
@@ -1297,14 +1296,6 @@ export class Config {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const isExcluded = excludeTools.some(
|
|
||||||
(tool) => tool === toolName || tool === normalizedClassName,
|
|
||||||
);
|
|
||||||
|
|
||||||
if (isExcluded) {
|
|
||||||
isEnabled = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isEnabled) {
|
if (isEnabled) {
|
||||||
// Pass message bus to tools when feature flag is enabled
|
// Pass message bus to tools when feature flag is enabled
|
||||||
// This first implementation is only focused on the general case of
|
// This first implementation is only focused on the general case of
|
||||||
@@ -1363,15 +1354,12 @@ export class Config {
|
|||||||
);
|
);
|
||||||
if (definition) {
|
if (definition) {
|
||||||
// We must respect the main allowed/exclude lists for agents too.
|
// We must respect the main allowed/exclude lists for agents too.
|
||||||
const excludeTools = this.getExcludeTools() || [];
|
|
||||||
|
|
||||||
const allowedTools = this.getAllowedTools();
|
const allowedTools = this.getAllowedTools();
|
||||||
|
|
||||||
const isExcluded = excludeTools.includes(definition.name);
|
|
||||||
const isAllowed =
|
const isAllowed =
|
||||||
!allowedTools || allowedTools.includes(definition.name);
|
!allowedTools || allowedTools.includes(definition.name);
|
||||||
|
|
||||||
if (isAllowed && !isExcluded) {
|
if (isAllowed) {
|
||||||
const messageBusEnabled = this.getEnableMessageBusIntegration();
|
const messageBusEnabled = this.getEnableMessageBusIntegration();
|
||||||
const wrapper = new SubagentToolWrapper(
|
const wrapper = new SubagentToolWrapper(
|
||||||
definition,
|
definition,
|
||||||
|
|||||||
@@ -243,6 +243,10 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getFullyQualifiedPrefix(): string {
|
||||||
|
return `${this.serverName}__`;
|
||||||
|
}
|
||||||
|
|
||||||
asFullyQualifiedTool(): DiscoveredMCPTool {
|
asFullyQualifiedTool(): DiscoveredMCPTool {
|
||||||
return new DiscoveredMCPTool(
|
return new DiscoveredMCPTool(
|
||||||
this.mcpTool,
|
this.mcpTool,
|
||||||
@@ -251,7 +255,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
|||||||
this.description,
|
this.description,
|
||||||
this.parameterSchema,
|
this.parameterSchema,
|
||||||
this.trust,
|
this.trust,
|
||||||
`${this.serverName}__${this.serverToolName}`,
|
`${this.getFullyQualifiedPrefix()}${this.serverToolName}`,
|
||||||
this.cliConfig,
|
this.cliConfig,
|
||||||
this.extensionName,
|
this.extensionName,
|
||||||
this.extensionId,
|
this.extensionId,
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ describe('ShellTool', () => {
|
|||||||
getAllowedTools: vi.fn().mockReturnValue([]),
|
getAllowedTools: vi.fn().mockReturnValue([]),
|
||||||
getApprovalMode: vi.fn().mockReturnValue('strict'),
|
getApprovalMode: vi.fn().mockReturnValue('strict'),
|
||||||
getCoreTools: vi.fn().mockReturnValue([]),
|
getCoreTools: vi.fn().mockReturnValue([]),
|
||||||
getExcludeTools: vi.fn().mockReturnValue([]),
|
getExcludeTools: vi.fn().mockReturnValue(new Set([])),
|
||||||
getDebugMode: vi.fn().mockReturnValue(false),
|
getDebugMode: vi.fn().mockReturnValue(false),
|
||||||
getTargetDir: vi.fn().mockReturnValue(tempRootDir),
|
getTargetDir: vi.fn().mockReturnValue(tempRootDir),
|
||||||
getSummarizeToolOutputConfig: vi.fn().mockReturnValue(undefined),
|
getSummarizeToolOutputConfig: vi.fn().mockReturnValue(undefined),
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
import type { Mocked } from 'vitest';
|
import type { Mocked, MockInstance } from 'vitest';
|
||||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||||
import type { ConfigParameters } from '../config/config.js';
|
import type { ConfigParameters } from '../config/config.js';
|
||||||
import { Config } from '../config/config.js';
|
import { Config } from '../config/config.js';
|
||||||
@@ -109,6 +109,9 @@ describe('ToolRegistry', () => {
|
|||||||
let config: Config;
|
let config: Config;
|
||||||
let toolRegistry: ToolRegistry;
|
let toolRegistry: ToolRegistry;
|
||||||
let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>;
|
let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>;
|
||||||
|
let mockConfigGetExcludedTools: MockInstance<
|
||||||
|
typeof Config.prototype.getExcludeTools
|
||||||
|
>;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.mocked(fs.existsSync).mockReturnValue(true);
|
vi.mocked(fs.existsSync).mockReturnValue(true);
|
||||||
@@ -132,6 +135,7 @@ describe('ToolRegistry', () => {
|
|||||||
config,
|
config,
|
||||||
'getToolDiscoveryCommand',
|
'getToolDiscoveryCommand',
|
||||||
);
|
);
|
||||||
|
mockConfigGetExcludedTools = vi.spyOn(config, 'getExcludeTools');
|
||||||
vi.spyOn(config, 'getMcpServers');
|
vi.spyOn(config, 'getMcpServers');
|
||||||
vi.spyOn(config, 'getMcpServerCommand');
|
vi.spyOn(config, 'getMcpServerCommand');
|
||||||
vi.spyOn(config, 'getPromptRegistry').mockReturnValue({
|
vi.spyOn(config, 'getPromptRegistry').mockReturnValue({
|
||||||
@@ -152,6 +156,75 @@ describe('ToolRegistry', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('excluded tools', () => {
|
||||||
|
const simpleTool = new MockTool({
|
||||||
|
name: 'tool-a',
|
||||||
|
displayName: 'Tool a',
|
||||||
|
});
|
||||||
|
const excludedTool = new ExcludedMockTool({
|
||||||
|
name: 'excluded-tool-class',
|
||||||
|
displayName: 'Excluded Tool Class',
|
||||||
|
});
|
||||||
|
const mockCallable = {} as CallableTool;
|
||||||
|
const mcpTool = new DiscoveredMCPTool(
|
||||||
|
mockCallable,
|
||||||
|
'mcp-server',
|
||||||
|
'excluded-mcp-tool',
|
||||||
|
'description',
|
||||||
|
{},
|
||||||
|
);
|
||||||
|
const allowedTool = new MockTool({
|
||||||
|
name: 'allowed-tool',
|
||||||
|
displayName: 'Allowed Tool',
|
||||||
|
});
|
||||||
|
|
||||||
|
it.each([
|
||||||
|
{
|
||||||
|
name: 'should match simple names',
|
||||||
|
tools: [simpleTool],
|
||||||
|
excludedTools: ['tool-a'],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'should match simple MCP tool names, when qualified or unqualified',
|
||||||
|
tools: [mcpTool, mcpTool.asFullyQualifiedTool()],
|
||||||
|
excludedTools: [mcpTool.name],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'should match qualified MCP tool names when qualified or unqualified',
|
||||||
|
tools: [mcpTool, mcpTool.asFullyQualifiedTool()],
|
||||||
|
excludedTools: [`${mcpTool.getFullyQualifiedPrefix()}${mcpTool.name}`],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'should match class names',
|
||||||
|
tools: [excludedTool],
|
||||||
|
excludedTools: ['ExcludedMockTool'],
|
||||||
|
},
|
||||||
|
])('$name', ({ tools, excludedTools }) => {
|
||||||
|
toolRegistry.registerTool(allowedTool);
|
||||||
|
for (const tool of tools) {
|
||||||
|
toolRegistry.registerTool(tool);
|
||||||
|
}
|
||||||
|
mockConfigGetExcludedTools.mockReturnValue(new Set(excludedTools));
|
||||||
|
|
||||||
|
expect(toolRegistry.getAllTools()).toEqual([allowedTool]);
|
||||||
|
expect(toolRegistry.getAllToolNames()).toEqual([allowedTool.name]);
|
||||||
|
expect(toolRegistry.getFunctionDeclarations()).toEqual(
|
||||||
|
toolRegistry.getFunctionDeclarationsFiltered([allowedTool.name]),
|
||||||
|
);
|
||||||
|
for (const tool of tools) {
|
||||||
|
expect(toolRegistry.getTool(tool.name)).toBeUndefined();
|
||||||
|
expect(
|
||||||
|
toolRegistry.getFunctionDeclarationsFiltered([tool.name]),
|
||||||
|
).toHaveLength(0);
|
||||||
|
if (tool instanceof DiscoveredMCPTool) {
|
||||||
|
expect(toolRegistry.getToolsByServer(tool.serverName)).toHaveLength(
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('getAllTools', () => {
|
describe('getAllTools', () => {
|
||||||
it('should return all registered tools sorted alphabetically by displayName', () => {
|
it('should return all registered tools sorted alphabetically by displayName', () => {
|
||||||
// Register tools with displayNames in non-alphabetical order
|
// Register tools with displayNames in non-alphabetical order
|
||||||
@@ -521,3 +594,12 @@ describe('ToolRegistry', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Used for tests that exclude by class name.
|
||||||
|
*/
|
||||||
|
class ExcludedMockTool extends MockTool {
|
||||||
|
constructor(options: ConstructorParameters<typeof MockTool>[0]) {
|
||||||
|
super(options);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -189,7 +189,9 @@ Signal: Signal number or \`(none)\` if no signal was received.
|
|||||||
|
|
||||||
export class ToolRegistry {
|
export class ToolRegistry {
|
||||||
// The tools keyed by tool name as seen by the LLM.
|
// The tools keyed by tool name as seen by the LLM.
|
||||||
private tools: Map<string, AnyDeclarativeTool> = new Map();
|
// This includes tools which are currently not active, use `getActiveTools`
|
||||||
|
// and `isActive` to get only the active tools.
|
||||||
|
private allKnownTools: Map<string, AnyDeclarativeTool> = new Map();
|
||||||
private config: Config;
|
private config: Config;
|
||||||
private messageBus?: MessageBus;
|
private messageBus?: MessageBus;
|
||||||
|
|
||||||
@@ -207,10 +209,14 @@ export class ToolRegistry {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Registers a tool definition.
|
* Registers a tool definition.
|
||||||
|
*
|
||||||
|
* Note that excluded tools are still registered to allow for enabling them
|
||||||
|
* later in the session.
|
||||||
|
*
|
||||||
* @param tool - The tool object containing schema and execution logic.
|
* @param tool - The tool object containing schema and execution logic.
|
||||||
*/
|
*/
|
||||||
registerTool(tool: AnyDeclarativeTool): void {
|
registerTool(tool: AnyDeclarativeTool): void {
|
||||||
if (this.tools.has(tool.name)) {
|
if (this.allKnownTools.has(tool.name)) {
|
||||||
if (tool instanceof DiscoveredMCPTool) {
|
if (tool instanceof DiscoveredMCPTool) {
|
||||||
tool = tool.asFullyQualifiedTool();
|
tool = tool.asFullyQualifiedTool();
|
||||||
} else {
|
} else {
|
||||||
@@ -220,7 +226,7 @@ export class ToolRegistry {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
this.tools.set(tool.name, tool);
|
this.allKnownTools.set(tool.name, tool);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -229,7 +235,7 @@ export class ToolRegistry {
|
|||||||
* 2. Discovered tools.
|
* 2. Discovered tools.
|
||||||
* 3. MCP tools ordered by server name.
|
* 3. MCP tools ordered by server name.
|
||||||
*
|
*
|
||||||
* This is a stable sort in that ties preseve existing order.
|
* This is a stable sort in that tries preserve existing order.
|
||||||
*/
|
*/
|
||||||
sortTools(): void {
|
sortTools(): void {
|
||||||
const getPriority = (tool: AnyDeclarativeTool): number => {
|
const getPriority = (tool: AnyDeclarativeTool): number => {
|
||||||
@@ -238,8 +244,8 @@ export class ToolRegistry {
|
|||||||
return 0; // Built-in
|
return 0; // Built-in
|
||||||
};
|
};
|
||||||
|
|
||||||
this.tools = new Map(
|
this.allKnownTools = new Map(
|
||||||
Array.from(this.tools.entries()).sort((a, b) => {
|
Array.from(this.allKnownTools.entries()).sort((a, b) => {
|
||||||
const toolA = a[1];
|
const toolA = a[1];
|
||||||
const toolB = b[1];
|
const toolB = b[1];
|
||||||
const priorityA = getPriority(toolA);
|
const priorityA = getPriority(toolA);
|
||||||
@@ -261,9 +267,9 @@ export class ToolRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private removeDiscoveredTools(): void {
|
private removeDiscoveredTools(): void {
|
||||||
for (const tool of this.tools.values()) {
|
for (const tool of this.allKnownTools.values()) {
|
||||||
if (tool instanceof DiscoveredTool || tool instanceof DiscoveredMCPTool) {
|
if (tool instanceof DiscoveredTool || tool instanceof DiscoveredMCPTool) {
|
||||||
this.tools.delete(tool.name);
|
this.allKnownTools.delete(tool.name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -273,9 +279,9 @@ export class ToolRegistry {
|
|||||||
* @param serverName The name of the server to remove tools from.
|
* @param serverName The name of the server to remove tools from.
|
||||||
*/
|
*/
|
||||||
removeMcpToolsByServer(serverName: string): void {
|
removeMcpToolsByServer(serverName: string): void {
|
||||||
for (const [name, tool] of this.tools.entries()) {
|
for (const [name, tool] of this.allKnownTools.entries()) {
|
||||||
if (tool instanceof DiscoveredMCPTool && tool.serverName === serverName) {
|
if (tool instanceof DiscoveredMCPTool && tool.serverName === serverName) {
|
||||||
this.tools.delete(name);
|
this.allKnownTools.delete(name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -416,6 +422,45 @@ export class ToolRegistry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @returns All the tools that are not excluded.
|
||||||
|
*/
|
||||||
|
private getActiveTools(): AnyDeclarativeTool[] {
|
||||||
|
const excludedTools = this.config.getExcludeTools() ?? new Set([]);
|
||||||
|
const activeTools: AnyDeclarativeTool[] = [];
|
||||||
|
for (const tool of this.allKnownTools.values()) {
|
||||||
|
if (this.isActiveTool(tool, excludedTools)) {
|
||||||
|
activeTools.push(tool);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return activeTools;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param tool
|
||||||
|
* @param excludeTools (optional, helps performance for repeated calls)
|
||||||
|
* @returns Whether or not the `tool` is not excluded.
|
||||||
|
*/
|
||||||
|
private isActiveTool(
|
||||||
|
tool: AnyDeclarativeTool,
|
||||||
|
excludeTools?: Set<string>,
|
||||||
|
): boolean {
|
||||||
|
excludeTools ??= this.config.getExcludeTools() ?? new Set([]);
|
||||||
|
const normalizedClassName = tool.constructor.name.replace(/^_+/, '');
|
||||||
|
const possibleNames = [tool.name, normalizedClassName];
|
||||||
|
if (tool instanceof DiscoveredMCPTool) {
|
||||||
|
// Check both the unqualified and qualified name for MCP tools.
|
||||||
|
if (tool.name.startsWith(tool.getFullyQualifiedPrefix())) {
|
||||||
|
possibleNames.push(
|
||||||
|
tool.name.substring(tool.getFullyQualifiedPrefix().length),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
possibleNames.push(`${tool.getFullyQualifiedPrefix()}${tool.name}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return !possibleNames.some((name) => excludeTools.has(name));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Retrieves the list of tool schemas (FunctionDeclaration array).
|
* Retrieves the list of tool schemas (FunctionDeclaration array).
|
||||||
* Extracts the declarations from the ToolListUnion structure.
|
* Extracts the declarations from the ToolListUnion structure.
|
||||||
@@ -424,7 +469,7 @@ export class ToolRegistry {
|
|||||||
*/
|
*/
|
||||||
getFunctionDeclarations(): FunctionDeclaration[] {
|
getFunctionDeclarations(): FunctionDeclaration[] {
|
||||||
const declarations: FunctionDeclaration[] = [];
|
const declarations: FunctionDeclaration[] = [];
|
||||||
this.tools.forEach((tool) => {
|
this.getActiveTools().forEach((tool) => {
|
||||||
declarations.push(tool.schema);
|
declarations.push(tool.schema);
|
||||||
});
|
});
|
||||||
return declarations;
|
return declarations;
|
||||||
@@ -438,8 +483,8 @@ export class ToolRegistry {
|
|||||||
getFunctionDeclarationsFiltered(toolNames: string[]): FunctionDeclaration[] {
|
getFunctionDeclarationsFiltered(toolNames: string[]): FunctionDeclaration[] {
|
||||||
const declarations: FunctionDeclaration[] = [];
|
const declarations: FunctionDeclaration[] = [];
|
||||||
for (const name of toolNames) {
|
for (const name of toolNames) {
|
||||||
const tool = this.tools.get(name);
|
const tool = this.allKnownTools.get(name);
|
||||||
if (tool) {
|
if (tool && this.isActiveTool(tool)) {
|
||||||
declarations.push(tool.schema);
|
declarations.push(tool.schema);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -447,17 +492,18 @@ export class ToolRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns an array of all registered and discovered tool names.
|
* Returns an array of all registered and discovered tool names which are not
|
||||||
|
* excluded via configuration.
|
||||||
*/
|
*/
|
||||||
getAllToolNames(): string[] {
|
getAllToolNames(): string[] {
|
||||||
return Array.from(this.tools.keys());
|
return this.getActiveTools().map((tool) => tool.name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns an array of all registered and discovered tool instances.
|
* Returns an array of all registered and discovered tool instances.
|
||||||
*/
|
*/
|
||||||
getAllTools(): AnyDeclarativeTool[] {
|
getAllTools(): AnyDeclarativeTool[] {
|
||||||
return Array.from(this.tools.values()).sort((a, b) =>
|
return this.getActiveTools().sort((a, b) =>
|
||||||
a.displayName.localeCompare(b.displayName),
|
a.displayName.localeCompare(b.displayName),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -467,7 +513,7 @@ export class ToolRegistry {
|
|||||||
*/
|
*/
|
||||||
getToolsByServer(serverName: string): AnyDeclarativeTool[] {
|
getToolsByServer(serverName: string): AnyDeclarativeTool[] {
|
||||||
const serverTools: AnyDeclarativeTool[] = [];
|
const serverTools: AnyDeclarativeTool[] = [];
|
||||||
for (const tool of this.tools.values()) {
|
for (const tool of this.getActiveTools()) {
|
||||||
if ((tool as DiscoveredMCPTool)?.serverName === serverName) {
|
if ((tool as DiscoveredMCPTool)?.serverName === serverName) {
|
||||||
serverTools.push(tool);
|
serverTools.push(tool);
|
||||||
}
|
}
|
||||||
@@ -479,6 +525,10 @@ export class ToolRegistry {
|
|||||||
* Get the definition of a specific tool.
|
* Get the definition of a specific tool.
|
||||||
*/
|
*/
|
||||||
getTool(name: string): AnyDeclarativeTool | undefined {
|
getTool(name: string): AnyDeclarativeTool | undefined {
|
||||||
return this.tools.get(name);
|
const tool = this.allKnownTools.get(name);
|
||||||
|
if (tool && this.isActiveTool(tool)) {
|
||||||
|
return tool;
|
||||||
|
}
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,10 +4,19 @@
|
|||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest';
|
import {
|
||||||
|
describe,
|
||||||
|
expect,
|
||||||
|
it,
|
||||||
|
vi,
|
||||||
|
beforeEach,
|
||||||
|
afterEach,
|
||||||
|
type MockInstance,
|
||||||
|
} from 'vitest';
|
||||||
import { SimpleExtensionLoader } from './extensionLoader.js';
|
import { SimpleExtensionLoader } from './extensionLoader.js';
|
||||||
import type { Config } from '../config/config.js';
|
import type { Config, GeminiCLIExtension } from '../config/config.js';
|
||||||
import { type McpClientManager } from '../tools/mcp-client-manager.js';
|
import { type McpClientManager } from '../tools/mcp-client-manager.js';
|
||||||
|
import type { GeminiClient } from '../core/client.js';
|
||||||
|
|
||||||
const mockRefreshServerHierarchicalMemory = vi.hoisted(() => vi.fn());
|
const mockRefreshServerHierarchicalMemory = vi.hoisted(() => vi.fn());
|
||||||
|
|
||||||
@@ -23,15 +32,20 @@ describe('SimpleExtensionLoader', () => {
|
|||||||
let mockConfig: Config;
|
let mockConfig: Config;
|
||||||
let extensionReloadingEnabled: boolean;
|
let extensionReloadingEnabled: boolean;
|
||||||
let mockMcpClientManager: McpClientManager;
|
let mockMcpClientManager: McpClientManager;
|
||||||
const activeExtension = {
|
let mockGeminiClientSetTools: MockInstance<
|
||||||
|
typeof GeminiClient.prototype.setTools
|
||||||
|
>;
|
||||||
|
|
||||||
|
const activeExtension: GeminiCLIExtension = {
|
||||||
name: 'test-extension',
|
name: 'test-extension',
|
||||||
isActive: true,
|
isActive: true,
|
||||||
version: '1.0.0',
|
version: '1.0.0',
|
||||||
path: '/path/to/extension',
|
path: '/path/to/extension',
|
||||||
contextFiles: [],
|
contextFiles: [],
|
||||||
|
excludeTools: ['some-tool'],
|
||||||
id: '123',
|
id: '123',
|
||||||
};
|
};
|
||||||
const inactiveExtension = {
|
const inactiveExtension: GeminiCLIExtension = {
|
||||||
name: 'test-extension',
|
name: 'test-extension',
|
||||||
isActive: false,
|
isActive: false,
|
||||||
version: '1.0.0',
|
version: '1.0.0',
|
||||||
@@ -46,9 +60,14 @@ describe('SimpleExtensionLoader', () => {
|
|||||||
stopExtension: vi.fn(),
|
stopExtension: vi.fn(),
|
||||||
} as unknown as McpClientManager;
|
} as unknown as McpClientManager;
|
||||||
extensionReloadingEnabled = false;
|
extensionReloadingEnabled = false;
|
||||||
|
mockGeminiClientSetTools = vi.fn();
|
||||||
mockConfig = {
|
mockConfig = {
|
||||||
getMcpClientManager: () => mockMcpClientManager,
|
getMcpClientManager: () => mockMcpClientManager,
|
||||||
getEnableExtensionReloading: () => extensionReloadingEnabled,
|
getEnableExtensionReloading: () => extensionReloadingEnabled,
|
||||||
|
getGeminiClient: vi.fn(() => ({
|
||||||
|
isInitialized: () => true,
|
||||||
|
setTools: mockGeminiClientSetTools,
|
||||||
|
})),
|
||||||
} as unknown as Config;
|
} as unknown as Config;
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -106,11 +125,14 @@ describe('SimpleExtensionLoader', () => {
|
|||||||
mockMcpClientManager.startExtension,
|
mockMcpClientManager.startExtension,
|
||||||
).toHaveBeenCalledExactlyOnceWith(activeExtension);
|
).toHaveBeenCalledExactlyOnceWith(activeExtension);
|
||||||
expect(mockRefreshServerHierarchicalMemory).toHaveBeenCalledOnce();
|
expect(mockRefreshServerHierarchicalMemory).toHaveBeenCalledOnce();
|
||||||
|
expect(mockGeminiClientSetTools).toHaveBeenCalledOnce();
|
||||||
} else {
|
} else {
|
||||||
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
|
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
|
||||||
expect(mockRefreshServerHierarchicalMemory).not.toHaveBeenCalled();
|
expect(mockRefreshServerHierarchicalMemory).not.toHaveBeenCalled();
|
||||||
|
expect(mockGeminiClientSetTools).not.toHaveBeenCalledOnce();
|
||||||
}
|
}
|
||||||
mockRefreshServerHierarchicalMemory.mockClear();
|
mockRefreshServerHierarchicalMemory.mockClear();
|
||||||
|
mockGeminiClientSetTools.mockClear();
|
||||||
|
|
||||||
await loader.unloadExtension(activeExtension);
|
await loader.unloadExtension(activeExtension);
|
||||||
if (reloadingEnabled) {
|
if (reloadingEnabled) {
|
||||||
@@ -118,9 +140,11 @@ describe('SimpleExtensionLoader', () => {
|
|||||||
mockMcpClientManager.stopExtension,
|
mockMcpClientManager.stopExtension,
|
||||||
).toHaveBeenCalledExactlyOnceWith(activeExtension);
|
).toHaveBeenCalledExactlyOnceWith(activeExtension);
|
||||||
expect(mockRefreshServerHierarchicalMemory).toHaveBeenCalledOnce();
|
expect(mockRefreshServerHierarchicalMemory).toHaveBeenCalledOnce();
|
||||||
|
expect(mockGeminiClientSetTools).toHaveBeenCalledOnce();
|
||||||
} else {
|
} else {
|
||||||
expect(mockMcpClientManager.stopExtension).not.toHaveBeenCalled();
|
expect(mockMcpClientManager.stopExtension).not.toHaveBeenCalled();
|
||||||
expect(mockRefreshServerHierarchicalMemory).not.toHaveBeenCalled();
|
expect(mockRefreshServerHierarchicalMemory).not.toHaveBeenCalled();
|
||||||
|
expect(mockGeminiClientSetTools).not.toHaveBeenCalledOnce();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -73,6 +73,8 @@ export abstract class ExtensionLoader {
|
|||||||
});
|
});
|
||||||
try {
|
try {
|
||||||
await this.config.getMcpClientManager()!.startExtension(extension);
|
await this.config.getMcpClientManager()!.startExtension(extension);
|
||||||
|
await this.maybeRefreshGeminiTools(extension);
|
||||||
|
|
||||||
// Note: Context files are loaded only once all extensions are done
|
// Note: Context files are loaded only once all extensions are done
|
||||||
// loading/unloading to reduce churn, see the `maybeRefreshMemories` call
|
// loading/unloading to reduce churn, see the `maybeRefreshMemories` call
|
||||||
// below.
|
// below.
|
||||||
@@ -80,9 +82,6 @@ export abstract class ExtensionLoader {
|
|||||||
// TODO: Update custom command updating away from the event based system
|
// TODO: Update custom command updating away from the event based system
|
||||||
// and call directly into a custom command manager here. See the
|
// and call directly into a custom command manager here. See the
|
||||||
// useSlashCommandProcessor hook which responds to events fired here today.
|
// useSlashCommandProcessor hook which responds to events fired here today.
|
||||||
|
|
||||||
// TODO: Move all enablement of extension features here, including at least:
|
|
||||||
// - excluded tool configuration
|
|
||||||
} finally {
|
} finally {
|
||||||
this.startCompletedCount++;
|
this.startCompletedCount++;
|
||||||
this.eventEmitter?.emit('extensionsStarting', {
|
this.eventEmitter?.emit('extensionsStarting', {
|
||||||
@@ -115,6 +114,21 @@ export abstract class ExtensionLoader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Refreshes the gemini tools list if it is initialized and the extension has
|
||||||
|
* any excludeTools settings.
|
||||||
|
*/
|
||||||
|
private async maybeRefreshGeminiTools(
|
||||||
|
extension: GeminiCLIExtension,
|
||||||
|
): Promise<void> {
|
||||||
|
if (extension.excludeTools && extension.excludeTools.length > 0) {
|
||||||
|
const geminiClient = this.config?.getGeminiClient();
|
||||||
|
if (geminiClient?.isInitialized()) {
|
||||||
|
await geminiClient.setTools();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* If extension reloading is enabled and `start` has already been called,
|
* If extension reloading is enabled and `start` has already been called,
|
||||||
* then calls `startExtension` to include all extension features into the
|
* then calls `startExtension` to include all extension features into the
|
||||||
@@ -150,6 +164,8 @@ export abstract class ExtensionLoader {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
await this.config.getMcpClientManager()!.stopExtension(extension);
|
await this.config.getMcpClientManager()!.stopExtension(extension);
|
||||||
|
await this.maybeRefreshGeminiTools(extension);
|
||||||
|
|
||||||
// Note: Context files are loaded only once all extensions are done
|
// Note: Context files are loaded only once all extensions are done
|
||||||
// loading/unloading to reduce churn, see the `maybeRefreshMemories` call
|
// loading/unloading to reduce churn, see the `maybeRefreshMemories` call
|
||||||
// below.
|
// below.
|
||||||
@@ -157,9 +173,6 @@ export abstract class ExtensionLoader {
|
|||||||
// TODO: Update custom command updating away from the event based system
|
// TODO: Update custom command updating away from the event based system
|
||||||
// and call directly into a custom command manager here. See the
|
// and call directly into a custom command manager here. See the
|
||||||
// useSlashCommandProcessor hook which responds to events fired here today.
|
// useSlashCommandProcessor hook which responds to events fired here today.
|
||||||
|
|
||||||
// TODO: Remove all extension features here, including at least:
|
|
||||||
// - excluded tools
|
|
||||||
} finally {
|
} finally {
|
||||||
this.stopCompletedCount++;
|
this.stopCompletedCount++;
|
||||||
this.eventEmitter?.emit('extensionsStopping', {
|
this.eventEmitter?.emit('extensionsStopping', {
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ beforeEach(() => {
|
|||||||
);
|
);
|
||||||
config = {
|
config = {
|
||||||
getCoreTools: () => [],
|
getCoreTools: () => [],
|
||||||
getExcludeTools: () => [],
|
getExcludeTools: () => new Set([]),
|
||||||
getAllowedTools: () => [],
|
getAllowedTools: () => [],
|
||||||
} as unknown as Config;
|
} as unknown as Config;
|
||||||
});
|
});
|
||||||
@@ -89,7 +89,7 @@ describe('isCommandAllowed', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should block a command if it is in the blocked list', () => {
|
it('should block a command if it is in the blocked list', () => {
|
||||||
config.getExcludeTools = () => ['ShellTool(badCommand --danger)'];
|
config.getExcludeTools = () => new Set(['ShellTool(badCommand --danger)']);
|
||||||
const result = isCommandAllowed('badCommand --danger', config);
|
const result = isCommandAllowed('badCommand --danger', config);
|
||||||
expect(result.allowed).toBe(false);
|
expect(result.allowed).toBe(false);
|
||||||
expect(result.reason).toBe(
|
expect(result.reason).toBe(
|
||||||
@@ -99,7 +99,7 @@ describe('isCommandAllowed', () => {
|
|||||||
|
|
||||||
it('should prioritize the blocklist over the allowlist', () => {
|
it('should prioritize the blocklist over the allowlist', () => {
|
||||||
config.getCoreTools = () => ['ShellTool(badCommand --danger)'];
|
config.getCoreTools = () => ['ShellTool(badCommand --danger)'];
|
||||||
config.getExcludeTools = () => ['ShellTool(badCommand --danger)'];
|
config.getExcludeTools = () => new Set(['ShellTool(badCommand --danger)']);
|
||||||
const result = isCommandAllowed('badCommand --danger', config);
|
const result = isCommandAllowed('badCommand --danger', config);
|
||||||
expect(result.allowed).toBe(false);
|
expect(result.allowed).toBe(false);
|
||||||
expect(result.reason).toBe(
|
expect(result.reason).toBe(
|
||||||
@@ -114,7 +114,7 @@ describe('isCommandAllowed', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should block any command when a wildcard is in excludeTools', () => {
|
it('should block any command when a wildcard is in excludeTools', () => {
|
||||||
config.getExcludeTools = () => ['run_shell_command'];
|
config.getExcludeTools = () => new Set(['run_shell_command']);
|
||||||
const result = isCommandAllowed('any random command', config);
|
const result = isCommandAllowed('any random command', config);
|
||||||
expect(result.allowed).toBe(false);
|
expect(result.allowed).toBe(false);
|
||||||
expect(result.reason).toBe(
|
expect(result.reason).toBe(
|
||||||
@@ -124,7 +124,7 @@ describe('isCommandAllowed', () => {
|
|||||||
|
|
||||||
it('should block a command on the blocklist even with a wildcard allow', () => {
|
it('should block a command on the blocklist even with a wildcard allow', () => {
|
||||||
config.getCoreTools = () => ['ShellTool'];
|
config.getCoreTools = () => ['ShellTool'];
|
||||||
config.getExcludeTools = () => ['ShellTool(badCommand --danger)'];
|
config.getExcludeTools = () => new Set(['ShellTool(badCommand --danger)']);
|
||||||
const result = isCommandAllowed('badCommand --danger', config);
|
const result = isCommandAllowed('badCommand --danger', config);
|
||||||
expect(result.allowed).toBe(false);
|
expect(result.allowed).toBe(false);
|
||||||
expect(result.reason).toBe(
|
expect(result.reason).toBe(
|
||||||
@@ -145,7 +145,7 @@ describe('isCommandAllowed', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should block a chained command if any part is blocked', () => {
|
it('should block a chained command if any part is blocked', () => {
|
||||||
config.getExcludeTools = () => ['run_shell_command(badCommand)'];
|
config.getExcludeTools = () => new Set(['run_shell_command(badCommand)']);
|
||||||
const result = isCommandAllowed(
|
const result = isCommandAllowed(
|
||||||
'echo "hello" && badCommand --danger',
|
'echo "hello" && badCommand --danger',
|
||||||
config,
|
config,
|
||||||
@@ -159,7 +159,7 @@ describe('isCommandAllowed', () => {
|
|||||||
it('should block a command that redefines an allowed function to run an unlisted command', () => {
|
it('should block a command that redefines an allowed function to run an unlisted command', () => {
|
||||||
config.getCoreTools = () => ['run_shell_command(echo)'];
|
config.getCoreTools = () => ['run_shell_command(echo)'];
|
||||||
const result = isCommandAllowed(
|
const result = isCommandAllowed(
|
||||||
'echo () (curl google.com) ; echo Hello Wolrd',
|
'echo () (curl google.com) ; echo Hello World',
|
||||||
config,
|
config,
|
||||||
);
|
);
|
||||||
expect(result.allowed).toBe(false);
|
expect(result.allowed).toBe(false);
|
||||||
@@ -355,7 +355,7 @@ describe('checkCommandPermissions', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should return a detailed failure object for a blocked command', () => {
|
it('should return a detailed failure object for a blocked command', () => {
|
||||||
config.getExcludeTools = () => ['ShellTool(badCommand)'];
|
config.getExcludeTools = () => new Set(['ShellTool(badCommand)']);
|
||||||
const result = checkCommandPermissions('badCommand --danger', config);
|
const result = checkCommandPermissions('badCommand --danger', config);
|
||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
allAllowed: false,
|
allAllowed: false,
|
||||||
@@ -424,7 +424,7 @@ describe('checkCommandPermissions', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should block a command on the sessionAllowlist if it is also globally blocked', () => {
|
it('should block a command on the sessionAllowlist if it is also globally blocked', () => {
|
||||||
config.getExcludeTools = () => ['run_shell_command(badCommand)'];
|
config.getExcludeTools = () => new Set(['run_shell_command(badCommand)']);
|
||||||
const result = checkCommandPermissions(
|
const result = checkCommandPermissions(
|
||||||
'badCommand --danger',
|
'badCommand --danger',
|
||||||
config,
|
config,
|
||||||
|
|||||||
@@ -605,9 +605,9 @@ export function checkCommandPermissions(
|
|||||||
} as AnyToolInvocation & { params: { command: string } };
|
} as AnyToolInvocation & { params: { command: string } };
|
||||||
|
|
||||||
// 1. Blocklist Check (Highest Priority)
|
// 1. Blocklist Check (Highest Priority)
|
||||||
const excludeTools = config.getExcludeTools() || [];
|
const excludeTools = config.getExcludeTools() || new Set([]);
|
||||||
const isWildcardBlocked = SHELL_TOOL_NAMES.some((name) =>
|
const isWildcardBlocked = SHELL_TOOL_NAMES.some((name) =>
|
||||||
excludeTools.includes(name),
|
excludeTools.has(name),
|
||||||
);
|
);
|
||||||
|
|
||||||
if (isWildcardBlocked) {
|
if (isWildcardBlocked) {
|
||||||
@@ -622,7 +622,9 @@ export function checkCommandPermissions(
|
|||||||
for (const cmd of commandsToValidate) {
|
for (const cmd of commandsToValidate) {
|
||||||
invocation.params['command'] = cmd;
|
invocation.params['command'] = cmd;
|
||||||
if (
|
if (
|
||||||
doesToolInvocationMatch('run_shell_command', invocation, excludeTools)
|
doesToolInvocationMatch('run_shell_command', invocation, [
|
||||||
|
...excludeTools,
|
||||||
|
])
|
||||||
) {
|
) {
|
||||||
return {
|
return {
|
||||||
allAllowed: false,
|
allAllowed: false,
|
||||||
|
|||||||
Reference in New Issue
Block a user