diff --git a/packages/core/src/agents/agent.ts b/packages/core/src/agents/agent.ts index 39aef3368d..58a8ff7e48 100644 --- a/packages/core/src/agents/agent.ts +++ b/packages/core/src/agents/agent.ts @@ -30,12 +30,12 @@ export class Agent { /** * Helper to quickly run a single prompt and get the results. */ - async *prompt( + prompt( input: string | Part[], sessionId?: string, signal?: AbortSignal, ): AsyncIterable { const session = this.createSession(sessionId); - yield* session.prompt(input, signal); + return session.prompt(input, signal); } } diff --git a/packages/core/src/agents/local-executor.ts b/packages/core/src/agents/local-executor.ts index 919ede4ca6..bfb644e01b 100644 --- a/packages/core/src/agents/local-executor.ts +++ b/packages/core/src/agents/local-executor.ts @@ -546,12 +546,7 @@ export class LocalAgentExecutor { // === UNIFIED RECOVERY BLOCK === // Only attempt recovery if it's a known recoverable reason. // We don't recover from GOAL (already done) or ABORTED (user cancelled). - if ( - terminateReason !== AgentTerminateMode.ERROR && - terminateReason !== AgentTerminateMode.ABORTED && - terminateReason !== AgentTerminateMode.GOAL && - terminateReason !== AgentTerminateMode.LOOP - ) { + if (this.isRecoverableReason(terminateReason)) { const recoveryResult = await this.executeFinalWarningTurn( chat, turnCounter, // Use current turnCounter for the recovery attempt @@ -1257,6 +1252,23 @@ Important Rules: return null; } + /** + * Returns true if the agent should attempt a recovery turn for the given reason. + */ + private isRecoverableReason( + reason: AgentTerminateMode, + ): reason is + | AgentTerminateMode.TIMEOUT + | AgentTerminateMode.MAX_TURNS + | AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL { + return ( + reason !== AgentTerminateMode.ERROR && + reason !== AgentTerminateMode.ABORTED && + reason !== AgentTerminateMode.GOAL && + reason !== AgentTerminateMode.LOOP + ); + } + /** Emits an activity event to the configured callback. */ private emitActivity( type: SubagentActivityEvent['type'], diff --git a/packages/core/src/agents/session.test.ts b/packages/core/src/agents/session.test.ts index 413f4a704c..444288ab2b 100644 --- a/packages/core/src/agents/session.test.ts +++ b/packages/core/src/agents/session.test.ts @@ -11,6 +11,9 @@ import { type AgentConfig, AgentTerminateMode, type AgentEvent, + type AgentFinishEvent, + type ToolSuiteStartEvent, + type ToolCallFinishEvent, } from './types.js'; import { Scheduler } from '../scheduler/scheduler.js'; import { GeminiEventType, CompressionStatus } from '../core/turn.js'; @@ -50,6 +53,7 @@ describe('AgentSession', () => { let session: AgentSession; const agentConfig: AgentConfig = { name: 'TestAgent', + systemInstruction: 'You are a test agent.', capabilities: { compression: true }, }; @@ -111,10 +115,7 @@ describe('AgentSession', () => { events.push(event); } - const finishEvent = events[events.length - 1] as Extract< - AgentEvent, - { type: 'agent_finish' } - >; + const finishEvent = events[events.length - 1] as AgentFinishEvent; expect(events[0].type).toBe('agent_start'); expect(finishEvent.type).toBe('agent_finish'); expect(finishEvent.value.reason).toBe(AgentTerminateMode.GOAL); @@ -195,10 +196,7 @@ describe('AgentSession', () => { const callFinish = events.find((e) => e.type === 'tool_call_finish'); expect(callStart).toBeDefined(); expect(callFinish).toBeDefined(); - expect( - (callFinish as Extract).value - .callId, - ).toBe('call1'); + expect((callFinish as ToolCallFinishEvent).value.callId).toBe('call1'); }); it('should handle multiple consecutive ReAct turns', async () => { @@ -270,7 +268,7 @@ describe('AgentSession', () => { const suiteStart = events.find( (e) => e.type === 'tool_suite_start', - ) as Extract; + ) as ToolSuiteStartEvent; expect(suiteStart.value.count).toBe(2); expect(mockScheduler.schedule).toHaveBeenCalledTimes(1); expect(mockScheduler.schedule).toHaveBeenCalledWith( @@ -305,7 +303,7 @@ describe('AgentSession', () => { throw new Error('Model connection lost'); }); - const events = []; + const events: AgentEvent[] = []; try { for await (const event of session.prompt('Error test')) { events.push(event); @@ -316,8 +314,10 @@ describe('AgentSession', () => { const finishEvent = events.find( (e) => e.type === 'agent_finish', - ) as Extract; + ) as AgentFinishEvent; expect(finishEvent).toBeDefined(); + expect(finishEvent.value.reason).toBe(AgentTerminateMode.ERROR); + expect(finishEvent.value.message).toBe('Model connection lost'); }); it('should ignore MessageBus updates from other schedulers', async () => { @@ -371,10 +371,9 @@ describe('AgentSession', () => { for await (const event of session.prompt('Loop')) { events.push(event); } - const finishEvent = events.find( (e) => e.type === 'agent_finish', - ) as Extract; + ) as AgentFinishEvent; expect(finishEvent.value.reason).toBe(AgentTerminateMode.LOOP); expect(finishEvent.value.message).toContain('Loop detected'); }); @@ -426,10 +425,7 @@ describe('AgentSession', () => { events.push(event); } - const finishEvent = events[events.length - 1] as Extract< - AgentEvent, - { type: 'agent_finish' } - >; + const finishEvent = events[events.length - 1] as AgentFinishEvent; expect(finishEvent.type).toBe('agent_finish'); expect(finishEvent.value.reason).toBe(AgentTerminateMode.ABORTED); }); @@ -520,10 +516,7 @@ describe('AgentSession', () => { expect(mockScheduler.schedule).toHaveBeenCalledTimes(2); - const finishEvent = events[events.length - 1] as Extract< - AgentEvent, - { type: 'agent_finish' } - >; + const finishEvent = events[events.length - 1] as AgentFinishEvent; expect(finishEvent.type).toBe('agent_finish'); expect(finishEvent.value.totalTurns).toBe(2); expect(finishEvent.value.reason).toBe(AgentTerminateMode.MAX_TURNS); diff --git a/packages/core/src/agents/session.ts b/packages/core/src/agents/session.ts index f0132750f6..95144e4f91 100644 --- a/packages/core/src/agents/session.ts +++ b/packages/core/src/agents/session.ts @@ -13,6 +13,7 @@ import { type ToolCallRequestInfo, type ToolCallResponseInfo, CoreToolCallStatus, + type CompletedToolCall, } from '../scheduler/types.js'; import { GeminiEventType, CompressionStatus } from '../core/turn.js'; import { recordToolCallInteractions } from '../code_assist/telemetry.js'; @@ -27,11 +28,32 @@ import { type ToolCallsUpdateMessage, } from '../confirmation-bus/types.js'; +/** Result of a single model turn in the ReAct loop. */ +export interface ModelTurnResult { + /** The specific tool calls requested by the model. */ + toolCalls: ToolCallRequestInfo[]; + /** The unified event stream from this model turn. */ + events: AsyncIterable; + /** Whether an infinite tool loop was detected. */ + loopDetected: boolean; +} + +/** Result of executing a batch of tool calls. */ +export interface ToolExecutionResult { + /** The response parts from the tool execution to be sent back to the model. */ + nextParts: Part[]; + /** Whether execution should stop immediately (e.g. on fatal tool error). */ + stopExecution: boolean; + /** Optional details if execution was stopped. */ + stopExecutionInfo: ToolCallResponseInfo | undefined; +} + /** * AgentSession manages the state of a conversation and orchestrates the agent * loop. */ export class AgentSession { + readonly sessionId: string; private readonly client: GeminiClient; private readonly scheduler: Scheduler; private readonly schedulerId: string; @@ -40,10 +62,11 @@ export class AgentSession { private hasFailedCompressionAttempt = false; constructor( - private readonly sessionId: string, + sessionId: string, private readonly config: AgentConfig, private readonly runtime: Config, ) { + this.sessionId = sessionId; this.client = this.runtime.getGeminiClient(); this.schedulerId = `agent-scheduler-${this.sessionId}-${Math.random().toString(36).substring(2, 9)}`; this.scheduler = new Scheduler({ @@ -100,90 +123,38 @@ export class AgentSession { value: { sessionId: this.sessionId }, }; - let currentInput = input; - let isContinuation = false; - const maxTurns = this.config.maxTurns ?? -1; - let terminationReason = AgentTerminateMode.GOAL; let terminationMessage: string | undefined = undefined; let terminationError: unknown | undefined = undefined; try { - while (maxTurns === -1 || this.totalTurns < maxTurns) { - if (combinedSignal.aborted) { - terminationReason = AgentTerminateMode.ABORTED; - break; - } - - this.totalTurns++; - const promptId = `${this.sessionId}#${this.totalTurns}`; - - // Compression check (from LocalAgentExecutor / useGeminiStream patterns) - if (this.config.capabilities?.compression) { - await this.tryCompressChat(promptId); - } - - const results = await this.runModelTurn( - currentInput, - promptId, - isContinuation ? undefined : input, - combinedSignal, - ); - - for await (const event of results.events) { - yield event; - } - - if (results.loopDetected) { + const loop = this._runLoop(input, combinedSignal); + for await (const event of loop) { + if (event.type === GeminiEventType.LoopDetected) { terminationReason = AgentTerminateMode.LOOP; terminationMessage = 'Loop detected, stopping execution'; - break; - } - - if (combinedSignal.aborted) { - terminationReason = AgentTerminateMode.ABORTED; - break; - } - - if (results.toolCalls.length > 0) { - const toolRun = this.executeTools(results.toolCalls, combinedSignal); - let resultsTools; - while (true) { - const { value, done } = await toolRun.next(); - if (done) { - resultsTools = value; - break; - } - yield value; - } - - if (resultsTools.stopExecution || combinedSignal.aborted) { - if (combinedSignal.aborted) { - terminationReason = AgentTerminateMode.ABORTED; - } else if (resultsTools.stopExecutionInfo) { - terminationReason = AgentTerminateMode.ERROR; - terminationMessage = - resultsTools.stopExecutionInfo.error?.message; - terminationError = resultsTools.stopExecutionInfo.error; - } - break; - } - - // Check if we hit the turn limit - if (maxTurns !== -1 && this.totalTurns >= maxTurns) { - terminationReason = AgentTerminateMode.MAX_TURNS; - terminationMessage = 'Maximum session turns exceeded.'; - break; - } - - currentInput = resultsTools.nextParts; - isContinuation = true; - } else { - // No more tool calls, turn is complete. - terminationReason = AgentTerminateMode.GOAL; - break; } + yield event; } + + if (combinedSignal.aborted) { + terminationReason = AgentTerminateMode.ABORTED; + } else if ( + terminationReason === AgentTerminateMode.GOAL && + this.config.maxTurns && + this.config.maxTurns !== -1 && + this.totalTurns >= this.config.maxTurns + ) { + // Only set MAX_TURNS if we haven't already hit another reason (like LOOP) + // and we are actually at or above the turn limit. + terminationReason = AgentTerminateMode.MAX_TURNS; + terminationMessage = 'Maximum session turns exceeded.'; + } + } catch (e) { + terminationReason = AgentTerminateMode.ERROR; + terminationMessage = e instanceof Error ? e.message : String(e); + terminationError = e; + throw e; } finally { internalController.abort(); yield { @@ -199,6 +170,90 @@ export class AgentSession { } } + /** + * Internal generator managing the turn-by-turn ReAct loop. + */ + private async *_runLoop( + input: string | Part[], + signal: AbortSignal, + ): AsyncIterable { + let currentInput = input; + let isContinuation = false; + const maxTurns = this.config.maxTurns ?? -1; + + while (maxTurns === -1 || this.totalTurns < maxTurns) { + if (signal.aborted) return; + + this.totalTurns++; + const promptId = `${this.sessionId}#${this.totalTurns}`; + + if (this.config.capabilities?.compression) { + await this.tryCompressChat(promptId); + } + + const results = await this.runModelTurn( + currentInput, + promptId, + isContinuation ? undefined : input, + signal, + ); + + for await (const event of results.events) { + yield event; + } + + if (results.loopDetected) return; + if (signal.aborted) return; + + if (results.toolCalls.length > 0) { + const toolRun = this._handleToolCalls(results.toolCalls, signal); + let toolResults: ToolExecutionResult; + while (true) { + const { value, done } = await toolRun.next(); + if (done) { + toolResults = value; + break; + } + yield value; + } + + if (toolResults.stopExecution || signal.aborted) { + if (toolResults.stopExecution && toolResults.stopExecutionInfo) { + throw ( + toolResults.stopExecutionInfo.error ?? + new Error('Tool execution stopped') + ); + } + return; + } + + if (maxTurns !== -1 && this.totalTurns >= maxTurns) { + return; + } + + currentInput = toolResults.nextParts; + isContinuation = true; + } else { + return; + } + } + } + + /** + * Orchestrates tool execution turn and yields events. + */ + private async *_handleToolCalls( + toolCalls: ToolCallRequestInfo[], + signal: AbortSignal, + ): AsyncGenerator { + const toolRun = this.executeTools(toolCalls, signal); + while (true) { + const { value, done } = await toolRun.next(); + if (done) return value; + yield value; + } + } + /** * Calls the model and yields the event stream. * Collects tool call requests for the next phase. @@ -208,11 +263,7 @@ export class AgentSession { promptId: string, displayContent?: string | Part[], signal?: AbortSignal, - ): Promise<{ - toolCalls: ToolCallRequestInfo[]; - events: AsyncIterable; - loopDetected: boolean; - }> { + ): Promise { const parts = Array.isArray(input) ? input : [{ text: input }]; const toolCalls: ToolCallRequestInfo[] = []; let loopDetected = false; @@ -244,7 +295,7 @@ export class AgentSession { } else if (event.type === GeminiEventType.LoopDetected) { loopDetected = true; } - yield event as AgentEvent; + yield event; } }; @@ -265,14 +316,7 @@ export class AgentSession { private async *executeTools( toolCalls: ToolCallRequestInfo[], signal?: AbortSignal, - ): AsyncGenerator< - AgentEvent, - { - nextParts: Part[]; - stopExecution: boolean; - stopExecutionInfo: ToolCallResponseInfo | undefined; - } - > { + ): AsyncGenerator { yield { type: 'tool_suite_start', value: { count: toolCalls.length }, @@ -282,34 +326,11 @@ export class AgentSession { let resolveNext: (() => void) | undefined; let isFinished = false; - // Track seen status transitions to avoid duplicate events - const seenStatuses = new Map(); + const onToolUpdate = this._createToolUpdateHandler(eventQueue, () => + resolveNext?.(), + ); const messageBus = this.runtime.getMessageBus(); - const onToolUpdate = (message: ToolCallsUpdateMessage) => { - if (message.schedulerId !== this.schedulerId) return; - - for (const call of message.toolCalls) { - const prevStatus = seenStatuses.get(call.request.callId); - if (prevStatus === call.status) continue; - - if (call.status === CoreToolCallStatus.Executing) { - eventQueue.push({ type: 'tool_call_start', value: call.request }); - } else if ( - call.status === CoreToolCallStatus.Success || - call.status === CoreToolCallStatus.Error || - call.status === CoreToolCallStatus.Cancelled - ) { - eventQueue.push({ - type: 'tool_call_finish', - value: call.response, - }); - } - seenStatuses.set(call.request.callId, call.status); - } - resolveNext?.(); - }; - messageBus.subscribe(MessageBusType.TOOL_CALLS_UPDATE, onToolUpdate); const schedulePromise = this.scheduler.schedule( @@ -343,17 +364,7 @@ export class AgentSession { value: { responses: completedCalls.map((c) => c.response) }, }; - // Record tool call info for persistence/telemetry - try { - const currentModel = - this.client.getCurrentSequenceModel() ?? this.runtime.getModel(); - this.client - .getChat() - .recordCompletedToolCalls(currentModel, completedCalls); - await recordToolCallInteractions(this.runtime, completedCalls); - } catch (e) { - debugLogger.warn(`Error recording tool call information: ${e}`); - } + await this._recordTelemetry(completedCalls); const nextParts = completedCalls.flatMap((c) => c.response.responseParts); const stopExecutionInfo = completedCalls.find( @@ -370,6 +381,56 @@ export class AgentSession { } } + /** + * Creates a handler for MessageBus tool update events. + */ + private _createToolUpdateHandler( + eventQueue: AgentEvent[], + onNewEvents: () => void, + ) { + const seenStatuses = new Map(); + + return (message: ToolCallsUpdateMessage) => { + if (message.schedulerId !== this.schedulerId) return; + + for (const call of message.toolCalls) { + const prevStatus = seenStatuses.get(call.request.callId); + if (prevStatus === call.status) continue; + + if (call.status === CoreToolCallStatus.Executing) { + eventQueue.push({ type: 'tool_call_start', value: call.request }); + } else if ( + call.status === CoreToolCallStatus.Success || + call.status === CoreToolCallStatus.Error || + call.status === CoreToolCallStatus.Cancelled + ) { + eventQueue.push({ + type: 'tool_call_finish', + value: call.response, + }); + } + seenStatuses.set(call.request.callId, call.status); + } + onNewEvents(); + }; + } + + /** + * Records tool interaction telemetry and persistence data. + */ + private async _recordTelemetry(completedCalls: CompletedToolCall[]) { + try { + const currentModel = + this.client.getCurrentSequenceModel() ?? this.runtime.getModel(); + this.client + .getChat() + .recordCompletedToolCalls(currentModel, completedCalls); + await recordToolCallInteractions(this.runtime, completedCalls); + } catch (e) { + debugLogger.warn(`Error recording tool call information: ${e}`); + } + } + /** * Attempts to compress the chat history if thresholds are exceeded. */ @@ -405,11 +466,4 @@ export class AgentSession { getHistory() { return this.client.getHistory(); } - - /** - * Returns the current session ID. - */ - getSessionId(): string { - return this.sessionId; - } } diff --git a/packages/core/src/agents/types.ts b/packages/core/src/agents/types.ts index f9d6f907f0..a4bdde6711 100644 --- a/packages/core/src/agents/types.ts +++ b/packages/core/src/agents/types.ts @@ -16,29 +16,74 @@ import { type ToolCallRequestInfo, } from '../scheduler/types.js'; +/** Emitted when an agent session begins execution. */ +export interface AgentStartEvent { + type: 'agent_start'; + value: { sessionId: string }; +} + +/** Emitted when an agent session completes, providing termination details. */ +export interface AgentFinishEvent { + type: 'agent_finish'; + value: { + sessionId: string; + totalTurns: number; + reason: AgentTerminateMode; + message?: string; + error?: unknown; + }; +} + +/** Emitted when a group of tool calls is about to be executed. */ +export interface ToolSuiteStartEvent { + type: 'tool_suite_start'; + value: { count: number }; +} + +/** Emitted when a group of tool calls has finished executing. */ +export interface ToolSuiteFinishEvent { + type: 'tool_suite_finish'; + value: { responses: ToolCallResponseInfo[] }; +} + +/** Emitted when an individual tool call begins execution. */ +export interface ToolCallStartEvent { + type: 'tool_call_start'; + value: ToolCallRequestInfo; +} + +/** Emitted when an individual tool call has finished execution. */ +export interface ToolCallFinishEvent { + type: 'tool_call_finish'; + value: ToolCallResponseInfo; +} + +/** Emitted when the model generates internal reasoning or "thought" content. */ +export interface ThoughtEvent { + type: 'thought'; + value: string; +} + +/** Emitted when an infinite loop is detected in the model's tool calling patterns. */ +export interface LoopDetectedEvent { + type: 'loop_detected'; + value: { sessionId: string }; +} + /** * Unified event type for the Agent loop. * This extends the base Gemini stream events with higher-level agent lifecycle events. */ export type AgentEvent = | ServerGeminiStreamEvent - | { type: 'agent_start'; value: { sessionId: string } } - | { - type: 'agent_finish'; - value: { - sessionId: string; - totalTurns: number; - reason: AgentTerminateMode; - message?: string; - error?: unknown; - }; - } - | { type: 'tool_suite_start'; value: { count: number } } - | { type: 'tool_suite_finish'; value: { responses: ToolCallResponseInfo[] } } - | { type: 'tool_call_start'; value: ToolCallRequestInfo } - | { type: 'tool_call_finish'; value: ToolCallResponseInfo } - | { type: 'thought'; value: string } - | { type: 'loop_detected'; value: { sessionId: string } }; + | AgentStartEvent + | AgentFinishEvent + | ToolSuiteStartEvent + | ToolSuiteFinishEvent + | ToolCallStartEvent + | ToolCallFinishEvent + | ThoughtEvent + | LoopDetectedEvent; /** * Configuration for an Agent. @@ -47,7 +92,7 @@ export interface AgentConfig { /** The name of the agent. */ name: string; /** The system instruction (personality/rules) for the agent. */ - systemInstruction?: string; + systemInstruction: string; /** Optional override for the model to use. */ model?: string; /**