Add support for running available commands prior to MCP servers loading (#15596)

This commit is contained in:
Adib234
2026-01-15 15:33:16 -05:00
committed by GitHub
parent 8a627d6c9a
commit 1e8f87fbdf
7 changed files with 230 additions and 8 deletions
@@ -28,6 +28,9 @@ import {
ToolConfirmationOutcome, ToolConfirmationOutcome,
Storage, Storage,
IdeClient, IdeClient,
addMCPStatusChangeListener,
removeMCPStatusChangeListener,
MCPDiscoveryState,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import { useSessionStats } from '../contexts/SessionContext.js'; import { useSessionStats } from '../contexts/SessionContext.js';
import type { import type {
@@ -269,6 +272,10 @@ export const useSlashCommandProcessor = (
ideClient.addStatusChangeListener(listener); ideClient.addStatusChangeListener(listener);
})(); })();
// Listen for MCP server status changes (e.g. connection, discovery completion)
// to reload slash commands (since they may include MCP prompts).
addMCPStatusChangeListener(listener);
// TODO: Ideally this would happen more directly inside the ExtensionLoader, // TODO: Ideally this would happen more directly inside the ExtensionLoader,
// but the CommandService today is not conducive to that since it isn't a // but the CommandService today is not conducive to that since it isn't a
// long lived service but instead gets fully re-created based on reload // long lived service but instead gets fully re-created based on reload
@@ -289,6 +296,7 @@ export const useSlashCommandProcessor = (
const ideClient = await IdeClient.getInstance(); const ideClient = await IdeClient.getInstance();
ideClient.removeStatusChangeListener(listener); ideClient.removeStatusChangeListener(listener);
})(); })();
removeMCPStatusChangeListener(listener);
appEvents.off('extensionsStarting', extensionEventListener); appEvents.off('extensionsStarting', extensionEventListener);
appEvents.off('extensionsStopping', extensionEventListener); appEvents.off('extensionsStopping', extensionEventListener);
}; };
@@ -572,9 +580,16 @@ export const useSlashCommandProcessor = (
} }
} }
const isMcpLoading =
config?.getMcpClientManager()?.getDiscoveryState() ===
MCPDiscoveryState.IN_PROGRESS;
const errorMessage = isMcpLoading
? `Unknown command: ${trimmed}. Command might have been from an MCP server but MCP servers are not done loading.`
: `Unknown command: ${trimmed}`;
addMessage({ addMessage({
type: MessageType.ERROR, type: MessageType.ERROR,
content: `Unknown command: ${trimmed}`, content: errorMessage,
timestamp: new Date(), timestamp: new Date(),
}); });
@@ -36,6 +36,7 @@ import {
debugLogger, debugLogger,
coreEvents, coreEvents,
CoreEvent, CoreEvent,
MCPDiscoveryState,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import type { Part, PartListUnion } from '@google/genai'; import type { Part, PartListUnion } from '@google/genai';
import type { UseHistoryManagerReturn } from './useHistoryManager.js'; import type { UseHistoryManagerReturn } from './useHistoryManager.js';
@@ -178,6 +179,11 @@ describe('useGeminiStream', () => {
return clientInstance; return clientInstance;
}); });
const mockMcpClientManager = {
getDiscoveryState: vi.fn().mockReturnValue(MCPDiscoveryState.COMPLETED),
getMcpServerCount: vi.fn().mockReturnValue(0),
};
const contentGeneratorConfig = { const contentGeneratorConfig = {
model: 'test-model', model: 'test-model',
apiKey: 'test-key', apiKey: 'test-key',
@@ -211,6 +217,7 @@ describe('useGeminiStream', () => {
getProjectRoot: vi.fn(() => '/test/dir'), getProjectRoot: vi.fn(() => '/test/dir'),
getCheckpointingEnabled: vi.fn(() => false), getCheckpointingEnabled: vi.fn(() => false),
getGeminiClient: mockGetGeminiClient, getGeminiClient: mockGetGeminiClient,
getMcpClientManager: () => mockMcpClientManager as any,
getApprovalMode: () => ApprovalMode.DEFAULT, getApprovalMode: () => ApprovalMode.DEFAULT,
getUsageStatisticsEnabled: () => true, getUsageStatisticsEnabled: () => true,
getDebugMode: () => false, getDebugMode: () => false,
@@ -254,6 +261,7 @@ describe('useGeminiStream', () => {
.mockClear() .mockClear()
.mockReturnValue((async function* () {})()); .mockReturnValue((async function* () {})());
handleAtCommandSpy = vi.spyOn(atCommandProcessor, 'handleAtCommand'); handleAtCommandSpy = vi.spyOn(atCommandProcessor, 'handleAtCommand');
vi.spyOn(coreEvents, 'emitFeedback');
}); });
const mockLoadedSettings: LoadedSettings = { const mockLoadedSettings: LoadedSettings = {
@@ -1954,6 +1962,73 @@ describe('useGeminiStream', () => {
}); });
}); });
describe('MCP Discovery State', () => {
it('should block non-slash command queries when discovery is in progress and servers exist', async () => {
const mockMcpClientManager = {
getDiscoveryState: vi
.fn()
.mockReturnValue(MCPDiscoveryState.IN_PROGRESS),
getMcpServerCount: vi.fn().mockReturnValue(1),
};
mockConfig.getMcpClientManager = () => mockMcpClientManager as any;
const { result } = renderTestHook();
await act(async () => {
await result.current.submitQuery('test query');
});
expect(coreEvents.emitFeedback).toHaveBeenCalledWith(
'info',
'Waiting for MCP servers to initialize... Slash commands are still available.',
);
expect(mockSendMessageStream).not.toHaveBeenCalled();
});
it('should NOT block queries when discovery is NOT_STARTED but there are no servers', async () => {
const mockMcpClientManager = {
getDiscoveryState: vi
.fn()
.mockReturnValue(MCPDiscoveryState.NOT_STARTED),
getMcpServerCount: vi.fn().mockReturnValue(0),
};
mockConfig.getMcpClientManager = () => mockMcpClientManager as any;
const { result } = renderTestHook();
await act(async () => {
await result.current.submitQuery('test query');
});
expect(coreEvents.emitFeedback).not.toHaveBeenCalledWith(
'info',
'Waiting for MCP servers to initialize... Slash commands are still available.',
);
expect(mockSendMessageStream).toHaveBeenCalled();
});
it('should NOT block slash commands even when discovery is in progress', async () => {
const mockMcpClientManager = {
getDiscoveryState: vi
.fn()
.mockReturnValue(MCPDiscoveryState.IN_PROGRESS),
getMcpServerCount: vi.fn().mockReturnValue(1),
};
mockConfig.getMcpClientManager = () => mockMcpClientManager as any;
const { result } = renderTestHook();
await act(async () => {
await result.current.submitQuery('/help');
});
expect(coreEvents.emitFeedback).not.toHaveBeenCalledWith(
'info',
'Waiting for MCP servers to initialize... Slash commands are still available.',
);
});
});
describe('handleFinishedEvent', () => { describe('handleFinishedEvent', () => {
it('should add info message for MAX_TOKENS finish reason', async () => { it('should add info message for MAX_TOKENS finish reason', async () => {
// Setup mock to return a stream with MAX_TOKENS finish reason // Setup mock to return a stream with MAX_TOKENS finish reason
@@ -3015,4 +3090,68 @@ describe('useGeminiStream', () => {
}); });
}); });
}); });
describe('MCP Server Initialization', () => {
it('should allow slash commands to run while MCP servers are initializing', async () => {
const mockMcpClientManager = {
getDiscoveryState: vi
.fn()
.mockReturnValue(MCPDiscoveryState.IN_PROGRESS),
getMcpServerCount: vi.fn().mockReturnValue(1),
};
mockConfig.getMcpClientManager = () => mockMcpClientManager as any;
const { result } = renderTestHook();
await act(async () => {
await result.current.submitQuery('/help');
});
// Slash command should be handled, and no Gemini call should be made.
expect(mockHandleSlashCommand).toHaveBeenCalledWith('/help');
expect(coreEvents.emitFeedback).not.toHaveBeenCalled();
});
it('should block normal prompts and provide feedback while MCP servers are initializing', async () => {
const mockMcpClientManager = {
getDiscoveryState: vi
.fn()
.mockReturnValue(MCPDiscoveryState.IN_PROGRESS),
getMcpServerCount: vi.fn().mockReturnValue(1),
};
mockConfig.getMcpClientManager = () => mockMcpClientManager as any;
const { result } = renderTestHook();
await act(async () => {
await result.current.submitQuery('a normal prompt');
});
// No slash command, no Gemini call, but feedback should be emitted.
expect(mockHandleSlashCommand).not.toHaveBeenCalled();
expect(mockSendMessageStream).not.toHaveBeenCalled();
expect(coreEvents.emitFeedback).toHaveBeenCalledWith(
'info',
'Waiting for MCP servers to initialize... Slash commands are still available.',
);
});
it('should allow normal prompts to run when MCP servers are finished initializing', async () => {
const mockMcpClientManager = {
getDiscoveryState: vi.fn().mockReturnValue(MCPDiscoveryState.COMPLETED),
getMcpServerCount: vi.fn().mockReturnValue(1),
};
mockConfig.getMcpClientManager = () => mockMcpClientManager as any;
const { result } = renderTestHook();
await act(async () => {
await result.current.submitQuery('a normal prompt');
});
// Prompt should be sent to Gemini.
expect(mockHandleSlashCommand).not.toHaveBeenCalled();
expect(mockSendMessageStream).toHaveBeenCalled();
expect(coreEvents.emitFeedback).not.toHaveBeenCalled();
});
});
}); });
@@ -30,6 +30,7 @@ import {
ToolErrorType, ToolErrorType,
coreEvents, coreEvents,
CoreEvent, CoreEvent,
MCPDiscoveryState,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import type { import type {
Config, Config,
@@ -951,6 +952,26 @@ export const useGeminiStream = (
{ name: 'submitQuery' }, { name: 'submitQuery' },
async ({ metadata: spanMetadata }) => { async ({ metadata: spanMetadata }) => {
spanMetadata.input = query; spanMetadata.input = query;
const discoveryState = config
.getMcpClientManager()
?.getDiscoveryState();
const mcpServerCount =
config.getMcpClientManager()?.getMcpServerCount() ?? 0;
if (
!options?.isContinuation &&
typeof query === 'string' &&
!isSlashCommand(query.trim()) &&
mcpServerCount > 0 &&
discoveryState !== MCPDiscoveryState.COMPLETED
) {
coreEvents.emitFeedback(
'info',
'Waiting for MCP servers to initialize... Slash commands are still available.',
);
return;
}
const queryId = `${Date.now()}-${Math.random()}`; const queryId = `${Date.now()}-${Math.random()}`;
activeQueryIdRef.current = queryId; activeQueryIdRef.current = queryId;
if ( if (
+29
View File
@@ -274,6 +274,35 @@ describe('Server Config (config.ts)', () => {
); );
}); });
it('should not await MCP initialization', async () => {
const config = new Config({
...baseParams,
checkpointing: false,
});
const { McpClientManager } = await import(
'../tools/mcp-client-manager.js'
);
let mcpStarted = false;
(McpClientManager as unknown as Mock).mockImplementation(() => ({
startConfiguredMcpServers: vi.fn().mockImplementation(async () => {
await new Promise((resolve) => setTimeout(resolve, 50));
mcpStarted = true;
}),
getMcpInstructions: vi.fn(),
}));
await config.initialize();
// Should return immediately, before MCP finishes (50ms delay)
expect(mcpStarted).toBe(false);
// Wait for it to eventually finish to avoid open handles
await new Promise((resolve) => setTimeout(resolve, 60));
expect(mcpStarted).toBe(true);
});
describe('getCompressionThreshold', () => { describe('getCompressionThreshold', () => {
it('should return the local compression threshold if it is set', async () => { it('should return the local compression threshold if it is set', async () => {
const config = new Config({ const config = new Config({
+8 -6
View File
@@ -796,12 +796,14 @@ export class Config {
this, this,
this.eventEmitter, this.eventEmitter,
); );
const initMcpHandle = startupProfiler.start('initialize_mcp_clients'); // We do not await this promise so that the CLI can start up even if
await Promise.all([ // MCP servers are slow to connect.
await this.mcpClientManager.startConfiguredMcpServers(), Promise.all([
await this.getExtensionLoader().start(this), this.mcpClientManager.startConfiguredMcpServers(),
]); this.getExtensionLoader().start(this),
initMcpHandle?.end(); ]).catch((error) => {
debugLogger.error('Error initializing MCP clients:', error);
});
if (this.skillsSupport) { if (this.skillsSupport) {
this.getSkillManager().setAdminSettings(this.adminSkillsEnabled); this.getSkillManager().setAdminSettings(this.adminSkillsEnabled);
@@ -14,7 +14,7 @@ import {
type MockedObject, type MockedObject,
} from 'vitest'; } from 'vitest';
import { McpClientManager } from './mcp-client-manager.js'; import { McpClientManager } from './mcp-client-manager.js';
import { McpClient } from './mcp-client.js'; import { McpClient, MCPDiscoveryState } from './mcp-client.js';
import type { ToolRegistry } from './tool-registry.js'; import type { ToolRegistry } from './tool-registry.js';
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
@@ -71,6 +71,18 @@ describe('McpClientManager', () => {
expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
}); });
it('should update global discovery state', async () => {
mockConfig.getMcpServers.mockReturnValue({
'test-server': {},
});
const manager = new McpClientManager(toolRegistry, mockConfig);
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.NOT_STARTED);
const promise = manager.startConfiguredMcpServers();
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.IN_PROGRESS);
await promise;
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.COMPLETED);
});
it('should not discover tools if folder is not trusted', async () => { it('should not discover tools if folder is not trusted', async () => {
mockConfig.getMcpServers.mockReturnValue({ mockConfig.getMcpServers.mockReturnValue({
'test-server': {}, 'test-server': {},
@@ -362,4 +362,8 @@ export class McpClientManager {
} }
return instructions.join('\n\n'); return instructions.join('\n\n');
} }
getMcpServerCount(): number {
return this.clients.size;
}
} }