From d3a7fc39e3b3c6b0b2ef6f9e53b234ccf9e9c1c2 Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Tue, 24 Mar 2026 11:37:59 -0700 Subject: [PATCH] feat: implement experimental useAgentStream with unified agent protocol This change introduces an experimental 'useAgentProtocol' flag to interactive mode. When enabled, the UI uses the new 'useAgentStream' hook which leverages the core 'LegacyAgentSession' (AgentProtocol) instead of the custom 'useGeminiStream' logic. Key changes: - Added 'useAgentProtocol' experimental setting to CLI and Core config. - Implemented 'useAgentStream' hook with basic interaction and thought support. - Modified 'useToolScheduler' to expose its internal Scheduler instance to ensure implementation parity. - Updated 'AppContainer' to conditionally branch between implementations via ternary operator. - Added comprehensive unit tests for the new hook. --- packages/cli/src/nonInteractiveCli.ts | 4 +- packages/cli/src/ui/AppContainer.tsx | 68 ++-- .../cli/src/ui/hooks/useAgentStream.test.tsx | 325 ++++++++++++++++++ packages/cli/src/ui/hooks/useAgentStream.ts | 290 ++++++++++++++++ packages/cli/src/ui/hooks/useToolScheduler.ts | 2 + 5 files changed, 667 insertions(+), 22 deletions(-) create mode 100644 packages/cli/src/ui/hooks/useAgentStream.test.tsx create mode 100644 packages/cli/src/ui/hooks/useAgentStream.ts diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index bae45e1e32..2fae1f55ff 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -370,7 +370,9 @@ export async function runNonInteractive({ return errToThrow; }; - const runTerminalExitHandler = (handler: () => never): never => { + const runTerminalExitHandler = ( + handler: () => void | never, + ): void | never => { terminalProcessExitHandled = true; return handler(); }; diff --git a/packages/cli/src/ui/AppContainer.tsx b/packages/cli/src/ui/AppContainer.tsx index d5b34915bc..de8c7d2122 100644 --- a/packages/cli/src/ui/AppContainer.tsx +++ b/packages/cli/src/ui/AppContainer.tsx @@ -110,6 +110,7 @@ import { computeTerminalTitle } from '../utils/windowTitle.js'; import { useTextBuffer } from './components/shared/text-buffer.js'; import { useLogger } from './hooks/useLogger.js'; import { useGeminiStream } from './hooks/useGeminiStream.js'; +import { useAgentStream } from './hooks/useAgentStream.js'; import { type BackgroundShell } from './hooks/shellCommandProcessor.js'; import { useVim } from './hooks/vim.js'; import { type LoadableSettingScope, SettingScope } from '../config/settings.js'; @@ -1091,6 +1092,8 @@ Logging in with Google... Restarting Gemini CLI to continue. }; }, [config]); + const useAgentProtocol = config.getExperimentalUseAgentProtocol(); + const { streamingState, submitQuery, @@ -1110,27 +1113,50 @@ Logging in with Google... Restarting Gemini CLI to continue. backgroundShells, dismissBackgroundShell, retryStatus, - } = useGeminiStream( - config.getGeminiClient(), - historyManager.history, - historyManager.addItem, - config, - settings, - setDebugMessage, - handleSlashCommand, - shellModeActive, - getPreferredEditor, - onAuthError, - performMemoryRefresh, - modelSwitchedFromQuotaError, - setModelSwitchedFromQuotaError, - onCancelSubmit, - setEmbeddedShellFocused, - terminalWidth, - terminalHeight, - embeddedShellFocused, - consumePendingHints, - ); + // eslint-disable-next-line react-hooks/rules-of-hooks + } = useAgentProtocol + ? useAgentStream( + config.getGeminiClient(), + historyManager.history, + historyManager.addItem, + config, + settings, + setDebugMessage, + handleSlashCommand, + shellModeActive, + getPreferredEditor, + onAuthError, + performMemoryRefresh, + modelSwitchedFromQuotaError, + setModelSwitchedFromQuotaError, + onCancelSubmit, + setEmbeddedShellFocused, + terminalWidth, + terminalHeight, + embeddedShellFocused, + consumePendingHints, + ) + : useGeminiStream( + config.getGeminiClient(), + historyManager.history, + historyManager.addItem, + config, + settings, + setDebugMessage, + handleSlashCommand, + shellModeActive, + getPreferredEditor, + onAuthError, + performMemoryRefresh, + modelSwitchedFromQuotaError, + setModelSwitchedFromQuotaError, + onCancelSubmit, + setEmbeddedShellFocused, + terminalWidth, + terminalHeight, + embeddedShellFocused, + consumePendingHints, + ); const pendingHistoryItems = useMemo( () => [...pendingSlashCommandHistoryItems, ...pendingGeminiHistoryItems], diff --git a/packages/cli/src/ui/hooks/useAgentStream.test.tsx b/packages/cli/src/ui/hooks/useAgentStream.test.tsx new file mode 100644 index 0000000000..82fe921bf6 --- /dev/null +++ b/packages/cli/src/ui/hooks/useAgentStream.test.tsx @@ -0,0 +1,325 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, +} from 'vitest'; +import { act } from 'react'; +import { renderHookWithProviders } from '../../test-utils/render.js'; + +// --- MOCKS --- + +const mockScheduler = vi.hoisted(() => ({ + schedule: vi.fn(), + dispose: vi.fn(), + cancelAll: vi.fn(), +})); + +const mockLegacyAgentSession = vi.hoisted(() => ({ + send: vi.fn().mockResolvedValue({ streamId: 'test-stream-id' }), + subscribe: vi.fn().mockReturnValue(() => {}), + abort: vi.fn().mockResolvedValue(undefined), +})); + +vi.mock('./useToolScheduler.js', () => ({ + useToolScheduler: vi.fn().mockReturnValue([ + [], // toolCalls + vi.fn(), // schedule + vi.fn(), // markToolsAsSubmitted + vi.fn(), // setToolCallsForDisplay + vi.fn(), // cancelAll + 0, // lastToolOutputTime + mockScheduler, // scheduler + ]), +})); + +vi.mock('./useLogger.js', () => ({ + useLogger: vi.fn().mockReturnValue({ + logMessage: vi.fn().mockResolvedValue(undefined), + }), +})); + +vi.mock('../contexts/SessionContext.js', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...(actual as any), + useSessionStats: vi.fn(() => ({ + startNewPrompt: vi.fn(), + })), + }; +}); + +// Mock core classes properly +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const actual = await importOriginal() as any; + return { + ...actual, + LegacyAgentSession: vi.fn().mockImplementation(() => mockLegacyAgentSession), + }; +}); + +// --- END MOCKS --- + +import { useAgentStream } from './useAgentStream.js'; +import { + LegacyAgentSession as MockLegacyAgentSession, +} from '@google/gemini-cli-core'; +import { MessageType, StreamingState } from '../types.js'; + +describe('useAgentStream', () => { + const mockAddItem = vi.fn(); + const mockOnDebugMessage = vi.fn(); + const mockHandleSlashCommand = vi.fn().mockResolvedValue(false); + const mockOnAuthError = vi.fn(); + const mockPerformMemoryRefresh = vi.fn(() => Promise.resolve()); + const mockSetModelSwitchedFromQuotaError = vi.fn(); + const mockOnCancelSubmit = vi.fn(); + const mockSetShellInputFocused = vi.fn(); + + const mockConfig = { + storage: {}, + getSessionId: () => 'test-session', + getExperimentalUseAgentProtocol: () => true, + getApprovalMode: () => 'default', + getMessageBus: () => ({}), + } as any; + + const mockSettings = { + merged: { + billing: { overageStrategy: 'stop' }, + ui: { errorVerbosity: 'full' }, + }, + } as any; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('should initialize LegacyAgentSession on mount', async () => { + await renderHookWithProviders(() => + useAgentStream( + {} as any, + [], + mockAddItem, + mockConfig, + mockSettings, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => undefined, + mockOnAuthError, + mockPerformMemoryRefresh, + false, + mockSetModelSwitchedFromQuotaError, + mockOnCancelSubmit, + mockSetShellInputFocused, + 80, + 24, + ), + ); + + expect(MockLegacyAgentSession).toHaveBeenCalled(); + expect(mockLegacyAgentSession.subscribe).toHaveBeenCalled(); + }); + + it('should call session.send when submitQuery is called', async () => { + const { result } = await renderHookWithProviders(() => + useAgentStream( + {} as any, + [], + mockAddItem, + mockConfig, + mockSettings, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => undefined, + mockOnAuthError, + mockPerformMemoryRefresh, + false, + mockSetModelSwitchedFromQuotaError, + mockOnCancelSubmit, + mockSetShellInputFocused, + 80, + 24, + ), + ); + + await act(async () => { + await result.current.submitQuery('hello'); + }); + + expect(mockLegacyAgentSession.send).toHaveBeenCalledWith({ + message: [{ type: 'text', text: 'hello' }], + }); + expect(mockAddItem).toHaveBeenCalledWith( + expect.objectContaining({ type: MessageType.USER, text: 'hello' }), + expect.any(Number), + ); + }); + + it('should update streamingState based on agent_start and agent_end events', async () => { + const { result } = await renderHookWithProviders(() => + useAgentStream( + {} as any, + [], + mockAddItem, + mockConfig, + mockSettings, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => undefined, + mockOnAuthError, + mockPerformMemoryRefresh, + false, + mockSetModelSwitchedFromQuotaError, + mockOnCancelSubmit, + mockSetShellInputFocused, + 80, + 24, + ), + ); + + const eventHandler = (mockLegacyAgentSession.subscribe as any).mock.calls[0][0]; + + expect(result.current.streamingState).toBe(StreamingState.Idle); + + act(() => { + eventHandler({ type: 'agent_start' }); + }); + expect(result.current.streamingState).toBe(StreamingState.Responding); + + act(() => { + eventHandler({ type: 'agent_end', reason: 'completed' }); + }); + expect(result.current.streamingState).toBe(StreamingState.Idle); + }); + + it('should accumulate text content and update pendingHistoryItems', async () => { + const { result } = await renderHookWithProviders(() => + useAgentStream( + {} as any, + [], + mockAddItem, + mockConfig, + mockSettings, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => undefined, + mockOnAuthError, + mockPerformMemoryRefresh, + false, + mockSetModelSwitchedFromQuotaError, + mockOnCancelSubmit, + mockSetShellInputFocused, + 80, + 24, + ), + ); + + const eventHandler = (mockLegacyAgentSession.subscribe as any).mock.calls[0][0]; + + act(() => { + eventHandler({ + type: 'message', + role: 'agent', + content: [{ type: 'text', text: 'Hello' }], + }); + }); + + expect(result.current.pendingHistoryItems).toHaveLength(1); + expect(result.current.pendingHistoryItems[0]).toMatchObject({ + type: 'gemini', + text: 'Hello', + }); + + act(() => { + eventHandler({ + type: 'message', + role: 'agent', + content: [{ type: 'text', text: ' world' }], + }); + }); + + expect(result.current.pendingHistoryItems[0].text).toBe('Hello world'); + }); + + it('should process thought events and update thought state', async () => { + const { result } = await renderHookWithProviders(() => + useAgentStream( + {} as any, + [], + mockAddItem, + mockConfig, + mockSettings, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => undefined, + mockOnAuthError, + mockPerformMemoryRefresh, + false, + mockSetModelSwitchedFromQuotaError, + mockOnCancelSubmit, + mockSetShellInputFocused, + 80, + 24, + ), + ); + + const eventHandler = (mockLegacyAgentSession.subscribe as any).mock.calls[0][0]; + + act(() => { + eventHandler({ + type: 'message', + role: 'agent', + content: [{ type: 'thought', thought: '**Thinking** about tests' }], + }); + }); + + expect(result.current.thought).toEqual({ + subject: 'Thinking', + description: 'about tests', + }); + }); + + it('should call session.abort when cancelOngoingRequest is called', async () => { + const { result } = await renderHookWithProviders(() => + useAgentStream( + {} as any, + [], + mockAddItem, + mockConfig, + mockSettings, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => undefined, + mockOnAuthError, + mockPerformMemoryRefresh, + false, + mockSetModelSwitchedFromQuotaError, + mockOnCancelSubmit, + mockSetShellInputFocused, + 80, + 24, + ), + ); + + await act(async () => { + await result.current.cancelOngoingRequest(); + }); + + expect(mockLegacyAgentSession.abort).toHaveBeenCalled(); + expect(mockOnCancelSubmit).toHaveBeenCalledWith(false); + }); +}); diff --git a/packages/cli/src/ui/hooks/useAgentStream.ts b/packages/cli/src/ui/hooks/useAgentStream.ts new file mode 100644 index 0000000000..1d664f7815 --- /dev/null +++ b/packages/cli/src/ui/hooks/useAgentStream.ts @@ -0,0 +1,290 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { useState, useRef, useCallback, useEffect, useMemo } from 'react'; +import { + getErrorMessage, + MessageSenderType, + ApprovalMode, + debugLogger, + LegacyAgentSession, + geminiPartsToContentParts, + parseThought, +} from '@google/gemini-cli-core'; +import type { + Config, + EditorType, + GeminiClient, + ThoughtSummary, + RetryAttemptPayload, + AgentEvent, +} from '@google/gemini-cli-core'; +import { type PartListUnion } from '@google/genai'; +import type { + HistoryItem, + HistoryItemWithoutId, + LoopDetectionConfirmationRequest, +} from '../types.js'; +import { StreamingState, MessageType } from '../types.js'; +import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js'; +import { type BackgroundShell } from './shellCommandProcessor.js'; +import type { UseHistoryManagerReturn } from './useHistoryManager.js'; +import { useLogger } from './useLogger.js'; +import { + useToolScheduler, +} from './useToolScheduler.js'; +import { useSessionStats } from '../contexts/SessionContext.js'; +import type { LoadedSettings } from '../../config/settings.js'; +import type { SlashCommandProcessorResult } from '../types.js'; +import { useStateAndRef } from './useStateAndRef.js'; + +/** + * useAgentStream implements the interactive agent loop using the LegacyAgentSession (AgentProtocol). + * It attempts to maintain parity with useGeminiStream while consolidating model/tool orchestration + * into the unified core API. + */ +export const useAgentStream = ( + geminiClient: GeminiClient, + _history: HistoryItem[], + addItem: UseHistoryManagerReturn['addItem'], + config: Config, + _settings: LoadedSettings, + _onDebugMessage: (message: string) => void, + _handleSlashCommand: ( + cmd: PartListUnion, + ) => Promise, + _shellModeActive: boolean, + getPreferredEditor: () => EditorType | undefined, + _onAuthError: (error: string) => void, + _performMemoryRefresh: () => Promise, + _modelSwitchedFromQuotaError: boolean, + _setModelSwitchedFromQuotaError: React.Dispatch>, + onCancelSubmit: (shouldRestorePrompt?: boolean) => void, + _setShellInputFocused: (value: boolean) => void, + _terminalWidth: number, + _terminalHeight: number, + _isShellFocused?: boolean, + _consumeUserHint?: () => string | null, +) => { + const [initError] = useState(null); + const [retryStatus] = useState( + null, + ); + const [streamingState, setStreamingState] = useState( + StreamingState.Idle, + ); + const [thought, setThought] = useState(null); + + // Track the current session instance + const sessionRef = useRef(null); + const currentStreamIdRef = useRef(null); + const userMessageTimestampRef = useRef(0); + const geminiMessageBufferRef = useRef(''); + const [pendingHistoryItem, pendingHistoryItemRef, setPendingHistoryItem] = + useStateAndRef(null); + + const [ + toolCalls, + _schedule, + _markToolsAsSubmitted, + _setToolCallsForDisplay, + cancelAllToolCalls, + lastOutputTime, + scheduler, + ] = useToolScheduler( + async (_completedTools) => { + // LegacyAgentSession owns the loop, so we don't need to trigger next turns here. + }, + config, + getPreferredEditor, + ); + + const { startNewPrompt } = useSessionStats(); + const logger = useLogger(config.storage); + + const activePtyId = undefined; + const backgroundShellCount = 0; + const isBackgroundShellVisible = false; + const toggleBackgroundShell = useCallback(() => {}, []); + const backgroundCurrentShell = undefined; + const backgroundShells = new Map(); + const dismissBackgroundShell = useCallback(async (_pid: number) => {}, []); + + // TODO: Support LoopDetection confirmation requests + const [ + loopDetectionConfirmationRequest, + ] = useState(null); + + const cancelOngoingRequest = useCallback(async () => { + if (sessionRef.current) { + await sessionRef.current.abort(); + cancelAllToolCalls(new AbortController().signal); + setStreamingState(StreamingState.Idle); + onCancelSubmit(false); + } + }, [cancelAllToolCalls, onCancelSubmit]); + + // TODO: Support native handleApprovalModeChange for Plan Mode + const handleApprovalModeChange = useCallback( + async (newApprovalMode: ApprovalMode) => { + debugLogger.debug(`Approval mode changed to ${newApprovalMode} (stub)`); + }, + [], + ); + + const handleEvent = useCallback( + (event: AgentEvent) => { + switch (event.type) { + case 'agent_start': + setStreamingState(StreamingState.Responding); + break; + case 'agent_end': + setStreamingState(StreamingState.Idle); + if (pendingHistoryItemRef.current) { + addItem( + pendingHistoryItemRef.current, + userMessageTimestampRef.current, + ); + setPendingHistoryItem(null); + } + break; + case 'message': + if (event.role === 'agent') { + for (const part of event.content) { + if (part.type === 'text') { + geminiMessageBufferRef.current += part.text; + // Update pending history item with incremental text + const splitPoint = findLastSafeSplitPoint( + geminiMessageBufferRef.current, + ); + if (splitPoint === geminiMessageBufferRef.current.length) { + setPendingHistoryItem({ + type: 'gemini', + text: geminiMessageBufferRef.current, + }); + } else { + const before = geminiMessageBufferRef.current.substring( + 0, + splitPoint, + ); + const after = + geminiMessageBufferRef.current.substring(splitPoint); + addItem( + { type: 'gemini', text: before }, + userMessageTimestampRef.current, + ); + geminiMessageBufferRef.current = after; + setPendingHistoryItem({ + type: 'gemini_content', + text: after, + }); + } + } else if (part.type === 'thought') { + setThought(parseThought(part.thought)); + } + } + } + break; + case 'tool_request': + // UI state is handled automatically by useToolScheduler via MessageBus + break; + case 'tool_response': + // UI state is handled automatically by useToolScheduler via MessageBus + break; + case 'error': + addItem( + { type: MessageType.ERROR, text: event.message }, + userMessageTimestampRef.current, + ); + break; + default: + break; + } + }, + [addItem, pendingHistoryItemRef, setPendingHistoryItem], + ); + + useEffect(() => { + if (sessionRef.current) { + return sessionRef.current.subscribe(handleEvent); + } + return undefined; + }, [handleEvent]); + + // Handle initialization of the session + if (!sessionRef.current) { + sessionRef.current = new LegacyAgentSession({ + client: geminiClient, + scheduler, + config, + promptId: '', + }); + } + + const submitQuery = useCallback( + async ( + query: PartListUnion, + options?: { isContinuation: boolean }, + _prompt_id?: string, + ) => { + if (!sessionRef.current) return; + + const timestamp = Date.now(); + userMessageTimestampRef.current = timestamp; + geminiMessageBufferRef.current = ''; + + if (!options?.isContinuation) { + if (typeof query === 'string') { + addItem({ type: MessageType.USER, text: query }, timestamp); + void logger?.logMessage(MessageSenderType.USER, query); + } + startNewPrompt(); + } + + const parts = geminiPartsToContentParts( + typeof query === 'string' ? [{ text: query }] : (query as any[]), + ); + + try { + const { streamId } = await sessionRef.current.send({ + message: parts, + }); + currentStreamIdRef.current = streamId; + } catch (err) { + addItem( + { type: MessageType.ERROR, text: getErrorMessage(err) }, + timestamp, + ); + } + }, + [addItem, logger, startNewPrompt], + ); + + const pendingHistoryItems = useMemo(() => { + return pendingHistoryItem ? [pendingHistoryItem] : []; + }, [pendingHistoryItem]); + + return { + streamingState, + submitQuery, + initError, + pendingHistoryItems, + thought, + cancelOngoingRequest, + pendingToolCalls: toolCalls, + handleApprovalModeChange, + activePtyId, + loopDetectionConfirmationRequest, + lastOutputTime, + backgroundShellCount, + isBackgroundShellVisible, + toggleBackgroundShell, + backgroundCurrentShell, + backgroundShells, + retryStatus, + dismissBackgroundShell, + }; +}; diff --git a/packages/cli/src/ui/hooks/useToolScheduler.ts b/packages/cli/src/ui/hooks/useToolScheduler.ts index 7d0933506a..670a8b76d5 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.ts @@ -75,6 +75,7 @@ export function useToolScheduler( React.Dispatch>, CancelAllFn, number, + Scheduler, ] { // State stores tool calls organized by their originating schedulerId const [toolCallsMap, setToolCallsMap] = useState< @@ -257,6 +258,7 @@ export function useToolScheduler( setToolCallsForDisplay, cancelAll, lastToolOutputTime, + scheduler, ]; }