Files
gemini-cli/packages/core/src/agents/a2a-client-manager.test.ts
2026-03-10 15:24:44 +00:00

398 lines
12 KiB
TypeScript

/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import {
A2AClientManager,
type SendMessageResult,
} from './a2a-client-manager.js';
import type { AgentCard, Task } from '@a2a-js/sdk';
import {
ClientFactory,
DefaultAgentCardResolver,
createAuthenticatingFetchWithRetry,
ClientFactoryOptions,
type AuthenticationHandler,
type Client,
} from '@a2a-js/sdk/client';
import { debugLogger } from '../utils/debugLogger.js';
vi.mock('../utils/debugLogger.js', () => ({
debugLogger: {
debug: vi.fn(),
},
}));
vi.mock('@a2a-js/sdk/client', () => {
const ClientFactory = vi.fn();
const DefaultAgentCardResolver = vi.fn();
const RestTransportFactory = vi.fn();
const JsonRpcTransportFactory = vi.fn();
const ClientFactoryOptions = {
default: {},
createFrom: vi.fn(),
};
const createAuthenticatingFetchWithRetry = vi.fn();
DefaultAgentCardResolver.prototype.resolve = vi.fn();
ClientFactory.prototype.createFromUrl = vi.fn();
return {
ClientFactory,
ClientFactoryOptions,
DefaultAgentCardResolver,
RestTransportFactory,
JsonRpcTransportFactory,
createAuthenticatingFetchWithRetry,
};
});
describe('A2AClientManager', () => {
let manager: A2AClientManager;
// Stable mocks initialized once
const sendMessageStreamMock = vi.fn();
const getTaskMock = vi.fn();
const cancelTaskMock = vi.fn();
const getAgentCardMock = vi.fn();
const authFetchMock = vi.fn();
const mockClient = {
sendMessageStream: sendMessageStreamMock,
getTask: getTaskMock,
cancelTask: cancelTaskMock,
getAgentCard: getAgentCardMock,
} as unknown as Client;
const mockAgentCard: Partial<AgentCard> = { name: 'TestAgent' };
beforeEach(() => {
vi.clearAllMocks();
A2AClientManager.resetInstanceForTesting();
manager = A2AClientManager.getInstance();
// Default mock implementations
getAgentCardMock.mockResolvedValue({
...mockAgentCard,
url: 'http://test.agent/real/endpoint',
} as AgentCard);
vi.mocked(ClientFactory.prototype.createFromUrl).mockResolvedValue(
mockClient,
);
vi.mocked(DefaultAgentCardResolver.prototype.resolve).mockResolvedValue({
...mockAgentCard,
url: 'http://test.agent/real/endpoint',
} as AgentCard);
vi.mocked(ClientFactoryOptions.createFrom).mockImplementation(
(_defaults, overrides) => overrides as ClientFactoryOptions,
);
vi.mocked(createAuthenticatingFetchWithRetry).mockReturnValue(
authFetchMock,
);
vi.stubGlobal(
'fetch',
vi.fn().mockResolvedValue({
ok: true,
json: async () => ({}),
} as Response),
);
});
afterEach(() => {
vi.restoreAllMocks();
vi.unstubAllGlobals();
});
it('should enforce the singleton pattern', () => {
const instance1 = A2AClientManager.getInstance();
const instance2 = A2AClientManager.getInstance();
expect(instance1).toBe(instance2);
});
describe('loadAgent', () => {
it('should create and cache an A2AClient', async () => {
const agentCard = await manager.loadAgent(
'TestAgent',
'http://test.agent/card',
);
expect(agentCard).toMatchObject(mockAgentCard);
expect(manager.getAgentCard('TestAgent')).toBe(agentCard);
expect(manager.getClient('TestAgent')).toBeDefined();
});
it('should throw an error if an agent with the same name is already loaded', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
await expect(
manager.loadAgent('TestAgent', 'http://another.agent/card'),
).rejects.toThrow("Agent with name 'TestAgent' is already loaded.");
});
it('should use native fetch by default', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(createAuthenticatingFetchWithRetry).not.toHaveBeenCalled();
});
it('should use provided custom authentication handler for transports only', async () => {
const customAuthHandler = {
headers: vi.fn(),
shouldRetryWithHeaders: vi.fn(),
};
await manager.loadAgent(
'CustomAuthAgent',
'http://custom.agent/card',
customAuthHandler as unknown as AuthenticationHandler,
);
expect(createAuthenticatingFetchWithRetry).toHaveBeenCalledWith(
expect.anything(),
customAuthHandler,
);
// Card resolver should NOT use the authenticated fetch by default.
const resolverInstance = vi.mocked(DefaultAgentCardResolver).mock
.instances[0];
expect(resolverInstance).toBeDefined();
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock
.calls[0][0];
expect(resolverOptions?.fetchImpl).not.toBe(authFetchMock);
});
it('should use unauthenticated fetch for card resolver and avoid authenticated fetch if success', async () => {
const customAuthHandler = {
headers: vi.fn(),
shouldRetryWithHeaders: vi.fn(),
};
await manager.loadAgent(
'AuthCardAgent',
'http://authcard.agent/card',
customAuthHandler as unknown as AuthenticationHandler,
);
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock
.calls[0][0];
const cardFetch = resolverOptions?.fetchImpl as typeof fetch;
expect(cardFetch).toBeDefined();
await cardFetch('http://test.url');
expect(fetch).toHaveBeenCalledWith('http://test.url', expect.anything());
expect(authFetchMock).not.toHaveBeenCalled();
});
it('should retry with authenticating fetch if agent card fetch returns 401', async () => {
const customAuthHandler = {
headers: vi.fn(),
shouldRetryWithHeaders: vi.fn(),
};
// Mock the initial unauthenticated fetch to fail with 401
vi.mocked(fetch).mockResolvedValueOnce({
ok: false,
status: 401,
json: async () => ({}),
} as Response);
await manager.loadAgent(
'AuthCardAgent401',
'http://authcard.agent/card',
customAuthHandler as unknown as AuthenticationHandler,
);
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock
.calls[0][0];
const cardFetch = resolverOptions?.fetchImpl as typeof fetch;
await cardFetch('http://test.url');
expect(fetch).toHaveBeenCalledWith('http://test.url', expect.anything());
expect(authFetchMock).toHaveBeenCalledWith('http://test.url', undefined);
});
it('should log a debug message upon loading an agent', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(debugLogger.debug).toHaveBeenCalledWith(
"[A2AClientManager] Loaded agent 'TestAgent' from http://test.agent/card",
);
});
it('should clear the cache', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(manager.getAgentCard('TestAgent')).toBeDefined();
expect(manager.getClient('TestAgent')).toBeDefined();
manager.clearCache();
expect(manager.getAgentCard('TestAgent')).toBeUndefined();
expect(manager.getClient('TestAgent')).toBeUndefined();
expect(debugLogger.debug).toHaveBeenCalledWith(
'[A2AClientManager] Cache cleared.',
);
});
});
describe('sendMessageStream', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
});
it('should send a message and return a stream', async () => {
const mockResult = {
kind: 'message',
messageId: 'a',
parts: [],
role: 'agent',
} as SendMessageResult;
sendMessageStreamMock.mockReturnValue(
(async function* () {
yield mockResult;
})(),
);
const stream = manager.sendMessageStream('TestAgent', 'Hello');
const results = [];
for await (const res of stream) {
results.push(res);
}
expect(results).toEqual([mockResult]);
expect(sendMessageStreamMock).toHaveBeenCalledWith(
expect.objectContaining({
message: expect.anything(),
}),
expect.any(Object),
);
});
it('should use contextId and taskId when provided', async () => {
sendMessageStreamMock.mockReturnValue(
(async function* () {
yield {
kind: 'message',
messageId: 'a',
parts: [],
role: 'agent',
} as SendMessageResult;
})(),
);
const expectedContextId = 'user-context-id';
const expectedTaskId = 'user-task-id';
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
contextId: expectedContextId,
taskId: expectedTaskId,
});
for await (const _ of stream) {
// consume stream
}
const call = sendMessageStreamMock.mock.calls[0][0];
expect(call.message.contextId).toBe(expectedContextId);
expect(call.message.taskId).toBe(expectedTaskId);
});
it('should throw prefixed error on failure', async () => {
sendMessageStreamMock.mockImplementationOnce(() => {
throw new Error('Network error');
});
const stream = manager.sendMessageStream('TestAgent', 'Hello');
await expect(async () => {
for await (const _ of stream) {
// consume
}
}).rejects.toThrow(
'[A2AClientManager] sendMessageStream Error [TestAgent]: Network error',
);
});
it('should throw an error if the agent is not found', async () => {
const stream = manager.sendMessageStream('NonExistentAgent', 'Hello');
await expect(async () => {
for await (const _ of stream) {
// consume
}
}).rejects.toThrow("Agent 'NonExistentAgent' not found.");
});
});
describe('getTask', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
});
it('should get a task from the correct agent', async () => {
getTaskMock.mockResolvedValue({
id: 'task123',
contextId: 'a',
kind: 'task',
status: { state: 'completed' },
} as Task);
await manager.getTask('TestAgent', 'task123');
expect(getTaskMock).toHaveBeenCalledWith({
id: 'task123',
});
});
it('should throw prefixed error on failure', async () => {
getTaskMock.mockRejectedValueOnce(new Error('Network error'));
await expect(manager.getTask('TestAgent', 'task123')).rejects.toThrow(
'A2AClient getTask Error [TestAgent]: Network error',
);
});
it('should throw an error if the agent is not found', async () => {
await expect(
manager.getTask('NonExistentAgent', 'task123'),
).rejects.toThrow("Agent 'NonExistentAgent' not found.");
});
});
describe('cancelTask', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
});
it('should cancel a task on the correct agent', async () => {
cancelTaskMock.mockResolvedValue({
id: 'task123',
contextId: 'a',
kind: 'task',
status: { state: 'canceled' },
} as Task);
await manager.cancelTask('TestAgent', 'task123');
expect(cancelTaskMock).toHaveBeenCalledWith({
id: 'task123',
});
});
it('should throw prefixed error on failure', async () => {
cancelTaskMock.mockRejectedValueOnce(new Error('Network error'));
await expect(manager.cancelTask('TestAgent', 'task123')).rejects.toThrow(
'A2AClient cancelTask Error [TestAgent]: Network error',
);
});
it('should throw an error if the agent is not found', async () => {
await expect(
manager.cancelTask('NonExistentAgent', 'task123'),
).rejects.toThrow("Agent 'NonExistentAgent' not found.");
});
});
});