mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-29 06:25:16 -07:00
[feat] Extension Reloading - respect updates to exclude tools (#12728)
This commit is contained in:
@@ -4,10 +4,19 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
describe,
|
||||
expect,
|
||||
it,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type MockInstance,
|
||||
} from 'vitest';
|
||||
import { SimpleExtensionLoader } from './extensionLoader.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { Config, GeminiCLIExtension } from '../config/config.js';
|
||||
import { type McpClientManager } from '../tools/mcp-client-manager.js';
|
||||
import type { GeminiClient } from '../core/client.js';
|
||||
|
||||
const mockRefreshServerHierarchicalMemory = vi.hoisted(() => vi.fn());
|
||||
|
||||
@@ -23,15 +32,20 @@ describe('SimpleExtensionLoader', () => {
|
||||
let mockConfig: Config;
|
||||
let extensionReloadingEnabled: boolean;
|
||||
let mockMcpClientManager: McpClientManager;
|
||||
const activeExtension = {
|
||||
let mockGeminiClientSetTools: MockInstance<
|
||||
typeof GeminiClient.prototype.setTools
|
||||
>;
|
||||
|
||||
const activeExtension: GeminiCLIExtension = {
|
||||
name: 'test-extension',
|
||||
isActive: true,
|
||||
version: '1.0.0',
|
||||
path: '/path/to/extension',
|
||||
contextFiles: [],
|
||||
excludeTools: ['some-tool'],
|
||||
id: '123',
|
||||
};
|
||||
const inactiveExtension = {
|
||||
const inactiveExtension: GeminiCLIExtension = {
|
||||
name: 'test-extension',
|
||||
isActive: false,
|
||||
version: '1.0.0',
|
||||
@@ -46,9 +60,14 @@ describe('SimpleExtensionLoader', () => {
|
||||
stopExtension: vi.fn(),
|
||||
} as unknown as McpClientManager;
|
||||
extensionReloadingEnabled = false;
|
||||
mockGeminiClientSetTools = vi.fn();
|
||||
mockConfig = {
|
||||
getMcpClientManager: () => mockMcpClientManager,
|
||||
getEnableExtensionReloading: () => extensionReloadingEnabled,
|
||||
getGeminiClient: vi.fn(() => ({
|
||||
isInitialized: () => true,
|
||||
setTools: mockGeminiClientSetTools,
|
||||
})),
|
||||
} as unknown as Config;
|
||||
});
|
||||
|
||||
@@ -106,11 +125,14 @@ describe('SimpleExtensionLoader', () => {
|
||||
mockMcpClientManager.startExtension,
|
||||
).toHaveBeenCalledExactlyOnceWith(activeExtension);
|
||||
expect(mockRefreshServerHierarchicalMemory).toHaveBeenCalledOnce();
|
||||
expect(mockGeminiClientSetTools).toHaveBeenCalledOnce();
|
||||
} else {
|
||||
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
|
||||
expect(mockRefreshServerHierarchicalMemory).not.toHaveBeenCalled();
|
||||
expect(mockGeminiClientSetTools).not.toHaveBeenCalledOnce();
|
||||
}
|
||||
mockRefreshServerHierarchicalMemory.mockClear();
|
||||
mockGeminiClientSetTools.mockClear();
|
||||
|
||||
await loader.unloadExtension(activeExtension);
|
||||
if (reloadingEnabled) {
|
||||
@@ -118,9 +140,11 @@ describe('SimpleExtensionLoader', () => {
|
||||
mockMcpClientManager.stopExtension,
|
||||
).toHaveBeenCalledExactlyOnceWith(activeExtension);
|
||||
expect(mockRefreshServerHierarchicalMemory).toHaveBeenCalledOnce();
|
||||
expect(mockGeminiClientSetTools).toHaveBeenCalledOnce();
|
||||
} else {
|
||||
expect(mockMcpClientManager.stopExtension).not.toHaveBeenCalled();
|
||||
expect(mockRefreshServerHierarchicalMemory).not.toHaveBeenCalled();
|
||||
expect(mockGeminiClientSetTools).not.toHaveBeenCalledOnce();
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -73,6 +73,8 @@ export abstract class ExtensionLoader {
|
||||
});
|
||||
try {
|
||||
await this.config.getMcpClientManager()!.startExtension(extension);
|
||||
await this.maybeRefreshGeminiTools(extension);
|
||||
|
||||
// Note: Context files are loaded only once all extensions are done
|
||||
// loading/unloading to reduce churn, see the `maybeRefreshMemories` call
|
||||
// below.
|
||||
@@ -80,9 +82,6 @@ export abstract class ExtensionLoader {
|
||||
// TODO: Update custom command updating away from the event based system
|
||||
// and call directly into a custom command manager here. See the
|
||||
// useSlashCommandProcessor hook which responds to events fired here today.
|
||||
|
||||
// TODO: Move all enablement of extension features here, including at least:
|
||||
// - excluded tool configuration
|
||||
} finally {
|
||||
this.startCompletedCount++;
|
||||
this.eventEmitter?.emit('extensionsStarting', {
|
||||
@@ -115,6 +114,21 @@ export abstract class ExtensionLoader {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the gemini tools list if it is initialized and the extension has
|
||||
* any excludeTools settings.
|
||||
*/
|
||||
private async maybeRefreshGeminiTools(
|
||||
extension: GeminiCLIExtension,
|
||||
): Promise<void> {
|
||||
if (extension.excludeTools && extension.excludeTools.length > 0) {
|
||||
const geminiClient = this.config?.getGeminiClient();
|
||||
if (geminiClient?.isInitialized()) {
|
||||
await geminiClient.setTools();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* If extension reloading is enabled and `start` has already been called,
|
||||
* then calls `startExtension` to include all extension features into the
|
||||
@@ -150,6 +164,8 @@ export abstract class ExtensionLoader {
|
||||
|
||||
try {
|
||||
await this.config.getMcpClientManager()!.stopExtension(extension);
|
||||
await this.maybeRefreshGeminiTools(extension);
|
||||
|
||||
// Note: Context files are loaded only once all extensions are done
|
||||
// loading/unloading to reduce churn, see the `maybeRefreshMemories` call
|
||||
// below.
|
||||
@@ -157,9 +173,6 @@ export abstract class ExtensionLoader {
|
||||
// TODO: Update custom command updating away from the event based system
|
||||
// and call directly into a custom command manager here. See the
|
||||
// useSlashCommandProcessor hook which responds to events fired here today.
|
||||
|
||||
// TODO: Remove all extension features here, including at least:
|
||||
// - excluded tools
|
||||
} finally {
|
||||
this.stopCompletedCount++;
|
||||
this.eventEmitter?.emit('extensionsStopping', {
|
||||
|
||||
@@ -58,7 +58,7 @@ beforeEach(() => {
|
||||
);
|
||||
config = {
|
||||
getCoreTools: () => [],
|
||||
getExcludeTools: () => [],
|
||||
getExcludeTools: () => new Set([]),
|
||||
getAllowedTools: () => [],
|
||||
} as unknown as Config;
|
||||
});
|
||||
@@ -89,7 +89,7 @@ describe('isCommandAllowed', () => {
|
||||
});
|
||||
|
||||
it('should block a command if it is in the blocked list', () => {
|
||||
config.getExcludeTools = () => ['ShellTool(badCommand --danger)'];
|
||||
config.getExcludeTools = () => new Set(['ShellTool(badCommand --danger)']);
|
||||
const result = isCommandAllowed('badCommand --danger', config);
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
@@ -99,7 +99,7 @@ describe('isCommandAllowed', () => {
|
||||
|
||||
it('should prioritize the blocklist over the allowlist', () => {
|
||||
config.getCoreTools = () => ['ShellTool(badCommand --danger)'];
|
||||
config.getExcludeTools = () => ['ShellTool(badCommand --danger)'];
|
||||
config.getExcludeTools = () => new Set(['ShellTool(badCommand --danger)']);
|
||||
const result = isCommandAllowed('badCommand --danger', config);
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
@@ -114,7 +114,7 @@ describe('isCommandAllowed', () => {
|
||||
});
|
||||
|
||||
it('should block any command when a wildcard is in excludeTools', () => {
|
||||
config.getExcludeTools = () => ['run_shell_command'];
|
||||
config.getExcludeTools = () => new Set(['run_shell_command']);
|
||||
const result = isCommandAllowed('any random command', config);
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
@@ -124,7 +124,7 @@ describe('isCommandAllowed', () => {
|
||||
|
||||
it('should block a command on the blocklist even with a wildcard allow', () => {
|
||||
config.getCoreTools = () => ['ShellTool'];
|
||||
config.getExcludeTools = () => ['ShellTool(badCommand --danger)'];
|
||||
config.getExcludeTools = () => new Set(['ShellTool(badCommand --danger)']);
|
||||
const result = isCommandAllowed('badCommand --danger', config);
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.reason).toBe(
|
||||
@@ -145,7 +145,7 @@ describe('isCommandAllowed', () => {
|
||||
});
|
||||
|
||||
it('should block a chained command if any part is blocked', () => {
|
||||
config.getExcludeTools = () => ['run_shell_command(badCommand)'];
|
||||
config.getExcludeTools = () => new Set(['run_shell_command(badCommand)']);
|
||||
const result = isCommandAllowed(
|
||||
'echo "hello" && badCommand --danger',
|
||||
config,
|
||||
@@ -159,7 +159,7 @@ describe('isCommandAllowed', () => {
|
||||
it('should block a command that redefines an allowed function to run an unlisted command', () => {
|
||||
config.getCoreTools = () => ['run_shell_command(echo)'];
|
||||
const result = isCommandAllowed(
|
||||
'echo () (curl google.com) ; echo Hello Wolrd',
|
||||
'echo () (curl google.com) ; echo Hello World',
|
||||
config,
|
||||
);
|
||||
expect(result.allowed).toBe(false);
|
||||
@@ -355,7 +355,7 @@ describe('checkCommandPermissions', () => {
|
||||
});
|
||||
|
||||
it('should return a detailed failure object for a blocked command', () => {
|
||||
config.getExcludeTools = () => ['ShellTool(badCommand)'];
|
||||
config.getExcludeTools = () => new Set(['ShellTool(badCommand)']);
|
||||
const result = checkCommandPermissions('badCommand --danger', config);
|
||||
expect(result).toEqual({
|
||||
allAllowed: false,
|
||||
@@ -424,7 +424,7 @@ describe('checkCommandPermissions', () => {
|
||||
});
|
||||
|
||||
it('should block a command on the sessionAllowlist if it is also globally blocked', () => {
|
||||
config.getExcludeTools = () => ['run_shell_command(badCommand)'];
|
||||
config.getExcludeTools = () => new Set(['run_shell_command(badCommand)']);
|
||||
const result = checkCommandPermissions(
|
||||
'badCommand --danger',
|
||||
config,
|
||||
|
||||
@@ -605,9 +605,9 @@ export function checkCommandPermissions(
|
||||
} as AnyToolInvocation & { params: { command: string } };
|
||||
|
||||
// 1. Blocklist Check (Highest Priority)
|
||||
const excludeTools = config.getExcludeTools() || [];
|
||||
const excludeTools = config.getExcludeTools() || new Set([]);
|
||||
const isWildcardBlocked = SHELL_TOOL_NAMES.some((name) =>
|
||||
excludeTools.includes(name),
|
||||
excludeTools.has(name),
|
||||
);
|
||||
|
||||
if (isWildcardBlocked) {
|
||||
@@ -622,7 +622,9 @@ export function checkCommandPermissions(
|
||||
for (const cmd of commandsToValidate) {
|
||||
invocation.params['command'] = cmd;
|
||||
if (
|
||||
doesToolInvocationMatch('run_shell_command', invocation, excludeTools)
|
||||
doesToolInvocationMatch('run_shell_command', invocation, [
|
||||
...excludeTools,
|
||||
])
|
||||
) {
|
||||
return {
|
||||
allAllowed: false,
|
||||
|
||||
Reference in New Issue
Block a user