From d034c4dd96a133fdebb9fcfe77cca9f6fcaae686 Mon Sep 17 00:00:00 2001 From: Adam Weidman Date: Thu, 26 Mar 2026 15:56:14 -0400 Subject: [PATCH] test(core): add unit tests for LocalSubagentProtocol and RemoteSubagentProtocol Adds 58 tests covering lifecycle events, activity translation, config buffering, session state persistence, auth setup, error handling, abort, and concurrent send() guard for both new protocol implementations. Also fixes a timing bug in RemoteSubagentProtocol where _sessionState was persisted in a finally block, causing it to be written after the result promise settled. Moved _saveSessionState() calls to run synchronously before each _resultResolve/_resultReject so callers see updated contextId/taskId immediately after getResult() settles. --- .../agents/local-subagent-protocol.test.ts | 776 ++++++++++++++++++ .../agents/remote-subagent-protocol.test.ts | 776 ++++++++++++++++++ .../src/agents/remote-subagent-protocol.ts | 19 +- 3 files changed, 1566 insertions(+), 5 deletions(-) create mode 100644 packages/core/src/agents/local-subagent-protocol.test.ts create mode 100644 packages/core/src/agents/remote-subagent-protocol.test.ts diff --git a/packages/core/src/agents/local-subagent-protocol.test.ts b/packages/core/src/agents/local-subagent-protocol.test.ts new file mode 100644 index 0000000000..7c45561d22 --- /dev/null +++ b/packages/core/src/agents/local-subagent-protocol.test.ts @@ -0,0 +1,776 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { LocalSubagentSession } from './local-subagent-protocol.js'; +import { LocalAgentExecutor } from './local-executor.js'; +import { + AgentTerminateMode, + type LocalAgentDefinition, + type SubagentActivityEvent, +} from './types.js'; +import { makeFakeConfig } from '../test-utils/config.js'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; +import type { AgentLoopContext } from '../config/agent-loop-context.js'; +import type { AgentEvent } from '../agent/types.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import type { z } from 'zod'; +import type { Mocked } from 'vitest'; + +vi.mock('./local-executor.js'); + +const MockLocalAgentExecutor = vi.mocked(LocalAgentExecutor); + +// Captures the onActivity callback passed to LocalAgentExecutor.create(). +// Set via create.mockImplementation in beforeEach to avoid mock.calls index fragility. +let capturedOnActivity: ((activity: SubagentActivityEvent) => void) | undefined; + +const testDefinition: LocalAgentDefinition = { + kind: 'local', + name: 'TestProtocolAgent', + description: 'A test agent for protocol tests.', + inputConfig: { + inputSchema: { + type: 'object', + properties: { + task: { type: 'string' }, + priority: { type: 'number' }, + }, + }, + }, + modelConfig: { model: 'test', generateContentConfig: {} }, + runConfig: { maxTimeMinutes: 1 }, + promptConfig: { systemPrompt: 'test' }, +}; + +const GOAL_OUTPUT = { + result: 'Analysis complete.', + terminate_reason: AgentTerminateMode.GOAL, +}; + +describe('LocalSubagentSession (protocol)', () => { + let mockContext: AgentLoopContext; + let mockMessageBus: MessageBus; + let mockExecutorInstance: Mocked>; + + beforeEach(() => { + vi.clearAllMocks(); + capturedOnActivity = undefined; + + mockContext = makeFakeConfig() as unknown as AgentLoopContext; + mockMessageBus = createMockMessageBus(); + + mockExecutorInstance = { + run: vi.fn().mockResolvedValue(GOAL_OUTPUT), + definition: testDefinition, + } as unknown as Mocked>; + + // Use mockImplementation (not mockResolvedValue) so we can capture onActivity. + MockLocalAgentExecutor.create.mockImplementation( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + async (_def: any, _ctx: any, onActivity: any) => { + capturedOnActivity = onActivity; + + return mockExecutorInstance as unknown as LocalAgentExecutor; + }, + ); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + // --------------------------------------------------------------------------- + // Lifecycle events + // --------------------------------------------------------------------------- + + describe('lifecycle events', () => { + it('emits agent_start then agent_end(completed) for a GOAL run', async () => { + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + const events: AgentEvent[] = []; + session.subscribe((e) => events.push(e)); + + await session.send({ message: [{ type: 'text', text: 'query' }] }); + await session.getResult(); + + expect(events[0].type).toBe('agent_start'); + expect(events[events.length - 1].type).toBe('agent_end'); + const endEvent = events[events.length - 1]; + if (endEvent.type === 'agent_end') { + expect(endEvent.reason).toBe('completed'); + } + }); + + it('emits agent_start exactly once even if ensureAgentStart called twice internally', async () => { + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + const events: AgentEvent[] = []; + session.subscribe((e) => events.push(e)); + + await session.send({ message: [{ type: 'text', text: 'query' }] }); + await session.getResult(); + + const startEvents = events.filter((e) => e.type === 'agent_start'); + expect(startEvents).toHaveLength(1); + }); + + it('emits agent_end exactly once on error path', async () => { + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + mockExecutorInstance.run.mockRejectedValue(new Error('executor failed')); + + const events: AgentEvent[] = []; + session.subscribe((e) => events.push(e)); + + await session.send({ message: [{ type: 'text', text: 'query' }] }); + await expect(session.getResult()).rejects.toThrow('executor failed'); + + const endEvents = events.filter((e) => e.type === 'agent_end'); + expect(endEvents).toHaveLength(1); + }); + + it('all events share the same streamId', async () => { + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + const events: AgentEvent[] = []; + session.subscribe((e) => events.push(e)); + + await session.send({ message: [{ type: 'text', text: 'query' }] }); + await session.getResult(); + + const streamIds = new Set(events.map((e) => e.streamId)); + expect(streamIds.size).toBe(1); + }); + }); + + // --------------------------------------------------------------------------- + // Config buffering (update + message pattern) + // --------------------------------------------------------------------------- + + describe('config buffering', () => { + it('merges buffered config with message query', async () => { + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + await session.send({ + update: { config: { task: 'analyze', priority: 5 } }, + }); + await session.send({ message: [{ type: 'text', text: 'my query' }] }); + await session.getResult(); + + expect(mockExecutorInstance.run).toHaveBeenCalledWith( + { task: 'analyze', priority: 5, query: 'my query' }, + expect.any(AbortSignal), + ); + }); + + it('omits query key when message text is empty', async () => { + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + await session.send({ update: { config: { task: 'no-query-task' } } }); + await session.send({ message: [{ type: 'text', text: '' }] }); + await session.getResult(); + + const callArgs = mockExecutorInstance.run.mock.calls[0][0]; + expect(callArgs).not.toHaveProperty('query'); + expect(callArgs).toEqual({ task: 'no-query-task' }); + }); + + it('sends only query when no prior update', async () => { + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + await session.send({ message: [{ type: 'text', text: 'just a query' }] }); + await session.getResult(); + + expect(mockExecutorInstance.run).toHaveBeenCalledWith( + { query: 'just a query' }, + expect.any(AbortSignal), + ); + }); + + it('multiple update calls are merged', async () => { + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + await session.send({ update: { config: { field1: 'a' } } }); + await session.send({ update: { config: { field2: 'b' } } }); + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await session.getResult(); + + expect(mockExecutorInstance.run).toHaveBeenCalledWith( + { field1: 'a', field2: 'b', query: 'q' }, + expect.any(AbortSignal), + ); + }); + + it('update returns streamId: null; message returns a streamId', async () => { + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + const updateResult = await session.send({ update: { config: {} } }); + expect(updateResult.streamId).toBeNull(); + + const messageResult = await session.send({ + message: [{ type: 'text', text: 'q' }], + }); + expect(messageResult.streamId).not.toBeNull(); + expect(typeof messageResult.streamId).toBe('string'); + + // Await completion to prevent dangling execution affecting subsequent tests + await session.getResult(); + }); + }); + + // --------------------------------------------------------------------------- + // Activity translation + // --------------------------------------------------------------------------- + + describe('activity translation', () => { + function makeSession() { + const activityEvents: SubagentActivityEvent[] = []; + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + return { session, activityEvents }; + } + + async function runWithActivities( + session: LocalSubagentSession, + activities: SubagentActivityEvent[], + ) { + mockExecutorInstance.run.mockImplementation(async () => { + // capturedOnActivity is set by the create.mockImplementation in beforeEach + // and updated whenever create() is called. By the time run() is called, + // capturedOnActivity holds the onActivity closure for the most-recently + // created executor — which is the one associated with this session. + for (const act of activities) { + capturedOnActivity?.(act); + } + return GOAL_OUTPUT; + }); + + const events: AgentEvent[] = []; + session.subscribe((e) => events.push(e)); + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await session.getResult(); + return events; + } + + it('THOUGHT_CHUNK → message event with thought content', async () => { + const { session } = makeSession(); + const events = await runWithActivities(session, [ + { + isSubagentActivityEvent: true, + agentName: 'TestProtocolAgent', + type: 'THOUGHT_CHUNK', + data: { text: 'I am thinking...' }, + }, + ]); + + const msgEvent = events.find((e) => e.type === 'message'); + expect(msgEvent).toBeDefined(); + if (msgEvent?.type === 'message') { + expect(msgEvent.role).toBe('agent'); + expect(msgEvent.content).toContainEqual({ + type: 'thought', + thought: 'I am thinking...', + }); + } + }); + + it('TOOL_CALL_START → tool_request event', async () => { + const { session } = makeSession(); + const events = await runWithActivities(session, [ + { + isSubagentActivityEvent: true, + agentName: 'TestProtocolAgent', + type: 'TOOL_CALL_START', + data: { callId: 'call-123', name: 'read_file', args: { path: '/a' } }, + }, + ]); + + const reqEvent = events.find((e) => e.type === 'tool_request'); + expect(reqEvent).toBeDefined(); + if (reqEvent?.type === 'tool_request') { + expect(reqEvent.requestId).toBe('call-123'); + expect(reqEvent.name).toBe('read_file'); + expect(reqEvent.args).toEqual({ path: '/a' }); + } + }); + + it('TOOL_CALL_END → tool_response event', async () => { + const { session } = makeSession(); + const events = await runWithActivities(session, [ + { + isSubagentActivityEvent: true, + agentName: 'TestProtocolAgent', + type: 'TOOL_CALL_END', + data: { id: 'call-123', name: 'read_file', output: 'file contents' }, + }, + ]); + + const respEvent = events.find((e) => e.type === 'tool_response'); + expect(respEvent).toBeDefined(); + if (respEvent?.type === 'tool_response') { + expect(respEvent.requestId).toBe('call-123'); + expect(respEvent.name).toBe('read_file'); + expect(respEvent.content).toContainEqual({ + type: 'text', + text: 'file contents', + }); + } + }); + + it('ERROR activity → error event with INTERNAL status, fatal: false', async () => { + const { session } = makeSession(); + const events = await runWithActivities(session, [ + { + isSubagentActivityEvent: true, + agentName: 'TestProtocolAgent', + type: 'ERROR', + data: { error: 'something went wrong' }, + }, + ]); + + const errEvent = events.find((e) => e.type === 'error'); + expect(errEvent).toBeDefined(); + if (errEvent?.type === 'error') { + expect(errEvent.status).toBe('INTERNAL'); + expect(errEvent.message).toBe('something went wrong'); + expect(errEvent.fatal).toBe(false); + } + }); + + it('unknown activity type → no events emitted', async () => { + const { session } = makeSession(); + const events = await runWithActivities(session, [ + { + isSubagentActivityEvent: true, + agentName: 'TestProtocolAgent', + // eslint-disable-next-line @typescript-eslint/no-explicit-any + type: 'UNKNOWN_TYPE' as any, + data: {}, + }, + ]); + + // Only agent_start and agent_end should be present + const nonLifecycle = events.filter( + (e) => e.type !== 'agent_start' && e.type !== 'agent_end', + ); + expect(nonLifecycle).toHaveLength(0); + }); + + it('TOOL_CALL_START with non-object args defaults to {}', async () => { + const { session } = makeSession(); + const events = await runWithActivities(session, [ + { + isSubagentActivityEvent: true, + agentName: 'TestProtocolAgent', + type: 'TOOL_CALL_START', + data: { callId: 'x', name: 'tool', args: null }, + }, + ]); + + const reqEvent = events.find((e) => e.type === 'tool_request'); + if (reqEvent?.type === 'tool_request') { + expect(reqEvent.args).toEqual({}); + } + }); + }); + + // --------------------------------------------------------------------------- + // getResult() promise + // --------------------------------------------------------------------------- + + describe('getResult()', () => { + it('resolves with OutputObject on GOAL termination', async () => { + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + await session.send({ message: [{ type: 'text', text: 'q' }] }); + const output = await session.getResult(); + + expect(output.result).toBe('Analysis complete.'); + expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL); + }); + + it('rejects when executor throws', async () => { + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + mockExecutorInstance.run.mockRejectedValue(new Error('executor error')); + + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await expect(session.getResult()).rejects.toThrow('executor error'); + }); + }); + + // --------------------------------------------------------------------------- + // rawActivityCallback + // --------------------------------------------------------------------------- + + describe('rawActivityCallback', () => { + it('receives raw SubagentActivityEvent before AgentEvent translation', async () => { + const rawActivities: SubagentActivityEvent[] = []; + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + (activity) => rawActivities.push(activity), + ); + + const thoughtActivity: SubagentActivityEvent = { + isSubagentActivityEvent: true, + agentName: 'TestProtocolAgent', + type: 'THOUGHT_CHUNK', + data: { text: 'raw thought' }, + }; + + mockExecutorInstance.run.mockImplementation(async () => { + const onActivity = MockLocalAgentExecutor.create.mock.calls[0]?.[2]; + onActivity?.(thoughtActivity); + return GOAL_OUTPUT; + }); + + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await session.getResult(); + + expect(rawActivities).toHaveLength(1); + expect(rawActivities[0]).toBe(thoughtActivity); + }); + + it('is called before AgentEvent translation (raw arrives first)', async () => { + const callOrder: string[] = []; + + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + () => callOrder.push('raw'), + ); + + session.subscribe((e) => { + if (e.type === 'message') callOrder.push('translated'); + }); + + mockExecutorInstance.run.mockImplementation(async () => { + const onActivity = MockLocalAgentExecutor.create.mock.calls[0]?.[2]; + onActivity?.({ + isSubagentActivityEvent: true, + agentName: 'TestProtocolAgent', + type: 'THOUGHT_CHUNK', + data: { text: 'thought' }, + }); + return GOAL_OUTPUT; + }); + + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await session.getResult(); + + expect(callOrder).toEqual(['raw', 'translated']); + }); + + it('is optional — no callback causes no error', async () => { + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + // no rawActivityCallback + ); + + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await expect(session.getResult()).resolves.toBeDefined(); + }); + }); + + // --------------------------------------------------------------------------- + // Subscription + // --------------------------------------------------------------------------- + + describe('subscription', () => { + it('unsubscribe stops event delivery', async () => { + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + const received: AgentEvent[] = []; + const unsub = session.subscribe((e) => received.push(e)); + unsub(); + + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await session.getResult(); + + expect(received).toHaveLength(0); + }); + + it('multiple subscribers all receive events', async () => { + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + const received1: AgentEvent[] = []; + const received2: AgentEvent[] = []; + session.subscribe((e) => received1.push(e)); + session.subscribe((e) => received2.push(e)); + + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await session.getResult(); + + expect(received1.length).toBeGreaterThan(0); + expect(received1).toEqual(received2); + }); + + it('events array accumulates all emitted events', async () => { + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await session.getResult(); + + expect(session.events.length).toBeGreaterThanOrEqual(2); // at least agent_start + agent_end + expect(session.events[0].type).toBe('agent_start'); + }); + }); + + // --------------------------------------------------------------------------- + // Terminate mode mapping + // --------------------------------------------------------------------------- + + describe('terminate mode → StreamEndReason mapping', () => { + const cases: Array<[AgentTerminateMode, string]> = [ + [AgentTerminateMode.GOAL, 'completed'], + [AgentTerminateMode.TIMEOUT, 'max_time'], + [AgentTerminateMode.MAX_TURNS, 'max_turns'], + [AgentTerminateMode.ABORTED, 'aborted'], + [AgentTerminateMode.ERROR, 'failed'], + [AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL, 'failed'], + ]; + + for (const [terminateMode, expectedReason] of cases) { + it(`${terminateMode} → agent_end(reason:'${expectedReason}')`, async () => { + mockExecutorInstance.run.mockResolvedValue({ + result: 'done', + terminate_reason: terminateMode, + }); + + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + const events: AgentEvent[] = []; + session.subscribe((e) => events.push(e)); + + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await session.getResult().catch(() => { + // ABORTED results in rejection — catch to let test complete + }); + + const endEvent = events.find((e) => e.type === 'agent_end'); + expect(endEvent).toBeDefined(); + if (endEvent?.type === 'agent_end') { + expect(endEvent.reason).toBe(expectedReason); + } + }); + } + }); + + // --------------------------------------------------------------------------- + // Abort + // --------------------------------------------------------------------------- + + describe('abort()', () => { + it('abort() causes agent_end(reason:aborted)', async () => { + // Make run() wait until aborted + let abortSignal: AbortSignal | undefined; + mockExecutorInstance.run.mockImplementation( + (_params: unknown, signal: AbortSignal) => { + abortSignal = signal; + return new Promise((_resolve, reject) => { + signal.addEventListener('abort', () => { + const err = new Error('AbortError'); + err.name = 'AbortError'; + reject(err); + }); + }); + }, + ); + + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + const events: AgentEvent[] = []; + session.subscribe((e) => events.push(e)); + + void session.send({ message: [{ type: 'text', text: 'q' }] }); + + // Wait for executor to be created and run started + await vi.waitFor(() => { + expect(abortSignal).toBeDefined(); + }); + + await session.abort(); + + await expect(session.getResult()).rejects.toThrow(); + + const endEvent = events.find((e) => e.type === 'agent_end'); + expect(endEvent).toBeDefined(); + if (endEvent?.type === 'agent_end') { + expect(endEvent.reason).toBe('aborted'); + } + }); + }); + + // --------------------------------------------------------------------------- + // Full event sequence + // --------------------------------------------------------------------------- + + describe('full event sequence', () => { + it('emits agent_start → message(thought) → tool_request → tool_response → agent_end in order', async () => { + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + mockExecutorInstance.run.mockImplementation(async () => { + const onActivity = MockLocalAgentExecutor.create.mock.calls[0]?.[2]; + onActivity?.({ + isSubagentActivityEvent: true, + agentName: 'TestProtocolAgent', + type: 'THOUGHT_CHUNK', + data: { text: 'thinking' }, + }); + onActivity?.({ + isSubagentActivityEvent: true, + agentName: 'TestProtocolAgent', + type: 'TOOL_CALL_START', + data: { callId: 'c1', name: 'tool', args: {} }, + }); + onActivity?.({ + isSubagentActivityEvent: true, + agentName: 'TestProtocolAgent', + type: 'TOOL_CALL_END', + data: { id: 'c1', name: 'tool', output: 'result' }, + }); + return GOAL_OUTPUT; + }); + + const events: AgentEvent[] = []; + session.subscribe((e) => events.push(e)); + + await session.send({ message: [{ type: 'text', text: 'go' }] }); + await session.getResult(); + + const types = events.map((e) => e.type); + expect(types).toEqual([ + 'agent_start', + 'message', + 'tool_request', + 'tool_response', + 'agent_end', + ]); + }); + }); + + // --------------------------------------------------------------------------- + // Concurrent send() guard + // --------------------------------------------------------------------------- + + describe('concurrent send() guard', () => { + it('calling send() while a stream is active throws', async () => { + let abortSignal: AbortSignal | undefined; + mockExecutorInstance.run.mockImplementation( + (_params: unknown, signal: AbortSignal) => { + abortSignal = signal; + return new Promise((_resolve, reject) => { + // Reject when aborted so getResult() can settle during cleanup + signal.addEventListener('abort', () => { + const err = new Error('AbortError'); + err.name = 'AbortError'; + reject(err); + }); + }); + }, + ); + + const session = new LocalSubagentSession( + testDefinition, + mockContext, + mockMessageBus, + ); + + void session.send({ message: [{ type: 'text', text: 'first' }] }); + + // Wait for execution to start + await vi.waitFor(() => { + expect(abortSignal).toBeDefined(); + }); + + // Second send() while first stream is active must throw + await expect( + session.send({ message: [{ type: 'text', text: 'second' }] }), + ).rejects.toThrow('cannot be called while a stream is active'); + + // Clean up: abort to unblock the hanging executor + await session.abort(); + await session.getResult().catch(() => {}); + }); + }); +}); diff --git a/packages/core/src/agents/remote-subagent-protocol.test.ts b/packages/core/src/agents/remote-subagent-protocol.test.ts new file mode 100644 index 0000000000..317b1ad856 --- /dev/null +++ b/packages/core/src/agents/remote-subagent-protocol.test.ts @@ -0,0 +1,776 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; +import { RemoteSubagentSession } from './remote-subagent-protocol.js'; + +import { A2AAuthProviderFactory } from './auth-provider/factory.js'; +import type { RemoteAgentDefinition, SubagentProgress } from './types.js'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; +import type { AgentLoopContext } from '../config/agent-loop-context.js'; +import type { AgentEvent } from '../agent/types.js'; +import type { Config } from '../config/config.js'; +import type { A2AAuthProvider } from './auth-provider/types.js'; + +// Mock A2AClientManager at module level +vi.mock('./a2a-client-manager.js', () => ({ + A2AClientManager: vi.fn().mockImplementation(() => ({ + getClient: vi.fn(), + loadAgent: vi.fn(), + sendMessageStream: vi.fn(), + })), +})); + +// Mock A2AAuthProviderFactory +vi.mock('./auth-provider/factory.js', () => ({ + A2AAuthProviderFactory: { + create: vi.fn(), + }, +})); + +const mockDefinition: RemoteAgentDefinition = { + name: 'test-remote-agent', + kind: 'remote', + agentCardUrl: 'http://test-agent/card', + displayName: 'Test Remote Agent', + description: 'A test remote agent', + inputConfig: { + inputSchema: { type: 'object' }, + }, +}; + +function makeChunk(text: string) { + return { + kind: 'message' as const, + messageId: `msg-${Math.random()}`, + role: 'agent' as const, + parts: [{ kind: 'text' as const, text }], + }; +} + +describe('RemoteSubagentSession (protocol)', () => { + let mockClientManager: { + getClient: Mock; + loadAgent: Mock; + sendMessageStream: Mock; + }; + let mockContext: AgentLoopContext; + let mockMessageBus: ReturnType; + + beforeEach(() => { + vi.clearAllMocks(); + + // Static session state is not cleared between tests — each test uses a + // unique agent name to avoid cross-test contamination. + + mockClientManager = { + getClient: vi.fn().mockReturnValue(undefined), // client not yet loaded + loadAgent: vi.fn().mockResolvedValue(undefined), + sendMessageStream: vi.fn(), + }; + + const mockConfig = { + getA2AClientManager: vi.fn().mockReturnValue(mockClientManager), + injectionService: { + getLatestInjectionIndex: vi.fn().mockReturnValue(0), + }, + } as unknown as Config; + + mockContext = { config: mockConfig } as unknown as AgentLoopContext; + mockMessageBus = createMockMessageBus(); + + // Default: sendMessageStream yields one chunk with "Hello" + mockClientManager.sendMessageStream.mockImplementation(async function* () { + yield makeChunk('Hello'); + }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + // Helper: run a session with the default or custom stream and collect events + async function runSession( + definition: RemoteAgentDefinition = mockDefinition, + query = 'test query', + ) { + const session = new RemoteSubagentSession( + definition, + mockContext, + mockMessageBus, + ); + const events: AgentEvent[] = []; + session.subscribe((e) => events.push(e)); + await session.send({ message: [{ type: 'text', text: query }] }); + const result = await session.getResult(); + return { session, events, result }; + } + + // --------------------------------------------------------------------------- + // Lifecycle events + // --------------------------------------------------------------------------- + + describe('lifecycle events', () => { + it('emits agent_start then agent_end(completed) on success', async () => { + const { events } = await runSession(); + + const types = events.map((e) => e.type); + expect(types[0]).toBe('agent_start'); + expect(types[types.length - 1]).toBe('agent_end'); + const end = events[events.length - 1]; + if (end.type === 'agent_end') { + expect(end.reason).toBe('completed'); + } + }); + + it('emits agent_start exactly once', async () => { + const { events } = await runSession(); + expect(events.filter((e) => e.type === 'agent_start')).toHaveLength(1); + }); + + it('emits agent_end exactly once on error path', async () => { + mockClientManager.sendMessageStream.mockReturnValue({ + [Symbol.asyncIterator]() { + return { + async next(): Promise> { + throw new Error('stream error'); + }, + }; + }, + }); + + const session = new RemoteSubagentSession( + mockDefinition, + mockContext, + mockMessageBus, + ); + const events: AgentEvent[] = []; + session.subscribe((e) => events.push(e)); + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await expect(session.getResult()).rejects.toThrow('stream error'); + + expect(events.filter((e) => e.type === 'agent_end')).toHaveLength(1); + }); + + it('all events share the same streamId', async () => { + const { events } = await runSession(); + const streamIds = new Set(events.map((e) => e.streamId)); + expect(streamIds.size).toBe(1); + }); + + it('message returns a non-null streamId; unsupported payload returns null', async () => { + const session = new RemoteSubagentSession( + mockDefinition, + mockContext, + mockMessageBus, + ); + const updateResult = await session.send({ + update: { config: { key: 'val' } }, + }); + expect(updateResult.streamId).toBeNull(); + + const messageResult = await session.send({ + message: [{ type: 'text', text: 'q' }], + }); + expect(messageResult.streamId).not.toBeNull(); + // complete the session to avoid dangling execution + await session.getResult(); + }); + }); + + // --------------------------------------------------------------------------- + // Chunk → AgentEvent translation + // --------------------------------------------------------------------------- + + describe('chunk → AgentEvent translation', () => { + it('each A2A chunk produces a message event with current accumulated text', async () => { + mockClientManager.sendMessageStream.mockImplementation( + async function* () { + yield makeChunk('Hello'); + yield makeChunk(' world'); + }, + ); + + const { events } = await runSession(); + + const msgEvents = events.filter((e) => e.type === 'message'); + expect(msgEvents.length).toBeGreaterThanOrEqual(1); + // Final message event should contain the accumulated text + const lastMsg = msgEvents[msgEvents.length - 1]; + if (lastMsg?.type === 'message') { + const textContent = lastMsg.content.find((c) => c.type === 'text'); + expect(textContent).toBeDefined(); + if (textContent?.type === 'text') { + expect(textContent.text).toContain('Hello'); + } + } + }); + + it('getLatestProgress() is updated per chunk with state running', async () => { + let capturedProgress: SubagentProgress | undefined; + + mockClientManager.sendMessageStream.mockImplementation( + async function* () { + yield makeChunk('Partial'); + }, + ); + + const session = new RemoteSubagentSession( + mockDefinition, + mockContext, + mockMessageBus, + ); + session.subscribe((e) => { + if (e.type === 'message') { + capturedProgress = session.getLatestProgress(); + } + }); + + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await session.getResult(); + + // During streaming, progress should be 'running' + expect(capturedProgress).toBeDefined(); + // Note: by the time we check, progress may be 'completed'. + // During the message event, it was 'running'. + expect(capturedProgress?.isSubagentProgress).toBe(true); + expect(capturedProgress?.agentName).toBe('Test Remote Agent'); + }); + + it('getLatestProgress() state is completed after getResult() resolves', async () => { + const { session } = await runSession(); + const progress = session.getLatestProgress(); + expect(progress?.state).toBe('completed'); + expect(progress?.result).toBe('Hello'); + }); + }); + + // --------------------------------------------------------------------------- + // getResult() promise + // --------------------------------------------------------------------------- + + describe('getResult()', () => { + it('resolves with ToolResult containing llmContent and SubagentProgress returnDisplay', async () => { + mockClientManager.sendMessageStream.mockImplementation( + async function* () { + yield makeChunk('Result text'); + }, + ); + + const { result } = await runSession(); + + expect(result.llmContent).toEqual([{ text: 'Result text' }]); + const display = result.returnDisplay as SubagentProgress; + expect(display.isSubagentProgress).toBe(true); + expect(display.state).toBe('completed'); + expect(display.result).toBe('Result text'); + expect(display.agentName).toBe('Test Remote Agent'); + }); + + it('rejects when stream throws a non-A2A error', async () => { + mockClientManager.sendMessageStream.mockReturnValue({ + [Symbol.asyncIterator]() { + return { + async next(): Promise> { + throw new Error('network failure'); + }, + }; + }, + }); + + const session = new RemoteSubagentSession( + mockDefinition, + mockContext, + mockMessageBus, + ); + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await expect(session.getResult()).rejects.toThrow(); + }); + + it('resolves even with empty stream (empty final output)', async () => { + mockClientManager.sendMessageStream.mockImplementation( + async function* () { + // yield nothing + }, + ); + + const { result } = await runSession(); + expect(result.llmContent).toEqual([{ text: '' }]); + }); + }); + + // --------------------------------------------------------------------------- + // Session state persistence + // --------------------------------------------------------------------------- + + describe('session state persistence', () => { + it('second call reuses contextId captured from first call', async () => { + const agentName = 'persistent-agent'; + const persistDef: RemoteAgentDefinition = { + ...mockDefinition, + name: agentName, + }; + + let callCount = 0; + mockClientManager.sendMessageStream.mockImplementation(async function* ( + _name: string, + _query: string, + opts: { contextId?: string }, + ) { + callCount++; + if (callCount === 1) { + // First call: return a chunk that yields a contextId + yield { + kind: 'message' as const, + messageId: 'msg-1', + role: 'agent' as const, + contextId: 'ctx-from-server', + parts: [{ kind: 'text' as const, text: 'First response' }], + }; + } else { + // Second call: caller should have passed the contextId + expect(opts.contextId).toBe('ctx-from-server'); + yield makeChunk('Second response'); + } + }); + + // First call + const session1 = new RemoteSubagentSession( + persistDef, + mockContext, + mockMessageBus, + ); + await session1.send({ message: [{ type: 'text', text: 'first' }] }); + await session1.getResult(); + + // Second call — different session but same agent name → should reuse contextId + const session2 = new RemoteSubagentSession( + persistDef, + mockContext, + mockMessageBus, + ); + await session2.send({ message: [{ type: 'text', text: 'second' }] }); + await session2.getResult(); + + expect(callCount).toBe(2); + }); + + it('different agent names have independent session state', async () => { + const def1: RemoteAgentDefinition = { + ...mockDefinition, + name: 'agent-alpha', + }; + const def2: RemoteAgentDefinition = { + ...mockDefinition, + name: 'agent-beta', + }; + + const capturedContextIds: Array = []; + mockClientManager.sendMessageStream.mockImplementation(async function* ( + _name: string, + _query: string, + opts: { contextId?: string }, + ) { + capturedContextIds.push(opts.contextId); + yield { + kind: 'message' as const, + messageId: 'msg-1', + role: 'agent' as const, + contextId: `ctx-for-${_name}`, + parts: [{ kind: 'text' as const, text: 'ok' }], + }; + }); + + const session1 = new RemoteSubagentSession( + def1, + mockContext, + mockMessageBus, + ); + await session1.send({ message: [{ type: 'text', text: 'q' }] }); + await session1.getResult(); + + const session2 = new RemoteSubagentSession( + def2, + mockContext, + mockMessageBus, + ); + await session2.send({ message: [{ type: 'text', text: 'q' }] }); + await session2.getResult(); + + // Both start with no contextId (different agents, different state entries) + expect(capturedContextIds[0]).toBeUndefined(); + expect(capturedContextIds[1]).toBeUndefined(); + }); + + it('taskId is cleared when a terminal-state task chunk is received', async () => { + // A task chunk with a terminal status sets clearTaskId=true, which + // should clear this.taskId so it is NOT passed on the next call. + const agentName = 'clearTaskId-agent'; + const def: RemoteAgentDefinition = { ...mockDefinition, name: agentName }; + + let callCount = 0; + const capturedTaskIds: Array = []; + + mockClientManager.sendMessageStream.mockImplementation(async function* ( + _n: string, + _q: string, + opts: { taskId?: string }, + ) { + callCount++; + capturedTaskIds.push(opts.taskId); + if (callCount === 1) { + // First call: yield a task chunk with taskId + terminal status → clearTaskId + yield { + kind: 'task' as const, + id: 'task-123', + contextId: 'ctx-1', + status: { state: 'completed' as const }, + }; + } else { + yield makeChunk('done'); + } + }); + + const session1 = new RemoteSubagentSession( + def, + mockContext, + mockMessageBus, + ); + await session1.send({ message: [{ type: 'text', text: 'first' }] }); + await session1.getResult(); + + const session2 = new RemoteSubagentSession( + def, + mockContext, + mockMessageBus, + ); + await session2.send({ message: [{ type: 'text', text: 'second' }] }); + await session2.getResult(); + + expect(callCount).toBe(2); + // First call starts with no taskId + expect(capturedTaskIds[0]).toBeUndefined(); + // Second call: taskId was cleared because terminal-state task chunk was received + expect(capturedTaskIds[1]).toBeUndefined(); + }); + }); + + // --------------------------------------------------------------------------- + // Auth setup + // --------------------------------------------------------------------------- + + describe('auth setup', () => { + it('no auth → loadAgent called without auth handler', async () => { + await runSession(); + + expect(mockClientManager.loadAgent).toHaveBeenCalledWith( + 'test-remote-agent', + { type: 'url', url: 'http://test-agent/card' }, + undefined, + ); + }); + + it('definition.auth present → A2AAuthProviderFactory.create called', async () => { + const authDef: RemoteAgentDefinition = { + ...mockDefinition, + name: 'auth-agent', + auth: { + type: 'http' as const, + scheme: 'Bearer' as const, + token: 'secret', + }, + }; + + const mockProvider = { + type: 'http' as const, + headers: vi.fn().mockResolvedValue({ Authorization: 'Bearer secret' }), + shouldRetryWithHeaders: vi.fn(), + } as unknown as A2AAuthProvider; + (A2AAuthProviderFactory.create as Mock).mockResolvedValue(mockProvider); + + await runSession(authDef, 'q'); + + expect(A2AAuthProviderFactory.create).toHaveBeenCalledWith( + expect.objectContaining({ + agentName: 'auth-agent', + agentCardUrl: 'http://test-agent/card', + }), + ); + expect(mockClientManager.loadAgent).toHaveBeenCalledWith( + 'auth-agent', + expect.any(Object), + mockProvider, + ); + }); + + it('auth factory returns undefined → throws error that rejects getResult()', async () => { + const authDef: RemoteAgentDefinition = { + ...mockDefinition, + name: 'failing-auth-agent', + auth: { + type: 'http' as const, + scheme: 'Bearer' as const, + token: 'secret', + }, + }; + + (A2AAuthProviderFactory.create as Mock).mockResolvedValue(undefined); + + const session = new RemoteSubagentSession( + authDef, + mockContext, + mockMessageBus, + ); + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await expect(session.getResult()).rejects.toThrow( + "Failed to create auth provider for agent 'failing-auth-agent'", + ); + }); + + it('agent already loaded → loadAgent not called again', async () => { + // Return a client object (truthy) so getClient returns defined + mockClientManager.getClient.mockReturnValue({}); + + await runSession(); + + expect(mockClientManager.loadAgent).not.toHaveBeenCalled(); + }); + }); + + // --------------------------------------------------------------------------- + // Error handling + // --------------------------------------------------------------------------- + + describe('error handling', () => { + it('stream error → error event + agent_end(failed)', async () => { + mockClientManager.sendMessageStream.mockReturnValue({ + [Symbol.asyncIterator]() { + return { + async next(): Promise> { + throw new Error('network error'); + }, + }; + }, + }); + + const session = new RemoteSubagentSession( + mockDefinition, + mockContext, + mockMessageBus, + ); + const events: AgentEvent[] = []; + session.subscribe((e) => events.push(e)); + + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await expect(session.getResult()).rejects.toThrow(); + + const errEvent = events.find((e) => e.type === 'error'); + expect(errEvent).toBeDefined(); + + const endEvent = events.find((e) => e.type === 'agent_end'); + expect(endEvent).toBeDefined(); + if (endEvent?.type === 'agent_end') { + expect(endEvent.reason).toBe('failed'); + } + }); + + it('missing A2AClientManager → rejects getResult()', async () => { + const mockConfig = { + getA2AClientManager: vi.fn().mockReturnValue(undefined), + injectionService: { + getLatestInjectionIndex: vi.fn().mockReturnValue(0), + }, + } as unknown as Config; + const noClientContext = { + config: mockConfig, + } as unknown as AgentLoopContext; + + const session = new RemoteSubagentSession( + mockDefinition, + noClientContext, + mockMessageBus, + ); + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await expect(session.getResult()).rejects.toThrow( + 'A2AClientManager not available', + ); + }); + }); + + // --------------------------------------------------------------------------- + // Subscription + // --------------------------------------------------------------------------- + + describe('subscription', () => { + it('unsubscribe stops event delivery', async () => { + const session = new RemoteSubagentSession( + mockDefinition, + mockContext, + mockMessageBus, + ); + const received: AgentEvent[] = []; + const unsub = session.subscribe((e) => received.push(e)); + unsub(); + + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await session.getResult(); + + expect(received).toHaveLength(0); + }); + + it('multiple subscribers all receive events', async () => { + const session = new RemoteSubagentSession( + mockDefinition, + mockContext, + mockMessageBus, + ); + const events1: AgentEvent[] = []; + const events2: AgentEvent[] = []; + session.subscribe((e) => events1.push(e)); + session.subscribe((e) => events2.push(e)); + + await session.send({ message: [{ type: 'text', text: 'q' }] }); + await session.getResult(); + + expect(events1.length).toBeGreaterThan(0); + expect(events1).toEqual(events2); + }); + }); + + // --------------------------------------------------------------------------- + // Abort + // --------------------------------------------------------------------------- + + describe('abort()', () => { + it('abort() causes agent_end(reason:aborted)', async () => { + let resolveChunk: (() => void) | undefined; + + // Stream that blocks until we abort + mockClientManager.sendMessageStream.mockImplementation( + async function* () { + // Hang until aborted + await new Promise((resolve) => { + resolveChunk = resolve; + }); + yield makeChunk('Too late'); + }, + ); + + const session = new RemoteSubagentSession( + mockDefinition, + mockContext, + mockMessageBus, + ); + const events: AgentEvent[] = []; + session.subscribe((e) => events.push(e)); + + void session.send({ message: [{ type: 'text', text: 'q' }] }); + + // Wait for agent_start to be emitted before aborting + await vi.waitFor(() => { + expect(events.some((e) => e.type === 'agent_start')).toBe(true); + }); + + await session.abort(); + + // Resolve the hanging chunk generator so it can check the signal + resolveChunk?.(); + + await expect(session.getResult()).rejects.toThrow(); + + const endEvent = events.find((e) => e.type === 'agent_end'); + expect(endEvent).toBeDefined(); + if (endEvent?.type === 'agent_end') { + expect(endEvent.reason).toBe('aborted'); + } + }); + }); + + // --------------------------------------------------------------------------- + // sendMessageStream call args + // --------------------------------------------------------------------------- + + describe('sendMessageStream call arguments', () => { + it('passes the query string from the message payload', async () => { + await runSession(mockDefinition, 'my specific query'); + + expect(mockClientManager.sendMessageStream).toHaveBeenCalledWith( + 'test-remote-agent', + 'my specific query', + expect.objectContaining({ signal: expect.any(Object) }), + ); + }); + + it('uses DEFAULT_QUERY_STRING when message text is empty', async () => { + const session = new RemoteSubagentSession( + mockDefinition, + mockContext, + mockMessageBus, + ); + await session.send({ message: [{ type: 'text', text: '' }] }); + await session.getResult(); + + // DEFAULT_QUERY_STRING = 'Get Started!' + expect(mockClientManager.sendMessageStream).toHaveBeenCalledWith( + 'test-remote-agent', + 'Get Started!', + expect.objectContaining({ signal: expect.any(Object) }), + ); + }); + }); + + // --------------------------------------------------------------------------- + // Concurrent send() guard + // --------------------------------------------------------------------------- + + describe('concurrent send() guard', () => { + it('calling send() while a stream is active throws', async () => { + let resolveChunk!: () => void; + + mockClientManager.sendMessageStream.mockImplementation( + async function* () { + // Block until test releases the chunk + await new Promise((resolve) => { + resolveChunk = resolve; + }); + yield makeChunk('late'); + }, + ); + + const session = new RemoteSubagentSession( + mockDefinition, + mockContext, + mockMessageBus, + ); + + void session.send({ message: [{ type: 'text', text: 'first' }] }); + + // Wait for the stream to actually start (agent_start emitted) + const events: AgentEvent[] = []; + session.subscribe((e) => events.push(e)); + await vi.waitFor(() => { + expect(events.some((e) => e.type === 'agent_start')).toBe(true); + }); + + // Second send() while first stream is active must throw + await expect( + session.send({ message: [{ type: 'text', text: 'second' }] }), + ).rejects.toThrow('cannot be called while a stream is active'); + + // Clean up: release the blocked generator so getResult() can settle + resolveChunk(); + await session.getResult().catch(() => {}); + }); + }); +}); diff --git a/packages/core/src/agents/remote-subagent-protocol.ts b/packages/core/src/agents/remote-subagent-protocol.ts index faeb29051d..825e67c5a9 100644 --- a/packages/core/src/agents/remote-subagent-protocol.ts +++ b/packages/core/src/agents/remote-subagent-protocol.ts @@ -180,6 +180,9 @@ class RemoteSubagentProtocol implements AgentProtocol { try { await this._runStream(query); } catch (err: unknown) { + // Save state before rejecting so callers see updated contextId/taskId + // immediately after the returned promise settles. + this._saveSessionState(); if (this._abortController.signal.aborted || isAbortLikeError(err)) { this._ensureAgentEnd('aborted'); this._resultReject(err); @@ -188,14 +191,16 @@ class RemoteSubagentProtocol implements AgentProtocol { this._resultReject(err); } this._clearActiveStream(); - } finally { - RemoteSubagentProtocol._sessionState.set(this.definition.name, { - contextId: this.contextId, - taskId: this.taskId, - }); } } + private _saveSessionState(): void { + RemoteSubagentProtocol._sessionState.set(this.definition.name, { + contextId: this.contextId, + taskId: this.taskId, + }); + } + private async _runStream(query: string): Promise { const clientManager = this.context.config.getA2AClientManager(); if (!clientManager) { @@ -228,6 +233,7 @@ class RemoteSubagentProtocol implements AgentProtocol { for await (const chunk of stream) { if (this._abortController.signal.aborted) { + this._saveSessionState(); this._finishStream('aborted'); this._resultReject(new Error('Operation aborted')); this._clearActiveStream(); @@ -283,6 +289,9 @@ class RemoteSubagentProtocol implements AgentProtocol { this._finishStream('completed'); + // Save state before resolving so callers see updated contextId/taskId + // immediately after the returned promise settles. + this._saveSessionState(); this._resultResolve({ llmContent: [{ text: finalOutput }], returnDisplay: finalProgress,