diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index bc299e53e2..bceb6deeac 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -1063,6 +1063,9 @@ export const useGeminiStream = ( 'Response stopped due to prohibited image content.', [FinishReason.NO_IMAGE]: 'Response stopped because no image was generated.', + [FinishReason.IMAGE_RECITATION]: + 'Response stopped due to image recitation policy.', + [FinishReason.IMAGE_OTHER]: 'Response stopped for other image reasons.', }; const message = finishReasonMessages[finishReason]; diff --git a/packages/core/src/agents/a2a-client-manager.test.ts b/packages/core/src/agents/a2a-client-manager.test.ts index afa66d0e5f..7859de38e6 100644 --- a/packages/core/src/agents/a2a-client-manager.test.ts +++ b/packages/core/src/agents/a2a-client-manager.test.ts @@ -5,96 +5,119 @@ */ import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest'; -import { - A2AClientManager, - type SendMessageResult, -} from './a2a-client-manager.js'; -import type { AgentCard, Task } from '@a2a-js/sdk'; -import { - ClientFactory, - DefaultAgentCardResolver, - createAuthenticatingFetchWithRetry, - ClientFactoryOptions, - type AuthenticationHandler, - type Client, -} from '@a2a-js/sdk/client'; +import { A2AClientManager } from './a2a-client-manager.js'; +import type { AgentCard } from '@a2a-js/sdk'; +import * as sdkClient from '@a2a-js/sdk/client'; +import * as dnsPromises from 'node:dns/promises'; +import type { LookupOptions } from 'node:dns'; import { debugLogger } from '../utils/debugLogger.js'; +interface MockClient { + sendMessageStream: ReturnType; + getTask: ReturnType; + cancelTask: ReturnType; +} + +vi.mock('@a2a-js/sdk/client', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...(actual as Record), + createAuthenticatingFetchWithRetry: vi.fn(), + ClientFactory: vi.fn(), + DefaultAgentCardResolver: vi.fn(), + ClientFactoryOptions: { + createFrom: vi.fn(), + default: {}, + }, + }; +}); + vi.mock('../utils/debugLogger.js', () => ({ debugLogger: { debug: vi.fn(), }, })); -vi.mock('@a2a-js/sdk/client', () => { - const ClientFactory = vi.fn(); - const DefaultAgentCardResolver = vi.fn(); - const RestTransportFactory = vi.fn(); - const JsonRpcTransportFactory = vi.fn(); - const ClientFactoryOptions = { - default: {}, - createFrom: vi.fn(), - }; - const createAuthenticatingFetchWithRetry = vi.fn(); - - DefaultAgentCardResolver.prototype.resolve = vi.fn(); - ClientFactory.prototype.createFromUrl = vi.fn(); - - return { - ClientFactory, - ClientFactoryOptions, - DefaultAgentCardResolver, - RestTransportFactory, - JsonRpcTransportFactory, - createAuthenticatingFetchWithRetry, - }; -}); +vi.mock('node:dns/promises', () => ({ + lookup: vi.fn(), +})); describe('A2AClientManager', () => { let manager: A2AClientManager; + const mockAgentCard: AgentCard = { + name: 'test-agent', + description: 'A test agent', + url: 'http://test.agent', + version: '1.0.0', + protocolVersion: '0.1.0', + capabilities: {}, + skills: [], + defaultInputModes: [], + defaultOutputModes: [], + }; + + const mockClient: MockClient = { + sendMessageStream: vi.fn(), + getTask: vi.fn(), + cancelTask: vi.fn(), + }; - // Stable mocks initialized once - const sendMessageStreamMock = vi.fn(); - const getTaskMock = vi.fn(); - const cancelTaskMock = vi.fn(); - const getAgentCardMock = vi.fn(); const authFetchMock = vi.fn(); - const mockClient = { - sendMessageStream: sendMessageStreamMock, - getTask: getTaskMock, - cancelTask: cancelTaskMock, - getAgentCard: getAgentCardMock, - } as unknown as Client; - - const mockAgentCard: Partial = { name: 'TestAgent' }; - beforeEach(() => { vi.clearAllMocks(); - A2AClientManager.resetInstanceForTesting(); manager = A2AClientManager.getInstance(); + manager.clearCache(); - // Default mock implementations - getAgentCardMock.mockResolvedValue({ + // Default DNS mock: resolve to public IP. + // Use any cast only for return value due to complex multi-signature overloads. + vi.mocked(dnsPromises.lookup).mockImplementation( + async (_h: string, options?: LookupOptions | number) => { + const addr = { address: '93.184.216.34', family: 4 }; + const isAll = typeof options === 'object' && options?.all; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return (isAll ? [addr] : addr) as any; + }, + ); + + // Re-create the instances as plain objects that can be spied on + const factoryInstance = { + createFromUrl: vi.fn(), + createFromAgentCard: vi.fn(), + }; + const resolverInstance = { + resolve: vi.fn(), + }; + + vi.mocked(sdkClient.ClientFactory).mockReturnValue( + factoryInstance as unknown as sdkClient.ClientFactory, + ); + vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue( + resolverInstance as unknown as sdkClient.DefaultAgentCardResolver, + ); + + vi.spyOn(factoryInstance, 'createFromUrl').mockResolvedValue( + mockClient as unknown as sdkClient.Client, + ); + vi.spyOn(factoryInstance, 'createFromAgentCard').mockResolvedValue( + mockClient as unknown as sdkClient.Client, + ); + vi.spyOn(resolverInstance, 'resolve').mockResolvedValue({ ...mockAgentCard, url: 'http://test.agent/real/endpoint', } as AgentCard); - vi.mocked(ClientFactory.prototype.createFromUrl).mockResolvedValue( - mockClient, + vi.spyOn(sdkClient.ClientFactoryOptions, 'createFrom').mockImplementation( + (_defaults, overrides) => + overrides as unknown as sdkClient.ClientFactoryOptions, ); - vi.mocked(DefaultAgentCardResolver.prototype.resolve).mockResolvedValue({ - ...mockAgentCard, - url: 'http://test.agent/real/endpoint', - } as AgentCard); - - vi.mocked(ClientFactoryOptions.createFrom).mockImplementation( - (_defaults, overrides) => overrides as ClientFactoryOptions, - ); - - vi.mocked(createAuthenticatingFetchWithRetry).mockReturnValue( - authFetchMock, + vi.mocked(sdkClient.createAuthenticatingFetchWithRetry).mockImplementation( + () => + authFetchMock.mockResolvedValue({ + ok: true, + json: async () => ({}), + } as Response), ); vi.stubGlobal( @@ -123,21 +146,27 @@ describe('A2AClientManager', () => { 'TestAgent', 'http://test.agent/card', ); - expect(agentCard).toMatchObject(mockAgentCard); expect(manager.getAgentCard('TestAgent')).toBe(agentCard); expect(manager.getClient('TestAgent')).toBeDefined(); }); + it('should configure ClientFactory with REST, JSON-RPC, and gRPC transports', async () => { + await manager.loadAgent('TestAgent', 'http://test.agent/card'); + expect(sdkClient.ClientFactoryOptions.createFrom).toHaveBeenCalled(); + }); + it('should throw an error if an agent with the same name is already loaded', async () => { await manager.loadAgent('TestAgent', 'http://test.agent/card'); await expect( - manager.loadAgent('TestAgent', 'http://another.agent/card'), + manager.loadAgent('TestAgent', 'http://test.agent/card'), ).rejects.toThrow("Agent with name 'TestAgent' is already loaded."); }); it('should use native fetch by default', async () => { await manager.loadAgent('TestAgent', 'http://test.agent/card'); - expect(createAuthenticatingFetchWithRetry).not.toHaveBeenCalled(); + expect( + sdkClient.createAuthenticatingFetchWithRetry, + ).not.toHaveBeenCalled(); }); it('should use provided custom authentication handler for transports only', async () => { @@ -146,21 +175,13 @@ describe('A2AClientManager', () => { shouldRetryWithHeaders: vi.fn(), }; await manager.loadAgent( - 'CustomAuthAgent', - 'http://custom.agent/card', - customAuthHandler as unknown as AuthenticationHandler, - ); - - expect(createAuthenticatingFetchWithRetry).toHaveBeenCalledWith( - expect.anything(), - customAuthHandler, + 'TestAgent', + 'http://test.agent/card', + customAuthHandler as unknown as sdkClient.AuthenticationHandler, ); // Card resolver should NOT use the authenticated fetch by default. - const resolverInstance = vi.mocked(DefaultAgentCardResolver).mock - .instances[0]; - expect(resolverInstance).toBeDefined(); - const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock + const resolverOptions = vi.mocked(sdkClient.DefaultAgentCardResolver).mock .calls[0][0]; expect(resolverOptions?.fetchImpl).not.toBe(authFetchMock); }); @@ -173,10 +194,10 @@ describe('A2AClientManager', () => { await manager.loadAgent( 'AuthCardAgent', 'http://authcard.agent/card', - customAuthHandler as unknown as AuthenticationHandler, + customAuthHandler as unknown as sdkClient.AuthenticationHandler, ); - const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock + const resolverOptions = vi.mocked(sdkClient.DefaultAgentCardResolver).mock .calls[0][0]; const cardFetch = resolverOptions?.fetchImpl as typeof fetch; @@ -204,10 +225,10 @@ describe('A2AClientManager', () => { await manager.loadAgent( 'AuthCardAgent401', 'http://authcard.agent/card', - customAuthHandler as unknown as AuthenticationHandler, + customAuthHandler as unknown as sdkClient.AuthenticationHandler, ); - const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock + const resolverOptions = vi.mocked(sdkClient.DefaultAgentCardResolver).mock .calls[0][0]; const cardFetch = resolverOptions?.fetchImpl as typeof fetch; @@ -220,100 +241,155 @@ describe('A2AClientManager', () => { it('should log a debug message upon loading an agent', async () => { await manager.loadAgent('TestAgent', 'http://test.agent/card'); expect(debugLogger.debug).toHaveBeenCalledWith( - "[A2AClientManager] Loaded agent 'TestAgent' from http://test.agent/card", + expect.stringContaining("Loaded agent 'TestAgent'"), ); }); it('should clear the cache', async () => { await manager.loadAgent('TestAgent', 'http://test.agent/card'); - expect(manager.getAgentCard('TestAgent')).toBeDefined(); - expect(manager.getClient('TestAgent')).toBeDefined(); - manager.clearCache(); - expect(manager.getAgentCard('TestAgent')).toBeUndefined(); expect(manager.getClient('TestAgent')).toBeUndefined(); - expect(debugLogger.debug).toHaveBeenCalledWith( - '[A2AClientManager] Cache cleared.', + }); + + it('should throw if resolveAgentCard fails', async () => { + const resolverInstance = { + resolve: vi.fn().mockRejectedValue(new Error('Resolution failed')), + }; + vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue( + resolverInstance as unknown as sdkClient.DefaultAgentCardResolver, ); + + await expect( + manager.loadAgent('FailAgent', 'http://fail.agent'), + ).rejects.toThrow('Resolution failed'); + }); + + it('should throw if factory.createFromAgentCard fails', async () => { + const factoryInstance = { + createFromAgentCard: vi + .fn() + .mockRejectedValue(new Error('Factory failed')), + }; + vi.mocked(sdkClient.ClientFactory).mockReturnValue( + factoryInstance as unknown as sdkClient.ClientFactory, + ); + + await expect( + manager.loadAgent('FailAgent', 'http://fail.agent'), + ).rejects.toThrow('Factory failed'); + }); + }); + + describe('getAgentCard and getClient', () => { + it('should return undefined if agent is not found', () => { + expect(manager.getAgentCard('Unknown')).toBeUndefined(); + expect(manager.getClient('Unknown')).toBeUndefined(); }); }); describe('sendMessageStream', () => { beforeEach(async () => { - await manager.loadAgent('TestAgent', 'http://test.agent'); + await manager.loadAgent('TestAgent', 'http://test.agent/card'); }); it('should send a message and return a stream', async () => { - const mockResult = { - kind: 'message', - messageId: 'a', - parts: [], - role: 'agent', - } as SendMessageResult; - - sendMessageStreamMock.mockReturnValue( + mockClient.sendMessageStream.mockReturnValue( (async function* () { - yield mockResult; + yield { kind: 'message' }; })(), ); const stream = manager.sendMessageStream('TestAgent', 'Hello'); const results = []; - for await (const res of stream) { - results.push(res); + for await (const result of stream) { + results.push(result); } - expect(results).toEqual([mockResult]); - expect(sendMessageStreamMock).toHaveBeenCalledWith( + expect(results).toHaveLength(1); + expect(mockClient.sendMessageStream).toHaveBeenCalled(); + }); + + it('should use contextId and taskId when provided', async () => { + mockClient.sendMessageStream.mockReturnValue( + (async function* () { + yield { kind: 'message' }; + })(), + ); + + const stream = manager.sendMessageStream('TestAgent', 'Hello', { + contextId: 'ctx123', + taskId: 'task456', + }); + // trigger execution + for await (const _ of stream) { + break; + } + + expect(mockClient.sendMessageStream).toHaveBeenCalledWith( expect.objectContaining({ - message: expect.anything(), + message: expect.objectContaining({ + contextId: 'ctx123', + taskId: 'task456', + }), }), expect.any(Object), ); }); - it('should use contextId and taskId when provided', async () => { - sendMessageStreamMock.mockReturnValue( + it('should correctly propagate AbortSignal to the stream', async () => { + mockClient.sendMessageStream.mockReturnValue( (async function* () { - yield { - kind: 'message', - messageId: 'a', - parts: [], - role: 'agent', - } as SendMessageResult; + yield { kind: 'message' }; })(), ); - const expectedContextId = 'user-context-id'; - const expectedTaskId = 'user-task-id'; - + const controller = new AbortController(); const stream = manager.sendMessageStream('TestAgent', 'Hello', { - contextId: expectedContextId, - taskId: expectedTaskId, + signal: controller.signal, }); - + // trigger execution for await (const _ of stream) { - // consume stream + break; } - const call = sendMessageStreamMock.mock.calls[0][0]; - expect(call.message.contextId).toBe(expectedContextId); - expect(call.message.taskId).toBe(expectedTaskId); + expect(mockClient.sendMessageStream).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ signal: controller.signal }), + ); + }); + + it('should handle a multi-chunk stream with different event types', async () => { + mockClient.sendMessageStream.mockReturnValue( + (async function* () { + yield { kind: 'message', messageId: 'm1' }; + yield { kind: 'status-update', taskId: 't1' }; + })(), + ); + + const stream = manager.sendMessageStream('TestAgent', 'Hello'); + const results = []; + for await (const result of stream) { + results.push(result); + } + + expect(results).toHaveLength(2); + expect(results[0].kind).toBe('message'); + expect(results[1].kind).toBe('status-update'); }); it('should throw prefixed error on failure', async () => { - sendMessageStreamMock.mockImplementationOnce(() => { - throw new Error('Network error'); + mockClient.sendMessageStream.mockImplementation(() => { + throw new Error('Network failure'); }); const stream = manager.sendMessageStream('TestAgent', 'Hello'); await expect(async () => { for await (const _ of stream) { - // consume + // empty } }).rejects.toThrow( - '[A2AClientManager] sendMessageStream Error [TestAgent]: Network error', + '[A2AClientManager] sendMessageStream Error [TestAgent]: Network failure', ); }); @@ -321,7 +397,7 @@ describe('A2AClientManager', () => { const stream = manager.sendMessageStream('NonExistentAgent', 'Hello'); await expect(async () => { for await (const _ of stream) { - // consume + // empty } }).rejects.toThrow("Agent 'NonExistentAgent' not found."); }); @@ -329,28 +405,23 @@ describe('A2AClientManager', () => { describe('getTask', () => { beforeEach(async () => { - await manager.loadAgent('TestAgent', 'http://test.agent'); + await manager.loadAgent('TestAgent', 'http://test.agent/card'); }); it('should get a task from the correct agent', async () => { - getTaskMock.mockResolvedValue({ - id: 'task123', - contextId: 'a', - kind: 'task', - status: { state: 'completed' }, - } as Task); + const mockTask = { id: 'task123', kind: 'task' }; + mockClient.getTask.mockResolvedValue(mockTask); - await manager.getTask('TestAgent', 'task123'); - expect(getTaskMock).toHaveBeenCalledWith({ - id: 'task123', - }); + const result = await manager.getTask('TestAgent', 'task123'); + expect(result).toBe(mockTask); + expect(mockClient.getTask).toHaveBeenCalledWith({ id: 'task123' }); }); it('should throw prefixed error on failure', async () => { - getTaskMock.mockRejectedValueOnce(new Error('Network error')); + mockClient.getTask.mockRejectedValue(new Error('Not found')); await expect(manager.getTask('TestAgent', 'task123')).rejects.toThrow( - 'A2AClient getTask Error [TestAgent]: Network error', + 'A2AClient getTask Error [TestAgent]: Not found', ); }); @@ -363,28 +434,23 @@ describe('A2AClientManager', () => { describe('cancelTask', () => { beforeEach(async () => { - await manager.loadAgent('TestAgent', 'http://test.agent'); + await manager.loadAgent('TestAgent', 'http://test.agent/card'); }); it('should cancel a task on the correct agent', async () => { - cancelTaskMock.mockResolvedValue({ - id: 'task123', - contextId: 'a', - kind: 'task', - status: { state: 'canceled' }, - } as Task); + const mockTask = { id: 'task123', kind: 'task' }; + mockClient.cancelTask.mockResolvedValue(mockTask); - await manager.cancelTask('TestAgent', 'task123'); - expect(cancelTaskMock).toHaveBeenCalledWith({ - id: 'task123', - }); + const result = await manager.cancelTask('TestAgent', 'task123'); + expect(result).toBe(mockTask); + expect(mockClient.cancelTask).toHaveBeenCalledWith({ id: 'task123' }); }); it('should throw prefixed error on failure', async () => { - cancelTaskMock.mockRejectedValueOnce(new Error('Network error')); + mockClient.cancelTask.mockRejectedValue(new Error('Cannot cancel')); await expect(manager.cancelTask('TestAgent', 'task123')).rejects.toThrow( - 'A2AClient cancelTask Error [TestAgent]: Network error', + 'A2AClient cancelTask Error [TestAgent]: Cannot cancel', ); }); @@ -394,4 +460,82 @@ describe('A2AClientManager', () => { ).rejects.toThrow("Agent 'NonExistentAgent' not found."); }); }); + + describe('Protocol Routing & URL Logic', () => { + it('should correctly split URLs to prevent .well-known doubling', async () => { + const fullUrl = 'http://localhost:9001/.well-known/agent-card.json'; + const resolverInstance = { + resolve: vi.fn().mockResolvedValue({ name: 'test' } as AgentCard), + }; + vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue( + resolverInstance as unknown as sdkClient.DefaultAgentCardResolver, + ); + + await manager.loadAgent('test-doubling', fullUrl); + + expect(resolverInstance.resolve).toHaveBeenCalledWith( + 'http://localhost:9001/', + undefined, + ); + }); + + it('should throw if a remote agent uses a private IP (SSRF protection)', async () => { + const privateUrl = 'http://169.254.169.254/.well-known/agent-card.json'; + await expect(manager.loadAgent('ssrf-agent', privateUrl)).rejects.toThrow( + /Refusing to load agent 'ssrf-agent' from private IP range/, + ); + }); + + it('should throw if a domain resolves to a private IP (DNS SSRF protection)', async () => { + const maliciousDomainUrl = + 'http://malicious.com/.well-known/agent-card.json'; + + vi.mocked(dnsPromises.lookup).mockImplementationOnce( + async (_h: string, options?: LookupOptions | number) => { + const addr = { address: '10.0.0.1', family: 4 }; + const isAll = typeof options === 'object' && options?.all; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return (isAll ? [addr] : addr) as any; + }, + ); + + await expect( + manager.loadAgent('dns-ssrf-agent', maliciousDomainUrl), + ).rejects.toThrow(/private IP range/); + }); + + it('should throw if a public agent card contains a private transport URL (Deep SSRF protection)', async () => { + const publicUrl = 'https://public.agent.com/card.json'; + const resolverInstance = { + resolve: vi.fn().mockResolvedValue({ + ...mockAgentCard, + url: 'http://192.168.1.1/api', // Malicious private transport in public card + } as AgentCard), + }; + vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue( + resolverInstance as unknown as sdkClient.DefaultAgentCardResolver, + ); + + // DNS for public.agent.com is public + vi.mocked(dnsPromises.lookup).mockImplementation( + async (hostname: string, options?: LookupOptions | number) => { + const isAll = typeof options === 'object' && options?.all; + if (hostname === 'public.agent.com') { + const addr = { address: '1.1.1.1', family: 4 }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return (isAll ? [addr] : addr) as any; + } + const addr = { address: '192.168.1.1', family: 4 }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return (isAll ? [addr] : addr) as any; + }, + ); + + await expect( + manager.loadAgent('malicious-agent', publicUrl), + ).rejects.toThrow( + /contains transport URL pointing to private IP range: http:\/\/192.168.1.1\/api/, + ); + }); + }); }); diff --git a/packages/core/src/agents/a2a-client-manager.ts b/packages/core/src/agents/a2a-client-manager.ts index 7d8f27f02b..32ec766eb0 100644 --- a/packages/core/src/agents/a2a-client-manager.ts +++ b/packages/core/src/agents/a2a-client-manager.ts @@ -12,20 +12,52 @@ import type { TaskStatusUpdateEvent, TaskArtifactUpdateEvent, } from '@a2a-js/sdk'; +import type { AuthenticationHandler, Client } from '@a2a-js/sdk/client'; import { - type Client, ClientFactory, ClientFactoryOptions, DefaultAgentCardResolver, - RestTransportFactory, JsonRpcTransportFactory, - type AuthenticationHandler, + RestTransportFactory, createAuthenticatingFetchWithRetry, } from '@a2a-js/sdk/client'; +import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc'; import { v4 as uuidv4 } from 'uuid'; import { Agent as UndiciAgent } from 'undici'; +import { + getGrpcChannelOptions, + getGrpcCredentials, + normalizeAgentCard, + pinUrlToIp, + splitAgentCardUrl, +} from './a2aUtils.js'; +import { + isPrivateIpAsync, + safeLookup, + isLoopbackHost, + safeFetch, +} from '../utils/fetch.js'; import { debugLogger } from '../utils/debugLogger.js'; -import { safeLookup } from '../utils/fetch.js'; + +/** + * Result of sending a message, which can be a full message, a task, + * or an incremental status/artifact update. + */ +export type SendMessageResult = + | Message + | Task + | TaskStatusUpdateEvent + | TaskArtifactUpdateEvent; + +/** + * Internal interface representing properties we inject into the SDK + * to enable DNS rebinding protection for gRPC connections. + * TODO: Replace with official SDK pinning API once available. + */ +interface InternalGrpcExtensions { + target: string; + grpcChannelOptions: Record; +} // Remote agents can take 10+ minutes (e.g. Deep Research). // Use a dedicated dispatcher so the global 5-min timeout isn't affected. @@ -34,22 +66,16 @@ const a2aDispatcher = new UndiciAgent({ headersTimeout: A2A_TIMEOUT, bodyTimeout: A2A_TIMEOUT, connect: { - lookup: safeLookup, // SSRF protection at connection level + // SSRF protection at the connection level (mitigates DNS rebinding) + lookup: safeLookup, }, }); const a2aFetch: typeof fetch = (input, init) => - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection - fetch(input, { ...init, dispatcher: a2aDispatcher } as RequestInit); - -export type SendMessageResult = - | Message - | Task - | TaskStatusUpdateEvent - | TaskArtifactUpdateEvent; + safeFetch(input, { ...init, dispatcher: a2aDispatcher }); /** - * Manages A2A clients and caches loaded agent information. - * Follows a singleton pattern to ensure a single client instance. + * Orchestrates communication with remote A2A agents. + * Manages protocol negotiation, authentication, and transport selection. */ export class A2AClientManager { private static instance: A2AClientManager; @@ -70,19 +96,10 @@ export class A2AClientManager { return A2AClientManager.instance; } - /** - * Resets the singleton instance. Only for testing purposes. - * @internal - */ - static resetInstanceForTesting() { - // @ts-expect-error - Resetting singleton for testing - A2AClientManager.instance = undefined; - } - /** * Loads an agent by fetching its AgentCard and caches the client. * @param name The name to assign to the agent. - * @param agentCardUrl The full URL to the agent's card. + * @param agentCardUrl {string} The full URL to the agent's card. * @param authHandler Optional authentication handler to use for this agent. * @returns The loaded AgentCard. */ @@ -119,21 +136,50 @@ export class A2AClientManager { }; const resolver = new DefaultAgentCardResolver({ fetchImpl: cardFetch }); + const agentCard = await this.resolveAgentCard(name, agentCardUrl, resolver); - const options = ClientFactoryOptions.createFrom( + // Pin URL to IP to prevent DNS rebinding for gRPC (connection-level SSRF protection) + const grpcInterface = agentCard.additionalInterfaces?.find( + (i) => i.transport === 'GRPC', + ); + const urlToPin = grpcInterface?.url ?? agentCard.url; + const { pinnedUrl, hostname } = await pinUrlToIp(urlToPin, name); + + // Prepare base gRPC options + const baseGrpcOptions: ConstructorParameters< + typeof GrpcTransportFactory + >[0] = { + grpcChannelCredentials: getGrpcCredentials(urlToPin), + }; + + // We inject additional properties into the transport options to force + // the use of a pinned IP address and matching SSL authority. This is + // required for robust DNS Rebinding protection. + const transportOptions = { + ...baseGrpcOptions, + target: pinnedUrl, + grpcChannelOptions: getGrpcChannelOptions(hostname), + } as ConstructorParameters[0] & + InternalGrpcExtensions; + + // Configure standard SDK client for tool registration and discovery + const clientOptions = ClientFactoryOptions.createFrom( ClientFactoryOptions.default, { transports: [ new RestTransportFactory({ fetchImpl: authFetch }), new JsonRpcTransportFactory({ fetchImpl: authFetch }), + new GrpcTransportFactory( + transportOptions as ConstructorParameters< + typeof GrpcTransportFactory + >[0], + ), ], cardResolver: resolver, }, ); - - const factory = new ClientFactory(options); - const client = await factory.createFromUrl(agentCardUrl, ''); - const agentCard = await client.getAgentCard(); + const factory = new ClientFactory(clientOptions); + const client = await factory.createFromAgentCard(agentCard); this.clients.set(name, client); this.agentCards.set(name, agentCard); @@ -168,9 +214,7 @@ export class A2AClientManager { options?: { contextId?: string; taskId?: string; signal?: AbortSignal }, ): AsyncIterable { const client = this.clients.get(agentName); - if (!client) { - throw new Error(`Agent '${agentName}' not found.`); - } + if (!client) throw new Error(`Agent '${agentName}' not found.`); const messageParams: MessageSendParams = { message: { @@ -186,7 +230,7 @@ export class A2AClientManager { try { yield* client.sendMessageStream(messageParams, { signal: options?.signal, - }); + }) as AsyncIterable; } catch (error: unknown) { const prefix = `[A2AClientManager] sendMessageStream Error [${agentName}]`; if (error instanceof Error) { @@ -224,9 +268,7 @@ export class A2AClientManager { */ async getTask(agentName: string, taskId: string): Promise { const client = this.clients.get(agentName); - if (!client) { - throw new Error(`Agent '${agentName}' not found.`); - } + if (!client) throw new Error(`Agent '${agentName}' not found.`); try { return await client.getTask({ id: taskId }); } catch (error: unknown) { @@ -246,9 +288,7 @@ export class A2AClientManager { */ async cancelTask(agentName: string, taskId: string): Promise { const client = this.clients.get(agentName); - if (!client) { - throw new Error(`Agent '${agentName}' not found.`); - } + if (!client) throw new Error(`Agent '${agentName}' not found.`); try { return await client.cancelTask({ id: taskId }); } catch (error: unknown) { @@ -259,4 +299,66 @@ export class A2AClientManager { throw new Error(`${prefix}: Unexpected error: ${String(error)}`); } } + + /** + * Resolves and normalizes an agent card from a given URL. + * Handles splitting the URL if it already contains the standard .well-known path. + * Also performs basic SSRF validation to prevent internal IP access. + */ + private async resolveAgentCard( + agentName: string, + url: string, + resolver: DefaultAgentCardResolver, + ): Promise { + // Validate URL to prevent SSRF (with DNS resolution) + if (await isPrivateIpAsync(url)) { + // Local/private IPs are allowed ONLY for localhost for testing. + const parsed = new URL(url); + if (!isLoopbackHost(parsed.hostname)) { + throw new Error( + `Refusing to load agent '${agentName}' from private IP range: ${url}. Remote agents must use public URLs.`, + ); + } + } + + const { baseUrl, path } = splitAgentCardUrl(url); + const rawCard = await resolver.resolve(baseUrl, path); + const agentCard = normalizeAgentCard(rawCard); + + // Deep validation of all transport URLs within the card to prevent SSRF + await this.validateAgentCardUrls(agentName, agentCard); + + return agentCard; + } + + /** + * Validates all URLs (top-level and interfaces) within an AgentCard for SSRF. + */ + private async validateAgentCardUrls( + agentName: string, + card: AgentCard, + ): Promise { + const urlsToValidate = [card.url]; + if (card.additionalInterfaces) { + for (const intf of card.additionalInterfaces) { + if (intf.url) urlsToValidate.push(intf.url); + } + } + + for (const url of urlsToValidate) { + if (!url) continue; + + // Ensure URL has a scheme for the parser (gRPC often provides raw IP:port) + const validationUrl = url.includes('://') ? url : `http://${url}`; + + if (await isPrivateIpAsync(validationUrl)) { + const parsed = new URL(validationUrl); + if (!isLoopbackHost(parsed.hostname)) { + throw new Error( + `Refusing to load agent '${agentName}': contains transport URL pointing to private IP range: ${url}.`, + ); + } + } + } + } } diff --git a/packages/core/src/utils/fetch.ts b/packages/core/src/utils/fetch.ts index a324172d94..86abd519b6 100644 --- a/packages/core/src/utils/fetch.ts +++ b/packages/core/src/utils/fetch.ts @@ -240,16 +240,16 @@ function handleFetchError(error: unknown, url: string): never { */ export async function safeFetch( input: RequestInfo | URL, - init?: RequestInit, + init?: NodeFetchInit, ): Promise { const nodeInit: NodeFetchInit = { - ...init, dispatcher: safeDispatcher, + ...init, }; try { // eslint-disable-next-line no-restricted-syntax - return await fetch(input, nodeInit); + return await fetch(input, nodeInit as RequestInit); } catch (error) { const url = input instanceof Request