mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-18 15:52:53 -07:00
!feat(core): tighten legacy agent session stream lifecycle
This commit is contained in:
@@ -101,24 +101,12 @@ async function collectEvents(
|
||||
options?: { streamId?: string; eventId?: string },
|
||||
): Promise<AgentEvent[]> {
|
||||
const events: AgentEvent[] = [];
|
||||
const streamOptions: { streamId?: string; eventId?: string } =
|
||||
options?.eventId
|
||||
? {
|
||||
eventId: options.eventId,
|
||||
...(options.streamId ? { streamId: options.streamId } : {}),
|
||||
}
|
||||
: {
|
||||
streamId:
|
||||
options?.streamId ??
|
||||
session.events.findLast((event) => event.type === 'agent_start')
|
||||
?.streamId,
|
||||
};
|
||||
const streamOptions =
|
||||
options?.eventId || options?.streamId ? options : undefined;
|
||||
|
||||
if (!streamOptions.eventId && !streamOptions.streamId) {
|
||||
return events;
|
||||
}
|
||||
|
||||
for await (const event of session.stream(streamOptions)) {
|
||||
for await (const event of streamOptions
|
||||
? session.stream(streamOptions)
|
||||
: session.stream()) {
|
||||
events.push(event);
|
||||
}
|
||||
return events;
|
||||
@@ -188,6 +176,40 @@ describe('LegacyAgentSession', () => {
|
||||
await collectEvents(session, { streamId: streamId ?? undefined });
|
||||
});
|
||||
|
||||
it('returns streamId before emitting agent_start', async () => {
|
||||
const sendMock = deps.client.sendMessageStream as ReturnType<
|
||||
typeof vi.fn
|
||||
>;
|
||||
sendMock.mockReturnValue(
|
||||
makeStream([
|
||||
{
|
||||
type: GeminiEventType.Finished,
|
||||
value: { reason: FinishReason.STOP, usageMetadata: undefined },
|
||||
},
|
||||
]),
|
||||
);
|
||||
|
||||
const session = new LegacyAgentSession(deps);
|
||||
const liveEvents: AgentEvent[] = [];
|
||||
session.subscribe((event) => {
|
||||
liveEvents.push(event);
|
||||
});
|
||||
|
||||
const { streamId } = await session.send({
|
||||
message: [{ type: 'text', text: 'hi' }],
|
||||
});
|
||||
|
||||
expect(streamId).toBe('test-stream');
|
||||
expect(liveEvents.some((event) => event.type === 'agent_start')).toBe(
|
||||
false,
|
||||
);
|
||||
|
||||
await collectEvents(session, { streamId: streamId ?? undefined });
|
||||
expect(liveEvents.some((event) => event.type === 'agent_start')).toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
it('throws for non-message payloads', async () => {
|
||||
const session = new LegacyAgentSession(deps);
|
||||
await expect(session.send({ update: { title: 'test' } })).rejects.toThrow(
|
||||
@@ -213,14 +235,17 @@ describe('LegacyAgentSession', () => {
|
||||
);
|
||||
|
||||
const session = new LegacyAgentSession(deps);
|
||||
await session.send({ message: [{ type: 'text', text: 'first' }] });
|
||||
const { streamId } = await session.send({
|
||||
message: [{ type: 'text', text: 'first' }],
|
||||
});
|
||||
await vi.advanceTimersByTimeAsync(0);
|
||||
|
||||
await expect(
|
||||
session.send({ message: [{ type: 'text', text: 'second' }] }),
|
||||
).rejects.toThrow('cannot be called while a stream is active');
|
||||
|
||||
resolveHang?.();
|
||||
await collectEvents(session);
|
||||
await collectEvents(session, { streamId: streamId ?? undefined });
|
||||
});
|
||||
|
||||
it('creates a new streamId after the previous stream completes', async () => {
|
||||
@@ -756,6 +781,48 @@ describe('LegacyAgentSession', () => {
|
||||
});
|
||||
|
||||
describe('abort', () => {
|
||||
it('treats abort before the first model event as aborted without fatal error', async () => {
|
||||
let releaseAbort: (() => void) | undefined;
|
||||
const sendMock = deps.client.sendMessageStream as ReturnType<
|
||||
typeof vi.fn
|
||||
>;
|
||||
sendMock.mockReturnValue(
|
||||
(async function* () {
|
||||
await new Promise<void>((resolve) => {
|
||||
releaseAbort = resolve;
|
||||
});
|
||||
yield* [];
|
||||
const abortError = new Error('Aborted');
|
||||
abortError.name = 'AbortError';
|
||||
throw abortError;
|
||||
})(),
|
||||
);
|
||||
|
||||
const session = new LegacyAgentSession(deps);
|
||||
const { streamId } = await session.send({
|
||||
message: [{ type: 'text', text: 'hi' }],
|
||||
});
|
||||
await vi.advanceTimersByTimeAsync(0);
|
||||
|
||||
await session.abort();
|
||||
releaseAbort?.();
|
||||
|
||||
const events = await collectEvents(session, {
|
||||
streamId: streamId ?? undefined,
|
||||
});
|
||||
expect(
|
||||
events.some(
|
||||
(event): event is AgentEvent<'error'> =>
|
||||
event.type === 'error' && event.fatal,
|
||||
),
|
||||
).toBe(false);
|
||||
|
||||
const streamEnd = events.findLast(
|
||||
(event): event is AgentEvent<'agent_end'> => event.type === 'agent_end',
|
||||
);
|
||||
expect(streamEnd?.reason).toBe('aborted');
|
||||
});
|
||||
|
||||
it('aborts the stream', async () => {
|
||||
const sendMock = deps.client.sendMessageStream as ReturnType<
|
||||
typeof vi.fn
|
||||
@@ -797,6 +864,59 @@ describe('LegacyAgentSession', () => {
|
||||
);
|
||||
expect(streamEnd?.reason).toBe('aborted');
|
||||
});
|
||||
|
||||
it('treats abort during pending scheduler work as aborted without fatal error', async () => {
|
||||
let resolveSchedule: ((value: CompletedToolCall[]) => void) | undefined;
|
||||
const sendMock = deps.client.sendMessageStream as ReturnType<
|
||||
typeof vi.fn
|
||||
>;
|
||||
sendMock.mockReturnValue(
|
||||
makeStream([
|
||||
{
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: makeToolRequest('call-1', 'slow_tool'),
|
||||
},
|
||||
{
|
||||
type: GeminiEventType.Finished,
|
||||
value: { reason: FinishReason.STOP, usageMetadata: undefined },
|
||||
},
|
||||
]),
|
||||
);
|
||||
|
||||
const scheduleMock = deps.scheduler.schedule as ReturnType<typeof vi.fn>;
|
||||
scheduleMock.mockReturnValue(
|
||||
new Promise<CompletedToolCall[]>((resolve) => {
|
||||
resolveSchedule = resolve;
|
||||
}),
|
||||
);
|
||||
|
||||
const session = new LegacyAgentSession(deps);
|
||||
const { streamId } = await session.send({
|
||||
message: [{ type: 'text', text: 'hi' }],
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 25));
|
||||
await session.abort();
|
||||
resolveSchedule?.([makeCompletedToolCall('call-1', 'slow_tool', 'done')]);
|
||||
|
||||
const events = await collectEvents(session, {
|
||||
streamId: streamId ?? undefined,
|
||||
});
|
||||
expect(
|
||||
events.some(
|
||||
(event): event is AgentEvent<'error'> =>
|
||||
event.type === 'error' && event.fatal,
|
||||
),
|
||||
).toBe(false);
|
||||
expect(events.some((event) => event.type === 'tool_response')).toBe(
|
||||
false,
|
||||
);
|
||||
|
||||
const streamEnd = events.findLast(
|
||||
(event): event is AgentEvent<'agent_end'> => event.type === 'agent_end',
|
||||
);
|
||||
expect(streamEnd?.reason).toBe('aborted');
|
||||
});
|
||||
});
|
||||
|
||||
describe('events property', () => {
|
||||
@@ -1273,5 +1393,25 @@ describe('LegacyAgentSession', () => {
|
||||
);
|
||||
expect(err?._meta?.['code']).toBe('ENOENT');
|
||||
});
|
||||
|
||||
it('preserves status in _meta for errors with status property', async () => {
|
||||
const sendMock = deps.client.sendMessageStream as ReturnType<
|
||||
typeof vi.fn
|
||||
>;
|
||||
const statusError = new Error('rate limited');
|
||||
(statusError as Error & { status: string }).status = 'RESOURCE_EXHAUSTED';
|
||||
sendMock.mockImplementation(() => {
|
||||
throw statusError;
|
||||
});
|
||||
|
||||
const session = new LegacyAgentSession(deps);
|
||||
await session.send({ message: [{ type: 'text', text: 'hi' }] });
|
||||
const events = await collectEvents(session);
|
||||
|
||||
const err = events.find(
|
||||
(e): e is AgentEvent<'error'> => e.type === 'error',
|
||||
);
|
||||
expect(err?._meta?.['status']).toBe('RESOURCE_EXHAUSTED');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -40,6 +40,10 @@ import type {
|
||||
Unsubscribe,
|
||||
} from './types.js';
|
||||
|
||||
function isAbortLikeError(err: unknown): boolean {
|
||||
return err instanceof Error && err.name === 'AbortError';
|
||||
}
|
||||
|
||||
export interface LegacySessionDeps {
|
||||
client: GeminiClient;
|
||||
scheduler: Scheduler;
|
||||
@@ -71,7 +75,7 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
this._promptId = deps.promptId;
|
||||
}
|
||||
|
||||
get events(): AgentEvent[] {
|
||||
get events(): readonly AgentEvent[] {
|
||||
return this._events;
|
||||
}
|
||||
|
||||
@@ -102,23 +106,11 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
this._beginNewStream();
|
||||
const streamId = this._translationState.streamId;
|
||||
const parts = contentPartsToGeminiParts(message);
|
||||
const userMessage = this._makeInternalEvent('message', {
|
||||
role: 'user',
|
||||
content: message,
|
||||
...(payload._meta ? { _meta: payload._meta } : {}),
|
||||
});
|
||||
const userMessage = this._makeUserMessageEvent(message, payload._meta);
|
||||
|
||||
this._emit([userMessage]);
|
||||
|
||||
void Promise.resolve().then(async () => {
|
||||
this._ensureAgentStart();
|
||||
try {
|
||||
await this._runLoop(parts);
|
||||
} catch (err: unknown) {
|
||||
this._emitErrorAndAgentEnd(err);
|
||||
this._markStreamDone();
|
||||
}
|
||||
});
|
||||
this._scheduleRunLoop(parts);
|
||||
|
||||
return { streamId };
|
||||
}
|
||||
@@ -127,6 +119,26 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
this._abortController.abort();
|
||||
}
|
||||
|
||||
private _scheduleRunLoop(initialParts: Part[]): void {
|
||||
setTimeout(() => {
|
||||
void this._runLoopInBackground(initialParts);
|
||||
}, 0);
|
||||
}
|
||||
|
||||
private async _runLoopInBackground(initialParts: Part[]): Promise<void> {
|
||||
this._ensureAgentStart();
|
||||
try {
|
||||
await this._runLoop(initialParts);
|
||||
} catch (err: unknown) {
|
||||
if (this._abortController.signal.aborted || isAbortLikeError(err)) {
|
||||
this._ensureAgentEnd('aborted');
|
||||
} else {
|
||||
this._emitErrorAndAgentEnd(err);
|
||||
}
|
||||
this._markStreamDone();
|
||||
}
|
||||
}
|
||||
|
||||
private async _runLoop(initialParts: Part[]): Promise<void> {
|
||||
let currentParts: Part[] = initialParts;
|
||||
let turnCount = 0;
|
||||
@@ -135,17 +147,11 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
while (true) {
|
||||
turnCount++;
|
||||
if (maxTurns >= 0 && turnCount > maxTurns) {
|
||||
this._emit([
|
||||
this._makeInternalEvent('agent_end', {
|
||||
reason: 'max_turns',
|
||||
data: {
|
||||
code: 'MAX_TURNS_EXCEEDED',
|
||||
maxTurns,
|
||||
turnCount: turnCount - 1,
|
||||
},
|
||||
}),
|
||||
]);
|
||||
this._markStreamDone();
|
||||
this._finishStream('max_turns', {
|
||||
code: 'MAX_TURNS_EXCEEDED',
|
||||
maxTurns,
|
||||
turnCount: turnCount - 1,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -158,8 +164,7 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
|
||||
for await (const event of responseStream) {
|
||||
if (this._abortController.signal.aborted) {
|
||||
this._ensureAgentEnd('aborted');
|
||||
this._markStreamDone();
|
||||
this._finishStream('aborted');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -170,8 +175,7 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
this._emit(translateEvent(event, this._translationState));
|
||||
|
||||
if (event.type === GeminiEventType.Error) {
|
||||
this._ensureAgentEnd('failed');
|
||||
this._markStreamDone();
|
||||
this._finishStream('failed');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -179,15 +183,13 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
event.type === GeminiEventType.InvalidStream ||
|
||||
event.type === GeminiEventType.ContextWindowWillOverflow
|
||||
) {
|
||||
this._ensureAgentEnd('failed');
|
||||
this._markStreamDone();
|
||||
this._finishStream('failed');
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.type === GeminiEventType.Finished) {
|
||||
if (toolCallRequests.length === 0) {
|
||||
this._ensureAgentEnd(mapFinishReason(event.value.reason));
|
||||
this._markStreamDone();
|
||||
this._finishStream(mapFinishReason(event.value.reason));
|
||||
return;
|
||||
}
|
||||
continue;
|
||||
@@ -203,9 +205,13 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
}
|
||||
}
|
||||
|
||||
if (this._abortController.signal.aborted) {
|
||||
this._finishStream('aborted');
|
||||
return;
|
||||
}
|
||||
|
||||
if (toolCallRequests.length === 0) {
|
||||
this._ensureAgentEnd('completed');
|
||||
this._markStreamDone();
|
||||
this._finishStream('completed');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -214,6 +220,11 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
this._abortController.signal,
|
||||
);
|
||||
|
||||
if (this._abortController.signal.aborted) {
|
||||
this._finishStream('aborted');
|
||||
return;
|
||||
}
|
||||
|
||||
const toolResponseParts: Part[] = [];
|
||||
for (const tc of completedToolCalls) {
|
||||
const response = tc.response;
|
||||
@@ -227,7 +238,7 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
const data = buildToolResponseData(response);
|
||||
|
||||
this._emit([
|
||||
this._makeInternalEvent('tool_response', {
|
||||
this._makeToolResponseEvent({
|
||||
requestId: request.callId,
|
||||
name: request.name,
|
||||
content,
|
||||
@@ -261,8 +272,7 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
tc.response.error !== undefined,
|
||||
);
|
||||
if (stopTool) {
|
||||
this._ensureAgentEnd('completed');
|
||||
this._markStreamDone();
|
||||
this._finishStream('completed');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -270,8 +280,7 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
isFatalToolError(tc.response.errorType),
|
||||
);
|
||||
if (fatalTool) {
|
||||
this._ensureAgentEnd('failed');
|
||||
this._markStreamDone();
|
||||
this._finishStream('failed');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -313,21 +322,29 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
private _ensureAgentStart(): void {
|
||||
if (!this._translationState.streamStartEmitted) {
|
||||
this._translationState.streamStartEmitted = true;
|
||||
this._emit([this._makeInternalEvent('agent_start', {})]);
|
||||
this._emit([this._makeAgentStartEvent()]);
|
||||
}
|
||||
}
|
||||
|
||||
private _ensureAgentEnd(reason: StreamEndReason = 'completed'): void {
|
||||
if (!this._agentEndEmitted && this._translationState.streamStartEmitted) {
|
||||
this._agentEndEmitted = true;
|
||||
this._emit([
|
||||
this._makeInternalEvent('agent_end', {
|
||||
reason,
|
||||
}),
|
||||
]);
|
||||
this._emit([this._makeAgentEndEvent(reason)]);
|
||||
}
|
||||
}
|
||||
|
||||
private _finishStream(
|
||||
reason: StreamEndReason,
|
||||
data?: Record<string, unknown>,
|
||||
): void {
|
||||
if (data && !this._agentEndEmitted) {
|
||||
this._emit([this._makeAgentEndEvent(reason, data)]);
|
||||
} else {
|
||||
this._ensureAgentEnd(reason);
|
||||
}
|
||||
this._markStreamDone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Preserve error identity fields in _meta so downstream consumers can
|
||||
* reconstruct fatal CLI errors.
|
||||
@@ -346,10 +363,13 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
if ('code' in err) {
|
||||
meta['code'] = err.code;
|
||||
}
|
||||
if ('status' in err) {
|
||||
meta['status'] = err.status;
|
||||
}
|
||||
}
|
||||
|
||||
this._emit([
|
||||
this._makeInternalEvent('error', {
|
||||
this._makeErrorEvent({
|
||||
status: 'INTERNAL',
|
||||
message,
|
||||
fatal: true,
|
||||
@@ -360,22 +380,76 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
this._ensureAgentEnd('failed');
|
||||
}
|
||||
|
||||
private _makeInternalEvent<T extends AgentEvent['type']>(
|
||||
type: T,
|
||||
payload: Omit<
|
||||
Partial<AgentEvent<T>>,
|
||||
'id' | 'timestamp' | 'streamId' | 'type'
|
||||
>,
|
||||
): AgentEvent<T> {
|
||||
const id = `${this._translationState.streamId}-${this._translationState.eventCounter++}`;
|
||||
private _nextEventFields() {
|
||||
return {
|
||||
...payload,
|
||||
id,
|
||||
id: `${this._translationState.streamId}-${this._translationState.eventCounter++}`,
|
||||
timestamp: new Date().toISOString(),
|
||||
streamId: this._translationState.streamId,
|
||||
type,
|
||||
};
|
||||
}
|
||||
|
||||
private _makeUserMessageEvent(
|
||||
content: ContentPart[],
|
||||
meta?: Record<string, unknown>,
|
||||
): AgentEvent<'message'> {
|
||||
const event = {
|
||||
...this._nextEventFields(),
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content,
|
||||
...(meta ? { _meta: meta } : {}),
|
||||
} satisfies AgentEvent<'message'>;
|
||||
return event;
|
||||
}
|
||||
|
||||
private _makeToolResponseEvent(
|
||||
payload: Omit<
|
||||
AgentEvent<'tool_response'>,
|
||||
'id' | 'timestamp' | 'streamId' | 'type'
|
||||
>,
|
||||
): AgentEvent<'tool_response'> {
|
||||
const event = {
|
||||
...this._nextEventFields(),
|
||||
type: 'tool_response',
|
||||
...payload,
|
||||
} satisfies AgentEvent<'tool_response'>;
|
||||
return event;
|
||||
}
|
||||
|
||||
private _makeAgentStartEvent(): AgentEvent<'agent_start'> {
|
||||
const event = {
|
||||
...this._nextEventFields(),
|
||||
type: 'agent_start',
|
||||
} satisfies AgentEvent<'agent_start'>;
|
||||
return event;
|
||||
}
|
||||
|
||||
private _makeAgentEndEvent(
|
||||
reason: StreamEndReason,
|
||||
data?: Record<string, unknown>,
|
||||
): AgentEvent<'agent_end'> {
|
||||
const event = {
|
||||
...this._nextEventFields(),
|
||||
type: 'agent_end',
|
||||
reason,
|
||||
...(data ? { data } : {}),
|
||||
} satisfies AgentEvent<'agent_end'>;
|
||||
return event;
|
||||
}
|
||||
|
||||
private _makeErrorEvent(
|
||||
payload: Omit<
|
||||
AgentEvent<'error'>,
|
||||
'id' | 'timestamp' | 'streamId' | 'type'
|
||||
>,
|
||||
): AgentEvent<'error'> {
|
||||
const event = {
|
||||
...this._nextEventFields(),
|
||||
type: 'error',
|
||||
...payload,
|
||||
} satisfies AgentEvent<'error'>;
|
||||
return event;
|
||||
}
|
||||
}
|
||||
|
||||
export class LegacyAgentSession extends AgentSession {
|
||||
|
||||
Reference in New Issue
Block a user