feat(a2a): add robust A2A V0 support, gRPC transport, and config validation

This commit is contained in:
Alisa Novikova
2026-03-05 16:33:19 -08:00
parent 6691fac50e
commit dae8d85a18
12 changed files with 686 additions and 205 deletions
@@ -5,96 +5,102 @@
*/
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import {
A2AClientManager,
type SendMessageResult,
} from './a2a-client-manager.js';
import type { AgentCard, Task } from '@a2a-js/sdk';
import {
ClientFactory,
DefaultAgentCardResolver,
createAuthenticatingFetchWithRetry,
ClientFactoryOptions,
type AuthenticationHandler,
type Client,
} from '@a2a-js/sdk/client';
import { A2AClientManager } from './a2a-client-manager.js';
import type { AgentCard } from '@a2a-js/sdk';
import * as sdkClient from '@a2a-js/sdk/client';
import { debugLogger } from '../utils/debugLogger.js';
interface MockClient {
sendMessageStream: ReturnType<typeof vi.fn>;
getTask: ReturnType<typeof vi.fn>;
cancelTask: ReturnType<typeof vi.fn>;
}
vi.mock('@a2a-js/sdk/client', async (importOriginal) => {
const actual = await importOriginal();
return {
...(actual as Record<string, unknown>),
createAuthenticatingFetchWithRetry: vi.fn(),
ClientFactory: vi.fn(),
DefaultAgentCardResolver: vi.fn(),
ClientFactoryOptions: {
createFrom: vi.fn(),
default: {},
},
};
});
vi.mock('../utils/debugLogger.js', () => ({
debugLogger: {
debug: vi.fn(),
},
}));
vi.mock('@a2a-js/sdk/client', () => {
const ClientFactory = vi.fn();
const DefaultAgentCardResolver = vi.fn();
const RestTransportFactory = vi.fn();
const JsonRpcTransportFactory = vi.fn();
const ClientFactoryOptions = {
default: {},
createFrom: vi.fn(),
};
const createAuthenticatingFetchWithRetry = vi.fn();
DefaultAgentCardResolver.prototype.resolve = vi.fn();
ClientFactory.prototype.createFromUrl = vi.fn();
return {
ClientFactory,
ClientFactoryOptions,
DefaultAgentCardResolver,
RestTransportFactory,
JsonRpcTransportFactory,
createAuthenticatingFetchWithRetry,
};
});
describe('A2AClientManager', () => {
let manager: A2AClientManager;
const mockAgentCard: AgentCard = {
name: 'test-agent',
description: 'A test agent',
url: 'http://test.agent',
version: '1.0.0',
protocolVersion: '0.1.0',
capabilities: {},
skills: [],
defaultInputModes: [],
defaultOutputModes: [],
};
const mockClient: MockClient = {
sendMessageStream: vi.fn(),
getTask: vi.fn(),
cancelTask: vi.fn(),
};
// Stable mocks initialized once
const sendMessageStreamMock = vi.fn();
const getTaskMock = vi.fn();
const cancelTaskMock = vi.fn();
const getAgentCardMock = vi.fn();
const authFetchMock = vi.fn();
const mockClient = {
sendMessageStream: sendMessageStreamMock,
getTask: getTaskMock,
cancelTask: cancelTaskMock,
getAgentCard: getAgentCardMock,
} as unknown as Client;
const mockAgentCard: Partial<AgentCard> = { name: 'TestAgent' };
beforeEach(() => {
vi.clearAllMocks();
A2AClientManager.resetInstanceForTesting();
manager = A2AClientManager.getInstance();
// Default mock implementations
getAgentCardMock.mockResolvedValue({
// Re-create the instances as plain objects that can be spied on
const factoryInstance = {
createFromUrl: vi.fn(),
createFromAgentCard: vi.fn(),
};
const resolverInstance = {
resolve: vi.fn(),
};
vi.mocked(sdkClient.ClientFactory).mockReturnValue(
factoryInstance as unknown as sdkClient.ClientFactory,
);
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
);
vi.spyOn(factoryInstance, 'createFromUrl').mockResolvedValue(
mockClient as unknown as sdkClient.Client,
);
vi.spyOn(factoryInstance, 'createFromAgentCard').mockResolvedValue(
mockClient as unknown as sdkClient.Client,
);
vi.spyOn(resolverInstance, 'resolve').mockResolvedValue({
...mockAgentCard,
url: 'http://test.agent/real/endpoint',
} as AgentCard);
vi.mocked(ClientFactory.prototype.createFromUrl).mockResolvedValue(
mockClient,
vi.spyOn(sdkClient.ClientFactoryOptions, 'createFrom').mockImplementation(
(_defaults, overrides) =>
overrides as unknown as sdkClient.ClientFactoryOptions,
);
vi.mocked(DefaultAgentCardResolver.prototype.resolve).mockResolvedValue({
...mockAgentCard,
url: 'http://test.agent/real/endpoint',
} as AgentCard);
vi.mocked(ClientFactoryOptions.createFrom).mockImplementation(
(_defaults, overrides) => overrides as ClientFactoryOptions,
);
vi.mocked(createAuthenticatingFetchWithRetry).mockReturnValue(
authFetchMock,
vi.mocked(sdkClient.createAuthenticatingFetchWithRetry).mockImplementation(
() =>
authFetchMock.mockResolvedValue({
ok: true,
json: async () => ({}),
} as Response),
);
vi.stubGlobal(
@@ -123,137 +129,188 @@ describe('A2AClientManager', () => {
'TestAgent',
'http://test.agent/card',
);
expect(agentCard).toMatchObject(mockAgentCard);
expect(manager.getAgentCard('TestAgent')).toBe(agentCard);
expect(manager.getClient('TestAgent')).toBeDefined();
});
it('should configure ClientFactory with REST, JSON-RPC, and gRPC transports', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(sdkClient.ClientFactoryOptions.createFrom).toHaveBeenCalled();
});
it('should throw an error if an agent with the same name is already loaded', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
await expect(
manager.loadAgent('TestAgent', 'http://another.agent/card'),
manager.loadAgent('TestAgent', 'http://test.agent/card'),
).rejects.toThrow("Agent with name 'TestAgent' is already loaded.");
});
it('should use native fetch by default', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(createAuthenticatingFetchWithRetry).not.toHaveBeenCalled();
expect(
sdkClient.createAuthenticatingFetchWithRetry,
).not.toHaveBeenCalled();
});
it('should use provided custom authentication handler', async () => {
const customAuthHandler = {
headers: vi.fn(),
shouldRetryWithHeaders: vi.fn(),
const authHandler: sdkClient.AuthenticationHandler = {
headers: async () => ({}),
shouldRetryWithHeaders: async () => undefined,
};
await manager.loadAgent(
'CustomAuthAgent',
'http://custom.agent/card',
customAuthHandler as unknown as AuthenticationHandler,
);
expect(createAuthenticatingFetchWithRetry).toHaveBeenCalledWith(
expect.anything(),
customAuthHandler,
'TestAgent',
'http://test.agent/card',
authHandler,
);
expect(sdkClient.createAuthenticatingFetchWithRetry).toHaveBeenCalled();
});
it('should log a debug message upon loading an agent', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(debugLogger.debug).toHaveBeenCalledWith(
"[A2AClientManager] Loaded agent 'TestAgent' from http://test.agent/card",
expect.stringContaining("Loaded agent 'TestAgent'"),
);
});
it('should clear the cache', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(manager.getAgentCard('TestAgent')).toBeDefined();
expect(manager.getClient('TestAgent')).toBeDefined();
manager.clearCache();
expect(manager.getAgentCard('TestAgent')).toBeUndefined();
expect(manager.getClient('TestAgent')).toBeUndefined();
expect(debugLogger.debug).toHaveBeenCalledWith(
'[A2AClientManager] Cache cleared.',
});
it('should throw if resolveAgentCard fails', async () => {
const resolverInstance = {
resolve: vi.fn().mockRejectedValue(new Error('Resolution failed')),
};
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
);
await expect(
manager.loadAgent('FailAgent', 'http://fail.agent'),
).rejects.toThrow('Resolution failed');
});
it('should throw if factory.createFromAgentCard fails', async () => {
const factoryInstance = {
createFromAgentCard: vi
.fn()
.mockRejectedValue(new Error('Factory failed')),
};
vi.mocked(sdkClient.ClientFactory).mockReturnValue(
factoryInstance as unknown as sdkClient.ClientFactory,
);
await expect(
manager.loadAgent('FailAgent', 'http://fail.agent'),
).rejects.toThrow('Factory failed');
});
});
describe('getAgentCard and getClient', () => {
it('should return undefined if agent is not found', () => {
expect(manager.getAgentCard('Unknown')).toBeUndefined();
expect(manager.getClient('Unknown')).toBeUndefined();
});
});
describe('sendMessageStream', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
await manager.loadAgent('TestAgent', 'http://test.agent/card');
});
it('should send a message and return a stream', async () => {
const mockResult = {
kind: 'message',
messageId: 'a',
parts: [],
role: 'agent',
} as SendMessageResult;
sendMessageStreamMock.mockReturnValue(
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield mockResult;
yield { kind: 'message' };
})(),
);
const stream = manager.sendMessageStream('TestAgent', 'Hello');
const results = [];
for await (const res of stream) {
results.push(res);
for await (const result of stream) {
results.push(result);
}
expect(results).toEqual([mockResult]);
expect(sendMessageStreamMock).toHaveBeenCalledWith(
expect(results).toHaveLength(1);
expect(mockClient.sendMessageStream).toHaveBeenCalled();
});
it('should use contextId and taskId when provided', async () => {
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield { kind: 'message' };
})(),
);
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
contextId: 'ctx123',
taskId: 'task456',
});
await stream[Symbol.asyncIterator]().next();
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
expect.objectContaining({
message: expect.anything(),
message: expect.objectContaining({
contextId: 'ctx123',
taskId: 'task456',
}),
}),
expect.any(Object),
);
});
it('should use contextId and taskId when provided', async () => {
sendMessageStreamMock.mockReturnValue(
it('should correctly propagate AbortSignal to the stream', async () => {
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield {
kind: 'message',
messageId: 'a',
parts: [],
role: 'agent',
} as SendMessageResult;
yield { kind: 'message' };
})(),
);
const expectedContextId = 'user-context-id';
const expectedTaskId = 'user-task-id';
const controller = new AbortController();
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
contextId: expectedContextId,
taskId: expectedTaskId,
signal: controller.signal,
});
await stream[Symbol.asyncIterator]().next();
for await (const _ of stream) {
// consume stream
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
expect.any(Object),
expect.objectContaining({ signal: controller.signal }),
);
});
it('should handle a multi-chunk stream with different event types', async () => {
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield { kind: 'message', messageId: 'm1' };
yield { kind: 'status-update', taskId: 't1' };
})(),
);
const stream = manager.sendMessageStream('TestAgent', 'Hello');
const results = [];
for await (const result of stream) {
results.push(result);
}
const call = sendMessageStreamMock.mock.calls[0][0];
expect(call.message.contextId).toBe(expectedContextId);
expect(call.message.taskId).toBe(expectedTaskId);
expect(results).toHaveLength(2);
expect(results[0].kind).toBe('message');
expect(results[1].kind).toBe('status-update');
});
it('should throw prefixed error on failure', async () => {
sendMessageStreamMock.mockImplementationOnce(() => {
throw new Error('Network error');
mockClient.sendMessageStream.mockImplementation(() => {
throw new Error('Network failure');
});
const stream = manager.sendMessageStream('TestAgent', 'Hello');
await expect(async () => {
for await (const _ of stream) {
// consume
// empty
}
}).rejects.toThrow(
'[A2AClientManager] sendMessageStream Error [TestAgent]: Network error',
'[A2AClientManager] sendMessageStream Error [TestAgent]: Network failure',
);
});
@@ -261,7 +318,7 @@ describe('A2AClientManager', () => {
const stream = manager.sendMessageStream('NonExistentAgent', 'Hello');
await expect(async () => {
for await (const _ of stream) {
// consume
// empty
}
}).rejects.toThrow("Agent 'NonExistentAgent' not found.");
});
@@ -269,28 +326,23 @@ describe('A2AClientManager', () => {
describe('getTask', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
await manager.loadAgent('TestAgent', 'http://test.agent/card');
});
it('should get a task from the correct agent', async () => {
getTaskMock.mockResolvedValue({
id: 'task123',
contextId: 'a',
kind: 'task',
status: { state: 'completed' },
} as Task);
const mockTask = { id: 'task123', kind: 'task' };
mockClient.getTask.mockResolvedValue(mockTask);
await manager.getTask('TestAgent', 'task123');
expect(getTaskMock).toHaveBeenCalledWith({
id: 'task123',
});
const result = await manager.getTask('TestAgent', 'task123');
expect(result).toBe(mockTask);
expect(mockClient.getTask).toHaveBeenCalledWith({ id: 'task123' });
});
it('should throw prefixed error on failure', async () => {
getTaskMock.mockRejectedValueOnce(new Error('Network error'));
mockClient.getTask.mockRejectedValue(new Error('Not found'));
await expect(manager.getTask('TestAgent', 'task123')).rejects.toThrow(
'A2AClient getTask Error [TestAgent]: Network error',
'A2AClient getTask Error [TestAgent]: Not found',
);
});
@@ -303,28 +355,23 @@ describe('A2AClientManager', () => {
describe('cancelTask', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
await manager.loadAgent('TestAgent', 'http://test.agent/card');
});
it('should cancel a task on the correct agent', async () => {
cancelTaskMock.mockResolvedValue({
id: 'task123',
contextId: 'a',
kind: 'task',
status: { state: 'canceled' },
} as Task);
const mockTask = { id: 'task123', kind: 'task' };
mockClient.cancelTask.mockResolvedValue(mockTask);
await manager.cancelTask('TestAgent', 'task123');
expect(cancelTaskMock).toHaveBeenCalledWith({
id: 'task123',
});
const result = await manager.cancelTask('TestAgent', 'task123');
expect(result).toBe(mockTask);
expect(mockClient.cancelTask).toHaveBeenCalledWith({ id: 'task123' });
});
it('should throw prefixed error on failure', async () => {
cancelTaskMock.mockRejectedValueOnce(new Error('Network error'));
mockClient.cancelTask.mockRejectedValue(new Error('Cannot cancel'));
await expect(manager.cancelTask('TestAgent', 'task123')).rejects.toThrow(
'A2AClient cancelTask Error [TestAgent]: Network error',
'A2AClient cancelTask Error [TestAgent]: Cannot cancel',
);
});
@@ -334,4 +381,23 @@ describe('A2AClientManager', () => {
).rejects.toThrow("Agent 'NonExistentAgent' not found.");
});
});
describe('Protocol Routing & URL Logic', () => {
it('should correctly split URLs to prevent .well-known doubling', async () => {
const fullUrl = 'http://localhost:9001/.well-known/agent-card.json';
const resolverInstance = {
resolve: vi.fn().mockResolvedValue({ name: 'test' } as AgentCard),
};
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
);
await manager.loadAgent('test-doubling', fullUrl);
expect(resolverInstance.resolve).toHaveBeenCalledWith(
'http://localhost:9001/',
'.well-known/agent-card.json',
);
});
});
});
+50 -22
View File
@@ -13,7 +13,6 @@ import type {
TaskArtifactUpdateEvent,
} from '@a2a-js/sdk';
import {
type Client,
ClientFactory,
ClientFactoryOptions,
DefaultAgentCardResolver,
@@ -21,9 +20,12 @@ import {
JsonRpcTransportFactory,
type AuthenticationHandler,
createAuthenticatingFetchWithRetry,
type Client,
} from '@a2a-js/sdk/client';
import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc';
import { v4 as uuidv4 } from 'uuid';
import { Agent as UndiciAgent } from 'undici';
import { getGrpcCredentials, normalizeAgentCard } from './a2aUtils.js';
import { debugLogger } from '../utils/debugLogger.js';
// Remote agents can take 10+ minutes (e.g. Deep Research).
@@ -44,8 +46,10 @@ export type SendMessageResult =
| TaskArtifactUpdateEvent;
/**
* Manages A2A clients and caches loaded agent information.
* Follows a singleton pattern to ensure a single client instance.
* Orchestrates communication with A2A agents.
*
* This manager handles agent discovery, card caching, and client lifecycle.
* It provides a unified messaging interface using the standard A2A SDK.
*/
export class A2AClientManager {
private static instance: A2AClientManager;
@@ -91,27 +95,26 @@ export class A2AClientManager {
throw new Error(`Agent with name '${name}' is already loaded.`);
}
let fetchImpl: typeof fetch = a2aFetch;
if (authHandler) {
fetchImpl = createAuthenticatingFetchWithRetry(a2aFetch, authHandler);
}
const fetchImpl = this.getFetchImpl(authHandler);
const resolver = new DefaultAgentCardResolver({ fetchImpl });
const agentCard = await this.resolveAgentCard(agentCardUrl, resolver);
const options = ClientFactoryOptions.createFrom(
// Configure standard SDK client for tool registration and discovery
const clientOptions = ClientFactoryOptions.createFrom(
ClientFactoryOptions.default,
{
transports: [
new RestTransportFactory({ fetchImpl }),
new JsonRpcTransportFactory({ fetchImpl }),
new GrpcTransportFactory({
grpcChannelCredentials: getGrpcCredentials(agentCard.url),
}),
],
cardResolver: resolver,
},
);
const factory = new ClientFactory(options);
const client = await factory.createFromUrl(agentCardUrl, '');
const agentCard = await client.getAgentCard();
const factory = new ClientFactory(clientOptions);
const client = await factory.createFromAgentCard(agentCard);
this.clients.set(name, client);
this.agentCards.set(name, agentCard);
@@ -146,9 +149,7 @@ export class A2AClientManager {
options?: { contextId?: string; taskId?: string; signal?: AbortSignal },
): AsyncIterable<SendMessageResult> {
const client = this.clients.get(agentName);
if (!client) {
throw new Error(`Agent '${agentName}' not found.`);
}
if (!client) throw new Error(`Agent '${agentName}' not found.`);
const messageParams: MessageSendParams = {
message: {
@@ -202,9 +203,7 @@ export class A2AClientManager {
*/
async getTask(agentName: string, taskId: string): Promise<Task> {
const client = this.clients.get(agentName);
if (!client) {
throw new Error(`Agent '${agentName}' not found.`);
}
if (!client) throw new Error(`Agent '${agentName}' not found.`);
try {
return await client.getTask({ id: taskId });
} catch (error: unknown) {
@@ -224,9 +223,7 @@ export class A2AClientManager {
*/
async cancelTask(agentName: string, taskId: string): Promise<Task> {
const client = this.clients.get(agentName);
if (!client) {
throw new Error(`Agent '${agentName}' not found.`);
}
if (!client) throw new Error(`Agent '${agentName}' not found.`);
try {
return await client.cancelTask({ id: taskId });
} catch (error: unknown) {
@@ -237,4 +234,35 @@ export class A2AClientManager {
throw new Error(`${prefix}: Unexpected error: ${String(error)}`);
}
}
/**
* Resolves the appropriate fetch implementation for an agent.
*/
private getFetchImpl(authHandler?: AuthenticationHandler): typeof fetch {
return authHandler
? createAuthenticatingFetchWithRetry(a2aFetch, authHandler)
: a2aFetch;
}
/**
* Resolves and normalizes an agent card from a given URL.
* Handles splitting the URL if it already contains the standard .well-known path.
*/
private async resolveAgentCard(
url: string,
resolver: DefaultAgentCardResolver,
): Promise<AgentCard> {
const standardPath = '.well-known/agent-card.json';
let baseUrl = url;
let path: string | undefined;
if (baseUrl.includes(standardPath)) {
const parts = baseUrl.split(standardPath);
baseUrl = parts[0] || '';
path = standardPath;
}
const rawCard = await resolver.resolve(baseUrl, path);
return normalizeAgentCard(rawCard);
}
}
+140
View File
@@ -11,6 +11,8 @@ import {
isTerminalState,
A2AResultReassembler,
AUTH_REQUIRED_MSG,
normalizeAgentCard,
getGrpcCredentials,
} from './a2aUtils.js';
import type { SendMessageResult } from './a2a-client-manager.js';
import type {
@@ -24,6 +26,18 @@ import type {
} from '@a2a-js/sdk';
describe('a2aUtils', () => {
describe('getGrpcCredentials', () => {
it('should return secure credentials for https', () => {
const credentials = getGrpcCredentials('https://test.agent');
expect(credentials).toBeDefined();
});
it('should return insecure credentials for http', () => {
const credentials = getGrpcCredentials('http://test.agent');
expect(credentials).toBeDefined();
});
});
describe('isTerminalState', () => {
it('should return true for completed, failed, canceled, and rejected', () => {
expect(isTerminalState('completed')).toBe(true);
@@ -223,6 +237,119 @@ describe('a2aUtils', () => {
} as Message),
).toBe('');
});
it('should handle file parts with neither name nor uri', () => {
const message: Message = {
kind: 'message',
role: 'user',
messageId: '1',
parts: [
{
kind: 'file',
file: {
mimeType: 'text/plain',
},
} as FilePart,
],
};
expect(extractMessageText(message)).toBe('File: [binary/unnamed]');
});
});
describe('normalizeAgentCard', () => {
it('should throw if input is not an object', () => {
expect(() => normalizeAgentCard(null)).toThrow('Agent card is missing.');
expect(() => normalizeAgentCard(undefined)).toThrow(
'Agent card is missing.',
);
expect(() => normalizeAgentCard('not an object')).toThrow(
'Agent card is missing.',
);
});
it('should preserve unknown fields while providing defaults for mandatory ones', () => {
const raw = {
name: 'my-agent',
customField: 'keep-me',
};
const normalized = normalizeAgentCard(raw);
expect(normalized.name).toBe('my-agent');
// @ts-expect-error - testing dynamic preservation
expect(normalized.customField).toBe('keep-me');
expect(normalized.description).toBe('');
expect(normalized.skills).toEqual([]);
expect(normalized.defaultInputModes).toEqual([]);
});
it('should normalize and synchronize interfaces while preserving other fields', () => {
const raw = {
name: 'test',
supportedInterfaces: [
{
url: 'grpc://test',
protocolBinding: 'GRPC',
protocolVersion: '1.0',
},
],
};
const normalized = normalizeAgentCard(raw);
// Should exist in both fields
expect(normalized.additionalInterfaces).toHaveLength(1);
expect(
(normalized as unknown as Record<string, unknown>)[
'supportedInterfaces'
],
).toHaveLength(1);
const intf = normalized.additionalInterfaces?.[0] as unknown as Record<
string,
unknown
>;
expect(intf['transport']).toBe('GRPC');
expect(intf['url']).toBe('grpc://test');
// Should fallback top-level url
expect(normalized.url).toBe('grpc://test');
});
it('should preserve existing top-level url if present', () => {
const raw = {
name: 'test',
url: 'http://existing',
supportedInterfaces: [{ url: 'http://other', transport: 'REST' }],
};
const normalized = normalizeAgentCard(raw);
expect(normalized.url).toBe('http://existing');
});
it('should NOT prepend http:// scheme to raw IP:port strings for gRPC interfaces', () => {
const raw = {
name: 'raw-ip-grpc',
supportedInterfaces: [{ url: '127.0.0.1:9000', transport: 'GRPC' }],
};
const normalized = normalizeAgentCard(raw);
expect(normalized.additionalInterfaces?.[0].url).toBe('127.0.0.1:9000');
expect(normalized.url).toBe('127.0.0.1:9000');
});
it('should prepend http:// scheme to raw IP:port strings for REST interfaces', () => {
const raw = {
name: 'raw-ip-rest',
supportedInterfaces: [{ url: '127.0.0.1:8080', transport: 'REST' }],
};
const normalized = normalizeAgentCard(raw);
expect(normalized.additionalInterfaces?.[0].url).toBe(
'http://127.0.0.1:8080',
);
});
});
describe('A2AResultReassembler', () => {
@@ -233,6 +360,7 @@ describe('a2aUtils', () => {
reassembler.update({
kind: 'status-update',
taskId: 't1',
contextId: 'ctx1',
status: {
state: 'working',
message: {
@@ -247,6 +375,7 @@ describe('a2aUtils', () => {
reassembler.update({
kind: 'artifact-update',
taskId: 't1',
contextId: 'ctx1',
append: false,
artifact: {
artifactId: 'a1',
@@ -259,6 +388,7 @@ describe('a2aUtils', () => {
reassembler.update({
kind: 'status-update',
taskId: 't1',
contextId: 'ctx1',
status: {
state: 'working',
message: {
@@ -273,6 +403,7 @@ describe('a2aUtils', () => {
reassembler.update({
kind: 'artifact-update',
taskId: 't1',
contextId: 'ctx1',
append: true,
artifact: {
artifactId: 'a1',
@@ -291,6 +422,7 @@ describe('a2aUtils', () => {
reassembler.update({
kind: 'status-update',
contextId: 'ctx1',
status: {
state: 'auth-required',
message: {
@@ -310,6 +442,7 @@ describe('a2aUtils', () => {
reassembler.update({
kind: 'status-update',
contextId: 'ctx1',
status: {
state: 'auth-required',
},
@@ -323,6 +456,7 @@ describe('a2aUtils', () => {
const chunk = {
kind: 'status-update',
contextId: 'ctx1',
status: {
state: 'auth-required',
message: {
@@ -351,6 +485,8 @@ describe('a2aUtils', () => {
reassembler.update({
kind: 'task',
id: 'task-1',
contextId: 'ctx1',
status: { state: 'completed' },
history: [
{
@@ -369,6 +505,8 @@ describe('a2aUtils', () => {
reassembler.update({
kind: 'task',
id: 'task-1',
contextId: 'ctx1',
status: { state: 'working' },
history: [
{
@@ -387,6 +525,8 @@ describe('a2aUtils', () => {
reassembler.update({
kind: 'task',
id: 'task-1',
contextId: 'ctx1',
status: { state: 'completed' },
artifacts: [
{
+154 -22
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import * as grpc from '@grpc/grpc-js';
import type {
Message,
Part,
@@ -13,6 +14,8 @@ import type {
Artifact,
TaskState,
TaskStatusUpdateEvent,
AgentCard,
AgentInterface,
} from '@a2a-js/sdk';
import type { SendMessageResult } from './a2a-client-manager.js';
@@ -210,36 +213,61 @@ function extractPartText(part: Part): string {
return '';
}
// Type Guards
/**
* Normalizes an agent card by ensuring it has the required properties
* and resolving any inconsistencies between protocol versions.
*/
export function normalizeAgentCard(card: unknown): AgentCard {
if (!isObject(card)) {
throw new Error('Agent card is missing.');
}
function isTextPart(part: Part): part is TextPart {
return part.kind === 'text';
}
// Double-cast to bypass strict linter while bootstrapping the object.
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const result = { ...card } as unknown as AgentCard;
function isDataPart(part: Part): part is DataPart {
return part.kind === 'data';
}
// Ensure mandatory fields exist with safe defaults.
if (typeof result.name !== 'string') result.name = 'unknown';
if (typeof result.description !== 'string') result.description = '';
if (typeof result.url !== 'string') result.url = '';
if (typeof result.version !== 'string') result.version = '';
if (typeof result.protocolVersion !== 'string') result.protocolVersion = '';
if (!isObject(result.capabilities)) result.capabilities = {};
if (!Array.isArray(result.skills)) result.skills = [];
if (!Array.isArray(result.defaultInputModes)) result.defaultInputModes = [];
if (!Array.isArray(result.defaultOutputModes)) result.defaultOutputModes = [];
function isFilePart(part: Part): part is FilePart {
return part.kind === 'file';
}
// Normalize interfaces and synchronize both interface fields.
const normalizedInterfaces = extractNormalizedInterfaces(card);
result.additionalInterfaces = normalizedInterfaces;
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
(result as unknown as Record<string, AgentInterface[]>)[
'supportedInterfaces'
] = normalizedInterfaces;
function isStatusUpdateEvent(
result: SendMessageResult,
): result is TaskStatusUpdateEvent {
return result.kind === 'status-update';
// Fallback preferredTransport: If not specified, default to GRPC if available.
if (
!result.preferredTransport &&
normalizedInterfaces.some((i) => i.transport === 'GRPC')
) {
result.preferredTransport = 'GRPC';
}
// Fallback: If top-level URL is missing, use the first interface's URL.
if (result.url === '' && normalizedInterfaces.length > 0) {
result.url = normalizedInterfaces[0].url;
}
return result;
}
/**
* Returns true if the given state is a terminal state for a task.
* Returns gRPC channel credentials based on the URL scheme.
*/
export function isTerminalState(state: TaskState | undefined): boolean {
return (
state === 'completed' ||
state === 'failed' ||
state === 'canceled' ||
state === 'rejected'
);
export function getGrpcCredentials(url: string): grpc.ChannelCredentials {
return url.startsWith('https://')
? grpc.credentials.createSsl()
: grpc.credentials.createInsecure();
}
/**
@@ -279,3 +307,107 @@ export function extractIdsFromResponse(result: SendMessageResult): {
return { contextId, taskId, clearTaskId };
}
/**
* Extracts and normalizes interfaces from the card, handling protocol version fallbacks.
* Preserves all original fields to maintain SDK compatibility.
*/
function extractNormalizedInterfaces(
card: Record<string, unknown>,
): AgentInterface[] {
const rawInterfaces =
getArray(card, 'additionalInterfaces') ||
getArray(card, 'supportedInterfaces');
if (!rawInterfaces) {
return [];
}
const mapped: AgentInterface[] = [];
for (const i of rawInterfaces) {
if (isObject(i)) {
// Create a copy to preserve all original fields.
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const normalized = { ...i } as unknown as AgentInterface & {
protocolBinding?: string;
};
// Ensure 'url' exists
if (typeof normalized.url !== 'string') {
normalized.url = '';
}
// Normalize 'transport' from 'protocolBinding'
const transport = normalized.transport || normalized.protocolBinding;
if (transport) {
normalized.transport = transport;
}
// Robust URL: Ensure the URL has a scheme (except for gRPC).
// Some agent implementations (like a2a-go samples) may provide raw IP:port strings.
// gRPC targets MUST NOT have a scheme (e.g. 'http://'), or they will fail name resolution.
if (
normalized.url &&
!normalized.url.includes('://') &&
!normalized.url.startsWith('/') &&
normalized.transport !== 'GRPC'
) {
// Default to http:// for insecure REST/JSON-RPC if scheme is missing.
normalized.url = `http://${normalized.url}`;
}
mapped.push(normalized);
}
}
return mapped;
}
/**
* Safely extracts an array property from an object.
*/
function getArray(
obj: Record<string, unknown>,
key: string,
): unknown[] | undefined {
const val = obj[key];
return Array.isArray(val) ? val : undefined;
}
// Type Guards
function isTextPart(part: Part): part is TextPart {
return part.kind === 'text';
}
function isDataPart(part: Part): part is DataPart {
return part.kind === 'data';
}
function isFilePart(part: Part): part is FilePart {
return part.kind === 'file';
}
function isStatusUpdateEvent(
result: SendMessageResult,
): result is TaskStatusUpdateEvent {
return result.kind === 'status-update';
}
/**
* Returns true if the given state is a terminal state for a task.
*/
export function isTerminalState(state: TaskState | undefined): boolean {
return (
state === 'completed' ||
state === 'failed' ||
state === 'canceled' ||
state === 'rejected'
);
}
/**
* Type guard to check if a value is a non-array object.
*/
function isObject(val: unknown): val is Record<string, unknown> {
return typeof val === 'object' && val !== null && !Array.isArray(val);
}
+25
View File
@@ -874,6 +874,31 @@ describe('AgentRegistry', () => {
);
});
it('should maintain registration under canonical name even if overrides are applied', async () => {
const originalName = 'my-agent';
const definition = { ...MOCK_AGENT_V1, name: originalName };
// Mock overrides in settings
vi.spyOn(mockConfig, 'getAgentsSettings').mockReturnValue({
overrides: {
[originalName]: {
enabled: true,
modelConfig: { model: 'overridden-model' },
},
},
});
await registry.testRegisterAgent(definition);
const registered = registry.getDefinition(originalName);
expect(registered).toBeDefined();
expect((registered as LocalAgentDefinition).modelConfig.model).toBe(
'overridden-model',
);
// Ensure it is NOT registered under some other key
expect(registry.getAllAgentNames()).toEqual([originalName]);
});
it('should reject an agent definition missing a name', async () => {
const invalidAgent = { ...MOCK_AGENT_V1, name: '' };
const debugWarnSpy = vi
+1 -1
View File
@@ -312,7 +312,7 @@ export class AgentRegistry {
}
const mergedDefinition = this.applyOverrides(definition, settingsOverrides);
this.agents.set(mergedDefinition.name, mergedDefinition);
this.agents.set(definition.name, mergedDefinition);
this.registerModelConfigs(mergedDefinition);
this.addAgentPolicy(mergedDefinition);