feat: Add A2A Client Manager and tests (#15485)

This commit is contained in:
Adam Weidman
2025-12-23 15:27:16 -05:00
committed by GitHub
parent 563d81e08e
commit 02a36afc38
5 changed files with 561 additions and 21 deletions

View File

@@ -0,0 +1,305 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import {
A2AClientManager,
type SendMessageResult,
} from './a2a-client-manager.js';
import type { AgentCard, Task } from '@a2a-js/sdk';
import type { AuthenticationHandler, Client } from '@a2a-js/sdk/client';
import { ClientFactory, DefaultAgentCardResolver } from '@a2a-js/sdk/client';
import { debugLogger } from '../utils/debugLogger.js';
import {
createAuthenticatingFetchWithRetry,
ClientFactoryOptions,
} from '@a2a-js/sdk/client';
vi.mock('../utils/debugLogger.js', () => ({
debugLogger: {
debug: vi.fn(),
},
}));
vi.mock('@a2a-js/sdk/client', () => {
const ClientFactory = vi.fn();
const DefaultAgentCardResolver = vi.fn();
const RestTransportFactory = vi.fn();
const JsonRpcTransportFactory = vi.fn();
const ClientFactoryOptions = {
default: {},
createFrom: vi.fn(),
};
const createAuthenticatingFetchWithRetry = vi.fn();
DefaultAgentCardResolver.prototype.resolve = vi.fn();
ClientFactory.prototype.createFromUrl = vi.fn();
return {
ClientFactory,
ClientFactoryOptions,
DefaultAgentCardResolver,
RestTransportFactory,
JsonRpcTransportFactory,
createAuthenticatingFetchWithRetry,
};
});
describe('A2AClientManager', () => {
let manager: A2AClientManager;
// Stable mocks initialized once
const sendMessageMock = vi.fn();
const getTaskMock = vi.fn();
const cancelTaskMock = vi.fn();
const getAgentCardMock = vi.fn();
const authFetchMock = vi.fn();
const mockClient = {
sendMessage: sendMessageMock,
getTask: getTaskMock,
cancelTask: cancelTaskMock,
getAgentCard: getAgentCardMock,
} as unknown as Client;
const mockAgentCard: Partial<AgentCard> = { name: 'TestAgent' };
beforeEach(() => {
vi.clearAllMocks();
A2AClientManager.resetInstanceForTesting();
manager = A2AClientManager.getInstance();
// Default mock implementations
getAgentCardMock.mockResolvedValue({
...mockAgentCard,
url: 'http://test.agent/real/endpoint',
} as AgentCard);
vi.mocked(ClientFactory.prototype.createFromUrl).mockResolvedValue(
mockClient,
);
vi.mocked(DefaultAgentCardResolver.prototype.resolve).mockResolvedValue({
...mockAgentCard,
url: 'http://test.agent/real/endpoint',
} as AgentCard);
vi.mocked(ClientFactoryOptions.createFrom).mockImplementation(
(_defaults, overrides) => overrides as ClientFactoryOptions,
);
vi.mocked(createAuthenticatingFetchWithRetry).mockReturnValue(
authFetchMock,
);
vi.stubGlobal(
'fetch',
vi.fn().mockResolvedValue({
ok: true,
json: async () => ({}),
} as Response),
);
});
afterEach(() => {
vi.restoreAllMocks();
vi.unstubAllGlobals();
});
it('should enforce the singleton pattern', () => {
const instance1 = A2AClientManager.getInstance();
const instance2 = A2AClientManager.getInstance();
expect(instance1).toBe(instance2);
});
describe('loadAgent', () => {
it('should create and cache an A2AClient', async () => {
const agentCard = await manager.loadAgent(
'TestAgent',
'http://test.agent/card',
);
expect(agentCard).toMatchObject(mockAgentCard);
expect(manager.getAgentCard('TestAgent')).toBe(agentCard);
expect(manager.getClient('TestAgent')).toBeDefined();
});
it('should throw an error if an agent with the same name is already loaded', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
await expect(
manager.loadAgent('TestAgent', 'http://another.agent/card'),
).rejects.toThrow("Agent with name 'TestAgent' is already loaded.");
});
it('should use native fetch by default', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(createAuthenticatingFetchWithRetry).not.toHaveBeenCalled();
});
it('should use provided custom authentication handler', async () => {
const customAuthHandler = {
headers: vi.fn(),
shouldRetryWithHeaders: vi.fn(),
};
await manager.loadAgent(
'CustomAuthAgent',
'http://custom.agent/card',
customAuthHandler as unknown as AuthenticationHandler,
);
expect(createAuthenticatingFetchWithRetry).toHaveBeenCalledWith(
expect.anything(),
customAuthHandler,
);
});
it('should log a debug message upon loading an agent', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(debugLogger.debug).toHaveBeenCalledWith(
"[A2AClientManager] Loaded agent 'TestAgent' from http://test.agent/card",
);
});
});
describe('sendMessage', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
});
it('should send a message to the correct agent', async () => {
sendMessageMock.mockResolvedValue({
kind: 'message',
messageId: 'a',
parts: [],
role: 'agent',
} as SendMessageResult);
await manager.sendMessage('TestAgent', 'Hello');
expect(sendMessageMock).toHaveBeenCalledWith(
expect.objectContaining({
message: expect.anything(),
}),
);
});
it('should use contextId and taskId when provided', async () => {
sendMessageMock.mockResolvedValue({
kind: 'message',
messageId: 'a',
parts: [],
role: 'agent',
} as SendMessageResult);
const expectedContextId = 'user-context-id';
const expectedTaskId = 'user-task-id';
await manager.sendMessage('TestAgent', 'Hello', {
contextId: expectedContextId,
taskId: expectedTaskId,
});
const call = sendMessageMock.mock.calls[0][0];
expect(call.message.contextId).toBe(expectedContextId);
expect(call.message.taskId).toBe(expectedTaskId);
});
it('should return result from client', async () => {
const mockResult = {
contextId: 'server-context-id',
id: 'ctx-1',
kind: 'task',
status: { state: 'working' },
};
sendMessageMock.mockResolvedValueOnce(mockResult as SendMessageResult);
const response = await manager.sendMessage('TestAgent', 'Hello');
expect(response).toEqual(mockResult);
});
it('should throw prefixed error on failure', async () => {
sendMessageMock.mockRejectedValueOnce(new Error('Network error'));
await expect(manager.sendMessage('TestAgent', 'Hello')).rejects.toThrow(
'A2AClient SendMessage Error [TestAgent]: Network error',
);
});
it('should throw an error if the agent is not found', async () => {
await expect(
manager.sendMessage('NonExistentAgent', 'Hello'),
).rejects.toThrow("Agent 'NonExistentAgent' not found.");
});
});
describe('getTask', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
});
it('should get a task from the correct agent', async () => {
getTaskMock.mockResolvedValue({
id: 'task123',
contextId: 'a',
kind: 'task',
status: { state: 'completed' },
} as Task);
await manager.getTask('TestAgent', 'task123');
expect(getTaskMock).toHaveBeenCalledWith({
id: 'task123',
});
});
it('should throw prefixed error on failure', async () => {
getTaskMock.mockRejectedValueOnce(new Error('Network error'));
await expect(manager.getTask('TestAgent', 'task123')).rejects.toThrow(
'A2AClient getTask Error [TestAgent]: Network error',
);
});
it('should throw an error if the agent is not found', async () => {
await expect(
manager.getTask('NonExistentAgent', 'task123'),
).rejects.toThrow("Agent 'NonExistentAgent' not found.");
});
});
describe('cancelTask', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
});
it('should cancel a task on the correct agent', async () => {
cancelTaskMock.mockResolvedValue({
id: 'task123',
contextId: 'a',
kind: 'task',
status: { state: 'canceled' },
} as Task);
await manager.cancelTask('TestAgent', 'task123');
expect(cancelTaskMock).toHaveBeenCalledWith({
id: 'task123',
});
});
it('should throw prefixed error on failure', async () => {
cancelTaskMock.mockRejectedValueOnce(new Error('Network error'));
await expect(manager.cancelTask('TestAgent', 'task123')).rejects.toThrow(
'A2AClient cancelTask Error [TestAgent]: Network error',
);
});
it('should throw an error if the agent is not found', async () => {
await expect(
manager.cancelTask('NonExistentAgent', 'task123'),
).rejects.toThrow("Agent 'NonExistentAgent' not found.");
});
});
});

