From 96b9be3ec43937dae2f31bdfbd88a4ecfd48a0f4 Mon Sep 17 00:00:00 2001 From: Adam Weidman <65992621+adamfweidman@users.noreply.github.com> Date: Tue, 6 Jan 2026 18:45:05 -0500 Subject: [PATCH] feat(agents): add support for remote agents (#16013) --- .../src/agents/a2a-client-manager.test.ts | 39 +++ .../core/src/agents/a2a-client-manager.ts | 151 +++++++- packages/core/src/agents/a2aUtils.test.ts | 171 ++++++++++ packages/core/src/agents/a2aUtils.ts | 142 ++++++++ .../src/agents/delegate-to-agent-tool.test.ts | 43 +++ .../core/src/agents/remote-invocation.test.ts | 321 ++++++++++++++++-- packages/core/src/agents/remote-invocation.ts | 149 +++++++- packages/core/src/agents/types.ts | 5 + 8 files changed, 980 insertions(+), 41 deletions(-) create mode 100644 packages/core/src/agents/a2aUtils.test.ts create mode 100644 packages/core/src/agents/a2aUtils.ts diff --git a/packages/core/src/agents/a2a-client-manager.test.ts b/packages/core/src/agents/a2a-client-manager.test.ts index 1fe55a42ba..fb0f2829a4 100644 --- a/packages/core/src/agents/a2a-client-manager.test.ts +++ b/packages/core/src/agents/a2a-client-manager.test.ts @@ -8,6 +8,7 @@ import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest'; import { A2AClientManager, type SendMessageResult, + createAdapterFetch, } from './a2a-client-manager.js'; import type { AgentCard, Task } from '@a2a-js/sdk'; import type { AuthenticationHandler, Client } from '@a2a-js/sdk/client'; @@ -302,4 +303,42 @@ describe('A2AClientManager', () => { ).rejects.toThrow("Agent 'NonExistentAgent' not found."); }); }); + + describe('createAdapterFetch', () => { + it('normalizes TASK_STATE_ enums to lower-case', async () => { + const baseFetch = vi + .fn() + .mockResolvedValue( + new Response( + JSON.stringify({ status: { state: 'TASK_STATE_WORKING' } }), + ), + ); + + const adapter = createAdapterFetch(baseFetch as typeof fetch); + const response = await adapter('http://example.com', { + method: 'POST', + body: '{}', + }); + const data = await response.json(); + + expect(data.status.state).toBe('working'); + }); + + it('lowercases non-prefixed task states', async () => { + const baseFetch = vi + .fn() + .mockResolvedValue( + new Response(JSON.stringify({ status: { state: 'WORKING' } })), + ); + + const adapter = createAdapterFetch(baseFetch as typeof fetch); + const response = await adapter('http://example.com', { + method: 'POST', + body: '{}', + }); + const data = await response.json(); + + expect(data.status.state).toBe('working'); + }); + }); }); diff --git a/packages/core/src/agents/a2a-client-manager.ts b/packages/core/src/agents/a2a-client-manager.ts index 9eccca4ad4..ef93522e03 100644 --- a/packages/core/src/agents/a2a-client-manager.ts +++ b/packages/core/src/agents/a2a-client-manager.ts @@ -68,11 +68,15 @@ export class A2AClientManager { throw new Error(`Agent with name '${name}' is already loaded.`); } - let fetchImpl = fetch; + let fetchImpl: typeof fetch = fetch; if (authHandler) { fetchImpl = createAuthenticatingFetchWithRetry(fetch, authHandler); } + // Wrap with custom adapter for ADK Reasoning Engine compatibility + // TODO: Remove this when a2a-js fixes compatibility + fetchImpl = createAdapterFetch(fetchImpl); + const resolver = new DefaultAgentCardResolver({ fetchImpl }); const options = ClientFactoryOptions.createFrom( @@ -207,3 +211,148 @@ export class A2AClientManager { } } } + +/** + * Maps TaskState proto-JSON enums to lower-case strings. + */ +function mapTaskState(state: string | undefined): string | undefined { + if (!state) return state; + if (state.startsWith('TASK_STATE_')) { + return state.replace('TASK_STATE_', '').toLowerCase(); + } + return state.toLowerCase(); +} + +/** + * Creates a fetch implementation that adapts standard A2A SDK requests to the + * proto-JSON dialect and endpoint shapes required by Vertex AI Agent Engine. + */ +export function createAdapterFetch(baseFetch: typeof fetch): typeof fetch { + return async ( + input: RequestInfo | URL, + init?: RequestInit, + ): Promise => { + const urlStr = input as string; + + // 2. Dialect Mapping (Request) + let body = init?.body; + let isRpc = false; + let rpcId: string | number | undefined; + + if (typeof body === 'string') { + try { + let jsonBody = JSON.parse(body); + + // Unwrap JSON-RPC if present + if (jsonBody.jsonrpc === '2.0') { + isRpc = true; + rpcId = jsonBody.id; + jsonBody = jsonBody.params; + } + + // Apply dialect translation to the message object + const message = jsonBody.message || jsonBody; + if (message && typeof message === 'object') { + // Role: user -> ROLE_USER, agent/model -> ROLE_AGENT + if (message.role === 'user') message.role = 'ROLE_USER'; + if (message.role === 'agent' || message.role === 'model') { + message.role = 'ROLE_AGENT'; + } + + // Strip SDK-specific 'kind' field + delete message.kind; + + // Map 'parts' to 'content' (Proto-JSON dialect often uses 'content' or typed parts) + // Also strip 'kind' from parts. + if (Array.isArray(message.parts)) { + message.content = message.parts.map( + (p: { kind?: string; text?: string }) => { + const { kind: _k, ...rest } = p; + // If it's a simple text part, ensure it matches { text: "..." } + if (p.kind === 'text') return { text: p.text }; + return rest; + }, + ); + delete message.parts; + } + } + + body = JSON.stringify(jsonBody); + } catch (error) { + debugLogger.debug( + '[A2AClientManager] Failed to parse request body for dialect translation:', + error, + ); + // Non-JSON or parse error; let the baseFetch handle it. + } + } + + const response = await baseFetch(urlStr, { ...init, body }); + + // Map response back + if (response.ok) { + try { + const responseData = await response.clone().json(); + + const result = + responseData.task || responseData.message || responseData; + + // Restore 'kind' for the SDK and a2aUtils parsing + if (result && typeof result === 'object' && !result.kind) { + if (responseData.task || (result.id && result.status)) { + result.kind = 'task'; + } else if (responseData.message || result.messageId) { + result.kind = 'message'; + } + } + + // Restore 'kind' on parts so extractMessageText works + if (result?.parts && Array.isArray(result.parts)) { + for (const part of result.parts) { + if (!part.kind) { + if (part.file) part.kind = 'file'; + else if (part.data) part.kind = 'data'; + else if (part.text) part.kind = 'text'; + } + } + } + + // Recursively restore 'kind' on artifact parts + if (result?.artifacts && Array.isArray(result.artifacts)) { + for (const artifact of result.artifacts) { + if (artifact.parts && Array.isArray(artifact.parts)) { + for (const part of artifact.parts) { + if (!part.kind) { + if (part.file) part.kind = 'file'; + else if (part.data) part.kind = 'data'; + else if (part.text) part.kind = 'text'; + } + } + } + } + } + + // Map Task States back to SDK expectations + if (result && typeof result === 'object' && result.status) { + result.status.state = mapTaskState(result.status.state); + } + + if (isRpc) { + return new Response( + JSON.stringify({ + jsonrpc: '2.0', + id: rpcId, + result, + }), + response, + ); + } + return new Response(JSON.stringify(result), response); + } catch (_e) { + // Non-JSON response or unwrapping failure + } + } + + return response; + }; +} diff --git a/packages/core/src/agents/a2aUtils.test.ts b/packages/core/src/agents/a2aUtils.test.ts new file mode 100644 index 0000000000..0527b54bdd --- /dev/null +++ b/packages/core/src/agents/a2aUtils.test.ts @@ -0,0 +1,171 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { + extractMessageText, + extractTaskText, + extractIdsFromResponse, +} from './a2aUtils.js'; +import type { Message, Task, TextPart, DataPart, FilePart } from '@a2a-js/sdk'; + +describe('a2aUtils', () => { + describe('extractIdsFromResponse', () => { + it('should extract IDs from a message response', () => { + const message: Message = { + kind: 'message', + role: 'agent', + messageId: 'm1', + contextId: 'ctx-1', + taskId: 'task-1', + parts: [], + }; + + const result = extractIdsFromResponse(message); + expect(result).toEqual({ contextId: 'ctx-1', taskId: 'task-1' }); + }); + + it('should extract IDs from an in-progress task response', () => { + const task: Task = { + id: 'task-2', + contextId: 'ctx-2', + kind: 'task', + status: { state: 'working' }, + }; + + const result = extractIdsFromResponse(task); + expect(result).toEqual({ contextId: 'ctx-2', taskId: 'task-2' }); + }); + }); + + describe('extractMessageText', () => { + it('should extract text from simple text parts', () => { + const message: Message = { + kind: 'message', + role: 'user', + messageId: '1', + parts: [ + { kind: 'text', text: 'Hello' } as TextPart, + { kind: 'text', text: 'World' } as TextPart, + ], + }; + expect(extractMessageText(message)).toBe('Hello\nWorld'); + }); + + it('should extract data from data parts', () => { + const message: Message = { + kind: 'message', + role: 'user', + messageId: '1', + parts: [{ kind: 'data', data: { foo: 'bar' } } as DataPart], + }; + expect(extractMessageText(message)).toBe('Data: {"foo":"bar"}'); + }); + + it('should extract file info from file parts', () => { + const message: Message = { + kind: 'message', + role: 'user', + messageId: '1', + parts: [ + { + kind: 'file', + file: { + name: 'test.txt', + uri: 'file://test.txt', + mimeType: 'text/plain', + }, + } as FilePart, + { + kind: 'file', + file: { + uri: 'http://example.com/doc', + mimeType: 'application/pdf', + }, + } as FilePart, + ], + }; + // The formatting logic in a2aUtils prefers name over uri + expect(extractMessageText(message)).toContain('File: test.txt'); + expect(extractMessageText(message)).toContain( + 'File: http://example.com/doc', + ); + }); + + it('should handle mixed parts', () => { + const message: Message = { + kind: 'message', + role: 'user', + messageId: '1', + parts: [ + { kind: 'text', text: 'Here is data:' } as TextPart, + { kind: 'data', data: { value: 123 } } as DataPart, + ], + }; + expect(extractMessageText(message)).toBe( + 'Here is data:\nData: {"value":123}', + ); + }); + + it('should return empty string for undefined or empty message', () => { + expect(extractMessageText(undefined)).toBe(''); + expect( + extractMessageText({ + kind: 'message', + role: 'user', + messageId: '1', + parts: [], + } as Message), + ).toBe(''); + }); + }); + + describe('extractTaskText', () => { + it('should extract basic task info', () => { + const task: Task = { + id: 'task-1', + contextId: 'ctx-1', + kind: 'task', + status: { + state: 'working', + message: { + kind: 'message', + role: 'agent', + messageId: 'm1', + parts: [{ kind: 'text', text: 'Processing...' } as TextPart], + }, + }, + }; + + const result = extractTaskText(task); + expect(result).toContain('ID: task-1'); + expect(result).toContain('State: working'); + expect(result).toContain('Status Message: Processing...'); + }); + + it('should extract artifacts', () => { + const task: Task = { + id: 'task-1', + contextId: 'ctx-1', + kind: 'task', + status: { state: 'completed' }, + artifacts: [ + { + artifactId: 'art-1', + name: 'Report', + parts: [{ kind: 'text', text: 'This is the report.' } as TextPart], + }, + ], + }; + + const result = extractTaskText(task); + expect(result).toContain('Artifacts:'); + expect(result).toContain(' - Name: Report'); + expect(result).toContain(' Content:'); + expect(result).toContain(' This is the report.'); + }); + }); +}); diff --git a/packages/core/src/agents/a2aUtils.ts b/packages/core/src/agents/a2aUtils.ts new file mode 100644 index 0000000000..fc19eceb05 --- /dev/null +++ b/packages/core/src/agents/a2aUtils.ts @@ -0,0 +1,142 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { + Message, + Task, + Part, + TextPart, + DataPart, + FilePart, +} from '@a2a-js/sdk'; + +/** + * Extracts a human-readable text representation from a Message object. + * Handles Text, Data (JSON), and File parts. + */ +export function extractMessageText(message: Message | undefined): string { + if (!message || !message.parts) { + return ''; + } + + const parts = message.parts + .map((part) => extractPartText(part)) + .filter(Boolean); + return parts.join('\n'); +} + +/** + * Extracts text from a single Part. + */ +export function extractPartText(part: Part): string { + if (isTextPart(part)) { + return part.text; + } + + if (isDataPart(part)) { + // Attempt to format known data types if metadata exists, otherwise JSON stringify + return `Data: ${JSON.stringify(part.data)}`; + } + + if (isFilePart(part)) { + const fileData = part.file; + if (fileData.name) { + return `File: ${fileData.name}`; + } + if ('uri' in fileData && fileData.uri) { + return `File: ${fileData.uri}`; + } + return `File: [binary/unnamed]`; + } + + return ''; +} + +/** + * Extracts a human-readable text summary from a Task object. + * Includes status, ID, and any artifact content. + */ +export function extractTaskText(task: Task): string { + let output = `ID: ${task.id}\n`; + output += `State: ${task.status.state}\n`; + + // Status Message + const statusMessageText = extractMessageText(task.status.message); + if (statusMessageText) { + output += `Status Message: ${statusMessageText}\n`; + } + + // Artifacts + if (task.artifacts && task.artifacts.length > 0) { + output += `Artifacts:\n`; + for (const artifact of task.artifacts) { + output += ` - Name: ${artifact.name}\n`; + if (artifact.parts && artifact.parts.length > 0) { + // Treat artifact parts as a message for extraction + const artifactContent = artifact.parts + .map((p) => extractPartText(p)) + .filter(Boolean) + .join('\n'); + + if (artifactContent) { + // Indent content for readability + const indentedContent = artifactContent.replace(/^/gm, ' '); + output += ` Content:\n${indentedContent}\n`; + } + } + } + } + + return output; +} + +// Type Guards + +function isTextPart(part: Part): part is TextPart { + return part.kind === 'text'; +} + +function isDataPart(part: Part): part is DataPart { + return part.kind === 'data'; +} + +function isFilePart(part: Part): part is FilePart { + return part.kind === 'file'; +} + +/** + * Extracts contextId and taskId from a Message or Task response. + * Follows the pattern from the A2A CLI sample to maintain conversational continuity. + */ +export function extractIdsFromResponse(result: Message | Task): { + contextId?: string; + taskId?: string; +} { + let contextId: string | undefined; + let taskId: string | undefined; + + if (result.kind === 'message') { + taskId = result.taskId; + contextId = result.contextId; + } else if (result.kind === 'task') { + taskId = result.id; + contextId = result.contextId; + + // If the task is in a final state (and not input-required), we clear the taskId + // so that the next interaction starts a fresh task (or keeps context without being bound to the old task). + if ( + result.status && + result.status.state !== 'input-required' && + (result.status.state === 'completed' || + result.status.state === 'failed' || + result.status.state === 'canceled') + ) { + taskId = undefined; + } + } + + return { contextId, taskId }; +} diff --git a/packages/core/src/agents/delegate-to-agent-tool.test.ts b/packages/core/src/agents/delegate-to-agent-tool.test.ts index 5c8601f217..3722ae2e88 100644 --- a/packages/core/src/agents/delegate-to-agent-tool.test.ts +++ b/packages/core/src/agents/delegate-to-agent-tool.test.ts @@ -13,6 +13,7 @@ import { LocalSubagentInvocation } from './local-invocation.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { MessageBusType } from '../confirmation-bus/types.js'; import { DELEGATE_TO_AGENT_TOOL_NAME } from '../tools/tool-names.js'; +import { RemoteAgentInvocation } from './remote-invocation.js'; import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; vi.mock('./local-invocation.js', () => ({ @@ -23,6 +24,15 @@ vi.mock('./local-invocation.js', () => ({ })), })); +vi.mock('./remote-invocation.js', () => ({ + RemoteAgentInvocation: vi.fn().mockImplementation(() => ({ + execute: vi.fn().mockResolvedValue({ + content: [{ type: 'text', text: 'Remote Success' }], + }), + shouldConfirmExecute: vi.fn().mockResolvedValue(true), + })), +})); + describe('DelegateToAgentTool', () => { let registry: AgentRegistry; let config: Config; @@ -45,6 +55,18 @@ describe('DelegateToAgentTool', () => { toolConfig: { tools: [] }, }; + const mockRemoteAgentDef: AgentDefinition = { + kind: 'remote', + name: 'remote_agent', + description: 'A remote agent', + agentCardUrl: 'https://example.com/agent.json', + inputConfig: { + inputs: { + query: { type: 'string', description: 'Query', required: true }, + }, + }, + }; + beforeEach(() => { config = { getDebugMode: () => false, @@ -58,6 +80,8 @@ describe('DelegateToAgentTool', () => { // Manually register the mock agent (bypassing protected method for testing) // eslint-disable-next-line @typescript-eslint/no-explicit-any (registry as any).agents.set(mockAgentDef.name, mockAgentDef); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (registry as any).agents.set(mockRemoteAgentDef.name, mockRemoteAgentDef); messageBus = createMockMessageBus(); @@ -176,4 +200,23 @@ describe('DelegateToAgentTool', () => { }), ); }); + + it('should delegate to remote agent correctly', async () => { + const invocation = tool.build({ + agent_name: 'remote_agent', + query: 'hello remote', + }); + + const result = await invocation.execute(new AbortController().signal); + expect(result).toEqual({ + content: [{ type: 'text', text: 'Remote Success' }], + }); + expect(RemoteAgentInvocation).toHaveBeenCalledWith( + mockRemoteAgentDef, + { query: 'hello remote' }, + messageBus, + 'remote_agent', + 'remote_agent', + ); + }); }); diff --git a/packages/core/src/agents/remote-invocation.test.ts b/packages/core/src/agents/remote-invocation.test.ts index 610961e440..f3c998e41b 100644 --- a/packages/core/src/agents/remote-invocation.test.ts +++ b/packages/core/src/agents/remote-invocation.test.ts @@ -4,55 +4,310 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect } from 'vitest'; -import type { ToolCallConfirmationDetails } from '../tools/tools.js'; +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; import { RemoteAgentInvocation } from './remote-invocation.js'; +import { A2AClientManager } from './a2a-client-manager.js'; import type { RemoteAgentDefinition } from './types.js'; import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; -class TestableRemoteAgentInvocation extends RemoteAgentInvocation { - override async getConfirmationDetails( - abortSignal: AbortSignal, - ): Promise { - return super.getConfirmationDetails(abortSignal); - } -} +// Mock A2AClientManager +vi.mock('./a2a-client-manager.js', () => { + const A2AClientManager = { + getInstance: vi.fn(), + }; + return { A2AClientManager }; +}); describe('RemoteAgentInvocation', () => { const mockDefinition: RemoteAgentDefinition = { + name: 'test-agent', kind: 'remote', - name: 'test-remote-agent', - description: 'A test remote agent', - displayName: 'Test Remote Agent', - agentCardUrl: 'https://example.com/agent-card', + agentCardUrl: 'http://test-agent/card', + displayName: 'Test Agent', + description: 'A test agent', inputConfig: { inputs: {}, }, }; + const mockClientManager = { + getClient: vi.fn(), + loadAgent: vi.fn(), + sendMessage: vi.fn(), + }; const mockMessageBus = createMockMessageBus(); - it('should be instantiated with correct params', () => { - const invocation = new RemoteAgentInvocation( - mockDefinition, - {}, - mockMessageBus, - ); - expect(invocation).toBeDefined(); - expect(invocation.getDescription()).toBe( - 'Calling remote agent Test Remote Agent', - ); + beforeEach(() => { + vi.clearAllMocks(); + (A2AClientManager.getInstance as Mock).mockReturnValue(mockClientManager); + ( + RemoteAgentInvocation as unknown as { + sessionState?: Map; + } + ).sessionState?.clear(); }); - it('should return false for confirmation details (not yet implemented)', async () => { - const invocation = new TestableRemoteAgentInvocation( - mockDefinition, - {}, - mockMessageBus, - ); - const details = await invocation.getConfirmationDetails( - new AbortController().signal, - ); - expect(details).toBe(false); + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('Constructor Validation', () => { + it('accepts valid input with string query', () => { + expect(() => { + new RemoteAgentInvocation( + mockDefinition, + { query: 'valid' }, + mockMessageBus, + ); + }).not.toThrow(); + }); + + it('throws if query is missing', () => { + expect(() => { + new RemoteAgentInvocation(mockDefinition, {}, mockMessageBus); + }).toThrow("requires a string 'query' input"); + }); + + it('throws if query is not a string', () => { + expect(() => { + new RemoteAgentInvocation( + mockDefinition, + { query: 123 }, + mockMessageBus, + ); + }).toThrow("requires a string 'query' input"); + }); + }); + + describe('Execution Logic', () => { + it('should lazy load the agent with ADCHandler if not present', async () => { + mockClientManager.getClient.mockReturnValue(undefined); + mockClientManager.sendMessage.mockResolvedValue({ + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [{ kind: 'text', text: 'Hello' }], + }); + + const invocation = new RemoteAgentInvocation( + mockDefinition, + { + query: 'hi', + }, + mockMessageBus, + ); + await invocation.execute(new AbortController().signal); + + expect(mockClientManager.loadAgent).toHaveBeenCalledWith( + 'test-agent', + 'http://test-agent/card', + expect.objectContaining({ + headers: expect.any(Function), + shouldRetryWithHeaders: expect.any(Function), + }), + ); + }); + + it('should not load the agent if already present', async () => { + mockClientManager.getClient.mockReturnValue({}); + mockClientManager.sendMessage.mockResolvedValue({ + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [{ kind: 'text', text: 'Hello' }], + }); + + const invocation = new RemoteAgentInvocation( + mockDefinition, + { + query: 'hi', + }, + mockMessageBus, + ); + await invocation.execute(new AbortController().signal); + + expect(mockClientManager.loadAgent).not.toHaveBeenCalled(); + }); + + it('should persist contextId and taskId across invocations', async () => { + mockClientManager.getClient.mockReturnValue({}); + + // First call return values + mockClientManager.sendMessage.mockResolvedValueOnce({ + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [{ kind: 'text', text: 'Response 1' }], + contextId: 'ctx-1', + taskId: 'task-1', + }); + + const invocation1 = new RemoteAgentInvocation( + mockDefinition, + { + query: 'first', + }, + mockMessageBus, + ); + + // Execute first time + const result1 = await invocation1.execute(new AbortController().signal); + expect(result1.returnDisplay).toBe('Response 1'); + expect(mockClientManager.sendMessage).toHaveBeenLastCalledWith( + 'test-agent', + 'first', + { contextId: undefined, taskId: undefined }, + ); + + // Prepare for second call with simulated state persistence + mockClientManager.sendMessage.mockResolvedValueOnce({ + kind: 'message', + messageId: 'msg-2', + role: 'agent', + parts: [{ kind: 'text', text: 'Response 2' }], + contextId: 'ctx-1', + taskId: 'task-2', + }); + + const invocation2 = new RemoteAgentInvocation( + mockDefinition, + { + query: 'second', + }, + mockMessageBus, + ); + const result2 = await invocation2.execute(new AbortController().signal); + expect(result2.returnDisplay).toBe('Response 2'); + + expect(mockClientManager.sendMessage).toHaveBeenLastCalledWith( + 'test-agent', + 'second', + { contextId: 'ctx-1', taskId: 'task-1' }, // Used state from first call + ); + + // Third call: Task completes + mockClientManager.sendMessage.mockResolvedValueOnce({ + kind: 'task', + id: 'task-2', + contextId: 'ctx-1', + status: { state: 'completed', message: undefined }, + artifacts: [], + history: [], + }); + + const invocation3 = new RemoteAgentInvocation( + mockDefinition, + { + query: 'third', + }, + mockMessageBus, + ); + await invocation3.execute(new AbortController().signal); + + // Fourth call: Should start new task (taskId undefined) + mockClientManager.sendMessage.mockResolvedValueOnce({ + kind: 'message', + messageId: 'msg-3', + role: 'agent', + parts: [{ kind: 'text', text: 'New Task' }], + }); + + const invocation4 = new RemoteAgentInvocation( + mockDefinition, + { + query: 'fourth', + }, + mockMessageBus, + ); + await invocation4.execute(new AbortController().signal); + + expect(mockClientManager.sendMessage).toHaveBeenLastCalledWith( + 'test-agent', + 'fourth', + { contextId: 'ctx-1', taskId: undefined }, // taskId cleared! + ); + }); + + it('should handle errors gracefully', async () => { + mockClientManager.getClient.mockReturnValue({}); + mockClientManager.sendMessage.mockRejectedValue( + new Error('Network error'), + ); + + const invocation = new RemoteAgentInvocation( + mockDefinition, + { + query: 'hi', + }, + mockMessageBus, + ); + const result = await invocation.execute(new AbortController().signal); + + expect(result.error).toBeDefined(); + expect(result.error?.message).toContain('Network error'); + expect(result.returnDisplay).toContain('Network error'); + }); + + it('should use a2a helpers for extracting text', async () => { + mockClientManager.getClient.mockReturnValue({}); + // Mock a complex message part that needs extraction + mockClientManager.sendMessage.mockResolvedValue({ + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [ + { kind: 'text', text: 'Extracted text' }, + { kind: 'data', data: { foo: 'bar' } }, + ], + }); + + const invocation = new RemoteAgentInvocation( + mockDefinition, + { + query: 'hi', + }, + mockMessageBus, + ); + const result = await invocation.execute(new AbortController().signal); + + // Just check that text is present, exact formatting depends on helper + expect(result.returnDisplay).toContain('Extracted text'); + }); + }); + + describe('Confirmations', () => { + it('should return info confirmation details', async () => { + const invocation = new RemoteAgentInvocation( + mockDefinition, + { + query: 'hi', + }, + mockMessageBus, + ); + // @ts-expect-error - getConfirmationDetails is protected + const confirmation = await invocation.getConfirmationDetails( + new AbortController().signal, + ); + + expect(confirmation).not.toBe(false); + if ( + confirmation && + typeof confirmation === 'object' && + confirmation.type === 'info' + ) { + expect(confirmation.title).toContain('Test Agent'); + expect(confirmation.prompt).toContain('http://test-agent/card'); + } else { + throw new Error('Expected confirmation to be of type info'); + } + }); }); }); diff --git a/packages/core/src/agents/remote-invocation.ts b/packages/core/src/agents/remote-invocation.ts index 28ee8de6bb..4bc23f7fb1 100644 --- a/packages/core/src/agents/remote-invocation.ts +++ b/packages/core/src/agents/remote-invocation.ts @@ -9,8 +9,54 @@ import { type ToolResult, type ToolCallConfirmationDetails, } from '../tools/tools.js'; -import type { AgentInputs, RemoteAgentDefinition } from './types.js'; +import type { + RemoteAgentInputs, + RemoteAgentDefinition, + AgentInputs, +} from './types.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { A2AClientManager } from './a2a-client-manager.js'; +import { + extractMessageText, + extractTaskText, + extractIdsFromResponse, +} from './a2aUtils.js'; +import { GoogleAuth } from 'google-auth-library'; +import type { AuthenticationHandler } from '@a2a-js/sdk/client'; +import { debugLogger } from '../utils/debugLogger.js'; + +/** + * Authentication handler implementation using Google Application Default Credentials (ADC). + */ +export class ADCHandler implements AuthenticationHandler { + private auth = new GoogleAuth({ + scopes: ['https://www.googleapis.com/auth/cloud-platform'], + }); + + async headers(): Promise> { + try { + const client = await this.auth.getClient(); + const token = await client.getAccessToken(); + if (token.token) { + return { Authorization: `Bearer ${token.token}` }; + } + throw new Error('Failed to retrieve ADC access token.'); + } catch (e) { + const errorMessage = `Failed to get ADC token: ${ + e instanceof Error ? e.message : String(e) + }`; + debugLogger.log('ERROR', errorMessage); + throw new Error(errorMessage); + } + } + + async shouldRetryWithHeaders( + _response: unknown, + ): Promise | undefined> { + // For ADC, we usually just re-fetch the token if needed. + return this.headers(); + } +} /** * A tool invocation that proxies to a remote A2A agent. @@ -19,9 +65,22 @@ import type { MessageBus } from '../confirmation-bus/message-bus.js'; * invokes the configured A2A tool. */ export class RemoteAgentInvocation extends BaseToolInvocation< - AgentInputs, + RemoteAgentInputs, ToolResult > { + // Persist state across ephemeral invocation instances. + private static readonly sessionState = new Map< + string, + { contextId?: string; taskId?: string } + >(); + // State for the ongoing conversation with the remote agent + private contextId: string | undefined; + private taskId: string | undefined; + // TODO: See if we can reuse the singleton from AppContainer or similar, but for now use getInstance directly + // as per the current pattern in the codebase. + private readonly clientManager = A2AClientManager.getInstance(); + private readonly authHandler = new ADCHandler(); + constructor( private readonly definition: RemoteAgentDefinition, params: AgentInputs, @@ -29,8 +88,15 @@ export class RemoteAgentInvocation extends BaseToolInvocation< _toolName?: string, _toolDisplayName?: string, ) { + const query = params['query']; + if (typeof query !== 'string') { + throw new Error( + `Remote agent '${definition.name}' requires a string 'query' input.`, + ); + } + // Safe to pass strict object to super super( - params, + { query }, messageBus, _toolName ?? definition.name, _toolDisplayName ?? definition.displayName, @@ -44,12 +110,81 @@ export class RemoteAgentInvocation extends BaseToolInvocation< protected override async getConfirmationDetails( _abortSignal: AbortSignal, ): Promise { - // TODO: Implement confirmation logic for remote agents. - return false; + // For now, always require confirmation for remote agents until we have a policy system for them. + return { + type: 'info', + title: `Call Remote Agent: ${this.definition.displayName ?? this.definition.name}`, + prompt: `This will send a message to the external agent at ${this.definition.agentCardUrl}.`, + onConfirm: async () => {}, // No-op for now, just informational + }; } async execute(_signal: AbortSignal): Promise { - // TODO: Implement remote agent invocation logic. - throw new Error(`Remote agent invocation not implemented.`); + // 1. Ensure the agent is loaded (cached by manager) + // We assume the user has provided an access token via some mechanism (TODO), + // or we rely on ADC. + try { + const priorState = RemoteAgentInvocation.sessionState.get( + this.definition.name, + ); + if (priorState) { + this.contextId = priorState.contextId; + this.taskId = priorState.taskId; + } + + if (!this.clientManager.getClient(this.definition.name)) { + await this.clientManager.loadAgent( + this.definition.name, + this.definition.agentCardUrl, + this.authHandler, + ); + } + + const message = this.params.query; + + const response = await this.clientManager.sendMessage( + this.definition.name, + message, + { + contextId: this.contextId, + taskId: this.taskId, + }, + ); + + // Extracts IDs, taskID will be undefined if the task is completed/failed/canceled. + const { contextId, taskId } = extractIdsFromResponse(response); + + this.contextId = contextId ?? this.contextId; + this.taskId = taskId; + + RemoteAgentInvocation.sessionState.set(this.definition.name, { + contextId: this.contextId, + taskId: this.taskId, + }); + + // Extract the output text + const resultData = response; + let outputText = ''; + + if (resultData.kind === 'message') { + outputText = extractMessageText(resultData); + } else if (resultData.kind === 'task') { + outputText = extractTaskText(resultData); + } else { + outputText = JSON.stringify(resultData); + } + + return { + llmContent: [{ text: outputText }], + returnDisplay: outputText, + }; + } catch (error: unknown) { + const errorMessage = `Error calling remote agent: ${error instanceof Error ? error.message : String(error)}`; + return { + llmContent: [{ text: errorMessage }], + returnDisplay: errorMessage, + error: { message: errorMessage }, + }; + } } } diff --git a/packages/core/src/agents/types.ts b/packages/core/src/agents/types.ts index c42a3103ac..f0d2743662 100644 --- a/packages/core/src/agents/types.ts +++ b/packages/core/src/agents/types.ts @@ -38,6 +38,11 @@ export interface OutputObject { */ export type AgentInputs = Record; +/** + * Simplified input structure for Remote Agents, which consumes a single string query. + */ +export type RemoteAgentInputs = { query: string }; + /** * Structured events emitted during subagent execution for user observability. */