Fix mcp instructions (#16439)

This commit is contained in:
christine betts
2026-01-20 11:10:21 -05:00
committed by GitHub
parent 1182168bd9
commit 166e04a8dd
3 changed files with 78 additions and 21 deletions

View File

@@ -1278,6 +1278,24 @@ export class Config {
return this.userMemory; return this.userMemory;
} }
/**
* Refreshes the MCP context, including memory, tools, and system instructions.
*/
async refreshMcpContext(): Promise<void> {
if (this.experimentalJitContext && this.contextManager) {
await this.contextManager.refresh();
} else {
const { refreshServerHierarchicalMemory } = await import(
'../utils/memoryDiscovery.js'
);
await refreshServerHierarchicalMemory(this);
}
if (this.geminiClient?.isInitialized()) {
await this.geminiClient.setTools();
await this.geminiClient.updateSystemInstruction();
}
}
setUserMemory(newUserMemory: string): void { setUserMemory(newUserMemory: string): void {
this.userMemory = newUserMemory; this.userMemory = newUserMemory;
} }

View File

@@ -53,6 +53,7 @@ describe('McpClientManager', () => {
getGeminiClient: vi.fn().mockReturnValue({ getGeminiClient: vi.fn().mockReturnValue({
isInitialized: vi.fn(), isInitialized: vi.fn(),
}), }),
refreshMcpContext: vi.fn(),
} as unknown as Config); } as unknown as Config);
toolRegistry = {} as ToolRegistry; toolRegistry = {} as ToolRegistry;
}); });
@@ -69,6 +70,24 @@ describe('McpClientManager', () => {
await manager.startConfiguredMcpServers(); await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).toHaveBeenCalledOnce(); expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
expect(mockConfig.refreshMcpContext).toHaveBeenCalledOnce();
});
it('should batch context refresh when starting multiple servers', async () => {
mockConfig.getMcpServers.mockReturnValue({
'server-1': {},
'server-2': {},
'server-3': {},
});
const manager = new McpClientManager(toolRegistry, mockConfig);
await manager.startConfiguredMcpServers();
// Each client should be connected/discovered
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(3);
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(3);
// But context refresh should happen only once
expect(mockConfig.refreshMcpContext).toHaveBeenCalledOnce();
}); });
it('should update global discovery state', async () => { it('should update global discovery state', async () => {
@@ -239,14 +258,16 @@ describe('McpClientManager', () => {
const instructions = manager.getMcpInstructions(); const instructions = manager.getMcpInstructions();
expect(instructions).toContain( expect(instructions).toContain(
"# Instructions for MCP Server 'server-with-instructions'", "The following are instructions provided by the tool server 'server-with-instructions':",
); );
expect(instructions).toContain('---[start of server instructions]---');
expect(instructions).toContain( expect(instructions).toContain(
'Instructions for server-with-instructions', 'Instructions for server-with-instructions',
); );
expect(instructions).toContain('---[end of server instructions]---');
expect(instructions).not.toContain( expect(instructions).not.toContain(
"# Instructions for MCP Server 'server-without-instructions'", "The following are instructions provided by the tool server 'server-without-instructions':",
); );
}); });
}); });

View File

