fix(policy): resolve regressions and refine sensitive tool plumbing

This commit is contained in:
Spencer
2026-02-27 23:06:07 +00:00
parent 903c9f79b3
commit 99bc45b689
3 changed files with 419 additions and 233 deletions
+190 -90
View File
@@ -33,6 +33,7 @@ import {
isEnabled, isEnabled,
McpClient, McpClient,
populateMcpServerCommand, populateMcpServerCommand,
type McpContext,
} from './mcp-client.js'; } from './mcp-client.js';
import type { ToolRegistry } from './tool-registry.js'; import type { ToolRegistry } from './tool-registry.js';
import type { ResourceRegistry } from '../resources/resource-registry.js'; import type { ResourceRegistry } from '../resources/resource-registry.js';
@@ -42,12 +43,28 @@ import * as path from 'node:path';
import { coreEvents } from '../utils/events.js'; import { coreEvents } from '../utils/events.js';
import type { EnvironmentSanitizationConfig } from '../services/environmentSanitization.js'; import type { EnvironmentSanitizationConfig } from '../services/environmentSanitization.js';
interface TestableTransport {
_authProvider?: GoogleCredentialProvider;
_requestInit?: {
headers?: Record<string, string>;
};
}
const EMPTY_CONFIG: EnvironmentSanitizationConfig = { const EMPTY_CONFIG: EnvironmentSanitizationConfig = {
enableEnvironmentVariableRedaction: true, enableEnvironmentVariableRedaction: true,
allowedEnvironmentVariables: [], allowedEnvironmentVariables: [],
blockedEnvironmentVariables: [], blockedEnvironmentVariables: [],
}; };
const MOCK_CONTEXT_DEFAULT = {
sanitizationConfig: EMPTY_CONFIG,
emitMcpDiagnostic: vi.fn(),
setUserInteractedWithMcp: vi.fn(),
isTrustedFolder: vi.fn().mockReturnValue(true),
};
let MOCK_CONTEXT: McpContext = MOCK_CONTEXT_DEFAULT;
vi.mock('@modelcontextprotocol/sdk/client/stdio.js'); vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
vi.mock('@modelcontextprotocol/sdk/client/index.js'); vi.mock('@modelcontextprotocol/sdk/client/index.js');
vi.mock('@google/genai'); vi.mock('@google/genai');
@@ -69,6 +86,12 @@ describe('mcp-client', () => {
let testWorkspace: string; let testWorkspace: string;
beforeEach(() => { beforeEach(() => {
MOCK_CONTEXT = {
sanitizationConfig: EMPTY_CONFIG,
emitMcpDiagnostic: vi.fn(),
setUserInteractedWithMcp: vi.fn(),
isTrustedFolder: vi.fn().mockReturnValue(true),
};
// create a tmp dir for this test // create a tmp dir for this test
// Create a unique temporary directory for the workspace to avoid conflicts // Create a unique temporary directory for the workspace to avoid conflicts
testWorkspace = fs.mkdtempSync( testWorkspace = fs.mkdtempSync(
@@ -136,12 +159,12 @@ describe('mcp-client', () => {
promptRegistry, promptRegistry,
resourceRegistry, resourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
await client.connect(); await client.connect();
await client.discover({} as Config); await client.discover(MOCK_CONTEXT);
expect(mockedClient.listTools).toHaveBeenCalledWith( expect(mockedClient.listTools).toHaveBeenCalledWith(
{}, {},
expect.objectContaining({ timeout: 600000, progressReporter: client }), expect.objectContaining({ timeout: 600000, progressReporter: client }),
@@ -217,12 +240,12 @@ describe('mcp-client', () => {
promptRegistry, promptRegistry,
resourceRegistry, resourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
await client.connect(); await client.connect();
await client.discover({} as Config); await client.discover(MOCK_CONTEXT);
expect(mockedToolRegistry.registerTool).toHaveBeenCalledTimes(2); expect(mockedToolRegistry.registerTool).toHaveBeenCalledTimes(2);
expect(consoleWarnSpy).not.toHaveBeenCalled(); expect(consoleWarnSpy).not.toHaveBeenCalled();
consoleWarnSpy.mockRestore(); consoleWarnSpy.mockRestore();
@@ -269,16 +292,17 @@ describe('mcp-client', () => {
promptRegistry, promptRegistry,
resourceRegistry, resourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
await client.connect(); await client.connect();
await expect(client.discover({} as Config)).rejects.toThrow('Test error'); await expect(client.discover(MOCK_CONTEXT)).rejects.toThrow('Test error');
expect(coreEvents.emitFeedback).toHaveBeenCalledWith( expect(MOCK_CONTEXT.emitMcpDiagnostic).toHaveBeenCalledWith(
'error', 'error',
`Error discovering prompts from test-server: Test error`, `Error discovering prompts from test-server: Test error`,
expect.any(Error), expect.any(Error),
'test-server',
); );
}); });
@@ -323,12 +347,12 @@ describe('mcp-client', () => {
promptRegistry, promptRegistry,
resourceRegistry, resourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
await client.connect(); await client.connect();
await expect(client.discover({} as Config)).rejects.toThrow( await expect(client.discover(MOCK_CONTEXT)).rejects.toThrow(
'No prompts, tools, or resources found on the server.', 'No prompts, tools, or resources found on the server.',
); );
}); });
@@ -383,12 +407,12 @@ describe('mcp-client', () => {
promptRegistry, promptRegistry,
resourceRegistry, resourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
await client.connect(); await client.connect();
await client.discover({} as Config); await client.discover(MOCK_CONTEXT);
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
}); });
@@ -451,7 +475,7 @@ describe('mcp-client', () => {
promptRegistry, promptRegistry,
resourceRegistry, resourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
@@ -532,7 +556,7 @@ describe('mcp-client', () => {
promptRegistry, promptRegistry,
resourceRegistry, resourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
@@ -610,7 +634,7 @@ describe('mcp-client', () => {
promptRegistry, promptRegistry,
resourceRegistry, resourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
@@ -696,12 +720,12 @@ describe('mcp-client', () => {
promptRegistry, promptRegistry,
resourceRegistry, resourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
await client.connect(); await client.connect();
await client.discover({} as Config); await client.discover(MOCK_CONTEXT);
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
const registeredTool = vi.mocked(mockedToolRegistry.registerTool).mock const registeredTool = vi.mocked(mockedToolRegistry.registerTool).mock
.calls[0][0]; .calls[0][0];
@@ -773,12 +797,12 @@ describe('mcp-client', () => {
promptRegistry, promptRegistry,
resourceRegistry, resourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
await client.connect(); await client.connect();
await client.discover({} as Config); await client.discover(MOCK_CONTEXT);
expect(resourceRegistry.setResourcesForServer).toHaveBeenCalledWith( expect(resourceRegistry.setResourcesForServer).toHaveBeenCalledWith(
'test-server', 'test-server',
[ [
@@ -859,12 +883,12 @@ describe('mcp-client', () => {
promptRegistry, promptRegistry,
resourceRegistry, resourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
await client.connect(); await client.connect();
await client.discover({} as Config); await client.discover(MOCK_CONTEXT);
expect(mockedClient.setNotificationHandler).toHaveBeenCalledTimes(2); expect(mockedClient.setNotificationHandler).toHaveBeenCalledTimes(2);
expect(resourceListHandler).toBeDefined(); expect(resourceListHandler).toBeDefined();
@@ -878,9 +902,11 @@ describe('mcp-client', () => {
[expect.objectContaining({ uri: 'file:///tmp/two.txt' })], [expect.objectContaining({ uri: 'file:///tmp/two.txt' })],
); );
expect(coreEvents.emitFeedback).toHaveBeenCalledWith( expect(MOCK_CONTEXT.emitMcpDiagnostic).toHaveBeenCalledWith(
'info', 'info',
'Resources updated for server: test-server', 'Resources updated for server: test-server',
undefined,
'test-server',
); );
}); });
@@ -943,12 +969,12 @@ describe('mcp-client', () => {
promptRegistry, promptRegistry,
resourceRegistry, resourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
await client.connect(); await client.connect();
await client.discover({ sanitizationConfig: EMPTY_CONFIG } as Config); await client.discover(MOCK_CONTEXT);
expect(mockedClient.setNotificationHandler).toHaveBeenCalledTimes(2); expect(mockedClient.setNotificationHandler).toHaveBeenCalledTimes(2);
expect(promptListHandler).toBeDefined(); expect(promptListHandler).toBeDefined();
@@ -963,9 +989,11 @@ describe('mcp-client', () => {
expect(promptRegistry.registerPrompt).toHaveBeenLastCalledWith( expect(promptRegistry.registerPrompt).toHaveBeenLastCalledWith(
expect.objectContaining({ name: 'two' }), expect.objectContaining({ name: 'two' }),
); );
expect(coreEvents.emitFeedback).toHaveBeenCalledWith( expect(MOCK_CONTEXT.emitMcpDiagnostic).toHaveBeenCalledWith(
'info', 'info',
'Prompts updated for server: test-server', 'Prompts updated for server: test-server',
undefined,
'test-server',
); );
}); });
@@ -1025,12 +1053,12 @@ describe('mcp-client', () => {
mockedPromptRegistry, mockedPromptRegistry,
resourceRegistry, resourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
await client.connect(); await client.connect();
await client.discover({} as Config); await client.discover(MOCK_CONTEXT);
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
expect(mockedPromptRegistry.registerPrompt).toHaveBeenCalledOnce(); expect(mockedPromptRegistry.registerPrompt).toHaveBeenCalledOnce();
@@ -1075,7 +1103,7 @@ describe('mcp-client', () => {
{} as PromptRegistry, {} as PromptRegistry,
{} as ResourceRegistry, {} as ResourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
@@ -1112,7 +1140,7 @@ describe('mcp-client', () => {
{} as PromptRegistry, {} as PromptRegistry,
{} as ResourceRegistry, {} as ResourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
@@ -1168,7 +1196,7 @@ describe('mcp-client', () => {
{} as PromptRegistry, {} as PromptRegistry,
{} as ResourceRegistry, {} as ResourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
onToolsUpdatedSpy, onToolsUpdatedSpy,
@@ -1200,9 +1228,11 @@ describe('mcp-client', () => {
expect(onToolsUpdatedSpy).toHaveBeenCalled(); expect(onToolsUpdatedSpy).toHaveBeenCalled();
// It should emit feedback event // It should emit feedback event
expect(coreEvents.emitFeedback).toHaveBeenCalledWith( expect(MOCK_CONTEXT.emitMcpDiagnostic).toHaveBeenCalledWith(
'info', 'info',
'Tools updated for server: test-server', 'Tools updated for server: test-server',
undefined,
'test-server',
); );
}); });
@@ -1239,7 +1269,7 @@ describe('mcp-client', () => {
{} as PromptRegistry, {} as PromptRegistry,
{} as ResourceRegistry, {} as ResourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
@@ -1310,7 +1340,7 @@ describe('mcp-client', () => {
{} as PromptRegistry, {} as PromptRegistry,
{} as ResourceRegistry, {} as ResourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
onToolsUpdatedSpy, onToolsUpdatedSpy,
@@ -1323,7 +1353,7 @@ describe('mcp-client', () => {
{} as PromptRegistry, {} as PromptRegistry,
{} as ResourceRegistry, {} as ResourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
onToolsUpdatedSpy, onToolsUpdatedSpy,
@@ -1409,7 +1439,7 @@ describe('mcp-client', () => {
{} as PromptRegistry, {} as PromptRegistry,
{} as ResourceRegistry, {} as ResourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
); );
@@ -1474,7 +1504,7 @@ describe('mcp-client', () => {
{} as PromptRegistry, {} as PromptRegistry,
{} as ResourceRegistry, {} as ResourceRegistry,
workspaceContext, workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
onToolsUpdatedSpy, onToolsUpdatedSpy,
@@ -1526,7 +1556,7 @@ describe('mcp-client', () => {
httpUrl: 'http://test-server', httpUrl: 'http://test-server',
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
@@ -1544,7 +1574,7 @@ describe('mcp-client', () => {
headers: { Authorization: 'derp' }, headers: { Authorization: 'derp' },
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
@@ -1565,7 +1595,7 @@ describe('mcp-client', () => {
url: 'http://test-server', url: 'http://test-server',
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
expect(transport).toMatchObject({ expect(transport).toMatchObject({
@@ -1582,7 +1612,7 @@ describe('mcp-client', () => {
headers: { Authorization: 'derp' }, headers: { Authorization: 'derp' },
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
@@ -1602,7 +1632,7 @@ describe('mcp-client', () => {
type: 'http', type: 'http',
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
@@ -1620,7 +1650,7 @@ describe('mcp-client', () => {
type: 'sse', type: 'sse',
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(transport).toBeInstanceOf(SSEClientTransport); expect(transport).toBeInstanceOf(SSEClientTransport);
@@ -1637,7 +1667,7 @@ describe('mcp-client', () => {
url: 'http://test-server', url: 'http://test-server',
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
@@ -1656,7 +1686,7 @@ describe('mcp-client', () => {
headers: { Authorization: 'Bearer token' }, headers: { Authorization: 'Bearer token' },
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
@@ -1677,7 +1707,7 @@ describe('mcp-client', () => {
headers: { 'X-API-Key': 'key123' }, headers: { 'X-API-Key': 'key123' },
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(transport).toBeInstanceOf(SSEClientTransport); expect(transport).toBeInstanceOf(SSEClientTransport);
@@ -1697,7 +1727,7 @@ describe('mcp-client', () => {
url: 'http://test-server-url', url: 'http://test-server-url',
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
// httpUrl should take priority and create HTTP transport // httpUrl should take priority and create HTTP transport
@@ -1723,7 +1753,7 @@ describe('mcp-client', () => {
cwd: 'test/cwd', cwd: 'test/cwd',
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(mockedTransport).toHaveBeenCalledWith({ expect(mockedTransport).toHaveBeenCalledWith({
@@ -1749,7 +1779,7 @@ describe('mcp-client', () => {
cwd: 'test/cwd', cwd: 'test/cwd',
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
const callArgs = mockedTransport.mock.calls[0][0]; const callArgs = mockedTransport.mock.calls[0][0];
@@ -1784,7 +1814,7 @@ describe('mcp-client', () => {
}, },
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
const callArgs = mockedTransport.mock.calls[0][0]; const callArgs = mockedTransport.mock.calls[0][0];
@@ -1792,6 +1822,80 @@ describe('mcp-client', () => {
expect(callArgs.env!['GEMINI_CLI_EXT_VAR']).toBeUndefined(); expect(callArgs.env!['GEMINI_CLI_EXT_VAR']).toBeUndefined();
}); });
it('should include extension settings with defined values in environment', async () => {
const mockedTransport = vi
.spyOn(SdkClientStdioLib, 'StdioClientTransport')
.mockReturnValue({} as SdkClientStdioLib.StdioClientTransport);
await createTransport(
'test-server',
{
command: 'test-command',
extension: {
name: 'test-ext',
resolvedSettings: [
{
envVar: 'GEMINI_CLI_EXT_VAR',
value: 'defined-value',
sensitive: false,
name: 'ext-setting',
},
],
version: '',
isActive: false,
path: '',
contextFiles: [],
id: '',
},
},
false,
MOCK_CONTEXT,
);
const callArgs = mockedTransport.mock.calls[0][0];
expect(callArgs.env).toBeDefined();
expect(callArgs.env!['GEMINI_CLI_EXT_VAR']).toBe('defined-value');
});
it('should resolve environment variables in mcpServerConfig.env using extension settings', async () => {
const mockedTransport = vi
.spyOn(SdkClientStdioLib, 'StdioClientTransport')
.mockReturnValue({} as SdkClientStdioLib.StdioClientTransport);
await createTransport(
'test-server',
{
command: 'test-command',
env: {
RESOLVED_VAR: '$GEMINI_CLI_EXT_VAR',
},
extension: {
name: 'test-ext',
resolvedSettings: [
{
envVar: 'GEMINI_CLI_EXT_VAR',
value: 'ext-value',
sensitive: false,
name: 'ext-setting',
},
],
version: '',
isActive: false,
path: '',
contextFiles: [],
id: '',
},
},
false,
MOCK_CONTEXT,
);
const callArgs = mockedTransport.mock.calls[0][0];
expect(callArgs.env).toBeDefined();
expect(callArgs.env!['GEMINI_CLI_EXT_VAR']).toBe('ext-value');
expect(callArgs.env!['RESOLVED_VAR']).toBe('ext-value');
});
it('should expand environment variables in mcpServerConfig.env and not redact them', async () => { it('should expand environment variables in mcpServerConfig.env and not redact them', async () => {
const mockedTransport = vi const mockedTransport = vi
.spyOn(SdkClientStdioLib, 'StdioClientTransport') .spyOn(SdkClientStdioLib, 'StdioClientTransport')
@@ -1814,7 +1918,7 @@ describe('mcp-client', () => {
}, },
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
const callArgs = mockedTransport.mock.calls[0][0]; const callArgs = mockedTransport.mock.calls[0][0];
@@ -1851,17 +1955,15 @@ describe('mcp-client', () => {
}, },
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any const testableTransport = transport as unknown as TestableTransport;
const authProvider = (transport as any)._authProvider; const authProvider = testableTransport._authProvider;
expect(authProvider).toBeInstanceOf(GoogleCredentialProvider); expect(authProvider).toBeInstanceOf(GoogleCredentialProvider);
// eslint-disable-next-line @typescript-eslint/no-explicit-any const googUserProject =
const googUserProject = (transport as any)._requestInit?.headers?.[ testableTransport._requestInit?.headers?.['X-Goog-User-Project'];
'X-Goog-User-Project'
];
expect(googUserProject).toBe('myproject'); expect(googUserProject).toBe('myproject');
}); });
@@ -1884,14 +1986,14 @@ describe('mcp-client', () => {
}, },
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
expect(mockGetRequestHeaders).toHaveBeenCalled(); expect(mockGetRequestHeaders).toHaveBeenCalled();
// eslint-disable-next-line @typescript-eslint/no-explicit-any const testableTransport = transport as unknown as TestableTransport;
const headers = (transport as any)._requestInit?.headers; const headers = testableTransport._requestInit?.headers;
expect(headers['X-Goog-User-Project']).toBe('provider-project'); expect(headers?.['X-Goog-User-Project']).toBe('provider-project');
}); });
it('should prioritize provider headers over config headers', async () => { it('should prioritize provider headers over config headers', async () => {
@@ -1916,13 +2018,13 @@ describe('mcp-client', () => {
}, },
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any const testableTransport = transport as unknown as TestableTransport;
const headers = (transport as any)._requestInit?.headers; const headers = testableTransport._requestInit?.headers;
expect(headers['X-Goog-User-Project']).toBe('provider-project'); expect(headers?.['X-Goog-User-Project']).toBe('provider-project');
}); });
it('should use GoogleCredentialProvider with SSE transport', async () => { it('should use GoogleCredentialProvider with SSE transport', async () => {
@@ -1937,12 +2039,12 @@ describe('mcp-client', () => {
}, },
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(transport).toBeInstanceOf(SSEClientTransport); expect(transport).toBeInstanceOf(SSEClientTransport);
// eslint-disable-next-line @typescript-eslint/no-explicit-any const testableTransport = transport as unknown as TestableTransport;
const authProvider = (transport as any)._authProvider; const authProvider = testableTransport._authProvider;
expect(authProvider).toBeInstanceOf(GoogleCredentialProvider); expect(authProvider).toBeInstanceOf(GoogleCredentialProvider);
}); });
@@ -1957,7 +2059,7 @@ describe('mcp-client', () => {
}, },
}, },
false, false,
EMPTY_CONFIG, MOCK_CONTEXT,
), ),
).rejects.toThrow( ).rejects.toThrow(
'URL must be provided in the config for Google Credentials provider', 'URL must be provided in the config for Google Credentials provider',
@@ -2104,12 +2206,11 @@ describe('connectToMcpServer with OAuth', () => {
scopes: ['test-scope'], scopes: ['test-scope'],
}); });
// We need this to be an any type because we dig into its private state. // We need this to be typed to dig into its private state.
// eslint-disable-next-line @typescript-eslint/no-explicit-any let capturedTransport: TestableTransport | undefined;
let capturedTransport: any;
vi.mocked(mockedClient.connect).mockImplementationOnce( vi.mocked(mockedClient.connect).mockImplementationOnce(
async (transport) => { async (transport) => {
capturedTransport = transport; capturedTransport = transport as unknown as TestableTransport;
return Promise.resolve(); return Promise.resolve();
}, },
); );
@@ -2120,15 +2221,15 @@ describe('connectToMcpServer with OAuth', () => {
{ httpUrl: serverUrl, oauth: { enabled: true } }, { httpUrl: serverUrl, oauth: { enabled: true } },
false, false,
workspaceContext, workspaceContext,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(client).toBe(mockedClient); expect(client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(2); expect(mockedClient.connect).toHaveBeenCalledTimes(2);
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
const authHeader = const authHeader = (capturedTransport as TestableTransport)._requestInit
capturedTransport._requestInit?.headers?.['Authorization']; ?.headers?.['Authorization'];
expect(authHeader).toBe('Bearer test-access-token'); expect(authHeader).toBe('Bearer test-access-token');
}); });
@@ -2150,12 +2251,11 @@ describe('connectToMcpServer with OAuth', () => {
'test-access-token-from-discovery', 'test-access-token-from-discovery',
); );
// We need this to be an any type because we dig into its private state. // We need this to be typed to dig into its private state.
// eslint-disable-next-line @typescript-eslint/no-explicit-any let capturedTransport: TestableTransport | undefined;
let capturedTransport: any;
vi.mocked(mockedClient.connect).mockImplementationOnce( vi.mocked(mockedClient.connect).mockImplementationOnce(
async (transport) => { async (transport) => {
capturedTransport = transport; capturedTransport = transport as unknown as TestableTransport;
return Promise.resolve(); return Promise.resolve();
}, },
); );
@@ -2166,7 +2266,7 @@ describe('connectToMcpServer with OAuth', () => {
{ httpUrl: serverUrl, oauth: { enabled: true } }, { httpUrl: serverUrl, oauth: { enabled: true } },
false, false,
workspaceContext, workspaceContext,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(client).toBe(mockedClient); expect(client).toBe(mockedClient);
@@ -2174,8 +2274,8 @@ describe('connectToMcpServer with OAuth', () => {
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl); expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl);
const authHeader = const authHeader = (capturedTransport as TestableTransport)._requestInit
capturedTransport._requestInit?.headers?.['Authorization']; ?.headers?.['Authorization'];
expect(authHeader).toBe('Bearer test-access-token-from-discovery'); expect(authHeader).toBe('Bearer test-access-token-from-discovery');
}); });
@@ -2206,7 +2306,7 @@ describe('connectToMcpServer with OAuth', () => {
{ httpUrl: serverUrl, oauth: { enabled: true } }, { httpUrl: serverUrl, oauth: { enabled: true } },
false, false,
workspaceContext, workspaceContext,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(client).toBe(mockedClient); expect(client).toBe(mockedClient);
@@ -2250,7 +2350,7 @@ describe('connectToMcpServer with OAuth', () => {
{ httpUrl: serverUrl, oauth: { enabled: true } }, { httpUrl: serverUrl, oauth: { enabled: true } },
false, false,
workspaceContext, workspaceContext,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(client).toBe(mockedClient); expect(client).toBe(mockedClient);
@@ -2306,7 +2406,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
{ url: 'http://test-server', type: 'http' }, { url: 'http://test-server', type: 'http' },
false, false,
workspaceContext, workspaceContext,
EMPTY_CONFIG, MOCK_CONTEXT,
), ),
).rejects.toThrow('Connection failed'); ).rejects.toThrow('Connection failed');
@@ -2326,7 +2426,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
{ url: 'http://test-server', type: 'sse' }, { url: 'http://test-server', type: 'sse' },
false, false,
workspaceContext, workspaceContext,
EMPTY_CONFIG, MOCK_CONTEXT,
), ),
).rejects.toThrow('Connection failed'); ).rejects.toThrow('Connection failed');
@@ -2345,7 +2445,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
{ url: 'http://test-server' }, { url: 'http://test-server' },
false, false,
workspaceContext, workspaceContext,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(client).toBe(mockedClient); expect(client).toBe(mockedClient);
@@ -2368,7 +2468,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
{ url: 'http://test-server' }, { url: 'http://test-server' },
false, false,
workspaceContext, workspaceContext,
EMPTY_CONFIG, MOCK_CONTEXT,
), ),
).rejects.toThrow('Server error'); ).rejects.toThrow('Server error');
@@ -2386,7 +2486,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
{ url: 'http://test-server' }, { url: 'http://test-server' },
false, false,
workspaceContext, workspaceContext,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(client).toBe(mockedClient); expect(client).toBe(mockedClient);
@@ -2471,7 +2571,7 @@ describe('connectToMcpServer - OAuth with transport fallback', () => {
{ url: 'http://test-server', oauth: { enabled: true } }, { url: 'http://test-server', oauth: { enabled: true } },
false, false,
workspaceContext, workspaceContext,
EMPTY_CONFIG, MOCK_CONTEXT,
); );
expect(client).toBe(mockedClient); expect(client).toBe(mockedClient);
+155 -41
View File
@@ -34,7 +34,11 @@ import {
ProgressNotificationSchema, ProgressNotificationSchema,
} from '@modelcontextprotocol/sdk/types.js'; } from '@modelcontextprotocol/sdk/types.js';
import { parse } from 'shell-quote'; import { parse } from 'shell-quote';
import type { Config, MCPServerConfig } from '../config/config.js'; import type {
Config,
MCPServerConfig,
GeminiCLIExtension,
} from '../config/config.js';
import { AuthProviderType } from '../config/config.js'; import { AuthProviderType } from '../config/config.js';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { ServiceAccountImpersonationProvider } from '../mcp/sa-impersonation-provider.js'; import { ServiceAccountImpersonationProvider } from '../mcp/sa-impersonation-provider.js';
@@ -146,7 +150,7 @@ export class McpClient implements McpProgressReporter {
private readonly promptRegistry: PromptRegistry, private readonly promptRegistry: PromptRegistry,
private readonly resourceRegistry: ResourceRegistry, private readonly resourceRegistry: ResourceRegistry,
private readonly workspaceContext: WorkspaceContext, private readonly workspaceContext: WorkspaceContext,
private readonly cliConfig: Config, private readonly cliConfig: McpContext,
private readonly debugMode: boolean, private readonly debugMode: boolean,
private readonly clientVersion: string, private readonly clientVersion: string,
private readonly onToolsUpdated?: (signal?: AbortSignal) => Promise<void>, private readonly onToolsUpdated?: (signal?: AbortSignal) => Promise<void>,
@@ -169,7 +173,7 @@ export class McpClient implements McpProgressReporter {
this.serverConfig, this.serverConfig,
this.debugMode, this.debugMode,
this.workspaceContext, this.workspaceContext,
this.cliConfig.sanitizationConfig, this.cliConfig,
); );
this.registerNotificationHandlers(); this.registerNotificationHandlers();
@@ -180,10 +184,11 @@ export class McpClient implements McpProgressReporter {
return; return;
} }
if (originalOnError) originalOnError(error); if (originalOnError) originalOnError(error);
coreEvents.emitFeedback( this.cliConfig.emitMcpDiagnostic(
'error', 'error',
`MCP ERROR (${this.serverName})`, `MCP ERROR (${this.serverName})`,
error, error,
this.serverName,
); );
this.updateStatus(MCPServerStatus.DISCONNECTED); this.updateStatus(MCPServerStatus.DISCONNECTED);
}; };
@@ -197,7 +202,7 @@ export class McpClient implements McpProgressReporter {
/** /**
* Discovers tools and prompts from the MCP server. * Discovers tools and prompts from the MCP server.
*/ */
async discover(cliConfig: Config): Promise<void> { async discover(cliConfig: McpContext): Promise<void> {
this.assertConnected(); this.assertConnected();
const prompts = await this.fetchPrompts(); const prompts = await this.fetchPrompts();
@@ -261,7 +266,7 @@ export class McpClient implements McpProgressReporter {
} }
private async discoverTools( private async discoverTools(
cliConfig: Config, cliConfig: McpContext,
options?: { timeout?: number; signal?: AbortSignal }, options?: { timeout?: number; signal?: AbortSignal },
): Promise<DiscoveredMCPTool[]> { ): Promise<DiscoveredMCPTool[]> {
this.assertConnected(); this.assertConnected();
@@ -284,12 +289,17 @@ export class McpClient implements McpProgressReporter {
signal?: AbortSignal; signal?: AbortSignal;
}): Promise<DiscoveredMCPPrompt[]> { }): Promise<DiscoveredMCPPrompt[]> {
this.assertConnected(); this.assertConnected();
return discoverPrompts(this.serverName, this.client!, options); return discoverPrompts(
this.serverName,
this.client!,
this.cliConfig,
options,
);
} }
private async discoverResources(): Promise<Resource[]> { private async discoverResources(): Promise<Resource[]> {
this.assertConnected(); this.assertConnected();
return discoverResources(this.serverName, this.client!); return discoverResources(this.serverName, this.client!, this.cliConfig);
} }
private updateResourceRegistry(resources: Resource[]): void { private updateResourceRegistry(resources: Resource[]): void {
@@ -433,9 +443,11 @@ export class McpClient implements McpProgressReporter {
clearTimeout(timeoutId); clearTimeout(timeoutId);
coreEvents.emitFeedback( this.cliConfig.emitMcpDiagnostic(
'info', 'info',
`Resources updated for server: ${this.serverName}`, `Resources updated for server: ${this.serverName}`,
undefined,
this.serverName,
); );
} while (this.pendingResourceRefresh); } while (this.pendingResourceRefresh);
} catch (error) { } catch (error) {
@@ -504,9 +516,11 @@ export class McpClient implements McpProgressReporter {
clearTimeout(timeoutId); clearTimeout(timeoutId);
coreEvents.emitFeedback( this.cliConfig.emitMcpDiagnostic(
'info', 'info',
`Prompts updated for server: ${this.serverName}`, `Prompts updated for server: ${this.serverName}`,
undefined,
this.serverName,
); );
} while (this.pendingPromptRefresh); } while (this.pendingPromptRefresh);
} catch (error) { } catch (error) {
@@ -581,9 +595,11 @@ export class McpClient implements McpProgressReporter {
clearTimeout(timeoutId); clearTimeout(timeoutId);
coreEvents.emitFeedback( this.cliConfig.emitMcpDiagnostic(
'info', 'info',
`Tools updated for server: ${this.serverName}`, `Tools updated for server: ${this.serverName}`,
undefined,
this.serverName,
); );
} while (this.pendingToolRefresh); } while (this.pendingToolRefresh);
} catch (error) { } catch (error) {
@@ -715,6 +731,7 @@ async function handleAutomaticOAuth(
mcpServerName: string, mcpServerName: string,
mcpServerConfig: MCPServerConfig, mcpServerConfig: MCPServerConfig,
wwwAuthenticate: string, wwwAuthenticate: string,
cliConfig: McpContext,
): Promise<boolean> { ): Promise<boolean> {
try { try {
debugLogger.log(`🔐 '${mcpServerName}' requires OAuth authentication`); debugLogger.log(`🔐 '${mcpServerName}' requires OAuth authentication`);
@@ -734,9 +751,11 @@ async function handleAutomaticOAuth(
} }
if (!oauthConfig) { if (!oauthConfig) {
coreEvents.emitFeedback( cliConfig.emitMcpDiagnostic(
'error', 'error',
`Could not configure OAuth for '${mcpServerName}' - please authenticate manually with /mcp auth ${mcpServerName}`, `Could not configure OAuth for '${mcpServerName}' - please authenticate manually with /mcp auth ${mcpServerName}`,
undefined,
mcpServerName,
); );
return false; return false;
} }
@@ -764,10 +783,11 @@ async function handleAutomaticOAuth(
); );
return true; return true;
} catch (error) { } catch (error) {
coreEvents.emitFeedback( cliConfig.emitMcpDiagnostic(
'error', 'error',
`Failed to handle automatic OAuth for server '${mcpServerName}': ${getErrorMessage(error)}`, `Failed to handle automatic OAuth for server '${mcpServerName}': ${getErrorMessage(error)}`,
error, error,
mcpServerName,
); );
return false; return false;
} }
@@ -778,15 +798,25 @@ async function handleAutomaticOAuth(
* *
* @param mcpServerConfig The MCP server configuration * @param mcpServerConfig The MCP server configuration
* @param headers Additional headers * @param headers Additional headers
* @param sanitizationConfig Configuration for environment sanitization
*/ */
function createTransportRequestInit( function createTransportRequestInit(
mcpServerConfig: MCPServerConfig, mcpServerConfig: MCPServerConfig,
headers: Record<string, string>, headers: Record<string, string>,
sanitizationConfig: EnvironmentSanitizationConfig,
): RequestInit { ): RequestInit {
const extensionEnv = getExtensionEnvironment(mcpServerConfig.extension);
const expansionEnv = { ...process.env, ...extensionEnv };
const sanitizedEnv = sanitizeEnvironment(expansionEnv, {
...sanitizationConfig,
enableEnvironmentVariableRedaction: true,
});
const expandedHeaders: Record<string, string> = {}; const expandedHeaders: Record<string, string> = {};
if (mcpServerConfig.headers) { if (mcpServerConfig.headers) {
for (const [key, value] of Object.entries(mcpServerConfig.headers)) { for (const [key, value] of Object.entries(mcpServerConfig.headers)) {
expandedHeaders[key] = expandEnvVars(value, process.env); expandedHeaders[key] = expandEnvVars(value, sanitizedEnv);
} }
} }
@@ -826,12 +856,14 @@ function createAuthProvider(
* @param mcpServerName The name of the MCP server * @param mcpServerName The name of the MCP server
* @param mcpServerConfig The MCP server configuration * @param mcpServerConfig The MCP server configuration
* @param accessToken The OAuth access token * @param accessToken The OAuth access token
* @param cliConfig The CLI configuration providing sanitization and diagnostic reporting
* @returns The transport with OAuth token, or null if creation fails * @returns The transport with OAuth token, or null if creation fails
*/ */
async function createTransportWithOAuth( async function createTransportWithOAuth(
mcpServerName: string, mcpServerName: string,
mcpServerConfig: MCPServerConfig, mcpServerConfig: MCPServerConfig,
accessToken: string, accessToken: string,
cliConfig: McpContext,
): Promise<StreamableHTTPClientTransport | SSEClientTransport | null> { ): Promise<StreamableHTTPClientTransport | SSEClientTransport | null> {
try { try {
const headers: Record<string, string> = { const headers: Record<string, string> = {
@@ -840,15 +872,20 @@ async function createTransportWithOAuth(
const transportOptions: const transportOptions:
| StreamableHTTPClientTransportOptions | StreamableHTTPClientTransportOptions
| SSEClientTransportOptions = { | SSEClientTransportOptions = {
requestInit: createTransportRequestInit(mcpServerConfig, headers), requestInit: createTransportRequestInit(
mcpServerConfig,
headers,
cliConfig.sanitizationConfig,
),
}; };
return createUrlTransport(mcpServerName, mcpServerConfig, transportOptions); return createUrlTransport(mcpServerName, mcpServerConfig, transportOptions);
} catch (error) { } catch (error) {
coreEvents.emitFeedback( cliConfig.emitMcpDiagnostic(
'error', 'error',
`Failed to create OAuth transport for server '${mcpServerName}': ${getErrorMessage(error)}`, `Failed to create OAuth transport for server '${mcpServerName}': ${getErrorMessage(error)}`,
error, error,
mcpServerName,
); );
return null; return null;
} }
@@ -970,7 +1007,7 @@ export async function connectAndDiscover(
promptRegistry: PromptRegistry, promptRegistry: PromptRegistry,
debugMode: boolean, debugMode: boolean,
workspaceContext: WorkspaceContext, workspaceContext: WorkspaceContext,
cliConfig: Config, cliConfig: McpContext,
): Promise<void> { ): Promise<void> {
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING); updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
@@ -982,16 +1019,21 @@ export async function connectAndDiscover(
mcpServerConfig, mcpServerConfig,
debugMode, debugMode,
workspaceContext, workspaceContext,
cliConfig.sanitizationConfig, cliConfig,
); );
mcpClient.onerror = (error) => { mcpClient.onerror = (error) => {
coreEvents.emitFeedback('error', `MCP ERROR (${mcpServerName}):`, error); cliConfig.emitMcpDiagnostic(
'error',
`MCP ERROR (${mcpServerName}):`,
error,
mcpServerName,
);
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
}; };
// Attempt to discover both prompts and tools // Attempt to discover both prompts and tools
const prompts = await discoverPrompts(mcpServerName, mcpClient); const prompts = await discoverPrompts(mcpServerName, mcpClient, cliConfig);
const tools = await discoverTools( const tools = await discoverTools(
mcpServerName, mcpServerName,
mcpServerConfig, mcpServerConfig,
@@ -1022,12 +1064,13 @@ export async function connectAndDiscover(
// eslint-disable-next-line @typescript-eslint/no-floating-promises // eslint-disable-next-line @typescript-eslint/no-floating-promises
mcpClient.close(); mcpClient.close();
} }
coreEvents.emitFeedback( cliConfig.emitMcpDiagnostic(
'error', 'error',
`Error connecting to MCP server '${mcpServerName}': ${getErrorMessage( `Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(
error, error,
)}`, )}`,
error, error,
mcpServerName,
); );
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
} }
@@ -1050,7 +1093,7 @@ export async function discoverTools(
mcpServerName: string, mcpServerName: string,
mcpServerConfig: MCPServerConfig, mcpServerConfig: MCPServerConfig,
mcpClient: Client, mcpClient: Client,
cliConfig: Config, cliConfig: McpContext,
messageBus: MessageBus, messageBus: MessageBus,
options?: { options?: {
timeout?: number; timeout?: number;
@@ -1099,13 +1142,14 @@ export async function discoverTools(
discoveredTools.push(tool); discoveredTools.push(tool);
} catch (error) { } catch (error) {
coreEvents.emitFeedback( cliConfig.emitMcpDiagnostic(
'error', 'error',
`Error discovering tool: '${ `Error discovering tool: '${
toolDef.name toolDef.name
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
}' from MCP server '${mcpServerName}': ${(error as Error).message}`, }' from MCP server '${mcpServerName}': ${(error as Error).message}`,
error, error,
mcpServerName,
); );
} }
} }
@@ -1115,12 +1159,13 @@ export async function discoverTools(
error instanceof Error && error instanceof Error &&
!error.message?.includes('Method not found') !error.message?.includes('Method not found')
) { ) {
coreEvents.emitFeedback( cliConfig.emitMcpDiagnostic(
'error', 'error',
`Error discovering tools from ${mcpServerName}: ${getErrorMessage( `Error discovering tools from ${mcpServerName}: ${getErrorMessage(
error, error,
)}`, )}`,
error, error,
mcpServerName,
); );
} }
return []; return [];
@@ -1216,6 +1261,7 @@ class McpCallableTool implements CallableTool {
export async function discoverPrompts( export async function discoverPrompts(
mcpServerName: string, mcpServerName: string,
mcpClient: Client, mcpClient: Client,
cliConfig: McpContext,
options?: { signal?: AbortSignal }, options?: { signal?: AbortSignal },
): Promise<DiscoveredMCPPrompt[]> { ): Promise<DiscoveredMCPPrompt[]> {
// Only request prompts if the server supports them. // Only request prompts if the server supports them.
@@ -1227,19 +1273,26 @@ export async function discoverPrompts(
...prompt, ...prompt,
serverName: mcpServerName, serverName: mcpServerName,
invoke: (params: Record<string, unknown>) => invoke: (params: Record<string, unknown>) =>
invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params), invokeMcpPrompt(
mcpServerName,
mcpClient,
prompt.name,
params,
cliConfig,
),
})); }));
} catch (error) { } catch (error) {
// It's okay if the method is not found, which is a common case. // It's okay if the method is not found, which is a common case.
if (error instanceof Error && error.message?.includes('Method not found')) { if (error instanceof Error && error.message?.includes('Method not found')) {
return []; return [];
} }
coreEvents.emitFeedback( cliConfig.emitMcpDiagnostic(
'error', 'error',
`Error discovering prompts from ${mcpServerName}: ${getErrorMessage( `Error discovering prompts from ${mcpServerName}: ${getErrorMessage(
error, error,
)}`, )}`,
error, error,
mcpServerName,
); );
throw error; throw error;
} }
@@ -1248,18 +1301,20 @@ export async function discoverPrompts(
export async function discoverResources( export async function discoverResources(
mcpServerName: string, mcpServerName: string,
mcpClient: Client, mcpClient: Client,
cliConfig: McpContext,
): Promise<Resource[]> { ): Promise<Resource[]> {
if (mcpClient.getServerCapabilities()?.resources == null) { if (mcpClient.getServerCapabilities()?.resources == null) {
return []; return [];
} }
const resources = await listResources(mcpServerName, mcpClient); const resources = await listResources(mcpServerName, mcpClient, cliConfig);
return resources; return resources;
} }
async function listResources( async function listResources(
mcpServerName: string, mcpServerName: string,
mcpClient: Client, mcpClient: Client,
cliConfig: McpContext,
): Promise<Resource[]> { ): Promise<Resource[]> {
const resources: Resource[] = []; const resources: Resource[] = [];
let cursor: string | undefined; let cursor: string | undefined;
@@ -1279,12 +1334,13 @@ async function listResources(
if (error instanceof Error && error.message?.includes('Method not found')) { if (error instanceof Error && error.message?.includes('Method not found')) {
return []; return [];
} }
coreEvents.emitFeedback( cliConfig.emitMcpDiagnostic(
'error', 'error',
`Error discovering resources from ${mcpServerName}: ${getErrorMessage( `Error discovering resources from ${mcpServerName}: ${getErrorMessage(
error, error,
)}`, )}`,
error, error,
mcpServerName,
); );
throw error; throw error;
} }
@@ -1305,7 +1361,9 @@ export async function invokeMcpPrompt(
mcpClient: Client, mcpClient: Client,
promptName: string, promptName: string,
promptParams: Record<string, unknown>, promptParams: Record<string, unknown>,
cliConfig: McpContext,
): Promise<GetPromptResult> { ): Promise<GetPromptResult> {
cliConfig.setUserInteractedWithMcp?.();
try { try {
const sanitizedParams: Record<string, string> = {}; const sanitizedParams: Record<string, string> = {};
for (const [key, value] of Object.entries(promptParams)) { for (const [key, value] of Object.entries(promptParams)) {
@@ -1325,12 +1383,13 @@ export async function invokeMcpPrompt(
error instanceof Error && error instanceof Error &&
!error.message?.includes('Method not found') !error.message?.includes('Method not found')
) { ) {
coreEvents.emitFeedback( cliConfig.emitMcpDiagnostic(
'error', 'error',
`Error invoking prompt '${promptName}' from ${mcpServerName} ${promptParams}: ${getErrorMessage( `Error invoking prompt '${promptName}' from ${mcpServerName} ${promptParams}: ${getErrorMessage(
error, error,
)}`, )}`,
error, error,
mcpServerName,
); );
} }
throw error; throw error;
@@ -1415,14 +1474,17 @@ async function connectWithSSETransport(
* @param serverName The name of the MCP server * @param serverName The name of the MCP server
* @throws Always throws an error with authentication instructions * @throws Always throws an error with authentication instructions
*/ */
async function showAuthRequiredMessage(serverName: string): Promise<never> { async function showAuthRequiredMessage(
serverName: string,
cliConfig: McpContext,
): Promise<never> {
const hasRejectedToken = !!(await getStoredOAuthToken(serverName)); const hasRejectedToken = !!(await getStoredOAuthToken(serverName));
const message = hasRejectedToken const message = hasRejectedToken
? `MCP server '${serverName}' rejected stored OAuth token. Please re-authenticate using: /mcp auth ${serverName}` ? `MCP server '${serverName}' rejected stored OAuth token. Please re-authenticate using: /mcp auth ${serverName}`
: `MCP server '${serverName}' requires authentication using: /mcp auth ${serverName}`; : `MCP server '${serverName}' requires authentication using: /mcp auth ${serverName}`;
coreEvents.emitFeedback('info', message); cliConfig.emitMcpDiagnostic('info', message, undefined, serverName);
throw new UnauthorizedError(message); throw new UnauthorizedError(message);
} }
@@ -1435,6 +1497,7 @@ async function showAuthRequiredMessage(serverName: string): Promise<never> {
* @param config The MCP server configuration * @param config The MCP server configuration
* @param accessToken The OAuth access token to use * @param accessToken The OAuth access token to use
* @param httpReturned404 Whether the HTTP transport returned 404 (indicating SSE-only server) * @param httpReturned404 Whether the HTTP transport returned 404 (indicating SSE-only server)
* @param cliConfig The CLI configuration providing sanitization and diagnostic reporting
*/ */
async function retryWithOAuth( async function retryWithOAuth(
client: Client, client: Client,
@@ -1442,6 +1505,7 @@ async function retryWithOAuth(
config: MCPServerConfig, config: MCPServerConfig,
accessToken: string, accessToken: string,
httpReturned404: boolean, httpReturned404: boolean,
cliConfig: McpContext,
): Promise<void> { ): Promise<void> {
if (httpReturned404) { if (httpReturned404) {
// HTTP returned 404, only try SSE // HTTP returned 404, only try SSE
@@ -1462,6 +1526,7 @@ async function retryWithOAuth(
serverName, serverName,
config, config,
accessToken, accessToken,
cliConfig,
); );
if (!httpTransport) { if (!httpTransport) {
throw new Error( throw new Error(
@@ -1497,6 +1562,23 @@ async function retryWithOAuth(
} }
} }
/**
* Interface for MCP operations that require configuration or diagnostic reporting.
* This is implemented by the central Config class and can be mocked for testing
* or used by the non-interactive CLI.
*/
export interface McpContext {
readonly sanitizationConfig: EnvironmentSanitizationConfig;
emitMcpDiagnostic(
severity: 'info' | 'warning' | 'error',
message: string,
error?: unknown,
serverName?: string,
): void;
setUserInteractedWithMcp?(): void;
isTrustedFolder(): boolean;
}
/** /**
* Creates and connects an MCP client to a server based on the provided configuration. * Creates and connects an MCP client to a server based on the provided configuration.
* It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and * It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and
@@ -1513,7 +1595,7 @@ export async function connectToMcpServer(
mcpServerConfig: MCPServerConfig, mcpServerConfig: MCPServerConfig,
debugMode: boolean, debugMode: boolean,
workspaceContext: WorkspaceContext, workspaceContext: WorkspaceContext,
sanitizationConfig: EnvironmentSanitizationConfig, cliConfig: McpContext,
): Promise<Client> { ): Promise<Client> {
const mcpClient = new Client( const mcpClient = new Client(
{ {
@@ -1580,7 +1662,7 @@ export async function connectToMcpServer(
mcpServerName, mcpServerName,
mcpServerConfig, mcpServerConfig,
debugMode, debugMode,
sanitizationConfig, cliConfig,
); );
try { try {
await mcpClient.connect(transport, { await mcpClient.connect(transport, {
@@ -1659,7 +1741,7 @@ export async function connectToMcpServer(
const shouldTriggerOAuth = mcpServerConfig.oauth?.enabled; const shouldTriggerOAuth = mcpServerConfig.oauth?.enabled;
if (!shouldTriggerOAuth) { if (!shouldTriggerOAuth) {
await showAuthRequiredMessage(mcpServerName); await showAuthRequiredMessage(mcpServerName, cliConfig);
} }
// Try to extract www-authenticate header from the error // Try to extract www-authenticate header from the error
@@ -1725,6 +1807,7 @@ export async function connectToMcpServer(
mcpServerName, mcpServerName,
mcpServerConfig, mcpServerConfig,
wwwAuthenticate, wwwAuthenticate,
cliConfig,
); );
if (oauthSuccess) { if (oauthSuccess) {
// Retry connection with OAuth token // Retry connection with OAuth token
@@ -1741,6 +1824,7 @@ export async function connectToMcpServer(
mcpServerConfig, mcpServerConfig,
accessToken, accessToken,
httpReturned404, httpReturned404,
cliConfig,
); );
return mcpClient; return mcpClient;
} else { } else {
@@ -1754,7 +1838,7 @@ export async function connectToMcpServer(
const shouldTryDiscovery = mcpServerConfig.oauth?.enabled; const shouldTryDiscovery = mcpServerConfig.oauth?.enabled;
if (!shouldTryDiscovery) { if (!shouldTryDiscovery) {
await showAuthRequiredMessage(mcpServerName); await showAuthRequiredMessage(mcpServerName, cliConfig);
} }
// For SSE/HTTP servers, try to discover OAuth configuration from the base URL // For SSE/HTTP servers, try to discover OAuth configuration from the base URL
@@ -1813,6 +1897,7 @@ export async function connectToMcpServer(
mcpServerName, mcpServerName,
mcpServerConfig, mcpServerConfig,
accessToken, accessToken,
cliConfig,
); );
if (!oauthTransport) { if (!oauthTransport) {
throw new Error( throw new Error(
@@ -1900,7 +1985,7 @@ export async function createTransport(
mcpServerName: string, mcpServerName: string,
mcpServerConfig: MCPServerConfig, mcpServerConfig: MCPServerConfig,
debugMode: boolean, debugMode: boolean,
sanitizationConfig: EnvironmentSanitizationConfig, cliConfig: McpContext,
): Promise<Transport> { ): Promise<Transport> {
const noUrl = !mcpServerConfig.url && !mcpServerConfig.httpUrl; const noUrl = !mcpServerConfig.url && !mcpServerConfig.httpUrl;
if (noUrl) { if (noUrl) {
@@ -1938,9 +2023,11 @@ export async function createTransport(
if (!accessToken) { if (!accessToken) {
// Emit info message (not error) since this is expected behavior // Emit info message (not error) since this is expected behavior
coreEvents.emitFeedback( cliConfig.emitMcpDiagnostic(
'info', 'info',
`MCP server '${mcpServerName}' requires authentication using: /mcp auth ${mcpServerName}`, `MCP server '${mcpServerName}' requires authentication using: /mcp auth ${mcpServerName}`,
undefined,
mcpServerName,
); );
} }
} else { } else {
@@ -1960,7 +2047,11 @@ export async function createTransport(
const transportOptions: const transportOptions:
| StreamableHTTPClientTransportOptions | StreamableHTTPClientTransportOptions
| SSEClientTransportOptions = { | SSEClientTransportOptions = {
requestInit: createTransportRequestInit(mcpServerConfig, headers), requestInit: createTransportRequestInit(
mcpServerConfig,
headers,
cliConfig.sanitizationConfig,
),
authProvider, authProvider,
}; };
@@ -1968,15 +2059,24 @@ export async function createTransport(
} }
if (mcpServerConfig.command) { if (mcpServerConfig.command) {
if (!cliConfig.isTrustedFolder()) {
throw new Error(
`MCP server '${mcpServerName}' uses stdio transport but current folder is not trusted. Use 'gemini trust' to enable it.`,
);
}
const extensionEnv = getExtensionEnvironment(mcpServerConfig.extension);
const expansionEnv = { ...process.env, ...extensionEnv };
// 1. Sanitize the base process environment to prevent unintended leaks of system-wide secrets. // 1. Sanitize the base process environment to prevent unintended leaks of system-wide secrets.
const sanitizedEnv = sanitizeEnvironment(process.env, { const sanitizedEnv = sanitizeEnvironment(expansionEnv, {
...sanitizationConfig, ...cliConfig.sanitizationConfig,
enableEnvironmentVariableRedaction: true, enableEnvironmentVariableRedaction: true,
}); });
const finalEnv: Record<string, string> = { const finalEnv: Record<string, string> = {
[GEMINI_CLI_IDENTIFICATION_ENV_VAR]: [GEMINI_CLI_IDENTIFICATION_ENV_VAR]:
GEMINI_CLI_IDENTIFICATION_ENV_VAR_VALUE, GEMINI_CLI_IDENTIFICATION_ENV_VAR_VALUE,
...extensionEnv,
}; };
for (const [key, value] of Object.entries(sanitizedEnv)) { for (const [key, value] of Object.entries(sanitizedEnv)) {
if (value !== undefined) { if (value !== undefined) {
@@ -1987,7 +2087,7 @@ export async function createTransport(
// Expand and merge explicit environment variables from the MCP configuration. // Expand and merge explicit environment variables from the MCP configuration.
if (mcpServerConfig.env) { if (mcpServerConfig.env) {
for (const [key, value] of Object.entries(mcpServerConfig.env)) { for (const [key, value] of Object.entries(mcpServerConfig.env)) {
finalEnv[key] = expandEnvVars(value, process.env); finalEnv[key] = expandEnvVars(value, expansionEnv);
} }
} }
@@ -2045,6 +2145,20 @@ interface NamedTool {
name?: string; name?: string;
} }
function getExtensionEnvironment(
extension?: GeminiCLIExtension,
): Record<string, string> {
const env: Record<string, string> = {};
if (extension?.resolvedSettings) {
for (const setting of extension.resolvedSettings) {
if (setting.value !== undefined) {
env[setting.envVar] = setting.value;
}
}
}
return env;
}
/** Visible for testing */ /** Visible for testing */
export function isEnabled( export function isEnabled(
funcDecl: NamedTool, funcDecl: NamedTool,
+74 -102
View File
@@ -11,6 +11,7 @@ import type { ToolInvocation, ToolLocation, ToolResult } from './tools.js';
import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js'; import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
import { ToolErrorType } from './tool-error.js'; import { ToolErrorType } from './tool-error.js';
import type { PartUnion } from '@google/genai';
import { import {
processSingleFileContent, processSingleFileContent,
getSpecificMimeType, getSpecificMimeType,
@@ -44,29 +45,20 @@ export interface ReadFileToolParams {
*/ */
end_line?: number; end_line?: number;
} }
class ReadFileToolInvocation extends BaseToolInvocation< class ReadFileToolInvocation extends BaseToolInvocation<
ReadFileToolParams, ReadFileToolParams,
ToolResult ToolResult
> { > {
private readonly resolvedPath: string; private readonly resolvedPath: string;
constructor( constructor(
private readonly config: Config, private config: Config,
params: ReadFileToolParams, params: ReadFileToolParams,
messageBus: MessageBus, messageBus: MessageBus,
_toolName?: string, _toolName?: string,
_toolDisplayName?: string, _toolDisplayName?: string,
isSensitive?: boolean,
) { ) {
super( super(params, messageBus, _toolName, _toolDisplayName);
params,
messageBus,
_toolName,
_toolDisplayName,
undefined,
undefined,
isSensitive,
);
this.resolvedPath = path.resolve( this.resolvedPath = path.resolve(
this.config.getTargetDir(), this.config.getTargetDir(),
this.params.file_path, this.params.file_path,
@@ -105,80 +97,66 @@ class ReadFileToolInvocation extends BaseToolInvocation<
}, },
}; };
} }
try {
const result = await processSingleFileContent(
this.resolvedPath,
this.config.getTargetDir(),
this.config.getFileSystemService(),
this.params.start_line,
this.params.end_line,
);
if (result.error) { const result = await processSingleFileContent(
return { this.resolvedPath,
llmContent: result.llmContent, this.config.getTargetDir(),
returnDisplay: result.returnDisplay || 'Error reading file', this.config.getFileSystemService(),
error: { this.params.start_line,
message: result.error, this.params.end_line,
type: result.errorType, );
},
};
}
let llmContent = result.llmContent; if (result.error) {
return {
if (result.isTruncated && typeof llmContent === 'string') { llmContent: result.llmContent,
const [startLine, endLine] = result.linesShown || [1, 0]; returnDisplay: result.returnDisplay || 'Error reading file',
llmContent = `
IMPORTANT: The file content has been truncated.
Status: Showing lines ${startLine}-${endLine} of ${result.originalLineCount} total lines.
Action: To read more of the file, you can use the 'start_line' and 'end_line' parameters in a subsequent 'read_file' call. For example, to read the next section of the file, use start_line: ${
endLine + 1
}.
--- FILE CONTENT (truncated) ---
${llmContent}
`;
}
const programming_language = getProgrammingLanguage({
file_path: this.resolvedPath,
});
logFileOperation(
this.config,
new FileOperationEvent(
this._toolName || READ_FILE_TOOL_NAME,
FileOperation.READ,
result.originalLineCount,
getSpecificMimeType(this.resolvedPath),
path.extname(this.resolvedPath),
programming_language,
),
);
const finalResult: ToolResult = {
llmContent,
returnDisplay: result.returnDisplay || '',
};
return finalResult;
} catch (err: unknown) {
const error = err instanceof Error ? err : new Error(String(err));
const errorMessage = String(error.message);
const toolResult: ToolResult = {
llmContent: [
{
text: `Error reading file: ${errorMessage}`,
},
],
returnDisplay: `Error: ${errorMessage}`,
error: { error: {
message: errorMessage, message: result.error,
type: ToolErrorType.EXECUTION_FAILED, type: result.errorType,
}, },
}; };
return toolResult;
} }
let llmContent: PartUnion;
if (result.isTruncated) {
const [start, end] = result.linesShown!;
const total = result.originalLineCount!;
llmContent = `
IMPORTANT: The file content has been truncated.
Status: Showing lines ${start}-${end} of ${total} total lines.
Action: To read more of the file, you can use the 'start_line' and 'end_line' parameters in a subsequent 'read_file' call. For example, to read the next section of the file, use start_line: ${end + 1}.
--- FILE CONTENT (truncated) ---
${result.llmContent}`;
} else {
llmContent = result.llmContent || '';
}
const lines =
typeof result.llmContent === 'string'
? result.llmContent.split('\n').length
: undefined;
const mimetype = getSpecificMimeType(this.resolvedPath);
const programming_language = getProgrammingLanguage({
file_path: this.resolvedPath,
});
logFileOperation(
this.config,
new FileOperationEvent(
READ_FILE_TOOL_NAME,
FileOperation.READ,
lines,
mimetype,
path.extname(this.resolvedPath),
programming_language,
),
);
return {
llmContent,
returnDisplay: result.returnDisplay || '',
};
} }
} }
@@ -205,9 +183,6 @@ export class ReadFileTool extends BaseDeclarativeTool<
messageBus, messageBus,
true, true,
false, false,
undefined,
undefined,
true,
); );
this.fileDiscoveryService = new FileDiscoveryService( this.fileDiscoveryService = new FileDiscoveryService(
config.getTargetDir(), config.getTargetDir(),
@@ -218,30 +193,15 @@ export class ReadFileTool extends BaseDeclarativeTool<
protected override validateToolParamValues( protected override validateToolParamValues(
params: ReadFileToolParams, params: ReadFileToolParams,
): string | null { ): string | null {
if (!params.file_path) { if (params.file_path.trim() === '') {
return "The 'file_path' parameter must be non-empty."; return "The 'file_path' parameter must be non-empty.";
} }
if (params.start_line !== undefined && params.start_line < 1) {
return 'start_line must be at least 1';
}
if (params.end_line !== undefined && params.end_line < 1) {
return 'end_line must be at least 1';
}
if (
params.start_line !== undefined &&
params.end_line !== undefined &&
params.start_line > params.end_line
) {
return 'start_line cannot be greater than end_line';
}
const resolvedPath = path.resolve( const resolvedPath = path.resolve(
this.config.getTargetDir(), this.config.getTargetDir(),
params.file_path, params.file_path,
); );
const validationError = this.config.validatePathAccess( const validationError = this.config.validatePathAccess(
resolvedPath, resolvedPath,
'read', 'read',
@@ -250,6 +210,20 @@ export class ReadFileTool extends BaseDeclarativeTool<
return validationError; return validationError;
} }
if (params.start_line !== undefined && params.start_line < 1) {
return 'start_line must be at least 1';
}
if (params.end_line !== undefined && params.end_line < 1) {
return 'end_line must be at least 1';
}
if (
params.start_line !== undefined &&
params.end_line !== undefined &&
params.start_line > params.end_line
) {
return 'start_line cannot be greater than end_line';
}
const fileFilteringOptions = this.config.getFileFilteringOptions(); const fileFilteringOptions = this.config.getFileFilteringOptions();
if ( if (
this.fileDiscoveryService.shouldIgnoreFile( this.fileDiscoveryService.shouldIgnoreFile(
@@ -268,7 +242,6 @@ export class ReadFileTool extends BaseDeclarativeTool<
messageBus: MessageBus, messageBus: MessageBus,
_toolName?: string, _toolName?: string,
_toolDisplayName?: string, _toolDisplayName?: string,
isSensitive?: boolean,
): ToolInvocation<ReadFileToolParams, ToolResult> { ): ToolInvocation<ReadFileToolParams, ToolResult> {
return new ReadFileToolInvocation( return new ReadFileToolInvocation(
this.config, this.config,
@@ -276,7 +249,6 @@ export class ReadFileTool extends BaseDeclarativeTool<
messageBus, messageBus,
_toolName, _toolName,
_toolDisplayName, _toolDisplayName,
isSensitive,
); );
} }