diff --git a/package-lock.json b/package-lock.json index 85448711c7..036fdc8ec1 100644 --- a/package-lock.json +++ b/package-lock.json @@ -84,9 +84,9 @@ } }, "node_modules/@a2a-js/sdk": { - "version": "0.3.8", - "resolved": "https://registry.npmjs.org/@a2a-js/sdk/-/sdk-0.3.8.tgz", - "integrity": "sha512-vAg6JQbhOnHTzApsB7nGzCQ9r7PuY4GMr8gt88dIR8Wc8G8RSqVTyTmFeMurgzcYrtHYXS3ru2rnDoGj9UDeSw==", + "version": "0.3.10", + "resolved": "https://registry.npmjs.org/@a2a-js/sdk/-/sdk-0.3.10.tgz", + "integrity": "sha512-t6w5ctnwJkSOMRl6M9rn95C1FTHCPqixxMR0yWXtzhZXEnF6mF1NAK0CfKlG3cz+tcwTxkmn287QZC3t9XPgrA==", "license": "Apache-2.0", "dependencies": { "uuid": "^11.1.0" @@ -95,9 +95,17 @@ "node": ">=18" }, "peerDependencies": { + "@bufbuild/protobuf": "^2.10.2", + "@grpc/grpc-js": "^1.11.0", "express": "^4.21.2 || ^5.1.0" }, "peerDependenciesMeta": { + "@bufbuild/protobuf": { + "optional": true + }, + "@grpc/grpc-js": { + "optional": true + }, "express": { "optional": true } @@ -515,6 +523,12 @@ "node": ">=18" } }, + "node_modules/@bufbuild/protobuf": { + "version": "2.11.0", + "resolved": "https://registry.npmjs.org/@bufbuild/protobuf/-/protobuf-2.11.0.tgz", + "integrity": "sha512-sBXGT13cpmPR5BMgHE6UEEfEaShh5Ror6rfN3yEK5si7QVrtZg8LEPQb0VVhiLRUslD2yLnXtnRzG035J/mZXQ==", + "license": "(Apache-2.0 AND BSD-3-Clause)" + }, "node_modules/@bundled-es-modules/cookie": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/@bundled-es-modules/cookie/-/cookie-2.0.1.tgz", @@ -1582,18 +1596,36 @@ } }, "node_modules/@grpc/grpc-js": { - "version": "1.13.4", - "resolved": "https://registry.npmjs.org/@grpc/grpc-js/-/grpc-js-1.13.4.tgz", - "integrity": "sha512-GsFaMXCkMqkKIvwCQjCrwH+GHbPKBjhwo/8ZuUkWHqbI73Kky9I+pQltrlT0+MWpedCoosda53lgjYfyEPgxBg==", + "version": "1.14.3", + "resolved": "https://registry.npmjs.org/@grpc/grpc-js/-/grpc-js-1.14.3.tgz", + "integrity": "sha512-Iq8QQQ/7X3Sac15oB6p0FmUg/klxQvXLeileoqrTRGJYLV+/9tubbr9ipz0GKHjmXVsgFPo/+W+2cA8eNcR+XA==", "license": "Apache-2.0", "dependencies": { - "@grpc/proto-loader": "^0.7.13", + "@grpc/proto-loader": "^0.8.0", "@js-sdsl/ordered-map": "^4.4.2" }, "engines": { "node": ">=12.10.0" } }, + "node_modules/@grpc/grpc-js/node_modules/@grpc/proto-loader": { + "version": "0.8.0", + "resolved": "https://registry.npmjs.org/@grpc/proto-loader/-/proto-loader-0.8.0.tgz", + "integrity": "sha512-rc1hOQtjIWGxcxpb9aHAfLpIctjEnsDehj0DAiVfBlmT84uvR0uUtN2hEi/ecvWVjXUGf5qPF4qEgiLOx1YIMQ==", + "license": "Apache-2.0", + "dependencies": { + "lodash.camelcase": "^4.3.0", + "long": "^5.0.0", + "protobufjs": "^7.5.3", + "yargs": "^17.7.2" + }, + "bin": { + "proto-loader-gen-types": "build/bin/proto-loader-gen-types.js" + }, + "engines": { + "node": ">=6" + } + }, "node_modules/@grpc/proto-loader": { "version": "0.7.15", "resolved": "https://registry.npmjs.org/@grpc/proto-loader/-/proto-loader-0.7.15.tgz", @@ -17447,11 +17479,13 @@ "version": "0.34.0-nightly.20260304.28af4e127", "license": "Apache-2.0", "dependencies": { - "@a2a-js/sdk": "^0.3.8", + "@a2a-js/sdk": "^0.3.10", + "@bufbuild/protobuf": "^2.11.0", "@google-cloud/logging": "^11.2.1", "@google-cloud/opentelemetry-cloud-monitoring-exporter": "^0.21.0", "@google-cloud/opentelemetry-cloud-trace-exporter": "^3.0.0", "@google/genai": "1.41.0", + "@grpc/grpc-js": "^1.14.3", "@iarna/toml": "^2.2.5", "@joshua.litt/get-ripgrep": "^0.0.3", "@modelcontextprotocol/sdk": "^1.23.0", diff --git a/packages/a2a-server/src/config/config.test.ts b/packages/a2a-server/src/config/config.test.ts index ee63df36f7..e7f9227ae2 100644 --- a/packages/a2a-server/src/config/config.test.ts +++ b/packages/a2a-server/src/config/config.test.ts @@ -316,6 +316,35 @@ describe('loadConfig', () => { ); }); + it('should pass agent settings to Config', async () => { + const settings: Settings = { + experimental: { + enableAgents: false, + }, + agents: { + overrides: { + test_agent: { enabled: true }, + }, + }, + }; + await loadConfig(settings, mockExtensionLoader, taskId); + expect(Config).toHaveBeenCalledWith( + expect.objectContaining({ + enableAgents: false, + agents: settings.agents, + }), + ); + }); + + it('should default enableAgents to true if not specified', async () => { + await loadConfig(mockSettings, mockExtensionLoader, taskId); + expect(Config).toHaveBeenCalledWith( + expect.objectContaining({ + enableAgents: true, + }), + ); + }); + describe('interactivity', () => { it('should set interactive true when not headless', async () => { vi.mocked(isHeadlessMode).mockReturnValue(false); diff --git a/packages/a2a-server/src/config/config.ts b/packages/a2a-server/src/config/config.ts index 1b236f9ac7..225dbdbcf3 100644 --- a/packages/a2a-server/src/config/config.ts +++ b/packages/a2a-server/src/config/config.ts @@ -109,6 +109,8 @@ export async function loadConfig( interactive: !isHeadlessMode(), enableInteractiveShell: !isHeadlessMode(), ptyInfo: 'auto', + enableAgents: settings.experimental?.enableAgents ?? true, + agents: settings.agents, }; const fileService = new FileDiscoveryService(workspaceDir, { diff --git a/packages/a2a-server/src/config/settings.test.ts b/packages/a2a-server/src/config/settings.test.ts index 7c51950535..131c220b14 100644 --- a/packages/a2a-server/src/config/settings.test.ts +++ b/packages/a2a-server/src/config/settings.test.ts @@ -112,6 +112,24 @@ describe('loadSettings', () => { expect(result.fileFiltering?.respectGitIgnore).toBe(true); }); + it('should load experimental and agents settings correctly', () => { + const settings = { + experimental: { + enableAgents: true, + }, + agents: { + overrides: { + test_agent: { enabled: false }, + }, + }, + }; + fs.writeFileSync(USER_SETTINGS_PATH, JSON.stringify(settings)); + + const result = loadSettings(mockWorkspaceDir); + expect(result.experimental?.enableAgents).toBe(true); + expect(result.agents?.overrides?.['test_agent']?.enabled).toBe(false); + }); + it('should overwrite top-level settings from workspace (shallow merge)', () => { const userSettings = { showMemoryUsage: false, diff --git a/packages/a2a-server/src/config/settings.ts b/packages/a2a-server/src/config/settings.ts index b3c44cc177..a67be2f67b 100644 --- a/packages/a2a-server/src/config/settings.ts +++ b/packages/a2a-server/src/config/settings.ts @@ -14,6 +14,7 @@ import { getErrorMessage, type TelemetrySettings, homedir, + type AgentSettings, } from '@google/gemini-cli-core'; import stripJsonComments from 'strip-json-comments'; @@ -45,6 +46,10 @@ export interface Settings { enableRecursiveFileSearch?: boolean; customIgnoreFilePaths?: string[]; }; + experimental?: { + enableAgents?: boolean; + }; + agents?: AgentSettings; } export interface SettingsError { diff --git a/packages/core/package.json b/packages/core/package.json index 827c09bc61..dd2aa29706 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -21,11 +21,13 @@ "dist" ], "dependencies": { - "@a2a-js/sdk": "^0.3.8", + "@a2a-js/sdk": "^0.3.10", + "@bufbuild/protobuf": "^2.11.0", "@google-cloud/logging": "^11.2.1", "@google-cloud/opentelemetry-cloud-monitoring-exporter": "^0.21.0", "@google-cloud/opentelemetry-cloud-trace-exporter": "^3.0.0", "@google/genai": "1.41.0", + "@grpc/grpc-js": "^1.14.3", "@iarna/toml": "^2.2.5", "@joshua.litt/get-ripgrep": "^0.0.3", "@modelcontextprotocol/sdk": "^1.23.0", diff --git a/packages/core/src/agents/a2a-client-manager.test.ts b/packages/core/src/agents/a2a-client-manager.test.ts index 68189a6771..34c322fe89 100644 --- a/packages/core/src/agents/a2a-client-manager.test.ts +++ b/packages/core/src/agents/a2a-client-manager.test.ts @@ -5,96 +5,102 @@ */ import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest'; -import { - A2AClientManager, - type SendMessageResult, -} from './a2a-client-manager.js'; -import type { AgentCard, Task } from '@a2a-js/sdk'; -import { - ClientFactory, - DefaultAgentCardResolver, - createAuthenticatingFetchWithRetry, - ClientFactoryOptions, - type AuthenticationHandler, - type Client, -} from '@a2a-js/sdk/client'; +import { A2AClientManager } from './a2a-client-manager.js'; +import type { AgentCard } from '@a2a-js/sdk'; +import * as sdkClient from '@a2a-js/sdk/client'; import { debugLogger } from '../utils/debugLogger.js'; +interface MockClient { + sendMessageStream: ReturnType; + getTask: ReturnType; + cancelTask: ReturnType; +} + +vi.mock('@a2a-js/sdk/client', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...(actual as Record), + createAuthenticatingFetchWithRetry: vi.fn(), + ClientFactory: vi.fn(), + DefaultAgentCardResolver: vi.fn(), + ClientFactoryOptions: { + createFrom: vi.fn(), + default: {}, + }, + }; +}); + vi.mock('../utils/debugLogger.js', () => ({ debugLogger: { debug: vi.fn(), }, })); -vi.mock('@a2a-js/sdk/client', () => { - const ClientFactory = vi.fn(); - const DefaultAgentCardResolver = vi.fn(); - const RestTransportFactory = vi.fn(); - const JsonRpcTransportFactory = vi.fn(); - const ClientFactoryOptions = { - default: {}, - createFrom: vi.fn(), - }; - const createAuthenticatingFetchWithRetry = vi.fn(); - - DefaultAgentCardResolver.prototype.resolve = vi.fn(); - ClientFactory.prototype.createFromUrl = vi.fn(); - - return { - ClientFactory, - ClientFactoryOptions, - DefaultAgentCardResolver, - RestTransportFactory, - JsonRpcTransportFactory, - createAuthenticatingFetchWithRetry, - }; -}); - describe('A2AClientManager', () => { let manager: A2AClientManager; + const mockAgentCard: AgentCard = { + name: 'test-agent', + description: 'A test agent', + url: 'http://test.agent', + version: '1.0.0', + protocolVersion: '0.1.0', + capabilities: {}, + skills: [], + defaultInputModes: [], + defaultOutputModes: [], + }; + + const mockClient: MockClient = { + sendMessageStream: vi.fn(), + getTask: vi.fn(), + cancelTask: vi.fn(), + }; - // Stable mocks initialized once - const sendMessageStreamMock = vi.fn(); - const getTaskMock = vi.fn(); - const cancelTaskMock = vi.fn(); - const getAgentCardMock = vi.fn(); const authFetchMock = vi.fn(); - const mockClient = { - sendMessageStream: sendMessageStreamMock, - getTask: getTaskMock, - cancelTask: cancelTaskMock, - getAgentCard: getAgentCardMock, - } as unknown as Client; - - const mockAgentCard: Partial = { name: 'TestAgent' }; - beforeEach(() => { vi.clearAllMocks(); A2AClientManager.resetInstanceForTesting(); manager = A2AClientManager.getInstance(); - // Default mock implementations - getAgentCardMock.mockResolvedValue({ + // Re-create the instances as plain objects that can be spied on + const factoryInstance = { + createFromUrl: vi.fn(), + createFromAgentCard: vi.fn(), + }; + const resolverInstance = { + resolve: vi.fn(), + }; + + vi.mocked(sdkClient.ClientFactory).mockReturnValue( + factoryInstance as unknown as sdkClient.ClientFactory, + ); + vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue( + resolverInstance as unknown as sdkClient.DefaultAgentCardResolver, + ); + + vi.spyOn(factoryInstance, 'createFromUrl').mockResolvedValue( + mockClient as unknown as sdkClient.Client, + ); + vi.spyOn(factoryInstance, 'createFromAgentCard').mockResolvedValue( + mockClient as unknown as sdkClient.Client, + ); + vi.spyOn(resolverInstance, 'resolve').mockResolvedValue({ ...mockAgentCard, url: 'http://test.agent/real/endpoint', } as AgentCard); - vi.mocked(ClientFactory.prototype.createFromUrl).mockResolvedValue( - mockClient, + vi.spyOn(sdkClient.ClientFactoryOptions, 'createFrom').mockImplementation( + (_defaults, overrides) => + overrides as unknown as sdkClient.ClientFactoryOptions, ); - vi.mocked(DefaultAgentCardResolver.prototype.resolve).mockResolvedValue({ - ...mockAgentCard, - url: 'http://test.agent/real/endpoint', - } as AgentCard); - - vi.mocked(ClientFactoryOptions.createFrom).mockImplementation( - (_defaults, overrides) => overrides as ClientFactoryOptions, - ); - - vi.mocked(createAuthenticatingFetchWithRetry).mockReturnValue( - authFetchMock, + vi.mocked(sdkClient.createAuthenticatingFetchWithRetry).mockImplementation( + () => + authFetchMock.mockResolvedValue({ + ok: true, + json: async () => ({}), + } as Response), ); vi.stubGlobal( @@ -123,137 +129,188 @@ describe('A2AClientManager', () => { 'TestAgent', 'http://test.agent/card', ); - expect(agentCard).toMatchObject(mockAgentCard); expect(manager.getAgentCard('TestAgent')).toBe(agentCard); expect(manager.getClient('TestAgent')).toBeDefined(); }); + it('should configure ClientFactory with REST, JSON-RPC, and gRPC transports', async () => { + await manager.loadAgent('TestAgent', 'http://test.agent/card'); + expect(sdkClient.ClientFactoryOptions.createFrom).toHaveBeenCalled(); + }); + it('should throw an error if an agent with the same name is already loaded', async () => { await manager.loadAgent('TestAgent', 'http://test.agent/card'); await expect( - manager.loadAgent('TestAgent', 'http://another.agent/card'), + manager.loadAgent('TestAgent', 'http://test.agent/card'), ).rejects.toThrow("Agent with name 'TestAgent' is already loaded."); }); it('should use native fetch by default', async () => { await manager.loadAgent('TestAgent', 'http://test.agent/card'); - expect(createAuthenticatingFetchWithRetry).not.toHaveBeenCalled(); + expect( + sdkClient.createAuthenticatingFetchWithRetry, + ).not.toHaveBeenCalled(); }); it('should use provided custom authentication handler', async () => { - const customAuthHandler = { - headers: vi.fn(), - shouldRetryWithHeaders: vi.fn(), + const authHandler: sdkClient.AuthenticationHandler = { + headers: async () => ({}), + shouldRetryWithHeaders: async () => undefined, }; await manager.loadAgent( - 'CustomAuthAgent', - 'http://custom.agent/card', - customAuthHandler as unknown as AuthenticationHandler, - ); - - expect(createAuthenticatingFetchWithRetry).toHaveBeenCalledWith( - expect.anything(), - customAuthHandler, + 'TestAgent', + 'http://test.agent/card', + authHandler, ); + expect(sdkClient.createAuthenticatingFetchWithRetry).toHaveBeenCalled(); }); it('should log a debug message upon loading an agent', async () => { await manager.loadAgent('TestAgent', 'http://test.agent/card'); expect(debugLogger.debug).toHaveBeenCalledWith( - "[A2AClientManager] Loaded agent 'TestAgent' from http://test.agent/card", + expect.stringContaining("Loaded agent 'TestAgent'"), ); }); it('should clear the cache', async () => { await manager.loadAgent('TestAgent', 'http://test.agent/card'); - expect(manager.getAgentCard('TestAgent')).toBeDefined(); - expect(manager.getClient('TestAgent')).toBeDefined(); - manager.clearCache(); - expect(manager.getAgentCard('TestAgent')).toBeUndefined(); expect(manager.getClient('TestAgent')).toBeUndefined(); - expect(debugLogger.debug).toHaveBeenCalledWith( - '[A2AClientManager] Cache cleared.', + }); + + it('should throw if resolveAgentCard fails', async () => { + const resolverInstance = { + resolve: vi.fn().mockRejectedValue(new Error('Resolution failed')), + }; + vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue( + resolverInstance as unknown as sdkClient.DefaultAgentCardResolver, ); + + await expect( + manager.loadAgent('FailAgent', 'http://fail.agent'), + ).rejects.toThrow('Resolution failed'); + }); + + it('should throw if factory.createFromAgentCard fails', async () => { + const factoryInstance = { + createFromAgentCard: vi + .fn() + .mockRejectedValue(new Error('Factory failed')), + }; + vi.mocked(sdkClient.ClientFactory).mockReturnValue( + factoryInstance as unknown as sdkClient.ClientFactory, + ); + + await expect( + manager.loadAgent('FailAgent', 'http://fail.agent'), + ).rejects.toThrow('Factory failed'); + }); + }); + + describe('getAgentCard and getClient', () => { + it('should return undefined if agent is not found', () => { + expect(manager.getAgentCard('Unknown')).toBeUndefined(); + expect(manager.getClient('Unknown')).toBeUndefined(); }); }); describe('sendMessageStream', () => { beforeEach(async () => { - await manager.loadAgent('TestAgent', 'http://test.agent'); + await manager.loadAgent('TestAgent', 'http://test.agent/card'); }); it('should send a message and return a stream', async () => { - const mockResult = { - kind: 'message', - messageId: 'a', - parts: [], - role: 'agent', - } as SendMessageResult; - - sendMessageStreamMock.mockReturnValue( + mockClient.sendMessageStream.mockReturnValue( (async function* () { - yield mockResult; + yield { kind: 'message' }; })(), ); const stream = manager.sendMessageStream('TestAgent', 'Hello'); const results = []; - for await (const res of stream) { - results.push(res); + for await (const result of stream) { + results.push(result); } - expect(results).toEqual([mockResult]); - expect(sendMessageStreamMock).toHaveBeenCalledWith( + expect(results).toHaveLength(1); + expect(mockClient.sendMessageStream).toHaveBeenCalled(); + }); + + it('should use contextId and taskId when provided', async () => { + mockClient.sendMessageStream.mockReturnValue( + (async function* () { + yield { kind: 'message' }; + })(), + ); + + const stream = manager.sendMessageStream('TestAgent', 'Hello', { + contextId: 'ctx123', + taskId: 'task456', + }); + await stream[Symbol.asyncIterator]().next(); + + expect(mockClient.sendMessageStream).toHaveBeenCalledWith( expect.objectContaining({ - message: expect.anything(), + message: expect.objectContaining({ + contextId: 'ctx123', + taskId: 'task456', + }), }), expect.any(Object), ); }); - it('should use contextId and taskId when provided', async () => { - sendMessageStreamMock.mockReturnValue( + it('should correctly propagate AbortSignal to the stream', async () => { + mockClient.sendMessageStream.mockReturnValue( (async function* () { - yield { - kind: 'message', - messageId: 'a', - parts: [], - role: 'agent', - } as SendMessageResult; + yield { kind: 'message' }; })(), ); - const expectedContextId = 'user-context-id'; - const expectedTaskId = 'user-task-id'; - + const controller = new AbortController(); const stream = manager.sendMessageStream('TestAgent', 'Hello', { - contextId: expectedContextId, - taskId: expectedTaskId, + signal: controller.signal, }); + await stream[Symbol.asyncIterator]().next(); - for await (const _ of stream) { - // consume stream + expect(mockClient.sendMessageStream).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ signal: controller.signal }), + ); + }); + + it('should handle a multi-chunk stream with different event types', async () => { + mockClient.sendMessageStream.mockReturnValue( + (async function* () { + yield { kind: 'message', messageId: 'm1' }; + yield { kind: 'status-update', taskId: 't1' }; + })(), + ); + + const stream = manager.sendMessageStream('TestAgent', 'Hello'); + const results = []; + for await (const result of stream) { + results.push(result); } - const call = sendMessageStreamMock.mock.calls[0][0]; - expect(call.message.contextId).toBe(expectedContextId); - expect(call.message.taskId).toBe(expectedTaskId); + expect(results).toHaveLength(2); + expect(results[0].kind).toBe('message'); + expect(results[1].kind).toBe('status-update'); }); it('should throw prefixed error on failure', async () => { - sendMessageStreamMock.mockImplementationOnce(() => { - throw new Error('Network error'); + mockClient.sendMessageStream.mockImplementation(() => { + throw new Error('Network failure'); }); const stream = manager.sendMessageStream('TestAgent', 'Hello'); await expect(async () => { for await (const _ of stream) { - // consume + // empty } }).rejects.toThrow( - '[A2AClientManager] sendMessageStream Error [TestAgent]: Network error', + '[A2AClientManager] sendMessageStream Error [TestAgent]: Network failure', ); }); @@ -261,7 +318,7 @@ describe('A2AClientManager', () => { const stream = manager.sendMessageStream('NonExistentAgent', 'Hello'); await expect(async () => { for await (const _ of stream) { - // consume + // empty } }).rejects.toThrow("Agent 'NonExistentAgent' not found."); }); @@ -269,28 +326,23 @@ describe('A2AClientManager', () => { describe('getTask', () => { beforeEach(async () => { - await manager.loadAgent('TestAgent', 'http://test.agent'); + await manager.loadAgent('TestAgent', 'http://test.agent/card'); }); it('should get a task from the correct agent', async () => { - getTaskMock.mockResolvedValue({ - id: 'task123', - contextId: 'a', - kind: 'task', - status: { state: 'completed' }, - } as Task); + const mockTask = { id: 'task123', kind: 'task' }; + mockClient.getTask.mockResolvedValue(mockTask); - await manager.getTask('TestAgent', 'task123'); - expect(getTaskMock).toHaveBeenCalledWith({ - id: 'task123', - }); + const result = await manager.getTask('TestAgent', 'task123'); + expect(result).toBe(mockTask); + expect(mockClient.getTask).toHaveBeenCalledWith({ id: 'task123' }); }); it('should throw prefixed error on failure', async () => { - getTaskMock.mockRejectedValueOnce(new Error('Network error')); + mockClient.getTask.mockRejectedValue(new Error('Not found')); await expect(manager.getTask('TestAgent', 'task123')).rejects.toThrow( - 'A2AClient getTask Error [TestAgent]: Network error', + 'A2AClient getTask Error [TestAgent]: Not found', ); }); @@ -303,28 +355,23 @@ describe('A2AClientManager', () => { describe('cancelTask', () => { beforeEach(async () => { - await manager.loadAgent('TestAgent', 'http://test.agent'); + await manager.loadAgent('TestAgent', 'http://test.agent/card'); }); it('should cancel a task on the correct agent', async () => { - cancelTaskMock.mockResolvedValue({ - id: 'task123', - contextId: 'a', - kind: 'task', - status: { state: 'canceled' }, - } as Task); + const mockTask = { id: 'task123', kind: 'task' }; + mockClient.cancelTask.mockResolvedValue(mockTask); - await manager.cancelTask('TestAgent', 'task123'); - expect(cancelTaskMock).toHaveBeenCalledWith({ - id: 'task123', - }); + const result = await manager.cancelTask('TestAgent', 'task123'); + expect(result).toBe(mockTask); + expect(mockClient.cancelTask).toHaveBeenCalledWith({ id: 'task123' }); }); it('should throw prefixed error on failure', async () => { - cancelTaskMock.mockRejectedValueOnce(new Error('Network error')); + mockClient.cancelTask.mockRejectedValue(new Error('Cannot cancel')); await expect(manager.cancelTask('TestAgent', 'task123')).rejects.toThrow( - 'A2AClient cancelTask Error [TestAgent]: Network error', + 'A2AClient cancelTask Error [TestAgent]: Cannot cancel', ); }); @@ -334,4 +381,23 @@ describe('A2AClientManager', () => { ).rejects.toThrow("Agent 'NonExistentAgent' not found."); }); }); + + describe('Protocol Routing & URL Logic', () => { + it('should correctly split URLs to prevent .well-known doubling', async () => { + const fullUrl = 'http://localhost:9001/.well-known/agent-card.json'; + const resolverInstance = { + resolve: vi.fn().mockResolvedValue({ name: 'test' } as AgentCard), + }; + vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue( + resolverInstance as unknown as sdkClient.DefaultAgentCardResolver, + ); + + await manager.loadAgent('test-doubling', fullUrl); + + expect(resolverInstance.resolve).toHaveBeenCalledWith( + 'http://localhost:9001/', + '.well-known/agent-card.json', + ); + }); + }); }); diff --git a/packages/core/src/agents/a2a-client-manager.ts b/packages/core/src/agents/a2a-client-manager.ts index e7070f3dfa..922b0fb7dc 100644 --- a/packages/core/src/agents/a2a-client-manager.ts +++ b/packages/core/src/agents/a2a-client-manager.ts @@ -13,7 +13,6 @@ import type { TaskArtifactUpdateEvent, } from '@a2a-js/sdk'; import { - type Client, ClientFactory, ClientFactoryOptions, DefaultAgentCardResolver, @@ -21,9 +20,12 @@ import { JsonRpcTransportFactory, type AuthenticationHandler, createAuthenticatingFetchWithRetry, + type Client, } from '@a2a-js/sdk/client'; +import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc'; import { v4 as uuidv4 } from 'uuid'; import { Agent as UndiciAgent } from 'undici'; +import { getGrpcCredentials, normalizeAgentCard } from './a2aUtils.js'; import { debugLogger } from '../utils/debugLogger.js'; // Remote agents can take 10+ minutes (e.g. Deep Research). @@ -44,8 +46,10 @@ export type SendMessageResult = | TaskArtifactUpdateEvent; /** - * Manages A2A clients and caches loaded agent information. - * Follows a singleton pattern to ensure a single client instance. + * Orchestrates communication with A2A agents. + * + * This manager handles agent discovery, card caching, and client lifecycle. + * It provides a unified messaging interface using the standard A2A SDK. */ export class A2AClientManager { private static instance: A2AClientManager; @@ -91,27 +95,26 @@ export class A2AClientManager { throw new Error(`Agent with name '${name}' is already loaded.`); } - let fetchImpl: typeof fetch = a2aFetch; - if (authHandler) { - fetchImpl = createAuthenticatingFetchWithRetry(a2aFetch, authHandler); - } - + const fetchImpl = this.getFetchImpl(authHandler); const resolver = new DefaultAgentCardResolver({ fetchImpl }); + const agentCard = await this.resolveAgentCard(agentCardUrl, resolver); - const options = ClientFactoryOptions.createFrom( + // Configure standard SDK client for tool registration and discovery + const clientOptions = ClientFactoryOptions.createFrom( ClientFactoryOptions.default, { transports: [ new RestTransportFactory({ fetchImpl }), new JsonRpcTransportFactory({ fetchImpl }), + new GrpcTransportFactory({ + grpcChannelCredentials: getGrpcCredentials(agentCard.url), + }), ], cardResolver: resolver, }, ); - - const factory = new ClientFactory(options); - const client = await factory.createFromUrl(agentCardUrl, ''); - const agentCard = await client.getAgentCard(); + const factory = new ClientFactory(clientOptions); + const client = await factory.createFromAgentCard(agentCard); this.clients.set(name, client); this.agentCards.set(name, agentCard); @@ -146,9 +149,7 @@ export class A2AClientManager { options?: { contextId?: string; taskId?: string; signal?: AbortSignal }, ): AsyncIterable { const client = this.clients.get(agentName); - if (!client) { - throw new Error(`Agent '${agentName}' not found.`); - } + if (!client) throw new Error(`Agent '${agentName}' not found.`); const messageParams: MessageSendParams = { message: { @@ -202,9 +203,7 @@ export class A2AClientManager { */ async getTask(agentName: string, taskId: string): Promise { const client = this.clients.get(agentName); - if (!client) { - throw new Error(`Agent '${agentName}' not found.`); - } + if (!client) throw new Error(`Agent '${agentName}' not found.`); try { return await client.getTask({ id: taskId }); } catch (error: unknown) { @@ -224,9 +223,7 @@ export class A2AClientManager { */ async cancelTask(agentName: string, taskId: string): Promise { const client = this.clients.get(agentName); - if (!client) { - throw new Error(`Agent '${agentName}' not found.`); - } + if (!client) throw new Error(`Agent '${agentName}' not found.`); try { return await client.cancelTask({ id: taskId }); } catch (error: unknown) { @@ -237,4 +234,35 @@ export class A2AClientManager { throw new Error(`${prefix}: Unexpected error: ${String(error)}`); } } + + /** + * Resolves the appropriate fetch implementation for an agent. + */ + private getFetchImpl(authHandler?: AuthenticationHandler): typeof fetch { + return authHandler + ? createAuthenticatingFetchWithRetry(a2aFetch, authHandler) + : a2aFetch; + } + + /** + * Resolves and normalizes an agent card from a given URL. + * Handles splitting the URL if it already contains the standard .well-known path. + */ + private async resolveAgentCard( + url: string, + resolver: DefaultAgentCardResolver, + ): Promise { + const standardPath = '.well-known/agent-card.json'; + let baseUrl = url; + let path: string | undefined; + + if (baseUrl.includes(standardPath)) { + const parts = baseUrl.split(standardPath); + baseUrl = parts[0] || ''; + path = standardPath; + } + + const rawCard = await resolver.resolve(baseUrl, path); + return normalizeAgentCard(rawCard); + } } diff --git a/packages/core/src/agents/a2aUtils.test.ts b/packages/core/src/agents/a2aUtils.test.ts index 2bcdad2c40..e9bae9a083 100644 --- a/packages/core/src/agents/a2aUtils.test.ts +++ b/packages/core/src/agents/a2aUtils.test.ts @@ -11,6 +11,8 @@ import { isTerminalState, A2AResultReassembler, AUTH_REQUIRED_MSG, + normalizeAgentCard, + getGrpcCredentials, } from './a2aUtils.js'; import type { SendMessageResult } from './a2a-client-manager.js'; import type { @@ -24,6 +26,18 @@ import type { } from '@a2a-js/sdk'; describe('a2aUtils', () => { + describe('getGrpcCredentials', () => { + it('should return secure credentials for https', () => { + const credentials = getGrpcCredentials('https://test.agent'); + expect(credentials).toBeDefined(); + }); + + it('should return insecure credentials for http', () => { + const credentials = getGrpcCredentials('http://test.agent'); + expect(credentials).toBeDefined(); + }); + }); + describe('isTerminalState', () => { it('should return true for completed, failed, canceled, and rejected', () => { expect(isTerminalState('completed')).toBe(true); @@ -223,6 +237,119 @@ describe('a2aUtils', () => { } as Message), ).toBe(''); }); + + it('should handle file parts with neither name nor uri', () => { + const message: Message = { + kind: 'message', + role: 'user', + messageId: '1', + parts: [ + { + kind: 'file', + file: { + mimeType: 'text/plain', + }, + } as FilePart, + ], + }; + expect(extractMessageText(message)).toBe('File: [binary/unnamed]'); + }); + }); + + describe('normalizeAgentCard', () => { + it('should throw if input is not an object', () => { + expect(() => normalizeAgentCard(null)).toThrow('Agent card is missing.'); + expect(() => normalizeAgentCard(undefined)).toThrow( + 'Agent card is missing.', + ); + expect(() => normalizeAgentCard('not an object')).toThrow( + 'Agent card is missing.', + ); + }); + + it('should preserve unknown fields while providing defaults for mandatory ones', () => { + const raw = { + name: 'my-agent', + customField: 'keep-me', + }; + + const normalized = normalizeAgentCard(raw); + + expect(normalized.name).toBe('my-agent'); + // @ts-expect-error - testing dynamic preservation + expect(normalized.customField).toBe('keep-me'); + expect(normalized.description).toBe(''); + expect(normalized.skills).toEqual([]); + expect(normalized.defaultInputModes).toEqual([]); + }); + + it('should normalize and synchronize interfaces while preserving other fields', () => { + const raw = { + name: 'test', + supportedInterfaces: [ + { + url: 'grpc://test', + protocolBinding: 'GRPC', + protocolVersion: '1.0', + }, + ], + }; + + const normalized = normalizeAgentCard(raw); + + // Should exist in both fields + expect(normalized.additionalInterfaces).toHaveLength(1); + expect( + (normalized as unknown as Record)[ + 'supportedInterfaces' + ], + ).toHaveLength(1); + + const intf = normalized.additionalInterfaces?.[0] as unknown as Record< + string, + unknown + >; + + expect(intf['transport']).toBe('GRPC'); + expect(intf['url']).toBe('grpc://test'); + + // Should fallback top-level url + expect(normalized.url).toBe('grpc://test'); + }); + + it('should preserve existing top-level url if present', () => { + const raw = { + name: 'test', + url: 'http://existing', + supportedInterfaces: [{ url: 'http://other', transport: 'REST' }], + }; + + const normalized = normalizeAgentCard(raw); + expect(normalized.url).toBe('http://existing'); + }); + + it('should NOT prepend http:// scheme to raw IP:port strings for gRPC interfaces', () => { + const raw = { + name: 'raw-ip-grpc', + supportedInterfaces: [{ url: '127.0.0.1:9000', transport: 'GRPC' }], + }; + + const normalized = normalizeAgentCard(raw); + expect(normalized.additionalInterfaces?.[0].url).toBe('127.0.0.1:9000'); + expect(normalized.url).toBe('127.0.0.1:9000'); + }); + + it('should prepend http:// scheme to raw IP:port strings for REST interfaces', () => { + const raw = { + name: 'raw-ip-rest', + supportedInterfaces: [{ url: '127.0.0.1:8080', transport: 'REST' }], + }; + + const normalized = normalizeAgentCard(raw); + expect(normalized.additionalInterfaces?.[0].url).toBe( + 'http://127.0.0.1:8080', + ); + }); }); describe('A2AResultReassembler', () => { @@ -233,6 +360,7 @@ describe('a2aUtils', () => { reassembler.update({ kind: 'status-update', taskId: 't1', + contextId: 'ctx1', status: { state: 'working', message: { @@ -247,6 +375,7 @@ describe('a2aUtils', () => { reassembler.update({ kind: 'artifact-update', taskId: 't1', + contextId: 'ctx1', append: false, artifact: { artifactId: 'a1', @@ -259,6 +388,7 @@ describe('a2aUtils', () => { reassembler.update({ kind: 'status-update', taskId: 't1', + contextId: 'ctx1', status: { state: 'working', message: { @@ -273,6 +403,7 @@ describe('a2aUtils', () => { reassembler.update({ kind: 'artifact-update', taskId: 't1', + contextId: 'ctx1', append: true, artifact: { artifactId: 'a1', @@ -291,6 +422,7 @@ describe('a2aUtils', () => { reassembler.update({ kind: 'status-update', + contextId: 'ctx1', status: { state: 'auth-required', message: { @@ -310,6 +442,7 @@ describe('a2aUtils', () => { reassembler.update({ kind: 'status-update', + contextId: 'ctx1', status: { state: 'auth-required', }, @@ -323,6 +456,7 @@ describe('a2aUtils', () => { const chunk = { kind: 'status-update', + contextId: 'ctx1', status: { state: 'auth-required', message: { @@ -351,6 +485,8 @@ describe('a2aUtils', () => { reassembler.update({ kind: 'task', + id: 'task-1', + contextId: 'ctx1', status: { state: 'completed' }, history: [ { @@ -369,6 +505,8 @@ describe('a2aUtils', () => { reassembler.update({ kind: 'task', + id: 'task-1', + contextId: 'ctx1', status: { state: 'working' }, history: [ { @@ -387,6 +525,8 @@ describe('a2aUtils', () => { reassembler.update({ kind: 'task', + id: 'task-1', + contextId: 'ctx1', status: { state: 'completed' }, artifacts: [ { diff --git a/packages/core/src/agents/a2aUtils.ts b/packages/core/src/agents/a2aUtils.ts index dc39f4e660..917dfb33a2 100644 --- a/packages/core/src/agents/a2aUtils.ts +++ b/packages/core/src/agents/a2aUtils.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import * as grpc from '@grpc/grpc-js'; import type { Message, Part, @@ -13,6 +14,8 @@ import type { Artifact, TaskState, TaskStatusUpdateEvent, + AgentCard, + AgentInterface, } from '@a2a-js/sdk'; import type { SendMessageResult } from './a2a-client-manager.js'; @@ -210,36 +213,61 @@ function extractPartText(part: Part): string { return ''; } -// Type Guards +/** + * Normalizes an agent card by ensuring it has the required properties + * and resolving any inconsistencies between protocol versions. + */ +export function normalizeAgentCard(card: unknown): AgentCard { + if (!isObject(card)) { + throw new Error('Agent card is missing.'); + } -function isTextPart(part: Part): part is TextPart { - return part.kind === 'text'; -} + // Double-cast to bypass strict linter while bootstrapping the object. + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + const result = { ...card } as unknown as AgentCard; -function isDataPart(part: Part): part is DataPart { - return part.kind === 'data'; -} + // Ensure mandatory fields exist with safe defaults. + if (typeof result.name !== 'string') result.name = 'unknown'; + if (typeof result.description !== 'string') result.description = ''; + if (typeof result.url !== 'string') result.url = ''; + if (typeof result.version !== 'string') result.version = ''; + if (typeof result.protocolVersion !== 'string') result.protocolVersion = ''; + if (!isObject(result.capabilities)) result.capabilities = {}; + if (!Array.isArray(result.skills)) result.skills = []; + if (!Array.isArray(result.defaultInputModes)) result.defaultInputModes = []; + if (!Array.isArray(result.defaultOutputModes)) result.defaultOutputModes = []; -function isFilePart(part: Part): part is FilePart { - return part.kind === 'file'; -} + // Normalize interfaces and synchronize both interface fields. + const normalizedInterfaces = extractNormalizedInterfaces(card); + result.additionalInterfaces = normalizedInterfaces; + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + (result as unknown as Record)[ + 'supportedInterfaces' + ] = normalizedInterfaces; -function isStatusUpdateEvent( - result: SendMessageResult, -): result is TaskStatusUpdateEvent { - return result.kind === 'status-update'; + // Fallback preferredTransport: If not specified, default to GRPC if available. + if ( + !result.preferredTransport && + normalizedInterfaces.some((i) => i.transport === 'GRPC') + ) { + result.preferredTransport = 'GRPC'; + } + + // Fallback: If top-level URL is missing, use the first interface's URL. + if (result.url === '' && normalizedInterfaces.length > 0) { + result.url = normalizedInterfaces[0].url; + } + + return result; } /** - * Returns true if the given state is a terminal state for a task. + * Returns gRPC channel credentials based on the URL scheme. */ -export function isTerminalState(state: TaskState | undefined): boolean { - return ( - state === 'completed' || - state === 'failed' || - state === 'canceled' || - state === 'rejected' - ); +export function getGrpcCredentials(url: string): grpc.ChannelCredentials { + return url.startsWith('https://') + ? grpc.credentials.createSsl() + : grpc.credentials.createInsecure(); } /** @@ -279,3 +307,107 @@ export function extractIdsFromResponse(result: SendMessageResult): { return { contextId, taskId, clearTaskId }; } + +/** + * Extracts and normalizes interfaces from the card, handling protocol version fallbacks. + * Preserves all original fields to maintain SDK compatibility. + */ +function extractNormalizedInterfaces( + card: Record, +): AgentInterface[] { + const rawInterfaces = + getArray(card, 'additionalInterfaces') || + getArray(card, 'supportedInterfaces'); + + if (!rawInterfaces) { + return []; + } + + const mapped: AgentInterface[] = []; + for (const i of rawInterfaces) { + if (isObject(i)) { + // Create a copy to preserve all original fields. + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + const normalized = { ...i } as unknown as AgentInterface & { + protocolBinding?: string; + }; + + // Ensure 'url' exists + if (typeof normalized.url !== 'string') { + normalized.url = ''; + } + + // Normalize 'transport' from 'protocolBinding' + const transport = normalized.transport || normalized.protocolBinding; + if (transport) { + normalized.transport = transport; + } + + // Robust URL: Ensure the URL has a scheme (except for gRPC). + // Some agent implementations (like a2a-go samples) may provide raw IP:port strings. + // gRPC targets MUST NOT have a scheme (e.g. 'http://'), or they will fail name resolution. + if ( + normalized.url && + !normalized.url.includes('://') && + !normalized.url.startsWith('/') && + normalized.transport !== 'GRPC' + ) { + // Default to http:// for insecure REST/JSON-RPC if scheme is missing. + normalized.url = `http://${normalized.url}`; + } + + mapped.push(normalized); + } + } + return mapped; +} + +/** + * Safely extracts an array property from an object. + */ +function getArray( + obj: Record, + key: string, +): unknown[] | undefined { + const val = obj[key]; + return Array.isArray(val) ? val : undefined; +} + +// Type Guards + +function isTextPart(part: Part): part is TextPart { + return part.kind === 'text'; +} + +function isDataPart(part: Part): part is DataPart { + return part.kind === 'data'; +} + +function isFilePart(part: Part): part is FilePart { + return part.kind === 'file'; +} + +function isStatusUpdateEvent( + result: SendMessageResult, +): result is TaskStatusUpdateEvent { + return result.kind === 'status-update'; +} + +/** + * Returns true if the given state is a terminal state for a task. + */ +export function isTerminalState(state: TaskState | undefined): boolean { + return ( + state === 'completed' || + state === 'failed' || + state === 'canceled' || + state === 'rejected' + ); +} + +/** + * Type guard to check if a value is a non-array object. + */ +function isObject(val: unknown): val is Record { + return typeof val === 'object' && val !== null && !Array.isArray(val); +} diff --git a/packages/core/src/agents/registry.test.ts b/packages/core/src/agents/registry.test.ts index edae478f2a..f7070ed895 100644 --- a/packages/core/src/agents/registry.test.ts +++ b/packages/core/src/agents/registry.test.ts @@ -874,6 +874,31 @@ describe('AgentRegistry', () => { ); }); + it('should maintain registration under canonical name even if overrides are applied', async () => { + const originalName = 'my-agent'; + const definition = { ...MOCK_AGENT_V1, name: originalName }; + + // Mock overrides in settings + vi.spyOn(mockConfig, 'getAgentsSettings').mockReturnValue({ + overrides: { + [originalName]: { + enabled: true, + modelConfig: { model: 'overridden-model' }, + }, + }, + }); + + await registry.testRegisterAgent(definition); + + const registered = registry.getDefinition(originalName); + expect(registered).toBeDefined(); + expect((registered as LocalAgentDefinition).modelConfig.model).toBe( + 'overridden-model', + ); + // Ensure it is NOT registered under some other key + expect(registry.getAllAgentNames()).toEqual([originalName]); + }); + it('should reject an agent definition missing a name', async () => { const invalidAgent = { ...MOCK_AGENT_V1, name: '' }; const debugWarnSpy = vi diff --git a/packages/core/src/agents/registry.ts b/packages/core/src/agents/registry.ts index bf7e669150..11b1e74b07 100644 --- a/packages/core/src/agents/registry.ts +++ b/packages/core/src/agents/registry.ts @@ -312,7 +312,7 @@ export class AgentRegistry { } const mergedDefinition = this.applyOverrides(definition, settingsOverrides); - this.agents.set(mergedDefinition.name, mergedDefinition); + this.agents.set(definition.name, mergedDefinition); this.registerModelConfigs(mergedDefinition); this.addAgentPolicy(mergedDefinition);