feat(a2a): enable native gRPC support and protocol routing (#21403)

Co-authored-by: Adam Weidman <adamfweidman@google.com>
This commit is contained in:
Alisa
2026-03-12 14:36:50 -07:00
committed by GitHub
parent 5abc170b08
commit 4d393f9dca
17 changed files with 302 additions and 935 deletions

View File

@@ -35,11 +35,6 @@ const commonRestrictedSyntaxRules = [
message:
'Do not throw string literals or non-Error objects. Throw new Error("...") instead.',
},
{
selector: 'CallExpression[callee.name="fetch"]',
message:
'Use safeFetch() from "@/utils/fetch" instead of the global fetch() to ensure SSRF protection. If you are implementing a custom security layer, use an eslint-disable comment and explain why.',
},
];
export default tseslint.config(

View File

@@ -123,7 +123,6 @@ async function downloadFiles({
downloads.push(
(async () => {
const endpoint = `${REPO_DOWNLOAD_URL}/refs/tags/${releaseTag}/${SOURCE_DIR}/${fileBasename}`;
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(endpoint, {
method: 'GET',
dispatcher: proxy ? new ProxyAgent(proxy) : undefined,

View File

@@ -61,7 +61,6 @@ export const getLatestGitHubRelease = async (
const endpoint = `https://api.github.com/repos/google-github-actions/run-gemini-cli/releases/latest`;
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(endpoint, {
method: 'GET',
headers: {

View File

@@ -5,11 +5,8 @@
*/
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import {
A2AClientManager,
type SendMessageResult,
} from './a2a-client-manager.js';
import type { AgentCard, Task } from '@a2a-js/sdk';
import { A2AClientManager } from './a2a-client-manager.js';
import type { AgentCard } from '@a2a-js/sdk';
import {
ClientFactory,
DefaultAgentCardResolver,
@@ -22,81 +19,95 @@ import type { Config } from '../config/config.js';
import { Agent as UndiciAgent, ProxyAgent } from 'undici';
import { debugLogger } from '../utils/debugLogger.js';
interface MockClient {
sendMessageStream: ReturnType<typeof vi.fn>;
getTask: ReturnType<typeof vi.fn>;
cancelTask: ReturnType<typeof vi.fn>;
}
vi.mock('@a2a-js/sdk/client', async (importOriginal) => {
const actual = await importOriginal();
return {
...(actual as Record<string, unknown>),
createAuthenticatingFetchWithRetry: vi.fn(),
ClientFactory: vi.fn(),
DefaultAgentCardResolver: vi.fn(),
ClientFactoryOptions: {
createFrom: vi.fn(),
default: {},
},
};
});
vi.mock('../utils/debugLogger.js', () => ({
debugLogger: {
debug: vi.fn(),
},
}));
vi.mock('@a2a-js/sdk/client', () => {
const ClientFactory = vi.fn();
const DefaultAgentCardResolver = vi.fn();
const RestTransportFactory = vi.fn();
const JsonRpcTransportFactory = vi.fn();
const ClientFactoryOptions = {
default: {},
createFrom: vi.fn(),
};
const createAuthenticatingFetchWithRetry = vi.fn();
DefaultAgentCardResolver.prototype.resolve = vi.fn();
ClientFactory.prototype.createFromUrl = vi.fn();
return {
ClientFactory,
ClientFactoryOptions,
DefaultAgentCardResolver,
RestTransportFactory,
JsonRpcTransportFactory,
createAuthenticatingFetchWithRetry,
};
});
describe('A2AClientManager', () => {
let manager: A2AClientManager;
const mockAgentCard: AgentCard = {
name: 'test-agent',
description: 'A test agent',
url: 'http://test.agent',
version: '1.0.0',
protocolVersion: '0.1.0',
capabilities: {},
skills: [],
defaultInputModes: [],
defaultOutputModes: [],
};
const mockClient: MockClient = {
sendMessageStream: vi.fn(),
getTask: vi.fn(),
cancelTask: vi.fn(),
};
// Stable mocks initialized once
const sendMessageStreamMock = vi.fn();
const getTaskMock = vi.fn();
const cancelTaskMock = vi.fn();
const getAgentCardMock = vi.fn();
const authFetchMock = vi.fn();
const mockClient = {
sendMessageStream: sendMessageStreamMock,
getTask: getTaskMock,
cancelTask: cancelTaskMock,
getAgentCard: getAgentCardMock,
} as unknown as Client;
const mockAgentCard: Partial<AgentCard> = { name: 'TestAgent' };
beforeEach(() => {
vi.clearAllMocks();
A2AClientManager.resetInstanceForTesting();
manager = A2AClientManager.getInstance();
// Default mock implementations
getAgentCardMock.mockResolvedValue({
// Re-create the instances as plain objects that can be spied on
const factoryInstance = {
createFromUrl: vi.fn(),
createFromAgentCard: vi.fn(),
};
const resolverInstance = {
resolve: vi.fn(),
};
vi.mocked(ClientFactory).mockReturnValue(
factoryInstance as unknown as ClientFactory,
);
vi.mocked(DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as DefaultAgentCardResolver,
);
vi.spyOn(factoryInstance, 'createFromUrl').mockResolvedValue(
mockClient as unknown as Client,
);
vi.spyOn(factoryInstance, 'createFromAgentCard').mockResolvedValue(
mockClient as unknown as Client,
);
vi.spyOn(resolverInstance, 'resolve').mockResolvedValue({
...mockAgentCard,
url: 'http://test.agent/real/endpoint',
} as AgentCard);
vi.mocked(ClientFactory.prototype.createFromUrl).mockResolvedValue(
mockClient,
vi.spyOn(ClientFactoryOptions, 'createFrom').mockImplementation(
(_defaults, overrides) => overrides as unknown as ClientFactoryOptions,
);
vi.mocked(DefaultAgentCardResolver.prototype.resolve).mockResolvedValue({
...mockAgentCard,
url: 'http://test.agent/real/endpoint',
} as AgentCard);
vi.mocked(ClientFactoryOptions.createFrom).mockImplementation(
(_defaults, overrides) => overrides as ClientFactoryOptions,
);
vi.mocked(createAuthenticatingFetchWithRetry).mockReturnValue(
authFetchMock,
vi.mocked(createAuthenticatingFetchWithRetry).mockImplementation(() =>
authFetchMock.mockResolvedValue({
ok: true,
json: async () => ({}),
} as Response),
);
vi.stubGlobal(
@@ -170,15 +181,19 @@ describe('A2AClientManager', () => {
'TestAgent',
'http://test.agent/card',
);
expect(agentCard).toMatchObject(mockAgentCard);
expect(manager.getAgentCard('TestAgent')).toBe(agentCard);
expect(manager.getClient('TestAgent')).toBeDefined();
});
it('should configure ClientFactory with REST, JSON-RPC, and gRPC transports', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(ClientFactoryOptions.createFrom).toHaveBeenCalled();
});
it('should throw an error if an agent with the same name is already loaded', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
await expect(
manager.loadAgent('TestAgent', 'http://another.agent/card'),
manager.loadAgent('TestAgent', 'http://test.agent/card'),
).rejects.toThrow("Agent with name 'TestAgent' is already loaded.");
});
@@ -193,20 +208,12 @@ describe('A2AClientManager', () => {
shouldRetryWithHeaders: vi.fn(),
};
await manager.loadAgent(
'CustomAuthAgent',
'http://custom.agent/card',
'TestAgent',
'http://test.agent/card',
customAuthHandler as unknown as AuthenticationHandler,
);
expect(createAuthenticatingFetchWithRetry).toHaveBeenCalledWith(
expect.anything(),
customAuthHandler,
);
// Card resolver should NOT use the authenticated fetch by default.
const resolverInstance = vi.mocked(DefaultAgentCardResolver).mock
.instances[0];
expect(resolverInstance).toBeDefined();
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock
.calls[0][0];
expect(resolverOptions?.fetchImpl).not.toBe(authFetchMock);
@@ -267,106 +274,163 @@ describe('A2AClientManager', () => {
it('should log a debug message upon loading an agent', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(debugLogger.debug).toHaveBeenCalledWith(
"[A2AClientManager] Loaded agent 'TestAgent' from http://test.agent/card",
expect.stringContaining("Loaded agent 'TestAgent'"),
);
});
it('should clear the cache', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(manager.getAgentCard('TestAgent')).toBeDefined();
expect(manager.getClient('TestAgent')).toBeDefined();
manager.clearCache();
expect(manager.getAgentCard('TestAgent')).toBeUndefined();
expect(manager.getClient('TestAgent')).toBeUndefined();
expect(debugLogger.debug).toHaveBeenCalledWith(
'[A2AClientManager] Cache cleared.',
});
it('should throw if resolveAgentCard fails', async () => {
const resolverInstance = {
resolve: vi.fn().mockRejectedValue(new Error('Resolution failed')),
};
vi.mocked(DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as DefaultAgentCardResolver,
);
await expect(
manager.loadAgent('FailAgent', 'http://fail.agent'),
).rejects.toThrow('Resolution failed');
});
it('should throw if factory.createFromAgentCard fails', async () => {
const factoryInstance = {
createFromAgentCard: vi
.fn()
.mockRejectedValue(new Error('Factory failed')),
};
vi.mocked(ClientFactory).mockReturnValue(
factoryInstance as unknown as ClientFactory,
);
await expect(
manager.loadAgent('FailAgent', 'http://fail.agent'),
).rejects.toThrow('Factory failed');
});
});
describe('getAgentCard and getClient', () => {
it('should return undefined if agent is not found', () => {
expect(manager.getAgentCard('Unknown')).toBeUndefined();
expect(manager.getClient('Unknown')).toBeUndefined();
});
});
describe('sendMessageStream', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
await manager.loadAgent('TestAgent', 'http://test.agent/card');
});
it('should send a message and return a stream', async () => {
const mockResult = {
kind: 'message',
messageId: 'a',
parts: [],
role: 'agent',
} as SendMessageResult;
sendMessageStreamMock.mockReturnValue(
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield mockResult;
yield { kind: 'message' };
})(),
);
const stream = manager.sendMessageStream('TestAgent', 'Hello');
const results = [];
for await (const res of stream) {
results.push(res);
for await (const result of stream) {
results.push(result);
}
expect(results).toEqual([mockResult]);
expect(sendMessageStreamMock).toHaveBeenCalledWith(
expect(results).toHaveLength(1);
expect(mockClient.sendMessageStream).toHaveBeenCalled();
});
it('should use contextId and taskId when provided', async () => {
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield { kind: 'message' };
})(),
);
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
contextId: 'ctx123',
taskId: 'task456',
});
// trigger execution
for await (const _ of stream) {
break;
}
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
expect.objectContaining({
message: expect.anything(),
message: expect.objectContaining({
contextId: 'ctx123',
taskId: 'task456',
}),
}),
expect.any(Object),
);
});
it('should use contextId and taskId when provided', async () => {
sendMessageStreamMock.mockReturnValue(
it('should correctly propagate AbortSignal to the stream', async () => {
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield {
kind: 'message',
messageId: 'a',
parts: [],
role: 'agent',
} as SendMessageResult;
yield { kind: 'message' };
})(),
);
const expectedContextId = 'user-context-id';
const expectedTaskId = 'user-task-id';
const controller = new AbortController();
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
contextId: expectedContextId,
taskId: expectedTaskId,
signal: controller.signal,
});
// trigger execution
for await (const _ of stream) {
// consume stream
break;
}
const call = sendMessageStreamMock.mock.calls[0][0];
expect(call.message.contextId).toBe(expectedContextId);
expect(call.message.taskId).toBe(expectedTaskId);
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
expect.any(Object),
expect.objectContaining({ signal: controller.signal }),
);
});
it('should propagate the original error on failure', async () => {
sendMessageStreamMock.mockImplementationOnce(() => {
throw new Error('Network error');
it('should handle a multi-chunk stream with different event types', async () => {
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield { kind: 'message', messageId: 'm1' };
yield { kind: 'status-update', taskId: 't1' };
})(),
);
const stream = manager.sendMessageStream('TestAgent', 'Hello');
const results = [];
for await (const result of stream) {
results.push(result);
}
expect(results).toHaveLength(2);
expect(results[0].kind).toBe('message');
expect(results[1].kind).toBe('status-update');
});
it('should throw prefixed error on failure', async () => {
mockClient.sendMessageStream.mockImplementation(() => {
throw new Error('Network failure');
});
const stream = manager.sendMessageStream('TestAgent', 'Hello');
await expect(async () => {
for await (const _ of stream) {
// consume
// empty
}
}).rejects.toThrow('Network error');
}).rejects.toThrow(
'[A2AClientManager] sendMessageStream Error [TestAgent]: Network failure',
);
});
it('should throw an error if the agent is not found', async () => {
const stream = manager.sendMessageStream('NonExistentAgent', 'Hello');
await expect(async () => {
for await (const _ of stream) {
// consume
// empty
}
}).rejects.toThrow("Agent 'NonExistentAgent' not found.");
});
@@ -374,28 +438,23 @@ describe('A2AClientManager', () => {
describe('getTask', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
await manager.loadAgent('TestAgent', 'http://test.agent/card');
});
it('should get a task from the correct agent', async () => {
getTaskMock.mockResolvedValue({
id: 'task123',
contextId: 'a',
kind: 'task',
status: { state: 'completed' },
} as Task);
const mockTask = { id: 'task123', kind: 'task' };
mockClient.getTask.mockResolvedValue(mockTask);
await manager.getTask('TestAgent', 'task123');
expect(getTaskMock).toHaveBeenCalledWith({
id: 'task123',
});
const result = await manager.getTask('TestAgent', 'task123');
expect(result).toBe(mockTask);
expect(mockClient.getTask).toHaveBeenCalledWith({ id: 'task123' });
});
it('should throw prefixed error on failure', async () => {
getTaskMock.mockRejectedValueOnce(new Error('Network error'));
mockClient.getTask.mockRejectedValue(new Error('Not found'));
await expect(manager.getTask('TestAgent', 'task123')).rejects.toThrow(
'A2AClient getTask Error [TestAgent]: Network error',
'A2AClient getTask Error [TestAgent]: Not found',
);
});
@@ -408,28 +467,23 @@ describe('A2AClientManager', () => {
describe('cancelTask', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
await manager.loadAgent('TestAgent', 'http://test.agent/card');
});
it('should cancel a task on the correct agent', async () => {
cancelTaskMock.mockResolvedValue({
id: 'task123',
contextId: 'a',
kind: 'task',
status: { state: 'canceled' },
} as Task);
const mockTask = { id: 'task123', kind: 'task' };
mockClient.cancelTask.mockResolvedValue(mockTask);
await manager.cancelTask('TestAgent', 'task123');
expect(cancelTaskMock).toHaveBeenCalledWith({
id: 'task123',
});
const result = await manager.cancelTask('TestAgent', 'task123');
expect(result).toBe(mockTask);
expect(mockClient.cancelTask).toHaveBeenCalledWith({ id: 'task123' });
});
it('should throw prefixed error on failure', async () => {
cancelTaskMock.mockRejectedValueOnce(new Error('Network error'));
mockClient.cancelTask.mockRejectedValue(new Error('Cannot cancel'));
await expect(manager.cancelTask('TestAgent', 'task123')).rejects.toThrow(
'A2AClient cancelTask Error [TestAgent]: Network error',
'A2AClient cancelTask Error [TestAgent]: Cannot cancel',
);
});

View File

@@ -12,36 +12,41 @@ import type {
TaskStatusUpdateEvent,
TaskArtifactUpdateEvent,
} from '@a2a-js/sdk';
import type { AuthenticationHandler, Client } from '@a2a-js/sdk/client';
import {
type Client,
ClientFactory,
ClientFactoryOptions,
DefaultAgentCardResolver,
RestTransportFactory,
JsonRpcTransportFactory,
type AuthenticationHandler,
RestTransportFactory,
createAuthenticatingFetchWithRetry,
} from '@a2a-js/sdk/client';
import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc';
import * as grpc from '@grpc/grpc-js';
import { v4 as uuidv4 } from 'uuid';
import { Agent as UndiciAgent, ProxyAgent } from 'undici';
import { normalizeAgentCard } from './a2aUtils.js';
import type { Config } from '../config/config.js';
import { debugLogger } from '../utils/debugLogger.js';
import { safeLookup } from '../utils/fetch.js';
import { classifyAgentError } from './a2a-errors.js';
// Remote agents can take 10+ minutes (e.g. Deep Research).
// Use a dedicated dispatcher so the global 5-min timeout isn't affected.
const A2A_TIMEOUT = 1800000; // 30 minutes
/**
* Result of sending a message, which can be a full message, a task,
* or an incremental status/artifact update.
*/
export type SendMessageResult =
| Message
| Task
| TaskStatusUpdateEvent
| TaskArtifactUpdateEvent;
// Remote agents can take 10+ minutes (e.g. Deep Research).
// Use a dedicated dispatcher so the global 5-min timeout isn't affected.
const A2A_TIMEOUT = 1800000; // 30 minutes
/**
* Manages A2A clients and caches loaded agent information.
* Follows a singleton pattern to ensure a single client instance.
* Orchestrates communication with remote A2A agents.
* Manages protocol negotiation, authentication, and transport selection.
*/
export class A2AClientManager {
private static instance: A2AClientManager;
@@ -58,9 +63,6 @@ export class A2AClientManager {
const agentOptions = {
headersTimeout: A2A_TIMEOUT,
bodyTimeout: A2A_TIMEOUT,
connect: {
lookup: safeLookup, // SSRF protection at connection level
},
};
if (proxyUrl) {
@@ -73,7 +75,6 @@ export class A2AClientManager {
}
this.a2aFetch = (input, init) =>
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
fetch(input, { ...init, dispatcher: this.a2aDispatcher } as RequestInit);
}
@@ -139,22 +140,35 @@ export class A2AClientManager {
};
const resolver = new DefaultAgentCardResolver({ fetchImpl: cardFetch });
const rawCard = await resolver.resolve(agentCardUrl, '');
// TODO: Remove normalizeAgentCard once @a2a-js/sdk handles
// proto field name aliases (supportedInterfaces → additionalInterfaces,
// protocolBinding → transport).
const agentCard = normalizeAgentCard(rawCard);
const options = ClientFactoryOptions.createFrom(
const grpcUrl =
agentCard.additionalInterfaces?.find((i) => i.transport === 'GRPC')
?.url ?? agentCard.url;
const clientOptions = ClientFactoryOptions.createFrom(
ClientFactoryOptions.default,
{
transports: [
new RestTransportFactory({ fetchImpl: authFetch }),
new JsonRpcTransportFactory({ fetchImpl: authFetch }),
new GrpcTransportFactory({
grpcChannelCredentials: grpcUrl.startsWith('https://')
? grpc.credentials.createSsl()
: grpc.credentials.createInsecure(),
}),
],
cardResolver: resolver,
},
);
try {
const factory = new ClientFactory(options);
const client = await factory.createFromUrl(agentCardUrl, '');
const agentCard = await client.getAgentCard();
const factory = new ClientFactory(clientOptions);
const client = await factory.createFromAgentCard(agentCard);
this.clients.set(name, client);
this.agentCards.set(name, agentCard);
@@ -192,9 +206,7 @@ export class A2AClientManager {
options?: { contextId?: string; taskId?: string; signal?: AbortSignal },
): AsyncIterable<SendMessageResult> {
const client = this.clients.get(agentName);
if (!client) {
throw new Error(`Agent '${agentName}' not found.`);
}
if (!client) throw new Error(`Agent '${agentName}' not found.`);
const messageParams: MessageSendParams = {
message: {
@@ -207,9 +219,19 @@ export class A2AClientManager {
},
};
yield* client.sendMessageStream(messageParams, {
signal: options?.signal,
});
try {
yield* client.sendMessageStream(messageParams, {
signal: options?.signal,
});
} catch (error: unknown) {
const prefix = `[A2AClientManager] sendMessageStream Error [${agentName}]`;
if (error instanceof Error) {
throw new Error(`${prefix}: ${error.message}`, { cause: error });
}
throw new Error(
`${prefix}: Unexpected error during sendMessageStream: ${String(error)}`,
);
}
}
/**
@@ -238,9 +260,7 @@ export class A2AClientManager {
*/
async getTask(agentName: string, taskId: string): Promise<Task> {
const client = this.clients.get(agentName);
if (!client) {
throw new Error(`Agent '${agentName}' not found.`);
}
if (!client) throw new Error(`Agent '${agentName}' not found.`);
try {
return await client.getTask({ id: taskId });
} catch (error: unknown) {
@@ -260,9 +280,7 @@ export class A2AClientManager {
*/
async cancelTask(agentName: string, taskId: string): Promise<Task> {
const client = this.clients.get(agentName);
if (!client) {
throw new Error(`Agent '${agentName}' not found.`);
}
if (!client) throw new Error(`Agent '${agentName}' not found.`);
try {
return await client.cancelTask({ id: taskId });
} catch (error: unknown) {

View File

@@ -12,9 +12,6 @@ import {
A2AResultReassembler,
AUTH_REQUIRED_MSG,
normalizeAgentCard,
getGrpcCredentials,
pinUrlToIp,
splitAgentCardUrl,
} from './a2aUtils.js';
import type { SendMessageResult } from './a2a-client-manager.js';
import type {
@@ -26,12 +23,6 @@ import type {
TaskStatusUpdateEvent,
TaskArtifactUpdateEvent,
} from '@a2a-js/sdk';
import * as dnsPromises from 'node:dns/promises';
import type { LookupAddress } from 'node:dns';
vi.mock('node:dns/promises', () => ({
lookup: vi.fn(),
}));
describe('a2aUtils', () => {
beforeEach(() => {
@@ -42,89 +33,6 @@ describe('a2aUtils', () => {
vi.restoreAllMocks();
});
describe('getGrpcCredentials', () => {
it('should return secure credentials for https', () => {
const credentials = getGrpcCredentials('https://test.agent');
expect(credentials).toBeDefined();
});
it('should return insecure credentials for http', () => {
const credentials = getGrpcCredentials('http://test.agent');
expect(credentials).toBeDefined();
});
});
describe('pinUrlToIp', () => {
it('should resolve and pin hostname to IP', async () => {
vi.mocked(
dnsPromises.lookup as unknown as (
hostname: string,
options: { all: true },
) => Promise<LookupAddress[]>,
).mockResolvedValue([{ address: '93.184.216.34', family: 4 }]);
const { pinnedUrl, hostname } = await pinUrlToIp(
'http://example.com:9000',
'test-agent',
);
expect(hostname).toBe('example.com');
expect(pinnedUrl).toBe('http://93.184.216.34:9000/');
});
it('should handle raw host:port strings (standard for gRPC)', async () => {
vi.mocked(
dnsPromises.lookup as unknown as (
hostname: string,
options: { all: true },
) => Promise<LookupAddress[]>,
).mockResolvedValue([{ address: '93.184.216.34', family: 4 }]);
const { pinnedUrl, hostname } = await pinUrlToIp(
'example.com:9000',
'test-agent',
);
expect(hostname).toBe('example.com');
expect(pinnedUrl).toBe('93.184.216.34:9000');
});
it('should throw error if resolution fails (fail closed)', async () => {
vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error'));
await expect(
pinUrlToIp('http://unreachable.com', 'test-agent'),
).rejects.toThrow("Failed to resolve host for agent 'test-agent'");
});
it('should throw error if resolved to private IP', async () => {
vi.mocked(
dnsPromises.lookup as unknown as (
hostname: string,
options: { all: true },
) => Promise<LookupAddress[]>,
).mockResolvedValue([{ address: '10.0.0.1', family: 4 }]);
await expect(
pinUrlToIp('http://malicious.com', 'test-agent'),
).rejects.toThrow('resolves to private IP range');
});
it('should allow localhost/127.0.0.1/::1 exceptions', async () => {
vi.mocked(
dnsPromises.lookup as unknown as (
hostname: string,
options: { all: true },
) => Promise<LookupAddress[]>,
).mockResolvedValue([{ address: '127.0.0.1', family: 4 }]);
const { pinnedUrl, hostname } = await pinUrlToIp(
'http://localhost:9000',
'test-agent',
);
expect(hostname).toBe('localhost');
expect(pinnedUrl).toBe('http://127.0.0.1:9000/');
});
});
describe('isTerminalState', () => {
it('should return true for completed, failed, canceled, and rejected', () => {
expect(isTerminalState('completed')).toBe(true);
@@ -365,12 +273,12 @@ describe('a2aUtils', () => {
expect(normalized.name).toBe('my-agent');
// @ts-expect-error - testing dynamic preservation
expect(normalized.customField).toBe('keep-me');
expect(normalized.description).toBe('');
expect(normalized.skills).toEqual([]);
expect(normalized.defaultInputModes).toEqual([]);
expect(normalized.description).toBeUndefined();
expect(normalized.skills).toBeUndefined();
expect(normalized.defaultInputModes).toBeUndefined();
});
it('should normalize and synchronize interfaces while preserving other fields', () => {
it('should map supportedInterfaces to additionalInterfaces with protocolBinding → transport', () => {
const raw = {
name: 'test',
supportedInterfaces: [
@@ -384,13 +292,7 @@ describe('a2aUtils', () => {
const normalized = normalizeAgentCard(raw);
// Should exist in both fields
expect(normalized.additionalInterfaces).toHaveLength(1);
expect(
(normalized as unknown as Record<string, unknown>)[
'supportedInterfaces'
],
).toHaveLength(1);
const intf = normalized.additionalInterfaces?.[0] as unknown as Record<
string,
@@ -399,43 +301,18 @@ describe('a2aUtils', () => {
expect(intf['transport']).toBe('GRPC');
expect(intf['url']).toBe('grpc://test');
// Should fallback top-level url
expect(normalized.url).toBe('grpc://test');
});
it('should preserve existing top-level url if present', () => {
it('should not overwrite additionalInterfaces if already present', () => {
const raw = {
name: 'test',
url: 'http://existing',
additionalInterfaces: [{ url: 'http://grpc', transport: 'GRPC' }],
supportedInterfaces: [{ url: 'http://other', transport: 'REST' }],
};
const normalized = normalizeAgentCard(raw);
expect(normalized.url).toBe('http://existing');
});
it('should NOT prepend http:// scheme to raw IP:port strings for gRPC interfaces', () => {
const raw = {
name: 'raw-ip-grpc',
supportedInterfaces: [{ url: '127.0.0.1:9000', transport: 'GRPC' }],
};
const normalized = normalizeAgentCard(raw);
expect(normalized.additionalInterfaces?.[0].url).toBe('127.0.0.1:9000');
expect(normalized.url).toBe('127.0.0.1:9000');
});
it('should prepend http:// scheme to raw IP:port strings for REST interfaces', () => {
const raw = {
name: 'raw-ip-rest',
supportedInterfaces: [{ url: '127.0.0.1:8080', transport: 'REST' }],
};
const normalized = normalizeAgentCard(raw);
expect(normalized.additionalInterfaces?.[0].url).toBe(
'http://127.0.0.1:8080',
);
expect(normalized.additionalInterfaces).toHaveLength(1);
expect(normalized.additionalInterfaces?.[0].url).toBe('http://grpc');
});
it('should NOT override existing transport if protocolBinding is also present', () => {
@@ -448,48 +325,20 @@ describe('a2aUtils', () => {
const normalized = normalizeAgentCard(raw);
expect(normalized.additionalInterfaces?.[0].transport).toBe('GRPC');
});
});
describe('splitAgentCardUrl', () => {
const standard = '.well-known/agent-card.json';
it('should not mutate the original card object', () => {
const raw = {
name: 'test',
supportedInterfaces: [{ url: 'grpc://test', protocolBinding: 'GRPC' }],
};
it('should return baseUrl as-is if it does not end with standard path', () => {
const url = 'http://localhost:9001/custom/path';
expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url });
});
it('should split correctly if URL ends with standard path', () => {
const url = `http://localhost:9001/${standard}`;
expect(splitAgentCardUrl(url)).toEqual({
baseUrl: 'http://localhost:9001/',
path: undefined,
});
});
it('should handle trailing slash in baseUrl when splitting', () => {
const url = `http://example.com/api/${standard}`;
expect(splitAgentCardUrl(url)).toEqual({
baseUrl: 'http://example.com/api/',
path: undefined,
});
});
it('should ignore hashes and query params when splitting', () => {
const url = `http://localhost:9001/${standard}?foo=bar#baz`;
expect(splitAgentCardUrl(url)).toEqual({
baseUrl: 'http://localhost:9001/',
path: undefined,
});
});
it('should return original URL if parsing fails', () => {
const url = 'not-a-url';
expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url });
});
it('should handle standard path appearing earlier in the path', () => {
const url = `http://localhost:9001/${standard}/something-else`;
expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url });
const normalized = normalizeAgentCard(raw);
expect(normalized).not.toBe(raw);
expect(normalized.additionalInterfaces).toBeDefined();
// Original should not have additionalInterfaces added
expect(
(raw as Record<string, unknown>)['additionalInterfaces'],
).toBeUndefined();
});
});

View File

@@ -4,9 +4,6 @@
* SPDX-License-Identifier: Apache-2.0
*/
import * as grpc from '@grpc/grpc-js';
import { lookup } from 'node:dns/promises';
import { z } from 'zod';
import type {
Message,
Part,
@@ -18,37 +15,10 @@ import type {
AgentCard,
AgentInterface,
} from '@a2a-js/sdk';
import { isAddressPrivate } from '../utils/fetch.js';
import type { SendMessageResult } from './a2a-client-manager.js';
export const AUTH_REQUIRED_MSG = `[Authorization Required] The agent has indicated it requires authorization to proceed. Please follow the agent's instructions.`;
const AgentInterfaceSchema = z
.object({
url: z.string().default(''),
transport: z.string().optional(),
protocolBinding: z.string().optional(),
})
.passthrough();
const AgentCardSchema = z
.object({
name: z.string().default('unknown'),
description: z.string().default(''),
url: z.string().default(''),
version: z.string().default(''),
protocolVersion: z.string().default(''),
capabilities: z.record(z.unknown()).default({}),
skills: z.array(z.union([z.string(), z.record(z.unknown())])).default([]),
defaultInputModes: z.array(z.string()).default([]),
defaultOutputModes: z.array(z.string()).default([]),
additionalInterfaces: z.array(AgentInterfaceSchema).optional(),
supportedInterfaces: z.array(AgentInterfaceSchema).optional(),
preferredTransport: z.string().optional(),
})
.passthrough();
/**
* Reassembles incremental A2A streaming updates into a coherent result.
* Shows sequential status/messages followed by all reassembled artifacts.
@@ -241,166 +211,45 @@ function extractPartText(part: Part): string {
}
/**
* Normalizes an agent card by ensuring it has the required properties
* and resolving any inconsistencies between protocol versions.
* Normalizes proto field name aliases that the SDK doesn't handle yet.
* The A2A proto spec uses `supported_interfaces` and `protocol_binding`,
* while the SDK expects `additionalInterfaces` and `transport`.
* TODO: Remove once @a2a-js/sdk handles these aliases natively.
*/
export function normalizeAgentCard(card: unknown): AgentCard {
if (!isObject(card)) {
throw new Error('Agent card is missing.');
}
// Use Zod to validate and parse the card, ensuring safe defaults and narrowing types.
const parsed = AgentCardSchema.parse(card);
// Narrowing to AgentCard interface after runtime validation.
// Shallow-copy to avoid mutating the SDK's cached object.
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const result = parsed as unknown as AgentCard;
const result = { ...card } as unknown as AgentCard;
// Normalize interfaces and synchronize both interface fields.
const normalizedInterfaces = extractNormalizedInterfaces(parsed);
result.additionalInterfaces = normalizedInterfaces;
// Sync supportedInterfaces for backward compatibility.
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const legacyResult = result as unknown as Record<string, AgentInterface[]>;
legacyResult['supportedInterfaces'] = normalizedInterfaces;
// Fallback preferredTransport: If not specified, default to GRPC if available.
if (
!result.preferredTransport &&
normalizedInterfaces.some((i) => i.transport === 'GRPC')
) {
result.preferredTransport = 'GRPC';
// Map supportedInterfaces → additionalInterfaces if needed
if (!result.additionalInterfaces) {
const raw = card;
if (Array.isArray(raw['supportedInterfaces'])) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
result.additionalInterfaces = raw[
'supportedInterfaces'
] as AgentInterface[];
}
}
// Fallback: If top-level URL is missing, use the first interface's URL.
if (result.url === '' && normalizedInterfaces.length > 0) {
result.url = normalizedInterfaces[0].url;
// Map protocolBinding → transport on each interface
for (const intf of result.additionalInterfaces ?? []) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const raw = intf as unknown as Record<string, unknown>;
const binding = raw['protocolBinding'];
if (!intf.transport && typeof binding === 'string') {
intf.transport = binding;
}
}
return result;
}
/**
* Returns gRPC channel credentials based on the URL scheme.
*/
export function getGrpcCredentials(url: string): grpc.ChannelCredentials {
return url.startsWith('https://')
? grpc.credentials.createSsl()
: grpc.credentials.createInsecure();
}
/**
* Returns gRPC channel options to ensure SSL/authority matches the original hostname
* when connecting via a pinned IP address.
*/
export function getGrpcChannelOptions(
hostname: string,
): Record<string, unknown> {
return {
'grpc.default_authority': hostname,
'grpc.ssl_target_name_override': hostname,
};
}
/**
* Resolves a hostname to its IP address and validates it against SSRF.
* Returns the pinned IP-based URL and the original hostname.
*/
export async function pinUrlToIp(
url: string,
agentName: string,
): Promise<{ pinnedUrl: string; hostname: string }> {
if (!url) return { pinnedUrl: url, hostname: '' };
// gRPC URLs in A2A can be 'host:port' or 'dns:///host:port' or have schemes.
// We normalize to host:port for resolution.
const hasScheme = url.includes('://');
const normalizedUrl = hasScheme ? url : `http://${url}`;
try {
const parsed = new URL(normalizedUrl);
const hostname = parsed.hostname;
const sanitizedHost =
hostname.startsWith('[') && hostname.endsWith(']')
? hostname.slice(1, -1)
: hostname;
// Resolve DNS to check the actual target IP and pin it
const addresses = await lookup(hostname, { all: true });
const publicAddresses = addresses.filter(
(addr) =>
!isAddressPrivate(addr.address) ||
sanitizedHost === 'localhost' ||
sanitizedHost === '127.0.0.1' ||
sanitizedHost === '::1',
);
if (publicAddresses.length === 0) {
if (addresses.length > 0) {
throw new Error(
`Refusing to load agent '${agentName}': transport URL '${url}' resolves to private IP range.`,
);
}
throw new Error(
`Failed to resolve any public IP addresses for host: ${hostname}`,
);
}
const pinnedIp = publicAddresses[0].address;
const pinnedHostname = pinnedIp.includes(':') ? `[${pinnedIp}]` : pinnedIp;
// Reconstruct URL with IP
parsed.hostname = pinnedHostname;
let pinnedUrl = parsed.toString();
// If original didn't have scheme, remove it (standard for gRPC targets)
if (!hasScheme) {
pinnedUrl = pinnedUrl.replace(/^http:\/\//, '');
// URL.toString() might append a trailing slash
if (pinnedUrl.endsWith('/') && !url.endsWith('/')) {
pinnedUrl = pinnedUrl.slice(0, -1);
}
}
return { pinnedUrl, hostname };
} catch (e) {
if (e instanceof Error && e.message.includes('Refusing')) throw e;
throw new Error(`Failed to resolve host for agent '${agentName}': ${url}`, {
cause: e,
});
}
}
/**
* Splts an agent card URL into a baseUrl and a standard path if it already
* contains '.well-known/agent-card.json'.
*/
export function splitAgentCardUrl(url: string): {
baseUrl: string;
path?: string;
} {
const standardPath = '.well-known/agent-card.json';
try {
const parsedUrl = new URL(url);
if (parsedUrl.pathname.endsWith(standardPath)) {
// Reconstruct baseUrl from parsed components to avoid issues with hashes or query params.
parsedUrl.pathname = parsedUrl.pathname.substring(
0,
parsedUrl.pathname.lastIndexOf(standardPath),
);
parsedUrl.search = '';
parsedUrl.hash = '';
// We return undefined for path if it's the standard one,
// because the SDK's DefaultAgentCardResolver appends it automatically.
return { baseUrl: parsedUrl.toString(), path: undefined };
}
} catch (_e) {
// Ignore URL parsing errors here, let the resolver handle them.
}
return { baseUrl: url };
}
/**
* Extracts contextId and taskId from a Message, Task, or Update response.
* Follows the pattern from the A2A CLI sample to maintain conversational continuity.
@@ -446,65 +295,6 @@ export function extractIdsFromResponse(result: SendMessageResult): {
return { contextId, taskId, clearTaskId };
}
/**
* Extracts and normalizes interfaces from the card, handling protocol version fallbacks.
* Preserves all original fields to maintain SDK compatibility.
*/
function extractNormalizedInterfaces(
card: Record<string, unknown>,
): AgentInterface[] {
const rawInterfaces =
getArray(card, 'additionalInterfaces') ||
getArray(card, 'supportedInterfaces');
if (!rawInterfaces) {
return [];
}
const mapped: AgentInterface[] = [];
for (const i of rawInterfaces) {
if (isObject(i)) {
// Use schema to validate interface object.
const parsed = AgentInterfaceSchema.parse(i);
// Narrowing to AgentInterface after runtime validation.
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const normalized = parsed as unknown as AgentInterface & {
protocolBinding?: string;
};
// Normalize 'transport' from 'protocolBinding' if missing.
if (!normalized.transport && normalized.protocolBinding) {
normalized.transport = normalized.protocolBinding;
}
// Robust URL: Ensure the URL has a scheme (except for gRPC).
if (
normalized.url &&
!normalized.url.includes('://') &&
!normalized.url.startsWith('/') &&
normalized.transport !== 'GRPC'
) {
// Default to http:// for insecure REST/JSON-RPC if scheme is missing.
normalized.url = `http://${normalized.url}`;
}
mapped.push(normalized as AgentInterface);
}
}
return mapped;
}
/**
* Safely extracts an array property from an object.
*/
function getArray(
obj: Record<string, unknown>,
key: string,
): unknown[] | undefined {
const val = obj[key];
return Array.isArray(val) ? val : undefined;
}
// Type Guards
function isTextPart(part: Part): part is TextPart {

View File

@@ -700,7 +700,6 @@ async function fetchAndCacheUserInfo(client: OAuth2Client): Promise<void> {
return;
}
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(
'https://www.googleapis.com/oauth2/v2/userinfo',
{

View File

@@ -111,7 +111,6 @@ export class MCPOAuthProvider {
scope: config.scopes?.join(' ') || '',
};
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(registrationUrl, {
method: 'POST',
headers: {
@@ -301,7 +300,6 @@ export class MCPOAuthProvider {
? { Accept: 'text/event-stream' }
: { Accept: 'application/json' };
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(mcpServerUrl, {
method: 'HEAD',
headers,

View File

@@ -97,7 +97,6 @@ export class OAuthUtils {
resourceMetadataUrl: string,
): Promise<OAuthProtectedResourceMetadata | null> {
try {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(resourceMetadataUrl);
if (!response.ok) {
return null;
@@ -122,7 +121,6 @@ export class OAuthUtils {
authServerMetadataUrl: string,
): Promise<OAuthAuthorizationServerMetadata | null> {
try {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(authServerMetadataUrl);
if (!response.ok) {
return null;

View File

@@ -546,7 +546,6 @@ export class ClearcutLogger {
let result: LogResponse = {};
try {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(CLEARCUT_URL, {
method: 'POST',
body: safeJsonStringify(request),

View File

@@ -1903,7 +1903,6 @@ export async function connectToMcpServer(
acceptHeader = 'application/json';
}
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(urlToFetch, {
method: 'HEAD',
headers: {

View File

@@ -5,27 +5,12 @@
*/
import { describe, it, expect, vi, beforeEach, afterAll } from 'vitest';
import {
isPrivateIp,
isPrivateIpAsync,
isAddressPrivate,
safeLookup,
safeFetch,
fetchWithTimeout,
PrivateIpError,
} from './fetch.js';
import * as dnsPromises from 'node:dns/promises';
import * as dns from 'node:dns';
import { isPrivateIp, isAddressPrivate, fetchWithTimeout } from './fetch.js';
vi.mock('node:dns/promises', () => ({
lookup: vi.fn(),
}));
// We need to mock node:dns for safeLookup since it uses the callback API
vi.mock('node:dns', () => ({
lookup: vi.fn(),
}));
// Mock global fetch
const originalFetch = global.fetch;
global.fetch = vi.fn();
@@ -114,150 +99,6 @@ describe('fetch utils', () => {
});
});
describe('isPrivateIpAsync', () => {
it('should identify private IPs directly', async () => {
expect(await isPrivateIpAsync('http://10.0.0.1/')).toBe(true);
});
it('should identify domains resolving to private IPs', async () => {
vi.mocked(dnsPromises.lookup).mockImplementation(
async () =>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
[{ address: '10.0.0.1', family: 4 }] as any,
);
expect(await isPrivateIpAsync('http://malicious.com/')).toBe(true);
});
it('should identify domains resolving to public IPs as non-private', async () => {
vi.mocked(dnsPromises.lookup).mockImplementation(
async () =>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
[{ address: '8.8.8.8', family: 4 }] as any,
);
expect(await isPrivateIpAsync('http://google.com/')).toBe(false);
});
it('should throw error if DNS resolution fails (fail closed)', async () => {
vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error'));
await expect(isPrivateIpAsync('http://unreachable.com/')).rejects.toThrow(
'Failed to verify if URL resolves to private IP',
);
});
it('should return false for invalid URLs instead of throwing verification error', async () => {
expect(await isPrivateIpAsync('not-a-url')).toBe(false);
});
});
describe('safeLookup', () => {
it('should filter out private IPs', async () => {
const addresses = [
{ address: '8.8.8.8', family: 4 },
{ address: '10.0.0.1', family: 4 },
];
vi.mocked(dns.lookup).mockImplementation(((
_h: string,
_o: dns.LookupOptions,
cb: (
err: Error | null,
addr: Array<{ address: string; family: number }>,
) => void,
) => {
cb(null, addresses);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
}) as any);
const result = await new Promise<
Array<{ address: string; family: number }>
>((resolve, reject) => {
safeLookup('example.com', { all: true }, (err, filtered) => {
if (err) reject(err);
else resolve(filtered);
});
});
expect(result).toHaveLength(1);
expect(result[0].address).toBe('8.8.8.8');
});
it('should allow explicit localhost', async () => {
const addresses = [{ address: '127.0.0.1', family: 4 }];
vi.mocked(dns.lookup).mockImplementation(((
_h: string,
_o: dns.LookupOptions,
cb: (
err: Error | null,
addr: Array<{ address: string; family: number }>,
) => void,
) => {
cb(null, addresses);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
}) as any);
const result = await new Promise<
Array<{ address: string; family: number }>
>((resolve, reject) => {
safeLookup('localhost', { all: true }, (err, filtered) => {
if (err) reject(err);
else resolve(filtered);
});
});
expect(result).toHaveLength(1);
expect(result[0].address).toBe('127.0.0.1');
});
it('should error if all resolved IPs are private', async () => {
const addresses = [{ address: '10.0.0.1', family: 4 }];
vi.mocked(dns.lookup).mockImplementation(((
_h: string,
_o: dns.LookupOptions,
cb: (
err: Error | null,
addr: Array<{ address: string; family: number }>,
) => void,
) => {
cb(null, addresses);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
}) as any);
await expect(
new Promise((resolve, reject) => {
safeLookup('malicious.com', { all: true }, (err, filtered) => {
if (err) reject(err);
else resolve(filtered);
});
}),
).rejects.toThrow(PrivateIpError);
});
});
describe('safeFetch', () => {
it('should forward to fetch with dispatcher', async () => {
vi.mocked(global.fetch).mockResolvedValue(new Response('ok'));
const response = await safeFetch('https://example.com');
expect(response.status).toBe(200);
expect(global.fetch).toHaveBeenCalledWith(
'https://example.com',
expect.objectContaining({
dispatcher: expect.any(Object),
}),
);
});
it('should handle Refusing to connect errors', async () => {
vi.mocked(global.fetch).mockRejectedValue(new PrivateIpError());
await expect(safeFetch('http://10.0.0.1')).rejects.toThrow(
'Access to private network is blocked',
);
});
});
describe('fetchWithTimeout', () => {
it('should handle timeouts', async () => {
vi.mocked(global.fetch).mockImplementation(
@@ -279,13 +120,5 @@ describe('fetch utils', () => {
'Request timed out after 50ms',
);
});
it('should handle private IP errors via handleFetchError', async () => {
vi.mocked(global.fetch).mockRejectedValue(new PrivateIpError());
await expect(fetchWithTimeout('http://10.0.0.1', 1000)).rejects.toThrow(
'Access to private network is blocked: http://10.0.0.1',
);
});
});
});

View File

@@ -6,37 +6,12 @@
import { getErrorMessage, isNodeError } from './errors.js';
import { URL } from 'node:url';
import * as dns from 'node:dns';
import { lookup } from 'node:dns/promises';
import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici';
import ipaddr from 'ipaddr.js';
const DEFAULT_HEADERS_TIMEOUT = 300000; // 5 minutes
const DEFAULT_BODY_TIMEOUT = 300000; // 5 minutes
// Configure default global dispatcher with higher timeouts
setGlobalDispatcher(
new Agent({
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
bodyTimeout: DEFAULT_BODY_TIMEOUT,
}),
);
// Local extension of RequestInit to support Node.js/undici dispatcher
interface NodeFetchInit extends RequestInit {
dispatcher?: Agent | ProxyAgent;
}
/**
* Error thrown when a connection to a private IP address is blocked for security reasons.
*/
export class PrivateIpError extends Error {
constructor(message = 'Refusing to connect to private IP address') {
super(message);
this.name = 'PrivateIpError';
}
}
export class FetchError extends Error {
constructor(
message: string,
@@ -48,6 +23,14 @@ export class FetchError extends Error {
}
}
// Configure default global dispatcher with higher timeouts
setGlobalDispatcher(
new Agent({
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
bodyTimeout: DEFAULT_BODY_TIMEOUT,
}),
);
/**
* Sanitizes a hostname by stripping IPv6 brackets if present.
*/
@@ -69,53 +52,6 @@ export function isLoopbackHost(hostname: string): boolean {
);
}
/**
* A custom DNS lookup implementation for undici agents that prevents
* connection to private IP ranges (SSRF protection).
*/
export function safeLookup(
hostname: string,
options: dns.LookupOptions | number | null | undefined,
callback: (
err: Error | null,
addresses: Array<{ address: string; family: number }>,
) => void,
): void {
// Use the callback-based dns.lookup to match undici's expected signature.
// We explicitly handle the 'all' option to ensure we get an array of addresses.
const lookupOptions =
typeof options === 'number' ? { family: options } : { ...options };
const finalOptions = { ...lookupOptions, all: true };
dns.lookup(hostname, finalOptions, (err, addresses) => {
if (err) {
callback(err, []);
return;
}
const addressArray = Array.isArray(addresses) ? addresses : [];
const filtered = addressArray.filter(
(addr) => !isAddressPrivate(addr.address) || isLoopbackHost(hostname),
);
if (filtered.length === 0 && addressArray.length > 0) {
callback(new PrivateIpError(), []);
return;
}
callback(null, filtered);
});
}
// Dedicated dispatcher with connection-level SSRF protection (safeLookup)
const safeDispatcher = new Agent({
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
bodyTimeout: DEFAULT_BODY_TIMEOUT,
connect: {
lookup: safeLookup,
},
});
export function isPrivateIp(url: string): boolean {
try {
const hostname = new URL(url).hostname;
@@ -125,37 +61,6 @@ export function isPrivateIp(url: string): boolean {
}
}
/**
* Checks if a URL resolves to a private IP address.
* Performs DNS resolution to prevent DNS rebinding/SSRF bypasses.
*/
export async function isPrivateIpAsync(url: string): Promise<boolean> {
try {
const parsed = new URL(url);
const hostname = parsed.hostname;
// Fast check for literal IPs or localhost
if (isAddressPrivate(hostname)) {
return true;
}
// Resolve DNS to check the actual target IP
const addresses = await lookup(hostname, { all: true });
return addresses.some((addr) => isAddressPrivate(addr.address));
} catch (e) {
if (
e instanceof Error &&
e.name === 'TypeError' &&
e.message.includes('Invalid URL')
) {
return false;
}
throw new Error(`Failed to verify if URL resolves to private IP: ${url}`, {
cause: e,
});
}
}
/**
* IANA Benchmark Testing Range (198.18.0.0/15).
* Classified as 'unicast' by ipaddr.js but is reserved and should not be
@@ -210,72 +115,15 @@ export function isAddressPrivate(address: string): boolean {
}
}
/**
* Internal helper to map varied fetch errors to a standardized FetchError.
* Centralizes security-related error mapping (e.g. PrivateIpError).
*/
function handleFetchError(error: unknown, url: string): never {
if (error instanceof PrivateIpError) {
throw new FetchError(
`Access to private network is blocked: ${url}`,
'ERR_PRIVATE_NETWORK',
{ cause: error },
);
}
if (error instanceof FetchError) {
throw error;
}
throw new FetchError(
getErrorMessage(error),
isNodeError(error) ? error.code : undefined,
{ cause: error },
);
}
/**
* Enhanced fetch with SSRF protection.
* Prevents access to private/internal networks at the connection level.
*/
export async function safeFetch(
input: RequestInfo | URL,
init?: RequestInit,
): Promise<Response> {
const nodeInit: NodeFetchInit = {
...init,
dispatcher: safeDispatcher,
};
try {
// eslint-disable-next-line no-restricted-syntax
return await fetch(input, nodeInit);
} catch (error) {
const url =
input instanceof Request
? input.url
: typeof input === 'string'
? input
: input.toString();
handleFetchError(error, url);
}
}
/**
* Creates an undici ProxyAgent that incorporates safe DNS lookup.
*/
export function createSafeProxyAgent(proxyUrl: string): ProxyAgent {
return new ProxyAgent({
uri: proxyUrl,
connect: {
lookup: safeLookup,
},
});
}
/**
* Performs a fetch with a specified timeout and connection-level SSRF protection.
*/
export async function fetchWithTimeout(
url: string,
timeout: number,
@@ -294,21 +142,17 @@ export async function fetchWithTimeout(
}
}
const nodeInit: NodeFetchInit = {
...options,
signal: controller.signal,
dispatcher: safeDispatcher,
};
try {
// eslint-disable-next-line no-restricted-syntax
const response = await fetch(url, nodeInit);
const response = await fetch(url, {
...options,
signal: controller.signal,
});
return response;
} catch (error) {
if (isNodeError(error) && error.code === 'ABORT_ERR') {
throw new FetchError(`Request timed out after ${timeout}ms`, 'ETIMEDOUT');
}
handleFetchError(error, url.toString());
throw new FetchError(getErrorMessage(error), undefined, { cause: error });
} finally {
clearTimeout(timeoutId);
}

View File

@@ -454,7 +454,6 @@ export async function exchangeCodeForToken(
params.append('resource', resource);
}
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(config.tokenUrl, {
method: 'POST',
headers: {
@@ -508,7 +507,6 @@ export async function refreshAccessToken(
params.append('resource', resource);
}
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(tokenUrl, {
method: 'POST',
headers: {

View File

@@ -42,7 +42,6 @@ async function checkForUpdates(
const currentVersion = context.extension.packageJSON.version;
// Fetch extension details from the VSCode Marketplace.
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(
'https://marketplace.visualstudio.com/_apis/public/gallery/extensionquery',
{

View File

@@ -356,7 +356,6 @@ describe('IDEServer', () => {
});
it('should reject request without auth token', async () => {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(`http://localhost:${port}/mcp`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
@@ -371,7 +370,6 @@ describe('IDEServer', () => {
});
it('should allow request with valid auth token', async () => {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(`http://localhost:${port}/mcp`, {
method: 'POST',
headers: {
@@ -389,7 +387,6 @@ describe('IDEServer', () => {
});
it('should reject request with invalid auth token', async () => {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(`http://localhost:${port}/mcp`, {
method: 'POST',
headers: {
@@ -416,7 +413,6 @@ describe('IDEServer', () => {
];
for (const header of malformedHeaders) {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(`http://localhost:${port}/mcp`, {
method: 'POST',
headers: {