Add support for sensitive keychain-stored per-extension settings (#11953)

This commit is contained in:
christine betts
2025-10-28 14:48:50 -04:00
committed by GitHub
parent 7a238bd938
commit 7e987113a2
22 changed files with 706 additions and 212 deletions
@@ -17,7 +17,7 @@ interface DisableArgs {
scope?: string;
}
export function handleDisable(args: DisableArgs) {
export async function handleDisable(args: DisableArgs) {
const workspaceDir = process.cwd();
const extensionManager = new ExtensionManager({
workspaceDir,
@@ -25,13 +25,16 @@ export function handleDisable(args: DisableArgs) {
requestSetting: promptForSetting,
settings: loadSettings(workspaceDir).merged,
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
try {
if (args.scope?.toLowerCase() === 'workspace') {
extensionManager.disableExtension(args.name, SettingScope.Workspace);
await extensionManager.disableExtension(
args.name,
SettingScope.Workspace,
);
} else {
extensionManager.disableExtension(args.name, SettingScope.User);
await extensionManager.disableExtension(args.name, SettingScope.User);
}
debugLogger.log(
`Extension "${args.name}" successfully disabled for scope "${args.scope}".`,
@@ -20,7 +20,7 @@ interface EnableArgs {
scope?: string;
}
export function handleEnable(args: EnableArgs) {
export async function handleEnable(args: EnableArgs) {
const workingDir = process.cwd();
const extensionManager = new ExtensionManager({
workspaceDir: workingDir,
@@ -28,7 +28,7 @@ export function handleEnable(args: EnableArgs) {
requestSetting: promptForSetting,
settings: loadSettings(workingDir).merged,
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
try {
if (args.scope?.toLowerCase() === 'workspace') {
@@ -76,7 +76,7 @@ export async function handleInstall(args: InstallArgs) {
requestSetting: promptForSetting,
settings: loadSettings(workspaceDir).merged,
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
const name =
await extensionManager.installOrUpdateExtension(installMetadata);
debugLogger.log(`Extension "${name}" installed successfully and enabled.`);
+1 -1
View File
@@ -33,7 +33,7 @@ export async function handleLink(args: InstallArgs) {
requestSetting: promptForSetting,
settings: loadSettings(workspaceDir).merged,
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
const extensionName =
await extensionManager.installOrUpdateExtension(installMetadata);
debugLogger.log(
+1 -1
View File
@@ -21,7 +21,7 @@ export async function handleList() {
requestSetting: promptForSetting,
settings: loadSettings(workspaceDir).merged,
});
const extensions = extensionManager.loadExtensions();
const extensions = await extensionManager.loadExtensions();
if (extensions.length === 0) {
debugLogger.log('No extensions installed.');
return;
@@ -25,7 +25,7 @@ export async function handleUninstall(args: UninstallArgs) {
requestSetting: promptForSetting,
settings: loadSettings(workspaceDir).merged,
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.uninstallExtension(args.name, false);
debugLogger.log(`Extension "${args.name}" successfully uninstalled.`);
} catch (error) {
@@ -37,7 +37,7 @@ export async function handleUpdate(args: UpdateArgs) {
settings: loadSettings(workspaceDir).merged,
});
const extensions = extensionManager.loadExtensions();
const extensions = await extensionManager.loadExtensions();
if (args.name) {
try {
const extension = extensions.find(
+1 -1
View File
@@ -33,7 +33,7 @@ async function getMcpServersFromConfig(): Promise<
requestConsent: requestConsentNonInteractive,
requestSetting: promptForSetting,
});
const extensions = extensionManager.loadExtensions();
const extensions = await extensionManager.loadExtensions();
const mcpServers = { ...(settings.merged.mcpServers || {}) };
for (const extension of extensions) {
Object.entries(extension.mcpServers || {}).forEach(([key, server]) => {
+1 -1
View File
@@ -423,7 +423,7 @@ export async function loadCliConfig(
workspaceDir: cwd,
enabledExtensionOverrides: argv.extensions,
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
// Call the (now wrapper) loadHierarchicalGeminiMemory which calls the server's version
const { memoryContent, fileCount, filePaths } =
+26 -12
View File
@@ -133,7 +133,7 @@ export class ExtensionManager implements ExtensionLoader {
const isUpdate = !!previousExtensionConfig;
let newExtensionConfig: ExtensionConfig | null = null;
let localSourcePath: string | undefined;
let extension: GeminiCLIExtension;
let extension: GeminiCLIExtension | null;
try {
if (!isWorkspaceTrusted(this.settings).isTrusted) {
throw new Error(
@@ -243,12 +243,16 @@ export class ExtensionManager implements ExtensionLoader {
this.requestConsent,
previousExtensionConfig,
);
const extensionStorage = new ExtensionStorage(newExtensionName);
const destinationPath = extensionStorage.getExtensionDir();
const extensionId = getExtensionId(newExtensionConfig, installMetadata);
const destinationPath = new ExtensionStorage(
newExtensionName,
).getExtensionDir();
let previousSettings: Record<string, string> | undefined;
if (isUpdate) {
previousSettings = getEnvContents(extensionStorage);
previousSettings = await getEnvContents(
previousExtensionConfig,
extensionId,
);
await this.uninstallExtension(newExtensionName, isUpdate);
}
@@ -257,6 +261,7 @@ export class ExtensionManager implements ExtensionLoader {
if (isUpdate) {
await maybePromptForSettings(
newExtensionConfig,
extensionId,
this.requestSetting,
previousExtensionConfig,
previousSettings,
@@ -264,6 +269,7 @@ export class ExtensionManager implements ExtensionLoader {
} else {
await maybePromptForSettings(
newExtensionConfig,
extensionId,
this.requestSetting,
);
}
@@ -286,7 +292,10 @@ export class ExtensionManager implements ExtensionLoader {
// TODO: Gracefully handle this call failing, we should back up the old
// extension prior to overwriting it and then restore it.
extension = this.loadExtension(destinationPath)!;
extension = await this.loadExtension(destinationPath)!;
if (!extension) {
throw new Error(`Extension not found`);
}
if (isUpdate) {
logExtensionUpdateEvent(
this.telemetryConfig,
@@ -401,7 +410,7 @@ export class ExtensionManager implements ExtensionLoader {
this.eventEmitter.emit('extensionUninstalled', { extension });
}
loadExtensions(): GeminiCLIExtension[] {
async loadExtensions(): Promise<GeminiCLIExtension[]> {
if (this.loadedExtensions) {
throw new Error('Extensions already loaded, only load extensions once.');
}
@@ -413,12 +422,14 @@ export class ExtensionManager implements ExtensionLoader {
for (const subdir of fs.readdirSync(extensionsDir)) {
const extensionDir = path.join(extensionsDir, subdir);
this.loadExtension(extensionDir);
await this.loadExtension(extensionDir);
}
return this.loadedExtensions;
}
private loadExtension(extensionDir: string): GeminiCLIExtension | null {
private async loadExtension(
extensionDir: string,
): Promise<GeminiCLIExtension | null> {
this.loadedExtensions ??= [];
if (!fs.statSync(extensionDir).isDirectory()) {
return null;
@@ -441,7 +452,10 @@ export class ExtensionManager implements ExtensionLoader {
);
}
const customEnv = getEnvContents(new ExtensionStorage(config.name));
const customEnv = await getEnvContents(
config,
getExtensionId(config, installMetadata),
);
config = resolveEnvVarsInObject(config, customEnv);
if (config.mcpServers) {
@@ -573,7 +587,7 @@ export class ExtensionManager implements ExtensionLoader {
return output;
}
disableExtension(name: string, scope: SettingScope) {
async disableExtension(name: string, scope: SettingScope) {
if (
scope === SettingScope.System ||
scope === SettingScope.SystemDefaults
@@ -598,7 +612,7 @@ export class ExtensionManager implements ExtensionLoader {
this.eventEmitter.emit('extensionDisabled', { extension });
}
enableExtension(name: string, scope: SettingScope) {
async enableExtension(name: string, scope: SettingScope) {
if (
scope === SettingScope.System ||
scope === SettingScope.SystemDefaults
+146 -106
View File
@@ -13,6 +13,7 @@ import {
ExtensionUninstallEvent,
ExtensionDisableEvent,
ExtensionEnableEvent,
KeychainTokenStorage,
} from '@google/gemini-cli-core';
import { loadSettings, SettingScope } from './settings.js';
import { isWorkspaceTrusted } from './trustedFolders.js';
@@ -96,6 +97,13 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
ExtensionInstallEvent: vi.fn(),
ExtensionUninstallEvent: vi.fn(),
ExtensionDisableEvent: vi.fn(),
KeychainTokenStorage: vi.fn().mockImplementation(() => ({
getSecret: vi.fn(),
setSecret: vi.fn(),
deleteSecret: vi.fn(),
listSecrets: vi.fn(),
isAvailable: vi.fn().mockResolvedValue(true),
})),
};
});
@@ -107,6 +115,14 @@ vi.mock('child_process', async (importOriginal) => {
};
});
interface MockKeychainStorage {
getSecret: ReturnType<typeof vi.fn>;
setSecret: ReturnType<typeof vi.fn>;
deleteSecret: ReturnType<typeof vi.fn>;
listSecrets: ReturnType<typeof vi.fn>;
isAvailable: ReturnType<typeof vi.fn>;
}
describe('extension tests', () => {
let tempHomeDir: string;
let tempWorkspaceDir: string;
@@ -116,8 +132,32 @@ describe('extension tests', () => {
let mockPromptForSettings: MockedFunction<
(setting: ExtensionSetting) => Promise<string>
>;
let mockKeychainStorage: MockKeychainStorage;
let keychainData: Record<string, string>;
beforeEach(() => {
vi.clearAllMocks();
keychainData = {};
mockKeychainStorage = {
getSecret: vi
.fn()
.mockImplementation(async (key: string) => keychainData[key] || null),
setSecret: vi
.fn()
.mockImplementation(async (key: string, value: string) => {
keychainData[key] = value;
}),
deleteSecret: vi.fn().mockImplementation(async (key: string) => {
delete keychainData[key];
}),
listSecrets: vi
.fn()
.mockImplementation(async () => Object.keys(keychainData)),
isAvailable: vi.fn().mockResolvedValue(true),
};
(
KeychainTokenStorage as unknown as ReturnType<typeof vi.fn>
).mockImplementation(() => mockKeychainStorage);
tempHomeDir = fs.mkdtempSync(
path.join(os.tmpdir(), 'gemini-cli-test-home-'),
);
@@ -151,7 +191,7 @@ describe('extension tests', () => {
});
describe('loadExtensions', () => {
it('should include extension path in loaded extension', () => {
it('should include extension path in loaded extension', async () => {
const extensionDir = path.join(userExtensionsDir, 'test-extension');
fs.mkdirSync(extensionDir, { recursive: true });
@@ -161,13 +201,13 @@ describe('extension tests', () => {
version: '1.0.0',
});
const extensions = extensionManager.loadExtensions();
const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1);
expect(extensions[0].path).toBe(extensionDir);
expect(extensions[0].name).toBe('test-extension');
});
it('should load context file path when GEMINI.md is present', () => {
it('should load context file path when GEMINI.md is present', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'ext1',
@@ -180,7 +220,7 @@ describe('extension tests', () => {
version: '2.0.0',
});
const extensions = extensionManager.loadExtensions();
const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(2);
const ext1 = extensions.find((e) => e.name === 'ext1');
@@ -191,7 +231,7 @@ describe('extension tests', () => {
expect(ext2?.contextFiles).toEqual([]);
});
it('should load context file path from the extension config', () => {
it('should load context file path from the extension config', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'ext1',
@@ -200,7 +240,7 @@ describe('extension tests', () => {
contextFileName: 'my-context-file.md',
});
const extensions = extensionManager.loadExtensions();
const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1);
const ext1 = extensions.find((e) => e.name === 'ext1');
@@ -209,7 +249,7 @@ describe('extension tests', () => {
]);
});
it('should annotate disabled extensions', () => {
it('should annotate disabled extensions', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'disabled-extension',
@@ -220,8 +260,8 @@ describe('extension tests', () => {
name: 'enabled-extension',
version: '2.0.0',
});
extensionManager.loadExtensions();
extensionManager.disableExtension(
await extensionManager.loadExtensions();
await extensionManager.disableExtension(
'disabled-extension',
SettingScope.User,
);
@@ -233,7 +273,7 @@ describe('extension tests', () => {
expect(extensions[1].isActive).toBe(true);
});
it('should hydrate variables', () => {
it('should hydrate variables', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'test-extension',
@@ -247,7 +287,7 @@ describe('extension tests', () => {
},
});
const extensions = extensionManager.loadExtensions();
const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1);
const expectedCwd = path.join(
userExtensionsDir,
@@ -266,7 +306,7 @@ describe('extension tests', () => {
});
fs.writeFileSync(path.join(sourceExtDir, 'context.md'), 'linked context');
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
const extension = await extensionManager.installOrUpdateExtension({
source: sourceExtDir,
type: 'link',
@@ -303,7 +343,7 @@ describe('extension tests', () => {
},
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({
source: sourceExtDir,
type: 'link',
@@ -319,7 +359,7 @@ describe('extension tests', () => {
]);
});
it('should resolve environment variables in extension configuration', () => {
it('should resolve environment variables in extension configuration', async () => {
process.env['TEST_API_KEY'] = 'test-api-key-123';
process.env['TEST_DB_URL'] = 'postgresql://localhost:5432/testdb';
@@ -352,7 +392,7 @@ describe('extension tests', () => {
};
fs.writeFileSync(configPath, JSON.stringify(extensionConfig));
const extensions = extensionManager.loadExtensions();
const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1);
const extension = extensions[0];
@@ -373,7 +413,7 @@ describe('extension tests', () => {
}
});
it('should resolve environment variables from an extension .env file', () => {
it('should resolve environment variables from an extension .env file', async () => {
const extDir = createExtension({
extensionsDir: userExtensionsDir,
name: 'test-extension',
@@ -388,12 +428,19 @@ describe('extension tests', () => {
},
},
},
settings: [
{
name: 'My API Key',
description: 'API key for testing.',
envVar: 'MY_API_KEY',
},
],
});
const envFilePath = path.join(extDir, '.env');
fs.writeFileSync(envFilePath, 'MY_API_KEY=test-key-from-file\n');
const extensions = extensionManager.loadExtensions();
const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1);
const extension = extensions[0];
@@ -403,7 +450,7 @@ describe('extension tests', () => {
expect(serverConfig.env!['STATIC_VALUE']).toBe('no-substitution');
});
it('should handle missing environment variables gracefully', () => {
it('should handle missing environment variables gracefully', async () => {
const userExtensionsDir = path.join(
tempHomeDir,
EXTENSIONS_DIRECTORY_NAME,
@@ -433,7 +480,7 @@ describe('extension tests', () => {
JSON.stringify(extensionConfig),
);
const extensions = extensionManager.loadExtensions();
const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1);
const extension = extensions[0];
@@ -443,7 +490,7 @@ describe('extension tests', () => {
expect(serverConfig.env!['MISSING_VAR_BRACES']).toBe('${ALSO_UNDEFINED}');
});
it('should skip extensions with invalid JSON and log a warning', () => {
it('should skip extensions with invalid JSON and log a warning', async () => {
const consoleSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
@@ -461,7 +508,7 @@ describe('extension tests', () => {
const badConfigPath = path.join(badExtDir, EXTENSIONS_CONFIG_FILENAME);
fs.writeFileSync(badConfigPath, '{ "name": "bad-ext"'); // Malformed
const extensions = extensionManager.loadExtensions();
const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1);
expect(extensions[0].name).toBe('good-ext');
@@ -474,7 +521,7 @@ describe('extension tests', () => {
consoleSpy.mockRestore();
});
it('should skip extensions with missing name and log a warning', () => {
it('should skip extensions with missing name and log a warning', async () => {
const consoleSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
@@ -492,7 +539,7 @@ describe('extension tests', () => {
const badConfigPath = path.join(badExtDir, EXTENSIONS_CONFIG_FILENAME);
fs.writeFileSync(badConfigPath, JSON.stringify({ version: '1.0.0' }));
const extensions = extensionManager.loadExtensions();
const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1);
expect(extensions[0].name).toBe('good-ext');
@@ -505,7 +552,7 @@ describe('extension tests', () => {
consoleSpy.mockRestore();
});
it('should filter trust out of mcp servers', () => {
it('should filter trust out of mcp servers', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'test-extension',
@@ -519,12 +566,12 @@ describe('extension tests', () => {
},
});
const extensions = extensionManager.loadExtensions();
const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1);
expect(extensions[0].mcpServers?.['test-server'].trust).toBeUndefined();
});
it('should throw an error for invalid extension names', () => {
it('should throw an error for invalid extension names', async () => {
const consoleSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
@@ -533,10 +580,8 @@ describe('extension tests', () => {
name: 'bad_name',
version: '1.0.0',
});
const extension = extensionManager
.loadExtensions()
.find((e) => e.name === 'bad_name');
const extensions = await extensionManager.loadExtensions();
const extension = extensions.find((e) => e.name === 'bad_name');
expect(extension).toBeUndefined();
expect(consoleSpy).toHaveBeenCalledWith(
@@ -546,7 +591,7 @@ describe('extension tests', () => {
});
describe('id generation', () => {
it('should generate id from source for non-github git urls', () => {
it('should generate id from source for non-github git urls', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'my-ext',
@@ -556,14 +601,12 @@ describe('extension tests', () => {
source: 'http://somehost.com/foo/bar',
},
});
const extension = extensionManager
.loadExtensions()
.find((e) => e.name === 'my-ext');
const extensions = await extensionManager.loadExtensions();
const extension = extensions.find((e) => e.name === 'my-ext');
expect(extension?.id).toBe(hashValue('http://somehost.com/foo/bar'));
});
it('should generate id from owner/repo for github http urls', () => {
it('should generate id from owner/repo for github http urls', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'my-ext',
@@ -574,13 +617,12 @@ describe('extension tests', () => {
},
});
const extension = extensionManager
.loadExtensions()
.find((e) => e.name === 'my-ext');
const extensions = await extensionManager.loadExtensions();
const extension = extensions.find((e) => e.name === 'my-ext');
expect(extension?.id).toBe(hashValue('https://github.com/foo/bar'));
});
it('should generate id from owner/repo for github ssh urls', () => {
it('should generate id from owner/repo for github ssh urls', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'my-ext',
@@ -591,13 +633,12 @@ describe('extension tests', () => {
},
});
const extension = extensionManager
.loadExtensions()
.find((e) => e.name === 'my-ext');
const extensions = await extensionManager.loadExtensions();
const extension = extensions.find((e) => e.name === 'my-ext');
expect(extension?.id).toBe(hashValue('https://github.com/foo/bar'));
});
it('should generate id from source for github-release extension', () => {
it('should generate id from source for github-release extension', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'my-ext',
@@ -607,14 +648,12 @@ describe('extension tests', () => {
source: 'https://github.com/foo/bar',
},
});
const extension = extensionManager
.loadExtensions()
.find((e) => e.name === 'my-ext');
const extensions = await extensionManager.loadExtensions();
const extension = extensions.find((e) => e.name === 'my-ext');
expect(extension?.id).toBe(hashValue('https://github.com/foo/bar'));
});
it('should generate id from the original source for local extension', () => {
it('should generate id from the original source for local extension', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'local-ext-name',
@@ -625,9 +664,8 @@ describe('extension tests', () => {
},
});
const extension = extensionManager
.loadExtensions()
.find((e) => e.name === 'local-ext-name');
const extensions = await extensionManager.loadExtensions();
const extension = extensions.find((e) => e.name === 'local-ext-name');
expect(extension?.id).toBe(hashValue('/some/path'));
});
@@ -638,7 +676,7 @@ describe('extension tests', () => {
name: 'link-ext-name',
version: '1.0.0',
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({
type: 'link',
source: actualExtensionDir,
@@ -650,16 +688,15 @@ describe('extension tests', () => {
expect(extension?.id).toBe(hashValue(actualExtensionDir));
});
it('should generate id from name for extension with no install metadata', () => {
it('should generate id from name for extension with no install metadata', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'no-meta-name',
version: '1.0.0',
});
const extension = extensionManager
.loadExtensions()
.find((e) => e.name === 'no-meta-name');
const extensions = await extensionManager.loadExtensions();
const extension = extensions.find((e) => e.name === 'no-meta-name');
expect(extension?.id).toBe(hashValue('no-meta-name'));
});
});
@@ -675,7 +712,7 @@ describe('extension tests', () => {
const targetExtDir = path.join(userExtensionsDir, 'my-local-extension');
const metadataPath = path.join(targetExtDir, INSTALL_METADATA_FILENAME);
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({
source: sourceExtDir,
type: 'local',
@@ -697,7 +734,7 @@ describe('extension tests', () => {
name: 'my-local-extension',
version: '1.0.0',
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({
source: sourceExtDir,
type: 'local',
@@ -791,7 +828,7 @@ describe('extension tests', () => {
type: 'github-release',
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({
source: gitUrl,
type: 'git',
@@ -816,7 +853,7 @@ describe('extension tests', () => {
const metadataPath = path.join(targetExtDir, INSTALL_METADATA_FILENAME);
const configPath = path.join(targetExtDir, EXTENSIONS_CONFIG_FILENAME);
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({
source: sourceExtDir,
type: 'link',
@@ -846,7 +883,7 @@ describe('extension tests', () => {
name: 'my-local-extension',
version: '1.1.0',
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
if (isUpdate) {
await extensionManager.installOrUpdateExtension({
source: sourceExtDir,
@@ -920,7 +957,7 @@ describe('extension tests', () => {
},
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await expect(
extensionManager.installOrUpdateExtension({
source: sourceExtDir,
@@ -952,7 +989,7 @@ This extension will run the following MCP servers:
},
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await expect(
extensionManager.installOrUpdateExtension({
source: sourceExtDir,
@@ -974,7 +1011,7 @@ This extension will run the following MCP servers:
},
});
mockRequestConsent.mockResolvedValue(false);
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await expect(
extensionManager.installOrUpdateExtension({
source: sourceExtDir,
@@ -992,7 +1029,7 @@ This extension will run the following MCP servers:
const targetExtDir = path.join(userExtensionsDir, 'my-local-extension');
const metadataPath = path.join(targetExtDir, INSTALL_METADATA_FILENAME);
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({
source: sourceExtDir,
type: 'local',
@@ -1023,7 +1060,7 @@ This extension will run the following MCP servers:
},
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
// Install it with hard coded consent first.
await extensionManager.installOrUpdateExtension({
source: sourceExtDir,
@@ -1058,7 +1095,7 @@ This extension will run the following MCP servers:
],
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({
source: sourceExtDir,
type: 'local',
@@ -1088,7 +1125,7 @@ This extension will run the following MCP servers:
settings: loadSettings(tempWorkspaceDir).merged,
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({
source: sourceExtDir,
type: 'local',
@@ -1111,7 +1148,7 @@ This extension will run the following MCP servers:
});
mockPromptForSettings.mockResolvedValueOnce('old-api-key');
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
// Install it so it exists in the userExtensionsDir
await extensionManager.installOrUpdateExtension({
source: oldSourceExtDir,
@@ -1181,7 +1218,7 @@ This extension will run the following MCP servers:
},
],
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({
source: oldSourceExtDir,
type: 'local',
@@ -1273,7 +1310,7 @@ This extension will run the following MCP servers:
join(tempDir, extensionName),
);
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({
source: gitUrl,
type: 'github-release',
@@ -1298,7 +1335,7 @@ This extension will run the following MCP servers:
type: 'github-release',
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension(
{ source: gitUrl, type: 'github-release' }, // Use github-release to force consent
);
@@ -1329,7 +1366,7 @@ This extension will run the following MCP servers:
});
mockRequestConsent.mockResolvedValue(false);
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await expect(
extensionManager.installOrUpdateExtension({
source: gitUrl,
@@ -1354,7 +1391,7 @@ This extension will run the following MCP servers:
type: 'github-release',
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension({
source: gitUrl,
type: 'git',
@@ -1385,7 +1422,7 @@ This extension will run the following MCP servers:
type: 'github-release',
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.installOrUpdateExtension(
{ source: gitUrl, type: 'github-release' }, // Note the type
);
@@ -1407,8 +1444,7 @@ This extension will run the following MCP servers:
name: 'my-local-extension',
version: '1.0.0',
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.uninstallExtension('my-local-extension', false);
expect(fs.existsSync(sourceExtDir)).toBe(false);
@@ -1426,7 +1462,7 @@ This extension will run the following MCP servers:
version: '1.0.0',
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.uninstallExtension('my-local-extension', false);
expect(fs.existsSync(sourceExtDir)).toBe(false);
@@ -1435,7 +1471,7 @@ This extension will run the following MCP servers:
});
it('should throw an error if the extension does not exist', async () => {
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await expect(
extensionManager.uninstallExtension('nonexistent-extension', false),
).rejects.toThrow('Extension not found.');
@@ -1453,7 +1489,7 @@ This extension will run the following MCP servers:
},
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.uninstallExtension(
'my-local-extension',
isUpdate,
@@ -1481,7 +1517,7 @@ This extension will run the following MCP servers:
const enablementManager = new ExtensionEnablementManager();
enablementManager.enable('test-extension', true, '/some/scope');
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.uninstallExtension('test-extension', isUpdate);
const config = enablementManager.readConfig()['test-extension'];
@@ -1506,7 +1542,7 @@ This extension will run the following MCP servers:
},
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await extensionManager.uninstallExtension(gitUrl, false);
expect(fs.existsSync(sourceExtDir)).toBe(false);
@@ -1526,7 +1562,7 @@ This extension will run the following MCP servers:
// No installMetadata provided
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
await expect(
extensionManager.uninstallExtension(
'https://github.com/google/no-metadata-extension',
@@ -1537,14 +1573,14 @@ This extension will run the following MCP servers:
});
describe('disableExtension', () => {
it('should disable an extension at the user scope', () => {
it('should disable an extension at the user scope', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'my-extension',
version: '1.0.0',
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
extensionManager.disableExtension('my-extension', SettingScope.User);
expect(
isEnabled({
@@ -1554,14 +1590,14 @@ This extension will run the following MCP servers:
).toBe(false);
});
it('should disable an extension at the workspace scope', () => {
it('should disable an extension at the workspace scope', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'my-extension',
version: '1.0.0',
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
extensionManager.disableExtension('my-extension', SettingScope.Workspace);
expect(
isEnabled({
@@ -1577,14 +1613,14 @@ This extension will run the following MCP servers:
).toBe(false);
});
it('should handle disabling the same extension twice', () => {
it('should handle disabling the same extension twice', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'my-extension',
version: '1.0.0',
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
extensionManager.disableExtension('my-extension', SettingScope.User);
extensionManager.disableExtension('my-extension', SettingScope.User);
expect(
@@ -1595,13 +1631,17 @@ This extension will run the following MCP servers:
).toBe(false);
});
it('should throw an error if you request system scope', () => {
expect(() =>
extensionManager.disableExtension('my-extension', SettingScope.System),
).toThrow('System and SystemDefaults scopes are not supported.');
it('should throw an error if you request system scope', async () => {
await expect(
async () =>
await extensionManager.disableExtension(
'my-extension',
SettingScope.System,
),
).rejects.toThrow('System and SystemDefaults scopes are not supported.');
});
it('should log a disable event', () => {
it('should log a disable event', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'ext1',
@@ -1612,7 +1652,7 @@ This extension will run the following MCP servers:
},
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
extensionManager.disableExtension('ext1', SettingScope.Workspace);
expect(mockLogExtensionDisable).toHaveBeenCalled();
@@ -1634,41 +1674,41 @@ This extension will run the following MCP servers:
return extensions.filter((e) => e.isActive);
};
it('should enable an extension at the user scope', () => {
it('should enable an extension at the user scope', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'ext1',
version: '1.0.0',
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
extensionManager.disableExtension('ext1', SettingScope.User);
let activeExtensions = getActiveExtensions();
expect(activeExtensions).toHaveLength(0);
extensionManager.enableExtension('ext1', SettingScope.User);
activeExtensions = getActiveExtensions();
await extensionManager.enableExtension('ext1', SettingScope.User);
activeExtensions = await getActiveExtensions();
expect(activeExtensions).toHaveLength(1);
expect(activeExtensions[0].name).toBe('ext1');
});
it('should enable an extension at the workspace scope', () => {
it('should enable an extension at the workspace scope', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'ext1',
version: '1.0.0',
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
extensionManager.disableExtension('ext1', SettingScope.Workspace);
let activeExtensions = getActiveExtensions();
expect(activeExtensions).toHaveLength(0);
extensionManager.enableExtension('ext1', SettingScope.Workspace);
activeExtensions = getActiveExtensions();
await extensionManager.enableExtension('ext1', SettingScope.Workspace);
activeExtensions = await getActiveExtensions();
expect(activeExtensions).toHaveLength(1);
expect(activeExtensions[0].name).toBe('ext1');
});
it('should log an enable event', () => {
it('should log an enable event', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'ext1',
@@ -1678,7 +1718,7 @@ This extension will run the following MCP servers:
type: 'local',
},
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
extensionManager.disableExtension('ext1', SettingScope.Workspace);
extensionManager.enableExtension('ext1', SettingScope.Workspace);
@@ -17,6 +17,7 @@ import { ExtensionStorage } from './storage.js';
import prompts from 'prompts';
import * as fsPromises from 'node:fs/promises';
import * as fs from 'node:fs';
import { KeychainTokenStorage } from '@google/gemini-cli-core';
vi.mock('prompts');
vi.mock('os', async (importOriginal) => {
@@ -27,11 +28,59 @@ vi.mock('os', async (importOriginal) => {
};
});
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const actual =
await importOriginal<typeof import('@google/gemini-cli-core')>();
return {
...actual,
KeychainTokenStorage: vi.fn().mockImplementation(() => ({
getSecret: vi.fn(),
setSecret: vi.fn(),
deleteSecret: vi.fn(),
listSecrets: vi.fn(),
isAvailable: vi.fn().mockResolvedValue(true),
})),
};
});
interface MockKeychainStorage {
getSecret: ReturnType<typeof vi.fn>;
setSecret: ReturnType<typeof vi.fn>;
deleteSecret: ReturnType<typeof vi.fn>;
listSecrets: ReturnType<typeof vi.fn>;
isAvailable: ReturnType<typeof vi.fn>;
}
describe('extensionSettings', () => {
let tempHomeDir: string;
let extensionDir: string;
let mockKeychainStorage: MockKeychainStorage;
let keychainData: Record<string, string>;
beforeEach(() => {
vi.clearAllMocks();
keychainData = {};
mockKeychainStorage = {
getSecret: vi
.fn()
.mockImplementation(async (key: string) => keychainData[key] || null),
setSecret: vi
.fn()
.mockImplementation(async (key: string, value: string) => {
keychainData[key] = value;
}),
deleteSecret: vi.fn().mockImplementation(async (key: string) => {
delete keychainData[key];
}),
listSecrets: vi
.fn()
.mockImplementation(async () => Object.keys(keychainData)),
isAvailable: vi.fn().mockResolvedValue(true),
};
(
KeychainTokenStorage as unknown as ReturnType<typeof vi.fn>
).mockImplementation(() => mockKeychainStorage);
tempHomeDir = os.tmpdir() + path.sep + `gemini-cli-test-home-${Date.now()}`;
extensionDir = path.join(tempHomeDir, '.gemini', 'extensions', 'test-ext');
// Spy and mock the method, but also create the directory so we can write to it.
@@ -59,7 +108,13 @@ describe('extensionSettings', () => {
it('should do nothing if settings are undefined', async () => {
const config: ExtensionConfig = { name: 'test-ext', version: '1.0.0' };
await maybePromptForSettings(config, mockRequestSetting);
await maybePromptForSettings(
config,
'12345',
mockRequestSetting,
undefined,
undefined,
);
expect(mockRequestSetting).not.toHaveBeenCalled();
});
@@ -69,11 +124,17 @@ describe('extensionSettings', () => {
version: '1.0.0',
settings: [],
};
await maybePromptForSettings(config, mockRequestSetting);
await maybePromptForSettings(
config,
'12345',
mockRequestSetting,
undefined,
undefined,
);
expect(mockRequestSetting).not.toHaveBeenCalled();
});
it('should call requestSetting for each setting', async () => {
it('should prompt for all settings if there is no previous config', async () => {
const config: ExtensionConfig = {
name: 'test-ext',
version: '1.0.0',
@@ -82,14 +143,25 @@ describe('extensionSettings', () => {
{ name: 's2', description: 'd2', envVar: 'VAR2' },
],
};
await maybePromptForSettings(config, mockRequestSetting);
await maybePromptForSettings(
config,
'12345',
mockRequestSetting,
undefined,
undefined,
);
expect(mockRequestSetting).toHaveBeenCalledTimes(2);
expect(mockRequestSetting).toHaveBeenCalledWith(config.settings![0]);
expect(mockRequestSetting).toHaveBeenCalledWith(config.settings![1]);
});
it('should write the .env file with the correct content', async () => {
const config: ExtensionConfig = {
it('should only prompt for new settings', async () => {
const previousConfig: ExtensionConfig = {
name: 'test-ext',
version: '1.0.0',
settings: [{ name: 's1', description: 'd1', envVar: 'VAR1' }],
};
const newConfig: ExtensionConfig = {
name: 'test-ext',
version: '1.0.0',
settings: [
@@ -97,35 +169,151 @@ describe('extensionSettings', () => {
{ name: 's2', description: 'd2', envVar: 'VAR2' },
],
};
await maybePromptForSettings(config, mockRequestSetting);
const previousSettings = { VAR1: 'previous-VAR1' };
await maybePromptForSettings(
newConfig,
'12345',
mockRequestSetting,
previousConfig,
previousSettings,
);
expect(mockRequestSetting).toHaveBeenCalledTimes(1);
expect(mockRequestSetting).toHaveBeenCalledWith(newConfig.settings![1]);
const expectedEnvPath = path.join(extensionDir, '.env');
const actualContent = await fsPromises.readFile(expectedEnvPath, 'utf-8');
const expectedContent = 'VAR1=mock-VAR1\nVAR2=mock-VAR2\n';
const expectedContent = 'VAR1=previous-VAR1\nVAR2=mock-VAR2\n';
expect(actualContent).toBe(expectedContent);
});
it('should remove settings that are no longer in the config', async () => {
const previousConfig: ExtensionConfig = {
name: 'test-ext',
version: '1.0.0',
settings: [
{ name: 's1', description: 'd1', envVar: 'VAR1' },
{ name: 's2', description: 'd2', envVar: 'VAR2' },
],
};
const newConfig: ExtensionConfig = {
name: 'test-ext',
version: '1.0.0',
settings: [{ name: 's1', description: 'd1', envVar: 'VAR1' }],
};
const previousSettings = {
VAR1: 'previous-VAR1',
VAR2: 'previous-VAR2',
};
await maybePromptForSettings(
newConfig,
'12345',
mockRequestSetting,
previousConfig,
previousSettings,
);
expect(mockRequestSetting).not.toHaveBeenCalled();
const expectedEnvPath = path.join(extensionDir, '.env');
const actualContent = await fsPromises.readFile(expectedEnvPath, 'utf-8');
const expectedContent = 'VAR1=previous-VAR1\n';
expect(actualContent).toBe(expectedContent);
});
it('should reprompt if a setting changes sensitivity', async () => {
const previousConfig: ExtensionConfig = {
name: 'test-ext',
version: '1.0.0',
settings: [
{ name: 's1', description: 'd1', envVar: 'VAR1', sensitive: false },
],
};
const newConfig: ExtensionConfig = {
name: 'test-ext',
version: '1.0.0',
settings: [
{ name: 's1', description: 'd1', envVar: 'VAR1', sensitive: true },
],
};
const previousSettings = { VAR1: 'previous-VAR1' };
await maybePromptForSettings(
newConfig,
'12345',
mockRequestSetting,
previousConfig,
previousSettings,
);
expect(mockRequestSetting).toHaveBeenCalledTimes(1);
expect(mockRequestSetting).toHaveBeenCalledWith(newConfig.settings![0]);
// The value should now be in keychain, not the .env file.
const expectedEnvPath = path.join(extensionDir, '.env');
const actualContent = await fsPromises.readFile(expectedEnvPath, 'utf-8');
expect(actualContent).toBe('');
});
it('should not prompt if settings are identical', async () => {
const previousConfig: ExtensionConfig = {
name: 'test-ext',
version: '1.0.0',
settings: [
{ name: 's1', description: 'd1', envVar: 'VAR1' },
{ name: 's2', description: 'd2', envVar: 'VAR2' },
],
};
const newConfig: ExtensionConfig = {
name: 'test-ext',
version: '1.0.0',
settings: [
{ name: 's1', description: 'd1', envVar: 'VAR1' },
{ name: 's2', description: 'd2', envVar: 'VAR2' },
],
};
const previousSettings = {
VAR1: 'previous-VAR1',
VAR2: 'previous-VAR2',
};
await maybePromptForSettings(
newConfig,
'12345',
mockRequestSetting,
previousConfig,
previousSettings,
);
expect(mockRequestSetting).not.toHaveBeenCalled();
const expectedEnvPath = path.join(extensionDir, '.env');
const actualContent = await fsPromises.readFile(expectedEnvPath, 'utf-8');
const expectedContent = 'VAR1=previous-VAR1\nVAR2=previous-VAR2\n';
expect(actualContent).toBe(expectedContent);
});
});
describe('promptForSetting', () => {
// it('should use prompts with type "password" for sensitive settings', async () => {
// const setting: ExtensionSetting = {
// name: 'API Key',
// description: 'Your secret key',
// envVar: 'API_KEY',
// sensitive: true,
// };
// vi.mocked(prompts).mockResolvedValue({ value: 'secret-key' });
it('should use prompts with type "password" for sensitive settings', async () => {
const setting: ExtensionSetting = {
name: 'API Key',
description: 'Your secret key',
envVar: 'API_KEY',
sensitive: true,
};
vi.mocked(prompts).mockResolvedValue({ value: 'secret-key' });
// const result = await promptForSetting(setting);
const result = await promptForSetting(setting);
// expect(prompts).toHaveBeenCalledWith({
// type: 'password',
// name: 'value',
// message: 'API Key\nYour secret key',
// });
// expect(result).toBe('secret-key');
// });
expect(prompts).toHaveBeenCalledWith({
type: 'password',
name: 'value',
message: 'API Key\nYour secret key',
});
expect(result).toBe('secret-key');
});
it('should use prompts with type "text" for non-sensitive settings', async () => {
const setting: ExtensionSetting = {
@@ -12,57 +12,76 @@ import { ExtensionStorage } from './storage.js';
import type { ExtensionConfig } from '../extension.js';
import prompts from 'prompts';
import { KeychainTokenStorage } from '@google/gemini-cli-core';
export interface ExtensionSetting {
name: string;
description: string;
envVar: string;
// NOTE: If no value is set, this setting will be considered NOT sensitive.
sensitive?: boolean;
}
export async function maybePromptForSettings(
extensionConfig: ExtensionConfig,
extensionId: string,
requestSetting: (setting: ExtensionSetting) => Promise<string>,
previousExtensionConfig?: ExtensionConfig,
previousSettings?: Record<string, string>,
): Promise<void> {
const { name: extensionName, settings } = extensionConfig;
if (
(!settings || settings.length === 0) &&
(!previousExtensionConfig?.settings ||
previousExtensionConfig.settings.length === 0)
) {
return;
}
const envFilePath = new ExtensionStorage(extensionName).getEnvFilePath();
const keychain = new KeychainTokenStorage(extensionId);
if (!settings || settings.length === 0) {
// No settings for this extension. Clear any existing .env file.
if (fsSync.existsSync(envFilePath)) {
await fs.writeFile(envFilePath, '');
}
await clearSettings(envFilePath, keychain);
return;
}
let settingsToPrompt = settings;
if (previousExtensionConfig) {
const oldSettings = new Set(
previousExtensionConfig.settings?.map((s) => s.name) || [],
);
settingsToPrompt = settingsToPrompt.filter((s) => !oldSettings.has(s.name));
}
const settingsChanges = getSettingsChanges(
settings,
previousExtensionConfig?.settings ?? [],
);
const allSettings: Record<string, string> = { ...(previousSettings ?? {}) };
if (settingsToPrompt && settingsToPrompt.length > 0) {
for (const setting of settingsToPrompt) {
const answer = await requestSetting(setting);
allSettings[setting.envVar] = answer;
}
for (const removedEnvSetting of settingsChanges.removeEnv) {
delete allSettings[removedEnvSetting.envVar];
}
const validEnvVars = new Set(settings.map((s) => s.envVar));
const finalSettings: Record<string, string> = {};
for (const [key, value] of Object.entries(allSettings)) {
if (validEnvVars.has(key)) {
finalSettings[key] = value;
for (const removedSensitiveSetting of settingsChanges.removeSensitive) {
await keychain.deleteSecret(removedSensitiveSetting.envVar);
}
for (const setting of settingsChanges.promptForSensitive.concat(
settingsChanges.promptForEnv,
)) {
const answer = await requestSetting(setting);
allSettings[setting.envVar] = answer;
}
const nonSensitiveSettings: Record<string, string> = {};
for (const setting of settings) {
const value = allSettings[setting.envVar];
if (value === undefined) {
continue;
}
if (setting.sensitive) {
await keychain.setSecret(setting.envVar, value);
} else {
nonSensitiveSettings[setting.envVar] = value;
}
}
let envContent = '';
for (const [key, value] of Object.entries(finalSettings)) {
for (const [key, value] of Object.entries(nonSensitiveSettings)) {
envContent += `${key}=${value}\n`;
}
@@ -73,17 +92,22 @@ export async function promptForSetting(
setting: ExtensionSetting,
): Promise<string> {
const response = await prompts({
// type: setting.sensitive ? 'password' : 'text',
type: 'text',
type: setting.sensitive ? 'password' : 'text',
name: 'value',
message: `${setting.name}\n${setting.description}`,
});
return response.value;
}
export function getEnvContents(
extensionStorage: ExtensionStorage,
): Record<string, string> {
export async function getEnvContents(
extensionConfig: ExtensionConfig,
extensionId: string,
): Promise<Record<string, string>> {
if (!extensionConfig.settings || extensionConfig.settings.length === 0) {
return Promise.resolve({});
}
const extensionStorage = new ExtensionStorage(extensionConfig.name);
const keychain = new KeychainTokenStorage(extensionId);
let customEnv: Record<string, string> = {};
if (fsSync.existsSync(extensionStorage.getEnvFilePath())) {
const envFile = fsSync.readFileSync(
@@ -92,5 +116,67 @@ export function getEnvContents(
);
customEnv = dotenv.parse(envFile);
}
if (extensionConfig.settings) {
for (const setting of extensionConfig.settings) {
if (setting.sensitive) {
const secret = await keychain.getSecret(setting.envVar);
if (secret) {
customEnv[setting.envVar] = secret;
}
}
}
}
return customEnv;
}
interface settingsChanges {
promptForSensitive: ExtensionSetting[];
removeSensitive: ExtensionSetting[];
promptForEnv: ExtensionSetting[];
removeEnv: ExtensionSetting[];
}
function getSettingsChanges(
settings: ExtensionSetting[],
oldSettings: ExtensionSetting[],
): settingsChanges {
const isSameSetting = (a: ExtensionSetting, b: ExtensionSetting) =>
a.envVar === b.envVar && (a.sensitive ?? false) === (b.sensitive ?? false);
const sensitiveOld = oldSettings.filter((s) => s.sensitive ?? false);
const sensitiveNew = settings.filter((s) => s.sensitive ?? false);
const envOld = oldSettings.filter((s) => !(s.sensitive ?? false));
const envNew = settings.filter((s) => !(s.sensitive ?? false));
return {
promptForSensitive: sensitiveNew.filter(
(s) => !sensitiveOld.some((old) => isSameSetting(s, old)),
),
removeSensitive: sensitiveOld.filter(
(s) => !sensitiveNew.some((neu) => isSameSetting(s, neu)),
),
promptForEnv: envNew.filter(
(s) => !envOld.some((old) => isSameSetting(s, old)),
),
removeEnv: envOld.filter(
(s) => !envNew.some((neu) => isSameSetting(s, neu)),
),
};
}
async function clearSettings(
envFilePath: string,
keychain: KeychainTokenStorage,
) {
if (fsSync.existsSync(envFilePath)) {
await fs.writeFile(envFilePath, '');
}
if (!keychain.isAvailable()) {
return;
}
const secrets = await keychain.listSecrets();
for (const secret of secrets) {
await keychain.deleteSecret(secret);
}
return;
}
@@ -9,7 +9,7 @@ import * as fs from 'node:fs';
import * as os from 'node:os';
import * as path from 'node:path';
import { checkForAllExtensionUpdates, updateExtension } from './update.js';
import { GEMINI_DIR } from '@google/gemini-cli-core';
import { GEMINI_DIR, KeychainTokenStorage } from '@google/gemini-cli-core';
import { isWorkspaceTrusted } from '../trustedFolders.js';
import { ExtensionUpdateState } from '../../ui/state/extensions.js';
import { createExtension } from '../../test-utils/createExtension.js';
@@ -64,9 +64,24 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
logExtensionUninstall: mockLogExtensionUninstall,
ExtensionInstallEvent: vi.fn(),
ExtensionUninstallEvent: vi.fn(),
KeychainTokenStorage: vi.fn().mockImplementation(() => ({
getSecret: vi.fn(),
setSecret: vi.fn(),
deleteSecret: vi.fn(),
listSecrets: vi.fn(),
isAvailable: vi.fn().mockResolvedValue(true),
})),
};
});
interface MockKeychainStorage {
getSecret: ReturnType<typeof vi.fn>;
setSecret: ReturnType<typeof vi.fn>;
deleteSecret: ReturnType<typeof vi.fn>;
listSecrets: ReturnType<typeof vi.fn>;
isAvailable: ReturnType<typeof vi.fn>;
}
describe('update tests', () => {
let tempHomeDir: string;
let tempWorkspaceDir: string;
@@ -76,8 +91,32 @@ describe('update tests', () => {
let mockPromptForSettings: MockedFunction<
(setting: ExtensionSetting) => Promise<string>
>;
let mockKeychainStorage: MockKeychainStorage;
let keychainData: Record<string, string>;
beforeEach(() => {
vi.clearAllMocks();
keychainData = {};
mockKeychainStorage = {
getSecret: vi
.fn()
.mockImplementation(async (key: string) => keychainData[key] || null),
setSecret: vi
.fn()
.mockImplementation(async (key: string, value: string) => {
keychainData[key] = value;
}),
deleteSecret: vi.fn().mockImplementation(async (key: string) => {
delete keychainData[key];
}),
listSecrets: vi
.fn()
.mockImplementation(async () => Object.keys(keychainData)),
isAvailable: vi.fn().mockResolvedValue(true),
};
(
KeychainTokenStorage as unknown as ReturnType<typeof vi.fn>
).mockImplementation(() => mockKeychainStorage);
tempHomeDir = fs.mkdtempSync(
path.join(os.tmpdir(), 'gemini-cli-test-home-'),
);
@@ -110,6 +149,7 @@ describe('update tests', () => {
afterEach(() => {
fs.rmSync(tempHomeDir, { recursive: true, force: true });
fs.rmSync(tempWorkspaceDir, { recursive: true, force: true });
vi.restoreAllMocks();
});
describe('updateExtension', () => {
@@ -139,11 +179,10 @@ describe('update tests', () => {
);
});
mockGit.getRemotes.mockResolvedValue([{ name: 'origin' }]);
const extension = extensionManager
.loadExtensions()
.find((e) => e.name === extensionName)!;
const extensions = await extensionManager.loadExtensions();
const extension = extensions.find((e) => e.name === extensionName)!;
const updateInfo = await updateExtension(
extension,
extension!,
extensionManager,
ExtensionUpdateState.UPDATE_AVAILABLE,
() => {},
@@ -189,11 +228,10 @@ describe('update tests', () => {
const dispatch = vi.fn();
const extension = extensionManager
.loadExtensions()
.find((e) => e.name === extensionName)!;
const extensions = await extensionManager.loadExtensions();
const extension = extensions.find((e) => e.name === extensionName)!;
await updateExtension(
extension,
extension!,
extensionManager,
ExtensionUpdateState.UPDATE_AVAILABLE,
dispatch,
@@ -231,12 +269,11 @@ describe('update tests', () => {
mockGit.getRemotes.mockResolvedValue([{ name: 'origin' }]);
const dispatch = vi.fn();
const extension = extensionManager
.loadExtensions()
.find((e) => e.name === extensionName)!;
const extensions = await extensionManager.loadExtensions();
const extension = extensions.find((e) => e.name === extensionName)!;
await expect(
updateExtension(
extension,
extension!,
extensionManager,
ExtensionUpdateState.UPDATE_AVAILABLE,
dispatch,
@@ -280,7 +317,7 @@ describe('update tests', () => {
const dispatch = vi.fn();
await checkForAllExtensionUpdates(
extensionManager.loadExtensions(),
await extensionManager.loadExtensions(),
extensionManager,
dispatch,
);
@@ -312,7 +349,7 @@ describe('update tests', () => {
const dispatch = vi.fn();
await checkForAllExtensionUpdates(
extensionManager.loadExtensions(),
await extensionManager.loadExtensions(),
extensionManager,
dispatch,
);
@@ -341,7 +378,7 @@ describe('update tests', () => {
});
const dispatch = vi.fn();
await checkForAllExtensionUpdates(
extensionManager.loadExtensions(),
await extensionManager.loadExtensions(),
extensionManager,
dispatch,
);
@@ -370,7 +407,7 @@ describe('update tests', () => {
});
const dispatch = vi.fn();
await checkForAllExtensionUpdates(
extensionManager.loadExtensions(),
await extensionManager.loadExtensions(),
extensionManager,
dispatch,
);
@@ -398,7 +435,7 @@ describe('update tests', () => {
const dispatch = vi.fn();
await checkForAllExtensionUpdates(
extensionManager.loadExtensions(),
await extensionManager.loadExtensions(),
extensionManager,
dispatch,
);
+1 -1
View File
@@ -58,7 +58,7 @@ export async function updateExtension(
const tempDir = await ExtensionStorage.createTmpDir();
try {
const previousExtensionConfig = await extensionManager.loadExtensionConfig(
const previousExtensionConfig = extensionManager.loadExtensionConfig(
extension.path,
);
let updatedExtension: GeminiCLIExtension;
+2 -2
View File
@@ -2442,7 +2442,7 @@ describe('Settings Loading and Merging', () => {
extensionManager,
'disableExtension',
);
mockDisableExtension.mockImplementation(() => {});
mockDisableExtension.mockImplementation(async () => {});
migrateDeprecatedSettings(loadedSettings, extensionManager);
@@ -2515,7 +2515,7 @@ describe('Settings Loading and Merging', () => {
extensionManager,
'disableExtension',
);
mockDisableExtension.mockImplementation(() => {});
mockDisableExtension.mockImplementation(async () => {});
migrateDeprecatedSettings(loadedSettings, extensionManager);
@@ -124,7 +124,7 @@ describe('useExtensionUpdates', () => {
autoUpdate: true,
},
});
await extensionManager.loadExtensions();
const addItem = vi.fn();
vi.mocked(checkForAllExtensionUpdates).mockImplementation(
@@ -145,7 +145,6 @@ describe('useExtensionUpdates', () => {
name: '',
});
extensionManager.loadExtensions();
function TestComponent() {
useExtensionUpdates(extensionManager, addItem);
return null;
@@ -189,7 +188,7 @@ describe('useExtensionUpdates', () => {
},
});
extensionManager.loadExtensions();
await extensionManager.loadExtensions();
const addItem = vi.fn();