From b23bcc7ae5e4f8dd9261f47ff50b8fb2a0c666ca Mon Sep 17 00:00:00 2001 From: Abhi Date: Thu, 19 Feb 2026 16:12:53 -0500 Subject: [PATCH] WIP - draft for agent factory --- packages/cli/src/nonInteractiveCli.test.ts | 116 +++ packages/cli/src/nonInteractiveCli.ts | 742 +++++++++++------- .../cli/src/ui/commands/rewindCommand.tsx | 6 +- .../__snapshots__/MainContent.test.tsx.snap | 1 - .../src/ui/hooks/useSessionBrowser.test.ts | 30 +- .../cli/src/ui/hooks/useSessionBrowser.ts | 7 +- packages/cli/src/ui/hooks/useSessionResume.ts | 3 +- packages/cli/src/utils/sessionUtils.ts | 114 +-- .../cli/src/zed-integration/acpResume.test.ts | 14 +- .../cli/src/zed-integration/zedIntegration.ts | 10 +- packages/core/src/agents/agent.test.ts | 73 ++ packages/core/src/agents/agent.ts | 41 + packages/core/src/agents/session.test.ts | 271 +++++++ packages/core/src/agents/session.ts | 297 +++++++ packages/core/src/agents/types.ts | 55 +- packages/core/src/index.ts | 2 + packages/core/src/utils/sessionUtils.test.ts | 122 +++ packages/core/src/utils/sessionUtils.ts | 111 +++ 18 files changed, 1598 insertions(+), 417 deletions(-) create mode 100644 packages/core/src/agents/agent.test.ts create mode 100644 packages/core/src/agents/agent.ts create mode 100644 packages/core/src/agents/session.test.ts create mode 100644 packages/core/src/agents/session.ts create mode 100644 packages/core/src/utils/sessionUtils.test.ts create mode 100644 packages/core/src/utils/sessionUtils.ts diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts index 206d011e63..3180bd7a1d 100644 --- a/packages/cli/src/nonInteractiveCli.test.ts +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -82,6 +82,19 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => { stdout: process.stdout, stderr: process.stderr, })), + AgentSession: vi.fn(), + AgentTerminateMode: { + ERROR: 'ERROR', + TIMEOUT: 'TIMEOUT', + GOAL: 'GOAL', + MAX_TURNS: 'MAX_TURNS', + ABORTED: 'ABORTED', + }, + CoreToolCallStatus: { + Success: 'success', + Error: 'error', + Cancelled: 'cancelled', + }, }; }); @@ -190,6 +203,7 @@ describe('runNonInteractive', () => { isTrustedFolder: vi.fn().mockReturnValue(false), getRawOutput: vi.fn().mockReturnValue(false), getAcceptRawOutputRisk: vi.fn().mockReturnValue(false), + isAgentsEnabled: vi.fn().mockReturnValue(false), } as unknown as Config; mockSettings = { @@ -2231,4 +2245,106 @@ describe('runNonInteractive', () => { expect(output).toContain('"status":"success"'); }); }); + describe('runNonInteractive (AgentSession)', () => { + let mockAgentSession: { prompt: Mock; resume: Mock }; + + beforeEach(async () => { + vi.mocked(mockConfig.isAgentsEnabled).mockReturnValue(true); + + // Get the mocked AgentSession class to spy on instances + const core = await import('@google/gemini-cli-core'); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const MockAgentSessionClass = core.AgentSession as any; + + // Mock the prompt method logic + mockAgentSession = { + prompt: vi.fn(), + resume: vi.fn(), + }; + + // When new AgentSession() is called, return our mock instance + MockAgentSessionClass.mockImplementation(() => mockAgentSession); + }); + + it('should process input and write text output from agent events', async () => { + const events = [ + { type: GeminiEventType.Content, value: 'Hello' }, + { type: GeminiEventType.Content, value: ' World' }, + { + type: 'agent_finish', + value: { + reason: 'GOAL', + sessionId: 'test-session-id', + totalTurns: 1, + }, + }, + ]; + + async function* eventGenerator() { + for (const event of events) { + yield event; + } + } + + mockAgentSession.prompt.mockReturnValue(eventGenerator()); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Test input', + prompt_id: 'prompt-id-agent', + }); + + expect(mockAgentSession.prompt).toHaveBeenCalledWith( + [{ text: 'Test input' }], + expect.any(AbortSignal), + ); + expect(getWrittenOutput()).toBe('Hello World\n'); + }); + + it('should write JSON output with stats from agent events', async () => { + const events = [ + { type: GeminiEventType.Content, value: 'JSON Response' }, + { + type: 'agent_finish', + value: { + reason: 'GOAL', + sessionId: 'test-session-id', + totalTurns: 1, + }, + }, + ]; + + async function* eventGenerator() { + for (const event of events) { + yield event; + } + } + + mockAgentSession.prompt.mockReturnValue(eventGenerator()); + vi.mocked(mockConfig.getOutputFormat).mockReturnValue(OutputFormat.JSON); + vi.mocked(uiTelemetryService.getMetrics).mockReturnValue( + MOCK_SESSION_METRICS, + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Test input', + prompt_id: 'prompt-id-agent-json', + }); + + expect(processStdoutSpy).toHaveBeenCalledWith( + JSON.stringify( + { + session_id: 'test-session-id', + response: 'JSON Response', + stats: MOCK_SESSION_METRICS, + }, + null, + 2, + ), + ); + }); + }); }); diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index 44af6bc81e..986d79df5b 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -24,18 +24,20 @@ import { debugLogger, coreEvents, CoreEvent, - createWorkingStdio, recordToolCallInteractions, ToolErrorType, Scheduler, ROOT_SCHEDULER_ID, + convertSessionToClientHistory, + createWorkingStdio, + AgentSession, + AgentTerminateMode, } from '@google/gemini-cli-core'; import type { Content, Part } from '@google/genai'; import readline from 'node:readline'; import stripAnsi from 'strip-ansi'; -import { convertSessionToHistoryFormats } from './ui/hooks/useSessionBrowser.js'; import { handleSlashCommand } from './nonInteractiveCliCommands.js'; import { ConsolePatcher } from './ui/utils/ConsolePatcher.js'; import { handleAtCommand } from './ui/hooks/atCommandProcessor.js'; @@ -55,6 +57,16 @@ interface RunNonInteractiveParams { resumedSessionData?: ResumedSessionData; } +interface LoopContext { + consolePatcher: ConsolePatcher; + textOutput: TextOutput; + abortController: AbortController; + streamFormatter: StreamJsonFormatter | null; + startTime: number; + setupStdinCancellation: () => void; + cleanupStdinCancellation: () => void; +} + export async function runNonInteractive({ config, settings, @@ -209,36 +221,20 @@ export async function runNonInteractive({ } }); - const geminiClient = config.getGeminiClient(); - const scheduler = new Scheduler({ - config, - messageBus: config.getMessageBus(), - getPreferredEditor: () => undefined, - schedulerId: ROOT_SCHEDULER_ID, - }); - - // Initialize chat. Resume if resume data is passed. - if (resumedSessionData) { - await geminiClient.resumeChat( - convertSessionToHistoryFormats( - resumedSessionData.conversation.messages, - ).clientHistory, - resumedSessionData, - ); - } - - // Emit init event for streaming JSON - if (streamFormatter) { - streamFormatter.emitEvent({ - type: JsonStreamEventType.INIT, - timestamp: new Date().toISOString(), - session_id: config.getSessionId(), - model: config.getModel(), - }); - } + const loopContext: LoopContext = { + consolePatcher, + textOutput, + abortController, + streamFormatter, + startTime, + setupStdinCancellation, + cleanupStdinCancellation, + }; + // --- Input Processing (Lifted) --- let query: Part[] | undefined; + // 1. Slash Commands if (isSlashCommand(input)) { const slashCommandResult = await handleSlashCommand( input, @@ -247,14 +243,13 @@ export async function runNonInteractive({ settings, ); // If a slash command is found and returns a prompt, use it. - // Otherwise, slashCommandResult falls through to the default prompt - // handling. if (slashCommandResult) { // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion query = slashCommandResult as Part[]; } } + // 2. @ Commands (if query not set by slash command) if (!query) { const { processedQuery, error } = await handleAtCommand({ query: input, @@ -266,8 +261,6 @@ export async function runNonInteractive({ }); if (error || !processedQuery) { - // An error occurred during @include processing (e.g., file not found). - // The error message is already logged by handleAtCommand. throw new FatalInputError( error || 'Exiting due to an error processing the @ command.', ); @@ -276,246 +269,18 @@ export async function runNonInteractive({ query = processedQuery as Part[]; } - // Emit user message event for streaming JSON - if (streamFormatter) { - streamFormatter.emitEvent({ - type: JsonStreamEventType.MESSAGE, - timestamp: new Date().toISOString(), - role: 'user', - content: input, - }); - } - - let currentMessages: Content[] = [{ role: 'user', parts: query }]; - - let turnCount = 0; - while (true) { - turnCount++; - if ( - config.getMaxSessionTurns() >= 0 && - turnCount > config.getMaxSessionTurns() - ) { - handleMaxTurnsExceededError(config); - } - const toolCallRequests: ToolCallRequestInfo[] = []; - - const responseStream = geminiClient.sendMessageStream( - currentMessages[0]?.parts || [], - abortController.signal, - prompt_id, - undefined, - false, - turnCount === 1 ? input : undefined, + if (config.isAgentsEnabled()) { + await runAgentSessionFlow( + loopContext, + { config, settings, input, prompt_id, resumedSessionData, query }, // API change: pass query + handleUserFeedback, + ); + } else { + await runLegacyManualLoop( + loopContext, + { config, settings, input, prompt_id, resumedSessionData, query }, // API change: pass query + handleUserFeedback, ); - - let responseText = ''; - for await (const event of responseStream) { - if (abortController.signal.aborted) { - handleCancellationError(config); - } - - if (event.type === GeminiEventType.Content) { - const isRaw = - config.getRawOutput() || config.getAcceptRawOutputRisk(); - const output = isRaw ? event.value : stripAnsi(event.value); - if (streamFormatter) { - streamFormatter.emitEvent({ - type: JsonStreamEventType.MESSAGE, - timestamp: new Date().toISOString(), - role: 'assistant', - content: output, - delta: true, - }); - } else if (config.getOutputFormat() === OutputFormat.JSON) { - responseText += output; - } else { - if (event.value) { - textOutput.write(output); - } - } - } else if (event.type === GeminiEventType.ToolCallRequest) { - if (streamFormatter) { - streamFormatter.emitEvent({ - type: JsonStreamEventType.TOOL_USE, - timestamp: new Date().toISOString(), - tool_name: event.value.name, - tool_id: event.value.callId, - parameters: event.value.args, - }); - } - toolCallRequests.push(event.value); - } else if (event.type === GeminiEventType.LoopDetected) { - if (streamFormatter) { - streamFormatter.emitEvent({ - type: JsonStreamEventType.ERROR, - timestamp: new Date().toISOString(), - severity: 'warning', - message: 'Loop detected, stopping execution', - }); - } - } else if (event.type === GeminiEventType.MaxSessionTurns) { - if (streamFormatter) { - streamFormatter.emitEvent({ - type: JsonStreamEventType.ERROR, - timestamp: new Date().toISOString(), - severity: 'error', - message: 'Maximum session turns exceeded', - }); - } - } else if (event.type === GeminiEventType.Error) { - throw event.value.error; - } else if (event.type === GeminiEventType.AgentExecutionStopped) { - const stopMessage = `Agent execution stopped: ${event.value.systemMessage?.trim() || event.value.reason}`; - if (config.getOutputFormat() === OutputFormat.TEXT) { - process.stderr.write(`${stopMessage}\n`); - } - // Emit final result event for streaming JSON if needed - if (streamFormatter) { - const metrics = uiTelemetryService.getMetrics(); - const durationMs = Date.now() - startTime; - streamFormatter.emitEvent({ - type: JsonStreamEventType.RESULT, - timestamp: new Date().toISOString(), - status: 'success', - stats: streamFormatter.convertToStreamStats( - metrics, - durationMs, - ), - }); - } - return; - } else if (event.type === GeminiEventType.AgentExecutionBlocked) { - const blockMessage = `Agent execution blocked: ${event.value.systemMessage?.trim() || event.value.reason}`; - if (config.getOutputFormat() === OutputFormat.TEXT) { - process.stderr.write(`[WARNING] ${blockMessage}\n`); - } - } - } - - if (toolCallRequests.length > 0) { - textOutput.ensureTrailingNewline(); - const completedToolCalls = await scheduler.schedule( - toolCallRequests, - abortController.signal, - ); - const toolResponseParts: Part[] = []; - - for (const completedToolCall of completedToolCalls) { - const toolResponse = completedToolCall.response; - const requestInfo = completedToolCall.request; - - if (streamFormatter) { - streamFormatter.emitEvent({ - type: JsonStreamEventType.TOOL_RESULT, - timestamp: new Date().toISOString(), - tool_id: requestInfo.callId, - status: - completedToolCall.status === 'error' ? 'error' : 'success', - output: - typeof toolResponse.resultDisplay === 'string' - ? toolResponse.resultDisplay - : undefined, - error: toolResponse.error - ? { - type: toolResponse.errorType || 'TOOL_EXECUTION_ERROR', - message: toolResponse.error.message, - } - : undefined, - }); - } - - if (toolResponse.error) { - handleToolError( - requestInfo.name, - toolResponse.error, - config, - toolResponse.errorType || 'TOOL_EXECUTION_ERROR', - typeof toolResponse.resultDisplay === 'string' - ? toolResponse.resultDisplay - : undefined, - ); - } - - if (toolResponse.responseParts) { - toolResponseParts.push(...toolResponse.responseParts); - } - } - - // Record tool calls with full metadata before sending responses to Gemini - try { - const currentModel = - geminiClient.getCurrentSequenceModel() ?? config.getModel(); - geminiClient - .getChat() - .recordCompletedToolCalls(currentModel, completedToolCalls); - - await recordToolCallInteractions(config, completedToolCalls); - } catch (error) { - debugLogger.error( - `Error recording completed tool call information: ${error}`, - ); - } - - // Check if any tool requested to stop execution immediately - const stopExecutionTool = completedToolCalls.find( - (tc) => tc.response.errorType === ToolErrorType.STOP_EXECUTION, - ); - - if (stopExecutionTool && stopExecutionTool.response.error) { - const stopMessage = `Agent execution stopped: ${stopExecutionTool.response.error.message}`; - - if (config.getOutputFormat() === OutputFormat.TEXT) { - process.stderr.write(`${stopMessage}\n`); - } - - // Emit final result event for streaming JSON - if (streamFormatter) { - const metrics = uiTelemetryService.getMetrics(); - const durationMs = Date.now() - startTime; - streamFormatter.emitEvent({ - type: JsonStreamEventType.RESULT, - timestamp: new Date().toISOString(), - status: 'success', - stats: streamFormatter.convertToStreamStats( - metrics, - durationMs, - ), - }); - } else if (config.getOutputFormat() === OutputFormat.JSON) { - const formatter = new JsonFormatter(); - const stats = uiTelemetryService.getMetrics(); - textOutput.write( - formatter.format(config.getSessionId(), responseText, stats), - ); - } else { - textOutput.ensureTrailingNewline(); // Ensure a final newline - } - return; - } - - currentMessages = [{ role: 'user', parts: toolResponseParts }]; - } else { - // Emit final result event for streaming JSON - if (streamFormatter) { - const metrics = uiTelemetryService.getMetrics(); - const durationMs = Date.now() - startTime; - streamFormatter.emitEvent({ - type: JsonStreamEventType.RESULT, - timestamp: new Date().toISOString(), - status: 'success', - stats: streamFormatter.convertToStreamStats(metrics, durationMs), - }); - } else if (config.getOutputFormat() === OutputFormat.JSON) { - const formatter = new JsonFormatter(); - const stats = uiTelemetryService.getMetrics(); - textOutput.write( - formatter.format(config.getSessionId(), responseText, stats), - ); - } else { - textOutput.ensureTrailingNewline(); // Ensure a final newline - } - return; - } } } catch (error) { errorToHandle = error; @@ -532,3 +297,436 @@ export async function runNonInteractive({ } }); } + +async function runAgentSessionFlow( + { textOutput, abortController, streamFormatter, startTime }: LoopContext, + { + config, + settings: _settings, + input, + prompt_id: _prompt_id, + resumedSessionData, + query, + }: RunNonInteractiveParams & { query: Part[] }, + _handleUserFeedback: (payload: UserFeedbackPayload) => void, +) { + const session = new AgentSession( + config.getSessionId(), + { + name: 'cli-agent', + maxTurns: config.getMaxSessionTurns(), + }, + config, + ); + + if (resumedSessionData) { + await session.resume(resumedSessionData); + } + + let finalResponseText = ''; + + // Handle initialization for stream JSON format + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.INIT, + timestamp: new Date().toISOString(), + session_id: config.getSessionId(), + model: config.getModel(), + }); + } + + // NOTE: Input processing (Slash commands, @ commands) is now handled in `runNonInteractive` + // and passed in via `query`. + + // Emit user message event for streaming JSON + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.MESSAGE, + timestamp: new Date().toISOString(), + role: 'user', + content: input, + }); + } + + // Start Agent Loop + const stream = session.prompt(query, abortController.signal); + + for await (const event of stream) { + if (event.type === GeminiEventType.Content) { + const isRaw = config.getRawOutput() || config.getAcceptRawOutputRisk(); + const output = isRaw ? event.value : stripAnsi(event.value); + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.MESSAGE, + timestamp: new Date().toISOString(), + role: 'assistant', + content: output, + delta: true, + }); + } else if (config.getOutputFormat() === OutputFormat.JSON) { + finalResponseText += output; + } else { + if (event.value) { + textOutput.write(output); + } + } + } else if (event.type === GeminiEventType.ToolCallRequest) { + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.TOOL_USE, + timestamp: new Date().toISOString(), + tool_name: event.value.name, + tool_id: event.value.callId, + parameters: event.value.args, + }); + } + } else if (event.type === 'tool_suite_finish') { + // Replicates the "TOOL_RESULT" emission from legacy loop + // The legacy loop emits this *after* execution. + // AgentSession emits 'tool_suite_finish' after execution. + for (const response of event.value.responses) { + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.TOOL_RESULT, + timestamp: new Date().toISOString(), + tool_id: response.callId, + status: response.error ? 'error' : 'success', + output: + typeof response.resultDisplay === 'string' + ? response.resultDisplay + : undefined, + error: response.error + ? { + type: response.errorType || 'TOOL_EXECUTION_ERROR', + message: response.error.message, + } + : undefined, + }); + } + // Handle explicit error printing for TEXT mode + if (response.error && config.getOutputFormat() === OutputFormat.TEXT) { + handleToolError( + response.callId, // Using callId as name fallback since name is not available in response + new Error(response.error.message), + config, + response.errorType || 'TOOL_EXECUTION_ERROR', + ); + } + } + } else if (event.type === 'agent_finish') { + const { reason, message, error: _error } = event.value; + + if (reason === AgentTerminateMode.MAX_TURNS) { + handleMaxTurnsExceededError(config); + } else if ( + reason === AgentTerminateMode.ERROR || + reason === AgentTerminateMode.ABORTED + ) { + if (config.getOutputFormat() === OutputFormat.TEXT && message) { + process.stderr.write(`Agent execution stopped: ${message}\n`); + } + } + + // Emit Final JSON + if (streamFormatter) { + const metrics = uiTelemetryService.getMetrics(); + const durationMs = Date.now() - startTime; + streamFormatter.emitEvent({ + type: JsonStreamEventType.RESULT, + timestamp: new Date().toISOString(), + status: reason === AgentTerminateMode.ERROR ? 'error' : 'success', + stats: streamFormatter.convertToStreamStats(metrics, durationMs), + }); + } else if (config.getOutputFormat() === OutputFormat.JSON) { + const formatter = new JsonFormatter(); + const stats = uiTelemetryService.getMetrics(); + textOutput.write( + formatter.format(config.getSessionId(), finalResponseText, stats), + ); + } else { + textOutput.ensureTrailingNewline(); + } + } + } +} + +async function runLegacyManualLoop( + { + textOutput, + abortController, + streamFormatter, + startTime, + cleanupStdinCancellation: _cleanupStdinCancellation, + }: LoopContext, + { + config, + settings: _settings, + input, + prompt_id, + resumedSessionData, + query, + }: RunNonInteractiveParams & { query: Part[] }, + _handleUserFeedback: (payload: UserFeedbackPayload) => void, +) { + const geminiClient = config.getGeminiClient(); + const scheduler = new Scheduler({ + config, + messageBus: config.getMessageBus(), + getPreferredEditor: () => undefined, + schedulerId: ROOT_SCHEDULER_ID, + }); + + // Initialize chat. Resume if resume data is passed. + if (resumedSessionData) { + await geminiClient.resumeChat( + convertSessionToClientHistory(resumedSessionData.conversation.messages), + resumedSessionData, + ); + } + + // Emit init event for streaming JSON + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.INIT, + timestamp: new Date().toISOString(), + session_id: config.getSessionId(), + model: config.getModel(), + }); + } + + // NOTE: Input processing now handled upstream in `runNonInteractive` + + // Emit user message event for streaming JSON + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.MESSAGE, + timestamp: new Date().toISOString(), + role: 'user', + content: input, + }); + } + + let currentMessages: Content[] = [{ role: 'user', parts: query }]; + + let turnCount = 0; + while (true) { + turnCount++; + if ( + config.getMaxSessionTurns() >= 0 && + turnCount > config.getMaxSessionTurns() + ) { + handleMaxTurnsExceededError(config); + } + const toolCallRequests: ToolCallRequestInfo[] = []; + + const responseStream = geminiClient.sendMessageStream( + currentMessages[0]?.parts || [], + abortController.signal, + prompt_id, + undefined, + false, + turnCount === 1 ? input : undefined, + ); + + let responseText = ''; + for await (const event of responseStream) { + if (abortController.signal.aborted) { + handleCancellationError(config); + } + + if (event.type === GeminiEventType.Content) { + const isRaw = config.getRawOutput() || config.getAcceptRawOutputRisk(); + const output = isRaw ? event.value : stripAnsi(event.value); + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.MESSAGE, + timestamp: new Date().toISOString(), + role: 'assistant', + content: output, + delta: true, + }); + } else if (config.getOutputFormat() === OutputFormat.JSON) { + responseText += output; + } else { + if (event.value) { + textOutput.write(output); + } + } + } else if (event.type === GeminiEventType.ToolCallRequest) { + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.TOOL_USE, + timestamp: new Date().toISOString(), + tool_name: event.value.name, + tool_id: event.value.callId, + parameters: event.value.args, + }); + } + toolCallRequests.push(event.value); + } else if (event.type === GeminiEventType.LoopDetected) { + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.ERROR, + timestamp: new Date().toISOString(), + severity: 'warning', + message: 'Loop detected, stopping execution', + }); + } + } else if (event.type === GeminiEventType.MaxSessionTurns) { + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.ERROR, + timestamp: new Date().toISOString(), + severity: 'error', + message: 'Maximum session turns exceeded', + }); + } + } else if (event.type === GeminiEventType.Error) { + throw event.value.error; + } else if (event.type === GeminiEventType.AgentExecutionStopped) { + const stopMessage = `Agent execution stopped: ${event.value.systemMessage?.trim() || event.value.reason}`; + if (config.getOutputFormat() === OutputFormat.TEXT) { + process.stderr.write(`${stopMessage}\n`); + } + // Emit final result event for streaming JSON if needed + if (streamFormatter) { + const metrics = uiTelemetryService.getMetrics(); + const durationMs = Date.now() - startTime; + streamFormatter.emitEvent({ + type: JsonStreamEventType.RESULT, + timestamp: new Date().toISOString(), + status: 'success', + stats: streamFormatter.convertToStreamStats(metrics, durationMs), + }); + } + return; + } else if (event.type === GeminiEventType.AgentExecutionBlocked) { + const blockMessage = `Agent execution blocked: ${event.value.systemMessage?.trim() || event.value.reason}`; + if (config.getOutputFormat() === OutputFormat.TEXT) { + process.stderr.write(`[WARNING] ${blockMessage}\n`); + } + } + } + + if (toolCallRequests.length > 0) { + textOutput.ensureTrailingNewline(); + const completedToolCalls = await scheduler.schedule( + toolCallRequests, + abortController.signal, + ); + const toolResponseParts: Part[] = []; + + for (const completedToolCall of completedToolCalls) { + const toolResponse = completedToolCall.response; + const requestInfo = completedToolCall.request; + + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.TOOL_RESULT, + timestamp: new Date().toISOString(), + tool_id: requestInfo.callId, + status: completedToolCall.status === 'error' ? 'error' : 'success', + output: + typeof toolResponse.resultDisplay === 'string' + ? toolResponse.resultDisplay + : undefined, + error: toolResponse.error + ? { + type: toolResponse.errorType || 'TOOL_EXECUTION_ERROR', + message: toolResponse.error.message, + } + : undefined, + }); + } + + if (toolResponse.error) { + handleToolError( + requestInfo.name, + toolResponse.error, + config, + toolResponse.errorType || 'TOOL_EXECUTION_ERROR', + typeof toolResponse.resultDisplay === 'string' + ? toolResponse.resultDisplay + : undefined, + ); + } + + if (toolResponse.responseParts) { + toolResponseParts.push(...toolResponse.responseParts); + } + } + + // Record tool calls with full metadata before sending responses to Gemini + try { + const currentModel = + geminiClient.getCurrentSequenceModel() ?? config.getModel(); + geminiClient + .getChat() + .recordCompletedToolCalls(currentModel, completedToolCalls); + + await recordToolCallInteractions(config, completedToolCalls); + } catch (error) { + debugLogger.error( + `Error recording completed tool call information: ${error}`, + ); + } + + // Check if any tool requested to stop execution immediately + const stopExecutionTool = completedToolCalls.find( + (tc) => tc.response.errorType === ToolErrorType.STOP_EXECUTION, + ); + + if (stopExecutionTool && stopExecutionTool.response.error) { + const stopMessage = `Agent execution stopped: ${stopExecutionTool.response.error.message}`; + + if (config.getOutputFormat() === OutputFormat.TEXT) { + process.stderr.write(`${stopMessage}\n`); + } + + // Emit final result event for streaming JSON + if (streamFormatter) { + const metrics = uiTelemetryService.getMetrics(); + const durationMs = Date.now() - startTime; + streamFormatter.emitEvent({ + type: JsonStreamEventType.RESULT, + timestamp: new Date().toISOString(), + status: 'success', + stats: streamFormatter.convertToStreamStats(metrics, durationMs), + }); + } else if (config.getOutputFormat() === OutputFormat.JSON) { + const formatter = new JsonFormatter(); + const stats = uiTelemetryService.getMetrics(); + textOutput.write( + formatter.format(config.getSessionId(), responseText, stats), + ); + } else { + textOutput.ensureTrailingNewline(); // Ensure a final newline + } + return; + } + + currentMessages = [{ role: 'user', parts: toolResponseParts }]; + } else { + // Emit final result event for streaming JSON + if (streamFormatter) { + const metrics = uiTelemetryService.getMetrics(); + const durationMs = Date.now() - startTime; + streamFormatter.emitEvent({ + type: JsonStreamEventType.RESULT, + timestamp: new Date().toISOString(), + status: 'success', + stats: streamFormatter.convertToStreamStats(metrics, durationMs), + }); + } else if (config.getOutputFormat() === OutputFormat.JSON) { + const formatter = new JsonFormatter(); + const stats = uiTelemetryService.getMetrics(); + textOutput.write( + formatter.format(config.getSessionId(), responseText, stats), + ); + } else { + textOutput.ensureTrailingNewline(); // Ensure a final newline + } + return; + } + } +} diff --git a/packages/cli/src/ui/commands/rewindCommand.tsx b/packages/cli/src/ui/commands/rewindCommand.tsx index d405172661..c4af3e845d 100644 --- a/packages/cli/src/ui/commands/rewindCommand.tsx +++ b/packages/cli/src/ui/commands/rewindCommand.tsx @@ -23,6 +23,7 @@ import { RewindEvent, type ChatRecordingService, type GeminiClient, + convertSessionToClientHistory, } from '@google/gemini-cli-core'; /** @@ -54,9 +55,8 @@ async function rewindConversation( } // Convert to UI and Client formats - const { uiHistory, clientHistory } = convertSessionToHistoryFormats( - conversation.messages, - ); + const { uiHistory } = convertSessionToHistoryFormats(conversation.messages); + const clientHistory = convertSessionToClientHistory(conversation.messages); client.setHistory(clientHistory as Content[]); diff --git a/packages/cli/src/ui/components/__snapshots__/MainContent.test.tsx.snap b/packages/cli/src/ui/components/__snapshots__/MainContent.test.tsx.snap index 29eda04bab..8c10a69589 100644 --- a/packages/cli/src/ui/components/__snapshots__/MainContent.test.tsx.snap +++ b/packages/cli/src/ui/components/__snapshots__/MainContent.test.tsx.snap @@ -26,7 +26,6 @@ AppHeader(full) │ Line 18 │ │ Line 19 │ │ Line 20 │ -│ │ ╰──────────────────────────────────────────────────────────────────────────────────────────────╯ ShowMoreLines " diff --git a/packages/cli/src/ui/hooks/useSessionBrowser.test.ts b/packages/cli/src/ui/hooks/useSessionBrowser.test.ts index d6326bdfef..ceff3e9c8c 100644 --- a/packages/cli/src/ui/hooks/useSessionBrowser.test.ts +++ b/packages/cli/src/ui/hooks/useSessionBrowser.test.ts @@ -20,7 +20,10 @@ import { type MessageRecord, CoreToolCallStatus, } from '@google/gemini-cli-core'; -import { coreEvents } from '@google/gemini-cli-core'; +import { + coreEvents, + convertSessionToClientHistory, +} from '@google/gemini-cli-core'; // Mock modules vi.mock('fs/promises'); @@ -157,7 +160,7 @@ describe('convertSessionToHistoryFormats', () => { it('should convert empty messages array', () => { const result = convertSessionToHistoryFormats([]); expect(result.uiHistory).toEqual([]); - expect(result.clientHistory).toEqual([]); + expect(convertSessionToClientHistory([])).toEqual([]); }); it('should convert basic user and model messages', () => { @@ -175,12 +178,13 @@ describe('convertSessionToHistoryFormats', () => { text: 'Hi there', }); - expect(result.clientHistory).toHaveLength(2); - expect(result.clientHistory[0]).toEqual({ + const clientHistory = convertSessionToClientHistory(messages); + expect(clientHistory).toHaveLength(2); + expect(clientHistory[0]).toEqual({ role: 'user', parts: [{ text: 'Hello' }], }); - expect(result.clientHistory[1]).toEqual({ + expect(clientHistory[1]).toEqual({ role: 'model', parts: [{ text: 'Hi there' }], }); @@ -203,8 +207,9 @@ describe('convertSessionToHistoryFormats', () => { text: 'User input', }); - expect(result.clientHistory).toHaveLength(1); - expect(result.clientHistory[0]).toEqual({ + const clientHistory = convertSessionToClientHistory(messages); + expect(clientHistory).toHaveLength(1); + expect(clientHistory[0]).toEqual({ role: 'user', parts: [{ text: 'Expanded content' }], }); @@ -225,7 +230,7 @@ describe('convertSessionToHistoryFormats', () => { text: 'Help text', }); - expect(result.clientHistory).toHaveLength(0); + expect(convertSessionToClientHistory(messages)).toHaveLength(0); }); it('should handle tool calls and responses', () => { @@ -264,12 +269,13 @@ describe('convertSessionToHistoryFormats', () => { ], }); - expect(result.clientHistory).toHaveLength(3); // User, Model (call), User (response) - expect(result.clientHistory[0]).toEqual({ + const clientHistory = convertSessionToClientHistory(messages); + expect(clientHistory).toHaveLength(3); // User, Model (call), User (response) + expect(clientHistory[0]).toEqual({ role: 'user', parts: [{ text: 'What time is it?' }], }); - expect(result.clientHistory[1]).toEqual({ + expect(clientHistory[1]).toEqual({ role: 'model', parts: [ { @@ -281,7 +287,7 @@ describe('convertSessionToHistoryFormats', () => { }, ], }); - expect(result.clientHistory[2]).toEqual({ + expect(clientHistory[2]).toEqual({ role: 'user', parts: [ { diff --git a/packages/cli/src/ui/hooks/useSessionBrowser.ts b/packages/cli/src/ui/hooks/useSessionBrowser.ts index de6495c3b9..c2a78e0c64 100644 --- a/packages/cli/src/ui/hooks/useSessionBrowser.ts +++ b/packages/cli/src/ui/hooks/useSessionBrowser.ts @@ -13,7 +13,10 @@ import type { ConversationRecord, ResumedSessionData, } from '@google/gemini-cli-core'; -import { coreEvents } from '@google/gemini-cli-core'; +import { + coreEvents, + convertSessionToClientHistory, +} from '@google/gemini-cli-core'; import type { SessionInfo } from '../../utils/sessionUtils.js'; import { convertSessionToHistoryFormats } from '../../utils/sessionUtils.js'; import type { Part } from '@google/genai'; @@ -77,7 +80,7 @@ export const useSessionBrowser = ( ); await onLoadHistory( historyData.uiHistory, - historyData.clientHistory, + convertSessionToClientHistory(conversation.messages), resumedSessionData, ); } catch (error) { diff --git a/packages/cli/src/ui/hooks/useSessionResume.ts b/packages/cli/src/ui/hooks/useSessionResume.ts index 9889c4bd12..055686773b 100644 --- a/packages/cli/src/ui/hooks/useSessionResume.ts +++ b/packages/cli/src/ui/hooks/useSessionResume.ts @@ -9,6 +9,7 @@ import { coreEvents, type Config, type ResumedSessionData, + convertSessionToClientHistory, } from '@google/gemini-cli-core'; import type { Part } from '@google/genai'; import type { HistoryItemWithoutId } from '../types.js'; @@ -113,7 +114,7 @@ export function useSessionResume({ ); void loadHistoryForResume( historyData.uiHistory, - historyData.clientHistory, + convertSessionToClientHistory(resumedSessionData.conversation.messages), resumedSessionData, ); } diff --git a/packages/cli/src/utils/sessionUtils.ts b/packages/cli/src/utils/sessionUtils.ts index bc2cfbc0e2..18a53b048b 100644 --- a/packages/cli/src/utils/sessionUtils.ts +++ b/packages/cli/src/utils/sessionUtils.ts @@ -16,7 +16,6 @@ import { import * as fs from 'node:fs/promises'; import path from 'node:path'; import { stripUnsafeCharacters } from '../ui/utils/textUtils.js'; -import type { Part } from '@google/genai'; import { MessageType, type HistoryItemWithoutId } from '../ui/types.js'; /** @@ -518,13 +517,12 @@ export class SessionSelector { } /** - * Converts session/conversation data into UI history and Gemini client history formats. + * Converts session/conversation data into UI history format. */ export function convertSessionToHistoryFormats( messages: ConversationRecord['messages'], ): { uiHistory: HistoryItemWithoutId[]; - clientHistory: Array<{ role: 'user' | 'model'; parts: Part[] }>; } { const uiHistory: HistoryItemWithoutId[] = []; @@ -591,117 +589,7 @@ export function convertSessionToHistoryFormats( } } - // Convert to Gemini client history format - const clientHistory: Array<{ role: 'user' | 'model'; parts: Part[] }> = []; - - for (const msg of messages) { - // Skip system/error messages and user slash commands - if (msg.type === 'info' || msg.type === 'error' || msg.type === 'warning') { - continue; - } - - if (msg.type === 'user') { - // Skip user slash commands - const contentString = partListUnionToString(msg.content); - if ( - contentString.trim().startsWith('/') || - contentString.trim().startsWith('?') - ) { - continue; - } - - // Add regular user message - clientHistory.push({ - role: 'user', - parts: Array.isArray(msg.content) - ? // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - (msg.content as Part[]) - : [{ text: contentString }], - }); - } else if (msg.type === 'gemini') { - // Handle Gemini messages with potential tool calls - const hasToolCalls = msg.toolCalls && msg.toolCalls.length > 0; - - if (hasToolCalls) { - // Create model message with function calls - const modelParts: Part[] = []; - - // Add text content if present - const contentString = partListUnionToString(msg.content); - if (msg.content && contentString.trim()) { - modelParts.push({ text: contentString }); - } - - // Add function calls - for (const toolCall of msg.toolCalls!) { - modelParts.push({ - functionCall: { - name: toolCall.name, - args: toolCall.args, - ...(toolCall.id && { id: toolCall.id }), - }, - }); - } - - clientHistory.push({ - role: 'model', - parts: modelParts, - }); - - // Create single function response message with all tool call responses - const functionResponseParts: Part[] = []; - for (const toolCall of msg.toolCalls!) { - if (toolCall.result) { - // Convert PartListUnion result to function response format - let responseData: Part; - - if (typeof toolCall.result === 'string') { - responseData = { - functionResponse: { - id: toolCall.id, - name: toolCall.name, - response: { - output: toolCall.result, - }, - }, - }; - } else if (Array.isArray(toolCall.result)) { - // toolCall.result is an array containing properly formatted - // function responses - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - functionResponseParts.push(...(toolCall.result as Part[])); - continue; - } else { - // Fallback for non-array results - responseData = toolCall.result; - } - - functionResponseParts.push(responseData); - } - } - - // Only add user message if we have function responses - if (functionResponseParts.length > 0) { - clientHistory.push({ - role: 'user', - parts: functionResponseParts, - }); - } - } else { - // Regular Gemini message without tool calls - const contentString = partListUnionToString(msg.content); - if (msg.content && contentString.trim()) { - clientHistory.push({ - role: 'model', - parts: [{ text: contentString }], - }); - } - } - } - } - return { uiHistory, - clientHistory, }; } diff --git a/packages/cli/src/zed-integration/acpResume.test.ts b/packages/cli/src/zed-integration/acpResume.test.ts index 869c27bf52..f814a9e586 100644 --- a/packages/cli/src/zed-integration/acpResume.test.ts +++ b/packages/cli/src/zed-integration/acpResume.test.ts @@ -26,6 +26,7 @@ import { SessionSelector, convertSessionToHistoryFormats, } from '../utils/sessionUtils.js'; +import { convertSessionToClientHistory } from '@google/gemini-cli-core'; import type { LoadedSettings } from '../config/settings.js'; vi.mock('../config/config.js', () => ({ @@ -42,6 +43,15 @@ vi.mock('../utils/sessionUtils.js', async (importOriginal) => { }; }); +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const actual = + await importOriginal(); + return { + ...actual, + convertSessionToClientHistory: vi.fn(), + }; +}); + describe('GeminiAgent Session Resume', () => { let mockConfig: Mocked; let mockSettings: Mocked; @@ -142,9 +152,11 @@ describe('GeminiAgent Session Resume', () => { { role: 'model', parts: [{ text: 'Hi there' }] }, ]; (convertSessionToHistoryFormats as unknown as Mock).mockReturnValue({ - clientHistory: mockClientHistory, uiHistory: [], }); + (convertSessionToClientHistory as unknown as Mock).mockReturnValue( + mockClientHistory, + ); const response = await agent.loadSession({ sessionId, diff --git a/packages/cli/src/zed-integration/zedIntegration.ts b/packages/cli/src/zed-integration/zedIntegration.ts index 44b1890ce2..3360b64513 100644 --- a/packages/cli/src/zed-integration/zedIntegration.ts +++ b/packages/cli/src/zed-integration/zedIntegration.ts @@ -37,6 +37,7 @@ import { partListUnionToString, LlmRole, ApprovalMode, + convertSessionToClientHistory, } from '@google/gemini-cli-core'; import * as acp from '@agentclientprotocol/sdk'; import { AcpFileSystemService } from './fileSystemService.js'; @@ -53,10 +54,7 @@ import { randomUUID } from 'node:crypto'; import type { CliArgs } from '../config/config.js'; import { loadCliConfig } from '../config/config.js'; import { runExitCleanup } from '../utils/cleanup.js'; -import { - SessionSelector, - convertSessionToHistoryFormats, -} from '../utils/sessionUtils.js'; +import { SessionSelector } from '../utils/sessionUtils.js'; export async function runZedIntegration( config: Config, @@ -258,9 +256,7 @@ export class GeminiAgent { config.setFileSystemService(acpFileSystemService); } - const { clientHistory } = convertSessionToHistoryFormats( - sessionData.messages, - ); + const clientHistory = convertSessionToClientHistory(sessionData.messages); const geminiClient = config.getGeminiClient(); await geminiClient.initialize(); diff --git a/packages/core/src/agents/agent.test.ts b/packages/core/src/agents/agent.test.ts new file mode 100644 index 0000000000..4c4de77da5 --- /dev/null +++ b/packages/core/src/agents/agent.test.ts @@ -0,0 +1,73 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { Agent } from './agent.js'; +import { AgentSession } from './session.js'; +import { makeFakeConfig } from '../test-utils/config.js'; +import { type AgentConfig } from './types.js'; + +vi.mock('./session.js', () => ({ + AgentSession: vi.fn().mockImplementation(() => ({ + prompt: vi.fn().mockImplementation(async function* () { + yield { type: 'agent_start', value: { sessionId: 'test-session' } }; + yield { + type: 'agent_finish', + value: { sessionId: 'test-session', totalTurns: 1 }, + }; + }), + })), +})); + +describe('Agent', () => { + let mockConfig: ReturnType; + const agentConfig: AgentConfig = { + name: 'TestAgent', + systemInstruction: 'You are a test agent.', + }; + + beforeEach(() => { + vi.clearAllMocks(); + mockConfig = makeFakeConfig(); + vi.spyOn(mockConfig, 'getSessionId').mockReturnValue('global-session-id'); + }); + + it('should create an AgentSession', () => { + const agent = new Agent(agentConfig, mockConfig); + const session = agent.createSession('custom-session-id'); + + expect(session).toBeDefined(); + expect(AgentSession).toHaveBeenCalledWith( + 'custom-session-id', + agentConfig, + mockConfig, + ); + }); + + it('should use global session ID if none provided to createSession', () => { + const agent = new Agent(agentConfig, mockConfig); + agent.createSession(); + + expect(AgentSession).toHaveBeenCalledWith( + 'global-session-id', + agentConfig, + mockConfig, + ); + }); + + it('should prompt through a new session', async () => { + const agent = new Agent(agentConfig, mockConfig); + const events = []; + for await (const event of agent.prompt('Hello')) { + events.push(event); + } + + expect(events).toHaveLength(2); + expect(events[0].type).toBe('agent_start'); + expect(events[1].type).toBe('agent_finish'); + expect(AgentSession).toHaveBeenCalled(); + }); +}); diff --git a/packages/core/src/agents/agent.ts b/packages/core/src/agents/agent.ts new file mode 100644 index 0000000000..39aef3368d --- /dev/null +++ b/packages/core/src/agents/agent.ts @@ -0,0 +1,41 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { type Part } from '@google/genai'; +import { type Config } from '../config/config.js'; +import { type AgentEvent, type AgentConfig } from './types.js'; +import { AgentSession } from './session.js'; + +/** + * The Agent class is a factory for creating stateful AgentSessions. + * This represents a configured agent template. + */ +export class Agent { + constructor( + private readonly config: AgentConfig, + private readonly runtime: Config, + ) {} + + /** + * Creates a new stateful session for interacting with the agent. + */ + createSession(sessionId?: string): AgentSession { + const id = sessionId ?? this.runtime.getSessionId(); + return new AgentSession(id, this.config, this.runtime); + } + + /** + * Helper to quickly run a single prompt and get the results. + */ + async *prompt( + input: string | Part[], + sessionId?: string, + signal?: AbortSignal, + ): AsyncIterable { + const session = this.createSession(sessionId); + yield* session.prompt(input, signal); + } +} diff --git a/packages/core/src/agents/session.test.ts b/packages/core/src/agents/session.test.ts new file mode 100644 index 0000000000..af54db22d1 --- /dev/null +++ b/packages/core/src/agents/session.test.ts @@ -0,0 +1,271 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { AgentSession } from './session.js'; +import { makeFakeConfig } from '../test-utils/config.js'; +import { type AgentConfig } from './types.js'; +import { Scheduler } from '../scheduler/scheduler.js'; +import { GeminiEventType } from '../core/turn.js'; +import { ChatCompressionService } from '../services/chatCompressionService.js'; +import { CompressionStatus } from '../core/turn.js'; +import { AgentTerminateMode, type AgentEvent } from './types.js'; +import { ToolErrorType } from '../tools/tool-error.js'; + +vi.mock('../core/client.js'); +vi.mock('../scheduler/scheduler.js'); +vi.mock('../services/chatCompressionService.js'); + +describe('AgentSession', () => { + let mockConfig: ReturnType; + let mockClient: { + sendMessageStream: ReturnType; + getChat: ReturnType; + getCurrentSequenceModel: ReturnType; + getHistory: ReturnType; + }; + let mockScheduler: { + schedule: ReturnType; + }; + let mockCompressionService: { + compress: ReturnType; + }; + let session: AgentSession; + const agentConfig: AgentConfig = { + name: 'TestAgent', + capabilities: { compression: true }, + }; + + beforeEach(() => { + vi.clearAllMocks(); + mockConfig = makeFakeConfig(); + + mockClient = { + sendMessageStream: vi.fn(), + getChat: vi.fn().mockReturnValue({ + recordCompletedToolCalls: vi.fn(), + setHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }), + getCurrentSequenceModel: vi.fn().mockReturnValue('test-model'), + getHistory: vi.fn().mockReturnValue([]), + }; + + mockScheduler = { + schedule: vi.fn(), + }; + + mockCompressionService = { + compress: vi.fn().mockResolvedValue({ + newHistory: null, + info: { compressionStatus: CompressionStatus.NOOP }, + }), + }; + + vi.spyOn(mockConfig, 'getGeminiClient').mockReturnValue( + mockClient as unknown as import('../core/client.js').GeminiClient, + ); + vi.mocked(Scheduler).mockImplementation( + () => mockScheduler as unknown as Scheduler, + ); + vi.mocked(ChatCompressionService).mockImplementation( + () => mockCompressionService as unknown as ChatCompressionService, + ); + + session = new AgentSession('test-session', agentConfig, mockConfig); + }); + + it('should emit agent_start and agent_finish', async () => { + mockClient.sendMessageStream.mockImplementation(async function* () { + yield { type: GeminiEventType.Content, value: 'Hello' }; + yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } }; + }); + + const events = []; + for await (const event of session.prompt('Hi')) { + events.push(event); + } + + const finishEvent = events[events.length - 1] as Extract< + AgentEvent, + { type: 'agent_finish' } + >; + expect(events[0].type).toBe('agent_start'); + expect(finishEvent.type).toBe('agent_finish'); + expect(finishEvent.value.reason).toBe(AgentTerminateMode.GOAL); + expect(mockClient.sendMessageStream).toHaveBeenCalled(); + }); + + it('should handle tool calls and execute them', async () => { + // Turn 1: Model calls a tool + mockClient.sendMessageStream.mockImplementationOnce(async function* () { + yield { + type: GeminiEventType.ToolCallRequest, + value: { callId: 'call1', name: 'test_tool', args: {} }, + }; + yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } }; + }); + + // Turn 2: Model finishes + mockClient.sendMessageStream.mockImplementationOnce(async function* () { + yield { type: GeminiEventType.Content, value: 'Tool executed' }; + yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } }; + }); + + mockScheduler.schedule.mockResolvedValueOnce([ + { + response: { + callId: 'call1', + responseParts: [ + { + functionResponse: { + name: 'test_tool', + response: { ok: true }, + id: 'call1', + }, + }, + ], + }, + }, + ]); + + const events = []; + for await (const event of session.prompt('Run tool')) { + events.push(event); + } + + expect(mockClient.sendMessageStream).toHaveBeenCalledTimes(2); + expect(mockScheduler.schedule).toHaveBeenCalledTimes(1); + + const suiteStart = events.find((e) => e.type === 'tool_suite_start'); + const suiteFinish = events.find((e) => e.type === 'tool_suite_finish'); + expect(suiteStart).toBeDefined(); + expect(suiteFinish).toBeDefined(); + expect(suiteFinish?.value.responses[0].callId).toBe('call1'); + }); + + it('should trigger compression if enabled', async () => { + mockClient.sendMessageStream.mockImplementation(async function* () { + yield { type: GeminiEventType.Content, value: 'Done' }; + yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } }; + }); + + for await (const _ of session.prompt('Compress me')) { + // consume stream to trigger compression + } + + expect(mockCompressionService.compress).toHaveBeenCalled(); + }); + + it('should respect abort signal', async () => { + const controller = new AbortController(); + mockClient.sendMessageStream.mockImplementation(async function* () { + yield { type: GeminiEventType.Content, value: 'Thinking...' }; + controller.abort(); + yield { type: GeminiEventType.Content, value: 'Still thinking...' }; + }); + + const events = []; + for await (const event of session.prompt('Long task', controller.signal)) { + events.push(event); + } + + // Should finish early + const finishEvent = events[events.length - 1] as Extract< + AgentEvent, + { type: 'agent_finish' } + >; + expect(finishEvent.type).toBe('agent_finish'); + expect(finishEvent.value.reason).toBe(AgentTerminateMode.ABORTED); + // It might still yield the first chunk before the signal is processed in the loop + }); + + it('should emit ERROR reason when a tool requests stop', async () => { + // Turn 1: Model calls a tool + mockClient.sendMessageStream.mockImplementationOnce(async function* () { + yield { + type: GeminiEventType.ToolCallRequest, + value: { callId: 'call_stop', name: 'stop_tool', args: {} }, + }; + yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } }; + }); + + mockScheduler.schedule.mockResolvedValueOnce([ + { + response: { + callId: 'call_stop', + errorType: ToolErrorType.STOP_EXECUTION, + error: new Error('Deny listed command'), + responseParts: [], + }, + }, + ]); + + const events = []; + for await (const event of session.prompt('Run tool')) { + events.push(event); + } + + const finishEvent = events.find( + (e) => e.type === 'agent_finish', + ) as Extract; + expect(finishEvent).toBeDefined(); + expect(finishEvent.value.reason).toBe(AgentTerminateMode.ERROR); + expect(finishEvent.value.message).toBe('Deny listed command'); + }); + + it('should respect maxTurns from config', async () => { + const customSession = new AgentSession( + 'test-session-2', + { ...agentConfig, maxTurns: 2 }, + mockConfig, + ); + + // Mock an infinite loop of tool calls from the model + mockClient.sendMessageStream.mockImplementation(async function* () { + yield { + type: GeminiEventType.ToolCallRequest, + value: { callId: 'call', name: 'test_tool', args: {} }, + }; + yield { type: GeminiEventType.Finished, value: { reason: 'STOP' } }; + }); + + mockScheduler.schedule.mockResolvedValue([ + { + response: { + callId: 'call', + responseParts: [ + { + functionResponse: { + name: 'test_tool', + response: { ok: true }, + id: 'call', + }, + }, + ], + }, + }, + ]); + + const events = []; + for await (const event of customSession.prompt('Start loop')) { + events.push(event); + } + + // It should perform exactly 2 turns, meaning mockScheduler.schedule is called twice + expect(mockScheduler.schedule).toHaveBeenCalledTimes(2); + + // The last event should be agent_finish + const finishEvent = events[events.length - 1] as Extract< + AgentEvent, + { type: 'agent_finish' } + >; + expect(finishEvent.type).toBe('agent_finish'); + expect(finishEvent.value.totalTurns).toBe(2); + expect(finishEvent.value.reason).toBe(AgentTerminateMode.MAX_TURNS); + expect(finishEvent.value.message).toBe('Maximum session turns exceeded.'); + }); +}); diff --git a/packages/core/src/agents/session.ts b/packages/core/src/agents/session.ts new file mode 100644 index 0000000000..f5545ed447 --- /dev/null +++ b/packages/core/src/agents/session.ts @@ -0,0 +1,297 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { type Part } from '@google/genai'; +import { type Config } from '../config/config.js'; +import { type GeminiClient } from '../core/client.js'; +import { type AgentEvent, type AgentConfig } from './types.js'; +import { Scheduler } from '../scheduler/scheduler.js'; +import { + ROOT_SCHEDULER_ID, + type ToolCallRequestInfo, +} from '../scheduler/types.js'; +import { GeminiEventType, CompressionStatus } from '../core/turn.js'; +import { recordToolCallInteractions } from '../code_assist/telemetry.js'; +import { debugLogger } from '../utils/debugLogger.js'; +import { ToolErrorType } from '../tools/tool-error.js'; +import { ChatCompressionService } from '../services/chatCompressionService.js'; +import { AgentTerminateMode } from './types.js'; +import type { ResumedSessionData } from '../services/chatRecordingService.js'; +import { convertSessionToClientHistory } from '../utils/sessionUtils.js'; + +/** + * AgentSession manages the state of a conversation and orchestrates the agent + * loop. + */ +export class AgentSession { + private readonly client: GeminiClient; + private readonly scheduler: Scheduler; + private readonly compressionService: ChatCompressionService; + private totalTurns = 0; + private hasFailedCompressionAttempt = false; + + constructor( + private readonly sessionId: string, + private readonly config: AgentConfig, + private readonly runtime: Config, + ) { + // For now, we reuse the GeminiClient from the global config. + this.client = this.runtime.getGeminiClient(); + this.scheduler = new Scheduler({ + config: this.runtime, + messageBus: this.runtime.getMessageBus(), + getPreferredEditor: () => undefined, + schedulerId: ROOT_SCHEDULER_ID, + }); + this.compressionService = new ChatCompressionService(); + } + + /** + * Resumes the agent session from persistent storage data. + * Hydrates the internal language model client with the previously saved trajectory. + * + * @param resumedSessionData The raw payload of a previously saved session. + */ + async resume(resumedSessionData: ResumedSessionData): Promise { + const clientHistory = convertSessionToClientHistory( + resumedSessionData.conversation.messages, + ); + await this.client.resumeChat(clientHistory, resumedSessionData); + } + + /** + * Executes the ReAct loop for a given user input. + * Returns an AsyncIterable of events occurring during the session. + */ + async *prompt( + input: string | Part[], + signal?: AbortSignal, + ): AsyncIterable { + yield { + type: 'agent_start', + value: { sessionId: this.sessionId }, + }; + + let currentInput = input; + let isContinuation = false; + const maxTurns = this.config.maxTurns ?? -1; + + let terminationReason = AgentTerminateMode.GOAL; + let terminationMessage: string | undefined = undefined; + let terminationError: unknown | undefined = undefined; + + try { + while (maxTurns === -1 || this.totalTurns < maxTurns) { + if (signal?.aborted) { + terminationReason = AgentTerminateMode.ABORTED; + break; + } + + this.totalTurns++; + const promptId = `${this.sessionId}#${this.totalTurns}`; + + // Compression check (from LocalAgentExecutor / useGeminiStream patterns) + if (this.config.capabilities?.compression) { + await this.tryCompressChat(promptId); + } + + const { toolCalls, events } = await this.runModelTurn( + currentInput, + promptId, + isContinuation ? undefined : input, + signal, + ); + + for await (const event of events) { + yield event; + } + + if (signal?.aborted) { + terminationReason = AgentTerminateMode.ABORTED; + break; + } + + if (toolCalls.length > 0) { + const results = await this.executeTools(toolCalls, signal); + for await (const event of results.events) { + yield event; + } + + if (results.stopExecution || signal?.aborted) { + if (signal?.aborted) { + terminationReason = AgentTerminateMode.ABORTED; + } else if (results.stopExecutionInfo) { + terminationReason = AgentTerminateMode.ERROR; + terminationMessage = results.stopExecutionInfo.error?.message; + terminationError = results.stopExecutionInfo.error; + } + break; + } + + // Check if we hit the turn limit + if (maxTurns !== -1 && this.totalTurns >= maxTurns) { + terminationReason = AgentTerminateMode.MAX_TURNS; + terminationMessage = 'Maximum session turns exceeded.'; + break; + } + + currentInput = results.nextParts; + isContinuation = true; + } else { + // No more tool calls, turn is complete. + // If we completed naturally but were at the limit, it's still a GOAL + terminationReason = AgentTerminateMode.GOAL; + break; + } + } + } finally { + yield { + type: 'agent_finish', + value: { + sessionId: this.sessionId, + totalTurns: this.totalTurns, + reason: terminationReason, + message: terminationMessage, + error: terminationError, + }, + }; + } + } + + /** + * Calls the model and yields the event stream. + * Collects tool call requests for the next phase. + */ + private async runModelTurn( + input: string | Part[], + promptId: string, + displayContent?: string | Part[], + signal?: AbortSignal, + ) { + const parts = Array.isArray(input) ? input : [{ text: input }]; + const toolCalls: ToolCallRequestInfo[] = []; + + const stream = this.client.sendMessageStream( + parts, + signal ?? new AbortController().signal, + promptId, + undefined, // maxTurns (client handles its own) + false, // isInvalidStreamRetry + displayContent, + ); + + const eventGenerator = async function* (): AsyncIterable { + for await (const event of stream) { + if (event.type === GeminiEventType.ToolCallRequest) { + toolCalls.push(event.value); + } + yield event as AgentEvent; + } + }; + + return { + toolCalls, + events: eventGenerator(), + }; + } + + /** + * Executes a batch of tool calls via the Scheduler. + */ + private async executeTools( + toolCalls: ToolCallRequestInfo[], + signal?: AbortSignal, + ) { + const events: AgentEvent[] = []; + events.push({ + type: 'tool_suite_start', + value: { count: toolCalls.length }, + }); + + const completedCalls = await this.scheduler.schedule( + toolCalls, + signal ?? new AbortController().signal, + ); + + events.push({ + type: 'tool_suite_finish', + value: { responses: completedCalls.map((c) => c.response) }, + }); + + // Record tool call info for persistence/telemetry + try { + const currentModel = + this.client.getCurrentSequenceModel() ?? this.runtime.getModel(); + this.client + .getChat() + .recordCompletedToolCalls(currentModel, completedCalls); + await recordToolCallInteractions(this.runtime, completedCalls); + } catch (e) { + debugLogger.warn(`Error recording tool call information: ${e}`); + } + + const nextParts = completedCalls.flatMap((c) => c.response.responseParts); + const stopExecutionInfo = completedCalls.find( + (c) => c.response.errorType === ToolErrorType.STOP_EXECUTION, + )?.response; + + const eventGenerator = async function* () { + for (const event of events) { + yield event; + } + }; + + return { + nextParts, + stopExecution: !!stopExecutionInfo, + stopExecutionInfo, + events: eventGenerator(), + }; + } + + /** + * Attempts to compress the chat history if thresholds are exceeded. + */ + private async tryCompressChat(promptId: string): Promise { + const chat = this.client.getChat(); + const model = this.config.model ?? this.runtime.getModel(); + + const { newHistory, info } = await this.compressionService.compress( + chat, + promptId, + false, + model, + this.runtime, + this.hasFailedCompressionAttempt, + ); + + if ( + info.compressionStatus === + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT + ) { + this.hasFailedCompressionAttempt = true; + } else if (info.compressionStatus === CompressionStatus.COMPRESSED) { + if (newHistory) { + chat.setHistory(newHistory); + this.hasFailedCompressionAttempt = false; + } + } + } + + /** + * Returns the current message history for this session. + */ + getHistory() { + return this.client.getHistory(); + } + + /** + * Returns the current session ID. + */ + getSessionId(): string { + return this.sessionId; + } +} diff --git a/packages/core/src/agents/types.ts b/packages/core/src/agents/types.ts index b9994d8b4a..ef2cf3d381 100644 --- a/packages/core/src/agents/types.ts +++ b/packages/core/src/agents/types.ts @@ -1,19 +1,64 @@ /** * @license - * Copyright 2025 Google LLC + * Copyright 2026 Google LLC * SPDX-License-Identifier: Apache-2.0 */ -/** - * @fileoverview Defines the core configuration interfaces and types for the agent architecture. - */ - import type { Content, FunctionDeclaration } from '@google/genai'; import type { AnyDeclarativeTool } from '../tools/tools.js'; import { type z } from 'zod'; import type { ModelConfig } from '../services/modelConfigService.js'; import type { AnySchema } from 'ajv'; import type { A2AAuthConfig } from './auth-provider/types.js'; +import { type ServerGeminiStreamEvent } from '../core/turn.js'; +import { type ToolCallResponseInfo } from '../scheduler/types.js'; + +/** + * Unified event type for the Agent loop. + * This extends the base Gemini stream events with higher-level agent lifecycle events. + */ +export type AgentEvent = + | ServerGeminiStreamEvent + | { type: 'agent_start'; value: { sessionId: string } } + | { + type: 'agent_finish'; + value: { + sessionId: string; + totalTurns: number; + reason: AgentTerminateMode; + message?: string; + error?: unknown; + }; + } + | { type: 'tool_suite_start'; value: { count: number } } + | { type: 'tool_suite_finish'; value: { responses: ToolCallResponseInfo[] } } + | { type: 'thought'; value: string } + | { type: 'loop_detected'; value: { sessionId: string } }; + +/** + * Configuration for an Agent. + */ +export interface AgentConfig { + /** The name of the agent. */ + name: string; + /** The system instruction (personality/rules) for the agent. */ + systemInstruction?: string; + /** Optional override for the model to use. */ + model?: string; + /** + * Optional maximum number of conversational turns. + * Set to -1 for no limit, defaults to -1 if not specified. + */ + maxTurns?: number; + /** + * Optional capabilities to enable for this agent. + */ + capabilities?: { + compression?: boolean; + loopDetection?: boolean; + ideContext?: boolean; + }; +} /** * Describes the possible termination modes for an agent. diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 8f82486173..159b8e3e9f 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -106,6 +106,7 @@ export * from './utils/secure-browser-launcher.js'; export * from './utils/apiConversionUtils.js'; export * from './utils/channel.js'; export * from './utils/constants.js'; +export * from './utils/sessionUtils.js'; // Export services export * from './services/fileDiscoveryService.js'; @@ -143,6 +144,7 @@ export * from './agents/types.js'; export * from './agents/agentLoader.js'; export * from './agents/local-executor.js'; export * from './agents/agent-scheduler.js'; +export * from './agents/session.js'; // Export specific tool logic export * from './tools/read-file.js'; diff --git a/packages/core/src/utils/sessionUtils.test.ts b/packages/core/src/utils/sessionUtils.test.ts new file mode 100644 index 0000000000..762842e616 --- /dev/null +++ b/packages/core/src/utils/sessionUtils.test.ts @@ -0,0 +1,122 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ +import { describe, it, expect } from 'vitest'; +import { convertSessionToClientHistory } from './sessionUtils.js'; +import { type ConversationRecord } from '../services/chatRecordingService.js'; +import { CoreToolCallStatus } from '../scheduler/types.js'; + +describe('convertSessionToClientHistory', () => { + it('should convert a simple conversation without tool calls', () => { + const messages: ConversationRecord['messages'] = [ + { + id: '1', + type: 'user', + timestamp: '2024-01-01T10:00:00Z', + content: 'Hello', + }, + { + id: '2', + type: 'gemini', + timestamp: '2024-01-01T10:01:00Z', + content: 'Hi there', + }, + ]; + + const history = convertSessionToClientHistory(messages); + + expect(history).toEqual([ + { role: 'user', parts: [{ text: 'Hello' }] }, + { role: 'model', parts: [{ text: 'Hi there' }] }, + ]); + }); + + it('should ignore info, error, and slash commands', () => { + const messages: ConversationRecord['messages'] = [ + { + id: '1', + type: 'info', + timestamp: '2024-01-01T10:00:00Z', + content: 'System info', + }, + { + id: '2', + type: 'user', + timestamp: '2024-01-01T10:01:00Z', + content: '/clear', + }, + { + id: '3', + type: 'user', + timestamp: '2024-01-01T10:02:00Z', + content: '?help', + }, + { + id: '4', + type: 'user', + timestamp: '2024-01-01T10:03:00Z', + content: 'Actual query', + }, + ]; + + const history = convertSessionToClientHistory(messages); + + expect(history).toEqual([ + { role: 'user', parts: [{ text: 'Actual query' }] }, + ]); + }); + + it('should correct map tool calls and their responses', () => { + const messages: ConversationRecord['messages'] = [ + { + id: 'msg1', + type: 'user', + timestamp: '2024-01-01T10:00:00Z', + content: 'List files', + }, + { + id: 'msg2', + type: 'gemini', + timestamp: '2024-01-01T10:01:00Z', + content: 'Let me check.', + toolCalls: [ + { + id: 'call123', + name: 'ls', + args: { dir: '.' }, + status: CoreToolCallStatus.Success, + timestamp: '2024-01-01T10:01:05Z', + result: 'file.txt', + }, + ], + }, + ]; + + const history = convertSessionToClientHistory(messages); + + expect(history).toEqual([ + { role: 'user', parts: [{ text: 'List files' }] }, + { + role: 'model', + parts: [ + { text: 'Let me check.' }, + { functionCall: { name: 'ls', args: { dir: '.' }, id: 'call123' } }, + ], + }, + { + role: 'user', + parts: [ + { + functionResponse: { + id: 'call123', + name: 'ls', + response: { output: 'file.txt' }, + }, + }, + ], + }, + ]); + }); +}); diff --git a/packages/core/src/utils/sessionUtils.ts b/packages/core/src/utils/sessionUtils.ts new file mode 100644 index 0000000000..c037f6657a --- /dev/null +++ b/packages/core/src/utils/sessionUtils.ts @@ -0,0 +1,111 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { type Part } from '@google/genai'; +import { type ConversationRecord } from '../services/chatRecordingService.js'; +import { partListUnionToString } from '../core/geminiRequest.js'; + +/** + * Converts session/conversation data into Gemini client history formats. + */ +export function convertSessionToClientHistory( + messages: ConversationRecord['messages'], +): Array<{ role: 'user' | 'model'; parts: Part[] }> { + const clientHistory: Array<{ role: 'user' | 'model'; parts: Part[] }> = []; + + for (const msg of messages) { + if (msg.type === 'info' || msg.type === 'error' || msg.type === 'warning') { + continue; + } + + if (msg.type === 'user') { + const contentString = partListUnionToString(msg.content); + if ( + contentString.trim().startsWith('/') || + contentString.trim().startsWith('?') + ) { + continue; + } + + clientHistory.push({ + role: 'user', + parts: Array.isArray(msg.content) + ? // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + (msg.content as Part[]) + : [{ text: contentString }], + }); + } else if (msg.type === 'gemini') { + const hasToolCalls = msg.toolCalls && msg.toolCalls.length > 0; + + if (hasToolCalls) { + const modelParts: Part[] = []; + const contentString = partListUnionToString(msg.content); + if (msg.content && contentString.trim()) { + modelParts.push({ text: contentString }); + } + + for (const toolCall of msg.toolCalls!) { + modelParts.push({ + functionCall: { + name: toolCall.name, + args: toolCall.args, + ...(toolCall.id && { id: toolCall.id }), + }, + }); + } + + clientHistory.push({ + role: 'model', + parts: modelParts, + }); + + const functionResponseParts: Part[] = []; + for (const toolCall of msg.toolCalls!) { + if (toolCall.result) { + let responseData: Part; + + if (typeof toolCall.result === 'string') { + responseData = { + functionResponse: { + id: toolCall.id, + name: toolCall.name, + response: { + output: toolCall.result, + }, + }, + }; + } else if (Array.isArray(toolCall.result)) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + functionResponseParts.push(...(toolCall.result as Part[])); + continue; + } else { + responseData = toolCall.result; + } + + functionResponseParts.push(responseData); + } + } + + if (functionResponseParts.length > 0) { + clientHistory.push({ + role: 'user', + parts: functionResponseParts, + }); + } + } else { + const contentString = partListUnionToString(msg.content); + if (msg.content && contentString.trim()) { + clientHistory.push({ + role: 'model', + parts: [{ text: contentString }], + }); + } + } + } + } + + return clientHistory; +}