diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index bae45e1e32..2fae1f55ff 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -370,7 +370,9 @@ export async function runNonInteractive({ return errToThrow; }; - const runTerminalExitHandler = (handler: () => never): never => { + const runTerminalExitHandler = ( + handler: () => void | never, + ): void | never => { terminalProcessExitHandled = true; return handler(); }; diff --git a/packages/cli/src/ui/AppContainer.tsx b/packages/cli/src/ui/AppContainer.tsx index d5b34915bc..de8c7d2122 100644 --- a/packages/cli/src/ui/AppContainer.tsx +++ b/packages/cli/src/ui/AppContainer.tsx @@ -110,6 +110,7 @@ import { computeTerminalTitle } from '../utils/windowTitle.js'; import { useTextBuffer } from './components/shared/text-buffer.js'; import { useLogger } from './hooks/useLogger.js'; import { useGeminiStream } from './hooks/useGeminiStream.js'; +import { useAgentStream } from './hooks/useAgentStream.js'; import { type BackgroundShell } from './hooks/shellCommandProcessor.js'; import { useVim } from './hooks/vim.js'; import { type LoadableSettingScope, SettingScope } from '../config/settings.js'; @@ -1091,6 +1092,8 @@ Logging in with Google... Restarting Gemini CLI to continue. }; }, [config]); + const useAgentProtocol = config.getExperimentalUseAgentProtocol(); + const { streamingState, submitQuery, @@ -1110,27 +1113,50 @@ Logging in with Google... Restarting Gemini CLI to continue. backgroundShells, dismissBackgroundShell, retryStatus, - } = useGeminiStream( - config.getGeminiClient(), - historyManager.history, - historyManager.addItem, - config, - settings, - setDebugMessage, - handleSlashCommand, - shellModeActive, - getPreferredEditor, - onAuthError, - performMemoryRefresh, - modelSwitchedFromQuotaError, - setModelSwitchedFromQuotaError, - onCancelSubmit, - setEmbeddedShellFocused, - terminalWidth, - terminalHeight, - embeddedShellFocused, - consumePendingHints, - ); + // eslint-disable-next-line react-hooks/rules-of-hooks + } = useAgentProtocol + ? useAgentStream( + config.getGeminiClient(), + historyManager.history, + historyManager.addItem, + config, + settings, + setDebugMessage, + handleSlashCommand, + shellModeActive, + getPreferredEditor, + onAuthError, + performMemoryRefresh, + modelSwitchedFromQuotaError, + setModelSwitchedFromQuotaError, + onCancelSubmit, + setEmbeddedShellFocused, + terminalWidth, + terminalHeight, + embeddedShellFocused, + consumePendingHints, + ) + : useGeminiStream( + config.getGeminiClient(), + historyManager.history, + historyManager.addItem, + config, + settings, + setDebugMessage, + handleSlashCommand, + shellModeActive, + getPreferredEditor, + onAuthError, + performMemoryRefresh, + modelSwitchedFromQuotaError, + setModelSwitchedFromQuotaError, + onCancelSubmit, + setEmbeddedShellFocused, + terminalWidth, + terminalHeight, + embeddedShellFocused, + consumePendingHints, + ); const pendingHistoryItems = useMemo( () => [...pendingSlashCommandHistoryItems, ...pendingGeminiHistoryItems], diff --git a/packages/cli/src/ui/hooks/useAgentStream.test.tsx b/packages/cli/src/ui/hooks/useAgentStream.test.tsx new file mode 100644 index 0000000000..82fe921bf6 --- /dev/null +++ b/packages/cli/src/ui/hooks/useAgentStream.test.tsx @@ -0,0 +1,325 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, +} from 'vitest'; +import { act } from 'react'; +import { renderHookWithProviders } from '../../test-utils/render.js'; + +// --- MOCKS --- + +const mockScheduler = vi.hoisted(() => ({ + schedule: vi.fn(), + dispose: vi.fn(), + cancelAll: vi.fn(), +})); + +const mockLegacyAgentSession = vi.hoisted(() => ({ + send: vi.fn().mockResolvedValue({ streamId: 'test-stream-id' }), + subscribe: vi.fn().mockReturnValue(() => {}), + abort: vi.fn().mockResolvedValue(undefined), +})); + +vi.mock('./useToolScheduler.js', () => ({ + useToolScheduler: vi.fn().mockReturnValue([ + [], // toolCalls + vi.fn(), // schedule + vi.fn(), // markToolsAsSubmitted + vi.fn(), // setToolCallsForDisplay + vi.fn(), // cancelAll + 0, // lastToolOutputTime + mockScheduler, // scheduler + ]), +})); + +vi.mock('./useLogger.js', () => ({ + useLogger: vi.fn().mockReturnValue({ + logMessage: vi.fn().mockResolvedValue(undefined), + }), +})); + +vi.mock('../contexts/SessionContext.js', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...(actual as any), + useSessionStats: vi.fn(() => ({ + startNewPrompt: vi.fn(), + })), + }; +}); + +// Mock core classes properly +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const actual = await importOriginal() as any; + return { + ...actual, + LegacyAgentSession: vi.fn().mockImplementation(() => mockLegacyAgentSession), + }; +}); + +// --- END MOCKS --- + +import { useAgentStream } from './useAgentStream.js'; +import { + LegacyAgentSession as MockLegacyAgentSession, +} from '@google/gemini-cli-core'; +import { MessageType, StreamingState } from '../types.js'; + +describe('useAgentStream', () => { + const mockAddItem = vi.fn(); + const mockOnDebugMessage = vi.fn(); + const mockHandleSlashCommand = vi.fn().mockResolvedValue(false); + const mockOnAuthError = vi.fn(); + const mockPerformMemoryRefresh = vi.fn(() => Promise.resolve()); + const mockSetModelSwitchedFromQuotaError = vi.fn(); + const mockOnCancelSubmit = vi.fn(); + const mockSetShellInputFocused = vi.fn(); + + const mockConfig = { + storage: {}, + getSessionId: () => 'test-session', + getExperimentalUseAgentProtocol: () => true, + getApprovalMode: () => 'default', + getMessageBus: () => ({}), + } as any; + + const mockSettings = { + merged: { + billing: { overageStrategy: 'stop' }, + ui: { errorVerbosity: 'full' }, + }, + } as any; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('should initialize LegacyAgentSession on mount', async () => { + await renderHookWithProviders(() => + useAgentStream( + {} as any, + [], + mockAddItem, + mockConfig, + mockSettings, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => undefined, + mockOnAuthError, + mockPerformMemoryRefresh, + false, + mockSetModelSwitchedFromQuotaError, + mockOnCancelSubmit, + mockSetShellInputFocused, + 80, + 24, + ), + ); + + expect(MockLegacyAgentSession).toHaveBeenCalled(); + expect(mockLegacyAgentSession.subscribe).toHaveBeenCalled(); + }); + + it('should call session.send when submitQuery is called', async () => { + const { result } = await renderHookWithProviders(() => + useAgentStream( + {} as any, + [], + mockAddItem, + mockConfig, + mockSettings, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => undefined, + mockOnAuthError, + mockPerformMemoryRefresh, + false, + mockSetModelSwitchedFromQuotaError, + mockOnCancelSubmit, + mockSetShellInputFocused, + 80, + 24, + ), + ); + + await act(async () => { + await result.current.submitQuery('hello'); + }); + + expect(mockLegacyAgentSession.send).toHaveBeenCalledWith({ + message: [{ type: 'text', text: 'hello' }], + }); + expect(mockAddItem).toHaveBeenCalledWith( + expect.objectContaining({ type: MessageType.USER, text: 'hello' }), + expect.any(Number), + ); + }); + + it('should update streamingState based on agent_start and agent_end events', async () => { + const { result } = await renderHookWithProviders(() => + useAgentStream( + {} as any, + [], + mockAddItem, + mockConfig, + mockSettings, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => undefined, + mockOnAuthError, + mockPerformMemoryRefresh, + false, + mockSetModelSwitchedFromQuotaError, + mockOnCancelSubmit, + mockSetShellInputFocused, + 80, + 24, + ), + ); + + const eventHandler = (mockLegacyAgentSession.subscribe as any).mock.calls[0][0]; + + expect(result.current.streamingState).toBe(StreamingState.Idle); + + act(() => { + eventHandler({ type: 'agent_start' }); + }); + expect(result.current.streamingState).toBe(StreamingState.Responding); + + act(() => { + eventHandler({ type: 'agent_end', reason: 'completed' }); + }); + expect(result.current.streamingState).toBe(StreamingState.Idle); + }); + + it('should accumulate text content and update pendingHistoryItems', async () => { + const { result } = await renderHookWithProviders(() => + useAgentStream( + {} as any, + [], + mockAddItem, + mockConfig, + mockSettings, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => undefined, + mockOnAuthError, + mockPerformMemoryRefresh, + false, + mockSetModelSwitchedFromQuotaError, + mockOnCancelSubmit, + mockSetShellInputFocused, + 80, + 24, + ), + ); + + const eventHandler = (mockLegacyAgentSession.subscribe as any).mock.calls[0][0]; + + act(() => { + eventHandler({ + type: 'message', + role: 'agent', + content: [{ type: 'text', text: 'Hello' }], + }); + }); + + expect(result.current.pendingHistoryItems).toHaveLength(1); + expect(result.current.pendingHistoryItems[0]).toMatchObject({ + type: 'gemini', + text: 'Hello', + }); + + act(() => { + eventHandler({ + type: 'message', + role: 'agent', + content: [{ type: 'text', text: ' world' }], + }); + }); + + expect(result.current.pendingHistoryItems[0].text).toBe('Hello world'); + }); + + it('should process thought events and update thought state', async () => { + const { result } = await renderHookWithProviders(() => + useAgentStream( + {} as any, + [], + mockAddItem, + mockConfig, + mockSettings, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => undefined, + mockOnAuthError, + mockPerformMemoryRefresh, + false, + mockSetModelSwitchedFromQuotaError, + mockOnCancelSubmit, + mockSetShellInputFocused, + 80, + 24, + ), + ); + + const eventHandler = (mockLegacyAgentSession.subscribe as any).mock.calls[0][0]; + + act(() => { + eventHandler({ + type: 'message', + role: 'agent', + content: [{ type: 'thought', thought: '**Thinking** about tests' }], + }); + }); + + expect(result.current.thought).toEqual({ + subject: 'Thinking', + description: 'about tests', + }); + }); + + it('should call session.abort when cancelOngoingRequest is called', async () => { + const { result } = await renderHookWithProviders(() => + useAgentStream( + {} as any, + [], + mockAddItem, + mockConfig, + mockSettings, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => undefined, + mockOnAuthError, + mockPerformMemoryRefresh, + false, + mockSetModelSwitchedFromQuotaError, + mockOnCancelSubmit, + mockSetShellInputFocused, + 80, + 24, + ), + ); + + await act(async () => { + await result.current.cancelOngoingRequest(); + }); + + expect(mockLegacyAgentSession.abort).toHaveBeenCalled(); + expect(mockOnCancelSubmit).toHaveBeenCalledWith(false); + }); +}); diff --git a/packages/cli/src/ui/hooks/useAgentStream.ts b/packages/cli/src/ui/hooks/useAgentStream.ts new file mode 100644 index 0000000000..1d664f7815 --- /dev/null +++ b/packages/cli/src/ui/hooks/useAgentStream.ts @@ -0,0 +1,290 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { useState, useRef, useCallback, useEffect, useMemo } from 'react'; +import { + getErrorMessage, + MessageSenderType, + ApprovalMode, + debugLogger, + LegacyAgentSession, + geminiPartsToContentParts, + parseThought, +} from '@google/gemini-cli-core'; +import type { + Config, + EditorType, + GeminiClient, + ThoughtSummary, + RetryAttemptPayload, + AgentEvent, +} from '@google/gemini-cli-core'; +import { type PartListUnion } from '@google/genai'; +import type { + HistoryItem, + HistoryItemWithoutId, + LoopDetectionConfirmationRequest, +} from '../types.js'; +import { StreamingState, MessageType } from '../types.js'; +import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js'; +import { type BackgroundShell } from './shellCommandProcessor.js'; +import type { UseHistoryManagerReturn } from './useHistoryManager.js'; +import { useLogger } from './useLogger.js'; +import { + useToolScheduler, +} from './useToolScheduler.js'; +import { useSessionStats } from '../contexts/SessionContext.js'; +import type { LoadedSettings } from '../../config/settings.js'; +import type { SlashCommandProcessorResult } from '../types.js'; +import { useStateAndRef } from './useStateAndRef.js'; + +/** + * useAgentStream implements the interactive agent loop using the LegacyAgentSession (AgentProtocol). + * It attempts to maintain parity with useGeminiStream while consolidating model/tool orchestration + * into the unified core API. + */ +export const useAgentStream = ( + geminiClient: GeminiClient, + _history: HistoryItem[], + addItem: UseHistoryManagerReturn['addItem'], + config: Config, + _settings: LoadedSettings, + _onDebugMessage: (message: string) => void, + _handleSlashCommand: ( + cmd: PartListUnion, + ) => Promise, + _shellModeActive: boolean, + getPreferredEditor: () => EditorType | undefined, + _onAuthError: (error: string) => void, + _performMemoryRefresh: () => Promise, + _modelSwitchedFromQuotaError: boolean, + _setModelSwitchedFromQuotaError: React.Dispatch>, + onCancelSubmit: (shouldRestorePrompt?: boolean) => void, + _setShellInputFocused: (value: boolean) => void, + _terminalWidth: number, + _terminalHeight: number, + _isShellFocused?: boolean, + _consumeUserHint?: () => string | null, +) => { + const [initError] = useState(null); + const [retryStatus] = useState( + null, + ); + const [streamingState, setStreamingState] = useState( + StreamingState.Idle, + ); + const [thought, setThought] = useState(null); + + // Track the current session instance + const sessionRef = useRef(null); + const currentStreamIdRef = useRef(null); + const userMessageTimestampRef = useRef(0); + const geminiMessageBufferRef = useRef(''); + const [pendingHistoryItem, pendingHistoryItemRef, setPendingHistoryItem] = + useStateAndRef(null); + + const [ + toolCalls, + _schedule, + _markToolsAsSubmitted, + _setToolCallsForDisplay, + cancelAllToolCalls, + lastOutputTime, + scheduler, + ] = useToolScheduler( + async (_completedTools) => { + // LegacyAgentSession owns the loop, so we don't need to trigger next turns here. + }, + config, + getPreferredEditor, + ); + + const { startNewPrompt } = useSessionStats(); + const logger = useLogger(config.storage); + + const activePtyId = undefined; + const backgroundShellCount = 0; + const isBackgroundShellVisible = false; + const toggleBackgroundShell = useCallback(() => {}, []); + const backgroundCurrentShell = undefined; + const backgroundShells = new Map(); + const dismissBackgroundShell = useCallback(async (_pid: number) => {}, []); + + // TODO: Support LoopDetection confirmation requests + const [ + loopDetectionConfirmationRequest, + ] = useState(null); + + const cancelOngoingRequest = useCallback(async () => { + if (sessionRef.current) { + await sessionRef.current.abort(); + cancelAllToolCalls(new AbortController().signal); + setStreamingState(StreamingState.Idle); + onCancelSubmit(false); + } + }, [cancelAllToolCalls, onCancelSubmit]); + + // TODO: Support native handleApprovalModeChange for Plan Mode + const handleApprovalModeChange = useCallback( + async (newApprovalMode: ApprovalMode) => { + debugLogger.debug(`Approval mode changed to ${newApprovalMode} (stub)`); + }, + [], + ); + + const handleEvent = useCallback( + (event: AgentEvent) => { + switch (event.type) { + case 'agent_start': + setStreamingState(StreamingState.Responding); + break; + case 'agent_end': + setStreamingState(StreamingState.Idle); + if (pendingHistoryItemRef.current) { + addItem( + pendingHistoryItemRef.current, + userMessageTimestampRef.current, + ); + setPendingHistoryItem(null); + } + break; + case 'message': + if (event.role === 'agent') { + for (const part of event.content) { + if (part.type === 'text') { + geminiMessageBufferRef.current += part.text; + // Update pending history item with incremental text + const splitPoint = findLastSafeSplitPoint( + geminiMessageBufferRef.current, + ); + if (splitPoint === geminiMessageBufferRef.current.length) { + setPendingHistoryItem({ + type: 'gemini', + text: geminiMessageBufferRef.current, + }); + } else { + const before = geminiMessageBufferRef.current.substring( + 0, + splitPoint, + ); + const after = + geminiMessageBufferRef.current.substring(splitPoint); + addItem( + { type: 'gemini', text: before }, + userMessageTimestampRef.current, + ); + geminiMessageBufferRef.current = after; + setPendingHistoryItem({ + type: 'gemini_content', + text: after, + }); + } + } else if (part.type === 'thought') { + setThought(parseThought(part.thought)); + } + } + } + break; + case 'tool_request': + // UI state is handled automatically by useToolScheduler via MessageBus + break; + case 'tool_response': + // UI state is handled automatically by useToolScheduler via MessageBus + break; + case 'error': + addItem( + { type: MessageType.ERROR, text: event.message }, + userMessageTimestampRef.current, + ); + break; + default: + break; + } + }, + [addItem, pendingHistoryItemRef, setPendingHistoryItem], + ); + + useEffect(() => { + if (sessionRef.current) { + return sessionRef.current.subscribe(handleEvent); + } + return undefined; + }, [handleEvent]); + + // Handle initialization of the session + if (!sessionRef.current) { + sessionRef.current = new LegacyAgentSession({ + client: geminiClient, + scheduler, + config, + promptId: '', + }); + } + + const submitQuery = useCallback( + async ( + query: PartListUnion, + options?: { isContinuation: boolean }, + _prompt_id?: string, + ) => { + if (!sessionRef.current) return; + + const timestamp = Date.now(); + userMessageTimestampRef.current = timestamp; + geminiMessageBufferRef.current = ''; + + if (!options?.isContinuation) { + if (typeof query === 'string') { + addItem({ type: MessageType.USER, text: query }, timestamp); + void logger?.logMessage(MessageSenderType.USER, query); + } + startNewPrompt(); + } + + const parts = geminiPartsToContentParts( + typeof query === 'string' ? [{ text: query }] : (query as any[]), + ); + + try { + const { streamId } = await sessionRef.current.send({ + message: parts, + }); + currentStreamIdRef.current = streamId; + } catch (err) { + addItem( + { type: MessageType.ERROR, text: getErrorMessage(err) }, + timestamp, + ); + } + }, + [addItem, logger, startNewPrompt], + ); + + const pendingHistoryItems = useMemo(() => { + return pendingHistoryItem ? [pendingHistoryItem] : []; + }, [pendingHistoryItem]); + + return { + streamingState, + submitQuery, + initError, + pendingHistoryItems, + thought, + cancelOngoingRequest, + pendingToolCalls: toolCalls, + handleApprovalModeChange, + activePtyId, + loopDetectionConfirmationRequest, + lastOutputTime, + backgroundShellCount, + isBackgroundShellVisible, + toggleBackgroundShell, + backgroundCurrentShell, + backgroundShells, + retryStatus, + dismissBackgroundShell, + }; +}; diff --git a/packages/cli/src/ui/hooks/useToolScheduler.ts b/packages/cli/src/ui/hooks/useToolScheduler.ts index 7d0933506a..670a8b76d5 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.ts @@ -75,6 +75,7 @@ export function useToolScheduler( React.Dispatch>, CancelAllFn, number, + Scheduler, ] { // State stores tool calls organized by their originating schedulerId const [toolCallsMap, setToolCallsMap] = useState< @@ -257,6 +258,7 @@ export function useToolScheduler( setToolCallsForDisplay, cancelAll, lastToolOutputTime, + scheduler, ]; }