refactor(core): address PR comments on a2aUtils and tests

This commit is contained in:
Alisa Novikova
2026-03-10 11:18:28 -07:00
parent 9027b1ad84
commit 25aa217306
2 changed files with 156 additions and 135 deletions
+30 -17
View File
@@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import { describe, it, expect, vi, beforeEach } from 'vitest'; import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { import {
extractMessageText, extractMessageText,
extractIdsFromResponse, extractIdsFromResponse,
@@ -27,6 +27,7 @@ import type {
TaskArtifactUpdateEvent, TaskArtifactUpdateEvent,
} from '@a2a-js/sdk'; } from '@a2a-js/sdk';
import * as dnsPromises from 'node:dns/promises'; import * as dnsPromises from 'node:dns/promises';
import type { LookupAddress } from 'node:dns';
vi.mock('node:dns/promises', () => ({ vi.mock('node:dns/promises', () => ({
lookup: vi.fn(), lookup: vi.fn(),
@@ -37,6 +38,10 @@ describe('a2aUtils', () => {
vi.clearAllMocks(); vi.clearAllMocks();
}); });
afterEach(() => {
vi.restoreAllMocks();
});
describe('getGrpcCredentials', () => { describe('getGrpcCredentials', () => {
it('should return secure credentials for https', () => { it('should return secure credentials for https', () => {
const credentials = getGrpcCredentials('https://test.agent'); const credentials = getGrpcCredentials('https://test.agent');
@@ -51,10 +56,12 @@ describe('a2aUtils', () => {
describe('pinUrlToIp', () => { describe('pinUrlToIp', () => {
it('should resolve and pin hostname to IP', async () => { it('should resolve and pin hostname to IP', async () => {
vi.mocked(dnsPromises.lookup).mockResolvedValue([ vi.mocked(
{ address: '93.184.216.34', family: 4 }, dnsPromises.lookup as unknown as (
// eslint-disable-next-line @typescript-eslint/no-explicit-any hostname: string,
] as any); options: { all: true },
) => Promise<LookupAddress[]>,
).mockResolvedValue([{ address: '93.184.216.34', family: 4 }]);
const { pinnedUrl, hostname } = await pinUrlToIp( const { pinnedUrl, hostname } = await pinUrlToIp(
'http://example.com:9000', 'http://example.com:9000',
@@ -65,10 +72,12 @@ describe('a2aUtils', () => {
}); });
it('should handle raw host:port strings (standard for gRPC)', async () => { it('should handle raw host:port strings (standard for gRPC)', async () => {
vi.mocked(dnsPromises.lookup).mockResolvedValue([ vi.mocked(
{ address: '93.184.216.34', family: 4 }, dnsPromises.lookup as unknown as (
// eslint-disable-next-line @typescript-eslint/no-explicit-any hostname: string,
] as any); options: { all: true },
) => Promise<LookupAddress[]>,
).mockResolvedValue([{ address: '93.184.216.34', family: 4 }]);
const { pinnedUrl, hostname } = await pinUrlToIp( const { pinnedUrl, hostname } = await pinUrlToIp(
'example.com:9000', 'example.com:9000',
@@ -87,10 +96,12 @@ describe('a2aUtils', () => {
}); });
it('should throw error if resolved to private IP', async () => { it('should throw error if resolved to private IP', async () => {
vi.mocked(dnsPromises.lookup).mockResolvedValue([ vi.mocked(
{ address: '10.0.0.1', family: 4 }, dnsPromises.lookup as unknown as (
// eslint-disable-next-line @typescript-eslint/no-explicit-any hostname: string,
] as any); options: { all: true },
) => Promise<LookupAddress[]>,
).mockResolvedValue([{ address: '10.0.0.1', family: 4 }]);
await expect( await expect(
pinUrlToIp('http://malicious.com', 'test-agent'), pinUrlToIp('http://malicious.com', 'test-agent'),
@@ -98,10 +109,12 @@ describe('a2aUtils', () => {
}); });
it('should allow localhost/127.0.0.1/::1 exceptions', async () => { it('should allow localhost/127.0.0.1/::1 exceptions', async () => {
vi.mocked(dnsPromises.lookup).mockResolvedValue([ vi.mocked(
{ address: '127.0.0.1', family: 4 }, dnsPromises.lookup as unknown as (
// eslint-disable-next-line @typescript-eslint/no-explicit-any hostname: string,
] as any); options: { all: true },
) => Promise<LookupAddress[]>,
).mockResolvedValue([{ address: '127.0.0.1', family: 4 }]);
const { pinnedUrl, hostname } = await pinUrlToIp( const { pinnedUrl, hostname } = await pinUrlToIp(
'http://localhost:9000', 'http://localhost:9000',
+126 -118
View File
@@ -6,6 +6,7 @@
import * as grpc from '@grpc/grpc-js'; import * as grpc from '@grpc/grpc-js';
import { lookup } from 'node:dns/promises'; import { lookup } from 'node:dns/promises';
import { z } from 'zod';
import type { import type {
Message, Message,
Part, Part,
@@ -14,17 +15,40 @@ import type {
FilePart, FilePart,
Artifact, Artifact,
TaskState, TaskState,
TaskStatusUpdateEvent,
TaskArtifactUpdateEvent,
AgentCard, AgentCard,
AgentInterface, AgentInterface,
Task,
} from '@a2a-js/sdk'; } from '@a2a-js/sdk';
import { isAddressPrivate } from '../utils/fetch.js'; import { isAddressPrivate } from '../utils/fetch.js';
import type { SendMessageResult } from './a2a-client-manager.js'; import type { SendMessageResult } from './a2a-client-manager.js';
export const AUTH_REQUIRED_MSG = `[Authorization Required] The agent has indicated it requires authorization to proceed. Please follow the agent's instructions.`; export const AUTH_REQUIRED_MSG = `[Authorization Required] The agent has indicated it requires authorization to proceed. Please follow the agent's instructions.`;
const AgentInterfaceSchema = z
.object({
url: z.string().default(''),
transport: z.string().optional(),
protocolBinding: z.string().optional(),
})
.passthrough();
const AgentCardSchema = z
.object({
name: z.string().default('unknown'),
description: z.string().default(''),
url: z.string().default(''),
version: z.string().default(''),
protocolVersion: z.string().default(''),
capabilities: z.record(z.unknown()).default({}),
skills: z.array(z.union([z.string(), z.record(z.unknown())])).default([]),
defaultInputModes: z.array(z.string()).default([]),
defaultOutputModes: z.array(z.string()).default([]),
additionalInterfaces: z.array(AgentInterfaceSchema).optional(),
supportedInterfaces: z.array(AgentInterfaceSchema).optional(),
preferredTransport: z.string().optional(),
})
.passthrough();
/** /**
* Reassembles incremental A2A streaming updates into a coherent result. * Reassembles incremental A2A streaming updates into a coherent result.
* Shows sequential status/messages followed by all reassembled artifacts. * Shows sequential status/messages followed by all reassembled artifacts.
@@ -40,68 +64,79 @@ export class A2AResultReassembler {
update(chunk: SendMessageResult) { update(chunk: SendMessageResult) {
if (!('kind' in chunk)) return; if (!('kind' in chunk)) return;
if (isStatusUpdateEvent(chunk)) { switch (chunk.kind) {
this.appendStateInstructions(chunk.status?.state); case 'status-update':
this.pushMessage(chunk.status?.message); this.appendStateInstructions(chunk.status?.state);
} else if (isArtifactUpdateEvent(chunk)) { this.pushMessage(chunk.status?.message);
if (chunk.artifact) { break;
const id = chunk.artifact.artifactId;
const existing = this.artifacts.get(id);
if (chunk.append && existing) { case 'artifact-update':
for (const part of chunk.artifact.parts) { if (chunk.artifact) {
existing.parts.push(structuredClone(part)); const id = chunk.artifact.artifactId;
const existing = this.artifacts.get(id);
if (chunk.append && existing) {
for (const part of chunk.artifact.parts) {
existing.parts.push(structuredClone(part));
}
} else {
this.artifacts.set(id, structuredClone(chunk.artifact));
} }
} else {
this.artifacts.set(id, structuredClone(chunk.artifact));
}
const newText = extractPartsText(chunk.artifact.parts, ''); const newText = extractPartsText(chunk.artifact.parts, '');
let chunks = this.artifactChunks.get(id); let chunks = this.artifactChunks.get(id);
if (!chunks) { if (!chunks) {
chunks = []; chunks = [];
this.artifactChunks.set(id, chunks); this.artifactChunks.set(id, chunks);
}
if (chunk.append) {
chunks.push(newText);
} else {
chunks.length = 0;
chunks.push(newText);
}
} }
if (chunk.append) { break;
chunks.push(newText);
} else { case 'task':
chunks.length = 0; this.appendStateInstructions(chunk.status?.state);
chunks.push(newText); this.pushMessage(chunk.status?.message);
if (chunk.artifacts) {
for (const art of chunk.artifacts) {
this.artifacts.set(art.artifactId, structuredClone(art));
this.artifactChunks.set(art.artifactId, [
extractPartsText(art.parts, ''),
]);
}
} }
} // History Fallback: Some agent implementations do not populate the
} else if (isTask(chunk)) { // status.message in their final terminal response, instead archiving
this.appendStateInstructions(chunk.status?.state); // the final answer in the task's history array. To ensure we don't
this.pushMessage(chunk.status?.message); // present an empty result, we fallback to the most recent agent message
if (chunk.artifacts) { // in the history only when the task is terminal and no other content
for (const art of chunk.artifacts) { // (message log or artifacts) has been reassembled.
this.artifacts.set(art.artifactId, structuredClone(art)); if (
this.artifactChunks.set(art.artifactId, [ isTerminalState(chunk.status?.state) &&
extractPartsText(art.parts, ''), this.messageLog.length === 0 &&
]); this.artifacts.size === 0 &&
chunk.history &&
chunk.history.length > 0
) {
const lastAgentMsg = [...chunk.history]
.reverse()
.find((m) => m.role?.toLowerCase().includes('agent'));
if (lastAgentMsg) {
this.pushMessage(lastAgentMsg);
}
} }
} break;
// History Fallback: Some agent implementations do not populate the
// status.message in their final terminal response, instead archiving case 'message':
// the final answer in the task's history array. To ensure we don't this.pushMessage(chunk);
// present an empty result, we fallback to the most recent agent message break;
// in the history only when the task is terminal and no other content default:
// (message log or artifacts) has been reassembled. // Handle unknown kinds gracefully
if ( break;
isTerminalState(chunk.status?.state) &&
this.messageLog.length === 0 &&
this.artifacts.size === 0 &&
chunk.history &&
chunk.history.length > 0
) {
const lastAgentMsg = [...chunk.history]
.reverse()
.find((m) => m.role?.toLowerCase().includes('agent'));
if (lastAgentMsg) {
this.pushMessage(lastAgentMsg);
}
}
} else if (isMessage(chunk)) {
this.pushMessage(chunk);
} }
} }
@@ -214,28 +249,20 @@ export function normalizeAgentCard(card: unknown): AgentCard {
throw new Error('Agent card is missing.'); throw new Error('Agent card is missing.');
} }
// Double-cast to bypass strict linter while bootstrapping the object. // Use Zod to validate and parse the card, ensuring safe defaults and narrowing types.
const parsed = AgentCardSchema.parse(card);
// Narrowing to AgentCard interface after runtime validation.
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const result = { ...card } as unknown as AgentCard; const result = parsed as unknown as AgentCard;
// Ensure mandatory fields exist with safe defaults.
if (typeof result.name !== 'string') result.name = 'unknown';
if (typeof result.description !== 'string') result.description = '';
if (typeof result.url !== 'string') result.url = '';
if (typeof result.version !== 'string') result.version = '';
if (typeof result.protocolVersion !== 'string') result.protocolVersion = '';
if (!isObject(result.capabilities)) result.capabilities = {};
if (!Array.isArray(result.skills)) result.skills = [];
if (!Array.isArray(result.defaultInputModes)) result.defaultInputModes = [];
if (!Array.isArray(result.defaultOutputModes)) result.defaultOutputModes = [];
// Normalize interfaces and synchronize both interface fields. // Normalize interfaces and synchronize both interface fields.
const normalizedInterfaces = extractNormalizedInterfaces(card); const normalizedInterfaces = extractNormalizedInterfaces(parsed);
result.additionalInterfaces = normalizedInterfaces; result.additionalInterfaces = normalizedInterfaces;
// Sync supportedInterfaces for backward compatibility.
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
(result as unknown as Record<string, AgentInterface[]>)[ const legacyResult = result as unknown as Record<string, AgentInterface[]>;
'supportedInterfaces' legacyResult['supportedInterfaces'] = normalizedInterfaces;
] = normalizedInterfaces;
// Fallback preferredTransport: If not specified, default to GRPC if available. // Fallback preferredTransport: If not specified, default to GRPC if available.
if ( if (
@@ -387,26 +414,33 @@ export function extractIdsFromResponse(result: SendMessageResult): {
let taskId: string | undefined; let taskId: string | undefined;
let clearTaskId = false; let clearTaskId = false;
if ('kind' in result) { if (!('kind' in result)) return { contextId, taskId, clearTaskId };
const kind = result.kind;
if (kind === 'message' || isArtifactUpdateEvent(result)) { switch (result.kind) {
case 'message':
case 'artifact-update':
taskId = result.taskId; taskId = result.taskId;
contextId = result.contextId; contextId = result.contextId;
} else if (kind === 'task') { break;
case 'task':
taskId = result.id; taskId = result.id;
contextId = result.contextId; contextId = result.contextId;
if (isTerminalState(result.status?.state)) { if (isTerminalState(result.status?.state)) {
clearTaskId = true; clearTaskId = true;
} }
} else if (isStatusUpdateEvent(result)) { break;
case 'status-update':
taskId = result.taskId; taskId = result.taskId;
contextId = result.contextId; contextId = result.contextId;
// Note: We ignore the 'final' flag here per A2A protocol best practices,
// as a stream can close while a task is still in a 'working' state.
if (isTerminalState(result.status?.state)) { if (isTerminalState(result.status?.state)) {
clearTaskId = true; clearTaskId = true;
} }
} break;
default:
// Handle other kind values if any
break;
} }
return { contextId, taskId, clearTaskId }; return { contextId, taskId, clearTaskId };
@@ -430,26 +464,20 @@ function extractNormalizedInterfaces(
const mapped: AgentInterface[] = []; const mapped: AgentInterface[] = [];
for (const i of rawInterfaces) { for (const i of rawInterfaces) {
if (isObject(i)) { if (isObject(i)) {
// Create a copy to preserve all original fields. // Use schema to validate interface object.
const parsed = AgentInterfaceSchema.parse(i);
// Narrowing to AgentInterface after runtime validation.
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const normalized = { ...i } as unknown as AgentInterface & { const normalized = parsed as unknown as AgentInterface & {
protocolBinding?: string; protocolBinding?: string;
}; };
// Ensure 'url' exists // Normalize 'transport' from 'protocolBinding' if missing.
if (typeof normalized.url !== 'string') { if (!normalized.transport && normalized.protocolBinding) {
normalized.url = ''; normalized.transport = normalized.protocolBinding;
}
// Normalize 'transport' from 'protocolBinding'
const transport = normalized.transport || normalized.protocolBinding;
if (transport) {
normalized.transport = transport;
} }
// Robust URL: Ensure the URL has a scheme (except for gRPC). // Robust URL: Ensure the URL has a scheme (except for gRPC).
// Some agent implementations (like a2a-go samples) may provide raw IP:port strings.
// gRPC targets MUST NOT have a scheme (e.g. 'http://'), or they will fail name resolution.
if ( if (
normalized.url && normalized.url &&
!normalized.url.includes('://') && !normalized.url.includes('://') &&
@@ -460,7 +488,7 @@ function extractNormalizedInterfaces(
normalized.url = `http://${normalized.url}`; normalized.url = `http://${normalized.url}`;
} }
mapped.push(normalized); mapped.push(normalized as AgentInterface);
} }
} }
return mapped; return mapped;
@@ -491,26 +519,6 @@ function isFilePart(part: Part): part is FilePart {
return part.kind === 'file'; return part.kind === 'file';
} }
function isStatusUpdateEvent(
result: SendMessageResult,
): result is TaskStatusUpdateEvent {
return result.kind === 'status-update';
}
function isArtifactUpdateEvent(
result: SendMessageResult,
): result is TaskArtifactUpdateEvent {
return result.kind === 'artifact-update';
}
function isMessage(result: SendMessageResult): result is Message {
return result.kind === 'message';
}
function isTask(result: SendMessageResult): result is Task {
return result.kind === 'task';
}
/** /**
* Returns true if the given state is a terminal state for a task. * Returns true if the given state is a terminal state for a task.
*/ */