mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-11 06:31:01 -07:00
feat(agents): add support for remote agents (#16013)
This commit is contained in:
@@ -8,6 +8,7 @@ import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
A2AClientManager,
|
||||
type SendMessageResult,
|
||||
createAdapterFetch,
|
||||
} from './a2a-client-manager.js';
|
||||
import type { AgentCard, Task } from '@a2a-js/sdk';
|
||||
import type { AuthenticationHandler, Client } from '@a2a-js/sdk/client';
|
||||
@@ -302,4 +303,42 @@ describe('A2AClientManager', () => {
|
||||
).rejects.toThrow("Agent 'NonExistentAgent' not found.");
|
||||
});
|
||||
});
|
||||
|
||||
describe('createAdapterFetch', () => {
|
||||
it('normalizes TASK_STATE_ enums to lower-case', async () => {
|
||||
const baseFetch = vi
|
||||
.fn()
|
||||
.mockResolvedValue(
|
||||
new Response(
|
||||
JSON.stringify({ status: { state: 'TASK_STATE_WORKING' } }),
|
||||
),
|
||||
);
|
||||
|
||||
const adapter = createAdapterFetch(baseFetch as typeof fetch);
|
||||
const response = await adapter('http://example.com', {
|
||||
method: 'POST',
|
||||
body: '{}',
|
||||
});
|
||||
const data = await response.json();
|
||||
|
||||
expect(data.status.state).toBe('working');
|
||||
});
|
||||
|
||||
it('lowercases non-prefixed task states', async () => {
|
||||
const baseFetch = vi
|
||||
.fn()
|
||||
.mockResolvedValue(
|
||||
new Response(JSON.stringify({ status: { state: 'WORKING' } })),
|
||||
);
|
||||
|
||||
const adapter = createAdapterFetch(baseFetch as typeof fetch);
|
||||
const response = await adapter('http://example.com', {
|
||||
method: 'POST',
|
||||
body: '{}',
|
||||
});
|
||||
const data = await response.json();
|
||||
|
||||
expect(data.status.state).toBe('working');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -68,11 +68,15 @@ export class A2AClientManager {
|
||||
throw new Error(`Agent with name '${name}' is already loaded.`);
|
||||
}
|
||||
|
||||
let fetchImpl = fetch;
|
||||
let fetchImpl: typeof fetch = fetch;
|
||||
if (authHandler) {
|
||||
fetchImpl = createAuthenticatingFetchWithRetry(fetch, authHandler);
|
||||
}
|
||||
|
||||
// Wrap with custom adapter for ADK Reasoning Engine compatibility
|
||||
// TODO: Remove this when a2a-js fixes compatibility
|
||||
fetchImpl = createAdapterFetch(fetchImpl);
|
||||
|
||||
const resolver = new DefaultAgentCardResolver({ fetchImpl });
|
||||
|
||||
const options = ClientFactoryOptions.createFrom(
|
||||
@@ -207,3 +211,148 @@ export class A2AClientManager {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps TaskState proto-JSON enums to lower-case strings.
|
||||
*/
|
||||
function mapTaskState(state: string | undefined): string | undefined {
|
||||
if (!state) return state;
|
||||
if (state.startsWith('TASK_STATE_')) {
|
||||
return state.replace('TASK_STATE_', '').toLowerCase();
|
||||
}
|
||||
return state.toLowerCase();
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a fetch implementation that adapts standard A2A SDK requests to the
|
||||
* proto-JSON dialect and endpoint shapes required by Vertex AI Agent Engine.
|
||||
*/
|
||||
export function createAdapterFetch(baseFetch: typeof fetch): typeof fetch {
|
||||
return async (
|
||||
input: RequestInfo | URL,
|
||||
init?: RequestInit,
|
||||
): Promise<Response> => {
|
||||
const urlStr = input as string;
|
||||
|
||||
// 2. Dialect Mapping (Request)
|
||||
let body = init?.body;
|
||||
let isRpc = false;
|
||||
let rpcId: string | number | undefined;
|
||||
|
||||
if (typeof body === 'string') {
|
||||
try {
|
||||
let jsonBody = JSON.parse(body);
|
||||
|
||||
// Unwrap JSON-RPC if present
|
||||
if (jsonBody.jsonrpc === '2.0') {
|
||||
isRpc = true;
|
||||
rpcId = jsonBody.id;
|
||||
jsonBody = jsonBody.params;
|
||||
}
|
||||
|
||||
// Apply dialect translation to the message object
|
||||
const message = jsonBody.message || jsonBody;
|
||||
if (message && typeof message === 'object') {
|
||||
// Role: user -> ROLE_USER, agent/model -> ROLE_AGENT
|
||||
if (message.role === 'user') message.role = 'ROLE_USER';
|
||||
if (message.role === 'agent' || message.role === 'model') {
|
||||
message.role = 'ROLE_AGENT';
|
||||
}
|
||||
|
||||
// Strip SDK-specific 'kind' field
|
||||
delete message.kind;
|
||||
|
||||
// Map 'parts' to 'content' (Proto-JSON dialect often uses 'content' or typed parts)
|
||||
// Also strip 'kind' from parts.
|
||||
if (Array.isArray(message.parts)) {
|
||||
message.content = message.parts.map(
|
||||
(p: { kind?: string; text?: string }) => {
|
||||
const { kind: _k, ...rest } = p;
|
||||
// If it's a simple text part, ensure it matches { text: "..." }
|
||||
if (p.kind === 'text') return { text: p.text };
|
||||
return rest;
|
||||
},
|
||||
);
|
||||
delete message.parts;
|
||||
}
|
||||
}
|
||||
|
||||
body = JSON.stringify(jsonBody);
|
||||
} catch (error) {
|
||||
debugLogger.debug(
|
||||
'[A2AClientManager] Failed to parse request body for dialect translation:',
|
||||
error,
|
||||
);
|
||||
// Non-JSON or parse error; let the baseFetch handle it.
|
||||
}
|
||||
}
|
||||
|
||||
const response = await baseFetch(urlStr, { ...init, body });
|
||||
|
||||
// Map response back
|
||||
if (response.ok) {
|
||||
try {
|
||||
const responseData = await response.clone().json();
|
||||
|
||||
const result =
|
||||
responseData.task || responseData.message || responseData;
|
||||
|
||||
// Restore 'kind' for the SDK and a2aUtils parsing
|
||||
if (result && typeof result === 'object' && !result.kind) {
|
||||
if (responseData.task || (result.id && result.status)) {
|
||||
result.kind = 'task';
|
||||
} else if (responseData.message || result.messageId) {
|
||||
result.kind = 'message';
|
||||
}
|
||||
}
|
||||
|
||||
// Restore 'kind' on parts so extractMessageText works
|
||||
if (result?.parts && Array.isArray(result.parts)) {
|
||||
for (const part of result.parts) {
|
||||
if (!part.kind) {
|
||||
if (part.file) part.kind = 'file';
|
||||
else if (part.data) part.kind = 'data';
|
||||
else if (part.text) part.kind = 'text';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively restore 'kind' on artifact parts
|
||||
if (result?.artifacts && Array.isArray(result.artifacts)) {
|
||||
for (const artifact of result.artifacts) {
|
||||
if (artifact.parts && Array.isArray(artifact.parts)) {
|
||||
for (const part of artifact.parts) {
|
||||
if (!part.kind) {
|
||||
if (part.file) part.kind = 'file';
|
||||
else if (part.data) part.kind = 'data';
|
||||
else if (part.text) part.kind = 'text';
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Map Task States back to SDK expectations
|
||||
if (result && typeof result === 'object' && result.status) {
|
||||
result.status.state = mapTaskState(result.status.state);
|
||||
}
|
||||
|
||||
if (isRpc) {
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
jsonrpc: '2.0',
|
||||
id: rpcId,
|
||||
result,
|
||||
}),
|
||||
response,
|
||||
);
|
||||
}
|
||||
return new Response(JSON.stringify(result), response);
|
||||
} catch (_e) {
|
||||
// Non-JSON response or unwrapping failure
|
||||
}
|
||||
}
|
||||
|
||||
return response;
|
||||
};
|
||||
}
|
||||
|
||||
171
packages/core/src/agents/a2aUtils.test.ts
Normal file
171
packages/core/src/agents/a2aUtils.test.ts
Normal file
@@ -0,0 +1,171 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import {
|
||||
extractMessageText,
|
||||
extractTaskText,
|
||||
extractIdsFromResponse,
|
||||
} from './a2aUtils.js';
|
||||
import type { Message, Task, TextPart, DataPart, FilePart } from '@a2a-js/sdk';
|
||||
|
||||
describe('a2aUtils', () => {
|
||||
describe('extractIdsFromResponse', () => {
|
||||
it('should extract IDs from a message response', () => {
|
||||
const message: Message = {
|
||||
kind: 'message',
|
||||
role: 'agent',
|
||||
messageId: 'm1',
|
||||
contextId: 'ctx-1',
|
||||
taskId: 'task-1',
|
||||
parts: [],
|
||||
};
|
||||
|
||||
const result = extractIdsFromResponse(message);
|
||||
expect(result).toEqual({ contextId: 'ctx-1', taskId: 'task-1' });
|
||||
});
|
||||
|
||||
it('should extract IDs from an in-progress task response', () => {
|
||||
const task: Task = {
|
||||
id: 'task-2',
|
||||
contextId: 'ctx-2',
|
||||
kind: 'task',
|
||||
status: { state: 'working' },
|
||||
};
|
||||
|
||||
const result = extractIdsFromResponse(task);
|
||||
expect(result).toEqual({ contextId: 'ctx-2', taskId: 'task-2' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractMessageText', () => {
|
||||
it('should extract text from simple text parts', () => {
|
||||
const message: Message = {
|
||||
kind: 'message',
|
||||
role: 'user',
|
||||
messageId: '1',
|
||||
parts: [
|
||||
{ kind: 'text', text: 'Hello' } as TextPart,
|
||||
{ kind: 'text', text: 'World' } as TextPart,
|
||||
],
|
||||
};
|
||||
expect(extractMessageText(message)).toBe('Hello\nWorld');
|
||||
});
|
||||
|
||||
it('should extract data from data parts', () => {
|
||||
const message: Message = {
|
||||
kind: 'message',
|
||||
role: 'user',
|
||||
messageId: '1',
|
||||
parts: [{ kind: 'data', data: { foo: 'bar' } } as DataPart],
|
||||
};
|
||||
expect(extractMessageText(message)).toBe('Data: {"foo":"bar"}');
|
||||
});
|
||||
|
||||
it('should extract file info from file parts', () => {
|
||||
const message: Message = {
|
||||
kind: 'message',
|
||||
role: 'user',
|
||||
messageId: '1',
|
||||
parts: [
|
||||
{
|
||||
kind: 'file',
|
||||
file: {
|
||||
name: 'test.txt',
|
||||
uri: 'file://test.txt',
|
||||
mimeType: 'text/plain',
|
||||
},
|
||||
} as FilePart,
|
||||
{
|
||||
kind: 'file',
|
||||
file: {
|
||||
uri: 'http://example.com/doc',
|
||||
mimeType: 'application/pdf',
|
||||
},
|
||||
} as FilePart,
|
||||
],
|
||||
};
|
||||
// The formatting logic in a2aUtils prefers name over uri
|
||||
expect(extractMessageText(message)).toContain('File: test.txt');
|
||||
expect(extractMessageText(message)).toContain(
|
||||
'File: http://example.com/doc',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle mixed parts', () => {
|
||||
const message: Message = {
|
||||
kind: 'message',
|
||||
role: 'user',
|
||||
messageId: '1',
|
||||
parts: [
|
||||
{ kind: 'text', text: 'Here is data:' } as TextPart,
|
||||
{ kind: 'data', data: { value: 123 } } as DataPart,
|
||||
],
|
||||
};
|
||||
expect(extractMessageText(message)).toBe(
|
||||
'Here is data:\nData: {"value":123}',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return empty string for undefined or empty message', () => {
|
||||
expect(extractMessageText(undefined)).toBe('');
|
||||
expect(
|
||||
extractMessageText({
|
||||
kind: 'message',
|
||||
role: 'user',
|
||||
messageId: '1',
|
||||
parts: [],
|
||||
} as Message),
|
||||
).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractTaskText', () => {
|
||||
it('should extract basic task info', () => {
|
||||
const task: Task = {
|
||||
id: 'task-1',
|
||||
contextId: 'ctx-1',
|
||||
kind: 'task',
|
||||
status: {
|
||||
state: 'working',
|
||||
message: {
|
||||
kind: 'message',
|
||||
role: 'agent',
|
||||
messageId: 'm1',
|
||||
parts: [{ kind: 'text', text: 'Processing...' } as TextPart],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = extractTaskText(task);
|
||||
expect(result).toContain('ID: task-1');
|
||||
expect(result).toContain('State: working');
|
||||
expect(result).toContain('Status Message: Processing...');
|
||||
});
|
||||
|
||||
it('should extract artifacts', () => {
|
||||
const task: Task = {
|
||||
id: 'task-1',
|
||||
contextId: 'ctx-1',
|
||||
kind: 'task',
|
||||
status: { state: 'completed' },
|
||||
artifacts: [
|
||||
{
|
||||
artifactId: 'art-1',
|
||||
name: 'Report',
|
||||
parts: [{ kind: 'text', text: 'This is the report.' } as TextPart],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const result = extractTaskText(task);
|
||||
expect(result).toContain('Artifacts:');
|
||||
expect(result).toContain(' - Name: Report');
|
||||
expect(result).toContain(' Content:');
|
||||
expect(result).toContain(' This is the report.');
|
||||
});
|
||||
});
|
||||
});
|
||||
142
packages/core/src/agents/a2aUtils.ts
Normal file
142
packages/core/src/agents/a2aUtils.ts
Normal file
@@ -0,0 +1,142 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
Message,
|
||||
Task,
|
||||
Part,
|
||||
TextPart,
|
||||
DataPart,
|
||||
FilePart,
|
||||
} from '@a2a-js/sdk';
|
||||
|
||||
/**
|
||||
* Extracts a human-readable text representation from a Message object.
|
||||
* Handles Text, Data (JSON), and File parts.
|
||||
*/
|
||||
export function extractMessageText(message: Message | undefined): string {
|
||||
if (!message || !message.parts) {
|
||||
return '';
|
||||
}
|
||||
|
||||
const parts = message.parts
|
||||
.map((part) => extractPartText(part))
|
||||
.filter(Boolean);
|
||||
return parts.join('\n');
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts text from a single Part.
|
||||
*/
|
||||
export function extractPartText(part: Part): string {
|
||||
if (isTextPart(part)) {
|
||||
return part.text;
|
||||
}
|
||||
|
||||
if (isDataPart(part)) {
|
||||
// Attempt to format known data types if metadata exists, otherwise JSON stringify
|
||||
return `Data: ${JSON.stringify(part.data)}`;
|
||||
}
|
||||
|
||||
if (isFilePart(part)) {
|
||||
const fileData = part.file;
|
||||
if (fileData.name) {
|
||||
return `File: ${fileData.name}`;
|
||||
}
|
||||
if ('uri' in fileData && fileData.uri) {
|
||||
return `File: ${fileData.uri}`;
|
||||
}
|
||||
return `File: [binary/unnamed]`;
|
||||
}
|
||||
|
||||
return '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts a human-readable text summary from a Task object.
|
||||
* Includes status, ID, and any artifact content.
|
||||
*/
|
||||
export function extractTaskText(task: Task): string {
|
||||
let output = `ID: ${task.id}\n`;
|
||||
output += `State: ${task.status.state}\n`;
|
||||
|
||||
// Status Message
|
||||
const statusMessageText = extractMessageText(task.status.message);
|
||||
if (statusMessageText) {
|
||||
output += `Status Message: ${statusMessageText}\n`;
|
||||
}
|
||||
|
||||
// Artifacts
|
||||
if (task.artifacts && task.artifacts.length > 0) {
|
||||
output += `Artifacts:\n`;
|
||||
for (const artifact of task.artifacts) {
|
||||
output += ` - Name: ${artifact.name}\n`;
|
||||
if (artifact.parts && artifact.parts.length > 0) {
|
||||
// Treat artifact parts as a message for extraction
|
||||
const artifactContent = artifact.parts
|
||||
.map((p) => extractPartText(p))
|
||||
.filter(Boolean)
|
||||
.join('\n');
|
||||
|
||||
if (artifactContent) {
|
||||
// Indent content for readability
|
||||
const indentedContent = artifactContent.replace(/^/gm, ' ');
|
||||
output += ` Content:\n${indentedContent}\n`;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
// Type Guards
|
||||
|
||||
function isTextPart(part: Part): part is TextPart {
|
||||
return part.kind === 'text';
|
||||
}
|
||||
|
||||
function isDataPart(part: Part): part is DataPart {
|
||||
return part.kind === 'data';
|
||||
}
|
||||
|
||||
function isFilePart(part: Part): part is FilePart {
|
||||
return part.kind === 'file';
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts contextId and taskId from a Message or Task response.
|
||||
* Follows the pattern from the A2A CLI sample to maintain conversational continuity.
|
||||
*/
|
||||
export function extractIdsFromResponse(result: Message | Task): {
|
||||
contextId?: string;
|
||||
taskId?: string;
|
||||
} {
|
||||
let contextId: string | undefined;
|
||||
let taskId: string | undefined;
|
||||
|
||||
if (result.kind === 'message') {
|
||||
taskId = result.taskId;
|
||||
contextId = result.contextId;
|
||||
} else if (result.kind === 'task') {
|
||||
taskId = result.id;
|
||||
contextId = result.contextId;
|
||||
|
||||
// If the task is in a final state (and not input-required), we clear the taskId
|
||||
// so that the next interaction starts a fresh task (or keeps context without being bound to the old task).
|
||||
if (
|
||||
result.status &&
|
||||
result.status.state !== 'input-required' &&
|
||||
(result.status.state === 'completed' ||
|
||||
result.status.state === 'failed' ||
|
||||
result.status.state === 'canceled')
|
||||
) {
|
||||
taskId = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
return { contextId, taskId };
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import { LocalSubagentInvocation } from './local-invocation.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { MessageBusType } from '../confirmation-bus/types.js';
|
||||
import { DELEGATE_TO_AGENT_TOOL_NAME } from '../tools/tool-names.js';
|
||||
import { RemoteAgentInvocation } from './remote-invocation.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
|
||||
vi.mock('./local-invocation.js', () => ({
|
||||
@@ -23,6 +24,15 @@ vi.mock('./local-invocation.js', () => ({
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('./remote-invocation.js', () => ({
|
||||
RemoteAgentInvocation: vi.fn().mockImplementation(() => ({
|
||||
execute: vi.fn().mockResolvedValue({
|
||||
content: [{ type: 'text', text: 'Remote Success' }],
|
||||
}),
|
||||
shouldConfirmExecute: vi.fn().mockResolvedValue(true),
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('DelegateToAgentTool', () => {
|
||||
let registry: AgentRegistry;
|
||||
let config: Config;
|
||||
@@ -45,6 +55,18 @@ describe('DelegateToAgentTool', () => {
|
||||
toolConfig: { tools: [] },
|
||||
};
|
||||
|
||||
const mockRemoteAgentDef: AgentDefinition = {
|
||||
kind: 'remote',
|
||||
name: 'remote_agent',
|
||||
description: 'A remote agent',
|
||||
agentCardUrl: 'https://example.com/agent.json',
|
||||
inputConfig: {
|
||||
inputs: {
|
||||
query: { type: 'string', description: 'Query', required: true },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
config = {
|
||||
getDebugMode: () => false,
|
||||
@@ -58,6 +80,8 @@ describe('DelegateToAgentTool', () => {
|
||||
// Manually register the mock agent (bypassing protected method for testing)
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(registry as any).agents.set(mockAgentDef.name, mockAgentDef);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(registry as any).agents.set(mockRemoteAgentDef.name, mockRemoteAgentDef);
|
||||
|
||||
messageBus = createMockMessageBus();
|
||||
|
||||
@@ -176,4 +200,23 @@ describe('DelegateToAgentTool', () => {
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should delegate to remote agent correctly', async () => {
|
||||
const invocation = tool.build({
|
||||
agent_name: 'remote_agent',
|
||||
query: 'hello remote',
|
||||
});
|
||||
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
expect(result).toEqual({
|
||||
content: [{ type: 'text', text: 'Remote Success' }],
|
||||
});
|
||||
expect(RemoteAgentInvocation).toHaveBeenCalledWith(
|
||||
mockRemoteAgentDef,
|
||||
{ query: 'hello remote' },
|
||||
messageBus,
|
||||
'remote_agent',
|
||||
'remote_agent',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,55 +4,310 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import type { ToolCallConfirmationDetails } from '../tools/tools.js';
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import { RemoteAgentInvocation } from './remote-invocation.js';
|
||||
import { A2AClientManager } from './a2a-client-manager.js';
|
||||
import type { RemoteAgentDefinition } from './types.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
|
||||
class TestableRemoteAgentInvocation extends RemoteAgentInvocation {
|
||||
override async getConfirmationDetails(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
return super.getConfirmationDetails(abortSignal);
|
||||
}
|
||||
}
|
||||
// Mock A2AClientManager
|
||||
vi.mock('./a2a-client-manager.js', () => {
|
||||
const A2AClientManager = {
|
||||
getInstance: vi.fn(),
|
||||
};
|
||||
return { A2AClientManager };
|
||||
});
|
||||
|
||||
describe('RemoteAgentInvocation', () => {
|
||||
const mockDefinition: RemoteAgentDefinition = {
|
||||
name: 'test-agent',
|
||||
kind: 'remote',
|
||||
name: 'test-remote-agent',
|
||||
description: 'A test remote agent',
|
||||
displayName: 'Test Remote Agent',
|
||||
agentCardUrl: 'https://example.com/agent-card',
|
||||
agentCardUrl: 'http://test-agent/card',
|
||||
displayName: 'Test Agent',
|
||||
description: 'A test agent',
|
||||
inputConfig: {
|
||||
inputs: {},
|
||||
},
|
||||
};
|
||||
|
||||
const mockClientManager = {
|
||||
getClient: vi.fn(),
|
||||
loadAgent: vi.fn(),
|
||||
sendMessage: vi.fn(),
|
||||
};
|
||||
const mockMessageBus = createMockMessageBus();
|
||||
|
||||
it('should be instantiated with correct params', () => {
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{},
|
||||
mockMessageBus,
|
||||
);
|
||||
expect(invocation).toBeDefined();
|
||||
expect(invocation.getDescription()).toBe(
|
||||
'Calling remote agent Test Remote Agent',
|
||||
);
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
(A2AClientManager.getInstance as Mock).mockReturnValue(mockClientManager);
|
||||
(
|
||||
RemoteAgentInvocation as unknown as {
|
||||
sessionState?: Map<string, { contextId?: string; taskId?: string }>;
|
||||
}
|
||||
).sessionState?.clear();
|
||||
});
|
||||
|
||||
it('should return false for confirmation details (not yet implemented)', async () => {
|
||||
const invocation = new TestableRemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{},
|
||||
mockMessageBus,
|
||||
);
|
||||
const details = await invocation.getConfirmationDetails(
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(details).toBe(false);
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('Constructor Validation', () => {
|
||||
it('accepts valid input with string query', () => {
|
||||
expect(() => {
|
||||
new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{ query: 'valid' },
|
||||
mockMessageBus,
|
||||
);
|
||||
}).not.toThrow();
|
||||
});
|
||||
|
||||
it('throws if query is missing', () => {
|
||||
expect(() => {
|
||||
new RemoteAgentInvocation(mockDefinition, {}, mockMessageBus);
|
||||
}).toThrow("requires a string 'query' input");
|
||||
});
|
||||
|
||||
it('throws if query is not a string', () => {
|
||||
expect(() => {
|
||||
new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{ query: 123 },
|
||||
mockMessageBus,
|
||||
);
|
||||
}).toThrow("requires a string 'query' input");
|
||||
});
|
||||
});
|
||||
|
||||
describe('Execution Logic', () => {
|
||||
it('should lazy load the agent with ADCHandler if not present', async () => {
|
||||
mockClientManager.getClient.mockReturnValue(undefined);
|
||||
mockClientManager.sendMessage.mockResolvedValue({
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Hello' }],
|
||||
});
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{
|
||||
query: 'hi',
|
||||
},
|
||||
mockMessageBus,
|
||||
);
|
||||
await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(mockClientManager.loadAgent).toHaveBeenCalledWith(
|
||||
'test-agent',
|
||||
'http://test-agent/card',
|
||||
expect.objectContaining({
|
||||
headers: expect.any(Function),
|
||||
shouldRetryWithHeaders: expect.any(Function),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not load the agent if already present', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.sendMessage.mockResolvedValue({
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Hello' }],
|
||||
});
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{
|
||||
query: 'hi',
|
||||
},
|
||||
mockMessageBus,
|
||||
);
|
||||
await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(mockClientManager.loadAgent).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should persist contextId and taskId across invocations', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
|
||||
// First call return values
|
||||
mockClientManager.sendMessage.mockResolvedValueOnce({
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Response 1' }],
|
||||
contextId: 'ctx-1',
|
||||
taskId: 'task-1',
|
||||
});
|
||||
|
||||
const invocation1 = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{
|
||||
query: 'first',
|
||||
},
|
||||
mockMessageBus,
|
||||
);
|
||||
|
||||
// Execute first time
|
||||
const result1 = await invocation1.execute(new AbortController().signal);
|
||||
expect(result1.returnDisplay).toBe('Response 1');
|
||||
expect(mockClientManager.sendMessage).toHaveBeenLastCalledWith(
|
||||
'test-agent',
|
||||
'first',
|
||||
{ contextId: undefined, taskId: undefined },
|
||||
);
|
||||
|
||||
// Prepare for second call with simulated state persistence
|
||||
mockClientManager.sendMessage.mockResolvedValueOnce({
|
||||
kind: 'message',
|
||||
messageId: 'msg-2',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Response 2' }],
|
||||
contextId: 'ctx-1',
|
||||
taskId: 'task-2',
|
||||
});
|
||||
|
||||
const invocation2 = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{
|
||||
query: 'second',
|
||||
},
|
||||
mockMessageBus,
|
||||
);
|
||||
const result2 = await invocation2.execute(new AbortController().signal);
|
||||
expect(result2.returnDisplay).toBe('Response 2');
|
||||
|
||||
expect(mockClientManager.sendMessage).toHaveBeenLastCalledWith(
|
||||
'test-agent',
|
||||
'second',
|
||||
{ contextId: 'ctx-1', taskId: 'task-1' }, // Used state from first call
|
||||
);
|
||||
|
||||
// Third call: Task completes
|
||||
mockClientManager.sendMessage.mockResolvedValueOnce({
|
||||
kind: 'task',
|
||||
id: 'task-2',
|
||||
contextId: 'ctx-1',
|
||||
status: { state: 'completed', message: undefined },
|
||||
artifacts: [],
|
||||
history: [],
|
||||
});
|
||||
|
||||
const invocation3 = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{
|
||||
query: 'third',
|
||||
},
|
||||
mockMessageBus,
|
||||
);
|
||||
await invocation3.execute(new AbortController().signal);
|
||||
|
||||
// Fourth call: Should start new task (taskId undefined)
|
||||
mockClientManager.sendMessage.mockResolvedValueOnce({
|
||||
kind: 'message',
|
||||
messageId: 'msg-3',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'New Task' }],
|
||||
});
|
||||
|
||||
const invocation4 = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{
|
||||
query: 'fourth',
|
||||
},
|
||||
mockMessageBus,
|
||||
);
|
||||
await invocation4.execute(new AbortController().signal);
|
||||
|
||||
expect(mockClientManager.sendMessage).toHaveBeenLastCalledWith(
|
||||
'test-agent',
|
||||
'fourth',
|
||||
{ contextId: 'ctx-1', taskId: undefined }, // taskId cleared!
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle errors gracefully', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.sendMessage.mockRejectedValue(
|
||||
new Error('Network error'),
|
||||
);
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{
|
||||
query: 'hi',
|
||||
},
|
||||
mockMessageBus,
|
||||
);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(result.error).toBeDefined();
|
||||
expect(result.error?.message).toContain('Network error');
|
||||
expect(result.returnDisplay).toContain('Network error');
|
||||
});
|
||||
|
||||
it('should use a2a helpers for extracting text', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
// Mock a complex message part that needs extraction
|
||||
mockClientManager.sendMessage.mockResolvedValue({
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [
|
||||
{ kind: 'text', text: 'Extracted text' },
|
||||
{ kind: 'data', data: { foo: 'bar' } },
|
||||
],
|
||||
});
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{
|
||||
query: 'hi',
|
||||
},
|
||||
mockMessageBus,
|
||||
);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
|
||||
// Just check that text is present, exact formatting depends on helper
|
||||
expect(result.returnDisplay).toContain('Extracted text');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Confirmations', () => {
|
||||
it('should return info confirmation details', async () => {
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{
|
||||
query: 'hi',
|
||||
},
|
||||
mockMessageBus,
|
||||
);
|
||||
// @ts-expect-error - getConfirmationDetails is protected
|
||||
const confirmation = await invocation.getConfirmationDetails(
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
expect(confirmation).not.toBe(false);
|
||||
if (
|
||||
confirmation &&
|
||||
typeof confirmation === 'object' &&
|
||||
confirmation.type === 'info'
|
||||
) {
|
||||
expect(confirmation.title).toContain('Test Agent');
|
||||
expect(confirmation.prompt).toContain('http://test-agent/card');
|
||||
} else {
|
||||
throw new Error('Expected confirmation to be of type info');
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -9,8 +9,54 @@ import {
|
||||
type ToolResult,
|
||||
type ToolCallConfirmationDetails,
|
||||
} from '../tools/tools.js';
|
||||
import type { AgentInputs, RemoteAgentDefinition } from './types.js';
|
||||
import type {
|
||||
RemoteAgentInputs,
|
||||
RemoteAgentDefinition,
|
||||
AgentInputs,
|
||||
} from './types.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { A2AClientManager } from './a2a-client-manager.js';
|
||||
import {
|
||||
extractMessageText,
|
||||
extractTaskText,
|
||||
extractIdsFromResponse,
|
||||
} from './a2aUtils.js';
|
||||
import { GoogleAuth } from 'google-auth-library';
|
||||
import type { AuthenticationHandler } from '@a2a-js/sdk/client';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
/**
|
||||
* Authentication handler implementation using Google Application Default Credentials (ADC).
|
||||
*/
|
||||
export class ADCHandler implements AuthenticationHandler {
|
||||
private auth = new GoogleAuth({
|
||||
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
|
||||
});
|
||||
|
||||
async headers(): Promise<Record<string, string>> {
|
||||
try {
|
||||
const client = await this.auth.getClient();
|
||||
const token = await client.getAccessToken();
|
||||
if (token.token) {
|
||||
return { Authorization: `Bearer ${token.token}` };
|
||||
}
|
||||
throw new Error('Failed to retrieve ADC access token.');
|
||||
} catch (e) {
|
||||
const errorMessage = `Failed to get ADC token: ${
|
||||
e instanceof Error ? e.message : String(e)
|
||||
}`;
|
||||
debugLogger.log('ERROR', errorMessage);
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
}
|
||||
|
||||
async shouldRetryWithHeaders(
|
||||
_response: unknown,
|
||||
): Promise<Record<string, string> | undefined> {
|
||||
// For ADC, we usually just re-fetch the token if needed.
|
||||
return this.headers();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A tool invocation that proxies to a remote A2A agent.
|
||||
@@ -19,9 +65,22 @@ import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
* invokes the configured A2A tool.
|
||||
*/
|
||||
export class RemoteAgentInvocation extends BaseToolInvocation<
|
||||
AgentInputs,
|
||||
RemoteAgentInputs,
|
||||
ToolResult
|
||||
> {
|
||||
// Persist state across ephemeral invocation instances.
|
||||
private static readonly sessionState = new Map<
|
||||
string,
|
||||
{ contextId?: string; taskId?: string }
|
||||
>();
|
||||
// State for the ongoing conversation with the remote agent
|
||||
private contextId: string | undefined;
|
||||
private taskId: string | undefined;
|
||||
// TODO: See if we can reuse the singleton from AppContainer or similar, but for now use getInstance directly
|
||||
// as per the current pattern in the codebase.
|
||||
private readonly clientManager = A2AClientManager.getInstance();
|
||||
private readonly authHandler = new ADCHandler();
|
||||
|
||||
constructor(
|
||||
private readonly definition: RemoteAgentDefinition,
|
||||
params: AgentInputs,
|
||||
@@ -29,8 +88,15 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
const query = params['query'];
|
||||
if (typeof query !== 'string') {
|
||||
throw new Error(
|
||||
`Remote agent '${definition.name}' requires a string 'query' input.`,
|
||||
);
|
||||
}
|
||||
// Safe to pass strict object to super
|
||||
super(
|
||||
params,
|
||||
{ query },
|
||||
messageBus,
|
||||
_toolName ?? definition.name,
|
||||
_toolDisplayName ?? definition.displayName,
|
||||
@@ -44,12 +110,81 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
|
||||
protected override async getConfirmationDetails(
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
// TODO: Implement confirmation logic for remote agents.
|
||||
return false;
|
||||
// For now, always require confirmation for remote agents until we have a policy system for them.
|
||||
return {
|
||||
type: 'info',
|
||||
title: `Call Remote Agent: ${this.definition.displayName ?? this.definition.name}`,
|
||||
prompt: `This will send a message to the external agent at ${this.definition.agentCardUrl}.`,
|
||||
onConfirm: async () => {}, // No-op for now, just informational
|
||||
};
|
||||
}
|
||||
|
||||
async execute(_signal: AbortSignal): Promise<ToolResult> {
|
||||
// TODO: Implement remote agent invocation logic.
|
||||
throw new Error(`Remote agent invocation not implemented.`);
|
||||
// 1. Ensure the agent is loaded (cached by manager)
|
||||
// We assume the user has provided an access token via some mechanism (TODO),
|
||||
// or we rely on ADC.
|
||||
try {
|
||||
const priorState = RemoteAgentInvocation.sessionState.get(
|
||||
this.definition.name,
|
||||
);
|
||||
if (priorState) {
|
||||
this.contextId = priorState.contextId;
|
||||
this.taskId = priorState.taskId;
|
||||
}
|
||||
|
||||
if (!this.clientManager.getClient(this.definition.name)) {
|
||||
await this.clientManager.loadAgent(
|
||||
this.definition.name,
|
||||
this.definition.agentCardUrl,
|
||||
this.authHandler,
|
||||
);
|
||||
}
|
||||
|
||||
const message = this.params.query;
|
||||
|
||||
const response = await this.clientManager.sendMessage(
|
||||
this.definition.name,
|
||||
message,
|
||||
{
|
||||
contextId: this.contextId,
|
||||
taskId: this.taskId,
|
||||
},
|
||||
);
|
||||
|
||||
// Extracts IDs, taskID will be undefined if the task is completed/failed/canceled.
|
||||
const { contextId, taskId } = extractIdsFromResponse(response);
|
||||
|
||||
this.contextId = contextId ?? this.contextId;
|
||||
this.taskId = taskId;
|
||||
|
||||
RemoteAgentInvocation.sessionState.set(this.definition.name, {
|
||||
contextId: this.contextId,
|
||||
taskId: this.taskId,
|
||||
});
|
||||
|
||||
// Extract the output text
|
||||
const resultData = response;
|
||||
let outputText = '';
|
||||
|
||||
if (resultData.kind === 'message') {
|
||||
outputText = extractMessageText(resultData);
|
||||
} else if (resultData.kind === 'task') {
|
||||
outputText = extractTaskText(resultData);
|
||||
} else {
|
||||
outputText = JSON.stringify(resultData);
|
||||
}
|
||||
|
||||
return {
|
||||
llmContent: [{ text: outputText }],
|
||||
returnDisplay: outputText,
|
||||
};
|
||||
} catch (error: unknown) {
|
||||
const errorMessage = `Error calling remote agent: ${error instanceof Error ? error.message : String(error)}`;
|
||||
return {
|
||||
llmContent: [{ text: errorMessage }],
|
||||
returnDisplay: errorMessage,
|
||||
error: { message: errorMessage },
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,11 @@ export interface OutputObject {
|
||||
*/
|
||||
export type AgentInputs = Record<string, unknown>;
|
||||
|
||||
/**
|
||||
* Simplified input structure for Remote Agents, which consumes a single string query.
|
||||
*/
|
||||
export type RemoteAgentInputs = { query: string };
|
||||
|
||||
/**
|
||||
* Structured events emitted during subagent execution for user observability.
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user