From 7e987113a21d3d15934fb9e76729c0d61e3663bd Mon Sep 17 00:00:00 2001 From: christine betts Date: Tue, 28 Oct 2025 14:48:50 -0400 Subject: [PATCH] Add support for sensitive keychain-stored per-extension settings (#11953) --- docs/extensions/index.md | 4 +- .../cli/src/commands/extensions/disable.ts | 11 +- .../cli/src/commands/extensions/enable.ts | 4 +- .../cli/src/commands/extensions/install.ts | 2 +- packages/cli/src/commands/extensions/link.ts | 2 +- packages/cli/src/commands/extensions/list.ts | 2 +- .../cli/src/commands/extensions/uninstall.ts | 2 +- .../cli/src/commands/extensions/update.ts | 2 +- packages/cli/src/commands/mcp/list.ts | 2 +- packages/cli/src/config/config.ts | 2 +- packages/cli/src/config/extension-manager.ts | 38 ++- packages/cli/src/config/extension.test.ts | 252 ++++++++++-------- .../extensions/extensionSettings.test.ts | 236 ++++++++++++++-- .../config/extensions/extensionSettings.ts | 140 ++++++++-- .../cli/src/config/extensions/update.test.ts | 73 +++-- packages/cli/src/config/extensions/update.ts | 2 +- packages/cli/src/config/settings.test.ts | 4 +- .../src/ui/hooks/useExtensionUpdates.test.tsx | 5 +- packages/core/index.ts | 1 + .../keychain-token-storage.test.ts | 49 ++++ .../token-storage/keychain-token-storage.ts | 78 +++++- packages/core/src/mcp/token-storage/types.ts | 7 + 22 files changed, 706 insertions(+), 212 deletions(-) diff --git a/docs/extensions/index.md b/docs/extensions/index.md index e07930dcf4..84d116cfe6 100644 --- a/docs/extensions/index.md +++ b/docs/extensions/index.md @@ -190,8 +190,8 @@ Each object in the array should have the following properties: - `description`: A description of the setting and what it's used for. - `envVar`: The name of the environment variable that the setting will be stored as. - -**Example** +- `sensitive`: Optional boolean. If true, obfuscates the input the user provides + and stores the secret in keychain storage. **Example** ```json { diff --git a/packages/cli/src/commands/extensions/disable.ts b/packages/cli/src/commands/extensions/disable.ts index 40bed33f83..bb60087275 100644 --- a/packages/cli/src/commands/extensions/disable.ts +++ b/packages/cli/src/commands/extensions/disable.ts @@ -17,7 +17,7 @@ interface DisableArgs { scope?: string; } -export function handleDisable(args: DisableArgs) { +export async function handleDisable(args: DisableArgs) { const workspaceDir = process.cwd(); const extensionManager = new ExtensionManager({ workspaceDir, @@ -25,13 +25,16 @@ export function handleDisable(args: DisableArgs) { requestSetting: promptForSetting, settings: loadSettings(workspaceDir).merged, }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); try { if (args.scope?.toLowerCase() === 'workspace') { - extensionManager.disableExtension(args.name, SettingScope.Workspace); + await extensionManager.disableExtension( + args.name, + SettingScope.Workspace, + ); } else { - extensionManager.disableExtension(args.name, SettingScope.User); + await extensionManager.disableExtension(args.name, SettingScope.User); } debugLogger.log( `Extension "${args.name}" successfully disabled for scope "${args.scope}".`, diff --git a/packages/cli/src/commands/extensions/enable.ts b/packages/cli/src/commands/extensions/enable.ts index 468353f6a1..0796830100 100644 --- a/packages/cli/src/commands/extensions/enable.ts +++ b/packages/cli/src/commands/extensions/enable.ts @@ -20,7 +20,7 @@ interface EnableArgs { scope?: string; } -export function handleEnable(args: EnableArgs) { +export async function handleEnable(args: EnableArgs) { const workingDir = process.cwd(); const extensionManager = new ExtensionManager({ workspaceDir: workingDir, @@ -28,7 +28,7 @@ export function handleEnable(args: EnableArgs) { requestSetting: promptForSetting, settings: loadSettings(workingDir).merged, }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); try { if (args.scope?.toLowerCase() === 'workspace') { diff --git a/packages/cli/src/commands/extensions/install.ts b/packages/cli/src/commands/extensions/install.ts index 95d2e17b7a..920cfe63a4 100644 --- a/packages/cli/src/commands/extensions/install.ts +++ b/packages/cli/src/commands/extensions/install.ts @@ -76,7 +76,7 @@ export async function handleInstall(args: InstallArgs) { requestSetting: promptForSetting, settings: loadSettings(workspaceDir).merged, }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); const name = await extensionManager.installOrUpdateExtension(installMetadata); debugLogger.log(`Extension "${name}" installed successfully and enabled.`); diff --git a/packages/cli/src/commands/extensions/link.ts b/packages/cli/src/commands/extensions/link.ts index 69c18d8bbe..9bee299a5e 100644 --- a/packages/cli/src/commands/extensions/link.ts +++ b/packages/cli/src/commands/extensions/link.ts @@ -33,7 +33,7 @@ export async function handleLink(args: InstallArgs) { requestSetting: promptForSetting, settings: loadSettings(workspaceDir).merged, }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); const extensionName = await extensionManager.installOrUpdateExtension(installMetadata); debugLogger.log( diff --git a/packages/cli/src/commands/extensions/list.ts b/packages/cli/src/commands/extensions/list.ts index a0b31e45f3..4596f95cd9 100644 --- a/packages/cli/src/commands/extensions/list.ts +++ b/packages/cli/src/commands/extensions/list.ts @@ -21,7 +21,7 @@ export async function handleList() { requestSetting: promptForSetting, settings: loadSettings(workspaceDir).merged, }); - const extensions = extensionManager.loadExtensions(); + const extensions = await extensionManager.loadExtensions(); if (extensions.length === 0) { debugLogger.log('No extensions installed.'); return; diff --git a/packages/cli/src/commands/extensions/uninstall.ts b/packages/cli/src/commands/extensions/uninstall.ts index 91242fe3a1..c768c95164 100644 --- a/packages/cli/src/commands/extensions/uninstall.ts +++ b/packages/cli/src/commands/extensions/uninstall.ts @@ -25,7 +25,7 @@ export async function handleUninstall(args: UninstallArgs) { requestSetting: promptForSetting, settings: loadSettings(workspaceDir).merged, }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.uninstallExtension(args.name, false); debugLogger.log(`Extension "${args.name}" successfully uninstalled.`); } catch (error) { diff --git a/packages/cli/src/commands/extensions/update.ts b/packages/cli/src/commands/extensions/update.ts index b5c1620810..f3e78f2cca 100644 --- a/packages/cli/src/commands/extensions/update.ts +++ b/packages/cli/src/commands/extensions/update.ts @@ -37,7 +37,7 @@ export async function handleUpdate(args: UpdateArgs) { settings: loadSettings(workspaceDir).merged, }); - const extensions = extensionManager.loadExtensions(); + const extensions = await extensionManager.loadExtensions(); if (args.name) { try { const extension = extensions.find( diff --git a/packages/cli/src/commands/mcp/list.ts b/packages/cli/src/commands/mcp/list.ts index 9e41964d17..9b5571d134 100644 --- a/packages/cli/src/commands/mcp/list.ts +++ b/packages/cli/src/commands/mcp/list.ts @@ -33,7 +33,7 @@ async function getMcpServersFromConfig(): Promise< requestConsent: requestConsentNonInteractive, requestSetting: promptForSetting, }); - const extensions = extensionManager.loadExtensions(); + const extensions = await extensionManager.loadExtensions(); const mcpServers = { ...(settings.merged.mcpServers || {}) }; for (const extension of extensions) { Object.entries(extension.mcpServers || {}).forEach(([key, server]) => { diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index 2a102f78bc..9d9630634f 100755 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -423,7 +423,7 @@ export async function loadCliConfig( workspaceDir: cwd, enabledExtensionOverrides: argv.extensions, }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); // Call the (now wrapper) loadHierarchicalGeminiMemory which calls the server's version const { memoryContent, fileCount, filePaths } = diff --git a/packages/cli/src/config/extension-manager.ts b/packages/cli/src/config/extension-manager.ts index d25591fd48..9980474e73 100644 --- a/packages/cli/src/config/extension-manager.ts +++ b/packages/cli/src/config/extension-manager.ts @@ -133,7 +133,7 @@ export class ExtensionManager implements ExtensionLoader { const isUpdate = !!previousExtensionConfig; let newExtensionConfig: ExtensionConfig | null = null; let localSourcePath: string | undefined; - let extension: GeminiCLIExtension; + let extension: GeminiCLIExtension | null; try { if (!isWorkspaceTrusted(this.settings).isTrusted) { throw new Error( @@ -243,12 +243,16 @@ export class ExtensionManager implements ExtensionLoader { this.requestConsent, previousExtensionConfig, ); - - const extensionStorage = new ExtensionStorage(newExtensionName); - const destinationPath = extensionStorage.getExtensionDir(); + const extensionId = getExtensionId(newExtensionConfig, installMetadata); + const destinationPath = new ExtensionStorage( + newExtensionName, + ).getExtensionDir(); let previousSettings: Record | undefined; if (isUpdate) { - previousSettings = getEnvContents(extensionStorage); + previousSettings = await getEnvContents( + previousExtensionConfig, + extensionId, + ); await this.uninstallExtension(newExtensionName, isUpdate); } @@ -257,6 +261,7 @@ export class ExtensionManager implements ExtensionLoader { if (isUpdate) { await maybePromptForSettings( newExtensionConfig, + extensionId, this.requestSetting, previousExtensionConfig, previousSettings, @@ -264,6 +269,7 @@ export class ExtensionManager implements ExtensionLoader { } else { await maybePromptForSettings( newExtensionConfig, + extensionId, this.requestSetting, ); } @@ -286,7 +292,10 @@ export class ExtensionManager implements ExtensionLoader { // TODO: Gracefully handle this call failing, we should back up the old // extension prior to overwriting it and then restore it. - extension = this.loadExtension(destinationPath)!; + extension = await this.loadExtension(destinationPath)!; + if (!extension) { + throw new Error(`Extension not found`); + } if (isUpdate) { logExtensionUpdateEvent( this.telemetryConfig, @@ -401,7 +410,7 @@ export class ExtensionManager implements ExtensionLoader { this.eventEmitter.emit('extensionUninstalled', { extension }); } - loadExtensions(): GeminiCLIExtension[] { + async loadExtensions(): Promise { if (this.loadedExtensions) { throw new Error('Extensions already loaded, only load extensions once.'); } @@ -413,12 +422,14 @@ export class ExtensionManager implements ExtensionLoader { for (const subdir of fs.readdirSync(extensionsDir)) { const extensionDir = path.join(extensionsDir, subdir); - this.loadExtension(extensionDir); + await this.loadExtension(extensionDir); } return this.loadedExtensions; } - private loadExtension(extensionDir: string): GeminiCLIExtension | null { + private async loadExtension( + extensionDir: string, + ): Promise { this.loadedExtensions ??= []; if (!fs.statSync(extensionDir).isDirectory()) { return null; @@ -441,7 +452,10 @@ export class ExtensionManager implements ExtensionLoader { ); } - const customEnv = getEnvContents(new ExtensionStorage(config.name)); + const customEnv = await getEnvContents( + config, + getExtensionId(config, installMetadata), + ); config = resolveEnvVarsInObject(config, customEnv); if (config.mcpServers) { @@ -573,7 +587,7 @@ export class ExtensionManager implements ExtensionLoader { return output; } - disableExtension(name: string, scope: SettingScope) { + async disableExtension(name: string, scope: SettingScope) { if ( scope === SettingScope.System || scope === SettingScope.SystemDefaults @@ -598,7 +612,7 @@ export class ExtensionManager implements ExtensionLoader { this.eventEmitter.emit('extensionDisabled', { extension }); } - enableExtension(name: string, scope: SettingScope) { + async enableExtension(name: string, scope: SettingScope) { if ( scope === SettingScope.System || scope === SettingScope.SystemDefaults diff --git a/packages/cli/src/config/extension.test.ts b/packages/cli/src/config/extension.test.ts index e4fa0364ac..21df5f26de 100644 --- a/packages/cli/src/config/extension.test.ts +++ b/packages/cli/src/config/extension.test.ts @@ -13,6 +13,7 @@ import { ExtensionUninstallEvent, ExtensionDisableEvent, ExtensionEnableEvent, + KeychainTokenStorage, } from '@google/gemini-cli-core'; import { loadSettings, SettingScope } from './settings.js'; import { isWorkspaceTrusted } from './trustedFolders.js'; @@ -96,6 +97,13 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => { ExtensionInstallEvent: vi.fn(), ExtensionUninstallEvent: vi.fn(), ExtensionDisableEvent: vi.fn(), + KeychainTokenStorage: vi.fn().mockImplementation(() => ({ + getSecret: vi.fn(), + setSecret: vi.fn(), + deleteSecret: vi.fn(), + listSecrets: vi.fn(), + isAvailable: vi.fn().mockResolvedValue(true), + })), }; }); @@ -107,6 +115,14 @@ vi.mock('child_process', async (importOriginal) => { }; }); +interface MockKeychainStorage { + getSecret: ReturnType; + setSecret: ReturnType; + deleteSecret: ReturnType; + listSecrets: ReturnType; + isAvailable: ReturnType; +} + describe('extension tests', () => { let tempHomeDir: string; let tempWorkspaceDir: string; @@ -116,8 +132,32 @@ describe('extension tests', () => { let mockPromptForSettings: MockedFunction< (setting: ExtensionSetting) => Promise >; + let mockKeychainStorage: MockKeychainStorage; + let keychainData: Record; beforeEach(() => { + vi.clearAllMocks(); + keychainData = {}; + mockKeychainStorage = { + getSecret: vi + .fn() + .mockImplementation(async (key: string) => keychainData[key] || null), + setSecret: vi + .fn() + .mockImplementation(async (key: string, value: string) => { + keychainData[key] = value; + }), + deleteSecret: vi.fn().mockImplementation(async (key: string) => { + delete keychainData[key]; + }), + listSecrets: vi + .fn() + .mockImplementation(async () => Object.keys(keychainData)), + isAvailable: vi.fn().mockResolvedValue(true), + }; + ( + KeychainTokenStorage as unknown as ReturnType + ).mockImplementation(() => mockKeychainStorage); tempHomeDir = fs.mkdtempSync( path.join(os.tmpdir(), 'gemini-cli-test-home-'), ); @@ -151,7 +191,7 @@ describe('extension tests', () => { }); describe('loadExtensions', () => { - it('should include extension path in loaded extension', () => { + it('should include extension path in loaded extension', async () => { const extensionDir = path.join(userExtensionsDir, 'test-extension'); fs.mkdirSync(extensionDir, { recursive: true }); @@ -161,13 +201,13 @@ describe('extension tests', () => { version: '1.0.0', }); - const extensions = extensionManager.loadExtensions(); + const extensions = await extensionManager.loadExtensions(); expect(extensions).toHaveLength(1); expect(extensions[0].path).toBe(extensionDir); expect(extensions[0].name).toBe('test-extension'); }); - it('should load context file path when GEMINI.md is present', () => { + it('should load context file path when GEMINI.md is present', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'ext1', @@ -180,7 +220,7 @@ describe('extension tests', () => { version: '2.0.0', }); - const extensions = extensionManager.loadExtensions(); + const extensions = await extensionManager.loadExtensions(); expect(extensions).toHaveLength(2); const ext1 = extensions.find((e) => e.name === 'ext1'); @@ -191,7 +231,7 @@ describe('extension tests', () => { expect(ext2?.contextFiles).toEqual([]); }); - it('should load context file path from the extension config', () => { + it('should load context file path from the extension config', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'ext1', @@ -200,7 +240,7 @@ describe('extension tests', () => { contextFileName: 'my-context-file.md', }); - const extensions = extensionManager.loadExtensions(); + const extensions = await extensionManager.loadExtensions(); expect(extensions).toHaveLength(1); const ext1 = extensions.find((e) => e.name === 'ext1'); @@ -209,7 +249,7 @@ describe('extension tests', () => { ]); }); - it('should annotate disabled extensions', () => { + it('should annotate disabled extensions', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'disabled-extension', @@ -220,8 +260,8 @@ describe('extension tests', () => { name: 'enabled-extension', version: '2.0.0', }); - extensionManager.loadExtensions(); - extensionManager.disableExtension( + await extensionManager.loadExtensions(); + await extensionManager.disableExtension( 'disabled-extension', SettingScope.User, ); @@ -233,7 +273,7 @@ describe('extension tests', () => { expect(extensions[1].isActive).toBe(true); }); - it('should hydrate variables', () => { + it('should hydrate variables', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'test-extension', @@ -247,7 +287,7 @@ describe('extension tests', () => { }, }); - const extensions = extensionManager.loadExtensions(); + const extensions = await extensionManager.loadExtensions(); expect(extensions).toHaveLength(1); const expectedCwd = path.join( userExtensionsDir, @@ -266,7 +306,7 @@ describe('extension tests', () => { }); fs.writeFileSync(path.join(sourceExtDir, 'context.md'), 'linked context'); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); const extension = await extensionManager.installOrUpdateExtension({ source: sourceExtDir, type: 'link', @@ -303,7 +343,7 @@ describe('extension tests', () => { }, }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.installOrUpdateExtension({ source: sourceExtDir, type: 'link', @@ -319,7 +359,7 @@ describe('extension tests', () => { ]); }); - it('should resolve environment variables in extension configuration', () => { + it('should resolve environment variables in extension configuration', async () => { process.env['TEST_API_KEY'] = 'test-api-key-123'; process.env['TEST_DB_URL'] = 'postgresql://localhost:5432/testdb'; @@ -352,7 +392,7 @@ describe('extension tests', () => { }; fs.writeFileSync(configPath, JSON.stringify(extensionConfig)); - const extensions = extensionManager.loadExtensions(); + const extensions = await extensionManager.loadExtensions(); expect(extensions).toHaveLength(1); const extension = extensions[0]; @@ -373,7 +413,7 @@ describe('extension tests', () => { } }); - it('should resolve environment variables from an extension .env file', () => { + it('should resolve environment variables from an extension .env file', async () => { const extDir = createExtension({ extensionsDir: userExtensionsDir, name: 'test-extension', @@ -388,12 +428,19 @@ describe('extension tests', () => { }, }, }, + settings: [ + { + name: 'My API Key', + description: 'API key for testing.', + envVar: 'MY_API_KEY', + }, + ], }); const envFilePath = path.join(extDir, '.env'); fs.writeFileSync(envFilePath, 'MY_API_KEY=test-key-from-file\n'); - const extensions = extensionManager.loadExtensions(); + const extensions = await extensionManager.loadExtensions(); expect(extensions).toHaveLength(1); const extension = extensions[0]; @@ -403,7 +450,7 @@ describe('extension tests', () => { expect(serverConfig.env!['STATIC_VALUE']).toBe('no-substitution'); }); - it('should handle missing environment variables gracefully', () => { + it('should handle missing environment variables gracefully', async () => { const userExtensionsDir = path.join( tempHomeDir, EXTENSIONS_DIRECTORY_NAME, @@ -433,7 +480,7 @@ describe('extension tests', () => { JSON.stringify(extensionConfig), ); - const extensions = extensionManager.loadExtensions(); + const extensions = await extensionManager.loadExtensions(); expect(extensions).toHaveLength(1); const extension = extensions[0]; @@ -443,7 +490,7 @@ describe('extension tests', () => { expect(serverConfig.env!['MISSING_VAR_BRACES']).toBe('${ALSO_UNDEFINED}'); }); - it('should skip extensions with invalid JSON and log a warning', () => { + it('should skip extensions with invalid JSON and log a warning', async () => { const consoleSpy = vi .spyOn(console, 'error') .mockImplementation(() => {}); @@ -461,7 +508,7 @@ describe('extension tests', () => { const badConfigPath = path.join(badExtDir, EXTENSIONS_CONFIG_FILENAME); fs.writeFileSync(badConfigPath, '{ "name": "bad-ext"'); // Malformed - const extensions = extensionManager.loadExtensions(); + const extensions = await extensionManager.loadExtensions(); expect(extensions).toHaveLength(1); expect(extensions[0].name).toBe('good-ext'); @@ -474,7 +521,7 @@ describe('extension tests', () => { consoleSpy.mockRestore(); }); - it('should skip extensions with missing name and log a warning', () => { + it('should skip extensions with missing name and log a warning', async () => { const consoleSpy = vi .spyOn(console, 'error') .mockImplementation(() => {}); @@ -492,7 +539,7 @@ describe('extension tests', () => { const badConfigPath = path.join(badExtDir, EXTENSIONS_CONFIG_FILENAME); fs.writeFileSync(badConfigPath, JSON.stringify({ version: '1.0.0' })); - const extensions = extensionManager.loadExtensions(); + const extensions = await extensionManager.loadExtensions(); expect(extensions).toHaveLength(1); expect(extensions[0].name).toBe('good-ext'); @@ -505,7 +552,7 @@ describe('extension tests', () => { consoleSpy.mockRestore(); }); - it('should filter trust out of mcp servers', () => { + it('should filter trust out of mcp servers', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'test-extension', @@ -519,12 +566,12 @@ describe('extension tests', () => { }, }); - const extensions = extensionManager.loadExtensions(); + const extensions = await extensionManager.loadExtensions(); expect(extensions).toHaveLength(1); expect(extensions[0].mcpServers?.['test-server'].trust).toBeUndefined(); }); - it('should throw an error for invalid extension names', () => { + it('should throw an error for invalid extension names', async () => { const consoleSpy = vi .spyOn(console, 'error') .mockImplementation(() => {}); @@ -533,10 +580,8 @@ describe('extension tests', () => { name: 'bad_name', version: '1.0.0', }); - - const extension = extensionManager - .loadExtensions() - .find((e) => e.name === 'bad_name'); + const extensions = await extensionManager.loadExtensions(); + const extension = extensions.find((e) => e.name === 'bad_name'); expect(extension).toBeUndefined(); expect(consoleSpy).toHaveBeenCalledWith( @@ -546,7 +591,7 @@ describe('extension tests', () => { }); describe('id generation', () => { - it('should generate id from source for non-github git urls', () => { + it('should generate id from source for non-github git urls', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'my-ext', @@ -556,14 +601,12 @@ describe('extension tests', () => { source: 'http://somehost.com/foo/bar', }, }); - - const extension = extensionManager - .loadExtensions() - .find((e) => e.name === 'my-ext'); + const extensions = await extensionManager.loadExtensions(); + const extension = extensions.find((e) => e.name === 'my-ext'); expect(extension?.id).toBe(hashValue('http://somehost.com/foo/bar')); }); - it('should generate id from owner/repo for github http urls', () => { + it('should generate id from owner/repo for github http urls', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'my-ext', @@ -574,13 +617,12 @@ describe('extension tests', () => { }, }); - const extension = extensionManager - .loadExtensions() - .find((e) => e.name === 'my-ext'); + const extensions = await extensionManager.loadExtensions(); + const extension = extensions.find((e) => e.name === 'my-ext'); expect(extension?.id).toBe(hashValue('https://github.com/foo/bar')); }); - it('should generate id from owner/repo for github ssh urls', () => { + it('should generate id from owner/repo for github ssh urls', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'my-ext', @@ -591,13 +633,12 @@ describe('extension tests', () => { }, }); - const extension = extensionManager - .loadExtensions() - .find((e) => e.name === 'my-ext'); + const extensions = await extensionManager.loadExtensions(); + const extension = extensions.find((e) => e.name === 'my-ext'); expect(extension?.id).toBe(hashValue('https://github.com/foo/bar')); }); - it('should generate id from source for github-release extension', () => { + it('should generate id from source for github-release extension', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'my-ext', @@ -607,14 +648,12 @@ describe('extension tests', () => { source: 'https://github.com/foo/bar', }, }); - - const extension = extensionManager - .loadExtensions() - .find((e) => e.name === 'my-ext'); + const extensions = await extensionManager.loadExtensions(); + const extension = extensions.find((e) => e.name === 'my-ext'); expect(extension?.id).toBe(hashValue('https://github.com/foo/bar')); }); - it('should generate id from the original source for local extension', () => { + it('should generate id from the original source for local extension', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'local-ext-name', @@ -625,9 +664,8 @@ describe('extension tests', () => { }, }); - const extension = extensionManager - .loadExtensions() - .find((e) => e.name === 'local-ext-name'); + const extensions = await extensionManager.loadExtensions(); + const extension = extensions.find((e) => e.name === 'local-ext-name'); expect(extension?.id).toBe(hashValue('/some/path')); }); @@ -638,7 +676,7 @@ describe('extension tests', () => { name: 'link-ext-name', version: '1.0.0', }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.installOrUpdateExtension({ type: 'link', source: actualExtensionDir, @@ -650,16 +688,15 @@ describe('extension tests', () => { expect(extension?.id).toBe(hashValue(actualExtensionDir)); }); - it('should generate id from name for extension with no install metadata', () => { + it('should generate id from name for extension with no install metadata', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'no-meta-name', version: '1.0.0', }); - const extension = extensionManager - .loadExtensions() - .find((e) => e.name === 'no-meta-name'); + const extensions = await extensionManager.loadExtensions(); + const extension = extensions.find((e) => e.name === 'no-meta-name'); expect(extension?.id).toBe(hashValue('no-meta-name')); }); }); @@ -675,7 +712,7 @@ describe('extension tests', () => { const targetExtDir = path.join(userExtensionsDir, 'my-local-extension'); const metadataPath = path.join(targetExtDir, INSTALL_METADATA_FILENAME); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.installOrUpdateExtension({ source: sourceExtDir, type: 'local', @@ -697,7 +734,7 @@ describe('extension tests', () => { name: 'my-local-extension', version: '1.0.0', }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.installOrUpdateExtension({ source: sourceExtDir, type: 'local', @@ -791,7 +828,7 @@ describe('extension tests', () => { type: 'github-release', }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.installOrUpdateExtension({ source: gitUrl, type: 'git', @@ -816,7 +853,7 @@ describe('extension tests', () => { const metadataPath = path.join(targetExtDir, INSTALL_METADATA_FILENAME); const configPath = path.join(targetExtDir, EXTENSIONS_CONFIG_FILENAME); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.installOrUpdateExtension({ source: sourceExtDir, type: 'link', @@ -846,7 +883,7 @@ describe('extension tests', () => { name: 'my-local-extension', version: '1.1.0', }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); if (isUpdate) { await extensionManager.installOrUpdateExtension({ source: sourceExtDir, @@ -920,7 +957,7 @@ describe('extension tests', () => { }, }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await expect( extensionManager.installOrUpdateExtension({ source: sourceExtDir, @@ -952,7 +989,7 @@ This extension will run the following MCP servers: }, }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await expect( extensionManager.installOrUpdateExtension({ source: sourceExtDir, @@ -974,7 +1011,7 @@ This extension will run the following MCP servers: }, }); mockRequestConsent.mockResolvedValue(false); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await expect( extensionManager.installOrUpdateExtension({ source: sourceExtDir, @@ -992,7 +1029,7 @@ This extension will run the following MCP servers: const targetExtDir = path.join(userExtensionsDir, 'my-local-extension'); const metadataPath = path.join(targetExtDir, INSTALL_METADATA_FILENAME); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.installOrUpdateExtension({ source: sourceExtDir, type: 'local', @@ -1023,7 +1060,7 @@ This extension will run the following MCP servers: }, }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); // Install it with hard coded consent first. await extensionManager.installOrUpdateExtension({ source: sourceExtDir, @@ -1058,7 +1095,7 @@ This extension will run the following MCP servers: ], }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.installOrUpdateExtension({ source: sourceExtDir, type: 'local', @@ -1088,7 +1125,7 @@ This extension will run the following MCP servers: settings: loadSettings(tempWorkspaceDir).merged, }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.installOrUpdateExtension({ source: sourceExtDir, type: 'local', @@ -1111,7 +1148,7 @@ This extension will run the following MCP servers: }); mockPromptForSettings.mockResolvedValueOnce('old-api-key'); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); // Install it so it exists in the userExtensionsDir await extensionManager.installOrUpdateExtension({ source: oldSourceExtDir, @@ -1181,7 +1218,7 @@ This extension will run the following MCP servers: }, ], }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.installOrUpdateExtension({ source: oldSourceExtDir, type: 'local', @@ -1273,7 +1310,7 @@ This extension will run the following MCP servers: join(tempDir, extensionName), ); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.installOrUpdateExtension({ source: gitUrl, type: 'github-release', @@ -1298,7 +1335,7 @@ This extension will run the following MCP servers: type: 'github-release', }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.installOrUpdateExtension( { source: gitUrl, type: 'github-release' }, // Use github-release to force consent ); @@ -1329,7 +1366,7 @@ This extension will run the following MCP servers: }); mockRequestConsent.mockResolvedValue(false); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await expect( extensionManager.installOrUpdateExtension({ source: gitUrl, @@ -1354,7 +1391,7 @@ This extension will run the following MCP servers: type: 'github-release', }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.installOrUpdateExtension({ source: gitUrl, type: 'git', @@ -1385,7 +1422,7 @@ This extension will run the following MCP servers: type: 'github-release', }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.installOrUpdateExtension( { source: gitUrl, type: 'github-release' }, // Note the type ); @@ -1407,8 +1444,7 @@ This extension will run the following MCP servers: name: 'my-local-extension', version: '1.0.0', }); - - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.uninstallExtension('my-local-extension', false); expect(fs.existsSync(sourceExtDir)).toBe(false); @@ -1426,7 +1462,7 @@ This extension will run the following MCP servers: version: '1.0.0', }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.uninstallExtension('my-local-extension', false); expect(fs.existsSync(sourceExtDir)).toBe(false); @@ -1435,7 +1471,7 @@ This extension will run the following MCP servers: }); it('should throw an error if the extension does not exist', async () => { - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await expect( extensionManager.uninstallExtension('nonexistent-extension', false), ).rejects.toThrow('Extension not found.'); @@ -1453,7 +1489,7 @@ This extension will run the following MCP servers: }, }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.uninstallExtension( 'my-local-extension', isUpdate, @@ -1481,7 +1517,7 @@ This extension will run the following MCP servers: const enablementManager = new ExtensionEnablementManager(); enablementManager.enable('test-extension', true, '/some/scope'); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.uninstallExtension('test-extension', isUpdate); const config = enablementManager.readConfig()['test-extension']; @@ -1506,7 +1542,7 @@ This extension will run the following MCP servers: }, }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await extensionManager.uninstallExtension(gitUrl, false); expect(fs.existsSync(sourceExtDir)).toBe(false); @@ -1526,7 +1562,7 @@ This extension will run the following MCP servers: // No installMetadata provided }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); await expect( extensionManager.uninstallExtension( 'https://github.com/google/no-metadata-extension', @@ -1537,14 +1573,14 @@ This extension will run the following MCP servers: }); describe('disableExtension', () => { - it('should disable an extension at the user scope', () => { + it('should disable an extension at the user scope', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'my-extension', version: '1.0.0', }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); extensionManager.disableExtension('my-extension', SettingScope.User); expect( isEnabled({ @@ -1554,14 +1590,14 @@ This extension will run the following MCP servers: ).toBe(false); }); - it('should disable an extension at the workspace scope', () => { + it('should disable an extension at the workspace scope', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'my-extension', version: '1.0.0', }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); extensionManager.disableExtension('my-extension', SettingScope.Workspace); expect( isEnabled({ @@ -1577,14 +1613,14 @@ This extension will run the following MCP servers: ).toBe(false); }); - it('should handle disabling the same extension twice', () => { + it('should handle disabling the same extension twice', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'my-extension', version: '1.0.0', }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); extensionManager.disableExtension('my-extension', SettingScope.User); extensionManager.disableExtension('my-extension', SettingScope.User); expect( @@ -1595,13 +1631,17 @@ This extension will run the following MCP servers: ).toBe(false); }); - it('should throw an error if you request system scope', () => { - expect(() => - extensionManager.disableExtension('my-extension', SettingScope.System), - ).toThrow('System and SystemDefaults scopes are not supported.'); + it('should throw an error if you request system scope', async () => { + await expect( + async () => + await extensionManager.disableExtension( + 'my-extension', + SettingScope.System, + ), + ).rejects.toThrow('System and SystemDefaults scopes are not supported.'); }); - it('should log a disable event', () => { + it('should log a disable event', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'ext1', @@ -1612,7 +1652,7 @@ This extension will run the following MCP servers: }, }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); extensionManager.disableExtension('ext1', SettingScope.Workspace); expect(mockLogExtensionDisable).toHaveBeenCalled(); @@ -1634,41 +1674,41 @@ This extension will run the following MCP servers: return extensions.filter((e) => e.isActive); }; - it('should enable an extension at the user scope', () => { + it('should enable an extension at the user scope', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'ext1', version: '1.0.0', }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); extensionManager.disableExtension('ext1', SettingScope.User); let activeExtensions = getActiveExtensions(); expect(activeExtensions).toHaveLength(0); - extensionManager.enableExtension('ext1', SettingScope.User); - activeExtensions = getActiveExtensions(); + await extensionManager.enableExtension('ext1', SettingScope.User); + activeExtensions = await getActiveExtensions(); expect(activeExtensions).toHaveLength(1); expect(activeExtensions[0].name).toBe('ext1'); }); - it('should enable an extension at the workspace scope', () => { + it('should enable an extension at the workspace scope', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'ext1', version: '1.0.0', }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); extensionManager.disableExtension('ext1', SettingScope.Workspace); let activeExtensions = getActiveExtensions(); expect(activeExtensions).toHaveLength(0); - extensionManager.enableExtension('ext1', SettingScope.Workspace); - activeExtensions = getActiveExtensions(); + await extensionManager.enableExtension('ext1', SettingScope.Workspace); + activeExtensions = await getActiveExtensions(); expect(activeExtensions).toHaveLength(1); expect(activeExtensions[0].name).toBe('ext1'); }); - it('should log an enable event', () => { + it('should log an enable event', async () => { createExtension({ extensionsDir: userExtensionsDir, name: 'ext1', @@ -1678,7 +1718,7 @@ This extension will run the following MCP servers: type: 'local', }, }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); extensionManager.disableExtension('ext1', SettingScope.Workspace); extensionManager.enableExtension('ext1', SettingScope.Workspace); diff --git a/packages/cli/src/config/extensions/extensionSettings.test.ts b/packages/cli/src/config/extensions/extensionSettings.test.ts index 9beb8a4284..e72ba8ad1a 100644 --- a/packages/cli/src/config/extensions/extensionSettings.test.ts +++ b/packages/cli/src/config/extensions/extensionSettings.test.ts @@ -17,6 +17,7 @@ import { ExtensionStorage } from './storage.js'; import prompts from 'prompts'; import * as fsPromises from 'node:fs/promises'; import * as fs from 'node:fs'; +import { KeychainTokenStorage } from '@google/gemini-cli-core'; vi.mock('prompts'); vi.mock('os', async (importOriginal) => { @@ -27,11 +28,59 @@ vi.mock('os', async (importOriginal) => { }; }); +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const actual = + await importOriginal(); + return { + ...actual, + KeychainTokenStorage: vi.fn().mockImplementation(() => ({ + getSecret: vi.fn(), + setSecret: vi.fn(), + deleteSecret: vi.fn(), + listSecrets: vi.fn(), + isAvailable: vi.fn().mockResolvedValue(true), + })), + }; +}); + +interface MockKeychainStorage { + getSecret: ReturnType; + setSecret: ReturnType; + deleteSecret: ReturnType; + listSecrets: ReturnType; + isAvailable: ReturnType; +} + describe('extensionSettings', () => { let tempHomeDir: string; let extensionDir: string; + let mockKeychainStorage: MockKeychainStorage; + let keychainData: Record; beforeEach(() => { + vi.clearAllMocks(); + keychainData = {}; + mockKeychainStorage = { + getSecret: vi + .fn() + .mockImplementation(async (key: string) => keychainData[key] || null), + setSecret: vi + .fn() + .mockImplementation(async (key: string, value: string) => { + keychainData[key] = value; + }), + deleteSecret: vi.fn().mockImplementation(async (key: string) => { + delete keychainData[key]; + }), + listSecrets: vi + .fn() + .mockImplementation(async () => Object.keys(keychainData)), + isAvailable: vi.fn().mockResolvedValue(true), + }; + ( + KeychainTokenStorage as unknown as ReturnType + ).mockImplementation(() => mockKeychainStorage); + tempHomeDir = os.tmpdir() + path.sep + `gemini-cli-test-home-${Date.now()}`; extensionDir = path.join(tempHomeDir, '.gemini', 'extensions', 'test-ext'); // Spy and mock the method, but also create the directory so we can write to it. @@ -59,7 +108,13 @@ describe('extensionSettings', () => { it('should do nothing if settings are undefined', async () => { const config: ExtensionConfig = { name: 'test-ext', version: '1.0.0' }; - await maybePromptForSettings(config, mockRequestSetting); + await maybePromptForSettings( + config, + '12345', + mockRequestSetting, + undefined, + undefined, + ); expect(mockRequestSetting).not.toHaveBeenCalled(); }); @@ -69,11 +124,17 @@ describe('extensionSettings', () => { version: '1.0.0', settings: [], }; - await maybePromptForSettings(config, mockRequestSetting); + await maybePromptForSettings( + config, + '12345', + mockRequestSetting, + undefined, + undefined, + ); expect(mockRequestSetting).not.toHaveBeenCalled(); }); - it('should call requestSetting for each setting', async () => { + it('should prompt for all settings if there is no previous config', async () => { const config: ExtensionConfig = { name: 'test-ext', version: '1.0.0', @@ -82,14 +143,25 @@ describe('extensionSettings', () => { { name: 's2', description: 'd2', envVar: 'VAR2' }, ], }; - await maybePromptForSettings(config, mockRequestSetting); + await maybePromptForSettings( + config, + '12345', + mockRequestSetting, + undefined, + undefined, + ); expect(mockRequestSetting).toHaveBeenCalledTimes(2); expect(mockRequestSetting).toHaveBeenCalledWith(config.settings![0]); expect(mockRequestSetting).toHaveBeenCalledWith(config.settings![1]); }); - it('should write the .env file with the correct content', async () => { - const config: ExtensionConfig = { + it('should only prompt for new settings', async () => { + const previousConfig: ExtensionConfig = { + name: 'test-ext', + version: '1.0.0', + settings: [{ name: 's1', description: 'd1', envVar: 'VAR1' }], + }; + const newConfig: ExtensionConfig = { name: 'test-ext', version: '1.0.0', settings: [ @@ -97,35 +169,151 @@ describe('extensionSettings', () => { { name: 's2', description: 'd2', envVar: 'VAR2' }, ], }; - await maybePromptForSettings(config, mockRequestSetting); + const previousSettings = { VAR1: 'previous-VAR1' }; + + await maybePromptForSettings( + newConfig, + '12345', + mockRequestSetting, + previousConfig, + previousSettings, + ); + + expect(mockRequestSetting).toHaveBeenCalledTimes(1); + expect(mockRequestSetting).toHaveBeenCalledWith(newConfig.settings![1]); const expectedEnvPath = path.join(extensionDir, '.env'); const actualContent = await fsPromises.readFile(expectedEnvPath, 'utf-8'); - const expectedContent = 'VAR1=mock-VAR1\nVAR2=mock-VAR2\n'; + const expectedContent = 'VAR1=previous-VAR1\nVAR2=mock-VAR2\n'; + expect(actualContent).toBe(expectedContent); + }); + it('should remove settings that are no longer in the config', async () => { + const previousConfig: ExtensionConfig = { + name: 'test-ext', + version: '1.0.0', + settings: [ + { name: 's1', description: 'd1', envVar: 'VAR1' }, + { name: 's2', description: 'd2', envVar: 'VAR2' }, + ], + }; + const newConfig: ExtensionConfig = { + name: 'test-ext', + version: '1.0.0', + settings: [{ name: 's1', description: 'd1', envVar: 'VAR1' }], + }; + const previousSettings = { + VAR1: 'previous-VAR1', + VAR2: 'previous-VAR2', + }; + + await maybePromptForSettings( + newConfig, + '12345', + mockRequestSetting, + previousConfig, + previousSettings, + ); + + expect(mockRequestSetting).not.toHaveBeenCalled(); + + const expectedEnvPath = path.join(extensionDir, '.env'); + const actualContent = await fsPromises.readFile(expectedEnvPath, 'utf-8'); + const expectedContent = 'VAR1=previous-VAR1\n'; + expect(actualContent).toBe(expectedContent); + }); + + it('should reprompt if a setting changes sensitivity', async () => { + const previousConfig: ExtensionConfig = { + name: 'test-ext', + version: '1.0.0', + settings: [ + { name: 's1', description: 'd1', envVar: 'VAR1', sensitive: false }, + ], + }; + const newConfig: ExtensionConfig = { + name: 'test-ext', + version: '1.0.0', + settings: [ + { name: 's1', description: 'd1', envVar: 'VAR1', sensitive: true }, + ], + }; + const previousSettings = { VAR1: 'previous-VAR1' }; + + await maybePromptForSettings( + newConfig, + '12345', + mockRequestSetting, + previousConfig, + previousSettings, + ); + + expect(mockRequestSetting).toHaveBeenCalledTimes(1); + expect(mockRequestSetting).toHaveBeenCalledWith(newConfig.settings![0]); + + // The value should now be in keychain, not the .env file. + const expectedEnvPath = path.join(extensionDir, '.env'); + const actualContent = await fsPromises.readFile(expectedEnvPath, 'utf-8'); + expect(actualContent).toBe(''); + }); + + it('should not prompt if settings are identical', async () => { + const previousConfig: ExtensionConfig = { + name: 'test-ext', + version: '1.0.0', + settings: [ + { name: 's1', description: 'd1', envVar: 'VAR1' }, + { name: 's2', description: 'd2', envVar: 'VAR2' }, + ], + }; + const newConfig: ExtensionConfig = { + name: 'test-ext', + version: '1.0.0', + settings: [ + { name: 's1', description: 'd1', envVar: 'VAR1' }, + { name: 's2', description: 'd2', envVar: 'VAR2' }, + ], + }; + const previousSettings = { + VAR1: 'previous-VAR1', + VAR2: 'previous-VAR2', + }; + + await maybePromptForSettings( + newConfig, + '12345', + mockRequestSetting, + previousConfig, + previousSettings, + ); + + expect(mockRequestSetting).not.toHaveBeenCalled(); + const expectedEnvPath = path.join(extensionDir, '.env'); + const actualContent = await fsPromises.readFile(expectedEnvPath, 'utf-8'); + const expectedContent = 'VAR1=previous-VAR1\nVAR2=previous-VAR2\n'; expect(actualContent).toBe(expectedContent); }); }); describe('promptForSetting', () => { - // it('should use prompts with type "password" for sensitive settings', async () => { - // const setting: ExtensionSetting = { - // name: 'API Key', - // description: 'Your secret key', - // envVar: 'API_KEY', - // sensitive: true, - // }; - // vi.mocked(prompts).mockResolvedValue({ value: 'secret-key' }); + it('should use prompts with type "password" for sensitive settings', async () => { + const setting: ExtensionSetting = { + name: 'API Key', + description: 'Your secret key', + envVar: 'API_KEY', + sensitive: true, + }; + vi.mocked(prompts).mockResolvedValue({ value: 'secret-key' }); - // const result = await promptForSetting(setting); + const result = await promptForSetting(setting); - // expect(prompts).toHaveBeenCalledWith({ - // type: 'password', - // name: 'value', - // message: 'API Key\nYour secret key', - // }); - // expect(result).toBe('secret-key'); - // }); + expect(prompts).toHaveBeenCalledWith({ + type: 'password', + name: 'value', + message: 'API Key\nYour secret key', + }); + expect(result).toBe('secret-key'); + }); it('should use prompts with type "text" for non-sensitive settings', async () => { const setting: ExtensionSetting = { diff --git a/packages/cli/src/config/extensions/extensionSettings.ts b/packages/cli/src/config/extensions/extensionSettings.ts index 55eb70b83a..f625ef5ea8 100644 --- a/packages/cli/src/config/extensions/extensionSettings.ts +++ b/packages/cli/src/config/extensions/extensionSettings.ts @@ -12,57 +12,76 @@ import { ExtensionStorage } from './storage.js'; import type { ExtensionConfig } from '../extension.js'; import prompts from 'prompts'; +import { KeychainTokenStorage } from '@google/gemini-cli-core'; export interface ExtensionSetting { name: string; description: string; envVar: string; + // NOTE: If no value is set, this setting will be considered NOT sensitive. + sensitive?: boolean; } export async function maybePromptForSettings( extensionConfig: ExtensionConfig, + extensionId: string, requestSetting: (setting: ExtensionSetting) => Promise, previousExtensionConfig?: ExtensionConfig, previousSettings?: Record, ): Promise { const { name: extensionName, settings } = extensionConfig; + if ( + (!settings || settings.length === 0) && + (!previousExtensionConfig?.settings || + previousExtensionConfig.settings.length === 0) + ) { + return; + } const envFilePath = new ExtensionStorage(extensionName).getEnvFilePath(); + const keychain = new KeychainTokenStorage(extensionId); if (!settings || settings.length === 0) { - // No settings for this extension. Clear any existing .env file. - if (fsSync.existsSync(envFilePath)) { - await fs.writeFile(envFilePath, ''); - } + await clearSettings(envFilePath, keychain); return; } - let settingsToPrompt = settings; - if (previousExtensionConfig) { - const oldSettings = new Set( - previousExtensionConfig.settings?.map((s) => s.name) || [], - ); - settingsToPrompt = settingsToPrompt.filter((s) => !oldSettings.has(s.name)); - } + const settingsChanges = getSettingsChanges( + settings, + previousExtensionConfig?.settings ?? [], + ); const allSettings: Record = { ...(previousSettings ?? {}) }; - if (settingsToPrompt && settingsToPrompt.length > 0) { - for (const setting of settingsToPrompt) { - const answer = await requestSetting(setting); - allSettings[setting.envVar] = answer; - } + for (const removedEnvSetting of settingsChanges.removeEnv) { + delete allSettings[removedEnvSetting.envVar]; } - const validEnvVars = new Set(settings.map((s) => s.envVar)); - const finalSettings: Record = {}; - for (const [key, value] of Object.entries(allSettings)) { - if (validEnvVars.has(key)) { - finalSettings[key] = value; + for (const removedSensitiveSetting of settingsChanges.removeSensitive) { + await keychain.deleteSecret(removedSensitiveSetting.envVar); + } + + for (const setting of settingsChanges.promptForSensitive.concat( + settingsChanges.promptForEnv, + )) { + const answer = await requestSetting(setting); + allSettings[setting.envVar] = answer; + } + + const nonSensitiveSettings: Record = {}; + for (const setting of settings) { + const value = allSettings[setting.envVar]; + if (value === undefined) { + continue; + } + if (setting.sensitive) { + await keychain.setSecret(setting.envVar, value); + } else { + nonSensitiveSettings[setting.envVar] = value; } } let envContent = ''; - for (const [key, value] of Object.entries(finalSettings)) { + for (const [key, value] of Object.entries(nonSensitiveSettings)) { envContent += `${key}=${value}\n`; } @@ -73,17 +92,22 @@ export async function promptForSetting( setting: ExtensionSetting, ): Promise { const response = await prompts({ - // type: setting.sensitive ? 'password' : 'text', - type: 'text', + type: setting.sensitive ? 'password' : 'text', name: 'value', message: `${setting.name}\n${setting.description}`, }); return response.value; } -export function getEnvContents( - extensionStorage: ExtensionStorage, -): Record { +export async function getEnvContents( + extensionConfig: ExtensionConfig, + extensionId: string, +): Promise> { + if (!extensionConfig.settings || extensionConfig.settings.length === 0) { + return Promise.resolve({}); + } + const extensionStorage = new ExtensionStorage(extensionConfig.name); + const keychain = new KeychainTokenStorage(extensionId); let customEnv: Record = {}; if (fsSync.existsSync(extensionStorage.getEnvFilePath())) { const envFile = fsSync.readFileSync( @@ -92,5 +116,67 @@ export function getEnvContents( ); customEnv = dotenv.parse(envFile); } + + if (extensionConfig.settings) { + for (const setting of extensionConfig.settings) { + if (setting.sensitive) { + const secret = await keychain.getSecret(setting.envVar); + if (secret) { + customEnv[setting.envVar] = secret; + } + } + } + } return customEnv; } + +interface settingsChanges { + promptForSensitive: ExtensionSetting[]; + removeSensitive: ExtensionSetting[]; + promptForEnv: ExtensionSetting[]; + removeEnv: ExtensionSetting[]; +} +function getSettingsChanges( + settings: ExtensionSetting[], + oldSettings: ExtensionSetting[], +): settingsChanges { + const isSameSetting = (a: ExtensionSetting, b: ExtensionSetting) => + a.envVar === b.envVar && (a.sensitive ?? false) === (b.sensitive ?? false); + + const sensitiveOld = oldSettings.filter((s) => s.sensitive ?? false); + const sensitiveNew = settings.filter((s) => s.sensitive ?? false); + const envOld = oldSettings.filter((s) => !(s.sensitive ?? false)); + const envNew = settings.filter((s) => !(s.sensitive ?? false)); + + return { + promptForSensitive: sensitiveNew.filter( + (s) => !sensitiveOld.some((old) => isSameSetting(s, old)), + ), + removeSensitive: sensitiveOld.filter( + (s) => !sensitiveNew.some((neu) => isSameSetting(s, neu)), + ), + promptForEnv: envNew.filter( + (s) => !envOld.some((old) => isSameSetting(s, old)), + ), + removeEnv: envOld.filter( + (s) => !envNew.some((neu) => isSameSetting(s, neu)), + ), + }; +} + +async function clearSettings( + envFilePath: string, + keychain: KeychainTokenStorage, +) { + if (fsSync.existsSync(envFilePath)) { + await fs.writeFile(envFilePath, ''); + } + if (!keychain.isAvailable()) { + return; + } + const secrets = await keychain.listSecrets(); + for (const secret of secrets) { + await keychain.deleteSecret(secret); + } + return; +} diff --git a/packages/cli/src/config/extensions/update.test.ts b/packages/cli/src/config/extensions/update.test.ts index 8dfe841d74..c3a1fb64e4 100644 --- a/packages/cli/src/config/extensions/update.test.ts +++ b/packages/cli/src/config/extensions/update.test.ts @@ -9,7 +9,7 @@ import * as fs from 'node:fs'; import * as os from 'node:os'; import * as path from 'node:path'; import { checkForAllExtensionUpdates, updateExtension } from './update.js'; -import { GEMINI_DIR } from '@google/gemini-cli-core'; +import { GEMINI_DIR, KeychainTokenStorage } from '@google/gemini-cli-core'; import { isWorkspaceTrusted } from '../trustedFolders.js'; import { ExtensionUpdateState } from '../../ui/state/extensions.js'; import { createExtension } from '../../test-utils/createExtension.js'; @@ -64,9 +64,24 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => { logExtensionUninstall: mockLogExtensionUninstall, ExtensionInstallEvent: vi.fn(), ExtensionUninstallEvent: vi.fn(), + KeychainTokenStorage: vi.fn().mockImplementation(() => ({ + getSecret: vi.fn(), + setSecret: vi.fn(), + deleteSecret: vi.fn(), + listSecrets: vi.fn(), + isAvailable: vi.fn().mockResolvedValue(true), + })), }; }); +interface MockKeychainStorage { + getSecret: ReturnType; + setSecret: ReturnType; + deleteSecret: ReturnType; + listSecrets: ReturnType; + isAvailable: ReturnType; +} + describe('update tests', () => { let tempHomeDir: string; let tempWorkspaceDir: string; @@ -76,8 +91,32 @@ describe('update tests', () => { let mockPromptForSettings: MockedFunction< (setting: ExtensionSetting) => Promise >; + let mockKeychainStorage: MockKeychainStorage; + let keychainData: Record; beforeEach(() => { + vi.clearAllMocks(); + keychainData = {}; + mockKeychainStorage = { + getSecret: vi + .fn() + .mockImplementation(async (key: string) => keychainData[key] || null), + setSecret: vi + .fn() + .mockImplementation(async (key: string, value: string) => { + keychainData[key] = value; + }), + deleteSecret: vi.fn().mockImplementation(async (key: string) => { + delete keychainData[key]; + }), + listSecrets: vi + .fn() + .mockImplementation(async () => Object.keys(keychainData)), + isAvailable: vi.fn().mockResolvedValue(true), + }; + ( + KeychainTokenStorage as unknown as ReturnType + ).mockImplementation(() => mockKeychainStorage); tempHomeDir = fs.mkdtempSync( path.join(os.tmpdir(), 'gemini-cli-test-home-'), ); @@ -110,6 +149,7 @@ describe('update tests', () => { afterEach(() => { fs.rmSync(tempHomeDir, { recursive: true, force: true }); fs.rmSync(tempWorkspaceDir, { recursive: true, force: true }); + vi.restoreAllMocks(); }); describe('updateExtension', () => { @@ -139,11 +179,10 @@ describe('update tests', () => { ); }); mockGit.getRemotes.mockResolvedValue([{ name: 'origin' }]); - const extension = extensionManager - .loadExtensions() - .find((e) => e.name === extensionName)!; + const extensions = await extensionManager.loadExtensions(); + const extension = extensions.find((e) => e.name === extensionName)!; const updateInfo = await updateExtension( - extension, + extension!, extensionManager, ExtensionUpdateState.UPDATE_AVAILABLE, () => {}, @@ -189,11 +228,10 @@ describe('update tests', () => { const dispatch = vi.fn(); - const extension = extensionManager - .loadExtensions() - .find((e) => e.name === extensionName)!; + const extensions = await extensionManager.loadExtensions(); + const extension = extensions.find((e) => e.name === extensionName)!; await updateExtension( - extension, + extension!, extensionManager, ExtensionUpdateState.UPDATE_AVAILABLE, dispatch, @@ -231,12 +269,11 @@ describe('update tests', () => { mockGit.getRemotes.mockResolvedValue([{ name: 'origin' }]); const dispatch = vi.fn(); - const extension = extensionManager - .loadExtensions() - .find((e) => e.name === extensionName)!; + const extensions = await extensionManager.loadExtensions(); + const extension = extensions.find((e) => e.name === extensionName)!; await expect( updateExtension( - extension, + extension!, extensionManager, ExtensionUpdateState.UPDATE_AVAILABLE, dispatch, @@ -280,7 +317,7 @@ describe('update tests', () => { const dispatch = vi.fn(); await checkForAllExtensionUpdates( - extensionManager.loadExtensions(), + await extensionManager.loadExtensions(), extensionManager, dispatch, ); @@ -312,7 +349,7 @@ describe('update tests', () => { const dispatch = vi.fn(); await checkForAllExtensionUpdates( - extensionManager.loadExtensions(), + await extensionManager.loadExtensions(), extensionManager, dispatch, ); @@ -341,7 +378,7 @@ describe('update tests', () => { }); const dispatch = vi.fn(); await checkForAllExtensionUpdates( - extensionManager.loadExtensions(), + await extensionManager.loadExtensions(), extensionManager, dispatch, ); @@ -370,7 +407,7 @@ describe('update tests', () => { }); const dispatch = vi.fn(); await checkForAllExtensionUpdates( - extensionManager.loadExtensions(), + await extensionManager.loadExtensions(), extensionManager, dispatch, ); @@ -398,7 +435,7 @@ describe('update tests', () => { const dispatch = vi.fn(); await checkForAllExtensionUpdates( - extensionManager.loadExtensions(), + await extensionManager.loadExtensions(), extensionManager, dispatch, ); diff --git a/packages/cli/src/config/extensions/update.ts b/packages/cli/src/config/extensions/update.ts index 40f1330bc7..7bfa253651 100644 --- a/packages/cli/src/config/extensions/update.ts +++ b/packages/cli/src/config/extensions/update.ts @@ -58,7 +58,7 @@ export async function updateExtension( const tempDir = await ExtensionStorage.createTmpDir(); try { - const previousExtensionConfig = await extensionManager.loadExtensionConfig( + const previousExtensionConfig = extensionManager.loadExtensionConfig( extension.path, ); let updatedExtension: GeminiCLIExtension; diff --git a/packages/cli/src/config/settings.test.ts b/packages/cli/src/config/settings.test.ts index 78e85041f2..6ca94c14c3 100644 --- a/packages/cli/src/config/settings.test.ts +++ b/packages/cli/src/config/settings.test.ts @@ -2442,7 +2442,7 @@ describe('Settings Loading and Merging', () => { extensionManager, 'disableExtension', ); - mockDisableExtension.mockImplementation(() => {}); + mockDisableExtension.mockImplementation(async () => {}); migrateDeprecatedSettings(loadedSettings, extensionManager); @@ -2515,7 +2515,7 @@ describe('Settings Loading and Merging', () => { extensionManager, 'disableExtension', ); - mockDisableExtension.mockImplementation(() => {}); + mockDisableExtension.mockImplementation(async () => {}); migrateDeprecatedSettings(loadedSettings, extensionManager); diff --git a/packages/cli/src/ui/hooks/useExtensionUpdates.test.tsx b/packages/cli/src/ui/hooks/useExtensionUpdates.test.tsx index be1a415538..8e36311dc0 100644 --- a/packages/cli/src/ui/hooks/useExtensionUpdates.test.tsx +++ b/packages/cli/src/ui/hooks/useExtensionUpdates.test.tsx @@ -124,7 +124,7 @@ describe('useExtensionUpdates', () => { autoUpdate: true, }, }); - + await extensionManager.loadExtensions(); const addItem = vi.fn(); vi.mocked(checkForAllExtensionUpdates).mockImplementation( @@ -145,7 +145,6 @@ describe('useExtensionUpdates', () => { name: '', }); - extensionManager.loadExtensions(); function TestComponent() { useExtensionUpdates(extensionManager, addItem); return null; @@ -189,7 +188,7 @@ describe('useExtensionUpdates', () => { }, }); - extensionManager.loadExtensions(); + await extensionManager.loadExtensions(); const addItem = vi.fn(); diff --git a/packages/core/index.ts b/packages/core/index.ts index acc9743e61..2369b6b0e2 100644 --- a/packages/core/index.ts +++ b/packages/core/index.ts @@ -44,5 +44,6 @@ export { makeFakeConfig } from './src/test-utils/config.js'; export * from './src/utils/pathReader.js'; export { ClearcutLogger } from './src/telemetry/clearcut-logger/clearcut-logger.js'; export { logModelSlashCommand } from './src/telemetry/loggers.js'; +export { KeychainTokenStorage } from './src/mcp/token-storage/keychain-token-storage.js'; export * from './src/utils/googleQuotaErrors.js'; export type { GoogleApiError } from './src/utils/googleErrors.js'; diff --git a/packages/core/src/mcp/token-storage/keychain-token-storage.test.ts b/packages/core/src/mcp/token-storage/keychain-token-storage.test.ts index 3b97902f19..632387e23b 100644 --- a/packages/core/src/mcp/token-storage/keychain-token-storage.test.ts +++ b/packages/core/src/mcp/token-storage/keychain-token-storage.test.ts @@ -386,5 +386,54 @@ describe('KeychainTokenStorage', () => { ); }); }); + + describe('Secrets', () => { + it('should set and get a secret', async () => { + mockKeytar.setPassword.mockResolvedValue(undefined); + mockKeytar.getPassword.mockResolvedValue('secret-value'); + + await storage.setSecret('secret-key', 'secret-value'); + const value = await storage.getSecret('secret-key'); + + expect(mockKeytar.setPassword).toHaveBeenCalledWith( + mockServiceName, + '__secret__secret-key', + 'secret-value', + ); + expect(mockKeytar.getPassword).toHaveBeenCalledWith( + mockServiceName, + '__secret__secret-key', + ); + expect(value).toBe('secret-value'); + }); + + it('should delete a secret', async () => { + mockKeytar.deletePassword.mockResolvedValue(true); + await storage.deleteSecret('secret-key'); + expect(mockKeytar.deletePassword).toHaveBeenCalledWith( + mockServiceName, + '__secret__secret-key', + ); + }); + + it('should list secrets', async () => { + mockKeytar.findCredentials.mockResolvedValue([ + { account: '__secret__secret1', password: '' }, + { account: '__secret__secret2', password: '' }, + { account: 'server1', password: '' }, + ]); + const secrets = await storage.listSecrets(); + expect(secrets).toEqual(['secret1', 'secret2']); + }); + + it('should not list secrets in listServers', async () => { + mockKeytar.findCredentials.mockResolvedValue([ + { account: '__secret__secret1', password: '' }, + { account: 'server1', password: '' }, + ]); + const servers = await storage.listServers(); + expect(servers).toEqual(['server1']); + }); + }); }); }); diff --git a/packages/core/src/mcp/token-storage/keychain-token-storage.ts b/packages/core/src/mcp/token-storage/keychain-token-storage.ts index aa8cee2e9d..93e94acb29 100644 --- a/packages/core/src/mcp/token-storage/keychain-token-storage.ts +++ b/packages/core/src/mcp/token-storage/keychain-token-storage.ts @@ -6,7 +6,7 @@ import * as crypto from 'node:crypto'; import { BaseTokenStorage } from './base-token-storage.js'; -import type { OAuthCredentials } from './types.js'; +import type { OAuthCredentials, SecretStorage } from './types.js'; import { coreEvents } from '../../utils/events.js'; interface Keytar { @@ -23,8 +23,12 @@ interface Keytar { } const KEYCHAIN_TEST_PREFIX = '__keychain_test__'; +const SECRET_PREFIX = '__secret__'; -export class KeychainTokenStorage extends BaseTokenStorage { +export class KeychainTokenStorage + extends BaseTokenStorage + implements SecretStorage +{ private keychainAvailable: boolean | null = null; private keytarModule: Keytar | null = null; private keytarLoadAttempted = false; @@ -137,7 +141,11 @@ export class KeychainTokenStorage extends BaseTokenStorage { try { const credentials = await keytar.findCredentials(this.serviceName); return credentials - .filter((cred) => !cred.account.startsWith(KEYCHAIN_TEST_PREFIX)) + .filter( + (cred) => + !cred.account.startsWith(KEYCHAIN_TEST_PREFIX) && + !cred.account.startsWith(SECRET_PREFIX), + ) .map((cred: { account: string }) => cred.account); } catch (error) { coreEvents.emitFeedback( @@ -163,7 +171,11 @@ export class KeychainTokenStorage extends BaseTokenStorage { try { const credentials = ( await keytar.findCredentials(this.serviceName) - ).filter((c) => !c.account.startsWith(KEYCHAIN_TEST_PREFIX)); + ).filter( + (c) => + !c.account.startsWith(KEYCHAIN_TEST_PREFIX) && + !c.account.startsWith(SECRET_PREFIX), + ); for (const cred of credentials) { try { @@ -258,4 +270,62 @@ export class KeychainTokenStorage extends BaseTokenStorage { async isAvailable(): Promise { return this.checkKeychainAvailability(); } + + async setSecret(key: string, value: string): Promise { + if (!(await this.checkKeychainAvailability())) { + throw new Error('Keychain is not available'); + } + const keytar = await this.getKeytar(); + if (!keytar) { + throw new Error('Keytar module not available'); + } + await keytar.setPassword(this.serviceName, `${SECRET_PREFIX}${key}`, value); + } + + async getSecret(key: string): Promise { + if (!(await this.checkKeychainAvailability())) { + throw new Error('Keychain is not available'); + } + const keytar = await this.getKeytar(); + if (!keytar) { + throw new Error('Keytar module not available'); + } + return keytar.getPassword(this.serviceName, `${SECRET_PREFIX}${key}`); + } + + async deleteSecret(key: string): Promise { + if (!(await this.checkKeychainAvailability())) { + throw new Error('Keychain is not available'); + } + const keytar = await this.getKeytar(); + if (!keytar) { + throw new Error('Keytar module not available'); + } + const deleted = await keytar.deletePassword( + this.serviceName, + `${SECRET_PREFIX}${key}`, + ); + if (!deleted) { + throw new Error(`No secret found for key: ${key}`); + } + } + + async listSecrets(): Promise { + if (!(await this.checkKeychainAvailability())) { + throw new Error('Keychain is not available'); + } + const keytar = await this.getKeytar(); + if (!keytar) { + throw new Error('Keytar module not available'); + } + try { + const credentials = await keytar.findCredentials(this.serviceName); + return credentials + .filter((cred) => cred.account.startsWith(SECRET_PREFIX)) + .map((cred) => cred.account.substring(SECRET_PREFIX.length)); + } catch (error) { + console.error('Failed to list secrets from keychain:', error); + return []; + } + } } diff --git a/packages/core/src/mcp/token-storage/types.ts b/packages/core/src/mcp/token-storage/types.ts index 1e95a975e0..b167e821d8 100644 --- a/packages/core/src/mcp/token-storage/types.ts +++ b/packages/core/src/mcp/token-storage/types.ts @@ -36,6 +36,13 @@ export interface TokenStorage { clearAll(): Promise; } +export interface SecretStorage { + setSecret(key: string, value: string): Promise; + getSecret(key: string): Promise; + deleteSecret(key: string): Promise; + listSecrets(): Promise; +} + export enum TokenStorageType { KEYCHAIN = 'keychain', ENCRYPTED_FILE = 'encrypted_file',