diff --git a/packages/core/src/agent/agent-session.test.ts b/packages/core/src/agent/agent-session.test.ts new file mode 100644 index 0000000000..c390d719d4 --- /dev/null +++ b/packages/core/src/agent/agent-session.test.ts @@ -0,0 +1,279 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, expect, it } from 'vitest'; +import { AgentSession } from './agent-session.js'; +import { MockAgentProtocol } from './mock.js'; +import type { AgentEvent } from './types.js'; + +describe('AgentSession', () => { + it('should passthrough simple methods', async () => { + const protocol = new MockAgentProtocol(); + const session = new AgentSession(protocol); + + protocol.pushResponse([{ type: 'message' }]); + await session.send({ update: { title: 't' } }); + // update, agent_start, message, agent_end = 4 events + expect(session.events).toHaveLength(4); + + let emitted = false; + session.subscribe(() => { + emitted = true; + }); + protocol.pushResponse([]); + await session.send({ update: { title: 't' } }); + expect(emitted).toBe(true); + + protocol.pushResponse([], { keepOpen: true }); + await session.send({ update: { title: 't' } }); + await session.abort(); + expect( + session.events.some( + (e) => + e.type === 'agent_end' && + (e as AgentEvent<'agent_end'>).reason === 'aborted', + ), + ).toBe(true); + }); + + it('should yield events via sendStream', async () => { + const protocol = new MockAgentProtocol(); + const session = new AgentSession(protocol); + + protocol.pushResponse([ + { + type: 'message', + role: 'agent', + content: [{ type: 'text', text: 'hello' }], + }, + ]); + + const events: AgentEvent[] = []; + for await (const event of session.sendStream({ + message: [{ type: 'text', text: 'hi' }], + })) { + events.push(event); + } + + // agent_start, agent message, agent_end = 3 events (user message skipped) + expect(events).toHaveLength(3); + expect(events[0].type).toBe('agent_start'); + expect(events[1].type).toBe('message'); + expect((events[1] as AgentEvent<'message'>).role).toBe('agent'); + expect(events[2].type).toBe('agent_end'); + }); + + it('should filter events by streamId in sendStream', async () => { + const protocol = new MockAgentProtocol(); + const session = new AgentSession(protocol); + + protocol.pushResponse([{ type: 'message' }]); + + const events: AgentEvent[] = []; + const stream = session.sendStream({ update: { title: 'foo' } }); + + for await (const event of stream) { + events.push(event); + } + + expect(events).toHaveLength(3); // agent_start, message, agent_end (update skipped) + const streamId = events[0].streamId; + expect(streamId).not.toBeNull(); + expect(events.every((e) => e.streamId === streamId)).toBe(true); + }); + + it('should handle events arriving before send() resolves', async () => { + const protocol = new MockAgentProtocol(); + const session = new AgentSession(protocol); + + protocol.pushResponse([{ type: 'message' }]); + + const events: AgentEvent[] = []; + for await (const event of session.sendStream({ + update: { title: 'foo' }, + })) { + events.push(event); + } + + expect(events).toHaveLength(3); // agent_start, message, agent_end (update skipped) + expect(events[0].type).toBe('agent_start'); + expect(events[1].type).toBe('message'); + expect(events[2].type).toBe('agent_end'); + }); + + it('should return immediately from sendStream if streamId is null', async () => { + const protocol = new MockAgentProtocol(); + const session = new AgentSession(protocol); + + // No response queued, so send() returns streamId: null + const events: AgentEvent[] = []; + for await (const event of session.sendStream({ + update: { title: 'foo' }, + })) { + events.push(event); + } + + expect(events).toHaveLength(0); + expect(protocol.events).toHaveLength(1); + expect(protocol.events[0].type).toBe('session_update'); + }); + + it('should skip events that occur before agent_start', async () => { + const protocol = new MockAgentProtocol(); + const session = new AgentSession(protocol); + + // Custom emission to ensure events happen before agent_start + protocol.pushResponse([ + { + type: 'message', + role: 'agent', + content: [{ type: 'text', text: 'hello' }], + }, + ]); + + // We can't easily inject events before agent_start with MockAgentProtocol.pushResponse + // because it emits them all together. + // But we know session_update is emitted first. + + const events: AgentEvent[] = []; + for await (const event of session.sendStream({ + message: [{ type: 'text', text: 'hi' }], + })) { + events.push(event); + } + + // The session_update (from the 'hi' message) should be skipped. + expect(events.some((e) => e.type === 'session_update')).toBe(false); + expect(events[0].type).toBe('agent_start'); + }); + + describe('stream()', () => { + it('should replay events after eventId', async () => { + const protocol = new MockAgentProtocol(); + const session = new AgentSession(protocol); + + // Create some events + protocol.pushResponse([{ type: 'message' }]); + await session.send({ update: { title: 't1' } }); + // Wait for events to be emitted + await new Promise((resolve) => setTimeout(resolve, 10)); + + const allEvents = session.events; + expect(allEvents.length).toBeGreaterThan(2); + const eventId = allEvents[1].id; + + const streamedEvents: AgentEvent[] = []; + for await (const event of session.stream({ eventId })) { + streamedEvents.push(event); + } + + expect(streamedEvents).toEqual(allEvents.slice(2)); + }); + + it('should replay events for streamId starting with agent_start', async () => { + const protocol = new MockAgentProtocol(); + const session = new AgentSession(protocol); + + protocol.pushResponse([{ type: 'message' }]); + const { streamId } = await session.send({ update: { title: 't1' } }); + await new Promise((resolve) => setTimeout(resolve, 10)); + + const allEvents = session.events; + const startEventIndex = allEvents.findIndex( + (e) => e.type === 'agent_start' && e.streamId === streamId, + ); + expect(startEventIndex).toBeGreaterThan(-1); + + const streamedEvents: AgentEvent[] = []; + for await (const event of session.stream({ streamId: streamId! })) { + streamedEvents.push(event); + } + + expect(streamedEvents).toEqual(allEvents.slice(startEventIndex)); + }); + + it('should continue listening for active stream after replay', async () => { + const protocol = new MockAgentProtocol(); + const session = new AgentSession(protocol); + + // Start a stream but keep it open + protocol.pushResponse([{ type: 'message' }], { keepOpen: true }); + const { streamId } = await session.send({ update: { title: 't1' } }); + await new Promise((resolve) => setTimeout(resolve, 10)); + + const streamedEvents: AgentEvent[] = []; + const streamPromise = (async () => { + for await (const event of session.stream({ streamId: streamId! })) { + streamedEvents.push(event); + } + })(); + + // Push more to the stream + await new Promise((resolve) => setTimeout(resolve, 20)); + protocol.pushToStream(streamId!, [{ type: 'message' }], { close: true }); + + await streamPromise; + + const allEvents = session.events; + const startEventIndex = allEvents.findIndex( + (e) => e.type === 'agent_start' && e.streamId === streamId, + ); + expect(streamedEvents).toEqual(allEvents.slice(startEventIndex)); + expect(streamedEvents.at(-1)?.type).toBe('agent_end'); + }); + + it('should follow an active stream if no options provided', async () => { + const protocol = new MockAgentProtocol(); + const session = new AgentSession(protocol); + + protocol.pushResponse([{ type: 'message' }], { keepOpen: true }); + const { streamId } = await session.send({ update: { title: 't1' } }); + await new Promise((resolve) => setTimeout(resolve, 10)); + + const streamedEvents: AgentEvent[] = []; + const streamPromise = (async () => { + for await (const event of session.stream()) { + streamedEvents.push(event); + } + })(); + + await new Promise((resolve) => setTimeout(resolve, 20)); + protocol.pushToStream(streamId!, [{ type: 'message' }], { close: true }); + await streamPromise; + + expect(streamedEvents.length).toBeGreaterThan(0); + expect(streamedEvents.at(-1)?.type).toBe('agent_end'); + }); + + it('should ONLY yield events for specific streamId even if newer streams exist', async () => { + const protocol = new MockAgentProtocol(); + const session = new AgentSession(protocol); + + // Stream 1 + protocol.pushResponse([{ type: 'message' }]); + const { streamId: streamId1 } = await session.send({ + update: { title: 's1' }, + }); + + // Stream 2 + protocol.pushResponse([{ type: 'message' }]); + const { streamId: streamId2 } = await session.send({ + update: { title: 's2' }, + }); + + await new Promise((resolve) => setTimeout(resolve, 20)); + + const streamedEvents: AgentEvent[] = []; + for await (const event of session.stream({ streamId: streamId1! })) { + streamedEvents.push(event); + } + + expect(streamedEvents.every((e) => e.streamId === streamId1)).toBe(true); + expect(streamedEvents.some((e) => e.type === 'agent_end')).toBe(true); + expect(streamedEvents.some((e) => e.streamId === streamId2)).toBe(false); + }); + }); +}); diff --git a/packages/core/src/agent/agent-session.ts b/packages/core/src/agent/agent-session.ts new file mode 100644 index 0000000000..0d9fc86bb0 --- /dev/null +++ b/packages/core/src/agent/agent-session.ts @@ -0,0 +1,212 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { + AgentProtocol, + AgentSend, + AgentEvent, + Unsubscribe, +} from './types.js'; + +/** + * AgentSession is a wrapper around AgentProtocol that provides a more + * convenient API for consuming agent activity as an AsyncIterable. + */ +export class AgentSession implements AgentProtocol { + private _protocol: AgentProtocol; + + constructor(protocol: AgentProtocol) { + this._protocol = protocol; + } + + async send(payload: AgentSend): Promise<{ streamId: string | null }> { + return this._protocol.send(payload); + } + + subscribe(callback: (event: AgentEvent) => void): Unsubscribe { + return this._protocol.subscribe(callback); + } + + async abort(): Promise { + return this._protocol.abort(); + } + + get events(): AgentEvent[] { + return this._protocol.events; + } + + /** + * Sends a payload to the agent and returns an AsyncIterable that yields + * events for the resulting stream. + * + * @param payload The payload to send to the agent. + */ + async *sendStream(payload: AgentSend): AsyncIterable { + const result = await this._protocol.send(payload); + const streamId = result.streamId; + + if (streamId === null) { + return; + } + + yield* this.stream({ streamId }); + } + + /** + * Returns an AsyncIterable that yields events from the agent session, + * optionally replaying events from history or reattaching to an existing stream. + * + * @param options Options for replaying or reattaching to the event stream. + */ + async *stream( + options: { + eventId?: string; + streamId?: string; + } = {}, + ): AsyncIterable { + let resolve: (() => void) | undefined; + let next = new Promise((res) => { + resolve = res; + }); + + let eventQueue: AgentEvent[] = []; + const earlyEvents: AgentEvent[] = []; + let done = false; + let trackedStreamId = options.streamId; + let started = false; + + // 1. Subscribe early to avoid missing any events that occur during replay setup + const unsubscribe = this._protocol.subscribe((event) => { + if (done) return; + + if (!started) { + earlyEvents.push(event); + return; + } + + if (trackedStreamId && event.streamId !== trackedStreamId) return; + + // If we don't have a tracked stream yet, the first agent_start we see becomes it. + if (!trackedStreamId && event.type === 'agent_start') { + trackedStreamId = event.streamId ?? undefined; + } + + // If we still don't have a tracked stream and we aren't replaying everything (eventId), ignore. + if (!trackedStreamId && !options.eventId) return; + + eventQueue.push(event); + if ( + event.type === 'agent_end' && + event.streamId === (trackedStreamId ?? null) + ) { + done = true; + } + + const currentResolve = resolve; + next = new Promise((r) => { + resolve = r; + }); + currentResolve?.(); + }); + + try { + const currentEvents = this._protocol.events; + let replayStartIndex = -1; + + if (options.eventId) { + const index = currentEvents.findIndex((e) => e.id === options.eventId); + if (index !== -1) { + replayStartIndex = index + 1; + } + } else if (options.streamId) { + const index = currentEvents.findIndex( + (e) => e.type === 'agent_start' && e.streamId === options.streamId, + ); + if (index !== -1) { + replayStartIndex = index; + } + } + + if (replayStartIndex !== -1) { + for (let i = replayStartIndex; i < currentEvents.length; i++) { + const event = currentEvents[i]; + if (options.streamId && event.streamId !== options.streamId) continue; + + eventQueue.push(event); + if (event.type === 'agent_start' && !trackedStreamId) { + trackedStreamId = event.streamId ?? undefined; + } + if ( + event.type === 'agent_end' && + event.streamId === (trackedStreamId ?? null) + ) { + done = true; + break; + } + } + } + + if (!done && !trackedStreamId) { + // Find active stream in history + const activeStarts = currentEvents.filter( + (e) => e.type === 'agent_start', + ); + for (let i = activeStarts.length - 1; i >= 0; i--) { + const start = activeStarts[i]; + if ( + !currentEvents.some( + (e) => e.type === 'agent_end' && e.streamId === start.streamId, + ) + ) { + trackedStreamId = start.streamId ?? undefined; + break; + } + } + } + + // If we replayed to the end and no stream is active, and we were specifically + // replaying from an eventId (or we've already finished the stream we were looking for), we are done. + if (!done && !trackedStreamId && options.eventId) { + done = true; + } + + started = true; + + // Process events that arrived while we were replaying + for (const event of earlyEvents) { + if (done) break; + if (trackedStreamId && event.streamId !== trackedStreamId) continue; + if (!trackedStreamId && event.type === 'agent_start') { + trackedStreamId = event.streamId ?? undefined; + } + if (!trackedStreamId && !options.eventId) continue; + + eventQueue.push(event); + if ( + event.type === 'agent_end' && + event.streamId === (trackedStreamId ?? null) + ) { + done = true; + } + } + + while (true) { + if (eventQueue.length > 0) { + const eventsToYield = eventQueue; + eventQueue = []; + for (const event of eventsToYield) { + yield event; + } + } + + if (done) break; + await next; + } + } finally { + unsubscribe(); + } + } +} diff --git a/packages/core/src/agent/mock.test.ts b/packages/core/src/agent/mock.test.ts index 41672223a9..4f102d5dbd 100644 --- a/packages/core/src/agent/mock.test.ts +++ b/packages/core/src/agent/mock.test.ts @@ -5,12 +5,24 @@ */ import { describe, expect, it } from 'vitest'; -import { MockAgentSession } from './mock.js'; -import type { AgentEvent } from './types.js'; +import { MockAgentProtocol } from './mock.js'; +import type { AgentEvent, AgentProtocol } from './types.js'; -describe('MockAgentSession', () => { - it('should yield queued events on send and stream', async () => { - const session = new MockAgentSession(); +const waitForStreamEnd = (session: AgentProtocol): Promise => + new Promise((resolve) => { + const events: AgentEvent[] = []; + const unsubscribe = session.subscribe((e) => { + events.push(e); + if (e.type === 'agent_end') { + unsubscribe(); + resolve(events); + } + }); + }); + +describe('MockAgentProtocol', () => { + it('should emit queued events on send and subscribe', async () => { + const session = new MockAgentProtocol(); const event1 = { type: 'message', role: 'agent', @@ -19,31 +31,30 @@ describe('MockAgentSession', () => { session.pushResponse([event1]); + const streamPromise = waitForStreamEnd(session); + const { streamId } = await session.send({ message: [{ type: 'text', text: 'hi' }], }); expect(streamId).toBeDefined(); - const streamedEvents: AgentEvent[] = []; - for await (const event of session.stream()) { - streamedEvents.push(event); - } + const streamedEvents = await streamPromise; - // Auto stream_start, auto user message, agent message, auto stream_end = 4 events + // Ordered: user message, agent_start, agent message, agent_end = 4 events expect(streamedEvents).toHaveLength(4); - expect(streamedEvents[0].type).toBe('stream_start'); - expect(streamedEvents[1].type).toBe('message'); - expect((streamedEvents[1] as AgentEvent<'message'>).role).toBe('user'); + expect(streamedEvents[0].type).toBe('message'); + expect((streamedEvents[0] as AgentEvent<'message'>).role).toBe('user'); + expect(streamedEvents[1].type).toBe('agent_start'); expect(streamedEvents[2].type).toBe('message'); expect((streamedEvents[2] as AgentEvent<'message'>).role).toBe('agent'); - expect(streamedEvents[3].type).toBe('stream_end'); + expect(streamedEvents[3].type).toBe('agent_end'); expect(session.events).toHaveLength(4); expect(session.events).toEqual(streamedEvents); }); it('should handle multiple responses', async () => { - const session = new MockAgentSession(); + const session = new MockAgentProtocol(); // Test with empty payload (no message injected) session.pushResponse([]); @@ -57,204 +68,154 @@ describe('MockAgentSession', () => { ]); // First send + const stream1Promise = waitForStreamEnd(session); const { streamId: s1 } = await session.send({ - update: {}, + update: { title: 't1' }, }); - const events1: AgentEvent[] = []; - for await (const e of session.stream()) events1.push(e); - expect(events1).toHaveLength(3); // stream_start, session_update, stream_end - expect(events1[0].type).toBe('stream_start'); - expect(events1[1].type).toBe('session_update'); - expect(events1[2].type).toBe('stream_end'); + const events1 = await stream1Promise; + expect(events1).toHaveLength(3); // session_update, agent_start, agent_end + expect(events1[0].type).toBe('session_update'); + expect(events1[1].type).toBe('agent_start'); + expect(events1[2].type).toBe('agent_end'); // Second send + const stream2Promise = waitForStreamEnd(session); const { streamId: s2 } = await session.send({ - update: {}, + update: { title: 't2' }, }); expect(s1).not.toBe(s2); - const events2: AgentEvent[] = []; - for await (const e of session.stream()) events2.push(e); - expect(events2).toHaveLength(4); // stream_start, session_update, error, stream_end - expect(events2[1].type).toBe('session_update'); + const events2 = await stream2Promise; + expect(events2).toHaveLength(4); // session_update, agent_start, error, agent_end + expect(events2[0].type).toBe('session_update'); + expect(events2[1].type).toBe('agent_start'); expect(events2[2].type).toBe('error'); + expect(events2[3].type).toBe('agent_end'); expect(session.events).toHaveLength(7); }); - it('should allow streaming by streamId', async () => { - const session = new MockAgentSession(); - session.pushResponse([{ type: 'message' }]); - - const { streamId } = await session.send({ - update: {}, - }); + it('should handle abort on a waiting stream', async () => { + const session = new MockAgentProtocol(); + // Use keepOpen to prevent auto agent_end + session.pushResponse([{ type: 'message' }], { keepOpen: true }); const events: AgentEvent[] = []; - for await (const e of session.stream({ streamId })) { + let resolveStream: (evs: AgentEvent[]) => void; + const streamPromise = new Promise((res) => { + resolveStream = res; + }); + + session.subscribe((e) => { events.push(e); - } - expect(events).toHaveLength(4); // start, update, message, end - }); + if (e.type === 'agent_end') { + resolveStream(events); + } + }); - it('should throw when streaming non-existent streamId', async () => { - const session = new MockAgentSession(); - await expect(async () => { - const stream = session.stream({ streamId: 'invalid' }); - await stream.next(); - }).rejects.toThrow('Stream not found: invalid'); - }); + const { streamId: _streamId } = await session.send({ + update: { title: 't' }, + }); - it('should throw when streaming non-existent eventId', async () => { - const session = new MockAgentSession(); - session.pushResponse([{ type: 'message' }]); - await session.send({ update: {} }); - - await expect(async () => { - const stream = session.stream({ eventId: 'invalid' }); - await stream.next(); - }).rejects.toThrow('Event not found: invalid'); - }); - - it('should handle abort on a waiting stream', async () => { - const session = new MockAgentSession(); - // Use keepOpen to prevent auto stream_end - session.pushResponse([{ type: 'message' }], { keepOpen: true }); - const { streamId } = await session.send({ update: {} }); - - const stream = session.stream({ streamId }); - - // Read initial events - const e1 = await stream.next(); - expect(e1.value.type).toBe('stream_start'); - const e2 = await stream.next(); - expect(e2.value.type).toBe('session_update'); - const e3 = await stream.next(); - expect(e3.value.type).toBe('message'); + // Initial events should have been emitted + expect(events.map((e) => e.type)).toEqual([ + 'session_update', + 'agent_start', + 'message', + ]); // At this point, the stream should be "waiting" for more events because it's still active - // and hasn't seen a stream_end. - const abortPromise = session.abort(); - const e4 = await stream.next(); - expect(e4.value.type).toBe('stream_end'); - expect((e4.value as AgentEvent<'stream_end'>).reason).toBe('aborted'); + // and hasn't seen an agent_end. + await session.abort(); - await abortPromise; - expect(await stream.next()).toEqual({ done: true, value: undefined }); + const finalEvents = await streamPromise; + expect(finalEvents[3].type).toBe('agent_end'); + expect((finalEvents[3] as AgentEvent<'agent_end'>).reason).toBe('aborted'); }); it('should handle pushToStream on a waiting stream', async () => { - const session = new MockAgentSession(); + const session = new MockAgentProtocol(); session.pushResponse([], { keepOpen: true }); - const { streamId } = await session.send({ update: {} }); - const stream = session.stream({ streamId }); - await stream.next(); // start - await stream.next(); // update + const events: AgentEvent[] = []; + session.subscribe((e) => events.push(e)); + + const { streamId } = await session.send({ update: { title: 't' } }); + + expect(events.map((e) => e.type)).toEqual([ + 'session_update', + 'agent_start', + ]); // Push new event to active stream - session.pushToStream(streamId, [{ type: 'message' }]); + session.pushToStream(streamId!, [{ type: 'message' }]); - const e3 = await stream.next(); - expect(e3.value.type).toBe('message'); + expect(events).toHaveLength(3); + expect(events[2].type).toBe('message'); await session.abort(); - const e4 = await stream.next(); - expect(e4.value.type).toBe('stream_end'); + expect(events).toHaveLength(4); + expect(events[3].type).toBe('agent_end'); }); it('should handle pushToStream with close option', async () => { - const session = new MockAgentSession(); + const session = new MockAgentProtocol(); session.pushResponse([], { keepOpen: true }); - const { streamId } = await session.send({ update: {} }); - const stream = session.stream({ streamId }); - await stream.next(); // start - await stream.next(); // update + const streamPromise = waitForStreamEnd(session); + const { streamId } = await session.send({ update: { title: 't' } }); // Push new event and close - session.pushToStream(streamId, [{ type: 'message' }], { close: true }); + session.pushToStream(streamId!, [{ type: 'message' }], { close: true }); - const e3 = await stream.next(); - expect(e3.value.type).toBe('message'); - - const e4 = await stream.next(); - expect(e4.value.type).toBe('stream_end'); - expect((e4.value as AgentEvent<'stream_end'>).reason).toBe('completed'); - - expect(await stream.next()).toEqual({ done: true, value: undefined }); + const events = await streamPromise; + expect(events.map((e) => e.type)).toEqual([ + 'session_update', + 'agent_start', + 'message', + 'agent_end', + ]); + expect((events[3] as AgentEvent<'agent_end'>).reason).toBe('completed'); }); - it('should not double up on stream_end if provided manually', async () => { - const session = new MockAgentSession(); + it('should not double up on agent_end if provided manually', async () => { + const session = new MockAgentProtocol(); session.pushResponse([ { type: 'message' }, - { type: 'stream_end', reason: 'completed' }, + { type: 'agent_end', reason: 'completed' }, ]); - const { streamId } = await session.send({ update: {} }); - const events: AgentEvent[] = []; - for await (const e of session.stream({ streamId })) { - events.push(e); - } + const streamPromise = waitForStreamEnd(session); + await session.send({ update: { title: 't' } }); - const endEvents = events.filter((e) => e.type === 'stream_end'); + const events = await streamPromise; + const endEvents = events.filter((e) => e.type === 'agent_end'); expect(endEvents).toHaveLength(1); }); - it('should stream after eventId', async () => { - const session = new MockAgentSession(); - // Use manual IDs to test resumption - session.pushResponse([ - { type: 'stream_start', id: 'e1' }, - { type: 'message', id: 'e2' }, - { type: 'stream_end', id: 'e3' }, - ]); - - await session.send({ update: {} }); - - // Stream first event only - const first: AgentEvent[] = []; - for await (const e of session.stream()) { - first.push(e); - if (e.id === 'e1') break; - } - expect(first).toHaveLength(1); - expect(first[0].id).toBe('e1'); - - // Resume from e1 - const second: AgentEvent[] = []; - for await (const e of session.stream({ eventId: 'e1' })) { - second.push(e); - } - expect(second).toHaveLength(3); // update, message, end - expect(second[0].type).toBe('session_update'); - expect(second[1].id).toBe('e2'); - expect(second[2].id).toBe('e3'); - }); - it('should handle elicitations', async () => { - const session = new MockAgentSession(); + const session = new MockAgentProtocol(); session.pushResponse([]); + const streamPromise = waitForStreamEnd(session); await session.send({ elicitations: [ { requestId: 'r1', action: 'accept', content: { foo: 'bar' } }, ], }); - const events: AgentEvent[] = []; - for await (const e of session.stream()) events.push(e); - - expect(events[1].type).toBe('elicitation_response'); - expect((events[1] as AgentEvent<'elicitation_response'>).requestId).toBe( + const events = await streamPromise; + expect(events[0].type).toBe('elicitation_response'); + expect((events[0] as AgentEvent<'elicitation_response'>).requestId).toBe( 'r1', ); + expect(events[1].type).toBe('agent_start'); }); it('should handle updates and track state', async () => { - const session = new MockAgentSession(); + const session = new MockAgentProtocol(); session.pushResponse([]); + const streamPromise = waitForStreamEnd(session); await session.send({ update: { title: 'New Title', model: 'gpt-4', config: { x: 1 } }, }); @@ -263,15 +224,24 @@ describe('MockAgentSession', () => { expect(session.model).toBe('gpt-4'); expect(session.config).toEqual({ x: 1 }); - const events: AgentEvent[] = []; - for await (const e of session.stream()) events.push(e); - expect(events[1].type).toBe('session_update'); + const events = await streamPromise; + expect(events[0].type).toBe('session_update'); + expect(events[1].type).toBe('agent_start'); + }); + + it('should return streamId: null if no response queued', async () => { + const session = new MockAgentProtocol(); + const { streamId } = await session.send({ update: { title: 'foo' } }); + expect(streamId).toBeNull(); + expect(session.events).toHaveLength(1); + expect(session.events[0].type).toBe('session_update'); + expect(session.events[0].streamId).toBeNull(); }); it('should throw on action', async () => { - const session = new MockAgentSession(); + const session = new MockAgentProtocol(); await expect( session.send({ action: { type: 'foo', data: {} } }), - ).rejects.toThrow('Actions not supported in MockAgentSession: foo'); + ).rejects.toThrow('Actions not supported in MockAgentProtocol: foo'); }); }); diff --git a/packages/core/src/agent/mock.ts b/packages/core/src/agent/mock.ts index 7baeb61a83..f29e87f878 100644 --- a/packages/core/src/agent/mock.ts +++ b/packages/core/src/agent/mock.ts @@ -9,31 +9,32 @@ import type { AgentEventCommon, AgentEventData, AgentSend, - AgentSession, + AgentProtocol, + Unsubscribe, } from './types.js'; export type MockAgentEvent = Partial & AgentEventData; export interface PushResponseOptions { - /** If true, does not automatically add a stream_end event. */ + /** If true, does not automatically add an agent_end event. */ keepOpen?: boolean; } /** - * A mock implementation of AgentSession for testing. + * A mock implementation of AgentProtocol for testing. * Allows queuing responses that will be yielded when send() is called. */ -export class MockAgentSession implements AgentSession { +export class MockAgentProtocol implements AgentProtocol { private _events: AgentEvent[] = []; private _responses: Array<{ events: MockAgentEvent[]; options?: PushResponseOptions; }> = []; - private _streams = new Map(); + private _subscribers = new Set<(event: AgentEvent) => void>(); private _activeStreamIds = new Set(); - private _lastStreamId?: string; + private _lastStreamId?: string | null; private _nextEventId = 1; - private _streamResolvers = new Map void>>(); + private _nextStreamId = 1; title?: string; model?: string; @@ -50,12 +51,28 @@ export class MockAgentSession implements AgentSession { return this._events; } + subscribe(callback: (event: AgentEvent) => void): Unsubscribe { + this._subscribers.add(callback); + return () => this._subscribers.delete(callback); + } + + private _emit(event: AgentEvent) { + if (!this._events.some((e) => e.id === event.id)) { + this._events.push(event); + } + for (const callback of this._subscribers) { + callback(event); + } + if (event.type === 'agent_end' && event.streamId) { + this._activeStreamIds.delete(event.streamId); + } + } + /** * Queues a sequence of events to be "emitted" by the agent in response to the * next send() call. */ pushResponse(events: MockAgentEvent[], options?: PushResponseOptions) { - // We store them as data and normalize them when send() is called this._responses.push({ events, options }); } @@ -67,11 +84,6 @@ export class MockAgentSession implements AgentSession { events: MockAgentEvent[], options?: { close?: boolean }, ) { - const stream = this._streams.get(streamId); - if (!stream) { - throw new Error(`Stream not found: ${streamId}`); - } - const now = new Date().toISOString(); for (const eventData of events) { const event: AgentEvent = { @@ -80,205 +92,147 @@ export class MockAgentSession implements AgentSession { timestamp: eventData.timestamp ?? now, streamId: eventData.streamId ?? streamId, } as AgentEvent; - stream.push(event); + this._emit(event); } if ( options?.close && - !events.some((eventData) => eventData.type === 'stream_end') + !events.some((eventData) => eventData.type === 'agent_end') ) { - stream.push({ + this._emit({ id: `e-${this._nextEventId++}`, timestamp: now, streamId, - type: 'stream_end', + type: 'agent_end', reason: 'completed', } as AgentEvent); } - - this._notify(streamId); } - private _notify(streamId: string) { - const resolvers = this._streamResolvers.get(streamId); - if (resolvers) { - this._streamResolvers.delete(streamId); - for (const resolve of resolvers) resolve(); - } - } - - async send(payload: AgentSend): Promise<{ streamId: string }> { - const { events: response, options } = this._responses.shift() ?? { + async send(payload: AgentSend): Promise<{ streamId: string | null }> { + const responseData = this._responses.shift(); + const { events: response, options } = responseData ?? { events: [], }; - const streamId = - response[0]?.streamId ?? `mock-stream-${this._streams.size + 1}`; + + // If there were queued responses (even if empty array), we trigger a stream. + const hasResponseEvents = responseData !== undefined; + const streamId = hasResponseEvents + ? (response[0]?.streamId ?? `mock-stream-${this._nextStreamId++}`) + : null; const now = new Date().toISOString(); + const eventsToEmit: AgentEvent[] = []; - if (!response.some((eventData) => eventData.type === 'stream_start')) { - response.unshift({ - type: 'stream_start', - streamId, - }); - } - - const startIndex = response.findIndex( - (eventData) => eventData.type === 'stream_start', - ); + // Helper to normalize and prepare for emission + const normalize = (eventData: MockAgentEvent): AgentEvent => + ({ + ...eventData, + id: eventData.id ?? `e-${this._nextEventId++}`, + timestamp: eventData.timestamp ?? now, + streamId: eventData.streamId ?? streamId, + }) as AgentEvent; + // 1. User/Update event (BEFORE agent_start) if ('message' in payload && payload.message) { - response.splice(startIndex + 1, 0, { - type: 'message', - role: 'user', - content: payload.message, - _meta: payload._meta, - }); - } else if ('elicitations' in payload && payload.elicitations) { - payload.elicitations.forEach((elicitation, i) => { - response.splice(startIndex + 1 + i, 0, { - type: 'elicitation_response', - ...elicitation, + eventsToEmit.push( + normalize({ + type: 'message', + role: 'user', + content: payload.message, _meta: payload._meta, - }); + }), + ); + } else if ('elicitations' in payload && payload.elicitations) { + payload.elicitations.forEach((elicitation) => { + eventsToEmit.push( + normalize({ + type: 'elicitation_response', + ...elicitation, + _meta: payload._meta, + }), + ); }); - } else if ('update' in payload && payload.update) { + } else if ( + 'update' in payload && + payload.update && + Object.keys(payload.update).length > 0 + ) { if (payload.update.title) this.title = payload.update.title; if (payload.update.model) this.model = payload.update.model; if (payload.update.config) { this.config = payload.update.config; } - response.splice(startIndex + 1, 0, { - type: 'session_update', - ...payload.update, - _meta: payload._meta, - }); + eventsToEmit.push( + normalize({ + type: 'session_update', + ...payload.update, + _meta: payload._meta, + }), + ); } else if ('action' in payload && payload.action) { throw new Error( - `Actions not supported in MockAgentSession: ${payload.action.type}`, + `Actions not supported in MockAgentProtocol: ${payload.action.type}`, ); } - if ( - !options?.keepOpen && - !response.some((eventData) => eventData.type === 'stream_end') - ) { - response.push({ - type: 'stream_end', - reason: 'completed', - streamId, - }); - } - - const normalizedResponse: AgentEvent[] = []; - for (const eventData of response) { - const event: AgentEvent = { - ...eventData, - id: eventData.id ?? `e-${this._nextEventId++}`, - timestamp: eventData.timestamp ?? now, - streamId: eventData.streamId ?? streamId, - } as AgentEvent; - normalizedResponse.push(event); - } - - this._streams.set(streamId, normalizedResponse); - this._activeStreamIds.add(streamId); - this._lastStreamId = streamId; - - return { streamId }; - } - - async *stream(options?: { - streamId?: string; - eventId?: string; - }): AsyncIterableIterator { - let streamId = options?.streamId; - - if (options?.eventId) { - const event = this._events.find( - (eventData) => eventData.id === options.eventId, - ); - if (!event) { - throw new Error(`Event not found: ${options.eventId}`); - } - streamId = streamId ?? event.streamId; - } - - streamId = streamId ?? this._lastStreamId; - - if (!streamId) { - return; - } - - const events = this._streams.get(streamId); - if (!events) { - throw new Error(`Stream not found: ${streamId}`); - } - - let i = 0; - if (options?.eventId) { - const idx = events.findIndex( - (eventData) => eventData.id === options.eventId, - ); - if (idx !== -1) { - i = idx + 1; - } else { - // This should theoretically not happen if the event was found in this._events - // but the trajectories match. - throw new Error( - `Event ${options.eventId} not found in stream ${streamId}`, + // 2. agent_start (if stream) + if (streamId) { + if (!response.some((eventData) => eventData.type === 'agent_start')) { + eventsToEmit.push( + normalize({ + type: 'agent_start', + streamId, + }), ); } } - while (true) { - if (i < events.length) { - const event = events[i++]; - // Add to session trajectory if not already present - if (!this._events.some((eventData) => eventData.id === event.id)) { - this._events.push(event); - } - yield event; + // 3. Response events + for (const eventData of response) { + eventsToEmit.push(normalize(eventData)); + } - // If it's a stream_end, we're done with this stream - if (event.type === 'stream_end') { - this._activeStreamIds.delete(streamId); - return; - } - } else { - // No more events in the array currently. Check if we're still active. - if (!this._activeStreamIds.has(streamId)) { - // If we weren't terminated by a stream_end but we're no longer active, - // it was an abort. - const abortEvent: AgentEvent = { - id: `e-${this._nextEventId++}`, - timestamp: new Date().toISOString(), + // 4. agent_end (if stream and not manual) + if (streamId && !options?.keepOpen) { + if (!eventsToEmit.some((e) => e.type === 'agent_end')) { + eventsToEmit.push( + normalize({ + type: 'agent_end', + reason: 'completed', streamId, - type: 'stream_end', - reason: 'aborted', - } as AgentEvent; - if (!this._events.some((e) => e.id === abortEvent.id)) { - this._events.push(abortEvent); - } - yield abortEvent; - return; - } - - // Wait for notification (new event or abort) - await new Promise((resolve) => { - const resolvers = this._streamResolvers.get(streamId) ?? []; - resolvers.push(resolve); - this._streamResolvers.set(streamId, resolvers); - }); + }), + ); } } + + if (streamId) { + this._activeStreamIds.add(streamId); + } + this._lastStreamId = streamId; + + // Emit events asynchronously so the caller receives the streamId first. + if (eventsToEmit.length > 0) { + void Promise.resolve().then(() => { + for (const event of eventsToEmit) { + this._emit(event); + } + }); + } + + return { streamId }; } async abort(): Promise { - if (this._lastStreamId) { + if (this._lastStreamId && this._activeStreamIds.has(this._lastStreamId)) { const streamId = this._lastStreamId; - this._activeStreamIds.delete(streamId); - this._notify(streamId); + this._emit({ + id: `e-${this._nextEventId++}`, + timestamp: new Date().toISOString(), + streamId, + type: 'agent_end', + reason: 'aborted', + } as AgentEvent); } } } diff --git a/packages/core/src/agent/types.ts b/packages/core/src/agent/types.ts index 8b698a8e48..3b1c740ad4 100644 --- a/packages/core/src/agent/types.ts +++ b/packages/core/src/agent/types.ts @@ -6,25 +6,27 @@ export type WithMeta = { _meta?: Record }; -export interface AgentSession extends Trajectory { +export type Unsubscribe = () => void; + +export interface AgentProtocol extends Trajectory { /** * Send data to the agent. Promise resolves when action is acknowledged. - * Returns the `streamId` of the stream the message was correlated to -- this may - * be a new stream if idle or an existing stream. - */ - send(payload: AgentSend): Promise<{ streamId: string }>; - /** - * Begin listening to actively streaming data. Stream must have the following - * properties: + * Returns the `streamId` of the stream the message was correlated to -- + * this may be a new stream if idle, an existing stream, or null if no + * stream was triggered. * - * - If no arguments are provided, streams events from an active stream. - * - If a {streamId} is provided, streams ALL events from that stream. - * - If an {eventId} is provided, streams all events AFTER that event. + * When a new stream is created by a send, the streamId MUST be returned + * before the `agent_start` event is emitted for the stream. */ - stream(options?: { - streamId?: string; - eventId?: string; - }): AsyncIterableIterator; + send(payload: AgentSend): Promise<{ streamId: string | null }>; + + /** + * Subscribes the provided callback to all future events emitted by this + * session. Returns an unsubscribe function. + * + * @param callback The callback function to listen to events. + */ + subscribe(callback: (event: AgentEvent) => void): Unsubscribe; /** * Aborts an active stream of agent activity. @@ -32,7 +34,7 @@ export interface AgentSession extends Trajectory { abort(): Promise; /** - * AgentSession implements the Trajectory interface and can retrieve existing events. + * AgentProtocol implements the Trajectory interface and can retrieve existing events. */ readonly events: AgentEvent[]; } @@ -61,7 +63,7 @@ export interface AgentEventCommon { /** Identifies the subagent thread, omitted for "main thread" events. */ threadId?: string; /** Identifies a particular stream of a particular thread. */ - streamId?: string; + streamId?: string | null; /** ISO Timestamp for the time at which the event occurred. */ timestamp: string; /** The concrete type of the event. */ @@ -90,10 +92,10 @@ export interface AgentEvents { session_update: SessionUpdate; /** Message content provided by user, agent, or developer. */ message: Message; - /** Event indicating the start of a new stream. */ - stream_start: StreamStart; - /** Event indicating the end of a running stream. */ - stream_end: StreamEnd; + /** Event indicating the start of agent activity on a stream. */ + agent_start: AgentStart; + /** Event indicating the end of agent activity on a stream. */ + agent_end: AgentEnd; /** Tool request issued by the agent. */ tool_request: ToolRequest; /** Tool update issued by the agent. */ @@ -257,7 +259,7 @@ export interface Usage { cost?: { amount: number; currency?: string }; } -export interface StreamStart { +export interface AgentStart { streamId: string; } @@ -272,7 +274,7 @@ type StreamEndReason = | 'elicitation' | (string & {}); -export interface StreamEnd { +export interface AgentEnd { streamId: string; reason: StreamEndReason; elicitationIds?: string[];