feat(core): Implement JIT context memory loading and UI sync (#14469)

This commit is contained in:
Sandy Tao
2025-12-19 07:04:03 -10:00
committed by GitHub
parent 3c92bdb1ad
commit 2e229d3bb6
14 changed files with 292 additions and 91 deletions

View File

@@ -435,9 +435,15 @@ export async function loadCliConfig(
});
await extensionManager.loadExtensions();
// Call the (now wrapper) loadHierarchicalGeminiMemory which calls the server's version
const { memoryContent, fileCount, filePaths } =
await loadServerHierarchicalMemory(
const experimentalJitContext = settings.experimental?.jitContext ?? false;
let memoryContent = '';
let fileCount = 0;
let filePaths: string[] = [];
if (!experimentalJitContext) {
// Call the (now wrapper) loadHierarchicalGeminiMemory which calls the server's version
const result = await loadServerHierarchicalMemory(
cwd,
[],
debugMode,
@@ -448,6 +454,10 @@ export async function loadCliConfig(
memoryFileFiltering,
settings.context?.discoveryMaxDirs,
);
memoryContent = result.memoryContent;
fileCount = result.fileCount;
filePaths = result.filePaths;
}
const question = argv.promptInteractive || argv.prompt || '';

View File

@@ -157,12 +157,14 @@ describe('memoryCommand', () => {
let mockSetUserMemory: Mock;
let mockSetGeminiMdFileCount: Mock;
let mockSetGeminiMdFilePaths: Mock;
let mockContextManagerRefresh: Mock;
beforeEach(() => {
refreshCommand = getSubCommand('refresh');
mockSetUserMemory = vi.fn();
mockSetGeminiMdFileCount = vi.fn();
mockSetGeminiMdFilePaths = vi.fn();
mockContextManagerRefresh = vi.fn().mockResolvedValue(undefined);
const mockConfig = {
setUserMemory: mockSetUserMemory,
@@ -185,6 +187,12 @@ describe('memoryCommand', () => {
updateSystemInstructionIfInitialized: vi
.fn()
.mockResolvedValue(undefined),
isJitContextEnabled: vi.fn().mockReturnValue(false),
getContextManager: vi.fn().mockReturnValue({
refresh: mockContextManagerRefresh,
}),
getUserMemory: vi.fn().mockReturnValue(''),
getGeminiMdFileCount: vi.fn().mockReturnValue(0),
};
mockContext = createMockCommandContext({
@@ -203,7 +211,32 @@ describe('memoryCommand', () => {
mockRefreshServerHierarchicalMemory.mockClear();
});
it('should display success message when memory is refreshed with content', async () => {
it('should use ContextManager.refresh when JIT is enabled', async () => {
if (!refreshCommand.action) throw new Error('Command has no action');
// Enable JIT in mock config
const config = mockContext.services.config;
if (!config) throw new Error('Config is undefined');
vi.mocked(config.isJitContextEnabled).mockReturnValue(true);
vi.mocked(config.getUserMemory).mockReturnValue('JIT Memory Content');
vi.mocked(config.getGeminiMdFileCount).mockReturnValue(3);
await refreshCommand.action(mockContext, '');
expect(mockContextManagerRefresh).toHaveBeenCalledOnce();
expect(mockRefreshServerHierarchicalMemory).not.toHaveBeenCalled();
expect(mockContext.ui.addItem).toHaveBeenCalledWith(
{
type: MessageType.INFO,
text: 'Memory refreshed successfully. Loaded 18 characters from 3 file(s).',
},
expect.any(Number),
);
});
it('should display success message when memory is refreshed with content (Legacy)', async () => {
if (!refreshCommand.action) throw new Error('Command has no action');
const refreshResult: LoadServerHierarchicalMemoryResponse = {

View File

@@ -87,8 +87,18 @@ export const memoryCommand: SlashCommand = {
try {
const config = context.services.config;
if (config) {
const { memoryContent, fileCount } =
await refreshServerHierarchicalMemory(config);
let memoryContent = '';
let fileCount = 0;
if (config.isJitContextEnabled()) {
await config.getContextManager()?.refresh();
memoryContent = config.getUserMemory();
fileCount = config.getGeminiMdFileCount();
} else {
const result = await refreshServerHierarchicalMemory(config);
memoryContent = result.memoryContent;
fileCount = result.fileCount;
}
await config.updateSystemInstructionIfInitialized();

View File

@@ -65,6 +65,13 @@ vi.mock('../tools/tool-registry', () => {
return { ToolRegistry: ToolRegistryMock };
});
vi.mock('../tools/mcp-client-manager.js', () => ({
McpClientManager: vi.fn().mockImplementation(() => ({
startConfiguredMcpServers: vi.fn(),
getMcpInstructions: vi.fn().mockReturnValue('MCP Instructions'),
})),
}));
vi.mock('../utils/memoryDiscovery.js', () => ({
loadServerHierarchicalMemory: vi.fn(),
}));
@@ -168,12 +175,15 @@ vi.mock('../utils/fetch.js', () => ({
setGlobalProxy: mockSetGlobalProxy,
}));
vi.mock('../services/contextManager.js');
import { BaseLlmClient } from '../core/baseLlmClient.js';
import { tokenLimit } from '../core/tokenLimits.js';
import { uiTelemetryService } from '../telemetry/index.js';
import { getCodeAssistServer } from '../code_assist/codeAssist.js';
import { getExperiments } from '../code_assist/experiments/experiments.js';
import type { CodeAssistServer } from '../code_assist/server.js';
import { ContextManager } from '../services/contextManager.js';
vi.mock('../core/baseLlmClient.js');
vi.mock('../core/tokenLimits.js', () => ({
@@ -1777,7 +1787,7 @@ describe('Config Quota & Preview Model Access', () => {
sessionId: 'test-session',
model: 'gemini-pro',
usageStatisticsEnabled: false,
embeddingModel: 'gemini-embedding', // required in type but not in the original file I copied, adding here
embeddingModel: 'gemini-embedding',
sandbox: {
command: 'docker',
image: 'gemini-cli-sandbox',
@@ -1877,3 +1887,71 @@ describe('Config Quota & Preview Model Access', () => {
});
});
});
describe('Config JIT Initialization', () => {
let config: Config;
let mockContextManager: {
refresh: Mock;
getGlobalMemory: Mock;
getEnvironmentMemory: Mock;
getLoadedPaths: Mock;
};
beforeEach(() => {
vi.clearAllMocks();
mockContextManager = {
refresh: vi.fn(),
getGlobalMemory: vi.fn().mockReturnValue('Global Memory'),
getEnvironmentMemory: vi
.fn()
.mockReturnValue('Environment Memory\n\nMCP Instructions'),
getLoadedPaths: vi.fn().mockReturnValue(new Set(['/path/to/GEMINI.md'])),
};
(ContextManager as unknown as Mock).mockImplementation(
() => mockContextManager,
);
});
it('should initialize ContextManager, load memory, and delegate to it when experimentalJitContext is enabled', async () => {
const params: ConfigParameters = {
sessionId: 'test-session',
targetDir: '/tmp/test',
debugMode: false,
model: 'test-model',
experimentalJitContext: true,
userMemory: 'Initial Memory',
cwd: '/tmp/test',
};
config = new Config(params);
await config.initialize();
expect(ContextManager).toHaveBeenCalledWith(config);
expect(mockContextManager.refresh).toHaveBeenCalled();
expect(config.getUserMemory()).toBe(
'Global Memory\n\nEnvironment Memory\n\nMCP Instructions',
);
// Verify state update (delegated to ContextManager)
expect(config.getGeminiMdFileCount()).toBe(1);
expect(config.getGeminiMdFilePaths()).toEqual(['/path/to/GEMINI.md']);
});
it('should NOT initialize ContextManager when experimentalJitContext is disabled', async () => {
const params: ConfigParameters = {
sessionId: 'test-session',
targetDir: '/tmp/test',
debugMode: false,
model: 'test-model',
experimentalJitContext: false,
userMemory: 'Initial Memory',
cwd: '/tmp/test',
};
config = new Config(params);
await config.initialize();
expect(ContextManager).not.toHaveBeenCalled();
expect(config.getUserMemory()).toBe('Initial Memory');
});
});

View File

@@ -696,6 +696,7 @@ export class Config {
if (this.experimentalJitContext) {
this.contextManager = new ContextManager(this);
await this.contextManager.refresh();
}
await this.geminiClient.initialize();
@@ -1062,6 +1063,14 @@ export class Config {
}
getUserMemory(): string {
if (this.experimentalJitContext && this.contextManager) {
return [
this.contextManager.getGlobalMemory(),
this.contextManager.getEnvironmentMemory(),
]
.filter(Boolean)
.join('\n\n');
}
return this.userMemory;
}
@@ -1086,6 +1095,9 @@ export class Config {
}
getGeminiMdFileCount(): number {
if (this.experimentalJitContext && this.contextManager) {
return this.contextManager.getLoadedPaths().size;
}
return this.geminiMdFileCount;
}
@@ -1094,6 +1106,9 @@ export class Config {
}
getGeminiMdFilePaths(): string[] {
if (this.experimentalJitContext && this.contextManager) {
return Array.from(this.contextManager.getLoadedPaths());
}
return this.geminiMdFilePaths;
}

View File

@@ -208,6 +208,9 @@ describe('Gemini Client (client.ts)', () => {
getVertexAI: vi.fn().mockReturnValue(false),
getUserAgent: vi.fn().mockReturnValue('test-agent'),
getUserMemory: vi.fn().mockReturnValue(''),
getGlobalMemory: vi.fn().mockReturnValue(''),
getEnvironmentMemory: vi.fn().mockReturnValue(''),
isJitContextEnabled: vi.fn().mockReturnValue(false),
getSessionId: vi.fn().mockReturnValue('test-session-id'),
getProxy: vi.fn().mockReturnValue(undefined),
@@ -1532,6 +1535,39 @@ ${JSON.stringify(
});
});
it('should use getGlobalMemory for system instruction when JIT is enabled', async () => {
vi.mocked(mockConfig.isJitContextEnabled).mockReturnValue(true);
vi.mocked(mockConfig.getGlobalMemory).mockReturnValue(
'Global JIT Memory',
);
vi.mocked(mockConfig.getUserMemory).mockReturnValue('Full JIT Memory');
const { getCoreSystemPrompt } = await import('./prompts.js');
const mockGetCoreSystemPrompt = vi.mocked(getCoreSystemPrompt);
await client.updateSystemInstruction();
expect(mockGetCoreSystemPrompt).toHaveBeenCalledWith(
mockConfig,
'Global JIT Memory',
);
});
it('should use getUserMemory for system instruction when JIT is disabled', async () => {
vi.mocked(mockConfig.isJitContextEnabled).mockReturnValue(false);
vi.mocked(mockConfig.getUserMemory).mockReturnValue('Legacy Memory');
const { getCoreSystemPrompt } = await import('./prompts.js');
const mockGetCoreSystemPrompt = vi.mocked(getCoreSystemPrompt);
await client.updateSystemInstruction();
expect(mockGetCoreSystemPrompt).toHaveBeenCalledWith(
mockConfig,
'Legacy Memory',
);
});
it('should recursively call sendMessageStream with "Please continue." when InvalidStream event is received', async () => {
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
true,

View File

@@ -179,8 +179,10 @@ export class GeminiClient {
return;
}
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(this.config, userMemory);
const systemMemory = this.config.isJitContextEnabled()
? this.config.getGlobalMemory()
: this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(this.config, systemMemory);
this.getChat().setSystemInstruction(systemInstruction);
}
@@ -198,8 +200,10 @@ export class GeminiClient {
const history = await getInitialChatHistory(this.config, extraHistory);
try {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(this.config, userMemory);
const systemMemory = this.config.isJitContextEnabled()
? this.config.getGlobalMemory()
: this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(this.config, systemMemory);
return new GeminiChat(
this.config,
systemInstruction,

View File

@@ -8,7 +8,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest';
import { ContextManager } from './contextManager.js';
import * as memoryDiscovery from '../utils/memoryDiscovery.js';
import type { Config } from '../config/config.js';
import type { ExtensionLoader } from '../utils/extensionLoader.js';
import { coreEvents, CoreEvent } from '../utils/events.js';
// Mock memoryDiscovery module
vi.mock('../utils/memoryDiscovery.js', async (importOriginal) => {
@@ -19,6 +19,9 @@ vi.mock('../utils/memoryDiscovery.js', async (importOriginal) => {
loadGlobalMemory: vi.fn(),
loadEnvironmentMemory: vi.fn(),
loadJitSubdirectoryMemory: vi.fn(),
concatenateInstructions: vi
.fn()
.mockImplementation(actual.concatenateInstructions),
};
});
@@ -30,58 +33,84 @@ describe('ContextManager', () => {
mockConfig = {
getDebugMode: vi.fn().mockReturnValue(false),
getWorkingDir: vi.fn().mockReturnValue('/app'),
getWorkspaceContext: vi.fn().mockReturnValue({
getDirectories: vi.fn().mockReturnValue(['/app']),
}),
getExtensionLoader: vi.fn().mockReturnValue({}),
getMcpClientManager: vi.fn().mockReturnValue({
getMcpInstructions: vi.fn().mockReturnValue('MCP Instructions'),
}),
} as unknown as Config;
contextManager = new ContextManager(mockConfig);
vi.clearAllMocks();
vi.spyOn(coreEvents, 'emit');
});
describe('loadGlobalMemory', () => {
it('should load and format global memory', async () => {
const mockResult: memoryDiscovery.MemoryLoadResult = {
describe('refresh', () => {
it('should load and format global and environment memory', async () => {
const mockGlobalResult: memoryDiscovery.MemoryLoadResult = {
files: [
{ path: '/home/user/.gemini/GEMINI.md', content: 'Global Content' },
],
};
vi.mocked(memoryDiscovery.loadGlobalMemory).mockResolvedValue(mockResult);
const result = await contextManager.loadGlobalMemory();
expect(memoryDiscovery.loadGlobalMemory).toHaveBeenCalledWith(false);
// The path will be relative to CWD (/app), so it might contain ../
expect(result).toMatch(/--- Context from: .*GEMINI.md ---/);
expect(result).toContain('Global Content');
expect(contextManager.getLoadedPaths()).toContain(
'/home/user/.gemini/GEMINI.md',
vi.mocked(memoryDiscovery.loadGlobalMemory).mockResolvedValue(
mockGlobalResult,
);
expect(contextManager.getGlobalMemory()).toBe(result);
});
});
describe('loadEnvironmentMemory', () => {
it('should load and format environment memory', async () => {
const mockResult: memoryDiscovery.MemoryLoadResult = {
const mockEnvResult: memoryDiscovery.MemoryLoadResult = {
files: [{ path: '/app/GEMINI.md', content: 'Env Content' }],
};
vi.mocked(memoryDiscovery.loadEnvironmentMemory).mockResolvedValue(
mockResult,
mockEnvResult,
);
const mockExtensionLoader = {} as unknown as ExtensionLoader;
const result = await contextManager.loadEnvironmentMemory(
['/app'],
mockExtensionLoader,
await contextManager.refresh();
expect(memoryDiscovery.loadGlobalMemory).toHaveBeenCalledWith(false);
expect(contextManager.getGlobalMemory()).toMatch(
/--- Context from: .*GEMINI.md ---/,
);
expect(contextManager.getGlobalMemory()).toContain('Global Content');
expect(memoryDiscovery.loadEnvironmentMemory).toHaveBeenCalledWith(
['/app'],
mockExtensionLoader,
expect.anything(),
false,
);
expect(result).toContain('--- Context from: GEMINI.md ---');
expect(result).toContain('Env Content');
expect(contextManager.getEnvironmentMemory()).toContain(
'--- Context from: GEMINI.md ---',
);
expect(contextManager.getEnvironmentMemory()).toContain('Env Content');
expect(contextManager.getEnvironmentMemory()).toContain(
'MCP Instructions',
);
expect(contextManager.getLoadedPaths()).toContain(
'/home/user/.gemini/GEMINI.md',
);
expect(contextManager.getLoadedPaths()).toContain('/app/GEMINI.md');
expect(contextManager.getEnvironmentMemory()).toBe(result);
});
it('should emit MemoryChanged event when memory is refreshed', async () => {
const mockGlobalResult = {
files: [{ path: '/app/GEMINI.md', content: 'content' }],
};
const mockEnvResult = {
files: [{ path: '/app/src/GEMINI.md', content: 'env content' }],
};
vi.mocked(memoryDiscovery.loadGlobalMemory).mockResolvedValue(
mockGlobalResult,
);
vi.mocked(memoryDiscovery.loadEnvironmentMemory).mockResolvedValue(
mockEnvResult,
);
await contextManager.refresh();
expect(coreEvents.emit).toHaveBeenCalledWith(CoreEvent.MemoryChanged, {
fileCount: 2,
});
});
});
@@ -122,27 +151,4 @@ describe('ContextManager', () => {
expect(result).toBe('');
});
});
describe('reset', () => {
it('should clear loaded paths and memory', async () => {
// Setup some state
const mockResult: memoryDiscovery.MemoryLoadResult = {
files: [
{ path: '/home/user/.gemini/GEMINI.md', content: 'Global Content' },
],
};
vi.mocked(memoryDiscovery.loadGlobalMemory).mockResolvedValue(mockResult);
await contextManager.loadGlobalMemory();
expect(contextManager.getLoadedPaths().size).toBeGreaterThan(0);
expect(contextManager.getGlobalMemory()).toBeTruthy();
// Reset
contextManager.reset();
expect(contextManager.getLoadedPaths().size).toBe(0);
expect(contextManager.getGlobalMemory()).toBe('');
expect(contextManager.getEnvironmentMemory()).toBe('');
});
});
});

View File

@@ -10,8 +10,8 @@ import {
loadJitSubdirectoryMemory,
concatenateInstructions,
} from '../utils/memoryDiscovery.js';
import type { ExtensionLoader } from '../utils/extensionLoader.js';
import type { Config } from '../config/config.js';
import { coreEvents, CoreEvent } from '../utils/events.js';
export class ContextManager {
private readonly loadedPaths: Set<string> = new Set();
@@ -24,36 +24,40 @@ export class ContextManager {
}
/**
* Loads the global memory (Tier 1) and returns the formatted content.
* Refreshes the memory by reloading global and environment memory.
*/
async loadGlobalMemory(): Promise<string> {
async refresh(): Promise<void> {
this.loadedPaths.clear();
await this.loadGlobalMemory();
await this.loadEnvironmentMemory();
this.emitMemoryChanged();
}
private async loadGlobalMemory(): Promise<void> {
const result = await loadGlobalMemory(this.config.getDebugMode());
this.markAsLoaded(result.files.map((f) => f.path));
this.globalMemory = concatenateInstructions(
result.files.map((f) => ({ filePath: f.path, content: f.content })),
this.config.getWorkingDir(),
);
return this.globalMemory;
}
/**
* Loads the environment memory (Tier 2) and returns the formatted content.
*/
async loadEnvironmentMemory(
trustedRoots: string[],
extensionLoader: ExtensionLoader,
): Promise<string> {
private async loadEnvironmentMemory(): Promise<void> {
const result = await loadEnvironmentMemory(
trustedRoots,
extensionLoader,
[...this.config.getWorkspaceContext().getDirectories()],
this.config.getExtensionLoader(),
this.config.getDebugMode(),
);
this.markAsLoaded(result.files.map((f) => f.path));
this.environmentMemory = concatenateInstructions(
const envMemory = concatenateInstructions(
result.files.map((f) => ({ filePath: f.path, content: f.content })),
this.config.getWorkingDir(),
);
return this.environmentMemory;
const mcpInstructions =
this.config.getMcpClientManager()?.getMcpInstructions() || '';
this.environmentMemory = [envMemory, mcpInstructions.trimStart()]
.filter(Boolean)
.join('\n\n');
}
/**
@@ -82,6 +86,12 @@ export class ContextManager {
);
}
private emitMemoryChanged(): void {
coreEvents.emit(CoreEvent.MemoryChanged, {
fileCount: this.loadedPaths.size,
});
}
getGlobalMemory(): string {
return this.globalMemory;
}
@@ -96,15 +106,6 @@ export class ContextManager {
}
}
/**
* Resets the loaded paths tracking and memory. Useful for testing or full reloads.
*/
reset(): void {
this.loadedPaths.clear();
this.globalMemory = '';
this.environmentMemory = '';
}
getLoadedPaths(): ReadonlySet<string> {
return this.loadedPaths;
}

View File

@@ -92,6 +92,7 @@ describe('getEnvironmentContext', () => {
getDirectories: vi.fn().mockReturnValue(['/test/dir']),
}),
getFileService: vi.fn(),
getEnvironmentMemory: vi.fn().mockReturnValue('Mock Environment Memory'),
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
storage: {
@@ -122,6 +123,7 @@ describe('getEnvironmentContext', () => {
expect(context).toContain(
'Here is the folder structure of the current working directories:\n\nMock Folder Structure',
);
expect(context).toContain('Mock Environment Memory');
expect(getFolderStructure).toHaveBeenCalledWith('/test/dir', {
fileService: undefined,
});

View File

@@ -62,6 +62,7 @@ export async function getEnvironmentContext(config: Config): Promise<Part[]> {
const platform = process.platform;
const directoryContext = await getDirectoryContextString(config);
const tempDir = config.storage.getProjectTempDir();
const environmentMemory = config.getEnvironmentMemory();
const context = `
This is the Gemini CLI. We are setting up the context for our chat.
@@ -69,6 +70,8 @@ Today's date is ${today} (formatted according to the user's locale).
My operating system is: ${platform}
The project's temporary directory is: ${tempDir}
${directoryContext}
${environmentMemory}
`.trim();
const initialParts: Part[] = [{ text: context }];

View File

@@ -5,7 +5,6 @@
*/
import { EventEmitter } from 'node:events';
import type { LoadServerHierarchicalMemoryResponse } from './memoryDiscovery.js';
/**
* Defines the severity level for user-facing feedback.
@@ -64,7 +63,9 @@ export interface OutputPayload {
/**
* Payload for the 'memory-changed' event.
*/
export type MemoryChangedPayload = LoadServerHierarchicalMemoryResponse;
export interface MemoryChangedPayload {
fileCount: number;
}
export enum CoreEvent {
UserFeedback = 'user-feedback',

View File

@@ -932,7 +932,9 @@ included directory memory
path.join(extensionPath, 'CustomContext.md'),
);
expect(config.getGeminiMdFilePaths()).equals(refreshResult.filePaths);
expect(mockEventListener).toHaveBeenCalledExactlyOnceWith(refreshResult);
expect(mockEventListener).toHaveBeenCalledExactlyOnceWith({
fileCount: refreshResult.fileCount,
});
});
it('should include MCP instructions in user memory', async () => {

View File

@@ -577,7 +577,7 @@ export async function refreshServerHierarchicalMemory(config: Config) {
config.setUserMemory(finalMemory);
config.setGeminiMdFileCount(result.fileCount);
config.setGeminiMdFilePaths(result.filePaths);
coreEvents.emit(CoreEvent.MemoryChanged, result);
coreEvents.emit(CoreEvent.MemoryChanged, { fileCount: result.fileCount });
return result;
}