mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-04 00:44:05 -07:00
feat(a2a): enable native gRPC support and protocol routing (#21403)
Co-authored-by: Adam Weidman <adamfweidman@google.com>
This commit is contained in:
@@ -5,11 +5,8 @@
|
||||
*/
|
||||
|
||||
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
A2AClientManager,
|
||||
type SendMessageResult,
|
||||
} from './a2a-client-manager.js';
|
||||
import type { AgentCard, Task } from '@a2a-js/sdk';
|
||||
import { A2AClientManager } from './a2a-client-manager.js';
|
||||
import type { AgentCard } from '@a2a-js/sdk';
|
||||
import {
|
||||
ClientFactory,
|
||||
DefaultAgentCardResolver,
|
||||
@@ -22,81 +19,95 @@ import type { Config } from '../config/config.js';
|
||||
import { Agent as UndiciAgent, ProxyAgent } from 'undici';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
interface MockClient {
|
||||
sendMessageStream: ReturnType<typeof vi.fn>;
|
||||
getTask: ReturnType<typeof vi.fn>;
|
||||
cancelTask: ReturnType<typeof vi.fn>;
|
||||
}
|
||||
|
||||
vi.mock('@a2a-js/sdk/client', async (importOriginal) => {
|
||||
const actual = await importOriginal();
|
||||
return {
|
||||
...(actual as Record<string, unknown>),
|
||||
createAuthenticatingFetchWithRetry: vi.fn(),
|
||||
ClientFactory: vi.fn(),
|
||||
DefaultAgentCardResolver: vi.fn(),
|
||||
ClientFactoryOptions: {
|
||||
createFrom: vi.fn(),
|
||||
default: {},
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('../utils/debugLogger.js', () => ({
|
||||
debugLogger: {
|
||||
debug: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock('@a2a-js/sdk/client', () => {
|
||||
const ClientFactory = vi.fn();
|
||||
const DefaultAgentCardResolver = vi.fn();
|
||||
const RestTransportFactory = vi.fn();
|
||||
const JsonRpcTransportFactory = vi.fn();
|
||||
const ClientFactoryOptions = {
|
||||
default: {},
|
||||
createFrom: vi.fn(),
|
||||
};
|
||||
const createAuthenticatingFetchWithRetry = vi.fn();
|
||||
|
||||
DefaultAgentCardResolver.prototype.resolve = vi.fn();
|
||||
ClientFactory.prototype.createFromUrl = vi.fn();
|
||||
|
||||
return {
|
||||
ClientFactory,
|
||||
ClientFactoryOptions,
|
||||
DefaultAgentCardResolver,
|
||||
RestTransportFactory,
|
||||
JsonRpcTransportFactory,
|
||||
createAuthenticatingFetchWithRetry,
|
||||
};
|
||||
});
|
||||
|
||||
describe('A2AClientManager', () => {
|
||||
let manager: A2AClientManager;
|
||||
const mockAgentCard: AgentCard = {
|
||||
name: 'test-agent',
|
||||
description: 'A test agent',
|
||||
url: 'http://test.agent',
|
||||
version: '1.0.0',
|
||||
protocolVersion: '0.1.0',
|
||||
capabilities: {},
|
||||
skills: [],
|
||||
defaultInputModes: [],
|
||||
defaultOutputModes: [],
|
||||
};
|
||||
|
||||
const mockClient: MockClient = {
|
||||
sendMessageStream: vi.fn(),
|
||||
getTask: vi.fn(),
|
||||
cancelTask: vi.fn(),
|
||||
};
|
||||
|
||||
// Stable mocks initialized once
|
||||
const sendMessageStreamMock = vi.fn();
|
||||
const getTaskMock = vi.fn();
|
||||
const cancelTaskMock = vi.fn();
|
||||
const getAgentCardMock = vi.fn();
|
||||
const authFetchMock = vi.fn();
|
||||
|
||||
const mockClient = {
|
||||
sendMessageStream: sendMessageStreamMock,
|
||||
getTask: getTaskMock,
|
||||
cancelTask: cancelTaskMock,
|
||||
getAgentCard: getAgentCardMock,
|
||||
} as unknown as Client;
|
||||
|
||||
const mockAgentCard: Partial<AgentCard> = { name: 'TestAgent' };
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
A2AClientManager.resetInstanceForTesting();
|
||||
manager = A2AClientManager.getInstance();
|
||||
|
||||
// Default mock implementations
|
||||
getAgentCardMock.mockResolvedValue({
|
||||
// Re-create the instances as plain objects that can be spied on
|
||||
const factoryInstance = {
|
||||
createFromUrl: vi.fn(),
|
||||
createFromAgentCard: vi.fn(),
|
||||
};
|
||||
const resolverInstance = {
|
||||
resolve: vi.fn(),
|
||||
};
|
||||
|
||||
vi.mocked(ClientFactory).mockReturnValue(
|
||||
factoryInstance as unknown as ClientFactory,
|
||||
);
|
||||
vi.mocked(DefaultAgentCardResolver).mockReturnValue(
|
||||
resolverInstance as unknown as DefaultAgentCardResolver,
|
||||
);
|
||||
|
||||
vi.spyOn(factoryInstance, 'createFromUrl').mockResolvedValue(
|
||||
mockClient as unknown as Client,
|
||||
);
|
||||
vi.spyOn(factoryInstance, 'createFromAgentCard').mockResolvedValue(
|
||||
mockClient as unknown as Client,
|
||||
);
|
||||
vi.spyOn(resolverInstance, 'resolve').mockResolvedValue({
|
||||
...mockAgentCard,
|
||||
url: 'http://test.agent/real/endpoint',
|
||||
} as AgentCard);
|
||||
|
||||
vi.mocked(ClientFactory.prototype.createFromUrl).mockResolvedValue(
|
||||
mockClient,
|
||||
vi.spyOn(ClientFactoryOptions, 'createFrom').mockImplementation(
|
||||
(_defaults, overrides) => overrides as unknown as ClientFactoryOptions,
|
||||
);
|
||||
|
||||
vi.mocked(DefaultAgentCardResolver.prototype.resolve).mockResolvedValue({
|
||||
...mockAgentCard,
|
||||
url: 'http://test.agent/real/endpoint',
|
||||
} as AgentCard);
|
||||
|
||||
vi.mocked(ClientFactoryOptions.createFrom).mockImplementation(
|
||||
(_defaults, overrides) => overrides as ClientFactoryOptions,
|
||||
);
|
||||
|
||||
vi.mocked(createAuthenticatingFetchWithRetry).mockReturnValue(
|
||||
authFetchMock,
|
||||
vi.mocked(createAuthenticatingFetchWithRetry).mockImplementation(() =>
|
||||
authFetchMock.mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response),
|
||||
);
|
||||
|
||||
vi.stubGlobal(
|
||||
@@ -170,15 +181,19 @@ describe('A2AClientManager', () => {
|
||||
'TestAgent',
|
||||
'http://test.agent/card',
|
||||
);
|
||||
expect(agentCard).toMatchObject(mockAgentCard);
|
||||
expect(manager.getAgentCard('TestAgent')).toBe(agentCard);
|
||||
expect(manager.getClient('TestAgent')).toBeDefined();
|
||||
});
|
||||
|
||||
it('should configure ClientFactory with REST, JSON-RPC, and gRPC transports', async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
expect(ClientFactoryOptions.createFrom).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should throw an error if an agent with the same name is already loaded', async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
await expect(
|
||||
manager.loadAgent('TestAgent', 'http://another.agent/card'),
|
||||
manager.loadAgent('TestAgent', 'http://test.agent/card'),
|
||||
).rejects.toThrow("Agent with name 'TestAgent' is already loaded.");
|
||||
});
|
||||
|
||||
@@ -193,20 +208,12 @@ describe('A2AClientManager', () => {
|
||||
shouldRetryWithHeaders: vi.fn(),
|
||||
};
|
||||
await manager.loadAgent(
|
||||
'CustomAuthAgent',
|
||||
'http://custom.agent/card',
|
||||
'TestAgent',
|
||||
'http://test.agent/card',
|
||||
customAuthHandler as unknown as AuthenticationHandler,
|
||||
);
|
||||
|
||||
expect(createAuthenticatingFetchWithRetry).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
customAuthHandler,
|
||||
);
|
||||
|
||||
// Card resolver should NOT use the authenticated fetch by default.
|
||||
const resolverInstance = vi.mocked(DefaultAgentCardResolver).mock
|
||||
.instances[0];
|
||||
expect(resolverInstance).toBeDefined();
|
||||
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock
|
||||
.calls[0][0];
|
||||
expect(resolverOptions?.fetchImpl).not.toBe(authFetchMock);
|
||||
@@ -267,106 +274,163 @@ describe('A2AClientManager', () => {
|
||||
it('should log a debug message upon loading an agent', async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
expect(debugLogger.debug).toHaveBeenCalledWith(
|
||||
"[A2AClientManager] Loaded agent 'TestAgent' from http://test.agent/card",
|
||||
expect.stringContaining("Loaded agent 'TestAgent'"),
|
||||
);
|
||||
});
|
||||
|
||||
it('should clear the cache', async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
expect(manager.getAgentCard('TestAgent')).toBeDefined();
|
||||
expect(manager.getClient('TestAgent')).toBeDefined();
|
||||
|
||||
manager.clearCache();
|
||||
|
||||
expect(manager.getAgentCard('TestAgent')).toBeUndefined();
|
||||
expect(manager.getClient('TestAgent')).toBeUndefined();
|
||||
expect(debugLogger.debug).toHaveBeenCalledWith(
|
||||
'[A2AClientManager] Cache cleared.',
|
||||
});
|
||||
|
||||
it('should throw if resolveAgentCard fails', async () => {
|
||||
const resolverInstance = {
|
||||
resolve: vi.fn().mockRejectedValue(new Error('Resolution failed')),
|
||||
};
|
||||
vi.mocked(DefaultAgentCardResolver).mockReturnValue(
|
||||
resolverInstance as unknown as DefaultAgentCardResolver,
|
||||
);
|
||||
|
||||
await expect(
|
||||
manager.loadAgent('FailAgent', 'http://fail.agent'),
|
||||
).rejects.toThrow('Resolution failed');
|
||||
});
|
||||
|
||||
it('should throw if factory.createFromAgentCard fails', async () => {
|
||||
const factoryInstance = {
|
||||
createFromAgentCard: vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Factory failed')),
|
||||
};
|
||||
vi.mocked(ClientFactory).mockReturnValue(
|
||||
factoryInstance as unknown as ClientFactory,
|
||||
);
|
||||
|
||||
await expect(
|
||||
manager.loadAgent('FailAgent', 'http://fail.agent'),
|
||||
).rejects.toThrow('Factory failed');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAgentCard and getClient', () => {
|
||||
it('should return undefined if agent is not found', () => {
|
||||
expect(manager.getAgentCard('Unknown')).toBeUndefined();
|
||||
expect(manager.getClient('Unknown')).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('sendMessageStream', () => {
|
||||
beforeEach(async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent');
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
});
|
||||
|
||||
it('should send a message and return a stream', async () => {
|
||||
const mockResult = {
|
||||
kind: 'message',
|
||||
messageId: 'a',
|
||||
parts: [],
|
||||
role: 'agent',
|
||||
} as SendMessageResult;
|
||||
|
||||
sendMessageStreamMock.mockReturnValue(
|
||||
mockClient.sendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield mockResult;
|
||||
yield { kind: 'message' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello');
|
||||
const results = [];
|
||||
for await (const res of stream) {
|
||||
results.push(res);
|
||||
for await (const result of stream) {
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
expect(results).toEqual([mockResult]);
|
||||
expect(sendMessageStreamMock).toHaveBeenCalledWith(
|
||||
expect(results).toHaveLength(1);
|
||||
expect(mockClient.sendMessageStream).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should use contextId and taskId when provided', async () => {
|
||||
mockClient.sendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { kind: 'message' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
|
||||
contextId: 'ctx123',
|
||||
taskId: 'task456',
|
||||
});
|
||||
// trigger execution
|
||||
for await (const _ of stream) {
|
||||
break;
|
||||
}
|
||||
|
||||
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
message: expect.anything(),
|
||||
message: expect.objectContaining({
|
||||
contextId: 'ctx123',
|
||||
taskId: 'task456',
|
||||
}),
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use contextId and taskId when provided', async () => {
|
||||
sendMessageStreamMock.mockReturnValue(
|
||||
it('should correctly propagate AbortSignal to the stream', async () => {
|
||||
mockClient.sendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'a',
|
||||
parts: [],
|
||||
role: 'agent',
|
||||
} as SendMessageResult;
|
||||
yield { kind: 'message' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const expectedContextId = 'user-context-id';
|
||||
const expectedTaskId = 'user-task-id';
|
||||
|
||||
const controller = new AbortController();
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
|
||||
contextId: expectedContextId,
|
||||
taskId: expectedTaskId,
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
// trigger execution
|
||||
for await (const _ of stream) {
|
||||
// consume stream
|
||||
break;
|
||||
}
|
||||
|
||||
const call = sendMessageStreamMock.mock.calls[0][0];
|
||||
expect(call.message.contextId).toBe(expectedContextId);
|
||||
expect(call.message.taskId).toBe(expectedTaskId);
|
||||
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
expect.objectContaining({ signal: controller.signal }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should propagate the original error on failure', async () => {
|
||||
sendMessageStreamMock.mockImplementationOnce(() => {
|
||||
throw new Error('Network error');
|
||||
it('should handle a multi-chunk stream with different event types', async () => {
|
||||
mockClient.sendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { kind: 'message', messageId: 'm1' };
|
||||
yield { kind: 'status-update', taskId: 't1' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello');
|
||||
const results = [];
|
||||
for await (const result of stream) {
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
expect(results).toHaveLength(2);
|
||||
expect(results[0].kind).toBe('message');
|
||||
expect(results[1].kind).toBe('status-update');
|
||||
});
|
||||
|
||||
it('should throw prefixed error on failure', async () => {
|
||||
mockClient.sendMessageStream.mockImplementation(() => {
|
||||
throw new Error('Network failure');
|
||||
});
|
||||
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello');
|
||||
await expect(async () => {
|
||||
for await (const _ of stream) {
|
||||
// consume
|
||||
// empty
|
||||
}
|
||||
}).rejects.toThrow('Network error');
|
||||
}).rejects.toThrow(
|
||||
'[A2AClientManager] sendMessageStream Error [TestAgent]: Network failure',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if the agent is not found', async () => {
|
||||
const stream = manager.sendMessageStream('NonExistentAgent', 'Hello');
|
||||
await expect(async () => {
|
||||
for await (const _ of stream) {
|
||||
// consume
|
||||
// empty
|
||||
}
|
||||
}).rejects.toThrow("Agent 'NonExistentAgent' not found.");
|
||||
});
|
||||
@@ -374,28 +438,23 @@ describe('A2AClientManager', () => {
|
||||
|
||||
describe('getTask', () => {
|
||||
beforeEach(async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent');
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
});
|
||||
|
||||
it('should get a task from the correct agent', async () => {
|
||||
getTaskMock.mockResolvedValue({
|
||||
id: 'task123',
|
||||
contextId: 'a',
|
||||
kind: 'task',
|
||||
status: { state: 'completed' },
|
||||
} as Task);
|
||||
const mockTask = { id: 'task123', kind: 'task' };
|
||||
mockClient.getTask.mockResolvedValue(mockTask);
|
||||
|
||||
await manager.getTask('TestAgent', 'task123');
|
||||
expect(getTaskMock).toHaveBeenCalledWith({
|
||||
id: 'task123',
|
||||
});
|
||||
const result = await manager.getTask('TestAgent', 'task123');
|
||||
expect(result).toBe(mockTask);
|
||||
expect(mockClient.getTask).toHaveBeenCalledWith({ id: 'task123' });
|
||||
});
|
||||
|
||||
it('should throw prefixed error on failure', async () => {
|
||||
getTaskMock.mockRejectedValueOnce(new Error('Network error'));
|
||||
mockClient.getTask.mockRejectedValue(new Error('Not found'));
|
||||
|
||||
await expect(manager.getTask('TestAgent', 'task123')).rejects.toThrow(
|
||||
'A2AClient getTask Error [TestAgent]: Network error',
|
||||
'A2AClient getTask Error [TestAgent]: Not found',
|
||||
);
|
||||
});
|
||||
|
||||
@@ -408,28 +467,23 @@ describe('A2AClientManager', () => {
|
||||
|
||||
describe('cancelTask', () => {
|
||||
beforeEach(async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent');
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
});
|
||||
|
||||
it('should cancel a task on the correct agent', async () => {
|
||||
cancelTaskMock.mockResolvedValue({
|
||||
id: 'task123',
|
||||
contextId: 'a',
|
||||
kind: 'task',
|
||||
status: { state: 'canceled' },
|
||||
} as Task);
|
||||
const mockTask = { id: 'task123', kind: 'task' };
|
||||
mockClient.cancelTask.mockResolvedValue(mockTask);
|
||||
|
||||
await manager.cancelTask('TestAgent', 'task123');
|
||||
expect(cancelTaskMock).toHaveBeenCalledWith({
|
||||
id: 'task123',
|
||||
});
|
||||
const result = await manager.cancelTask('TestAgent', 'task123');
|
||||
expect(result).toBe(mockTask);
|
||||
expect(mockClient.cancelTask).toHaveBeenCalledWith({ id: 'task123' });
|
||||
});
|
||||
|
||||
it('should throw prefixed error on failure', async () => {
|
||||
cancelTaskMock.mockRejectedValueOnce(new Error('Network error'));
|
||||
mockClient.cancelTask.mockRejectedValue(new Error('Cannot cancel'));
|
||||
|
||||
await expect(manager.cancelTask('TestAgent', 'task123')).rejects.toThrow(
|
||||
'A2AClient cancelTask Error [TestAgent]: Network error',
|
||||
'A2AClient cancelTask Error [TestAgent]: Cannot cancel',
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user