mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 14:10:37 -07:00
feat(a2a): implement standardized normalization and streaming reassembly
This commit is contained in:
@@ -4,13 +4,17 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import {
|
||||
extractMessageText,
|
||||
extractIdsFromResponse,
|
||||
isTerminalState,
|
||||
A2AResultReassembler,
|
||||
AUTH_REQUIRED_MSG,
|
||||
normalizeAgentCard,
|
||||
getGrpcCredentials,
|
||||
pinUrlToIp,
|
||||
splitAgentCardUrl,
|
||||
} from './a2aUtils.js';
|
||||
import type { SendMessageResult } from './a2a-client-manager.js';
|
||||
import type {
|
||||
@@ -22,8 +26,92 @@ import type {
|
||||
TaskStatusUpdateEvent,
|
||||
TaskArtifactUpdateEvent,
|
||||
} from '@a2a-js/sdk';
|
||||
import * as dnsPromises from 'node:dns/promises';
|
||||
|
||||
vi.mock('node:dns/promises', () => ({
|
||||
lookup: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('a2aUtils', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('getGrpcCredentials', () => {
|
||||
it('should return secure credentials for https', () => {
|
||||
const credentials = getGrpcCredentials('https://test.agent');
|
||||
expect(credentials).toBeDefined();
|
||||
});
|
||||
|
||||
it('should return insecure credentials for http', () => {
|
||||
const credentials = getGrpcCredentials('http://test.agent');
|
||||
expect(credentials).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('pinUrlToIp', () => {
|
||||
it('should resolve and pin hostname to IP', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockResolvedValue([
|
||||
{ address: '93.184.216.34', family: 4 },
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
] as any);
|
||||
|
||||
const { pinnedUrl, hostname } = await pinUrlToIp(
|
||||
'http://example.com:9000',
|
||||
'test-agent',
|
||||
);
|
||||
expect(hostname).toBe('example.com');
|
||||
expect(pinnedUrl).toBe('http://93.184.216.34:9000/');
|
||||
});
|
||||
|
||||
it('should handle raw host:port strings (standard for gRPC)', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockResolvedValue([
|
||||
{ address: '93.184.216.34', family: 4 },
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
] as any);
|
||||
|
||||
const { pinnedUrl, hostname } = await pinUrlToIp(
|
||||
'example.com:9000',
|
||||
'test-agent',
|
||||
);
|
||||
expect(hostname).toBe('example.com');
|
||||
expect(pinnedUrl).toBe('93.184.216.34:9000');
|
||||
});
|
||||
|
||||
it('should throw error if resolution fails (fail closed)', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error'));
|
||||
|
||||
await expect(
|
||||
pinUrlToIp('http://unreachable.com', 'test-agent'),
|
||||
).rejects.toThrow("Failed to resolve host for agent 'test-agent'");
|
||||
});
|
||||
|
||||
it('should throw error if resolved to private IP', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockResolvedValue([
|
||||
{ address: '10.0.0.1', family: 4 },
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
] as any);
|
||||
|
||||
await expect(
|
||||
pinUrlToIp('http://malicious.com', 'test-agent'),
|
||||
).rejects.toThrow('resolves to private IP range');
|
||||
});
|
||||
|
||||
it('should allow localhost/127.0.0.1/::1 exceptions', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockResolvedValue([
|
||||
{ address: '127.0.0.1', family: 4 },
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
] as any);
|
||||
|
||||
const { pinnedUrl, hostname } = await pinUrlToIp(
|
||||
'http://localhost:9000',
|
||||
'test-agent',
|
||||
);
|
||||
expect(hostname).toBe('localhost');
|
||||
expect(pinnedUrl).toBe('http://127.0.0.1:9000/');
|
||||
});
|
||||
});
|
||||
|
||||
describe('isTerminalState', () => {
|
||||
it('should return true for completed, failed, canceled, and rejected', () => {
|
||||
expect(isTerminalState('completed')).toBe(true);
|
||||
@@ -223,6 +311,173 @@ describe('a2aUtils', () => {
|
||||
} as Message),
|
||||
).toBe('');
|
||||
});
|
||||
|
||||
it('should handle file parts with neither name nor uri', () => {
|
||||
const message: Message = {
|
||||
kind: 'message',
|
||||
role: 'user',
|
||||
messageId: '1',
|
||||
parts: [
|
||||
{
|
||||
kind: 'file',
|
||||
file: {
|
||||
mimeType: 'text/plain',
|
||||
},
|
||||
} as FilePart,
|
||||
],
|
||||
};
|
||||
expect(extractMessageText(message)).toBe('File: [binary/unnamed]');
|
||||
});
|
||||
});
|
||||
|
||||
describe('normalizeAgentCard', () => {
|
||||
it('should throw if input is not an object', () => {
|
||||
expect(() => normalizeAgentCard(null)).toThrow('Agent card is missing.');
|
||||
expect(() => normalizeAgentCard(undefined)).toThrow(
|
||||
'Agent card is missing.',
|
||||
);
|
||||
expect(() => normalizeAgentCard('not an object')).toThrow(
|
||||
'Agent card is missing.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should preserve unknown fields while providing defaults for mandatory ones', () => {
|
||||
const raw = {
|
||||
name: 'my-agent',
|
||||
customField: 'keep-me',
|
||||
};
|
||||
|
||||
const normalized = normalizeAgentCard(raw);
|
||||
|
||||
expect(normalized.name).toBe('my-agent');
|
||||
// @ts-expect-error - testing dynamic preservation
|
||||
expect(normalized.customField).toBe('keep-me');
|
||||
expect(normalized.description).toBe('');
|
||||
expect(normalized.skills).toEqual([]);
|
||||
expect(normalized.defaultInputModes).toEqual([]);
|
||||
});
|
||||
|
||||
it('should normalize and synchronize interfaces while preserving other fields', () => {
|
||||
const raw = {
|
||||
name: 'test',
|
||||
supportedInterfaces: [
|
||||
{
|
||||
url: 'grpc://test',
|
||||
protocolBinding: 'GRPC',
|
||||
protocolVersion: '1.0',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const normalized = normalizeAgentCard(raw);
|
||||
|
||||
// Should exist in both fields
|
||||
expect(normalized.additionalInterfaces).toHaveLength(1);
|
||||
expect(
|
||||
(normalized as unknown as Record<string, unknown>)[
|
||||
'supportedInterfaces'
|
||||
],
|
||||
).toHaveLength(1);
|
||||
|
||||
const intf = normalized.additionalInterfaces?.[0] as unknown as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
|
||||
expect(intf['transport']).toBe('GRPC');
|
||||
expect(intf['url']).toBe('grpc://test');
|
||||
|
||||
// Should fallback top-level url
|
||||
expect(normalized.url).toBe('grpc://test');
|
||||
});
|
||||
|
||||
it('should preserve existing top-level url if present', () => {
|
||||
const raw = {
|
||||
name: 'test',
|
||||
url: 'http://existing',
|
||||
supportedInterfaces: [{ url: 'http://other', transport: 'REST' }],
|
||||
};
|
||||
|
||||
const normalized = normalizeAgentCard(raw);
|
||||
expect(normalized.url).toBe('http://existing');
|
||||
});
|
||||
|
||||
it('should NOT prepend http:// scheme to raw IP:port strings for gRPC interfaces', () => {
|
||||
const raw = {
|
||||
name: 'raw-ip-grpc',
|
||||
supportedInterfaces: [{ url: '127.0.0.1:9000', transport: 'GRPC' }],
|
||||
};
|
||||
|
||||
const normalized = normalizeAgentCard(raw);
|
||||
expect(normalized.additionalInterfaces?.[0].url).toBe('127.0.0.1:9000');
|
||||
expect(normalized.url).toBe('127.0.0.1:9000');
|
||||
});
|
||||
|
||||
it('should prepend http:// scheme to raw IP:port strings for REST interfaces', () => {
|
||||
const raw = {
|
||||
name: 'raw-ip-rest',
|
||||
supportedInterfaces: [{ url: '127.0.0.1:8080', transport: 'REST' }],
|
||||
};
|
||||
|
||||
const normalized = normalizeAgentCard(raw);
|
||||
expect(normalized.additionalInterfaces?.[0].url).toBe(
|
||||
'http://127.0.0.1:8080',
|
||||
);
|
||||
});
|
||||
|
||||
it('should NOT override existing transport if protocolBinding is also present', () => {
|
||||
const raw = {
|
||||
name: 'priority-test',
|
||||
supportedInterfaces: [
|
||||
{ url: 'foo', transport: 'GRPC', protocolBinding: 'REST' },
|
||||
],
|
||||
};
|
||||
const normalized = normalizeAgentCard(raw);
|
||||
expect(normalized.additionalInterfaces?.[0].transport).toBe('GRPC');
|
||||
});
|
||||
});
|
||||
|
||||
describe('splitAgentCardUrl', () => {
|
||||
const standard = '.well-known/agent-card.json';
|
||||
|
||||
it('should return baseUrl as-is if it does not end with standard path', () => {
|
||||
const url = 'http://localhost:9001/custom/path';
|
||||
expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url });
|
||||
});
|
||||
|
||||
it('should split correctly if URL ends with standard path', () => {
|
||||
const url = `http://localhost:9001/${standard}`;
|
||||
expect(splitAgentCardUrl(url)).toEqual({
|
||||
baseUrl: 'http://localhost:9001/',
|
||||
path: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle trailing slash in baseUrl when splitting', () => {
|
||||
const url = `http://example.com/api/${standard}`;
|
||||
expect(splitAgentCardUrl(url)).toEqual({
|
||||
baseUrl: 'http://example.com/api/',
|
||||
path: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
it('should ignore hashes and query params when splitting', () => {
|
||||
const url = `http://localhost:9001/${standard}?foo=bar#baz`;
|
||||
expect(splitAgentCardUrl(url)).toEqual({
|
||||
baseUrl: 'http://localhost:9001/',
|
||||
path: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
it('should return original URL if parsing fails', () => {
|
||||
const url = 'not-a-url';
|
||||
expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url });
|
||||
});
|
||||
|
||||
it('should handle standard path appearing earlier in the path', () => {
|
||||
const url = `http://localhost:9001/${standard}/something-else`;
|
||||
expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url });
|
||||
});
|
||||
});
|
||||
|
||||
describe('A2AResultReassembler', () => {
|
||||
@@ -233,6 +488,7 @@ describe('a2aUtils', () => {
|
||||
reassembler.update({
|
||||
kind: 'status-update',
|
||||
taskId: 't1',
|
||||
contextId: 'ctx1',
|
||||
status: {
|
||||
state: 'working',
|
||||
message: {
|
||||
@@ -247,6 +503,7 @@ describe('a2aUtils', () => {
|
||||
reassembler.update({
|
||||
kind: 'artifact-update',
|
||||
taskId: 't1',
|
||||
contextId: 'ctx1',
|
||||
append: false,
|
||||
artifact: {
|
||||
artifactId: 'a1',
|
||||
@@ -259,6 +516,7 @@ describe('a2aUtils', () => {
|
||||
reassembler.update({
|
||||
kind: 'status-update',
|
||||
taskId: 't1',
|
||||
contextId: 'ctx1',
|
||||
status: {
|
||||
state: 'working',
|
||||
message: {
|
||||
@@ -273,6 +531,7 @@ describe('a2aUtils', () => {
|
||||
reassembler.update({
|
||||
kind: 'artifact-update',
|
||||
taskId: 't1',
|
||||
contextId: 'ctx1',
|
||||
append: true,
|
||||
artifact: {
|
||||
artifactId: 'a1',
|
||||
@@ -291,6 +550,7 @@ describe('a2aUtils', () => {
|
||||
|
||||
reassembler.update({
|
||||
kind: 'status-update',
|
||||
contextId: 'ctx1',
|
||||
status: {
|
||||
state: 'auth-required',
|
||||
message: {
|
||||
@@ -310,6 +570,7 @@ describe('a2aUtils', () => {
|
||||
|
||||
reassembler.update({
|
||||
kind: 'status-update',
|
||||
contextId: 'ctx1',
|
||||
status: {
|
||||
state: 'auth-required',
|
||||
},
|
||||
@@ -323,6 +584,7 @@ describe('a2aUtils', () => {
|
||||
|
||||
const chunk = {
|
||||
kind: 'status-update',
|
||||
contextId: 'ctx1',
|
||||
status: {
|
||||
state: 'auth-required',
|
||||
message: {
|
||||
@@ -351,6 +613,8 @@ describe('a2aUtils', () => {
|
||||
|
||||
reassembler.update({
|
||||
kind: 'task',
|
||||
id: 'task-1',
|
||||
contextId: 'ctx1',
|
||||
status: { state: 'completed' },
|
||||
history: [
|
||||
{
|
||||
@@ -369,6 +633,8 @@ describe('a2aUtils', () => {
|
||||
|
||||
reassembler.update({
|
||||
kind: 'task',
|
||||
id: 'task-1',
|
||||
contextId: 'ctx1',
|
||||
status: { state: 'working' },
|
||||
history: [
|
||||
{
|
||||
@@ -387,6 +653,8 @@ describe('a2aUtils', () => {
|
||||
|
||||
reassembler.update({
|
||||
kind: 'task',
|
||||
id: 'task-1',
|
||||
contextId: 'ctx1',
|
||||
status: { state: 'completed' },
|
||||
artifacts: [
|
||||
{
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as grpc from '@grpc/grpc-js';
|
||||
import { lookup } from 'node:dns/promises';
|
||||
import type {
|
||||
Message,
|
||||
Part,
|
||||
@@ -13,7 +15,12 @@ import type {
|
||||
Artifact,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
TaskArtifactUpdateEvent,
|
||||
AgentCard,
|
||||
AgentInterface,
|
||||
Task,
|
||||
} from '@a2a-js/sdk';
|
||||
import { isAddressPrivate } from '../utils/fetch.js';
|
||||
import type { SendMessageResult } from './a2a-client-manager.js';
|
||||
|
||||
export const AUTH_REQUIRED_MSG = `[Authorization Required] The agent has indicated it requires authorization to proceed. Please follow the agent's instructions.`;
|
||||
@@ -33,80 +40,68 @@ export class A2AResultReassembler {
|
||||
update(chunk: SendMessageResult) {
|
||||
if (!('kind' in chunk)) return;
|
||||
|
||||
switch (chunk.kind) {
|
||||
case 'status-update':
|
||||
this.appendStateInstructions(chunk.status?.state);
|
||||
this.pushMessage(chunk.status?.message);
|
||||
break;
|
||||
if (isStatusUpdateEvent(chunk)) {
|
||||
this.appendStateInstructions(chunk.status?.state);
|
||||
this.pushMessage(chunk.status?.message);
|
||||
} else if (isArtifactUpdateEvent(chunk)) {
|
||||
if (chunk.artifact) {
|
||||
const id = chunk.artifact.artifactId;
|
||||
const existing = this.artifacts.get(id);
|
||||
|
||||
case 'artifact-update':
|
||||
if (chunk.artifact) {
|
||||
const id = chunk.artifact.artifactId;
|
||||
const existing = this.artifacts.get(id);
|
||||
|
||||
if (chunk.append && existing) {
|
||||
for (const part of chunk.artifact.parts) {
|
||||
existing.parts.push(structuredClone(part));
|
||||
}
|
||||
} else {
|
||||
this.artifacts.set(id, structuredClone(chunk.artifact));
|
||||
}
|
||||
|
||||
const newText = extractPartsText(chunk.artifact.parts, '');
|
||||
let chunks = this.artifactChunks.get(id);
|
||||
if (!chunks) {
|
||||
chunks = [];
|
||||
this.artifactChunks.set(id, chunks);
|
||||
}
|
||||
if (chunk.append) {
|
||||
chunks.push(newText);
|
||||
} else {
|
||||
chunks.length = 0;
|
||||
chunks.push(newText);
|
||||
if (chunk.append && existing) {
|
||||
for (const part of chunk.artifact.parts) {
|
||||
existing.parts.push(structuredClone(part));
|
||||
}
|
||||
} else {
|
||||
this.artifacts.set(id, structuredClone(chunk.artifact));
|
||||
}
|
||||
break;
|
||||
|
||||
case 'task':
|
||||
this.appendStateInstructions(chunk.status?.state);
|
||||
this.pushMessage(chunk.status?.message);
|
||||
if (chunk.artifacts) {
|
||||
for (const art of chunk.artifacts) {
|
||||
this.artifacts.set(art.artifactId, structuredClone(art));
|
||||
this.artifactChunks.set(art.artifactId, [
|
||||
extractPartsText(art.parts, ''),
|
||||
]);
|
||||
}
|
||||
const newText = extractPartsText(chunk.artifact.parts, '');
|
||||
let chunks = this.artifactChunks.get(id);
|
||||
if (!chunks) {
|
||||
chunks = [];
|
||||
this.artifactChunks.set(id, chunks);
|
||||
}
|
||||
// History Fallback: Some agent implementations do not populate the
|
||||
// status.message in their final terminal response, instead archiving
|
||||
// the final answer in the task's history array. To ensure we don't
|
||||
// present an empty result, we fallback to the most recent agent message
|
||||
// in the history only when the task is terminal and no other content
|
||||
// (message log or artifacts) has been reassembled.
|
||||
if (
|
||||
isTerminalState(chunk.status?.state) &&
|
||||
this.messageLog.length === 0 &&
|
||||
this.artifacts.size === 0 &&
|
||||
chunk.history &&
|
||||
chunk.history.length > 0
|
||||
) {
|
||||
const lastAgentMsg = [...chunk.history]
|
||||
.reverse()
|
||||
.find((m) => m.role?.toLowerCase().includes('agent'));
|
||||
if (lastAgentMsg) {
|
||||
this.pushMessage(lastAgentMsg);
|
||||
}
|
||||
if (chunk.append) {
|
||||
chunks.push(newText);
|
||||
} else {
|
||||
chunks.length = 0;
|
||||
chunks.push(newText);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'message': {
|
||||
this.pushMessage(chunk);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
break;
|
||||
} else if (isTask(chunk)) {
|
||||
this.appendStateInstructions(chunk.status?.state);
|
||||
this.pushMessage(chunk.status?.message);
|
||||
if (chunk.artifacts) {
|
||||
for (const art of chunk.artifacts) {
|
||||
this.artifacts.set(art.artifactId, structuredClone(art));
|
||||
this.artifactChunks.set(art.artifactId, [
|
||||
extractPartsText(art.parts, ''),
|
||||
]);
|
||||
}
|
||||
}
|
||||
// History Fallback: Some agent implementations do not populate the
|
||||
// status.message in their final terminal response, instead archiving
|
||||
// the final answer in the task's history array. To ensure we don't
|
||||
// present an empty result, we fallback to the most recent agent message
|
||||
// in the history only when the task is terminal and no other content
|
||||
// (message log or artifacts) has been reassembled.
|
||||
if (
|
||||
isTerminalState(chunk.status?.state) &&
|
||||
this.messageLog.length === 0 &&
|
||||
this.artifacts.size === 0 &&
|
||||
chunk.history &&
|
||||
chunk.history.length > 0
|
||||
) {
|
||||
const lastAgentMsg = [...chunk.history]
|
||||
.reverse()
|
||||
.find((m) => m.role?.toLowerCase().includes('agent'));
|
||||
if (lastAgentMsg) {
|
||||
this.pushMessage(lastAgentMsg);
|
||||
}
|
||||
}
|
||||
} else if (isMessage(chunk)) {
|
||||
this.pushMessage(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,36 +205,173 @@ function extractPartText(part: Part): string {
|
||||
return '';
|
||||
}
|
||||
|
||||
// Type Guards
|
||||
/**
|
||||
* Normalizes an agent card by ensuring it has the required properties
|
||||
* and resolving any inconsistencies between protocol versions.
|
||||
*/
|
||||
export function normalizeAgentCard(card: unknown): AgentCard {
|
||||
if (!isObject(card)) {
|
||||
throw new Error('Agent card is missing.');
|
||||
}
|
||||
|
||||
function isTextPart(part: Part): part is TextPart {
|
||||
return part.kind === 'text';
|
||||
}
|
||||
// Double-cast to bypass strict linter while bootstrapping the object.
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const result = { ...card } as unknown as AgentCard;
|
||||
|
||||
function isDataPart(part: Part): part is DataPart {
|
||||
return part.kind === 'data';
|
||||
}
|
||||
// Ensure mandatory fields exist with safe defaults.
|
||||
if (typeof result.name !== 'string') result.name = 'unknown';
|
||||
if (typeof result.description !== 'string') result.description = '';
|
||||
if (typeof result.url !== 'string') result.url = '';
|
||||
if (typeof result.version !== 'string') result.version = '';
|
||||
if (typeof result.protocolVersion !== 'string') result.protocolVersion = '';
|
||||
if (!isObject(result.capabilities)) result.capabilities = {};
|
||||
if (!Array.isArray(result.skills)) result.skills = [];
|
||||
if (!Array.isArray(result.defaultInputModes)) result.defaultInputModes = [];
|
||||
if (!Array.isArray(result.defaultOutputModes)) result.defaultOutputModes = [];
|
||||
|
||||
function isFilePart(part: Part): part is FilePart {
|
||||
return part.kind === 'file';
|
||||
}
|
||||
// Normalize interfaces and synchronize both interface fields.
|
||||
const normalizedInterfaces = extractNormalizedInterfaces(card);
|
||||
result.additionalInterfaces = normalizedInterfaces;
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
(result as unknown as Record<string, AgentInterface[]>)[
|
||||
'supportedInterfaces'
|
||||
] = normalizedInterfaces;
|
||||
|
||||
function isStatusUpdateEvent(
|
||||
result: SendMessageResult,
|
||||
): result is TaskStatusUpdateEvent {
|
||||
return result.kind === 'status-update';
|
||||
// Fallback preferredTransport: If not specified, default to GRPC if available.
|
||||
if (
|
||||
!result.preferredTransport &&
|
||||
normalizedInterfaces.some((i) => i.transport === 'GRPC')
|
||||
) {
|
||||
result.preferredTransport = 'GRPC';
|
||||
}
|
||||
|
||||
// Fallback: If top-level URL is missing, use the first interface's URL.
|
||||
if (result.url === '' && normalizedInterfaces.length > 0) {
|
||||
result.url = normalizedInterfaces[0].url;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the given state is a terminal state for a task.
|
||||
* Returns gRPC channel credentials based on the URL scheme.
|
||||
*/
|
||||
export function isTerminalState(state: TaskState | undefined): boolean {
|
||||
return (
|
||||
state === 'completed' ||
|
||||
state === 'failed' ||
|
||||
state === 'canceled' ||
|
||||
state === 'rejected'
|
||||
);
|
||||
export function getGrpcCredentials(url: string): grpc.ChannelCredentials {
|
||||
return url.startsWith('https://')
|
||||
? grpc.credentials.createSsl()
|
||||
: grpc.credentials.createInsecure();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns gRPC channel options to ensure SSL/authority matches the original hostname
|
||||
* when connecting via a pinned IP address.
|
||||
*/
|
||||
export function getGrpcChannelOptions(
|
||||
hostname: string,
|
||||
): Record<string, unknown> {
|
||||
return {
|
||||
'grpc.default_authority': hostname,
|
||||
'grpc.ssl_target_name_override': hostname,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves a hostname to its IP address and validates it against SSRF.
|
||||
* Returns the pinned IP-based URL and the original hostname.
|
||||
*/
|
||||
export async function pinUrlToIp(
|
||||
url: string,
|
||||
agentName: string,
|
||||
): Promise<{ pinnedUrl: string; hostname: string }> {
|
||||
if (!url) return { pinnedUrl: url, hostname: '' };
|
||||
|
||||
// gRPC URLs in A2A can be 'host:port' or 'dns:///host:port' or have schemes.
|
||||
// We normalize to host:port for resolution.
|
||||
const hasScheme = url.includes('://');
|
||||
const normalizedUrl = hasScheme ? url : `http://${url}`;
|
||||
|
||||
try {
|
||||
const parsed = new URL(normalizedUrl);
|
||||
const hostname = parsed.hostname;
|
||||
|
||||
const sanitizedHost =
|
||||
hostname.startsWith('[') && hostname.endsWith(']')
|
||||
? hostname.slice(1, -1)
|
||||
: hostname;
|
||||
|
||||
// Resolve DNS to check the actual target IP and pin it
|
||||
const addresses = await lookup(hostname, { all: true });
|
||||
const publicAddresses = addresses.filter(
|
||||
(addr) =>
|
||||
!isAddressPrivate(addr.address) ||
|
||||
sanitizedHost === 'localhost' ||
|
||||
sanitizedHost === '127.0.0.1' ||
|
||||
sanitizedHost === '::1',
|
||||
);
|
||||
|
||||
if (publicAddresses.length === 0) {
|
||||
if (addresses.length > 0) {
|
||||
throw new Error(
|
||||
`Refusing to load agent '${agentName}': transport URL '${url}' resolves to private IP range.`,
|
||||
);
|
||||
}
|
||||
throw new Error(
|
||||
`Failed to resolve any public IP addresses for host: ${hostname}`,
|
||||
);
|
||||
}
|
||||
|
||||
const pinnedIp = publicAddresses[0].address;
|
||||
const pinnedHostname = pinnedIp.includes(':') ? `[${pinnedIp}]` : pinnedIp;
|
||||
|
||||
// Reconstruct URL with IP
|
||||
parsed.hostname = pinnedHostname;
|
||||
let pinnedUrl = parsed.toString();
|
||||
|
||||
// If original didn't have scheme, remove it (standard for gRPC targets)
|
||||
if (!hasScheme) {
|
||||
pinnedUrl = pinnedUrl.replace(/^http:\/\//, '');
|
||||
// URL.toString() might append a trailing slash
|
||||
if (pinnedUrl.endsWith('/') && !url.endsWith('/')) {
|
||||
pinnedUrl = pinnedUrl.slice(0, -1);
|
||||
}
|
||||
}
|
||||
|
||||
return { pinnedUrl, hostname };
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.message.includes('Refusing')) throw e;
|
||||
throw new Error(`Failed to resolve host for agent '${agentName}': ${url}`, {
|
||||
cause: e,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Splts an agent card URL into a baseUrl and a standard path if it already
|
||||
* contains '.well-known/agent-card.json'.
|
||||
*/
|
||||
export function splitAgentCardUrl(url: string): {
|
||||
baseUrl: string;
|
||||
path?: string;
|
||||
} {
|
||||
const standardPath = '.well-known/agent-card.json';
|
||||
try {
|
||||
const parsedUrl = new URL(url);
|
||||
if (parsedUrl.pathname.endsWith(standardPath)) {
|
||||
// Reconstruct baseUrl from parsed components to avoid issues with hashes or query params.
|
||||
parsedUrl.pathname = parsedUrl.pathname.substring(
|
||||
0,
|
||||
parsedUrl.pathname.lastIndexOf(standardPath),
|
||||
);
|
||||
parsedUrl.search = '';
|
||||
parsedUrl.hash = '';
|
||||
// We return undefined for path if it's the standard one,
|
||||
// because the SDK's DefaultAgentCardResolver appends it automatically.
|
||||
return { baseUrl: parsedUrl.toString(), path: undefined };
|
||||
}
|
||||
} catch (_e) {
|
||||
// Ignore URL parsing errors here, let the resolver handle them.
|
||||
}
|
||||
return { baseUrl: url };
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -257,7 +389,7 @@ export function extractIdsFromResponse(result: SendMessageResult): {
|
||||
|
||||
if ('kind' in result) {
|
||||
const kind = result.kind;
|
||||
if (kind === 'message' || kind === 'artifact-update') {
|
||||
if (kind === 'message' || isArtifactUpdateEvent(result)) {
|
||||
taskId = result.taskId;
|
||||
contextId = result.contextId;
|
||||
} else if (kind === 'task') {
|
||||
@@ -279,3 +411,121 @@ export function extractIdsFromResponse(result: SendMessageResult): {
|
||||
|
||||
return { contextId, taskId, clearTaskId };
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts and normalizes interfaces from the card, handling protocol version fallbacks.
|
||||
* Preserves all original fields to maintain SDK compatibility.
|
||||
*/
|
||||
function extractNormalizedInterfaces(
|
||||
card: Record<string, unknown>,
|
||||
): AgentInterface[] {
|
||||
const rawInterfaces =
|
||||
getArray(card, 'additionalInterfaces') ||
|
||||
getArray(card, 'supportedInterfaces');
|
||||
|
||||
if (!rawInterfaces) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const mapped: AgentInterface[] = [];
|
||||
for (const i of rawInterfaces) {
|
||||
if (isObject(i)) {
|
||||
// Create a copy to preserve all original fields.
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const normalized = { ...i } as unknown as AgentInterface & {
|
||||
protocolBinding?: string;
|
||||
};
|
||||
|
||||
// Ensure 'url' exists
|
||||
if (typeof normalized.url !== 'string') {
|
||||
normalized.url = '';
|
||||
}
|
||||
|
||||
// Normalize 'transport' from 'protocolBinding'
|
||||
const transport = normalized.transport || normalized.protocolBinding;
|
||||
if (transport) {
|
||||
normalized.transport = transport;
|
||||
}
|
||||
|
||||
// Robust URL: Ensure the URL has a scheme (except for gRPC).
|
||||
// Some agent implementations (like a2a-go samples) may provide raw IP:port strings.
|
||||
// gRPC targets MUST NOT have a scheme (e.g. 'http://'), or they will fail name resolution.
|
||||
if (
|
||||
normalized.url &&
|
||||
!normalized.url.includes('://') &&
|
||||
!normalized.url.startsWith('/') &&
|
||||
normalized.transport !== 'GRPC'
|
||||
) {
|
||||
// Default to http:// for insecure REST/JSON-RPC if scheme is missing.
|
||||
normalized.url = `http://${normalized.url}`;
|
||||
}
|
||||
|
||||
mapped.push(normalized);
|
||||
}
|
||||
}
|
||||
return mapped;
|
||||
}
|
||||
|
||||
/**
|
||||
* Safely extracts an array property from an object.
|
||||
*/
|
||||
function getArray(
|
||||
obj: Record<string, unknown>,
|
||||
key: string,
|
||||
): unknown[] | undefined {
|
||||
const val = obj[key];
|
||||
return Array.isArray(val) ? val : undefined;
|
||||
}
|
||||
|
||||
// Type Guards
|
||||
|
||||
function isTextPart(part: Part): part is TextPart {
|
||||
return part.kind === 'text';
|
||||
}
|
||||
|
||||
function isDataPart(part: Part): part is DataPart {
|
||||
return part.kind === 'data';
|
||||
}
|
||||
|
||||
function isFilePart(part: Part): part is FilePart {
|
||||
return part.kind === 'file';
|
||||
}
|
||||
|
||||
function isStatusUpdateEvent(
|
||||
result: SendMessageResult,
|
||||
): result is TaskStatusUpdateEvent {
|
||||
return result.kind === 'status-update';
|
||||
}
|
||||
|
||||
function isArtifactUpdateEvent(
|
||||
result: SendMessageResult,
|
||||
): result is TaskArtifactUpdateEvent {
|
||||
return result.kind === 'artifact-update';
|
||||
}
|
||||
|
||||
function isMessage(result: SendMessageResult): result is Message {
|
||||
return result.kind === 'message';
|
||||
}
|
||||
|
||||
function isTask(result: SendMessageResult): result is Task {
|
||||
return result.kind === 'task';
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the given state is a terminal state for a task.
|
||||
*/
|
||||
export function isTerminalState(state: TaskState | undefined): boolean {
|
||||
return (
|
||||
state === 'completed' ||
|
||||
state === 'failed' ||
|
||||
state === 'canceled' ||
|
||||
state === 'rejected'
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Type guard to check if a value is a non-array object.
|
||||
*/
|
||||
function isObject(val: unknown): val is Record<string, unknown> {
|
||||
return typeof val === 'object' && val !== null && !Array.isArray(val);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user