From 78bd526792b18381347331985b3bb25f22a7e853 Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Fri, 27 Mar 2026 13:09:59 -0400 Subject: [PATCH] chore: update channels --- packages/cli/src/ui/AppContainer.tsx | 18 +-- .../src/ui/components/views/ChannelsList.tsx | 6 +- packages/core/src/channels/types.ts | 21 +-- packages/core/src/config/config.ts | 2 +- packages/core/src/tools/mcp-client.test.ts | 135 ++++++++++++++++++ packages/core/src/tools/mcp-client.ts | 36 +++-- packages/core/src/tools/web-fetch.ts | 22 +-- packages/core/src/utils/textUtils.ts | 21 +++ 8 files changed, 203 insertions(+), 58 deletions(-) diff --git a/packages/cli/src/ui/AppContainer.tsx b/packages/cli/src/ui/AppContainer.tsx index 3a1468a33e..0924e1ef7a 100644 --- a/packages/cli/src/ui/AppContainer.tsx +++ b/packages/cli/src/ui/AppContainer.tsx @@ -1218,24 +1218,8 @@ Logging in with Google... Restarting Gemini CLI to continue. const channelsEnabled = config.getChannels().length > 0; useEffect(() => { if (!channelsEnabled) return; - const escapeAttr = (s: string) => - s - .replace(/&/g, '&') - .replace(/"/g, '"') - .replace(//g, '>'); const handler = (payload: ChannelMessagePayload) => { - const meta = payload.metadata ?? {}; - const user = meta['user'] ?? payload.sender; - const chatId = meta['chat_id'] ?? ''; - const msgId = meta['message_id'] ?? ''; - const imagePath = meta['image_path'] ?? ''; - const safeContent = payload.content.replace( - /<\/channel>/gi, - '</channel>', - ); - const formatted = `\n${safeContent}\n`; - addMessage(formatted); + addMessage(payload.content); }; coreEvents.on(CoreEvent.ChannelMessage, handler); return () => { diff --git a/packages/cli/src/ui/components/views/ChannelsList.tsx b/packages/cli/src/ui/components/views/ChannelsList.tsx index ab651b9fca..cd6feb2306 100644 --- a/packages/cli/src/ui/components/views/ChannelsList.tsx +++ b/packages/cli/src/ui/components/views/ChannelsList.tsx @@ -46,7 +46,11 @@ export const ChannelsList: React.FC = ({ channels }) => { Direction:{' '} {channel.supportsReply ? 'two-way' : 'one-way'} diff --git a/packages/core/src/channels/types.ts b/packages/core/src/channels/types.ts index 47564e4733..04d9835087 100644 --- a/packages/core/src/channels/types.ts +++ b/packages/core/src/channels/types.ts @@ -8,20 +8,16 @@ * Payload for the 'channel-message' event, emitted when an MCP server * declaring the `gemini/channel` experimental capability sends a * `notifications/gemini/channel` notification. + * + * XML formatting and escaping happens at the trust boundary in mcp-client.ts, + * so `content` is a pre-formatted, escaped `` XML string ready for + * injection into the conversation. */ export interface ChannelMessagePayload { /** Name of the MCP server acting as the channel. */ channelName: string; - /** Sender identifier (e.g. Telegram username, Discord user ID). */ - sender: string; - /** The message body. */ + /** Pre-formatted, escaped `` XML string. */ content: string; - /** Unix epoch milliseconds when the message was received. */ - timestamp: number; - /** Optional correlation ID for two-way channel replies. */ - replyTo?: string; - /** Extra key-value pairs surfaced as XML attributes on the tag. */ - metadata?: Record; } /** @@ -47,3 +43,10 @@ export const activeChannels = new Map(); export function getActiveChannelNames(): string[] { return Array.from(activeChannels.keys()); } + +/** + * Removes a channel entry when its MCP server disconnects. + */ +export function removeChannel(name: string): void { + activeChannels.delete(name); +} diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 6f59de68cb..71e8ae6605 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -1387,7 +1387,7 @@ export class Config implements McpContext, AgentLoopContext { if (active.length > 0) { coreEvents.emitFeedback( 'info', - `Channels listening for messages: ${active.join(', ')}`, + `Channels listening for messages: ${active.join(', ')}\n Only use channels you trust — messages are injected into the conversation.`, undefined, { style: 'channel' }, ); diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 4a14b671a0..7877a6ecfa 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -43,6 +43,7 @@ import * as fs from 'node:fs'; import * as os from 'node:os'; import * as path from 'node:path'; import { coreEvents } from '../utils/events.js'; +import { activeChannels } from '../channels/types.js'; import type { EnvironmentSanitizationConfig } from '../services/environmentSanitization.js'; interface TestableTransport { @@ -63,6 +64,7 @@ const MOCK_CONTEXT_DEFAULT = { emitMcpDiagnostic: vi.fn(), setUserInteractedWithMcp: vi.fn(), isTrustedFolder: vi.fn().mockReturnValue(true), + getChannels: vi.fn().mockReturnValue([]), }; let MOCK_CONTEXT: McpContext = MOCK_CONTEXT_DEFAULT; @@ -80,6 +82,8 @@ vi.mock('../utils/events.js', () => ({ coreEvents: { emitFeedback: vi.fn(), emitConsoleLog: vi.fn(), + emitChannelMessage: vi.fn(), + emitMcpProgress: vi.fn(), }, })); @@ -93,6 +97,7 @@ describe('mcp-client', () => { emitMcpDiagnostic: vi.fn(), setUserInteractedWithMcp: vi.fn(), isTrustedFolder: vi.fn().mockReturnValue(true), + getChannels: vi.fn().mockReturnValue([]), }; // create a tmp dir for this test // Create a unique temporary directory for the workspace to avoid conflicts @@ -1110,12 +1115,16 @@ describe('mcp-client', () => { expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); expect(mockedPromptRegistry.registerPrompt).toHaveBeenCalledOnce(); + // Simulate a channel entry being registered for this server + activeChannels.set('test-server', { supportsReply: false }); + await client.disconnect(); expect(mockedClient.close).toHaveBeenCalledOnce(); expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledOnce(); expect(mockedPromptRegistry.removePromptsByServer).toHaveBeenCalledOnce(); expect(resourceRegistry.removeResourcesByServer).toHaveBeenCalledOnce(); + expect(activeChannels.has('test-server')).toBe(false); }); }); @@ -1731,6 +1740,132 @@ describe('mcp-client', () => { }); }); + describe('Channel notifications', () => { + const CHANNEL_CAPABILITIES = { + experimental: { 'gemini/channel': { displayName: 'Test' } }, + }; + + /** + * Creates a mock MCP client, connects a McpClient, and returns + * the channel notification handler (or null if none was registered). + * The channel handler is always the last setNotificationHandler call + * when the server is in the --channels list (registered after Progress). + */ + async function connectWithChannels(channels: string[]) { + const mockedClient = { + connect: vi.fn(), + getServerCapabilities: vi.fn().mockReturnValue(CHANNEL_CAPABILITIES), + setNotificationHandler: vi.fn(), + request: vi.fn().mockResolvedValue({}), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + }; + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + + const client = new McpClient( + 'test-server', + { command: 'test-command' }, + workspaceContext, + { ...MOCK_CONTEXT, getChannels: vi.fn().mockReturnValue(channels) }, + false, + '0.0.1', + ); + await client.connect(); + + const handlerCalls = mockedClient.setNotificationHandler.mock.calls; + return { mockedClient, handlerCalls }; + } + + function getLastHandler( + handlerCalls: any[][], + ): ((notification: any) => void) | undefined { + return handlerCalls.length > 0 + ? handlerCalls[handlerCalls.length - 1][1] + : undefined; + } + + function getEmittedContent(): string { + return (coreEvents.emitChannelMessage as any).mock.calls[0][0].content; + } + + it('should register handler when server declares capability and is in --channels list', async () => { + const { handlerCalls: withChannel } = await connectWithChannels([ + 'test-server', + ]); + const { handlerCalls: withoutChannel } = await connectWithChannels([]); + + // When in --channels list, an extra handler is registered (the channel one). + expect(withChannel.length).toBe(withoutChannel.length + 1); + }); + + it('should NOT register handler when server is not in --channels list', async () => { + const { handlerCalls } = await connectWithChannels([]); + + // Only the ProgressNotificationSchema handler should be registered + // (no tools/resources/prompts capabilities = no other handlers). + expect(handlerCalls).toHaveLength(1); + expect(handlerCalls[0][0]).toBe(ProgressNotificationSchema); + }); + + it('should emit channel message with properly formatted XML', async () => { + const { handlerCalls } = await connectWithChannels(['test-server']); + const handler = getLastHandler(handlerCalls)!; + + handler({ + method: 'notifications/gemini/channel', + params: { + content: 'hello', + sender: 'alice', + meta: { chat_id: '123' }, + }, + }); + + expect(coreEvents.emitChannelMessage).toHaveBeenCalledWith({ + channelName: 'test-server', + content: expect.stringContaining('hello'), + }); + const xml = getEmittedContent(); + expect(xml).toContain(' { + const { handlerCalls } = await connectWithChannels(['test-server']); + const handler = getLastHandler(handlerCalls)!; + + handler({ + method: 'notifications/gemini/channel', + params: { + content: '', + sender: 'evil { + const { handlerCalls } = await connectWithChannels(['test-server']); + const handler = getLastHandler(handlerCalls)!; + + handler({ + method: 'notifications/gemini/channel', + params: { content: '', sender: 'alice' }, + }); + + expect(coreEvents.emitChannelMessage).not.toHaveBeenCalled(); + }); + }); + describe('appendMcpServerCommand', () => { it('should do nothing if no MCP servers or command are configured', () => { const out = populateMcpServerCommand({}, undefined); diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index af0ca65ffe..f89ceaa048 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -68,11 +68,12 @@ import type { WorkspaceContext, } from '../utils/workspaceContext.js'; import { getToolCallContext } from '../utils/toolCallContext.js'; +import { escapeXml, sanitizeXmlKey } from '../utils/textUtils.js'; import type { ToolRegistry } from './tool-registry.js'; import { debugLogger } from '../utils/debugLogger.js'; import { type MessageBus } from '../confirmation-bus/message-bus.js'; import { coreEvents } from '../utils/events.js'; -import { activeChannels } from '../channels/types.js'; +import { activeChannels, removeChannel } from '../channels/types.js'; import { type ResourceRegistry, type MCPResource, @@ -289,6 +290,7 @@ export class McpClient implements McpProgressReporter { registries.promptRegistry.removePromptsByServer(this.serverName); registries.resourceRegistry.removeResourcesByServer(this.serverName); } + removeChannel(this.serverName); this.updateStatus(MCPServerStatus.DISCONNECTING); const client = this.client; this.client = undefined; @@ -486,7 +488,7 @@ export class McpClient implements McpProgressReporter { // listen for channel notifications and route them through coreEvents. // Only register if this server is in the --channels list. const channelCap = capabilities?.experimental?.['gemini/channel']; - const enabledChannels = this.cliConfig.getChannels?.() ?? []; + const enabledChannels = this.cliConfig.getChannels(); if (channelCap && enabledChannels.includes(this.serverName)) { debugLogger.log( `Server '${this.serverName}' declares gemini/channel capability. Listening for channel messages...`, @@ -496,12 +498,11 @@ export class McpClient implements McpProgressReporter { channelCap != null && typeof channelCap === 'object' ? Object.fromEntries(Object.entries(channelCap)) : {}; + const rawDisplayName = channelCapRecord['displayName']; activeChannels.set(this.serverName, { supportsReply: capabilities?.tools != null, displayName: - typeof channelCapRecord['displayName'] === 'string' - ? channelCapRecord['displayName'] - : undefined, + typeof rawDisplayName === 'string' ? rawDisplayName : undefined, }); const ChannelNotificationSchema = NotificationSchema.extend({ @@ -517,19 +518,28 @@ export class McpClient implements McpProgressReporter { if (typeof content !== 'string' || !content) return; const rawMeta = params['meta']; - const meta = + const metaObj: Record = rawMeta != null && typeof rawMeta === 'object' ? Object.fromEntries( Object.entries(rawMeta).map(([k, v]) => [k, String(v)]), ) - : undefined; + : {}; + metaObj['user'] = + metaObj['user'] ?? String(params['sender'] ?? 'unknown'); + + const attrs = Object.entries(metaObj) + .filter(([, v]) => v !== '') + .map(([k, v]) => `${sanitizeXmlKey(k)}="${escapeXml(v)}"`) + .join(' '); + + const safeContent = content.replace(/<\/channel/gi, '</channel'); + + const source = escapeXml(this.serverName); + const formattedXml = `\n${safeContent}\n`; + coreEvents.emitChannelMessage({ channelName: this.serverName, - sender: String(params['sender'] ?? 'unknown'), - content, - timestamp: Date.now(), - replyTo: params['replyTo'] ? String(params['replyTo']) : undefined, - metadata: meta, + content: formattedXml, }); }, ); @@ -1817,7 +1827,7 @@ export interface McpContext { source?: string; }>; }; - getChannels?(): string[]; + getChannels(): string[]; } /** diff --git a/packages/core/src/tools/web-fetch.ts b/packages/core/src/tools/web-fetch.ts index dc90d892ef..a8eb3b624b 100644 --- a/packages/core/src/tools/web-fetch.ts +++ b/packages/core/src/tools/web-fetch.ts @@ -19,7 +19,7 @@ import { ToolErrorType } from './tool-error.js'; import { getErrorMessage } from '../utils/errors.js'; import { getResponseText } from '../utils/partUtils.js'; import { fetchWithTimeout, isPrivateIp } from '../utils/fetch.js'; -import { truncateString } from '../utils/textUtils.js'; +import { truncateString, escapeXml } from '../utils/textUtils.js'; import { convert } from 'html-to-text'; import { logWebFetchFallbackAttempt, @@ -188,18 +188,6 @@ function isGroundingSupportItem(item: unknown): item is GroundingSupportItem { return typeof item === 'object' && item !== null; } -/** - * Sanitizes text for safe embedding in XML tags. - */ -function sanitizeXml(text: string): string { - return text - .replace(/&/g, '&') - .replace(//g, '>') - .replace(/"/g, '"') - .replace(/'/g, '''); -} - /** * Parameters for the WebFetch tool */ @@ -434,10 +422,10 @@ class WebFetchToolInvocation extends BaseToolInvocation< .map((url) => { const content = finalContentsByUrl.get(url); if (content !== undefined) { - return `\n${sanitizeXml(content)}\n`; + return `\n${escapeXml(content)}\n`; } const error = errors.find((e) => e.url === url); - return `\nError: ${sanitizeXml(error?.message || 'Unknown error')}\n`; + return `\nError: ${escapeXml(error?.message || 'Unknown error')}\n`; }) .join('\n'); @@ -446,7 +434,7 @@ class WebFetchToolInvocation extends BaseToolInvocation< const fallbackPrompt = `Follow the user's instructions below using the provided webpage content. -${sanitizeXml(this.params.prompt ?? '')} +${escapeXml(this.params.prompt ?? '')} I was unable to access the URL(s) directly using the primary fetch tool. Instead, I have fetched the raw content of the page(s). Please use the following content to answer the request. Do not attempt to access the URL(s) again. @@ -771,7 +759,7 @@ Response: ${truncateString(rawResponseText, 10000, '\n\n... [Error response trun const sanitizedPrompt = `Follow the user's instructions to process the authorized URLs. -${sanitizeXml(userPrompt)} +${escapeXml(userPrompt)} diff --git a/packages/core/src/utils/textUtils.ts b/packages/core/src/utils/textUtils.ts index 8d4cbfa6d5..2069eedca9 100644 --- a/packages/core/src/utils/textUtils.ts +++ b/packages/core/src/utils/textUtils.ts @@ -121,6 +121,27 @@ export function truncateString( * @param replacements A record of keys to their replacement values. * @returns The resulting string with placeholders replaced. */ +/** + * Escapes a string for safe embedding in XML content or attributes. + * Replaces &, <, >, ", and ' with their XML entity equivalents. + */ +export function escapeXml(s: string): string { + return s + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"') + .replace(/'/g, '''); +} + +/** + * Strips characters that are not valid in XML element/attribute names. + * Only allows alphanumeric characters and underscores. + */ +export function sanitizeXmlKey(s: string): string { + return s.replace(/[^a-zA-Z0-9_]/g, ''); +} + export function safeTemplateReplace( template: string, replacements: Record,