diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 546ac455b6..097146276e 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -1278,6 +1278,24 @@ export class Config { return this.userMemory; } + /** + * Refreshes the MCP context, including memory, tools, and system instructions. + */ + async refreshMcpContext(): Promise { + if (this.experimentalJitContext && this.contextManager) { + await this.contextManager.refresh(); + } else { + const { refreshServerHierarchicalMemory } = await import( + '../utils/memoryDiscovery.js' + ); + await refreshServerHierarchicalMemory(this); + } + if (this.geminiClient?.isInitialized()) { + await this.geminiClient.setTools(); + await this.geminiClient.updateSystemInstruction(); + } + } + setUserMemory(newUserMemory: string): void { this.userMemory = newUserMemory; } diff --git a/packages/core/src/tools/mcp-client-manager.test.ts b/packages/core/src/tools/mcp-client-manager.test.ts index ba035d54f1..27c6984a7c 100644 --- a/packages/core/src/tools/mcp-client-manager.test.ts +++ b/packages/core/src/tools/mcp-client-manager.test.ts @@ -53,6 +53,7 @@ describe('McpClientManager', () => { getGeminiClient: vi.fn().mockReturnValue({ isInitialized: vi.fn(), }), + refreshMcpContext: vi.fn(), } as unknown as Config); toolRegistry = {} as ToolRegistry; }); @@ -69,6 +70,24 @@ describe('McpClientManager', () => { await manager.startConfiguredMcpServers(); expect(mockedMcpClient.connect).toHaveBeenCalledOnce(); expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); + expect(mockConfig.refreshMcpContext).toHaveBeenCalledOnce(); + }); + + it('should batch context refresh when starting multiple servers', async () => { + mockConfig.getMcpServers.mockReturnValue({ + 'server-1': {}, + 'server-2': {}, + 'server-3': {}, + }); + const manager = new McpClientManager(toolRegistry, mockConfig); + await manager.startConfiguredMcpServers(); + + // Each client should be connected/discovered + expect(mockedMcpClient.connect).toHaveBeenCalledTimes(3); + expect(mockedMcpClient.discover).toHaveBeenCalledTimes(3); + + // But context refresh should happen only once + expect(mockConfig.refreshMcpContext).toHaveBeenCalledOnce(); }); it('should update global discovery state', async () => { @@ -239,14 +258,16 @@ describe('McpClientManager', () => { const instructions = manager.getMcpInstructions(); expect(instructions).toContain( - "# Instructions for MCP Server 'server-with-instructions'", + "The following are instructions provided by the tool server 'server-with-instructions':", ); + expect(instructions).toContain('---[start of server instructions]---'); expect(instructions).toContain( 'Instructions for server-with-instructions', ); + expect(instructions).toContain('---[end of server instructions]---'); expect(instructions).not.toContain( - "# Instructions for MCP Server 'server-without-instructions'", + "The following are instructions provided by the tool server 'server-without-instructions':", ); }); }); diff --git a/packages/core/src/tools/mcp-client-manager.ts b/packages/core/src/tools/mcp-client-manager.ts index cc4602334c..a4619756f0 100644 --- a/packages/core/src/tools/mcp-client-manager.ts +++ b/packages/core/src/tools/mcp-client-manager.ts @@ -33,6 +33,7 @@ export class McpClientManager { private discoveryPromise: Promise | undefined; private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED; private readonly eventEmitter?: EventEmitter; + private pendingRefreshPromise: Promise | null = null; private readonly blockedMcpServers: Array<{ name: string; extensionName: string; @@ -66,10 +67,11 @@ export class McpClientManager { async stopExtension(extension: GeminiCLIExtension) { debugLogger.log(`Unloading extension: ${extension.name}`); await Promise.all( - Object.keys(extension.mcpServers ?? {}).map( - this.disconnectClient.bind(this), + Object.keys(extension.mcpServers ?? {}).map((name) => + this.disconnectClient(name, true), ), ); + await this.cliConfig.refreshMcpContext(); } /** @@ -89,6 +91,7 @@ export class McpClientManager { }), ), ); + await this.cliConfig.refreshMcpContext(); } private isAllowedMcpServer(name: string) { @@ -111,7 +114,7 @@ export class McpClientManager { return true; } - private async disconnectClient(name: string) { + private async disconnectClient(name: string, skipRefresh = false) { const existing = this.clients.get(name); if (existing) { try { @@ -123,11 +126,10 @@ export class McpClientManager { `Error stopping client '${name}': ${getErrorMessage(error)}`, ); } finally { - // This is required to update the content generator configuration with the - // new tool configuration. - const geminiClient = this.cliConfig.getGeminiClient(); - if (geminiClient.isInitialized()) { - await geminiClient.setTools(); + if (!skipRefresh) { + // This is required to update the content generator configuration with the + // new tool configuration and system instructions. + await this.cliConfig.refreshMcpContext(); } } } @@ -183,10 +185,7 @@ export class McpClientManager { this.cliConfig.getDebugMode(), async () => { debugLogger.log('Tools changed, updating Gemini context...'); - const geminiClient = this.cliConfig.getGeminiClient(); - if (geminiClient.isInitialized()) { - await geminiClient.setTools(); - } + await this.scheduleMcpContextRefresh(); }, ); if (!existing) { @@ -219,12 +218,6 @@ export class McpClientManager { error, ); } finally { - // This is required to update the content generator configuration with the - // new tool configuration. - const geminiClient = this.cliConfig.getGeminiClient(); - if (geminiClient.isInitialized()) { - await geminiClient.setTools(); - } resolve(); } })().catch(reject); @@ -282,6 +275,7 @@ export class McpClientManager { this.maybeDiscoverMcpServer(name, config), ), ); + await this.cliConfig.refreshMcpContext(); } /** @@ -303,6 +297,7 @@ export class McpClientManager { } }), ); + await this.cliConfig.refreshMcpContext(); } /** @@ -314,6 +309,7 @@ export class McpClientManager { throw new Error(`No MCP server registered with the name "${name}"`); } await this.maybeDiscoverMcpServer(name, client.getServerConfig()); + await this.cliConfig.refreshMcpContext(); } /** @@ -360,13 +356,35 @@ export class McpClientManager { const clientInstructions = client.getInstructions(); if (clientInstructions) { instructions.push( - `# Instructions for MCP Server '${name}'\n${clientInstructions}`, + `The following are instructions provided by the tool server '${name}':\n---[start of server instructions]---\n${clientInstructions}\n---[end of server instructions]---`, ); } } return instructions.join('\n\n'); } + private async scheduleMcpContextRefresh(): Promise { + if (this.pendingRefreshPromise) { + return this.pendingRefreshPromise; + } + + this.pendingRefreshPromise = (async () => { + // Debounce to coalesce multiple rapid updates + await new Promise((resolve) => setTimeout(resolve, 300)); + try { + await this.cliConfig.refreshMcpContext(); + } catch (error) { + debugLogger.error( + `Error refreshing MCP context: ${getErrorMessage(error)}`, + ); + } finally { + this.pendingRefreshPromise = null; + } + })(); + + return this.pendingRefreshPromise; + } + getMcpServerCount(): number { return this.clients.size; }