Fix mcp instructions (#16439)

This commit is contained in:
christine betts
2026-01-20 11:10:21 -05:00
committed by GitHub
parent 1182168bd9
commit 166e04a8dd
3 changed files with 78 additions and 21 deletions

View File

@@ -1278,6 +1278,24 @@ export class Config {
return this.userMemory;
}
/**
* Refreshes the MCP context, including memory, tools, and system instructions.
*/
async refreshMcpContext(): Promise<void> {
if (this.experimentalJitContext && this.contextManager) {
await this.contextManager.refresh();
} else {
const { refreshServerHierarchicalMemory } = await import(
'../utils/memoryDiscovery.js'
);
await refreshServerHierarchicalMemory(this);
}
if (this.geminiClient?.isInitialized()) {
await this.geminiClient.setTools();
await this.geminiClient.updateSystemInstruction();
}
}
setUserMemory(newUserMemory: string): void {
this.userMemory = newUserMemory;
}

View File

@@ -53,6 +53,7 @@ describe('McpClientManager', () => {
getGeminiClient: vi.fn().mockReturnValue({
isInitialized: vi.fn(),
}),
refreshMcpContext: vi.fn(),
} as unknown as Config);
toolRegistry = {} as ToolRegistry;
});
@@ -69,6 +70,24 @@ describe('McpClientManager', () => {
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
expect(mockConfig.refreshMcpContext).toHaveBeenCalledOnce();
});
it('should batch context refresh when starting multiple servers', async () => {
mockConfig.getMcpServers.mockReturnValue({
'server-1': {},
'server-2': {},
'server-3': {},
});
const manager = new McpClientManager(toolRegistry, mockConfig);
await manager.startConfiguredMcpServers();
// Each client should be connected/discovered
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(3);
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(3);
// But context refresh should happen only once
expect(mockConfig.refreshMcpContext).toHaveBeenCalledOnce();
});
it('should update global discovery state', async () => {
@@ -239,14 +258,16 @@ describe('McpClientManager', () => {
const instructions = manager.getMcpInstructions();
expect(instructions).toContain(
"# Instructions for MCP Server 'server-with-instructions'",
"The following are instructions provided by the tool server 'server-with-instructions':",
);
expect(instructions).toContain('---[start of server instructions]---');
expect(instructions).toContain(
'Instructions for server-with-instructions',
);
expect(instructions).toContain('---[end of server instructions]---');
expect(instructions).not.toContain(
"# Instructions for MCP Server 'server-without-instructions'",
"The following are instructions provided by the tool server 'server-without-instructions':",
);
});
});

View File

@@ -33,6 +33,7 @@ export class McpClientManager {
private discoveryPromise: Promise<void> | undefined;
private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
private readonly eventEmitter?: EventEmitter;
private pendingRefreshPromise: Promise<void> | null = null;
private readonly blockedMcpServers: Array<{
name: string;
extensionName: string;
@@ -66,10 +67,11 @@ export class McpClientManager {
async stopExtension(extension: GeminiCLIExtension) {
debugLogger.log(`Unloading extension: ${extension.name}`);
await Promise.all(
Object.keys(extension.mcpServers ?? {}).map(
this.disconnectClient.bind(this),
Object.keys(extension.mcpServers ?? {}).map((name) =>
this.disconnectClient(name, true),
),
);
await this.cliConfig.refreshMcpContext();
}
/**
@@ -89,6 +91,7 @@ export class McpClientManager {
}),
),
);
await this.cliConfig.refreshMcpContext();
}
private isAllowedMcpServer(name: string) {
@@ -111,7 +114,7 @@ export class McpClientManager {
return true;
}
private async disconnectClient(name: string) {
private async disconnectClient(name: string, skipRefresh = false) {
const existing = this.clients.get(name);
if (existing) {
try {
@@ -123,11 +126,10 @@ export class McpClientManager {
`Error stopping client '${name}': ${getErrorMessage(error)}`,
);
} finally {
// This is required to update the content generator configuration with the
// new tool configuration.
const geminiClient = this.cliConfig.getGeminiClient();
if (geminiClient.isInitialized()) {
await geminiClient.setTools();
if (!skipRefresh) {
// This is required to update the content generator configuration with the
// new tool configuration and system instructions.
await this.cliConfig.refreshMcpContext();
}
}
}
@@ -183,10 +185,7 @@ export class McpClientManager {
this.cliConfig.getDebugMode(),
async () => {
debugLogger.log('Tools changed, updating Gemini context...');
const geminiClient = this.cliConfig.getGeminiClient();
if (geminiClient.isInitialized()) {
await geminiClient.setTools();
}
await this.scheduleMcpContextRefresh();
},
);
if (!existing) {
@@ -219,12 +218,6 @@ export class McpClientManager {
error,
);
} finally {
// This is required to update the content generator configuration with the
// new tool configuration.
const geminiClient = this.cliConfig.getGeminiClient();
if (geminiClient.isInitialized()) {
await geminiClient.setTools();
}
resolve();
}
})().catch(reject);
@@ -282,6 +275,7 @@ export class McpClientManager {
this.maybeDiscoverMcpServer(name, config),
),
);
await this.cliConfig.refreshMcpContext();
}
/**
@@ -303,6 +297,7 @@ export class McpClientManager {
}
}),
);
await this.cliConfig.refreshMcpContext();
}
/**
@@ -314,6 +309,7 @@ export class McpClientManager {
throw new Error(`No MCP server registered with the name "${name}"`);
}
await this.maybeDiscoverMcpServer(name, client.getServerConfig());
await this.cliConfig.refreshMcpContext();
}
/**
@@ -360,13 +356,35 @@ export class McpClientManager {
const clientInstructions = client.getInstructions();
if (clientInstructions) {
instructions.push(
`# Instructions for MCP Server '${name}'\n${clientInstructions}`,
`The following are instructions provided by the tool server '${name}':\n---[start of server instructions]---\n${clientInstructions}\n---[end of server instructions]---`,
);
}
}
return instructions.join('\n\n');
}
private async scheduleMcpContextRefresh(): Promise<void> {
if (this.pendingRefreshPromise) {
return this.pendingRefreshPromise;
}
this.pendingRefreshPromise = (async () => {
// Debounce to coalesce multiple rapid updates
await new Promise((resolve) => setTimeout(resolve, 300));
try {
await this.cliConfig.refreshMcpContext();
} catch (error) {
debugLogger.error(
`Error refreshing MCP context: ${getErrorMessage(error)}`,
);
} finally {
this.pendingRefreshPromise = null;
}
})();
return this.pendingRefreshPromise;
}
getMcpServerCount(): number {
return this.clients.size;
}