diff --git a/eslint.config.js b/eslint.config.js index a0a0429119..d3a267f30a 100644 --- a/eslint.config.js +++ b/eslint.config.js @@ -35,11 +35,6 @@ const commonRestrictedSyntaxRules = [ message: 'Do not throw string literals or non-Error objects. Throw new Error("...") instead.', }, - { - selector: 'CallExpression[callee.name="fetch"]', - message: - 'Use safeFetch() from "@/utils/fetch" instead of the global fetch() to ensure SSRF protection. If you are implementing a custom security layer, use an eslint-disable comment and explain why.', - }, ]; export default tseslint.config( diff --git a/packages/cli/src/ui/commands/setupGithubCommand.ts b/packages/cli/src/ui/commands/setupGithubCommand.ts index 2554ebaa60..c68dd5cb88 100644 --- a/packages/cli/src/ui/commands/setupGithubCommand.ts +++ b/packages/cli/src/ui/commands/setupGithubCommand.ts @@ -123,7 +123,6 @@ async function downloadFiles({ downloads.push( (async () => { const endpoint = `${REPO_DOWNLOAD_URL}/refs/tags/${releaseTag}/${SOURCE_DIR}/${fileBasename}`; - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(endpoint, { method: 'GET', dispatcher: proxy ? new ProxyAgent(proxy) : undefined, diff --git a/packages/cli/src/utils/gitUtils.ts b/packages/cli/src/utils/gitUtils.ts index 83d89ad164..e27673f0fe 100644 --- a/packages/cli/src/utils/gitUtils.ts +++ b/packages/cli/src/utils/gitUtils.ts @@ -61,7 +61,6 @@ export const getLatestGitHubRelease = async ( const endpoint = `https://api.github.com/repos/google-github-actions/run-gemini-cli/releases/latest`; - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(endpoint, { method: 'GET', headers: { diff --git a/packages/core/src/agents/a2a-client-manager.test.ts b/packages/core/src/agents/a2a-client-manager.test.ts index aab0de5506..0a0aa4d956 100644 --- a/packages/core/src/agents/a2a-client-manager.test.ts +++ b/packages/core/src/agents/a2a-client-manager.test.ts @@ -5,11 +5,8 @@ */ import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest'; -import { - A2AClientManager, - type SendMessageResult, -} from './a2a-client-manager.js'; -import type { AgentCard, Task } from '@a2a-js/sdk'; +import { A2AClientManager } from './a2a-client-manager.js'; +import type { AgentCard } from '@a2a-js/sdk'; import { ClientFactory, DefaultAgentCardResolver, @@ -22,81 +19,95 @@ import type { Config } from '../config/config.js'; import { Agent as UndiciAgent, ProxyAgent } from 'undici'; import { debugLogger } from '../utils/debugLogger.js'; +interface MockClient { + sendMessageStream: ReturnType; + getTask: ReturnType; + cancelTask: ReturnType; +} + +vi.mock('@a2a-js/sdk/client', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...(actual as Record), + createAuthenticatingFetchWithRetry: vi.fn(), + ClientFactory: vi.fn(), + DefaultAgentCardResolver: vi.fn(), + ClientFactoryOptions: { + createFrom: vi.fn(), + default: {}, + }, + }; +}); + vi.mock('../utils/debugLogger.js', () => ({ debugLogger: { debug: vi.fn(), }, })); -vi.mock('@a2a-js/sdk/client', () => { - const ClientFactory = vi.fn(); - const DefaultAgentCardResolver = vi.fn(); - const RestTransportFactory = vi.fn(); - const JsonRpcTransportFactory = vi.fn(); - const ClientFactoryOptions = { - default: {}, - createFrom: vi.fn(), - }; - const createAuthenticatingFetchWithRetry = vi.fn(); - - DefaultAgentCardResolver.prototype.resolve = vi.fn(); - ClientFactory.prototype.createFromUrl = vi.fn(); - - return { - ClientFactory, - ClientFactoryOptions, - DefaultAgentCardResolver, - RestTransportFactory, - JsonRpcTransportFactory, - createAuthenticatingFetchWithRetry, - }; -}); - describe('A2AClientManager', () => { let manager: A2AClientManager; + const mockAgentCard: AgentCard = { + name: 'test-agent', + description: 'A test agent', + url: 'http://test.agent', + version: '1.0.0', + protocolVersion: '0.1.0', + capabilities: {}, + skills: [], + defaultInputModes: [], + defaultOutputModes: [], + }; + + const mockClient: MockClient = { + sendMessageStream: vi.fn(), + getTask: vi.fn(), + cancelTask: vi.fn(), + }; - // Stable mocks initialized once - const sendMessageStreamMock = vi.fn(); - const getTaskMock = vi.fn(); - const cancelTaskMock = vi.fn(); - const getAgentCardMock = vi.fn(); const authFetchMock = vi.fn(); - const mockClient = { - sendMessageStream: sendMessageStreamMock, - getTask: getTaskMock, - cancelTask: cancelTaskMock, - getAgentCard: getAgentCardMock, - } as unknown as Client; - - const mockAgentCard: Partial = { name: 'TestAgent' }; - beforeEach(() => { vi.clearAllMocks(); A2AClientManager.resetInstanceForTesting(); manager = A2AClientManager.getInstance(); - // Default mock implementations - getAgentCardMock.mockResolvedValue({ + // Re-create the instances as plain objects that can be spied on + const factoryInstance = { + createFromUrl: vi.fn(), + createFromAgentCard: vi.fn(), + }; + const resolverInstance = { + resolve: vi.fn(), + }; + + vi.mocked(ClientFactory).mockReturnValue( + factoryInstance as unknown as ClientFactory, + ); + vi.mocked(DefaultAgentCardResolver).mockReturnValue( + resolverInstance as unknown as DefaultAgentCardResolver, + ); + + vi.spyOn(factoryInstance, 'createFromUrl').mockResolvedValue( + mockClient as unknown as Client, + ); + vi.spyOn(factoryInstance, 'createFromAgentCard').mockResolvedValue( + mockClient as unknown as Client, + ); + vi.spyOn(resolverInstance, 'resolve').mockResolvedValue({ ...mockAgentCard, url: 'http://test.agent/real/endpoint', } as AgentCard); - vi.mocked(ClientFactory.prototype.createFromUrl).mockResolvedValue( - mockClient, + vi.spyOn(ClientFactoryOptions, 'createFrom').mockImplementation( + (_defaults, overrides) => overrides as unknown as ClientFactoryOptions, ); - vi.mocked(DefaultAgentCardResolver.prototype.resolve).mockResolvedValue({ - ...mockAgentCard, - url: 'http://test.agent/real/endpoint', - } as AgentCard); - - vi.mocked(ClientFactoryOptions.createFrom).mockImplementation( - (_defaults, overrides) => overrides as ClientFactoryOptions, - ); - - vi.mocked(createAuthenticatingFetchWithRetry).mockReturnValue( - authFetchMock, + vi.mocked(createAuthenticatingFetchWithRetry).mockImplementation(() => + authFetchMock.mockResolvedValue({ + ok: true, + json: async () => ({}), + } as Response), ); vi.stubGlobal( @@ -170,15 +181,19 @@ describe('A2AClientManager', () => { 'TestAgent', 'http://test.agent/card', ); - expect(agentCard).toMatchObject(mockAgentCard); expect(manager.getAgentCard('TestAgent')).toBe(agentCard); expect(manager.getClient('TestAgent')).toBeDefined(); }); + it('should configure ClientFactory with REST, JSON-RPC, and gRPC transports', async () => { + await manager.loadAgent('TestAgent', 'http://test.agent/card'); + expect(ClientFactoryOptions.createFrom).toHaveBeenCalled(); + }); + it('should throw an error if an agent with the same name is already loaded', async () => { await manager.loadAgent('TestAgent', 'http://test.agent/card'); await expect( - manager.loadAgent('TestAgent', 'http://another.agent/card'), + manager.loadAgent('TestAgent', 'http://test.agent/card'), ).rejects.toThrow("Agent with name 'TestAgent' is already loaded."); }); @@ -193,20 +208,12 @@ describe('A2AClientManager', () => { shouldRetryWithHeaders: vi.fn(), }; await manager.loadAgent( - 'CustomAuthAgent', - 'http://custom.agent/card', + 'TestAgent', + 'http://test.agent/card', customAuthHandler as unknown as AuthenticationHandler, ); - expect(createAuthenticatingFetchWithRetry).toHaveBeenCalledWith( - expect.anything(), - customAuthHandler, - ); - // Card resolver should NOT use the authenticated fetch by default. - const resolverInstance = vi.mocked(DefaultAgentCardResolver).mock - .instances[0]; - expect(resolverInstance).toBeDefined(); const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock .calls[0][0]; expect(resolverOptions?.fetchImpl).not.toBe(authFetchMock); @@ -267,106 +274,163 @@ describe('A2AClientManager', () => { it('should log a debug message upon loading an agent', async () => { await manager.loadAgent('TestAgent', 'http://test.agent/card'); expect(debugLogger.debug).toHaveBeenCalledWith( - "[A2AClientManager] Loaded agent 'TestAgent' from http://test.agent/card", + expect.stringContaining("Loaded agent 'TestAgent'"), ); }); it('should clear the cache', async () => { await manager.loadAgent('TestAgent', 'http://test.agent/card'); - expect(manager.getAgentCard('TestAgent')).toBeDefined(); - expect(manager.getClient('TestAgent')).toBeDefined(); - manager.clearCache(); - expect(manager.getAgentCard('TestAgent')).toBeUndefined(); expect(manager.getClient('TestAgent')).toBeUndefined(); - expect(debugLogger.debug).toHaveBeenCalledWith( - '[A2AClientManager] Cache cleared.', + }); + + it('should throw if resolveAgentCard fails', async () => { + const resolverInstance = { + resolve: vi.fn().mockRejectedValue(new Error('Resolution failed')), + }; + vi.mocked(DefaultAgentCardResolver).mockReturnValue( + resolverInstance as unknown as DefaultAgentCardResolver, ); + + await expect( + manager.loadAgent('FailAgent', 'http://fail.agent'), + ).rejects.toThrow('Resolution failed'); + }); + + it('should throw if factory.createFromAgentCard fails', async () => { + const factoryInstance = { + createFromAgentCard: vi + .fn() + .mockRejectedValue(new Error('Factory failed')), + }; + vi.mocked(ClientFactory).mockReturnValue( + factoryInstance as unknown as ClientFactory, + ); + + await expect( + manager.loadAgent('FailAgent', 'http://fail.agent'), + ).rejects.toThrow('Factory failed'); + }); + }); + + describe('getAgentCard and getClient', () => { + it('should return undefined if agent is not found', () => { + expect(manager.getAgentCard('Unknown')).toBeUndefined(); + expect(manager.getClient('Unknown')).toBeUndefined(); }); }); describe('sendMessageStream', () => { beforeEach(async () => { - await manager.loadAgent('TestAgent', 'http://test.agent'); + await manager.loadAgent('TestAgent', 'http://test.agent/card'); }); it('should send a message and return a stream', async () => { - const mockResult = { - kind: 'message', - messageId: 'a', - parts: [], - role: 'agent', - } as SendMessageResult; - - sendMessageStreamMock.mockReturnValue( + mockClient.sendMessageStream.mockReturnValue( (async function* () { - yield mockResult; + yield { kind: 'message' }; })(), ); const stream = manager.sendMessageStream('TestAgent', 'Hello'); const results = []; - for await (const res of stream) { - results.push(res); + for await (const result of stream) { + results.push(result); } - expect(results).toEqual([mockResult]); - expect(sendMessageStreamMock).toHaveBeenCalledWith( + expect(results).toHaveLength(1); + expect(mockClient.sendMessageStream).toHaveBeenCalled(); + }); + + it('should use contextId and taskId when provided', async () => { + mockClient.sendMessageStream.mockReturnValue( + (async function* () { + yield { kind: 'message' }; + })(), + ); + + const stream = manager.sendMessageStream('TestAgent', 'Hello', { + contextId: 'ctx123', + taskId: 'task456', + }); + // trigger execution + for await (const _ of stream) { + break; + } + + expect(mockClient.sendMessageStream).toHaveBeenCalledWith( expect.objectContaining({ - message: expect.anything(), + message: expect.objectContaining({ + contextId: 'ctx123', + taskId: 'task456', + }), }), expect.any(Object), ); }); - it('should use contextId and taskId when provided', async () => { - sendMessageStreamMock.mockReturnValue( + it('should correctly propagate AbortSignal to the stream', async () => { + mockClient.sendMessageStream.mockReturnValue( (async function* () { - yield { - kind: 'message', - messageId: 'a', - parts: [], - role: 'agent', - } as SendMessageResult; + yield { kind: 'message' }; })(), ); - const expectedContextId = 'user-context-id'; - const expectedTaskId = 'user-task-id'; - + const controller = new AbortController(); const stream = manager.sendMessageStream('TestAgent', 'Hello', { - contextId: expectedContextId, - taskId: expectedTaskId, + signal: controller.signal, }); - + // trigger execution for await (const _ of stream) { - // consume stream + break; } - const call = sendMessageStreamMock.mock.calls[0][0]; - expect(call.message.contextId).toBe(expectedContextId); - expect(call.message.taskId).toBe(expectedTaskId); + expect(mockClient.sendMessageStream).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ signal: controller.signal }), + ); }); - it('should propagate the original error on failure', async () => { - sendMessageStreamMock.mockImplementationOnce(() => { - throw new Error('Network error'); + it('should handle a multi-chunk stream with different event types', async () => { + mockClient.sendMessageStream.mockReturnValue( + (async function* () { + yield { kind: 'message', messageId: 'm1' }; + yield { kind: 'status-update', taskId: 't1' }; + })(), + ); + + const stream = manager.sendMessageStream('TestAgent', 'Hello'); + const results = []; + for await (const result of stream) { + results.push(result); + } + + expect(results).toHaveLength(2); + expect(results[0].kind).toBe('message'); + expect(results[1].kind).toBe('status-update'); + }); + + it('should throw prefixed error on failure', async () => { + mockClient.sendMessageStream.mockImplementation(() => { + throw new Error('Network failure'); }); const stream = manager.sendMessageStream('TestAgent', 'Hello'); await expect(async () => { for await (const _ of stream) { - // consume + // empty } - }).rejects.toThrow('Network error'); + }).rejects.toThrow( + '[A2AClientManager] sendMessageStream Error [TestAgent]: Network failure', + ); }); it('should throw an error if the agent is not found', async () => { const stream = manager.sendMessageStream('NonExistentAgent', 'Hello'); await expect(async () => { for await (const _ of stream) { - // consume + // empty } }).rejects.toThrow("Agent 'NonExistentAgent' not found."); }); @@ -374,28 +438,23 @@ describe('A2AClientManager', () => { describe('getTask', () => { beforeEach(async () => { - await manager.loadAgent('TestAgent', 'http://test.agent'); + await manager.loadAgent('TestAgent', 'http://test.agent/card'); }); it('should get a task from the correct agent', async () => { - getTaskMock.mockResolvedValue({ - id: 'task123', - contextId: 'a', - kind: 'task', - status: { state: 'completed' }, - } as Task); + const mockTask = { id: 'task123', kind: 'task' }; + mockClient.getTask.mockResolvedValue(mockTask); - await manager.getTask('TestAgent', 'task123'); - expect(getTaskMock).toHaveBeenCalledWith({ - id: 'task123', - }); + const result = await manager.getTask('TestAgent', 'task123'); + expect(result).toBe(mockTask); + expect(mockClient.getTask).toHaveBeenCalledWith({ id: 'task123' }); }); it('should throw prefixed error on failure', async () => { - getTaskMock.mockRejectedValueOnce(new Error('Network error')); + mockClient.getTask.mockRejectedValue(new Error('Not found')); await expect(manager.getTask('TestAgent', 'task123')).rejects.toThrow( - 'A2AClient getTask Error [TestAgent]: Network error', + 'A2AClient getTask Error [TestAgent]: Not found', ); }); @@ -408,28 +467,23 @@ describe('A2AClientManager', () => { describe('cancelTask', () => { beforeEach(async () => { - await manager.loadAgent('TestAgent', 'http://test.agent'); + await manager.loadAgent('TestAgent', 'http://test.agent/card'); }); it('should cancel a task on the correct agent', async () => { - cancelTaskMock.mockResolvedValue({ - id: 'task123', - contextId: 'a', - kind: 'task', - status: { state: 'canceled' }, - } as Task); + const mockTask = { id: 'task123', kind: 'task' }; + mockClient.cancelTask.mockResolvedValue(mockTask); - await manager.cancelTask('TestAgent', 'task123'); - expect(cancelTaskMock).toHaveBeenCalledWith({ - id: 'task123', - }); + const result = await manager.cancelTask('TestAgent', 'task123'); + expect(result).toBe(mockTask); + expect(mockClient.cancelTask).toHaveBeenCalledWith({ id: 'task123' }); }); it('should throw prefixed error on failure', async () => { - cancelTaskMock.mockRejectedValueOnce(new Error('Network error')); + mockClient.cancelTask.mockRejectedValue(new Error('Cannot cancel')); await expect(manager.cancelTask('TestAgent', 'task123')).rejects.toThrow( - 'A2AClient cancelTask Error [TestAgent]: Network error', + 'A2AClient cancelTask Error [TestAgent]: Cannot cancel', ); }); diff --git a/packages/core/src/agents/a2a-client-manager.ts b/packages/core/src/agents/a2a-client-manager.ts index 7d558e7dbe..3a03c033d8 100644 --- a/packages/core/src/agents/a2a-client-manager.ts +++ b/packages/core/src/agents/a2a-client-manager.ts @@ -12,36 +12,41 @@ import type { TaskStatusUpdateEvent, TaskArtifactUpdateEvent, } from '@a2a-js/sdk'; +import type { AuthenticationHandler, Client } from '@a2a-js/sdk/client'; import { - type Client, ClientFactory, ClientFactoryOptions, DefaultAgentCardResolver, - RestTransportFactory, JsonRpcTransportFactory, - type AuthenticationHandler, + RestTransportFactory, createAuthenticatingFetchWithRetry, } from '@a2a-js/sdk/client'; +import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc'; +import * as grpc from '@grpc/grpc-js'; import { v4 as uuidv4 } from 'uuid'; import { Agent as UndiciAgent, ProxyAgent } from 'undici'; +import { normalizeAgentCard } from './a2aUtils.js'; import type { Config } from '../config/config.js'; import { debugLogger } from '../utils/debugLogger.js'; -import { safeLookup } from '../utils/fetch.js'; import { classifyAgentError } from './a2a-errors.js'; -// Remote agents can take 10+ minutes (e.g. Deep Research). -// Use a dedicated dispatcher so the global 5-min timeout isn't affected. -const A2A_TIMEOUT = 1800000; // 30 minutes - +/** + * Result of sending a message, which can be a full message, a task, + * or an incremental status/artifact update. + */ export type SendMessageResult = | Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent; +// Remote agents can take 10+ minutes (e.g. Deep Research). +// Use a dedicated dispatcher so the global 5-min timeout isn't affected. +const A2A_TIMEOUT = 1800000; // 30 minutes + /** - * Manages A2A clients and caches loaded agent information. - * Follows a singleton pattern to ensure a single client instance. + * Orchestrates communication with remote A2A agents. + * Manages protocol negotiation, authentication, and transport selection. */ export class A2AClientManager { private static instance: A2AClientManager; @@ -58,9 +63,6 @@ export class A2AClientManager { const agentOptions = { headersTimeout: A2A_TIMEOUT, bodyTimeout: A2A_TIMEOUT, - connect: { - lookup: safeLookup, // SSRF protection at connection level - }, }; if (proxyUrl) { @@ -73,7 +75,6 @@ export class A2AClientManager { } this.a2aFetch = (input, init) => - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection fetch(input, { ...init, dispatcher: this.a2aDispatcher } as RequestInit); } @@ -139,22 +140,35 @@ export class A2AClientManager { }; const resolver = new DefaultAgentCardResolver({ fetchImpl: cardFetch }); + const rawCard = await resolver.resolve(agentCardUrl, ''); + // TODO: Remove normalizeAgentCard once @a2a-js/sdk handles + // proto field name aliases (supportedInterfaces → additionalInterfaces, + // protocolBinding → transport). + const agentCard = normalizeAgentCard(rawCard); - const options = ClientFactoryOptions.createFrom( + const grpcUrl = + agentCard.additionalInterfaces?.find((i) => i.transport === 'GRPC') + ?.url ?? agentCard.url; + + const clientOptions = ClientFactoryOptions.createFrom( ClientFactoryOptions.default, { transports: [ new RestTransportFactory({ fetchImpl: authFetch }), new JsonRpcTransportFactory({ fetchImpl: authFetch }), + new GrpcTransportFactory({ + grpcChannelCredentials: grpcUrl.startsWith('https://') + ? grpc.credentials.createSsl() + : grpc.credentials.createInsecure(), + }), ], cardResolver: resolver, }, ); try { - const factory = new ClientFactory(options); - const client = await factory.createFromUrl(agentCardUrl, ''); - const agentCard = await client.getAgentCard(); + const factory = new ClientFactory(clientOptions); + const client = await factory.createFromAgentCard(agentCard); this.clients.set(name, client); this.agentCards.set(name, agentCard); @@ -192,9 +206,7 @@ export class A2AClientManager { options?: { contextId?: string; taskId?: string; signal?: AbortSignal }, ): AsyncIterable { const client = this.clients.get(agentName); - if (!client) { - throw new Error(`Agent '${agentName}' not found.`); - } + if (!client) throw new Error(`Agent '${agentName}' not found.`); const messageParams: MessageSendParams = { message: { @@ -207,9 +219,19 @@ export class A2AClientManager { }, }; - yield* client.sendMessageStream(messageParams, { - signal: options?.signal, - }); + try { + yield* client.sendMessageStream(messageParams, { + signal: options?.signal, + }); + } catch (error: unknown) { + const prefix = `[A2AClientManager] sendMessageStream Error [${agentName}]`; + if (error instanceof Error) { + throw new Error(`${prefix}: ${error.message}`, { cause: error }); + } + throw new Error( + `${prefix}: Unexpected error during sendMessageStream: ${String(error)}`, + ); + } } /** @@ -238,9 +260,7 @@ export class A2AClientManager { */ async getTask(agentName: string, taskId: string): Promise { const client = this.clients.get(agentName); - if (!client) { - throw new Error(`Agent '${agentName}' not found.`); - } + if (!client) throw new Error(`Agent '${agentName}' not found.`); try { return await client.getTask({ id: taskId }); } catch (error: unknown) { @@ -260,9 +280,7 @@ export class A2AClientManager { */ async cancelTask(agentName: string, taskId: string): Promise { const client = this.clients.get(agentName); - if (!client) { - throw new Error(`Agent '${agentName}' not found.`); - } + if (!client) throw new Error(`Agent '${agentName}' not found.`); try { return await client.cancelTask({ id: taskId }); } catch (error: unknown) { diff --git a/packages/core/src/agents/a2aUtils.test.ts b/packages/core/src/agents/a2aUtils.test.ts index c3fe170aa5..0dce551be4 100644 --- a/packages/core/src/agents/a2aUtils.test.ts +++ b/packages/core/src/agents/a2aUtils.test.ts @@ -12,9 +12,6 @@ import { A2AResultReassembler, AUTH_REQUIRED_MSG, normalizeAgentCard, - getGrpcCredentials, - pinUrlToIp, - splitAgentCardUrl, } from './a2aUtils.js'; import type { SendMessageResult } from './a2a-client-manager.js'; import type { @@ -26,12 +23,6 @@ import type { TaskStatusUpdateEvent, TaskArtifactUpdateEvent, } from '@a2a-js/sdk'; -import * as dnsPromises from 'node:dns/promises'; -import type { LookupAddress } from 'node:dns'; - -vi.mock('node:dns/promises', () => ({ - lookup: vi.fn(), -})); describe('a2aUtils', () => { beforeEach(() => { @@ -42,89 +33,6 @@ describe('a2aUtils', () => { vi.restoreAllMocks(); }); - describe('getGrpcCredentials', () => { - it('should return secure credentials for https', () => { - const credentials = getGrpcCredentials('https://test.agent'); - expect(credentials).toBeDefined(); - }); - - it('should return insecure credentials for http', () => { - const credentials = getGrpcCredentials('http://test.agent'); - expect(credentials).toBeDefined(); - }); - }); - - describe('pinUrlToIp', () => { - it('should resolve and pin hostname to IP', async () => { - vi.mocked( - dnsPromises.lookup as unknown as ( - hostname: string, - options: { all: true }, - ) => Promise, - ).mockResolvedValue([{ address: '93.184.216.34', family: 4 }]); - - const { pinnedUrl, hostname } = await pinUrlToIp( - 'http://example.com:9000', - 'test-agent', - ); - expect(hostname).toBe('example.com'); - expect(pinnedUrl).toBe('http://93.184.216.34:9000/'); - }); - - it('should handle raw host:port strings (standard for gRPC)', async () => { - vi.mocked( - dnsPromises.lookup as unknown as ( - hostname: string, - options: { all: true }, - ) => Promise, - ).mockResolvedValue([{ address: '93.184.216.34', family: 4 }]); - - const { pinnedUrl, hostname } = await pinUrlToIp( - 'example.com:9000', - 'test-agent', - ); - expect(hostname).toBe('example.com'); - expect(pinnedUrl).toBe('93.184.216.34:9000'); - }); - - it('should throw error if resolution fails (fail closed)', async () => { - vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error')); - - await expect( - pinUrlToIp('http://unreachable.com', 'test-agent'), - ).rejects.toThrow("Failed to resolve host for agent 'test-agent'"); - }); - - it('should throw error if resolved to private IP', async () => { - vi.mocked( - dnsPromises.lookup as unknown as ( - hostname: string, - options: { all: true }, - ) => Promise, - ).mockResolvedValue([{ address: '10.0.0.1', family: 4 }]); - - await expect( - pinUrlToIp('http://malicious.com', 'test-agent'), - ).rejects.toThrow('resolves to private IP range'); - }); - - it('should allow localhost/127.0.0.1/::1 exceptions', async () => { - vi.mocked( - dnsPromises.lookup as unknown as ( - hostname: string, - options: { all: true }, - ) => Promise, - ).mockResolvedValue([{ address: '127.0.0.1', family: 4 }]); - - const { pinnedUrl, hostname } = await pinUrlToIp( - 'http://localhost:9000', - 'test-agent', - ); - expect(hostname).toBe('localhost'); - expect(pinnedUrl).toBe('http://127.0.0.1:9000/'); - }); - }); - describe('isTerminalState', () => { it('should return true for completed, failed, canceled, and rejected', () => { expect(isTerminalState('completed')).toBe(true); @@ -365,12 +273,12 @@ describe('a2aUtils', () => { expect(normalized.name).toBe('my-agent'); // @ts-expect-error - testing dynamic preservation expect(normalized.customField).toBe('keep-me'); - expect(normalized.description).toBe(''); - expect(normalized.skills).toEqual([]); - expect(normalized.defaultInputModes).toEqual([]); + expect(normalized.description).toBeUndefined(); + expect(normalized.skills).toBeUndefined(); + expect(normalized.defaultInputModes).toBeUndefined(); }); - it('should normalize and synchronize interfaces while preserving other fields', () => { + it('should map supportedInterfaces to additionalInterfaces with protocolBinding → transport', () => { const raw = { name: 'test', supportedInterfaces: [ @@ -384,13 +292,7 @@ describe('a2aUtils', () => { const normalized = normalizeAgentCard(raw); - // Should exist in both fields expect(normalized.additionalInterfaces).toHaveLength(1); - expect( - (normalized as unknown as Record)[ - 'supportedInterfaces' - ], - ).toHaveLength(1); const intf = normalized.additionalInterfaces?.[0] as unknown as Record< string, @@ -399,43 +301,18 @@ describe('a2aUtils', () => { expect(intf['transport']).toBe('GRPC'); expect(intf['url']).toBe('grpc://test'); - - // Should fallback top-level url - expect(normalized.url).toBe('grpc://test'); }); - it('should preserve existing top-level url if present', () => { + it('should not overwrite additionalInterfaces if already present', () => { const raw = { name: 'test', - url: 'http://existing', + additionalInterfaces: [{ url: 'http://grpc', transport: 'GRPC' }], supportedInterfaces: [{ url: 'http://other', transport: 'REST' }], }; const normalized = normalizeAgentCard(raw); - expect(normalized.url).toBe('http://existing'); - }); - - it('should NOT prepend http:// scheme to raw IP:port strings for gRPC interfaces', () => { - const raw = { - name: 'raw-ip-grpc', - supportedInterfaces: [{ url: '127.0.0.1:9000', transport: 'GRPC' }], - }; - - const normalized = normalizeAgentCard(raw); - expect(normalized.additionalInterfaces?.[0].url).toBe('127.0.0.1:9000'); - expect(normalized.url).toBe('127.0.0.1:9000'); - }); - - it('should prepend http:// scheme to raw IP:port strings for REST interfaces', () => { - const raw = { - name: 'raw-ip-rest', - supportedInterfaces: [{ url: '127.0.0.1:8080', transport: 'REST' }], - }; - - const normalized = normalizeAgentCard(raw); - expect(normalized.additionalInterfaces?.[0].url).toBe( - 'http://127.0.0.1:8080', - ); + expect(normalized.additionalInterfaces).toHaveLength(1); + expect(normalized.additionalInterfaces?.[0].url).toBe('http://grpc'); }); it('should NOT override existing transport if protocolBinding is also present', () => { @@ -448,48 +325,20 @@ describe('a2aUtils', () => { const normalized = normalizeAgentCard(raw); expect(normalized.additionalInterfaces?.[0].transport).toBe('GRPC'); }); - }); - describe('splitAgentCardUrl', () => { - const standard = '.well-known/agent-card.json'; + it('should not mutate the original card object', () => { + const raw = { + name: 'test', + supportedInterfaces: [{ url: 'grpc://test', protocolBinding: 'GRPC' }], + }; - it('should return baseUrl as-is if it does not end with standard path', () => { - const url = 'http://localhost:9001/custom/path'; - expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url }); - }); - - it('should split correctly if URL ends with standard path', () => { - const url = `http://localhost:9001/${standard}`; - expect(splitAgentCardUrl(url)).toEqual({ - baseUrl: 'http://localhost:9001/', - path: undefined, - }); - }); - - it('should handle trailing slash in baseUrl when splitting', () => { - const url = `http://example.com/api/${standard}`; - expect(splitAgentCardUrl(url)).toEqual({ - baseUrl: 'http://example.com/api/', - path: undefined, - }); - }); - - it('should ignore hashes and query params when splitting', () => { - const url = `http://localhost:9001/${standard}?foo=bar#baz`; - expect(splitAgentCardUrl(url)).toEqual({ - baseUrl: 'http://localhost:9001/', - path: undefined, - }); - }); - - it('should return original URL if parsing fails', () => { - const url = 'not-a-url'; - expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url }); - }); - - it('should handle standard path appearing earlier in the path', () => { - const url = `http://localhost:9001/${standard}/something-else`; - expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url }); + const normalized = normalizeAgentCard(raw); + expect(normalized).not.toBe(raw); + expect(normalized.additionalInterfaces).toBeDefined(); + // Original should not have additionalInterfaces added + expect( + (raw as Record)['additionalInterfaces'], + ).toBeUndefined(); }); }); diff --git a/packages/core/src/agents/a2aUtils.ts b/packages/core/src/agents/a2aUtils.ts index ec8b36bba1..70fc9cf557 100644 --- a/packages/core/src/agents/a2aUtils.ts +++ b/packages/core/src/agents/a2aUtils.ts @@ -4,9 +4,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -import * as grpc from '@grpc/grpc-js'; -import { lookup } from 'node:dns/promises'; -import { z } from 'zod'; import type { Message, Part, @@ -18,37 +15,10 @@ import type { AgentCard, AgentInterface, } from '@a2a-js/sdk'; -import { isAddressPrivate } from '../utils/fetch.js'; import type { SendMessageResult } from './a2a-client-manager.js'; export const AUTH_REQUIRED_MSG = `[Authorization Required] The agent has indicated it requires authorization to proceed. Please follow the agent's instructions.`; -const AgentInterfaceSchema = z - .object({ - url: z.string().default(''), - transport: z.string().optional(), - protocolBinding: z.string().optional(), - }) - .passthrough(); - -const AgentCardSchema = z - .object({ - name: z.string().default('unknown'), - description: z.string().default(''), - url: z.string().default(''), - version: z.string().default(''), - protocolVersion: z.string().default(''), - capabilities: z.record(z.unknown()).default({}), - skills: z.array(z.union([z.string(), z.record(z.unknown())])).default([]), - defaultInputModes: z.array(z.string()).default([]), - defaultOutputModes: z.array(z.string()).default([]), - - additionalInterfaces: z.array(AgentInterfaceSchema).optional(), - supportedInterfaces: z.array(AgentInterfaceSchema).optional(), - preferredTransport: z.string().optional(), - }) - .passthrough(); - /** * Reassembles incremental A2A streaming updates into a coherent result. * Shows sequential status/messages followed by all reassembled artifacts. @@ -241,166 +211,45 @@ function extractPartText(part: Part): string { } /** - * Normalizes an agent card by ensuring it has the required properties - * and resolving any inconsistencies between protocol versions. + * Normalizes proto field name aliases that the SDK doesn't handle yet. + * The A2A proto spec uses `supported_interfaces` and `protocol_binding`, + * while the SDK expects `additionalInterfaces` and `transport`. + * TODO: Remove once @a2a-js/sdk handles these aliases natively. */ export function normalizeAgentCard(card: unknown): AgentCard { if (!isObject(card)) { throw new Error('Agent card is missing.'); } - // Use Zod to validate and parse the card, ensuring safe defaults and narrowing types. - const parsed = AgentCardSchema.parse(card); - // Narrowing to AgentCard interface after runtime validation. + // Shallow-copy to avoid mutating the SDK's cached object. // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const result = parsed as unknown as AgentCard; + const result = { ...card } as unknown as AgentCard; - // Normalize interfaces and synchronize both interface fields. - const normalizedInterfaces = extractNormalizedInterfaces(parsed); - result.additionalInterfaces = normalizedInterfaces; - - // Sync supportedInterfaces for backward compatibility. - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const legacyResult = result as unknown as Record; - legacyResult['supportedInterfaces'] = normalizedInterfaces; - - // Fallback preferredTransport: If not specified, default to GRPC if available. - if ( - !result.preferredTransport && - normalizedInterfaces.some((i) => i.transport === 'GRPC') - ) { - result.preferredTransport = 'GRPC'; + // Map supportedInterfaces → additionalInterfaces if needed + if (!result.additionalInterfaces) { + const raw = card; + if (Array.isArray(raw['supportedInterfaces'])) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + result.additionalInterfaces = raw[ + 'supportedInterfaces' + ] as AgentInterface[]; + } } - // Fallback: If top-level URL is missing, use the first interface's URL. - if (result.url === '' && normalizedInterfaces.length > 0) { - result.url = normalizedInterfaces[0].url; + // Map protocolBinding → transport on each interface + for (const intf of result.additionalInterfaces ?? []) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + const raw = intf as unknown as Record; + const binding = raw['protocolBinding']; + + if (!intf.transport && typeof binding === 'string') { + intf.transport = binding; + } } return result; } -/** - * Returns gRPC channel credentials based on the URL scheme. - */ -export function getGrpcCredentials(url: string): grpc.ChannelCredentials { - return url.startsWith('https://') - ? grpc.credentials.createSsl() - : grpc.credentials.createInsecure(); -} - -/** - * Returns gRPC channel options to ensure SSL/authority matches the original hostname - * when connecting via a pinned IP address. - */ -export function getGrpcChannelOptions( - hostname: string, -): Record { - return { - 'grpc.default_authority': hostname, - 'grpc.ssl_target_name_override': hostname, - }; -} - -/** - * Resolves a hostname to its IP address and validates it against SSRF. - * Returns the pinned IP-based URL and the original hostname. - */ -export async function pinUrlToIp( - url: string, - agentName: string, -): Promise<{ pinnedUrl: string; hostname: string }> { - if (!url) return { pinnedUrl: url, hostname: '' }; - - // gRPC URLs in A2A can be 'host:port' or 'dns:///host:port' or have schemes. - // We normalize to host:port for resolution. - const hasScheme = url.includes('://'); - const normalizedUrl = hasScheme ? url : `http://${url}`; - - try { - const parsed = new URL(normalizedUrl); - const hostname = parsed.hostname; - - const sanitizedHost = - hostname.startsWith('[') && hostname.endsWith(']') - ? hostname.slice(1, -1) - : hostname; - - // Resolve DNS to check the actual target IP and pin it - const addresses = await lookup(hostname, { all: true }); - const publicAddresses = addresses.filter( - (addr) => - !isAddressPrivate(addr.address) || - sanitizedHost === 'localhost' || - sanitizedHost === '127.0.0.1' || - sanitizedHost === '::1', - ); - - if (publicAddresses.length === 0) { - if (addresses.length > 0) { - throw new Error( - `Refusing to load agent '${agentName}': transport URL '${url}' resolves to private IP range.`, - ); - } - throw new Error( - `Failed to resolve any public IP addresses for host: ${hostname}`, - ); - } - - const pinnedIp = publicAddresses[0].address; - const pinnedHostname = pinnedIp.includes(':') ? `[${pinnedIp}]` : pinnedIp; - - // Reconstruct URL with IP - parsed.hostname = pinnedHostname; - let pinnedUrl = parsed.toString(); - - // If original didn't have scheme, remove it (standard for gRPC targets) - if (!hasScheme) { - pinnedUrl = pinnedUrl.replace(/^http:\/\//, ''); - // URL.toString() might append a trailing slash - if (pinnedUrl.endsWith('/') && !url.endsWith('/')) { - pinnedUrl = pinnedUrl.slice(0, -1); - } - } - - return { pinnedUrl, hostname }; - } catch (e) { - if (e instanceof Error && e.message.includes('Refusing')) throw e; - throw new Error(`Failed to resolve host for agent '${agentName}': ${url}`, { - cause: e, - }); - } -} - -/** - * Splts an agent card URL into a baseUrl and a standard path if it already - * contains '.well-known/agent-card.json'. - */ -export function splitAgentCardUrl(url: string): { - baseUrl: string; - path?: string; -} { - const standardPath = '.well-known/agent-card.json'; - try { - const parsedUrl = new URL(url); - if (parsedUrl.pathname.endsWith(standardPath)) { - // Reconstruct baseUrl from parsed components to avoid issues with hashes or query params. - parsedUrl.pathname = parsedUrl.pathname.substring( - 0, - parsedUrl.pathname.lastIndexOf(standardPath), - ); - parsedUrl.search = ''; - parsedUrl.hash = ''; - // We return undefined for path if it's the standard one, - // because the SDK's DefaultAgentCardResolver appends it automatically. - return { baseUrl: parsedUrl.toString(), path: undefined }; - } - } catch (_e) { - // Ignore URL parsing errors here, let the resolver handle them. - } - return { baseUrl: url }; -} - /** * Extracts contextId and taskId from a Message, Task, or Update response. * Follows the pattern from the A2A CLI sample to maintain conversational continuity. @@ -446,65 +295,6 @@ export function extractIdsFromResponse(result: SendMessageResult): { return { contextId, taskId, clearTaskId }; } -/** - * Extracts and normalizes interfaces from the card, handling protocol version fallbacks. - * Preserves all original fields to maintain SDK compatibility. - */ -function extractNormalizedInterfaces( - card: Record, -): AgentInterface[] { - const rawInterfaces = - getArray(card, 'additionalInterfaces') || - getArray(card, 'supportedInterfaces'); - - if (!rawInterfaces) { - return []; - } - - const mapped: AgentInterface[] = []; - for (const i of rawInterfaces) { - if (isObject(i)) { - // Use schema to validate interface object. - const parsed = AgentInterfaceSchema.parse(i); - // Narrowing to AgentInterface after runtime validation. - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const normalized = parsed as unknown as AgentInterface & { - protocolBinding?: string; - }; - - // Normalize 'transport' from 'protocolBinding' if missing. - if (!normalized.transport && normalized.protocolBinding) { - normalized.transport = normalized.protocolBinding; - } - - // Robust URL: Ensure the URL has a scheme (except for gRPC). - if ( - normalized.url && - !normalized.url.includes('://') && - !normalized.url.startsWith('/') && - normalized.transport !== 'GRPC' - ) { - // Default to http:// for insecure REST/JSON-RPC if scheme is missing. - normalized.url = `http://${normalized.url}`; - } - - mapped.push(normalized as AgentInterface); - } - } - return mapped; -} - -/** - * Safely extracts an array property from an object. - */ -function getArray( - obj: Record, - key: string, -): unknown[] | undefined { - const val = obj[key]; - return Array.isArray(val) ? val : undefined; -} - // Type Guards function isTextPart(part: Part): part is TextPart { diff --git a/packages/core/src/code_assist/oauth2.ts b/packages/core/src/code_assist/oauth2.ts index 654ba0e10a..e238a4a860 100644 --- a/packages/core/src/code_assist/oauth2.ts +++ b/packages/core/src/code_assist/oauth2.ts @@ -700,7 +700,6 @@ async function fetchAndCacheUserInfo(client: OAuth2Client): Promise { return; } - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch( 'https://www.googleapis.com/oauth2/v2/userinfo', { diff --git a/packages/core/src/mcp/oauth-provider.ts b/packages/core/src/mcp/oauth-provider.ts index 01934d9019..6aaafa6054 100644 --- a/packages/core/src/mcp/oauth-provider.ts +++ b/packages/core/src/mcp/oauth-provider.ts @@ -111,7 +111,6 @@ export class MCPOAuthProvider { scope: config.scopes?.join(' ') || '', }; - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(registrationUrl, { method: 'POST', headers: { @@ -301,7 +300,6 @@ export class MCPOAuthProvider { ? { Accept: 'text/event-stream' } : { Accept: 'application/json' }; - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(mcpServerUrl, { method: 'HEAD', headers, diff --git a/packages/core/src/mcp/oauth-utils.ts b/packages/core/src/mcp/oauth-utils.ts index 207b694181..320c3b9685 100644 --- a/packages/core/src/mcp/oauth-utils.ts +++ b/packages/core/src/mcp/oauth-utils.ts @@ -97,7 +97,6 @@ export class OAuthUtils { resourceMetadataUrl: string, ): Promise { try { - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(resourceMetadataUrl); if (!response.ok) { return null; @@ -122,7 +121,6 @@ export class OAuthUtils { authServerMetadataUrl: string, ): Promise { try { - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(authServerMetadataUrl); if (!response.ok) { return null; diff --git a/packages/core/src/telemetry/clearcut-logger/clearcut-logger.ts b/packages/core/src/telemetry/clearcut-logger/clearcut-logger.ts index 5953578eae..2f059030ca 100644 --- a/packages/core/src/telemetry/clearcut-logger/clearcut-logger.ts +++ b/packages/core/src/telemetry/clearcut-logger/clearcut-logger.ts @@ -546,7 +546,6 @@ export class ClearcutLogger { let result: LogResponse = {}; try { - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(CLEARCUT_URL, { method: 'POST', body: safeJsonStringify(request), diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 7932e35f38..6dbae6dcde 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -1903,7 +1903,6 @@ export async function connectToMcpServer( acceptHeader = 'application/json'; } - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(urlToFetch, { method: 'HEAD', headers: { diff --git a/packages/core/src/utils/fetch.test.ts b/packages/core/src/utils/fetch.test.ts index 3eddefaf3d..4ac0c7b344 100644 --- a/packages/core/src/utils/fetch.test.ts +++ b/packages/core/src/utils/fetch.test.ts @@ -5,27 +5,12 @@ */ import { describe, it, expect, vi, beforeEach, afterAll } from 'vitest'; -import { - isPrivateIp, - isPrivateIpAsync, - isAddressPrivate, - safeLookup, - safeFetch, - fetchWithTimeout, - PrivateIpError, -} from './fetch.js'; -import * as dnsPromises from 'node:dns/promises'; -import * as dns from 'node:dns'; +import { isPrivateIp, isAddressPrivate, fetchWithTimeout } from './fetch.js'; vi.mock('node:dns/promises', () => ({ lookup: vi.fn(), })); -// We need to mock node:dns for safeLookup since it uses the callback API -vi.mock('node:dns', () => ({ - lookup: vi.fn(), -})); - // Mock global fetch const originalFetch = global.fetch; global.fetch = vi.fn(); @@ -114,150 +99,6 @@ describe('fetch utils', () => { }); }); - describe('isPrivateIpAsync', () => { - it('should identify private IPs directly', async () => { - expect(await isPrivateIpAsync('http://10.0.0.1/')).toBe(true); - }); - - it('should identify domains resolving to private IPs', async () => { - vi.mocked(dnsPromises.lookup).mockImplementation( - async () => - // eslint-disable-next-line @typescript-eslint/no-explicit-any - [{ address: '10.0.0.1', family: 4 }] as any, - ); - expect(await isPrivateIpAsync('http://malicious.com/')).toBe(true); - }); - - it('should identify domains resolving to public IPs as non-private', async () => { - vi.mocked(dnsPromises.lookup).mockImplementation( - async () => - // eslint-disable-next-line @typescript-eslint/no-explicit-any - [{ address: '8.8.8.8', family: 4 }] as any, - ); - expect(await isPrivateIpAsync('http://google.com/')).toBe(false); - }); - - it('should throw error if DNS resolution fails (fail closed)', async () => { - vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error')); - await expect(isPrivateIpAsync('http://unreachable.com/')).rejects.toThrow( - 'Failed to verify if URL resolves to private IP', - ); - }); - - it('should return false for invalid URLs instead of throwing verification error', async () => { - expect(await isPrivateIpAsync('not-a-url')).toBe(false); - }); - }); - - describe('safeLookup', () => { - it('should filter out private IPs', async () => { - const addresses = [ - { address: '8.8.8.8', family: 4 }, - { address: '10.0.0.1', family: 4 }, - ]; - - vi.mocked(dns.lookup).mockImplementation((( - _h: string, - _o: dns.LookupOptions, - cb: ( - err: Error | null, - addr: Array<{ address: string; family: number }>, - ) => void, - ) => { - cb(null, addresses); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - }) as any); - - const result = await new Promise< - Array<{ address: string; family: number }> - >((resolve, reject) => { - safeLookup('example.com', { all: true }, (err, filtered) => { - if (err) reject(err); - else resolve(filtered); - }); - }); - - expect(result).toHaveLength(1); - expect(result[0].address).toBe('8.8.8.8'); - }); - - it('should allow explicit localhost', async () => { - const addresses = [{ address: '127.0.0.1', family: 4 }]; - - vi.mocked(dns.lookup).mockImplementation((( - _h: string, - _o: dns.LookupOptions, - cb: ( - err: Error | null, - addr: Array<{ address: string; family: number }>, - ) => void, - ) => { - cb(null, addresses); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - }) as any); - - const result = await new Promise< - Array<{ address: string; family: number }> - >((resolve, reject) => { - safeLookup('localhost', { all: true }, (err, filtered) => { - if (err) reject(err); - else resolve(filtered); - }); - }); - - expect(result).toHaveLength(1); - expect(result[0].address).toBe('127.0.0.1'); - }); - - it('should error if all resolved IPs are private', async () => { - const addresses = [{ address: '10.0.0.1', family: 4 }]; - - vi.mocked(dns.lookup).mockImplementation((( - _h: string, - _o: dns.LookupOptions, - cb: ( - err: Error | null, - addr: Array<{ address: string; family: number }>, - ) => void, - ) => { - cb(null, addresses); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - }) as any); - - await expect( - new Promise((resolve, reject) => { - safeLookup('malicious.com', { all: true }, (err, filtered) => { - if (err) reject(err); - else resolve(filtered); - }); - }), - ).rejects.toThrow(PrivateIpError); - }); - }); - - describe('safeFetch', () => { - it('should forward to fetch with dispatcher', async () => { - vi.mocked(global.fetch).mockResolvedValue(new Response('ok')); - - const response = await safeFetch('https://example.com'); - expect(response.status).toBe(200); - expect(global.fetch).toHaveBeenCalledWith( - 'https://example.com', - expect.objectContaining({ - dispatcher: expect.any(Object), - }), - ); - }); - - it('should handle Refusing to connect errors', async () => { - vi.mocked(global.fetch).mockRejectedValue(new PrivateIpError()); - - await expect(safeFetch('http://10.0.0.1')).rejects.toThrow( - 'Access to private network is blocked', - ); - }); - }); - describe('fetchWithTimeout', () => { it('should handle timeouts', async () => { vi.mocked(global.fetch).mockImplementation( @@ -279,13 +120,5 @@ describe('fetch utils', () => { 'Request timed out after 50ms', ); }); - - it('should handle private IP errors via handleFetchError', async () => { - vi.mocked(global.fetch).mockRejectedValue(new PrivateIpError()); - - await expect(fetchWithTimeout('http://10.0.0.1', 1000)).rejects.toThrow( - 'Access to private network is blocked: http://10.0.0.1', - ); - }); }); }); diff --git a/packages/core/src/utils/fetch.ts b/packages/core/src/utils/fetch.ts index a324172d94..e339ea7fed 100644 --- a/packages/core/src/utils/fetch.ts +++ b/packages/core/src/utils/fetch.ts @@ -6,37 +6,12 @@ import { getErrorMessage, isNodeError } from './errors.js'; import { URL } from 'node:url'; -import * as dns from 'node:dns'; -import { lookup } from 'node:dns/promises'; import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici'; import ipaddr from 'ipaddr.js'; const DEFAULT_HEADERS_TIMEOUT = 300000; // 5 minutes const DEFAULT_BODY_TIMEOUT = 300000; // 5 minutes -// Configure default global dispatcher with higher timeouts -setGlobalDispatcher( - new Agent({ - headersTimeout: DEFAULT_HEADERS_TIMEOUT, - bodyTimeout: DEFAULT_BODY_TIMEOUT, - }), -); - -// Local extension of RequestInit to support Node.js/undici dispatcher -interface NodeFetchInit extends RequestInit { - dispatcher?: Agent | ProxyAgent; -} - -/** - * Error thrown when a connection to a private IP address is blocked for security reasons. - */ -export class PrivateIpError extends Error { - constructor(message = 'Refusing to connect to private IP address') { - super(message); - this.name = 'PrivateIpError'; - } -} - export class FetchError extends Error { constructor( message: string, @@ -48,6 +23,14 @@ export class FetchError extends Error { } } +// Configure default global dispatcher with higher timeouts +setGlobalDispatcher( + new Agent({ + headersTimeout: DEFAULT_HEADERS_TIMEOUT, + bodyTimeout: DEFAULT_BODY_TIMEOUT, + }), +); + /** * Sanitizes a hostname by stripping IPv6 brackets if present. */ @@ -69,53 +52,6 @@ export function isLoopbackHost(hostname: string): boolean { ); } -/** - * A custom DNS lookup implementation for undici agents that prevents - * connection to private IP ranges (SSRF protection). - */ -export function safeLookup( - hostname: string, - options: dns.LookupOptions | number | null | undefined, - callback: ( - err: Error | null, - addresses: Array<{ address: string; family: number }>, - ) => void, -): void { - // Use the callback-based dns.lookup to match undici's expected signature. - // We explicitly handle the 'all' option to ensure we get an array of addresses. - const lookupOptions = - typeof options === 'number' ? { family: options } : { ...options }; - const finalOptions = { ...lookupOptions, all: true }; - - dns.lookup(hostname, finalOptions, (err, addresses) => { - if (err) { - callback(err, []); - return; - } - - const addressArray = Array.isArray(addresses) ? addresses : []; - const filtered = addressArray.filter( - (addr) => !isAddressPrivate(addr.address) || isLoopbackHost(hostname), - ); - - if (filtered.length === 0 && addressArray.length > 0) { - callback(new PrivateIpError(), []); - return; - } - - callback(null, filtered); - }); -} - -// Dedicated dispatcher with connection-level SSRF protection (safeLookup) -const safeDispatcher = new Agent({ - headersTimeout: DEFAULT_HEADERS_TIMEOUT, - bodyTimeout: DEFAULT_BODY_TIMEOUT, - connect: { - lookup: safeLookup, - }, -}); - export function isPrivateIp(url: string): boolean { try { const hostname = new URL(url).hostname; @@ -125,37 +61,6 @@ export function isPrivateIp(url: string): boolean { } } -/** - * Checks if a URL resolves to a private IP address. - * Performs DNS resolution to prevent DNS rebinding/SSRF bypasses. - */ -export async function isPrivateIpAsync(url: string): Promise { - try { - const parsed = new URL(url); - const hostname = parsed.hostname; - - // Fast check for literal IPs or localhost - if (isAddressPrivate(hostname)) { - return true; - } - - // Resolve DNS to check the actual target IP - const addresses = await lookup(hostname, { all: true }); - return addresses.some((addr) => isAddressPrivate(addr.address)); - } catch (e) { - if ( - e instanceof Error && - e.name === 'TypeError' && - e.message.includes('Invalid URL') - ) { - return false; - } - throw new Error(`Failed to verify if URL resolves to private IP: ${url}`, { - cause: e, - }); - } -} - /** * IANA Benchmark Testing Range (198.18.0.0/15). * Classified as 'unicast' by ipaddr.js but is reserved and should not be @@ -210,72 +115,15 @@ export function isAddressPrivate(address: string): boolean { } } -/** - * Internal helper to map varied fetch errors to a standardized FetchError. - * Centralizes security-related error mapping (e.g. PrivateIpError). - */ -function handleFetchError(error: unknown, url: string): never { - if (error instanceof PrivateIpError) { - throw new FetchError( - `Access to private network is blocked: ${url}`, - 'ERR_PRIVATE_NETWORK', - { cause: error }, - ); - } - - if (error instanceof FetchError) { - throw error; - } - - throw new FetchError( - getErrorMessage(error), - isNodeError(error) ? error.code : undefined, - { cause: error }, - ); -} - -/** - * Enhanced fetch with SSRF protection. - * Prevents access to private/internal networks at the connection level. - */ -export async function safeFetch( - input: RequestInfo | URL, - init?: RequestInit, -): Promise { - const nodeInit: NodeFetchInit = { - ...init, - dispatcher: safeDispatcher, - }; - - try { - // eslint-disable-next-line no-restricted-syntax - return await fetch(input, nodeInit); - } catch (error) { - const url = - input instanceof Request - ? input.url - : typeof input === 'string' - ? input - : input.toString(); - handleFetchError(error, url); - } -} - /** * Creates an undici ProxyAgent that incorporates safe DNS lookup. */ export function createSafeProxyAgent(proxyUrl: string): ProxyAgent { return new ProxyAgent({ uri: proxyUrl, - connect: { - lookup: safeLookup, - }, }); } -/** - * Performs a fetch with a specified timeout and connection-level SSRF protection. - */ export async function fetchWithTimeout( url: string, timeout: number, @@ -294,21 +142,17 @@ export async function fetchWithTimeout( } } - const nodeInit: NodeFetchInit = { - ...options, - signal: controller.signal, - dispatcher: safeDispatcher, - }; - try { - // eslint-disable-next-line no-restricted-syntax - const response = await fetch(url, nodeInit); + const response = await fetch(url, { + ...options, + signal: controller.signal, + }); return response; } catch (error) { if (isNodeError(error) && error.code === 'ABORT_ERR') { throw new FetchError(`Request timed out after ${timeout}ms`, 'ETIMEDOUT'); } - handleFetchError(error, url.toString()); + throw new FetchError(getErrorMessage(error), undefined, { cause: error }); } finally { clearTimeout(timeoutId); } diff --git a/packages/core/src/utils/oauth-flow.ts b/packages/core/src/utils/oauth-flow.ts index 45318efdb5..e13fd37837 100644 --- a/packages/core/src/utils/oauth-flow.ts +++ b/packages/core/src/utils/oauth-flow.ts @@ -454,7 +454,6 @@ export async function exchangeCodeForToken( params.append('resource', resource); } - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(config.tokenUrl, { method: 'POST', headers: { @@ -508,7 +507,6 @@ export async function refreshAccessToken( params.append('resource', resource); } - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(tokenUrl, { method: 'POST', headers: { diff --git a/packages/vscode-ide-companion/src/extension.ts b/packages/vscode-ide-companion/src/extension.ts index e8cef91c2b..456ec6e872 100644 --- a/packages/vscode-ide-companion/src/extension.ts +++ b/packages/vscode-ide-companion/src/extension.ts @@ -42,7 +42,6 @@ async function checkForUpdates( const currentVersion = context.extension.packageJSON.version; // Fetch extension details from the VSCode Marketplace. - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch( 'https://marketplace.visualstudio.com/_apis/public/gallery/extensionquery', { diff --git a/packages/vscode-ide-companion/src/ide-server.test.ts b/packages/vscode-ide-companion/src/ide-server.test.ts index b3d39bf832..eb28638a78 100644 --- a/packages/vscode-ide-companion/src/ide-server.test.ts +++ b/packages/vscode-ide-companion/src/ide-server.test.ts @@ -356,7 +356,6 @@ describe('IDEServer', () => { }); it('should reject request without auth token', async () => { - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(`http://localhost:${port}/mcp`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, @@ -371,7 +370,6 @@ describe('IDEServer', () => { }); it('should allow request with valid auth token', async () => { - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(`http://localhost:${port}/mcp`, { method: 'POST', headers: { @@ -389,7 +387,6 @@ describe('IDEServer', () => { }); it('should reject request with invalid auth token', async () => { - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(`http://localhost:${port}/mcp`, { method: 'POST', headers: { @@ -416,7 +413,6 @@ describe('IDEServer', () => { ]; for (const header of malformedHeaders) { - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(`http://localhost:${port}/mcp`, { method: 'POST', headers: {