Add support for sensitive keychain-stored per-extension settings (#11953)

This commit is contained in:
christine betts
2025-10-28 14:48:50 -04:00
committed by GitHub
parent 7a238bd938
commit 7e987113a2
22 changed files with 706 additions and 212 deletions
+2 -2
View File
@@ -190,8 +190,8 @@ Each object in the array should have the following properties:
- `description`: A description of the setting and what it's used for. - `description`: A description of the setting and what it's used for.
- `envVar`: The name of the environment variable that the setting will be stored - `envVar`: The name of the environment variable that the setting will be stored
as. as.
- `sensitive`: Optional boolean. If true, obfuscates the input the user provides
**Example** and stores the secret in keychain storage. **Example**
```json ```json
{ {
@@ -17,7 +17,7 @@ interface DisableArgs {
scope?: string; scope?: string;
} }
export function handleDisable(args: DisableArgs) { export async function handleDisable(args: DisableArgs) {
const workspaceDir = process.cwd(); const workspaceDir = process.cwd();
const extensionManager = new ExtensionManager({ const extensionManager = new ExtensionManager({
workspaceDir, workspaceDir,
@@ -25,13 +25,16 @@ export function handleDisable(args: DisableArgs) {
requestSetting: promptForSetting, requestSetting: promptForSetting,
settings: loadSettings(workspaceDir).merged, settings: loadSettings(workspaceDir).merged,
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
try { try {
if (args.scope?.toLowerCase() === 'workspace') { if (args.scope?.toLowerCase() === 'workspace') {
extensionManager.disableExtension(args.name, SettingScope.Workspace); await extensionManager.disableExtension(
args.name,
SettingScope.Workspace,
);
} else { } else {
extensionManager.disableExtension(args.name, SettingScope.User); await extensionManager.disableExtension(args.name, SettingScope.User);
} }
debugLogger.log( debugLogger.log(
`Extension "${args.name}" successfully disabled for scope "${args.scope}".`, `Extension "${args.name}" successfully disabled for scope "${args.scope}".`,
@@ -20,7 +20,7 @@ interface EnableArgs {
scope?: string; scope?: string;
} }
export function handleEnable(args: EnableArgs) { export async function handleEnable(args: EnableArgs) {
const workingDir = process.cwd(); const workingDir = process.cwd();
const extensionManager = new ExtensionManager({ const extensionManager = new ExtensionManager({
workspaceDir: workingDir, workspaceDir: workingDir,
@@ -28,7 +28,7 @@ export function handleEnable(args: EnableArgs) {
requestSetting: promptForSetting, requestSetting: promptForSetting,
settings: loadSettings(workingDir).merged, settings: loadSettings(workingDir).merged,
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
try { try {
if (args.scope?.toLowerCase() === 'workspace') { if (args.scope?.toLowerCase() === 'workspace') {
@@ -76,7 +76,7 @@ export async function handleInstall(args: InstallArgs) {
requestSetting: promptForSetting, requestSetting: promptForSetting,
settings: loadSettings(workspaceDir).merged, settings: loadSettings(workspaceDir).merged,
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
const name = const name =
await extensionManager.installOrUpdateExtension(installMetadata); await extensionManager.installOrUpdateExtension(installMetadata);
debugLogger.log(`Extension "${name}" installed successfully and enabled.`); debugLogger.log(`Extension "${name}" installed successfully and enabled.`);
+1 -1
View File
@@ -33,7 +33,7 @@ export async function handleLink(args: InstallArgs) {
requestSetting: promptForSetting, requestSetting: promptForSetting,
settings: loadSettings(workspaceDir).merged, settings: loadSettings(workspaceDir).merged,
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
const extensionName = const extensionName =
await extensionManager.installOrUpdateExtension(installMetadata); await extensionManager.installOrUpdateExtension(installMetadata);
debugLogger.log( debugLogger.log(
+1 -1
View File
@@ -21,7 +21,7 @@ export async function handleList() {
requestSetting: promptForSetting, requestSetting: promptForSetting,
settings: loadSettings(workspaceDir).merged, settings: loadSettings(workspaceDir).merged,
}); });
const extensions = extensionManager.loadExtensions(); const extensions = await extensionManager.loadExtensions();
if (extensions.length === 0) { if (extensions.length === 0) {
debugLogger.log('No extensions installed.'); debugLogger.log('No extensions installed.');
return; return;
@@ -25,7 +25,7 @@ export async function handleUninstall(args: UninstallArgs) {
requestSetting: promptForSetting, requestSetting: promptForSetting,
settings: loadSettings(workspaceDir).merged, settings: loadSettings(workspaceDir).merged,
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.uninstallExtension(args.name, false); await extensionManager.uninstallExtension(args.name, false);
debugLogger.log(`Extension "${args.name}" successfully uninstalled.`); debugLogger.log(`Extension "${args.name}" successfully uninstalled.`);
} catch (error) { } catch (error) {
@@ -37,7 +37,7 @@ export async function handleUpdate(args: UpdateArgs) {
settings: loadSettings(workspaceDir).merged, settings: loadSettings(workspaceDir).merged,
}); });
const extensions = extensionManager.loadExtensions(); const extensions = await extensionManager.loadExtensions();
if (args.name) { if (args.name) {
try { try {
const extension = extensions.find( const extension = extensions.find(
+1 -1
View File
@@ -33,7 +33,7 @@ async function getMcpServersFromConfig(): Promise<
requestConsent: requestConsentNonInteractive, requestConsent: requestConsentNonInteractive,
requestSetting: promptForSetting, requestSetting: promptForSetting,
}); });
const extensions = extensionManager.loadExtensions(); const extensions = await extensionManager.loadExtensions();
const mcpServers = { ...(settings.merged.mcpServers || {}) }; const mcpServers = { ...(settings.merged.mcpServers || {}) };
for (const extension of extensions) { for (const extension of extensions) {
Object.entries(extension.mcpServers || {}).forEach(([key, server]) => { Object.entries(extension.mcpServers || {}).forEach(([key, server]) => {
+1 -1
View File
@@ -423,7 +423,7 @@ export async function loadCliConfig(
workspaceDir: cwd, workspaceDir: cwd,
enabledExtensionOverrides: argv.extensions, enabledExtensionOverrides: argv.extensions,
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
// Call the (now wrapper) loadHierarchicalGeminiMemory which calls the server's version // Call the (now wrapper) loadHierarchicalGeminiMemory which calls the server's version
const { memoryContent, fileCount, filePaths } = const { memoryContent, fileCount, filePaths } =
+26 -12
View File
@@ -133,7 +133,7 @@ export class ExtensionManager implements ExtensionLoader {
const isUpdate = !!previousExtensionConfig; const isUpdate = !!previousExtensionConfig;
let newExtensionConfig: ExtensionConfig | null = null; let newExtensionConfig: ExtensionConfig | null = null;
let localSourcePath: string | undefined; let localSourcePath: string | undefined;
let extension: GeminiCLIExtension; let extension: GeminiCLIExtension | null;
try { try {
if (!isWorkspaceTrusted(this.settings).isTrusted) { if (!isWorkspaceTrusted(this.settings).isTrusted) {
throw new Error( throw new Error(
@@ -243,12 +243,16 @@ export class ExtensionManager implements ExtensionLoader {
this.requestConsent, this.requestConsent,
previousExtensionConfig, previousExtensionConfig,
); );
const extensionId = getExtensionId(newExtensionConfig, installMetadata);
const extensionStorage = new ExtensionStorage(newExtensionName); const destinationPath = new ExtensionStorage(
const destinationPath = extensionStorage.getExtensionDir(); newExtensionName,
).getExtensionDir();
let previousSettings: Record<string, string> | undefined; let previousSettings: Record<string, string> | undefined;
if (isUpdate) { if (isUpdate) {
previousSettings = getEnvContents(extensionStorage); previousSettings = await getEnvContents(
previousExtensionConfig,
extensionId,
);
await this.uninstallExtension(newExtensionName, isUpdate); await this.uninstallExtension(newExtensionName, isUpdate);
} }
@@ -257,6 +261,7 @@ export class ExtensionManager implements ExtensionLoader {
if (isUpdate) { if (isUpdate) {
await maybePromptForSettings( await maybePromptForSettings(
newExtensionConfig, newExtensionConfig,
extensionId,
this.requestSetting, this.requestSetting,
previousExtensionConfig, previousExtensionConfig,
previousSettings, previousSettings,
@@ -264,6 +269,7 @@ export class ExtensionManager implements ExtensionLoader {
} else { } else {
await maybePromptForSettings( await maybePromptForSettings(
newExtensionConfig, newExtensionConfig,
extensionId,
this.requestSetting, this.requestSetting,
); );
} }
@@ -286,7 +292,10 @@ export class ExtensionManager implements ExtensionLoader {
// TODO: Gracefully handle this call failing, we should back up the old // TODO: Gracefully handle this call failing, we should back up the old
// extension prior to overwriting it and then restore it. // extension prior to overwriting it and then restore it.
extension = this.loadExtension(destinationPath)!; extension = await this.loadExtension(destinationPath)!;
if (!extension) {
throw new Error(`Extension not found`);
}
if (isUpdate) { if (isUpdate) {
logExtensionUpdateEvent( logExtensionUpdateEvent(
this.telemetryConfig, this.telemetryConfig,
@@ -401,7 +410,7 @@ export class ExtensionManager implements ExtensionLoader {
this.eventEmitter.emit('extensionUninstalled', { extension }); this.eventEmitter.emit('extensionUninstalled', { extension });
} }
loadExtensions(): GeminiCLIExtension[] { async loadExtensions(): Promise<GeminiCLIExtension[]> {
if (this.loadedExtensions) { if (this.loadedExtensions) {
throw new Error('Extensions already loaded, only load extensions once.'); throw new Error('Extensions already loaded, only load extensions once.');
} }
@@ -413,12 +422,14 @@ export class ExtensionManager implements ExtensionLoader {
for (const subdir of fs.readdirSync(extensionsDir)) { for (const subdir of fs.readdirSync(extensionsDir)) {
const extensionDir = path.join(extensionsDir, subdir); const extensionDir = path.join(extensionsDir, subdir);
this.loadExtension(extensionDir); await this.loadExtension(extensionDir);
} }
return this.loadedExtensions; return this.loadedExtensions;
} }
private loadExtension(extensionDir: string): GeminiCLIExtension | null { private async loadExtension(
extensionDir: string,
): Promise<GeminiCLIExtension | null> {
this.loadedExtensions ??= []; this.loadedExtensions ??= [];
if (!fs.statSync(extensionDir).isDirectory()) { if (!fs.statSync(extensionDir).isDirectory()) {
return null; return null;
@@ -441,7 +452,10 @@ export class ExtensionManager implements ExtensionLoader {
); );
} }
const customEnv = getEnvContents(new ExtensionStorage(config.name)); const customEnv = await getEnvContents(
config,
getExtensionId(config, installMetadata),
);
config = resolveEnvVarsInObject(config, customEnv); config = resolveEnvVarsInObject(config, customEnv);
if (config.mcpServers) { if (config.mcpServers) {
@@ -573,7 +587,7 @@ export class ExtensionManager implements ExtensionLoader {
return output; return output;
} }
disableExtension(name: string, scope: SettingScope) { async disableExtension(name: string, scope: SettingScope) {
if ( if (
scope === SettingScope.System || scope === SettingScope.System ||
scope === SettingScope.SystemDefaults scope === SettingScope.SystemDefaults
@@ -598,7 +612,7 @@ export class ExtensionManager implements ExtensionLoader {
this.eventEmitter.emit('extensionDisabled', { extension }); this.eventEmitter.emit('extensionDisabled', { extension });
} }
enableExtension(name: string, scope: SettingScope) { async enableExtension(name: string, scope: SettingScope) {
if ( if (
scope === SettingScope.System || scope === SettingScope.System ||
scope === SettingScope.SystemDefaults scope === SettingScope.SystemDefaults
+146 -106
View File
@@ -13,6 +13,7 @@ import {
ExtensionUninstallEvent, ExtensionUninstallEvent,
ExtensionDisableEvent, ExtensionDisableEvent,
ExtensionEnableEvent, ExtensionEnableEvent,
KeychainTokenStorage,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import { loadSettings, SettingScope } from './settings.js'; import { loadSettings, SettingScope } from './settings.js';
import { isWorkspaceTrusted } from './trustedFolders.js'; import { isWorkspaceTrusted } from './trustedFolders.js';
@@ -96,6 +97,13 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
ExtensionInstallEvent: vi.fn(), ExtensionInstallEvent: vi.fn(),
ExtensionUninstallEvent: vi.fn(), ExtensionUninstallEvent: vi.fn(),
ExtensionDisableEvent: vi.fn(), ExtensionDisableEvent: vi.fn(),
KeychainTokenStorage: vi.fn().mockImplementation(() => ({
getSecret: vi.fn(),
setSecret: vi.fn(),
deleteSecret: vi.fn(),
listSecrets: vi.fn(),
isAvailable: vi.fn().mockResolvedValue(true),
})),
}; };
}); });
@@ -107,6 +115,14 @@ vi.mock('child_process', async (importOriginal) => {
}; };
}); });
interface MockKeychainStorage {
getSecret: ReturnType<typeof vi.fn>;
setSecret: ReturnType<typeof vi.fn>;
deleteSecret: ReturnType<typeof vi.fn>;
listSecrets: ReturnType<typeof vi.fn>;
isAvailable: ReturnType<typeof vi.fn>;
}
describe('extension tests', () => { describe('extension tests', () => {
let tempHomeDir: string; let tempHomeDir: string;
let tempWorkspaceDir: string; let tempWorkspaceDir: string;
@@ -116,8 +132,32 @@ describe('extension tests', () => {
let mockPromptForSettings: MockedFunction< let mockPromptForSettings: MockedFunction<
(setting: ExtensionSetting) => Promise<string> (setting: ExtensionSetting) => Promise<string>
>; >;
let mockKeychainStorage: MockKeychainStorage;
let keychainData: Record<string, string>;
beforeEach(() => { beforeEach(() => {
vi.clearAllMocks();
keychainData = {};
mockKeychainStorage = {
getSecret: vi
.fn()
.mockImplementation(async (key: string) => keychainData[key] || null),
setSecret: vi
.fn()
.mockImplementation(async (key: string, value: string) => {
keychainData[key] = value;
}),
deleteSecret: vi.fn().mockImplementation(async (key: string) => {
delete keychainData[key];
}),
listSecrets: vi
.fn()
.mockImplementation(async () => Object.keys(keychainData)),
isAvailable: vi.fn().mockResolvedValue(true),
};
(
KeychainTokenStorage as unknown as ReturnType<typeof vi.fn>
).mockImplementation(() => mockKeychainStorage);
tempHomeDir = fs.mkdtempSync( tempHomeDir = fs.mkdtempSync(
path.join(os.tmpdir(), 'gemini-cli-test-home-'), path.join(os.tmpdir(), 'gemini-cli-test-home-'),
); );
@@ -151,7 +191,7 @@ describe('extension tests', () => {
}); });
describe('loadExtensions', () => { describe('loadExtensions', () => {
it('should include extension path in loaded extension', () => { it('should include extension path in loaded extension', async () => {
const extensionDir = path.join(userExtensionsDir, 'test-extension'); const extensionDir = path.join(userExtensionsDir, 'test-extension');
fs.mkdirSync(extensionDir, { recursive: true }); fs.mkdirSync(extensionDir, { recursive: true });
@@ -161,13 +201,13 @@ describe('extension tests', () => {
version: '1.0.0', version: '1.0.0',
}); });
const extensions = extensionManager.loadExtensions(); const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1); expect(extensions).toHaveLength(1);
expect(extensions[0].path).toBe(extensionDir); expect(extensions[0].path).toBe(extensionDir);
expect(extensions[0].name).toBe('test-extension'); expect(extensions[0].name).toBe('test-extension');
}); });
it('should load context file path when GEMINI.md is present', () => { it('should load context file path when GEMINI.md is present', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'ext1', name: 'ext1',
@@ -180,7 +220,7 @@ describe('extension tests', () => {
version: '2.0.0', version: '2.0.0',
}); });
const extensions = extensionManager.loadExtensions(); const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(2); expect(extensions).toHaveLength(2);
const ext1 = extensions.find((e) => e.name === 'ext1'); const ext1 = extensions.find((e) => e.name === 'ext1');
@@ -191,7 +231,7 @@ describe('extension tests', () => {
expect(ext2?.contextFiles).toEqual([]); expect(ext2?.contextFiles).toEqual([]);
}); });
it('should load context file path from the extension config', () => { it('should load context file path from the extension config', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'ext1', name: 'ext1',
@@ -200,7 +240,7 @@ describe('extension tests', () => {
contextFileName: 'my-context-file.md', contextFileName: 'my-context-file.md',
}); });
const extensions = extensionManager.loadExtensions(); const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1); expect(extensions).toHaveLength(1);
const ext1 = extensions.find((e) => e.name === 'ext1'); const ext1 = extensions.find((e) => e.name === 'ext1');
@@ -209,7 +249,7 @@ describe('extension tests', () => {
]); ]);
}); });
it('should annotate disabled extensions', () => { it('should annotate disabled extensions', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'disabled-extension', name: 'disabled-extension',
@@ -220,8 +260,8 @@ describe('extension tests', () => {
name: 'enabled-extension', name: 'enabled-extension',
version: '2.0.0', version: '2.0.0',
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
extensionManager.disableExtension( await extensionManager.disableExtension(
'disabled-extension', 'disabled-extension',
SettingScope.User, SettingScope.User,
); );
@@ -233,7 +273,7 @@ describe('extension tests', () => {
expect(extensions[1].isActive).toBe(true); expect(extensions[1].isActive).toBe(true);
}); });
it('should hydrate variables', () => { it('should hydrate variables', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'test-extension', name: 'test-extension',
@@ -247,7 +287,7 @@ describe('extension tests', () => {
}, },
}); });
const extensions = extensionManager.loadExtensions(); const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1); expect(extensions).toHaveLength(1);
const expectedCwd = path.join( const expectedCwd = path.join(
userExtensionsDir, userExtensionsDir,
@@ -266,7 +306,7 @@ describe('extension tests', () => {
}); });
fs.writeFileSync(path.join(sourceExtDir, 'context.md'), 'linked context'); fs.writeFileSync(path.join(sourceExtDir, 'context.md'), 'linked context');
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
const extension = await extensionManager.installOrUpdateExtension({ const extension = await extensionManager.installOrUpdateExtension({
source: sourceExtDir, source: sourceExtDir,
type: 'link', type: 'link',
@@ -303,7 +343,7 @@ describe('extension tests', () => {
}, },
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({ await extensionManager.installOrUpdateExtension({
source: sourceExtDir, source: sourceExtDir,
type: 'link', type: 'link',
@@ -319,7 +359,7 @@ describe('extension tests', () => {
]); ]);
}); });
it('should resolve environment variables in extension configuration', () => { it('should resolve environment variables in extension configuration', async () => {
process.env['TEST_API_KEY'] = 'test-api-key-123'; process.env['TEST_API_KEY'] = 'test-api-key-123';
process.env['TEST_DB_URL'] = 'postgresql://localhost:5432/testdb'; process.env['TEST_DB_URL'] = 'postgresql://localhost:5432/testdb';
@@ -352,7 +392,7 @@ describe('extension tests', () => {
}; };
fs.writeFileSync(configPath, JSON.stringify(extensionConfig)); fs.writeFileSync(configPath, JSON.stringify(extensionConfig));
const extensions = extensionManager.loadExtensions(); const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1); expect(extensions).toHaveLength(1);
const extension = extensions[0]; const extension = extensions[0];
@@ -373,7 +413,7 @@ describe('extension tests', () => {
} }
}); });
it('should resolve environment variables from an extension .env file', () => { it('should resolve environment variables from an extension .env file', async () => {
const extDir = createExtension({ const extDir = createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'test-extension', name: 'test-extension',
@@ -388,12 +428,19 @@ describe('extension tests', () => {
}, },
}, },
}, },
settings: [
{
name: 'My API Key',
description: 'API key for testing.',
envVar: 'MY_API_KEY',
},
],
}); });
const envFilePath = path.join(extDir, '.env'); const envFilePath = path.join(extDir, '.env');
fs.writeFileSync(envFilePath, 'MY_API_KEY=test-key-from-file\n'); fs.writeFileSync(envFilePath, 'MY_API_KEY=test-key-from-file\n');
const extensions = extensionManager.loadExtensions(); const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1); expect(extensions).toHaveLength(1);
const extension = extensions[0]; const extension = extensions[0];
@@ -403,7 +450,7 @@ describe('extension tests', () => {
expect(serverConfig.env!['STATIC_VALUE']).toBe('no-substitution'); expect(serverConfig.env!['STATIC_VALUE']).toBe('no-substitution');
}); });
it('should handle missing environment variables gracefully', () => { it('should handle missing environment variables gracefully', async () => {
const userExtensionsDir = path.join( const userExtensionsDir = path.join(
tempHomeDir, tempHomeDir,
EXTENSIONS_DIRECTORY_NAME, EXTENSIONS_DIRECTORY_NAME,
@@ -433,7 +480,7 @@ describe('extension tests', () => {
JSON.stringify(extensionConfig), JSON.stringify(extensionConfig),
); );
const extensions = extensionManager.loadExtensions(); const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1); expect(extensions).toHaveLength(1);
const extension = extensions[0]; const extension = extensions[0];
@@ -443,7 +490,7 @@ describe('extension tests', () => {
expect(serverConfig.env!['MISSING_VAR_BRACES']).toBe('${ALSO_UNDEFINED}'); expect(serverConfig.env!['MISSING_VAR_BRACES']).toBe('${ALSO_UNDEFINED}');
}); });
it('should skip extensions with invalid JSON and log a warning', () => { it('should skip extensions with invalid JSON and log a warning', async () => {
const consoleSpy = vi const consoleSpy = vi
.spyOn(console, 'error') .spyOn(console, 'error')
.mockImplementation(() => {}); .mockImplementation(() => {});
@@ -461,7 +508,7 @@ describe('extension tests', () => {
const badConfigPath = path.join(badExtDir, EXTENSIONS_CONFIG_FILENAME); const badConfigPath = path.join(badExtDir, EXTENSIONS_CONFIG_FILENAME);
fs.writeFileSync(badConfigPath, '{ "name": "bad-ext"'); // Malformed fs.writeFileSync(badConfigPath, '{ "name": "bad-ext"'); // Malformed
const extensions = extensionManager.loadExtensions(); const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1); expect(extensions).toHaveLength(1);
expect(extensions[0].name).toBe('good-ext'); expect(extensions[0].name).toBe('good-ext');
@@ -474,7 +521,7 @@ describe('extension tests', () => {
consoleSpy.mockRestore(); consoleSpy.mockRestore();
}); });
it('should skip extensions with missing name and log a warning', () => { it('should skip extensions with missing name and log a warning', async () => {
const consoleSpy = vi const consoleSpy = vi
.spyOn(console, 'error') .spyOn(console, 'error')
.mockImplementation(() => {}); .mockImplementation(() => {});
@@ -492,7 +539,7 @@ describe('extension tests', () => {
const badConfigPath = path.join(badExtDir, EXTENSIONS_CONFIG_FILENAME); const badConfigPath = path.join(badExtDir, EXTENSIONS_CONFIG_FILENAME);
fs.writeFileSync(badConfigPath, JSON.stringify({ version: '1.0.0' })); fs.writeFileSync(badConfigPath, JSON.stringify({ version: '1.0.0' }));
const extensions = extensionManager.loadExtensions(); const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1); expect(extensions).toHaveLength(1);
expect(extensions[0].name).toBe('good-ext'); expect(extensions[0].name).toBe('good-ext');
@@ -505,7 +552,7 @@ describe('extension tests', () => {
consoleSpy.mockRestore(); consoleSpy.mockRestore();
}); });
it('should filter trust out of mcp servers', () => { it('should filter trust out of mcp servers', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'test-extension', name: 'test-extension',
@@ -519,12 +566,12 @@ describe('extension tests', () => {
}, },
}); });
const extensions = extensionManager.loadExtensions(); const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1); expect(extensions).toHaveLength(1);
expect(extensions[0].mcpServers?.['test-server'].trust).toBeUndefined(); expect(extensions[0].mcpServers?.['test-server'].trust).toBeUndefined();
}); });
it('should throw an error for invalid extension names', () => { it('should throw an error for invalid extension names', async () => {
const consoleSpy = vi const consoleSpy = vi
.spyOn(console, 'error') .spyOn(console, 'error')
.mockImplementation(() => {}); .mockImplementation(() => {});
@@ -533,10 +580,8 @@ describe('extension tests', () => {
name: 'bad_name', name: 'bad_name',
version: '1.0.0', version: '1.0.0',
}); });
const extensions = await extensionManager.loadExtensions();
const extension = extensionManager const extension = extensions.find((e) => e.name === 'bad_name');
.loadExtensions()
.find((e) => e.name === 'bad_name');
expect(extension).toBeUndefined(); expect(extension).toBeUndefined();
expect(consoleSpy).toHaveBeenCalledWith( expect(consoleSpy).toHaveBeenCalledWith(
@@ -546,7 +591,7 @@ describe('extension tests', () => {
}); });
describe('id generation', () => { describe('id generation', () => {
it('should generate id from source for non-github git urls', () => { it('should generate id from source for non-github git urls', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'my-ext', name: 'my-ext',
@@ -556,14 +601,12 @@ describe('extension tests', () => {
source: 'http://somehost.com/foo/bar', source: 'http://somehost.com/foo/bar',
}, },
}); });
const extensions = await extensionManager.loadExtensions();
const extension = extensionManager const extension = extensions.find((e) => e.name === 'my-ext');
.loadExtensions()
.find((e) => e.name === 'my-ext');
expect(extension?.id).toBe(hashValue('http://somehost.com/foo/bar')); expect(extension?.id).toBe(hashValue('http://somehost.com/foo/bar'));
}); });
it('should generate id from owner/repo for github http urls', () => { it('should generate id from owner/repo for github http urls', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'my-ext', name: 'my-ext',
@@ -574,13 +617,12 @@ describe('extension tests', () => {
}, },
}); });
const extension = extensionManager const extensions = await extensionManager.loadExtensions();
.loadExtensions() const extension = extensions.find((e) => e.name === 'my-ext');
.find((e) => e.name === 'my-ext');
expect(extension?.id).toBe(hashValue('https://github.com/foo/bar')); expect(extension?.id).toBe(hashValue('https://github.com/foo/bar'));
}); });
it('should generate id from owner/repo for github ssh urls', () => { it('should generate id from owner/repo for github ssh urls', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'my-ext', name: 'my-ext',
@@ -591,13 +633,12 @@ describe('extension tests', () => {
}, },
}); });
const extension = extensionManager const extensions = await extensionManager.loadExtensions();
.loadExtensions() const extension = extensions.find((e) => e.name === 'my-ext');
.find((e) => e.name === 'my-ext');
expect(extension?.id).toBe(hashValue('https://github.com/foo/bar')); expect(extension?.id).toBe(hashValue('https://github.com/foo/bar'));
}); });
it('should generate id from source for github-release extension', () => { it('should generate id from source for github-release extension', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'my-ext', name: 'my-ext',
@@ -607,14 +648,12 @@ describe('extension tests', () => {
source: 'https://github.com/foo/bar', source: 'https://github.com/foo/bar',
}, },
}); });
const extensions = await extensionManager.loadExtensions();
const extension = extensionManager const extension = extensions.find((e) => e.name === 'my-ext');
.loadExtensions()
.find((e) => e.name === 'my-ext');
expect(extension?.id).toBe(hashValue('https://github.com/foo/bar')); expect(extension?.id).toBe(hashValue('https://github.com/foo/bar'));
}); });
it('should generate id from the original source for local extension', () => { it('should generate id from the original source for local extension', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'local-ext-name', name: 'local-ext-name',
@@ -625,9 +664,8 @@ describe('extension tests', () => {
}, },
}); });
const extension = extensionManager const extensions = await extensionManager.loadExtensions();
.loadExtensions() const extension = extensions.find((e) => e.name === 'local-ext-name');
.find((e) => e.name === 'local-ext-name');
expect(extension?.id).toBe(hashValue('/some/path')); expect(extension?.id).toBe(hashValue('/some/path'));
}); });
@@ -638,7 +676,7 @@ describe('extension tests', () => {
name: 'link-ext-name', name: 'link-ext-name',
version: '1.0.0', version: '1.0.0',
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({ await extensionManager.installOrUpdateExtension({
type: 'link', type: 'link',
source: actualExtensionDir, source: actualExtensionDir,
@@ -650,16 +688,15 @@ describe('extension tests', () => {
expect(extension?.id).toBe(hashValue(actualExtensionDir)); expect(extension?.id).toBe(hashValue(actualExtensionDir));
}); });
it('should generate id from name for extension with no install metadata', () => { it('should generate id from name for extension with no install metadata', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'no-meta-name', name: 'no-meta-name',
version: '1.0.0', version: '1.0.0',
}); });
const extension = extensionManager const extensions = await extensionManager.loadExtensions();
.loadExtensions() const extension = extensions.find((e) => e.name === 'no-meta-name');
.find((e) => e.name === 'no-meta-name');
expect(extension?.id).toBe(hashValue('no-meta-name')); expect(extension?.id).toBe(hashValue('no-meta-name'));
}); });
}); });
@@ -675,7 +712,7 @@ describe('extension tests', () => {
const targetExtDir = path.join(userExtensionsDir, 'my-local-extension'); const targetExtDir = path.join(userExtensionsDir, 'my-local-extension');
const metadataPath = path.join(targetExtDir, INSTALL_METADATA_FILENAME); const metadataPath = path.join(targetExtDir, INSTALL_METADATA_FILENAME);
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({ await extensionManager.installOrUpdateExtension({
source: sourceExtDir, source: sourceExtDir,
type: 'local', type: 'local',
@@ -697,7 +734,7 @@ describe('extension tests', () => {
name: 'my-local-extension', name: 'my-local-extension',
version: '1.0.0', version: '1.0.0',
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({ await extensionManager.installOrUpdateExtension({
source: sourceExtDir, source: sourceExtDir,
type: 'local', type: 'local',
@@ -791,7 +828,7 @@ describe('extension tests', () => {
type: 'github-release', type: 'github-release',
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({ await extensionManager.installOrUpdateExtension({
source: gitUrl, source: gitUrl,
type: 'git', type: 'git',
@@ -816,7 +853,7 @@ describe('extension tests', () => {
const metadataPath = path.join(targetExtDir, INSTALL_METADATA_FILENAME); const metadataPath = path.join(targetExtDir, INSTALL_METADATA_FILENAME);
const configPath = path.join(targetExtDir, EXTENSIONS_CONFIG_FILENAME); const configPath = path.join(targetExtDir, EXTENSIONS_CONFIG_FILENAME);
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({ await extensionManager.installOrUpdateExtension({
source: sourceExtDir, source: sourceExtDir,
type: 'link', type: 'link',
@@ -846,7 +883,7 @@ describe('extension tests', () => {
name: 'my-local-extension', name: 'my-local-extension',
version: '1.1.0', version: '1.1.0',
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
if (isUpdate) { if (isUpdate) {
await extensionManager.installOrUpdateExtension({ await extensionManager.installOrUpdateExtension({
source: sourceExtDir, source: sourceExtDir,
@@ -920,7 +957,7 @@ describe('extension tests', () => {
}, },
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await expect( await expect(
extensionManager.installOrUpdateExtension({ extensionManager.installOrUpdateExtension({
source: sourceExtDir, source: sourceExtDir,
@@ -952,7 +989,7 @@ This extension will run the following MCP servers:
}, },
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await expect( await expect(
extensionManager.installOrUpdateExtension({ extensionManager.installOrUpdateExtension({
source: sourceExtDir, source: sourceExtDir,
@@ -974,7 +1011,7 @@ This extension will run the following MCP servers:
}, },
}); });
mockRequestConsent.mockResolvedValue(false); mockRequestConsent.mockResolvedValue(false);
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await expect( await expect(
extensionManager.installOrUpdateExtension({ extensionManager.installOrUpdateExtension({
source: sourceExtDir, source: sourceExtDir,
@@ -992,7 +1029,7 @@ This extension will run the following MCP servers:
const targetExtDir = path.join(userExtensionsDir, 'my-local-extension'); const targetExtDir = path.join(userExtensionsDir, 'my-local-extension');
const metadataPath = path.join(targetExtDir, INSTALL_METADATA_FILENAME); const metadataPath = path.join(targetExtDir, INSTALL_METADATA_FILENAME);
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({ await extensionManager.installOrUpdateExtension({
source: sourceExtDir, source: sourceExtDir,
type: 'local', type: 'local',
@@ -1023,7 +1060,7 @@ This extension will run the following MCP servers:
}, },
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
// Install it with hard coded consent first. // Install it with hard coded consent first.
await extensionManager.installOrUpdateExtension({ await extensionManager.installOrUpdateExtension({
source: sourceExtDir, source: sourceExtDir,
@@ -1058,7 +1095,7 @@ This extension will run the following MCP servers:
], ],
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({ await extensionManager.installOrUpdateExtension({
source: sourceExtDir, source: sourceExtDir,
type: 'local', type: 'local',
@@ -1088,7 +1125,7 @@ This extension will run the following MCP servers:
settings: loadSettings(tempWorkspaceDir).merged, settings: loadSettings(tempWorkspaceDir).merged,
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({ await extensionManager.installOrUpdateExtension({
source: sourceExtDir, source: sourceExtDir,
type: 'local', type: 'local',
@@ -1111,7 +1148,7 @@ This extension will run the following MCP servers:
}); });
mockPromptForSettings.mockResolvedValueOnce('old-api-key'); mockPromptForSettings.mockResolvedValueOnce('old-api-key');
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
// Install it so it exists in the userExtensionsDir // Install it so it exists in the userExtensionsDir
await extensionManager.installOrUpdateExtension({ await extensionManager.installOrUpdateExtension({
source: oldSourceExtDir, source: oldSourceExtDir,
@@ -1181,7 +1218,7 @@ This extension will run the following MCP servers:
}, },
], ],
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({ await extensionManager.installOrUpdateExtension({
source: oldSourceExtDir, source: oldSourceExtDir,
type: 'local', type: 'local',
@@ -1273,7 +1310,7 @@ This extension will run the following MCP servers:
join(tempDir, extensionName), join(tempDir, extensionName),
); );
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({ await extensionManager.installOrUpdateExtension({
source: gitUrl, source: gitUrl,
type: 'github-release', type: 'github-release',
@@ -1298,7 +1335,7 @@ This extension will run the following MCP servers:
type: 'github-release', type: 'github-release',
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension( await extensionManager.installOrUpdateExtension(
{ source: gitUrl, type: 'github-release' }, // Use github-release to force consent { source: gitUrl, type: 'github-release' }, // Use github-release to force consent
); );
@@ -1329,7 +1366,7 @@ This extension will run the following MCP servers:
}); });
mockRequestConsent.mockResolvedValue(false); mockRequestConsent.mockResolvedValue(false);
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await expect( await expect(
extensionManager.installOrUpdateExtension({ extensionManager.installOrUpdateExtension({
source: gitUrl, source: gitUrl,
@@ -1354,7 +1391,7 @@ This extension will run the following MCP servers:
type: 'github-release', type: 'github-release',
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({ await extensionManager.installOrUpdateExtension({
source: gitUrl, source: gitUrl,
type: 'git', type: 'git',
@@ -1385,7 +1422,7 @@ This extension will run the following MCP servers:
type: 'github-release', type: 'github-release',
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension( await extensionManager.installOrUpdateExtension(
{ source: gitUrl, type: 'github-release' }, // Note the type { source: gitUrl, type: 'github-release' }, // Note the type
); );
@@ -1407,8 +1444,7 @@ This extension will run the following MCP servers:
name: 'my-local-extension', name: 'my-local-extension',
version: '1.0.0', version: '1.0.0',
}); });
await extensionManager.loadExtensions();
extensionManager.loadExtensions();
await extensionManager.uninstallExtension('my-local-extension', false); await extensionManager.uninstallExtension('my-local-extension', false);
expect(fs.existsSync(sourceExtDir)).toBe(false); expect(fs.existsSync(sourceExtDir)).toBe(false);
@@ -1426,7 +1462,7 @@ This extension will run the following MCP servers:
version: '1.0.0', version: '1.0.0',
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.uninstallExtension('my-local-extension', false); await extensionManager.uninstallExtension('my-local-extension', false);
expect(fs.existsSync(sourceExtDir)).toBe(false); expect(fs.existsSync(sourceExtDir)).toBe(false);
@@ -1435,7 +1471,7 @@ This extension will run the following MCP servers:
}); });
it('should throw an error if the extension does not exist', async () => { it('should throw an error if the extension does not exist', async () => {
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await expect( await expect(
extensionManager.uninstallExtension('nonexistent-extension', false), extensionManager.uninstallExtension('nonexistent-extension', false),
).rejects.toThrow('Extension not found.'); ).rejects.toThrow('Extension not found.');
@@ -1453,7 +1489,7 @@ This extension will run the following MCP servers:
}, },
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.uninstallExtension( await extensionManager.uninstallExtension(
'my-local-extension', 'my-local-extension',
isUpdate, isUpdate,
@@ -1481,7 +1517,7 @@ This extension will run the following MCP servers:
const enablementManager = new ExtensionEnablementManager(); const enablementManager = new ExtensionEnablementManager();
enablementManager.enable('test-extension', true, '/some/scope'); enablementManager.enable('test-extension', true, '/some/scope');
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.uninstallExtension('test-extension', isUpdate); await extensionManager.uninstallExtension('test-extension', isUpdate);
const config = enablementManager.readConfig()['test-extension']; const config = enablementManager.readConfig()['test-extension'];
@@ -1506,7 +1542,7 @@ This extension will run the following MCP servers:
}, },
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await extensionManager.uninstallExtension(gitUrl, false); await extensionManager.uninstallExtension(gitUrl, false);
expect(fs.existsSync(sourceExtDir)).toBe(false); expect(fs.existsSync(sourceExtDir)).toBe(false);
@@ -1526,7 +1562,7 @@ This extension will run the following MCP servers:
// No installMetadata provided // No installMetadata provided
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
await expect( await expect(
extensionManager.uninstallExtension( extensionManager.uninstallExtension(
'https://github.com/google/no-metadata-extension', 'https://github.com/google/no-metadata-extension',
@@ -1537,14 +1573,14 @@ This extension will run the following MCP servers:
}); });
describe('disableExtension', () => { describe('disableExtension', () => {
it('should disable an extension at the user scope', () => { it('should disable an extension at the user scope', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'my-extension', name: 'my-extension',
version: '1.0.0', version: '1.0.0',
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
extensionManager.disableExtension('my-extension', SettingScope.User); extensionManager.disableExtension('my-extension', SettingScope.User);
expect( expect(
isEnabled({ isEnabled({
@@ -1554,14 +1590,14 @@ This extension will run the following MCP servers:
).toBe(false); ).toBe(false);
}); });
it('should disable an extension at the workspace scope', () => { it('should disable an extension at the workspace scope', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'my-extension', name: 'my-extension',
version: '1.0.0', version: '1.0.0',
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
extensionManager.disableExtension('my-extension', SettingScope.Workspace); extensionManager.disableExtension('my-extension', SettingScope.Workspace);
expect( expect(
isEnabled({ isEnabled({
@@ -1577,14 +1613,14 @@ This extension will run the following MCP servers:
).toBe(false); ).toBe(false);
}); });
it('should handle disabling the same extension twice', () => { it('should handle disabling the same extension twice', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'my-extension', name: 'my-extension',
version: '1.0.0', version: '1.0.0',
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
extensionManager.disableExtension('my-extension', SettingScope.User); extensionManager.disableExtension('my-extension', SettingScope.User);
extensionManager.disableExtension('my-extension', SettingScope.User); extensionManager.disableExtension('my-extension', SettingScope.User);
expect( expect(
@@ -1595,13 +1631,17 @@ This extension will run the following MCP servers:
).toBe(false); ).toBe(false);
}); });
it('should throw an error if you request system scope', () => { it('should throw an error if you request system scope', async () => {
expect(() => await expect(
extensionManager.disableExtension('my-extension', SettingScope.System), async () =>
).toThrow('System and SystemDefaults scopes are not supported.'); await extensionManager.disableExtension(
'my-extension',
SettingScope.System,
),
).rejects.toThrow('System and SystemDefaults scopes are not supported.');
}); });
it('should log a disable event', () => { it('should log a disable event', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'ext1', name: 'ext1',
@@ -1612,7 +1652,7 @@ This extension will run the following MCP servers:
}, },
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
extensionManager.disableExtension('ext1', SettingScope.Workspace); extensionManager.disableExtension('ext1', SettingScope.Workspace);
expect(mockLogExtensionDisable).toHaveBeenCalled(); expect(mockLogExtensionDisable).toHaveBeenCalled();
@@ -1634,41 +1674,41 @@ This extension will run the following MCP servers:
return extensions.filter((e) => e.isActive); return extensions.filter((e) => e.isActive);
}; };
it('should enable an extension at the user scope', () => { it('should enable an extension at the user scope', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'ext1', name: 'ext1',
version: '1.0.0', version: '1.0.0',
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
extensionManager.disableExtension('ext1', SettingScope.User); extensionManager.disableExtension('ext1', SettingScope.User);
let activeExtensions = getActiveExtensions(); let activeExtensions = getActiveExtensions();
expect(activeExtensions).toHaveLength(0); expect(activeExtensions).toHaveLength(0);
extensionManager.enableExtension('ext1', SettingScope.User); await extensionManager.enableExtension('ext1', SettingScope.User);
activeExtensions = getActiveExtensions(); activeExtensions = await getActiveExtensions();
expect(activeExtensions).toHaveLength(1); expect(activeExtensions).toHaveLength(1);
expect(activeExtensions[0].name).toBe('ext1'); expect(activeExtensions[0].name).toBe('ext1');
}); });
it('should enable an extension at the workspace scope', () => { it('should enable an extension at the workspace scope', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'ext1', name: 'ext1',
version: '1.0.0', version: '1.0.0',
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
extensionManager.disableExtension('ext1', SettingScope.Workspace); extensionManager.disableExtension('ext1', SettingScope.Workspace);
let activeExtensions = getActiveExtensions(); let activeExtensions = getActiveExtensions();
expect(activeExtensions).toHaveLength(0); expect(activeExtensions).toHaveLength(0);
extensionManager.enableExtension('ext1', SettingScope.Workspace); await extensionManager.enableExtension('ext1', SettingScope.Workspace);
activeExtensions = getActiveExtensions(); activeExtensions = await getActiveExtensions();
expect(activeExtensions).toHaveLength(1); expect(activeExtensions).toHaveLength(1);
expect(activeExtensions[0].name).toBe('ext1'); expect(activeExtensions[0].name).toBe('ext1');
}); });
it('should log an enable event', () => { it('should log an enable event', async () => {
createExtension({ createExtension({
extensionsDir: userExtensionsDir, extensionsDir: userExtensionsDir,
name: 'ext1', name: 'ext1',
@@ -1678,7 +1718,7 @@ This extension will run the following MCP servers:
type: 'local', type: 'local',
}, },
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
extensionManager.disableExtension('ext1', SettingScope.Workspace); extensionManager.disableExtension('ext1', SettingScope.Workspace);
extensionManager.enableExtension('ext1', SettingScope.Workspace); extensionManager.enableExtension('ext1', SettingScope.Workspace);
@@ -17,6 +17,7 @@ import { ExtensionStorage } from './storage.js';
import prompts from 'prompts'; import prompts from 'prompts';
import * as fsPromises from 'node:fs/promises'; import * as fsPromises from 'node:fs/promises';
import * as fs from 'node:fs'; import * as fs from 'node:fs';
import { KeychainTokenStorage } from '@google/gemini-cli-core';
vi.mock('prompts'); vi.mock('prompts');
vi.mock('os', async (importOriginal) => { vi.mock('os', async (importOriginal) => {
@@ -27,11 +28,59 @@ vi.mock('os', async (importOriginal) => {
}; };
}); });
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const actual =
await importOriginal<typeof import('@google/gemini-cli-core')>();
return {
...actual,
KeychainTokenStorage: vi.fn().mockImplementation(() => ({
getSecret: vi.fn(),
setSecret: vi.fn(),
deleteSecret: vi.fn(),
listSecrets: vi.fn(),
isAvailable: vi.fn().mockResolvedValue(true),
})),
};
});
interface MockKeychainStorage {
getSecret: ReturnType<typeof vi.fn>;
setSecret: ReturnType<typeof vi.fn>;
deleteSecret: ReturnType<typeof vi.fn>;
listSecrets: ReturnType<typeof vi.fn>;
isAvailable: ReturnType<typeof vi.fn>;
}
describe('extensionSettings', () => { describe('extensionSettings', () => {
let tempHomeDir: string; let tempHomeDir: string;
let extensionDir: string; let extensionDir: string;
let mockKeychainStorage: MockKeychainStorage;
let keychainData: Record<string, string>;
beforeEach(() => { beforeEach(() => {
vi.clearAllMocks();
keychainData = {};
mockKeychainStorage = {
getSecret: vi
.fn()
.mockImplementation(async (key: string) => keychainData[key] || null),
setSecret: vi
.fn()
.mockImplementation(async (key: string, value: string) => {
keychainData[key] = value;
}),
deleteSecret: vi.fn().mockImplementation(async (key: string) => {
delete keychainData[key];
}),
listSecrets: vi
.fn()
.mockImplementation(async () => Object.keys(keychainData)),
isAvailable: vi.fn().mockResolvedValue(true),
};
(
KeychainTokenStorage as unknown as ReturnType<typeof vi.fn>
).mockImplementation(() => mockKeychainStorage);
tempHomeDir = os.tmpdir() + path.sep + `gemini-cli-test-home-${Date.now()}`; tempHomeDir = os.tmpdir() + path.sep + `gemini-cli-test-home-${Date.now()}`;
extensionDir = path.join(tempHomeDir, '.gemini', 'extensions', 'test-ext'); extensionDir = path.join(tempHomeDir, '.gemini', 'extensions', 'test-ext');
// Spy and mock the method, but also create the directory so we can write to it. // Spy and mock the method, but also create the directory so we can write to it.
@@ -59,7 +108,13 @@ describe('extensionSettings', () => {
it('should do nothing if settings are undefined', async () => { it('should do nothing if settings are undefined', async () => {
const config: ExtensionConfig = { name: 'test-ext', version: '1.0.0' }; const config: ExtensionConfig = { name: 'test-ext', version: '1.0.0' };
await maybePromptForSettings(config, mockRequestSetting); await maybePromptForSettings(
config,
'12345',
mockRequestSetting,
undefined,
undefined,
);
expect(mockRequestSetting).not.toHaveBeenCalled(); expect(mockRequestSetting).not.toHaveBeenCalled();
}); });
@@ -69,11 +124,17 @@ describe('extensionSettings', () => {
version: '1.0.0', version: '1.0.0',
settings: [], settings: [],
}; };
await maybePromptForSettings(config, mockRequestSetting); await maybePromptForSettings(
config,
'12345',
mockRequestSetting,
undefined,
undefined,
);
expect(mockRequestSetting).not.toHaveBeenCalled(); expect(mockRequestSetting).not.toHaveBeenCalled();
}); });
it('should call requestSetting for each setting', async () => { it('should prompt for all settings if there is no previous config', async () => {
const config: ExtensionConfig = { const config: ExtensionConfig = {
name: 'test-ext', name: 'test-ext',
version: '1.0.0', version: '1.0.0',
@@ -82,14 +143,25 @@ describe('extensionSettings', () => {
{ name: 's2', description: 'd2', envVar: 'VAR2' }, { name: 's2', description: 'd2', envVar: 'VAR2' },
], ],
}; };
await maybePromptForSettings(config, mockRequestSetting); await maybePromptForSettings(
config,
'12345',
mockRequestSetting,
undefined,
undefined,
);
expect(mockRequestSetting).toHaveBeenCalledTimes(2); expect(mockRequestSetting).toHaveBeenCalledTimes(2);
expect(mockRequestSetting).toHaveBeenCalledWith(config.settings![0]); expect(mockRequestSetting).toHaveBeenCalledWith(config.settings![0]);
expect(mockRequestSetting).toHaveBeenCalledWith(config.settings![1]); expect(mockRequestSetting).toHaveBeenCalledWith(config.settings![1]);
}); });
it('should write the .env file with the correct content', async () => { it('should only prompt for new settings', async () => {
const config: ExtensionConfig = { const previousConfig: ExtensionConfig = {
name: 'test-ext',
version: '1.0.0',
settings: [{ name: 's1', description: 'd1', envVar: 'VAR1' }],
};
const newConfig: ExtensionConfig = {
name: 'test-ext', name: 'test-ext',
version: '1.0.0', version: '1.0.0',
settings: [ settings: [
@@ -97,35 +169,151 @@ describe('extensionSettings', () => {
{ name: 's2', description: 'd2', envVar: 'VAR2' }, { name: 's2', description: 'd2', envVar: 'VAR2' },
], ],
}; };
await maybePromptForSettings(config, mockRequestSetting); const previousSettings = { VAR1: 'previous-VAR1' };
await maybePromptForSettings(
newConfig,
'12345',
mockRequestSetting,
previousConfig,
previousSettings,
);
expect(mockRequestSetting).toHaveBeenCalledTimes(1);
expect(mockRequestSetting).toHaveBeenCalledWith(newConfig.settings![1]);
const expectedEnvPath = path.join(extensionDir, '.env'); const expectedEnvPath = path.join(extensionDir, '.env');
const actualContent = await fsPromises.readFile(expectedEnvPath, 'utf-8'); const actualContent = await fsPromises.readFile(expectedEnvPath, 'utf-8');
const expectedContent = 'VAR1=mock-VAR1\nVAR2=mock-VAR2\n'; const expectedContent = 'VAR1=previous-VAR1\nVAR2=mock-VAR2\n';
expect(actualContent).toBe(expectedContent);
});
it('should remove settings that are no longer in the config', async () => {
const previousConfig: ExtensionConfig = {
name: 'test-ext',
version: '1.0.0',
settings: [
{ name: 's1', description: 'd1', envVar: 'VAR1' },
{ name: 's2', description: 'd2', envVar: 'VAR2' },
],
};
const newConfig: ExtensionConfig = {
name: 'test-ext',
version: '1.0.0',
settings: [{ name: 's1', description: 'd1', envVar: 'VAR1' }],
};
const previousSettings = {
VAR1: 'previous-VAR1',
VAR2: 'previous-VAR2',
};
await maybePromptForSettings(
newConfig,
'12345',
mockRequestSetting,
previousConfig,
previousSettings,
);
expect(mockRequestSetting).not.toHaveBeenCalled();
const expectedEnvPath = path.join(extensionDir, '.env');
const actualContent = await fsPromises.readFile(expectedEnvPath, 'utf-8');
const expectedContent = 'VAR1=previous-VAR1\n';
expect(actualContent).toBe(expectedContent);
});
it('should reprompt if a setting changes sensitivity', async () => {
const previousConfig: ExtensionConfig = {
name: 'test-ext',
version: '1.0.0',
settings: [
{ name: 's1', description: 'd1', envVar: 'VAR1', sensitive: false },
],
};
const newConfig: ExtensionConfig = {
name: 'test-ext',
version: '1.0.0',
settings: [
{ name: 's1', description: 'd1', envVar: 'VAR1', sensitive: true },
],
};
const previousSettings = { VAR1: 'previous-VAR1' };
await maybePromptForSettings(
newConfig,
'12345',
mockRequestSetting,
previousConfig,
previousSettings,
);
expect(mockRequestSetting).toHaveBeenCalledTimes(1);
expect(mockRequestSetting).toHaveBeenCalledWith(newConfig.settings![0]);
// The value should now be in keychain, not the .env file.
const expectedEnvPath = path.join(extensionDir, '.env');
const actualContent = await fsPromises.readFile(expectedEnvPath, 'utf-8');
expect(actualContent).toBe('');
});
it('should not prompt if settings are identical', async () => {
const previousConfig: ExtensionConfig = {
name: 'test-ext',
version: '1.0.0',
settings: [
{ name: 's1', description: 'd1', envVar: 'VAR1' },
{ name: 's2', description: 'd2', envVar: 'VAR2' },
],
};
const newConfig: ExtensionConfig = {
name: 'test-ext',
version: '1.0.0',
settings: [
{ name: 's1', description: 'd1', envVar: 'VAR1' },
{ name: 's2', description: 'd2', envVar: 'VAR2' },
],
};
const previousSettings = {
VAR1: 'previous-VAR1',
VAR2: 'previous-VAR2',
};
await maybePromptForSettings(
newConfig,
'12345',
mockRequestSetting,
previousConfig,
previousSettings,
);
expect(mockRequestSetting).not.toHaveBeenCalled();
const expectedEnvPath = path.join(extensionDir, '.env');
const actualContent = await fsPromises.readFile(expectedEnvPath, 'utf-8');
const expectedContent = 'VAR1=previous-VAR1\nVAR2=previous-VAR2\n';
expect(actualContent).toBe(expectedContent); expect(actualContent).toBe(expectedContent);
}); });
}); });
describe('promptForSetting', () => { describe('promptForSetting', () => {
// it('should use prompts with type "password" for sensitive settings', async () => { it('should use prompts with type "password" for sensitive settings', async () => {
// const setting: ExtensionSetting = { const setting: ExtensionSetting = {
// name: 'API Key', name: 'API Key',
// description: 'Your secret key', description: 'Your secret key',
// envVar: 'API_KEY', envVar: 'API_KEY',
// sensitive: true, sensitive: true,
// }; };
// vi.mocked(prompts).mockResolvedValue({ value: 'secret-key' }); vi.mocked(prompts).mockResolvedValue({ value: 'secret-key' });
// const result = await promptForSetting(setting); const result = await promptForSetting(setting);
// expect(prompts).toHaveBeenCalledWith({ expect(prompts).toHaveBeenCalledWith({
// type: 'password', type: 'password',
// name: 'value', name: 'value',
// message: 'API Key\nYour secret key', message: 'API Key\nYour secret key',
// }); });
// expect(result).toBe('secret-key'); expect(result).toBe('secret-key');
// }); });
it('should use prompts with type "text" for non-sensitive settings', async () => { it('should use prompts with type "text" for non-sensitive settings', async () => {
const setting: ExtensionSetting = { const setting: ExtensionSetting = {
@@ -12,57 +12,76 @@ import { ExtensionStorage } from './storage.js';
import type { ExtensionConfig } from '../extension.js'; import type { ExtensionConfig } from '../extension.js';
import prompts from 'prompts'; import prompts from 'prompts';
import { KeychainTokenStorage } from '@google/gemini-cli-core';
export interface ExtensionSetting { export interface ExtensionSetting {
name: string; name: string;
description: string; description: string;
envVar: string; envVar: string;
// NOTE: If no value is set, this setting will be considered NOT sensitive.
sensitive?: boolean;
} }
export async function maybePromptForSettings( export async function maybePromptForSettings(
extensionConfig: ExtensionConfig, extensionConfig: ExtensionConfig,
extensionId: string,
requestSetting: (setting: ExtensionSetting) => Promise<string>, requestSetting: (setting: ExtensionSetting) => Promise<string>,
previousExtensionConfig?: ExtensionConfig, previousExtensionConfig?: ExtensionConfig,
previousSettings?: Record<string, string>, previousSettings?: Record<string, string>,
): Promise<void> { ): Promise<void> {
const { name: extensionName, settings } = extensionConfig; const { name: extensionName, settings } = extensionConfig;
if (
(!settings || settings.length === 0) &&
(!previousExtensionConfig?.settings ||
previousExtensionConfig.settings.length === 0)
) {
return;
}
const envFilePath = new ExtensionStorage(extensionName).getEnvFilePath(); const envFilePath = new ExtensionStorage(extensionName).getEnvFilePath();
const keychain = new KeychainTokenStorage(extensionId);
if (!settings || settings.length === 0) { if (!settings || settings.length === 0) {
// No settings for this extension. Clear any existing .env file. await clearSettings(envFilePath, keychain);
if (fsSync.existsSync(envFilePath)) {
await fs.writeFile(envFilePath, '');
}
return; return;
} }
let settingsToPrompt = settings; const settingsChanges = getSettingsChanges(
if (previousExtensionConfig) { settings,
const oldSettings = new Set( previousExtensionConfig?.settings ?? [],
previousExtensionConfig.settings?.map((s) => s.name) || [], );
);
settingsToPrompt = settingsToPrompt.filter((s) => !oldSettings.has(s.name));
}
const allSettings: Record<string, string> = { ...(previousSettings ?? {}) }; const allSettings: Record<string, string> = { ...(previousSettings ?? {}) };
if (settingsToPrompt && settingsToPrompt.length > 0) { for (const removedEnvSetting of settingsChanges.removeEnv) {
for (const setting of settingsToPrompt) { delete allSettings[removedEnvSetting.envVar];
const answer = await requestSetting(setting);
allSettings[setting.envVar] = answer;
}
} }
const validEnvVars = new Set(settings.map((s) => s.envVar)); for (const removedSensitiveSetting of settingsChanges.removeSensitive) {
const finalSettings: Record<string, string> = {}; await keychain.deleteSecret(removedSensitiveSetting.envVar);
for (const [key, value] of Object.entries(allSettings)) { }
if (validEnvVars.has(key)) {
finalSettings[key] = value; for (const setting of settingsChanges.promptForSensitive.concat(
settingsChanges.promptForEnv,
)) {
const answer = await requestSetting(setting);
allSettings[setting.envVar] = answer;
}
const nonSensitiveSettings: Record<string, string> = {};
for (const setting of settings) {
const value = allSettings[setting.envVar];
if (value === undefined) {
continue;
}
if (setting.sensitive) {
await keychain.setSecret(setting.envVar, value);
} else {
nonSensitiveSettings[setting.envVar] = value;
} }
} }
let envContent = ''; let envContent = '';
for (const [key, value] of Object.entries(finalSettings)) { for (const [key, value] of Object.entries(nonSensitiveSettings)) {
envContent += `${key}=${value}\n`; envContent += `${key}=${value}\n`;
} }
@@ -73,17 +92,22 @@ export async function promptForSetting(
setting: ExtensionSetting, setting: ExtensionSetting,
): Promise<string> { ): Promise<string> {
const response = await prompts({ const response = await prompts({
// type: setting.sensitive ? 'password' : 'text', type: setting.sensitive ? 'password' : 'text',
type: 'text',
name: 'value', name: 'value',
message: `${setting.name}\n${setting.description}`, message: `${setting.name}\n${setting.description}`,
}); });
return response.value; return response.value;
} }
export function getEnvContents( export async function getEnvContents(
extensionStorage: ExtensionStorage, extensionConfig: ExtensionConfig,
): Record<string, string> { extensionId: string,
): Promise<Record<string, string>> {
if (!extensionConfig.settings || extensionConfig.settings.length === 0) {
return Promise.resolve({});
}
const extensionStorage = new ExtensionStorage(extensionConfig.name);
const keychain = new KeychainTokenStorage(extensionId);
let customEnv: Record<string, string> = {}; let customEnv: Record<string, string> = {};
if (fsSync.existsSync(extensionStorage.getEnvFilePath())) { if (fsSync.existsSync(extensionStorage.getEnvFilePath())) {
const envFile = fsSync.readFileSync( const envFile = fsSync.readFileSync(
@@ -92,5 +116,67 @@ export function getEnvContents(
); );
customEnv = dotenv.parse(envFile); customEnv = dotenv.parse(envFile);
} }
if (extensionConfig.settings) {
for (const setting of extensionConfig.settings) {
if (setting.sensitive) {
const secret = await keychain.getSecret(setting.envVar);
if (secret) {
customEnv[setting.envVar] = secret;
}
}
}
}
return customEnv; return customEnv;
} }
interface settingsChanges {
promptForSensitive: ExtensionSetting[];
removeSensitive: ExtensionSetting[];
promptForEnv: ExtensionSetting[];
removeEnv: ExtensionSetting[];
}
function getSettingsChanges(
settings: ExtensionSetting[],
oldSettings: ExtensionSetting[],
): settingsChanges {
const isSameSetting = (a: ExtensionSetting, b: ExtensionSetting) =>
a.envVar === b.envVar && (a.sensitive ?? false) === (b.sensitive ?? false);
const sensitiveOld = oldSettings.filter((s) => s.sensitive ?? false);
const sensitiveNew = settings.filter((s) => s.sensitive ?? false);
const envOld = oldSettings.filter((s) => !(s.sensitive ?? false));
const envNew = settings.filter((s) => !(s.sensitive ?? false));
return {
promptForSensitive: sensitiveNew.filter(
(s) => !sensitiveOld.some((old) => isSameSetting(s, old)),
),
removeSensitive: sensitiveOld.filter(
(s) => !sensitiveNew.some((neu) => isSameSetting(s, neu)),
),
promptForEnv: envNew.filter(
(s) => !envOld.some((old) => isSameSetting(s, old)),
),
removeEnv: envOld.filter(
(s) => !envNew.some((neu) => isSameSetting(s, neu)),
),
};
}
async function clearSettings(
envFilePath: string,
keychain: KeychainTokenStorage,
) {
if (fsSync.existsSync(envFilePath)) {
await fs.writeFile(envFilePath, '');
}
if (!keychain.isAvailable()) {
return;
}
const secrets = await keychain.listSecrets();
for (const secret of secrets) {
await keychain.deleteSecret(secret);
}
return;
}
@@ -9,7 +9,7 @@ import * as fs from 'node:fs';
import * as os from 'node:os'; import * as os from 'node:os';
import * as path from 'node:path'; import * as path from 'node:path';
import { checkForAllExtensionUpdates, updateExtension } from './update.js'; import { checkForAllExtensionUpdates, updateExtension } from './update.js';
import { GEMINI_DIR } from '@google/gemini-cli-core'; import { GEMINI_DIR, KeychainTokenStorage } from '@google/gemini-cli-core';
import { isWorkspaceTrusted } from '../trustedFolders.js'; import { isWorkspaceTrusted } from '../trustedFolders.js';
import { ExtensionUpdateState } from '../../ui/state/extensions.js'; import { ExtensionUpdateState } from '../../ui/state/extensions.js';
import { createExtension } from '../../test-utils/createExtension.js'; import { createExtension } from '../../test-utils/createExtension.js';
@@ -64,9 +64,24 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
logExtensionUninstall: mockLogExtensionUninstall, logExtensionUninstall: mockLogExtensionUninstall,
ExtensionInstallEvent: vi.fn(), ExtensionInstallEvent: vi.fn(),
ExtensionUninstallEvent: vi.fn(), ExtensionUninstallEvent: vi.fn(),
KeychainTokenStorage: vi.fn().mockImplementation(() => ({
getSecret: vi.fn(),
setSecret: vi.fn(),
deleteSecret: vi.fn(),
listSecrets: vi.fn(),
isAvailable: vi.fn().mockResolvedValue(true),
})),
}; };
}); });
interface MockKeychainStorage {
getSecret: ReturnType<typeof vi.fn>;
setSecret: ReturnType<typeof vi.fn>;
deleteSecret: ReturnType<typeof vi.fn>;
listSecrets: ReturnType<typeof vi.fn>;
isAvailable: ReturnType<typeof vi.fn>;
}
describe('update tests', () => { describe('update tests', () => {
let tempHomeDir: string; let tempHomeDir: string;
let tempWorkspaceDir: string; let tempWorkspaceDir: string;
@@ -76,8 +91,32 @@ describe('update tests', () => {
let mockPromptForSettings: MockedFunction< let mockPromptForSettings: MockedFunction<
(setting: ExtensionSetting) => Promise<string> (setting: ExtensionSetting) => Promise<string>
>; >;
let mockKeychainStorage: MockKeychainStorage;
let keychainData: Record<string, string>;
beforeEach(() => { beforeEach(() => {
vi.clearAllMocks();
keychainData = {};
mockKeychainStorage = {
getSecret: vi
.fn()
.mockImplementation(async (key: string) => keychainData[key] || null),
setSecret: vi
.fn()
.mockImplementation(async (key: string, value: string) => {
keychainData[key] = value;
}),
deleteSecret: vi.fn().mockImplementation(async (key: string) => {
delete keychainData[key];
}),
listSecrets: vi
.fn()
.mockImplementation(async () => Object.keys(keychainData)),
isAvailable: vi.fn().mockResolvedValue(true),
};
(
KeychainTokenStorage as unknown as ReturnType<typeof vi.fn>
).mockImplementation(() => mockKeychainStorage);
tempHomeDir = fs.mkdtempSync( tempHomeDir = fs.mkdtempSync(
path.join(os.tmpdir(), 'gemini-cli-test-home-'), path.join(os.tmpdir(), 'gemini-cli-test-home-'),
); );
@@ -110,6 +149,7 @@ describe('update tests', () => {
afterEach(() => { afterEach(() => {
fs.rmSync(tempHomeDir, { recursive: true, force: true }); fs.rmSync(tempHomeDir, { recursive: true, force: true });
fs.rmSync(tempWorkspaceDir, { recursive: true, force: true }); fs.rmSync(tempWorkspaceDir, { recursive: true, force: true });
vi.restoreAllMocks();
}); });
describe('updateExtension', () => { describe('updateExtension', () => {
@@ -139,11 +179,10 @@ describe('update tests', () => {
); );
}); });
mockGit.getRemotes.mockResolvedValue([{ name: 'origin' }]); mockGit.getRemotes.mockResolvedValue([{ name: 'origin' }]);
const extension = extensionManager const extensions = await extensionManager.loadExtensions();
.loadExtensions() const extension = extensions.find((e) => e.name === extensionName)!;
.find((e) => e.name === extensionName)!;
const updateInfo = await updateExtension( const updateInfo = await updateExtension(
extension, extension!,
extensionManager, extensionManager,
ExtensionUpdateState.UPDATE_AVAILABLE, ExtensionUpdateState.UPDATE_AVAILABLE,
() => {}, () => {},
@@ -189,11 +228,10 @@ describe('update tests', () => {
const dispatch = vi.fn(); const dispatch = vi.fn();
const extension = extensionManager const extensions = await extensionManager.loadExtensions();
.loadExtensions() const extension = extensions.find((e) => e.name === extensionName)!;
.find((e) => e.name === extensionName)!;
await updateExtension( await updateExtension(
extension, extension!,
extensionManager, extensionManager,
ExtensionUpdateState.UPDATE_AVAILABLE, ExtensionUpdateState.UPDATE_AVAILABLE,
dispatch, dispatch,
@@ -231,12 +269,11 @@ describe('update tests', () => {
mockGit.getRemotes.mockResolvedValue([{ name: 'origin' }]); mockGit.getRemotes.mockResolvedValue([{ name: 'origin' }]);
const dispatch = vi.fn(); const dispatch = vi.fn();
const extension = extensionManager const extensions = await extensionManager.loadExtensions();
.loadExtensions() const extension = extensions.find((e) => e.name === extensionName)!;
.find((e) => e.name === extensionName)!;
await expect( await expect(
updateExtension( updateExtension(
extension, extension!,
extensionManager, extensionManager,
ExtensionUpdateState.UPDATE_AVAILABLE, ExtensionUpdateState.UPDATE_AVAILABLE,
dispatch, dispatch,
@@ -280,7 +317,7 @@ describe('update tests', () => {
const dispatch = vi.fn(); const dispatch = vi.fn();
await checkForAllExtensionUpdates( await checkForAllExtensionUpdates(
extensionManager.loadExtensions(), await extensionManager.loadExtensions(),
extensionManager, extensionManager,
dispatch, dispatch,
); );
@@ -312,7 +349,7 @@ describe('update tests', () => {
const dispatch = vi.fn(); const dispatch = vi.fn();
await checkForAllExtensionUpdates( await checkForAllExtensionUpdates(
extensionManager.loadExtensions(), await extensionManager.loadExtensions(),
extensionManager, extensionManager,
dispatch, dispatch,
); );
@@ -341,7 +378,7 @@ describe('update tests', () => {
}); });
const dispatch = vi.fn(); const dispatch = vi.fn();
await checkForAllExtensionUpdates( await checkForAllExtensionUpdates(
extensionManager.loadExtensions(), await extensionManager.loadExtensions(),
extensionManager, extensionManager,
dispatch, dispatch,
); );
@@ -370,7 +407,7 @@ describe('update tests', () => {
}); });
const dispatch = vi.fn(); const dispatch = vi.fn();
await checkForAllExtensionUpdates( await checkForAllExtensionUpdates(
extensionManager.loadExtensions(), await extensionManager.loadExtensions(),
extensionManager, extensionManager,
dispatch, dispatch,
); );
@@ -398,7 +435,7 @@ describe('update tests', () => {
const dispatch = vi.fn(); const dispatch = vi.fn();
await checkForAllExtensionUpdates( await checkForAllExtensionUpdates(
extensionManager.loadExtensions(), await extensionManager.loadExtensions(),
extensionManager, extensionManager,
dispatch, dispatch,
); );
+1 -1
View File
@@ -58,7 +58,7 @@ export async function updateExtension(
const tempDir = await ExtensionStorage.createTmpDir(); const tempDir = await ExtensionStorage.createTmpDir();
try { try {
const previousExtensionConfig = await extensionManager.loadExtensionConfig( const previousExtensionConfig = extensionManager.loadExtensionConfig(
extension.path, extension.path,
); );
let updatedExtension: GeminiCLIExtension; let updatedExtension: GeminiCLIExtension;
+2 -2
View File
@@ -2442,7 +2442,7 @@ describe('Settings Loading and Merging', () => {
extensionManager, extensionManager,
'disableExtension', 'disableExtension',
); );
mockDisableExtension.mockImplementation(() => {}); mockDisableExtension.mockImplementation(async () => {});
migrateDeprecatedSettings(loadedSettings, extensionManager); migrateDeprecatedSettings(loadedSettings, extensionManager);
@@ -2515,7 +2515,7 @@ describe('Settings Loading and Merging', () => {
extensionManager, extensionManager,
'disableExtension', 'disableExtension',
); );
mockDisableExtension.mockImplementation(() => {}); mockDisableExtension.mockImplementation(async () => {});
migrateDeprecatedSettings(loadedSettings, extensionManager); migrateDeprecatedSettings(loadedSettings, extensionManager);
@@ -124,7 +124,7 @@ describe('useExtensionUpdates', () => {
autoUpdate: true, autoUpdate: true,
}, },
}); });
await extensionManager.loadExtensions();
const addItem = vi.fn(); const addItem = vi.fn();
vi.mocked(checkForAllExtensionUpdates).mockImplementation( vi.mocked(checkForAllExtensionUpdates).mockImplementation(
@@ -145,7 +145,6 @@ describe('useExtensionUpdates', () => {
name: '', name: '',
}); });
extensionManager.loadExtensions();
function TestComponent() { function TestComponent() {
useExtensionUpdates(extensionManager, addItem); useExtensionUpdates(extensionManager, addItem);
return null; return null;
@@ -189,7 +188,7 @@ describe('useExtensionUpdates', () => {
}, },
}); });
extensionManager.loadExtensions(); await extensionManager.loadExtensions();
const addItem = vi.fn(); const addItem = vi.fn();
+1
View File
@@ -44,5 +44,6 @@ export { makeFakeConfig } from './src/test-utils/config.js';
export * from './src/utils/pathReader.js'; export * from './src/utils/pathReader.js';
export { ClearcutLogger } from './src/telemetry/clearcut-logger/clearcut-logger.js'; export { ClearcutLogger } from './src/telemetry/clearcut-logger/clearcut-logger.js';
export { logModelSlashCommand } from './src/telemetry/loggers.js'; export { logModelSlashCommand } from './src/telemetry/loggers.js';
export { KeychainTokenStorage } from './src/mcp/token-storage/keychain-token-storage.js';
export * from './src/utils/googleQuotaErrors.js'; export * from './src/utils/googleQuotaErrors.js';
export type { GoogleApiError } from './src/utils/googleErrors.js'; export type { GoogleApiError } from './src/utils/googleErrors.js';
@@ -386,5 +386,54 @@ describe('KeychainTokenStorage', () => {
); );
}); });
}); });
describe('Secrets', () => {
it('should set and get a secret', async () => {
mockKeytar.setPassword.mockResolvedValue(undefined);
mockKeytar.getPassword.mockResolvedValue('secret-value');
await storage.setSecret('secret-key', 'secret-value');
const value = await storage.getSecret('secret-key');
expect(mockKeytar.setPassword).toHaveBeenCalledWith(
mockServiceName,
'__secret__secret-key',
'secret-value',
);
expect(mockKeytar.getPassword).toHaveBeenCalledWith(
mockServiceName,
'__secret__secret-key',
);
expect(value).toBe('secret-value');
});
it('should delete a secret', async () => {
mockKeytar.deletePassword.mockResolvedValue(true);
await storage.deleteSecret('secret-key');
expect(mockKeytar.deletePassword).toHaveBeenCalledWith(
mockServiceName,
'__secret__secret-key',
);
});
it('should list secrets', async () => {
mockKeytar.findCredentials.mockResolvedValue([
{ account: '__secret__secret1', password: '' },
{ account: '__secret__secret2', password: '' },
{ account: 'server1', password: '' },
]);
const secrets = await storage.listSecrets();
expect(secrets).toEqual(['secret1', 'secret2']);
});
it('should not list secrets in listServers', async () => {
mockKeytar.findCredentials.mockResolvedValue([
{ account: '__secret__secret1', password: '' },
{ account: 'server1', password: '' },
]);
const servers = await storage.listServers();
expect(servers).toEqual(['server1']);
});
});
}); });
}); });
@@ -6,7 +6,7 @@
import * as crypto from 'node:crypto'; import * as crypto from 'node:crypto';
import { BaseTokenStorage } from './base-token-storage.js'; import { BaseTokenStorage } from './base-token-storage.js';
import type { OAuthCredentials } from './types.js'; import type { OAuthCredentials, SecretStorage } from './types.js';
import { coreEvents } from '../../utils/events.js'; import { coreEvents } from '../../utils/events.js';
interface Keytar { interface Keytar {
@@ -23,8 +23,12 @@ interface Keytar {
} }
const KEYCHAIN_TEST_PREFIX = '__keychain_test__'; const KEYCHAIN_TEST_PREFIX = '__keychain_test__';
const SECRET_PREFIX = '__secret__';
export class KeychainTokenStorage extends BaseTokenStorage { export class KeychainTokenStorage
extends BaseTokenStorage
implements SecretStorage
{
private keychainAvailable: boolean | null = null; private keychainAvailable: boolean | null = null;
private keytarModule: Keytar | null = null; private keytarModule: Keytar | null = null;
private keytarLoadAttempted = false; private keytarLoadAttempted = false;
@@ -137,7 +141,11 @@ export class KeychainTokenStorage extends BaseTokenStorage {
try { try {
const credentials = await keytar.findCredentials(this.serviceName); const credentials = await keytar.findCredentials(this.serviceName);
return credentials return credentials
.filter((cred) => !cred.account.startsWith(KEYCHAIN_TEST_PREFIX)) .filter(
(cred) =>
!cred.account.startsWith(KEYCHAIN_TEST_PREFIX) &&
!cred.account.startsWith(SECRET_PREFIX),
)
.map((cred: { account: string }) => cred.account); .map((cred: { account: string }) => cred.account);
} catch (error) { } catch (error) {
coreEvents.emitFeedback( coreEvents.emitFeedback(
@@ -163,7 +171,11 @@ export class KeychainTokenStorage extends BaseTokenStorage {
try { try {
const credentials = ( const credentials = (
await keytar.findCredentials(this.serviceName) await keytar.findCredentials(this.serviceName)
).filter((c) => !c.account.startsWith(KEYCHAIN_TEST_PREFIX)); ).filter(
(c) =>
!c.account.startsWith(KEYCHAIN_TEST_PREFIX) &&
!c.account.startsWith(SECRET_PREFIX),
);
for (const cred of credentials) { for (const cred of credentials) {
try { try {
@@ -258,4 +270,62 @@ export class KeychainTokenStorage extends BaseTokenStorage {
async isAvailable(): Promise<boolean> { async isAvailable(): Promise<boolean> {
return this.checkKeychainAvailability(); return this.checkKeychainAvailability();
} }
async setSecret(key: string, value: string): Promise<void> {
if (!(await this.checkKeychainAvailability())) {
throw new Error('Keychain is not available');
}
const keytar = await this.getKeytar();
if (!keytar) {
throw new Error('Keytar module not available');
}
await keytar.setPassword(this.serviceName, `${SECRET_PREFIX}${key}`, value);
}
async getSecret(key: string): Promise<string | null> {
if (!(await this.checkKeychainAvailability())) {
throw new Error('Keychain is not available');
}
const keytar = await this.getKeytar();
if (!keytar) {
throw new Error('Keytar module not available');
}
return keytar.getPassword(this.serviceName, `${SECRET_PREFIX}${key}`);
}
async deleteSecret(key: string): Promise<void> {
if (!(await this.checkKeychainAvailability())) {
throw new Error('Keychain is not available');
}
const keytar = await this.getKeytar();
if (!keytar) {
throw new Error('Keytar module not available');
}
const deleted = await keytar.deletePassword(
this.serviceName,
`${SECRET_PREFIX}${key}`,
);
if (!deleted) {
throw new Error(`No secret found for key: ${key}`);
}
}
async listSecrets(): Promise<string[]> {
if (!(await this.checkKeychainAvailability())) {
throw new Error('Keychain is not available');
}
const keytar = await this.getKeytar();
if (!keytar) {
throw new Error('Keytar module not available');
}
try {
const credentials = await keytar.findCredentials(this.serviceName);
return credentials
.filter((cred) => cred.account.startsWith(SECRET_PREFIX))
.map((cred) => cred.account.substring(SECRET_PREFIX.length));
} catch (error) {
console.error('Failed to list secrets from keychain:', error);
return [];
}
}
} }
@@ -36,6 +36,13 @@ export interface TokenStorage {
clearAll(): Promise<void>; clearAll(): Promise<void>;
} }
export interface SecretStorage {
setSecret(key: string, value: string): Promise<void>;
getSecret(key: string): Promise<string | null>;
deleteSecret(key: string): Promise<void>;
listSecrets(): Promise<string[]>;
}
export enum TokenStorageType { export enum TokenStorageType {
KEYCHAIN = 'keychain', KEYCHAIN = 'keychain',
ENCRYPTED_FILE = 'encrypted_file', ENCRYPTED_FILE = 'encrypted_file',