mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-02 09:20:42 -07:00
Reload gemini memory on extension load/unload + memory refresh refactor (#12651)
This commit is contained in:
@@ -266,6 +266,8 @@ export interface ConfigParameters {
|
||||
folderTrust?: boolean;
|
||||
ideMode?: boolean;
|
||||
loadMemoryFromIncludeDirectories?: boolean;
|
||||
importFormat?: 'tree' | 'flat';
|
||||
discoveryMaxDirs?: number;
|
||||
compressionThreshold?: number;
|
||||
interactive?: boolean;
|
||||
trustedFolder?: boolean;
|
||||
@@ -369,6 +371,8 @@ export class Config {
|
||||
| undefined;
|
||||
private readonly experimentalZedIntegration: boolean = false;
|
||||
private readonly loadMemoryFromIncludeDirectories: boolean = false;
|
||||
private readonly importFormat: 'tree' | 'flat';
|
||||
private readonly discoveryMaxDirs: number;
|
||||
private readonly compressionThreshold: number | undefined;
|
||||
private readonly interactive: boolean;
|
||||
private readonly ptyInfo: string;
|
||||
@@ -479,6 +483,8 @@ export class Config {
|
||||
this.ideMode = params.ideMode ?? false;
|
||||
this.loadMemoryFromIncludeDirectories =
|
||||
params.loadMemoryFromIncludeDirectories ?? false;
|
||||
this.importFormat = params.importFormat ?? 'tree';
|
||||
this.discoveryMaxDirs = params.discoveryMaxDirs ?? 200;
|
||||
this.compressionThreshold = params.compressionThreshold;
|
||||
this.interactive = params.interactive ?? false;
|
||||
this.ptyInfo = params.ptyInfo ?? 'child_process';
|
||||
@@ -707,6 +713,14 @@ export class Config {
|
||||
return this.loadMemoryFromIncludeDirectories;
|
||||
}
|
||||
|
||||
getImportFormat(): 'tree' | 'flat' {
|
||||
return this.importFormat;
|
||||
}
|
||||
|
||||
getDiscoveryMaxDirs(): number {
|
||||
return this.discoveryMaxDirs;
|
||||
}
|
||||
|
||||
getContentGeneratorConfig(): ContentGeneratorConfig {
|
||||
return this.contentGeneratorConfig;
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
import { EventEmitter } from 'node:events';
|
||||
import type { LoadServerHierarchicalMemoryResponse } from './memoryDiscovery.js';
|
||||
|
||||
/**
|
||||
* Defines the severity level for user-facing feedback.
|
||||
@@ -53,13 +54,26 @@ export interface ModelChangedPayload {
|
||||
model: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Payload for the 'memory-changed' event.
|
||||
*/
|
||||
export type MemoryChangedPayload = LoadServerHierarchicalMemoryResponse;
|
||||
|
||||
export enum CoreEvent {
|
||||
UserFeedback = 'user-feedback',
|
||||
FallbackModeChanged = 'fallback-mode-changed',
|
||||
ModelChanged = 'model-changed',
|
||||
MemoryChanged = 'memory-changed',
|
||||
}
|
||||
|
||||
export class CoreEventEmitter extends EventEmitter {
|
||||
export interface CoreEvents {
|
||||
[CoreEvent.UserFeedback]: [UserFeedbackPayload];
|
||||
[CoreEvent.FallbackModeChanged]: [FallbackModeChangedPayload];
|
||||
[CoreEvent.ModelChanged]: [ModelChangedPayload];
|
||||
[CoreEvent.MemoryChanged]: [MemoryChangedPayload];
|
||||
}
|
||||
|
||||
export class CoreEventEmitter extends EventEmitter<CoreEvents> {
|
||||
private _feedbackBacklog: UserFeedbackPayload[] = [];
|
||||
private static readonly MAX_BACKLOG_SIZE = 10000;
|
||||
|
||||
@@ -116,63 +130,6 @@ export class CoreEventEmitter extends EventEmitter {
|
||||
this.emit(CoreEvent.UserFeedback, payload);
|
||||
}
|
||||
}
|
||||
|
||||
override on(
|
||||
event: CoreEvent.UserFeedback,
|
||||
listener: (payload: UserFeedbackPayload) => void,
|
||||
): this;
|
||||
override on(
|
||||
event: CoreEvent.FallbackModeChanged,
|
||||
listener: (payload: FallbackModeChangedPayload) => void,
|
||||
): this;
|
||||
override on(
|
||||
event: CoreEvent.ModelChanged,
|
||||
listener: (payload: ModelChangedPayload) => void,
|
||||
): this;
|
||||
override on(
|
||||
event: string | symbol,
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
listener: (...args: any[]) => void,
|
||||
): this {
|
||||
return super.on(event, listener);
|
||||
}
|
||||
|
||||
override off(
|
||||
event: CoreEvent.UserFeedback,
|
||||
listener: (payload: UserFeedbackPayload) => void,
|
||||
): this;
|
||||
override off(
|
||||
event: CoreEvent.FallbackModeChanged,
|
||||
listener: (payload: FallbackModeChangedPayload) => void,
|
||||
): this;
|
||||
override off(
|
||||
event: CoreEvent.ModelChanged,
|
||||
listener: (payload: ModelChangedPayload) => void,
|
||||
): this;
|
||||
override off(
|
||||
event: string | symbol,
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
listener: (...args: any[]) => void,
|
||||
): this {
|
||||
return super.off(event, listener);
|
||||
}
|
||||
|
||||
override emit(
|
||||
event: CoreEvent.UserFeedback,
|
||||
payload: UserFeedbackPayload,
|
||||
): boolean;
|
||||
override emit(
|
||||
event: CoreEvent.FallbackModeChanged,
|
||||
payload: FallbackModeChangedPayload,
|
||||
): boolean;
|
||||
override emit(
|
||||
event: CoreEvent.ModelChanged,
|
||||
payload: ModelChangedPayload,
|
||||
): boolean;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
override emit(event: string | symbol, ...args: any[]): boolean {
|
||||
return super.emit(event, ...args);
|
||||
}
|
||||
}
|
||||
|
||||
export const coreEvents = new CoreEventEmitter();
|
||||
|
||||
@@ -9,6 +9,16 @@ import { SimpleExtensionLoader } from './extensionLoader.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { type McpClientManager } from '../tools/mcp-client-manager.js';
|
||||
|
||||
const mockRefreshServerHierarchicalMemory = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('./memoryDiscovery.js', async (importActual) => {
|
||||
const actual = await importActual<typeof import('./memoryDiscovery.js')>();
|
||||
return {
|
||||
...actual,
|
||||
refreshServerHierarchicalMemory: mockRefreshServerHierarchicalMemory,
|
||||
};
|
||||
});
|
||||
|
||||
describe('SimpleExtensionLoader', () => {
|
||||
let mockConfig: Config;
|
||||
let extensionReloadingEnabled: boolean;
|
||||
@@ -79,29 +89,59 @@ describe('SimpleExtensionLoader', () => {
|
||||
).toHaveBeenCalledExactlyOnceWith(activeExtension);
|
||||
});
|
||||
|
||||
it.each([true, false])(
|
||||
'should only call `start` and `stop` if extension reloading is enabled ($i)',
|
||||
async (reloadingEnabled) => {
|
||||
extensionReloadingEnabled = reloadingEnabled;
|
||||
const loader = new SimpleExtensionLoader([]);
|
||||
await loader.start(mockConfig);
|
||||
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
|
||||
await loader.loadExtension(activeExtension);
|
||||
if (reloadingEnabled) {
|
||||
expect(
|
||||
mockMcpClientManager.startExtension,
|
||||
).toHaveBeenCalledExactlyOnceWith(activeExtension);
|
||||
} else {
|
||||
describe.each([true, false])(
|
||||
'when enableExtensionReloading === $i',
|
||||
(reloadingEnabled) => {
|
||||
beforeEach(() => {
|
||||
extensionReloadingEnabled = reloadingEnabled;
|
||||
});
|
||||
|
||||
it(`should ${reloadingEnabled ? '' : 'not '}reload extension features`, async () => {
|
||||
const loader = new SimpleExtensionLoader([]);
|
||||
await loader.start(mockConfig);
|
||||
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
|
||||
}
|
||||
await loader.unloadExtension(activeExtension);
|
||||
if (reloadingEnabled) {
|
||||
expect(
|
||||
mockMcpClientManager.stopExtension,
|
||||
).toHaveBeenCalledExactlyOnceWith(activeExtension);
|
||||
} else {
|
||||
expect(mockMcpClientManager.stopExtension).not.toHaveBeenCalled();
|
||||
}
|
||||
await loader.loadExtension(activeExtension);
|
||||
if (reloadingEnabled) {
|
||||
expect(
|
||||
mockMcpClientManager.startExtension,
|
||||
).toHaveBeenCalledExactlyOnceWith(activeExtension);
|
||||
expect(mockRefreshServerHierarchicalMemory).toHaveBeenCalledOnce();
|
||||
} else {
|
||||
expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled();
|
||||
expect(mockRefreshServerHierarchicalMemory).not.toHaveBeenCalled();
|
||||
}
|
||||
mockRefreshServerHierarchicalMemory.mockClear();
|
||||
|
||||
await loader.unloadExtension(activeExtension);
|
||||
if (reloadingEnabled) {
|
||||
expect(
|
||||
mockMcpClientManager.stopExtension,
|
||||
).toHaveBeenCalledExactlyOnceWith(activeExtension);
|
||||
expect(mockRefreshServerHierarchicalMemory).toHaveBeenCalledOnce();
|
||||
} else {
|
||||
expect(mockMcpClientManager.stopExtension).not.toHaveBeenCalled();
|
||||
expect(mockRefreshServerHierarchicalMemory).not.toHaveBeenCalled();
|
||||
}
|
||||
});
|
||||
|
||||
it.runIf(reloadingEnabled)(
|
||||
'Should only reload memory once all extensions are done',
|
||||
async () => {
|
||||
const anotherExtension = {
|
||||
...activeExtension,
|
||||
name: 'another-extension',
|
||||
};
|
||||
const loader = new SimpleExtensionLoader([]);
|
||||
await loader.loadExtension(activeExtension);
|
||||
await loader.start(mockConfig);
|
||||
expect(mockRefreshServerHierarchicalMemory).not.toHaveBeenCalled();
|
||||
await Promise.all([
|
||||
loader.unloadExtension(activeExtension),
|
||||
loader.loadExtension(anotherExtension),
|
||||
]);
|
||||
expect(mockRefreshServerHierarchicalMemory).toHaveBeenCalledOnce();
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import type { EventEmitter } from 'node:events';
|
||||
import type { Config, GeminiCLIExtension } from '../config/config.js';
|
||||
import { refreshServerHierarchicalMemory } from './memoryDiscovery.js';
|
||||
|
||||
export abstract class ExtensionLoader {
|
||||
// Assigned in `start`.
|
||||
@@ -18,6 +19,9 @@ export abstract class ExtensionLoader {
|
||||
protected stoppingCount: number = 0;
|
||||
protected stopCompletedCount: number = 0;
|
||||
|
||||
// Whether or not we are currently executing `start`
|
||||
private isStarting: boolean = false;
|
||||
|
||||
constructor(private readonly eventEmitter?: EventEmitter<ExtensionEvents>) {}
|
||||
|
||||
/**
|
||||
@@ -32,16 +36,21 @@ export abstract class ExtensionLoader {
|
||||
* McpClientManager, PromptRegistry, and GeminiChat set up.
|
||||
*/
|
||||
async start(config: Config): Promise<void> {
|
||||
if (!this.config) {
|
||||
this.config = config;
|
||||
} else {
|
||||
throw new Error('Already started, you may only call `start` once.');
|
||||
this.isStarting = true;
|
||||
try {
|
||||
if (!this.config) {
|
||||
this.config = config;
|
||||
} else {
|
||||
throw new Error('Already started, you may only call `start` once.');
|
||||
}
|
||||
await Promise.all(
|
||||
this.getExtensions()
|
||||
.filter((e) => e.isActive)
|
||||
.map(this.startExtension.bind(this)),
|
||||
);
|
||||
} finally {
|
||||
this.isStarting = false;
|
||||
}
|
||||
await Promise.all(
|
||||
this.getExtensions()
|
||||
.filter((e) => e.isActive)
|
||||
.map(this.startExtension.bind(this)),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -64,12 +73,15 @@ export abstract class ExtensionLoader {
|
||||
});
|
||||
try {
|
||||
await this.config.getMcpClientManager()!.startExtension(extension);
|
||||
// Note: Context files are loaded only once all extensions are done
|
||||
// loading/unloading to reduce churn, see the `maybeRefreshMemories` call
|
||||
// below.
|
||||
|
||||
// TODO: Update custom command updating away from the event based system
|
||||
// and call directly into a custom command manager here. See the
|
||||
// useSlashCommandProcessor hook which responds to events fired here today.
|
||||
|
||||
// TODO: Move all enablement of extension features here, including at least:
|
||||
// - context file loading
|
||||
// - excluded tool configuration
|
||||
} finally {
|
||||
this.startCompletedCount++;
|
||||
@@ -81,6 +93,25 @@ export abstract class ExtensionLoader {
|
||||
this.startingCount = 0;
|
||||
this.startCompletedCount = 0;
|
||||
}
|
||||
await this.maybeRefreshMemories();
|
||||
}
|
||||
}
|
||||
|
||||
private async maybeRefreshMemories(): Promise<void> {
|
||||
if (!this.config) {
|
||||
throw new Error(
|
||||
'Cannot refresh gemini memories prior to calling `start`.',
|
||||
);
|
||||
}
|
||||
if (
|
||||
!this.isStarting && // Don't refresh memories on the first call to `start`.
|
||||
this.startingCount === this.startCompletedCount &&
|
||||
this.stoppingCount === this.stopCompletedCount
|
||||
) {
|
||||
// Wait until all extensions are done starting and stopping before we
|
||||
// reload memory, this is somewhat expensive and also busts the context
|
||||
// cache, we want to only do it once.
|
||||
await refreshServerHierarchicalMemory(this.config);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,12 +150,15 @@ export abstract class ExtensionLoader {
|
||||
|
||||
try {
|
||||
await this.config.getMcpClientManager()!.stopExtension(extension);
|
||||
// Note: Context files are loaded only once all extensions are done
|
||||
// loading/unloading to reduce churn, see the `maybeRefreshMemories` call
|
||||
// below.
|
||||
|
||||
// TODO: Update custom command updating away from the event based system
|
||||
// and call directly into a custom command manager here. See the
|
||||
// useSlashCommandProcessor hook which responds to events fired here today.
|
||||
|
||||
// TODO: Remove all extension features here, including at least:
|
||||
// - context files
|
||||
// - excluded tools
|
||||
} finally {
|
||||
this.stopCompletedCount++;
|
||||
@@ -136,6 +170,7 @@ export abstract class ExtensionLoader {
|
||||
this.stoppingCount = 0;
|
||||
this.stopCompletedCount = 0;
|
||||
}
|
||||
await this.maybeRefreshMemories();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
loadGlobalMemory,
|
||||
loadEnvironmentMemory,
|
||||
loadJitSubdirectoryMemory,
|
||||
refreshServerHierarchicalMemory,
|
||||
} from './memoryDiscovery.js';
|
||||
import {
|
||||
setGeminiMdFilename,
|
||||
@@ -20,8 +21,10 @@ import {
|
||||
} from '../tools/memoryTool.js';
|
||||
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
|
||||
import { GEMINI_DIR } from './paths.js';
|
||||
import type { GeminiCLIExtension } from '../config/config.js';
|
||||
import { Config, type GeminiCLIExtension } from '../config/config.js';
|
||||
import { Storage } from '../config/storage.js';
|
||||
import { SimpleExtensionLoader } from './extensionLoader.js';
|
||||
import { CoreEvent, coreEvents } from './events.js';
|
||||
|
||||
vi.mock('os', async (importOriginal) => {
|
||||
const actualOs = await importOriginal<typeof os>();
|
||||
@@ -876,4 +879,58 @@ included directory memory
|
||||
expect(result.files.find((f) => f.path === outerMemory)).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
it('refreshServerHierarchicalMemory should refresh memory and update config', async () => {
|
||||
const extensionLoader = new SimpleExtensionLoader([]);
|
||||
const config = new Config({
|
||||
sessionId: '1',
|
||||
targetDir: cwd,
|
||||
cwd,
|
||||
debugMode: false,
|
||||
model: 'fake-model',
|
||||
extensionLoader,
|
||||
});
|
||||
const result = await loadServerHierarchicalMemory(
|
||||
config.getWorkingDir(),
|
||||
config.shouldLoadMemoryFromIncludeDirectories()
|
||||
? config.getWorkspaceContext().getDirectories()
|
||||
: [],
|
||||
config.getDebugMode(),
|
||||
config.getFileService(),
|
||||
config.getExtensionLoader(),
|
||||
config.isTrustedFolder(),
|
||||
config.getImportFormat(),
|
||||
);
|
||||
expect(result.fileCount).equals(0);
|
||||
|
||||
// Now add an extension with a memory file
|
||||
const extensionsDir = new Storage(homedir).getExtensionsDir();
|
||||
const extensionPath = path.join(extensionsDir, 'new-extension');
|
||||
const contextFilePath = path.join(extensionPath, 'CustomContext.md');
|
||||
await fsPromises.mkdir(extensionPath, { recursive: true });
|
||||
await fsPromises.writeFile(contextFilePath, 'Really cool custom context!');
|
||||
await extensionLoader.loadExtension({
|
||||
name: 'new-extension',
|
||||
isActive: true,
|
||||
contextFiles: [contextFilePath],
|
||||
version: '1.0.0',
|
||||
id: '1234',
|
||||
path: extensionPath,
|
||||
});
|
||||
|
||||
const mockEventListener = vi.fn();
|
||||
coreEvents.on(CoreEvent.MemoryChanged, mockEventListener);
|
||||
const refreshResult = await refreshServerHierarchicalMemory(config);
|
||||
expect(refreshResult.fileCount).equals(1);
|
||||
expect(config.getGeminiMdFileCount()).equals(refreshResult.fileCount);
|
||||
expect(refreshResult.memoryContent).toContain(
|
||||
'Really cool custom context!',
|
||||
);
|
||||
expect(config.getUserMemory()).equals(refreshResult.memoryContent);
|
||||
expect(refreshResult.filePaths[0]).toContain(
|
||||
path.join(extensionPath, 'CustomContext.md'),
|
||||
);
|
||||
expect(config.getGeminiMdFilePaths()).equals(refreshResult.filePaths);
|
||||
expect(mockEventListener).toHaveBeenCalledExactlyOnceWith(refreshResult);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -17,6 +17,8 @@ import { DEFAULT_MEMORY_FILE_FILTERING_OPTIONS } from '../config/constants.js';
|
||||
import { GEMINI_DIR } from './paths.js';
|
||||
import type { ExtensionLoader } from './extensionLoader.js';
|
||||
import { debugLogger } from './debugLogger.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { CoreEvent, coreEvents } from './events.js';
|
||||
|
||||
// Simple console logger, similar to the one previously in CLI's config.ts
|
||||
// TODO: Integrate with a more robust server-side logger if available/appropriate.
|
||||
@@ -481,6 +483,15 @@ export async function loadServerHierarchicalMemory(
|
||||
fileFilteringOptions?: FileFilteringOptions,
|
||||
maxDirs: number = 200,
|
||||
): Promise<LoadServerHierarchicalMemoryResponse> {
|
||||
// FIX: Use real, canonical paths for a reliable comparison to handle symlinks.
|
||||
const realCwd = await fs.realpath(path.resolve(currentWorkingDirectory));
|
||||
const realHome = await fs.realpath(path.resolve(homedir()));
|
||||
const isHomeDirectory = realCwd === realHome;
|
||||
|
||||
// If it is the home directory, pass an empty string to the core memory
|
||||
// function to signal that it should skip the workspace search.
|
||||
currentWorkingDirectory = isHomeDirectory ? '' : currentWorkingDirectory;
|
||||
|
||||
if (debugMode)
|
||||
logger.debug(
|
||||
`Loading server hierarchical memory for CWD: ${currentWorkingDirectory} (importFormat: ${importFormat})`,
|
||||
@@ -538,6 +549,33 @@ export async function loadServerHierarchicalMemory(
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads the hierarchical memory and resets the state of `config` as needed such
|
||||
* that it reflects the new memory.
|
||||
*
|
||||
* Returns the result of the call to `loadHierarchicalGeminiMemory`.
|
||||
*/
|
||||
export async function refreshServerHierarchicalMemory(config: Config) {
|
||||
const result = await loadServerHierarchicalMemory(
|
||||
config.getWorkingDir(),
|
||||
config.shouldLoadMemoryFromIncludeDirectories()
|
||||
? config.getWorkspaceContext().getDirectories()
|
||||
: [],
|
||||
config.getDebugMode(),
|
||||
config.getFileService(),
|
||||
config.getExtensionLoader(),
|
||||
config.isTrustedFolder(),
|
||||
config.getImportFormat(),
|
||||
config.getFileFilteringOptions(),
|
||||
config.getDiscoveryMaxDirs(),
|
||||
);
|
||||
config.setUserMemory(result.memoryContent);
|
||||
config.setGeminiMdFileCount(result.fileCount);
|
||||
config.setGeminiMdFilePaths(result.filePaths);
|
||||
coreEvents.emit(CoreEvent.MemoryChanged, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
export async function loadJitSubdirectoryMemory(
|
||||
targetPath: string,
|
||||
trustedRoots: string[],
|
||||
|
||||
Reference in New Issue
Block a user