@@ -33,6 +33,7 @@ export class McpClientManager {
private discoveryPromise: Promise<void> | undefined; private discoveryPromise: Promise<void> | undefined;
private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED; private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
private readonly eventEmitter?: EventEmitter; private readonly eventEmitter?: EventEmitter;
private pendingRefreshPromise: Promise<void> | null = null;
private readonly blockedMcpServers: Array<{ private readonly blockedMcpServers: Array<{
name: string; name: string;
extensionName: string; extensionName: string;
@@ -66,10 +67,11 @@ export class McpClientManager {
async stopExtension(extension: GeminiCLIExtension) { async stopExtension(extension: GeminiCLIExtension) {
debugLogger.log(`Unloading extension: ${extension.name}`); debugLogger.log(`Unloading extension: ${extension.name}`);
await Promise.all( await Promise.all(
Object.keys(extension.mcpServers ?? {}).map( Object.keys(extension.mcpServers ?? {}).map((name) =>
this.disconnectClient.bind(this), this.disconnectClient(name, true),
), ),
); );
await this.cliConfig.refreshMcpContext();
} }
/** /**
@@ -89,6 +91,7 @@ export class McpClientManager {
}), }),
), ),
); );
await this.cliConfig.refreshMcpContext();
} }
private isAllowedMcpServer(name: string) { private isAllowedMcpServer(name: string) {
@@ -111,7 +114,7 @@ export class McpClientManager {
return true; return true;
} }
private async disconnectClient(name: string) { private async disconnectClient(name: string, skipRefresh = false) {
const existing = this.clients.get(name); const existing = this.clients.get(name);
if (existing) { if (existing) {
try { try {
@@ -123,11 +126,10 @@ export class McpClientManager {
`Error stopping client '${name}': ${getErrorMessage(error)}`, `Error stopping client '${name}': ${getErrorMessage(error)}`,
); );
} finally { } finally {
// This is required to update the content generator configuration with the if (!skipRefresh) {
// new tool configuration. // This is required to update the content generator configuration with the
const geminiClient = this.cliConfig.getGeminiClient(); // new tool configuration and system instructions.
if (geminiClient.isInitialized()) { await this.cliConfig.refreshMcpContext();
await geminiClient.setTools();
} }
} }
} }
@@ -183,10 +185,7 @@ export class McpClientManager {
this.cliConfig.getDebugMode(), this.cliConfig.getDebugMode(),
async () => { async () => {
debugLogger.log('Tools changed, updating Gemini context...'); debugLogger.log('Tools changed, updating Gemini context...');
const geminiClient = this.cliConfig.getGeminiClient(); await this.scheduleMcpContextRefresh();
if (geminiClient.isInitialized()) {
await geminiClient.setTools();
}
}, },
); );
if (!existing) { if (!existing) {
@@ -219,12 +218,6 @@ export class McpClientManager {
error, error,
); );
} finally { } finally {
// This is required to update the content generator configuration with the
// new tool configuration.
const geminiClient = this.cliConfig.getGeminiClient();
if (geminiClient.isInitialized()) {
await geminiClient.setTools();
}
resolve(); resolve();
} }
})().catch(reject); })().catch(reject);
@@ -282,6 +275,7 @@ export class McpClientManager {
this.maybeDiscoverMcpServer(name, config), this.maybeDiscoverMcpServer(name, config),
), ),
); );
await this.cliConfig.refreshMcpContext();
} }
/** /**
@@ -303,6 +297,7 @@ export class McpClientManager {
} }
}), }),
); );
await this.cliConfig.refreshMcpContext();
} }
/** /**
@@ -314,6 +309,7 @@ export class McpClientManager {
throw new Error(`No MCP server registered with the name "${name}"`); throw new Error(`No MCP server registered with the name "${name}"`);
} }
await this.maybeDiscoverMcpServer(name, client.getServerConfig()); await this.maybeDiscoverMcpServer(name, client.getServerConfig());
await this.cliConfig.refreshMcpContext();
} }
/** /**
@@ -360,13 +356,35 @@ export class McpClientManager {
const clientInstructions = client.getInstructions(); const clientInstructions = client.getInstructions();
if (clientInstructions) { if (clientInstructions) {
instructions.push( instructions.push(
`# Instructions for MCP Server '${name}'\n${clientInstructions}`, `The following are instructions provided by the tool server '${name}':\n---[start of server instructions]---\n${clientInstructions}\n---[end of server instructions]---`,
); );
} }
} }
return instructions.join('\n\n'); return instructions.join('\n\n');
} }
private async scheduleMcpContextRefresh(): Promise<void> {
if (this.pendingRefreshPromise) {
return this.pendingRefreshPromise;
}
this.pendingRefreshPromise = (async () => {
// Debounce to coalesce multiple rapid updates
await new Promise((resolve) => setTimeout(resolve, 300));
try {
await this.cliConfig.refreshMcpContext();
} catch (error) {
debugLogger.error(
`Error refreshing MCP context: ${getErrorMessage(error)}`,
);
} finally {
this.pendingRefreshPromise = null;
}
})();
return this.pendingRefreshPromise;
}
getMcpServerCount(): number { getMcpServerCount(): number {
return this.clients.size; return this.clients.size;
} }