View File

@@ -0,0 +1,209 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { AgentCard, Message, MessageSendParams, Task } from '@a2a-js/sdk';
import {
type Client,
ClientFactory,
ClientFactoryOptions,
DefaultAgentCardResolver,
RestTransportFactory,
JsonRpcTransportFactory,
type AuthenticationHandler,
createAuthenticatingFetchWithRetry,
} from '@a2a-js/sdk/client';
import { v4 as uuidv4 } from 'uuid';
import { debugLogger } from '../utils/debugLogger.js';
export type SendMessageResult = Message | Task;
/**
* Manages A2A clients and caches loaded agent information.
* Follows a singleton pattern to ensure a single client instance.
*/
export class A2AClientManager {
private static instance: A2AClientManager;
// Each agent should manage their own context/taskIds/card/etc
private clients = new Map<string, Client>();
private agentCards = new Map<string, AgentCard>();
private constructor() {}
/**
* Gets the singleton instance of the A2AClientManager.
*/
static getInstance(): A2AClientManager {
if (!A2AClientManager.instance) {
A2AClientManager.instance = new A2AClientManager();
}
return A2AClientManager.instance;
}
/**
* Resets the singleton instance. Only for testing purposes.
* @internal
*/
static resetInstanceForTesting() {
// @ts-expect-error - Resetting singleton for testing
A2AClientManager.instance = undefined;
}
/**
* Loads an agent by fetching its AgentCard and caches the client.
* @param name The name to assign to the agent.
* @param agentCardUrl The full URL to the agent's card.
* @param authHandler Optional authentication handler to use for this agent.
* @returns The loaded AgentCard.
*/
async loadAgent(
name: string,
agentCardUrl: string,
authHandler?: AuthenticationHandler,
): Promise<AgentCard> {
if (this.clients.has(name)) {
throw new Error(`Agent with name '${name}' is already loaded.`);
}
let fetchImpl = fetch;
if (authHandler) {
fetchImpl = createAuthenticatingFetchWithRetry(fetch, authHandler);
}
const resolver = new DefaultAgentCardResolver({ fetchImpl });
const options = ClientFactoryOptions.createFrom(
ClientFactoryOptions.default,
{
transports: [
new RestTransportFactory({ fetchImpl }),
new JsonRpcTransportFactory({ fetchImpl }),
],
cardResolver: resolver,
},
);
const factory = new ClientFactory(options);
const client = await factory.createFromUrl(agentCardUrl, '');
const agentCard = await client.getAgentCard();
this.clients.set(name, client);
this.agentCards.set(name, agentCard);
debugLogger.debug(
`[A2AClientManager] Loaded agent '${name}' from ${agentCardUrl}`,
);
return agentCard;
}
/**
* Sends a message to a loaded agent.
* @param agentName The name of the agent to send the message to.
* @param message The message content.
* @param options Optional context and task IDs to maintain conversation state.
* @returns The response from the agent (Message or Task).
* @throws Error if the agent returns an error response.
*/
async sendMessage(
agentName: string,
message: string,
options?: { contextId?: string; taskId?: string },
): Promise<SendMessageResult> {
const client = this.clients.get(agentName);
if (!client) {
throw new Error(`Agent '${agentName}' not found.`);
}
const messageParams: MessageSendParams = {
message: {
kind: 'message',
role: 'user',
messageId: uuidv4(),
parts: [{ kind: 'text', text: message }],
contextId: options?.contextId,
taskId: options?.taskId,
},
configuration: {
blocking: true,
},
};
try {
return await client.sendMessage(messageParams);
} catch (error: unknown) {
const prefix = `A2AClient SendMessage Error [${agentName}]`;
if (error instanceof Error) {
throw new Error(`${prefix}: ${error.message}`, { cause: error });
}
throw new Error(
`${prefix}: Unexpected error during sendMessage: ${String(error)}`,
);
}
}
/**
* Retrieves a loaded agent card.
* @param name The name of the agent.
* @returns The agent card, or undefined if not found.
*/
getAgentCard(name: string): AgentCard | undefined {
return this.agentCards.get(name);
}
/**
* Retrieves a loaded client.
* @param name The name of the agent.
* @returns The client, or undefined if not found.
*/
getClient(name: string): Client | undefined {
return this.clients.get(name);
}
/**
* Retrieves a task from an agent.
* @param agentName The name of the agent.
* @param taskId The ID of the task to retrieve.
* @returns The task details.
*/
async getTask(agentName: string, taskId: string): Promise<Task> {
const client = this.clients.get(agentName);
if (!client) {
throw new Error(`Agent '${agentName}' not found.`);
}
try {
return await client.getTask({ id: taskId });
} catch (error: unknown) {
const prefix = `A2AClient getTask Error [${agentName}]`;
if (error instanceof Error) {
throw new Error(`${prefix}: ${error.message}`, { cause: error });
}
throw new Error(`${prefix}: Unexpected error: ${String(error)}`);
}
}
/**
* Cancels a task on an agent.
* @param agentName The name of the agent.
* @param taskId The ID of the task to cancel.
* @returns The cancellation response.
*/
async cancelTask(agentName: string, taskId: string): Promise<Task> {
const client = this.clients.get(agentName);
if (!client) {
throw new Error(`Agent '${agentName}' not found.`);
}
try {
return await client.cancelTask({ id: taskId });
} catch (error: unknown) {
const prefix = `A2AClient cancelTask Error [${agentName}]`;
if (error instanceof Error) {
throw new Error(`${prefix}: ${error.message}`, { cause: error });
}
throw new Error(`${prefix}: Unexpected error: ${String(error)}`);
}
}
}