feat(cli): implement useAgentStream hook (#24292)

Co-authored-by: Adam Weidman <adamfweidman@gmail.com>
Co-authored-by: Adam Weidman <adamfweidman@google.com>
This commit is contained in:
Michael Bleigh
2026-04-09 12:06:27 -07:00
committed by GitHub
parent f387e456be
commit e406856343
7 changed files with 758 additions and 9 deletions

View File

@@ -0,0 +1,207 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { act } from 'react';
import type { LegacyAgentProtocol } from '@google/gemini-cli-core';
import { renderHookWithProviders } from '../../test-utils/render.js';
// --- MOCKS ---
const mockLegacyAgentProtocol = vi.hoisted(() => ({
send: vi.fn().mockResolvedValue({ streamId: 'test-stream-id' }),
subscribe: vi.fn().mockReturnValue(() => {}),
abort: vi.fn().mockResolvedValue(undefined),
}));
vi.mock('../contexts/SessionContext.js', async (importOriginal) => {
const actual = await importOriginal<Record<string, unknown>>();
return {
...actual,
useSessionStats: vi.fn(() => ({
startNewPrompt: vi.fn(),
})),
};
});
// --- END MOCKS ---
import { useAgentStream } from './useAgentStream.js';
import { MessageType, StreamingState } from '../types.js';
describe('useAgentStream', () => {
const mockAddItem = vi.fn();
const mockOnCancelSubmit = vi.fn();
beforeEach(() => {
vi.clearAllMocks();
});
it('should initialize on mount', async () => {
await renderHookWithProviders(() =>
useAgentStream({
agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol,
addItem: mockAddItem,
onCancelSubmit: mockOnCancelSubmit,
isShellFocused: false,
}),
);
expect(mockLegacyAgentProtocol.subscribe).toHaveBeenCalled();
});
it('should call agent.send when submitQuery is called', async () => {
const { result } = await renderHookWithProviders(() =>
useAgentStream({
agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol,
addItem: mockAddItem,
onCancelSubmit: mockOnCancelSubmit,
isShellFocused: false,
}),
);
await act(async () => {
await result.current.submitQuery('hello');
});
expect(mockLegacyAgentProtocol.send).toHaveBeenCalledWith({
message: { content: [{ type: 'text', text: 'hello' }] },
});
expect(mockAddItem).toHaveBeenCalledWith(
expect.objectContaining({ type: MessageType.USER, text: 'hello' }),
expect.any(Number),
);
});
it('should update streamingState based on agent_start and agent_end events', async () => {
const { result } = await renderHookWithProviders(() =>
useAgentStream({
agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol,
addItem: mockAddItem,
onCancelSubmit: mockOnCancelSubmit,
isShellFocused: false,
}),
);
const eventHandler = vi.mocked(mockLegacyAgentProtocol.subscribe).mock
.calls[0][0];
expect(result.current.streamingState).toBe(StreamingState.Idle);
act(() => {
eventHandler({
type: 'agent_start',
id: '1',
timestamp: '',
streamId: '',
});
});
expect(result.current.streamingState).toBe(StreamingState.Responding);
act(() => {
eventHandler({
type: 'agent_end',
reason: 'completed',
id: '2',
timestamp: '',
streamId: '',
});
});
expect(result.current.streamingState).toBe(StreamingState.Idle);
});
it('should accumulate text content and update pendingHistoryItems', async () => {
const { result } = await renderHookWithProviders(() =>
useAgentStream({
agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol,
addItem: mockAddItem,
onCancelSubmit: mockOnCancelSubmit,
isShellFocused: false,
}),
);
const eventHandler = vi.mocked(mockLegacyAgentProtocol.subscribe).mock
.calls[0][0];
act(() => {
eventHandler({
type: 'message',
role: 'agent',
content: [{ type: 'text', text: 'Hello' }],
id: '1',
timestamp: '',
streamId: '',
});
});
expect(result.current.pendingHistoryItems).toHaveLength(1);
expect(result.current.pendingHistoryItems[0]).toMatchObject({
type: 'gemini',
text: 'Hello',
});
act(() => {
eventHandler({
type: 'message',
role: 'agent',
content: [{ type: 'text', text: ' world' }],
id: '2',
timestamp: '',
streamId: '',
});
});
expect(result.current.pendingHistoryItems[0].text).toBe('Hello world');
});
it('should process thought events and update thought state', async () => {
const { result } = await renderHookWithProviders(() =>
useAgentStream({
agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol,
addItem: mockAddItem,
onCancelSubmit: mockOnCancelSubmit,
isShellFocused: false,
}),
);
const eventHandler = vi.mocked(mockLegacyAgentProtocol.subscribe).mock
.calls[0][0];
act(() => {
eventHandler({
type: 'message',
role: 'agent',
content: [{ type: 'thought', thought: '**Thinking** about tests' }],
id: '1',
timestamp: '',
streamId: '',
});
});
expect(result.current.thought).toEqual({
subject: 'Thinking',
description: 'about tests',
});
});
it('should call agent.abort when cancelOngoingRequest is called', async () => {
const { result } = await renderHookWithProviders(() =>
useAgentStream({
agent: mockLegacyAgentProtocol as unknown as LegacyAgentProtocol,
addItem: mockAddItem,
onCancelSubmit: mockOnCancelSubmit,
isShellFocused: false,
}),
);
await act(async () => {
await result.current.cancelOngoingRequest();
});
expect(mockLegacyAgentProtocol.abort).toHaveBeenCalled();
expect(mockOnCancelSubmit).toHaveBeenCalledWith(false);
});
});

View File

@@ -0,0 +1,528 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { useState, useRef, useCallback, useEffect, useMemo } from 'react';
import {
getErrorMessage,
MessageSenderType,
debugLogger,
geminiPartsToContentParts,
parseThought,
CoreToolCallStatus,
type ApprovalMode,
Kind,
type ThoughtSummary,
type RetryAttemptPayload,
type AgentEvent,
type AgentProtocol,
type Logger,
type Part,
} from '@google/gemini-cli-core';
import type {
HistoryItemWithoutId,
LoopDetectionConfirmationRequest,
IndividualToolCallDisplay,
HistoryItemToolGroup,
} from '../types.js';
import { StreamingState, MessageType } from '../types.js';
import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
import { getToolGroupBorderAppearance } from '../utils/borderStyles.js';
import { type BackgroundTask } from './useExecutionLifecycle.js';
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
import { useSessionStats } from '../contexts/SessionContext.js';
import { useStateAndRef } from './useStateAndRef.js';
import { type MinimalTrackedToolCall } from './useTurnActivityMonitor.js';
export interface UseAgentStreamOptions {
agent?: AgentProtocol;
addItem: UseHistoryManagerReturn['addItem'];
onCancelSubmit: (shouldRestorePrompt?: boolean) => void;
isShellFocused?: boolean;
logger?: Logger | null;
}
/**
* useAgentStream implements the interactive agent loop using an AgentProtocol.
* It is completely agnostic to the specific agent implementation.
*/
export const useAgentStream = ({
agent,
addItem,
onCancelSubmit,
isShellFocused,
logger,
}: UseAgentStreamOptions) => {
const [initError] = useState<string | null>(null);
const [retryStatus] = useState<RetryAttemptPayload | null>(null);
const [streamingState, setStreamingState] = useState<StreamingState>(
StreamingState.Idle,
);
const [thought, setThought] = useState<ThoughtSummary | null>(null);
const [lastOutputTime, setLastOutputTime] = useState<number>(Date.now());
const currentStreamIdRef = useRef<string | null>(null);
const userMessageTimestampRef = useRef<number>(0);
const geminiMessageBufferRef = useRef<string>('');
const [pendingHistoryItem, pendingHistoryItemRef, setPendingHistoryItem] =
useStateAndRef<HistoryItemWithoutId | null>(null);
const [trackedTools, , setTrackedTools] = useStateAndRef<
IndividualToolCallDisplay[]
>([]);
const [pushedToolCallIds, pushedToolCallIdsRef, setPushedToolCallIds] =
useStateAndRef<Set<string>>(new Set());
const [_isFirstToolInGroup, isFirstToolInGroupRef, setIsFirstToolInGroup] =
useStateAndRef<boolean>(true);
const { startNewPrompt } = useSessionStats();
// TODO: Implement dynamic shell-related state derivation from trackedTools or dedicated refs.
// This includes activePtyId, backgroundTasks, and related visibility states to restore
// parity with legacy terminal focus detection and background task tracking.
// Note: Avoid checking ITERM_SESSION_ID for terminal detection and ensure context is sanitized.
const activePtyId = undefined;
const backgroundTaskCount = 0;
const isBackgroundTaskVisible = false;
const toggleBackgroundTasks = useCallback(() => {}, []);
const backgroundCurrentExecution = undefined;
const backgroundTasks = useMemo(() => new Map<number, BackgroundTask>(), []);
const dismissBackgroundTask = useCallback(async (_pid: number) => {}, []);
// Use the trackedTools to mock pendingToolCalls for inactivity monitors
const pendingToolCalls = useMemo(
(): MinimalTrackedToolCall[] =>
trackedTools.map((t) => ({
request: {
name: t.originalRequestName || t.name,
args: { command: t.description },
callId: t.callId,
isClientInitiated: t.isClientInitiated ?? false,
prompt_id: '',
},
status: t.status,
})),
[trackedTools],
);
// TODO: Support LoopDetection confirmation requests
const [loopDetectionConfirmationRequest] =
useState<LoopDetectionConfirmationRequest | null>(null);
const flushPendingText = useCallback(() => {
if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, userMessageTimestampRef.current);
setPendingHistoryItem(null);
geminiMessageBufferRef.current = '';
}
}, [addItem, pendingHistoryItemRef, setPendingHistoryItem]);
const cancelOngoingRequest = useCallback(async () => {
if (agent) {
await agent.abort();
setStreamingState(StreamingState.Idle);
onCancelSubmit(false);
}
}, [agent, onCancelSubmit]);
// TODO: Support native handleApprovalModeChange for Plan Mode
const handleApprovalModeChange = useCallback(
async (newApprovalMode: ApprovalMode) => {
debugLogger.debug(`Approval mode changed to ${newApprovalMode} (stub)`);
},
[],
);
const handleEvent = useCallback(
(event: AgentEvent) => {
setLastOutputTime(Date.now());
switch (event.type) {
case 'agent_start':
setStreamingState(StreamingState.Responding);
break;
case 'agent_end':
setStreamingState(StreamingState.Idle);
flushPendingText();
break;
case 'message':
if (event.role === 'agent') {
for (const part of event.content) {
if (part.type === 'text') {
geminiMessageBufferRef.current += part.text;
// Update pending history item with incremental text
const splitPoint = findLastSafeSplitPoint(
geminiMessageBufferRef.current,
);
if (splitPoint === geminiMessageBufferRef.current.length) {
setPendingHistoryItem({
type: 'gemini',
text: geminiMessageBufferRef.current,
});
} else {
const before = geminiMessageBufferRef.current.substring(
0,
splitPoint,
);
const after =
geminiMessageBufferRef.current.substring(splitPoint);
addItem(
{ type: 'gemini', text: before },
userMessageTimestampRef.current,
);
geminiMessageBufferRef.current = after;
setPendingHistoryItem({
type: 'gemini_content',
text: after,
});
}
} else if (part.type === 'thought') {
setThought(parseThought(part.thought));
}
}
}
break;
case 'tool_request': {
flushPendingText();
const legacyState = event._meta?.legacyState;
const displayName = legacyState?.displayName ?? event.name;
const isOutputMarkdown = legacyState?.isOutputMarkdown ?? false;
const desc = legacyState?.description ?? '';
const fallbackKind = Kind.Other;
const newCall: IndividualToolCallDisplay = {
callId: event.requestId,
name: displayName,
originalRequestName: event.name,
description: desc,
status: CoreToolCallStatus.Scheduled,
isClientInitiated: false,
renderOutputAsMarkdown: isOutputMarkdown,
kind: legacyState?.kind ?? fallbackKind,
confirmationDetails: undefined,
resultDisplay: undefined,
};
setTrackedTools((prev) => [...prev, newCall]);
break;
}
case 'tool_update': {
setTrackedTools((prev) =>
prev.map((tc): IndividualToolCallDisplay => {
if (tc.callId !== event.requestId) return tc;
const legacyState = event._meta?.legacyState;
const evtStatus = legacyState?.status;
let status = tc.status;
if (evtStatus === 'executing')
status = CoreToolCallStatus.Executing;
else if (evtStatus === 'error') status = CoreToolCallStatus.Error;
else if (evtStatus === 'success')
status = CoreToolCallStatus.Success;
const liveOutput =
event.displayContent?.[0]?.type === 'text'
? event.displayContent[0].text
: tc.resultDisplay;
const progressMessage =
legacyState?.progressMessage ?? tc.progressMessage;
const progress = legacyState?.progress ?? tc.progress;
const progressTotal =
legacyState?.progressTotal ?? tc.progressTotal;
const ptyId = legacyState?.pid ?? tc.ptyId;
const description = legacyState?.description ?? tc.description;
return {
...tc,
status,
resultDisplay: liveOutput,
progressMessage,
progress,
progressTotal,
ptyId,
description,
};
}),
);
break;
}
case 'tool_response': {
setTrackedTools((prev) =>
prev.map((tc): IndividualToolCallDisplay => {
if (tc.callId !== event.requestId) return tc;
const legacyState = event._meta?.legacyState;
const outputFile = legacyState?.outputFile;
const resultDisplay =
event.displayContent?.[0]?.type === 'text'
? event.displayContent[0].text
: tc.resultDisplay;
return {
...tc,
status: event.isError
? CoreToolCallStatus.Error
: CoreToolCallStatus.Success,
resultDisplay,
outputFile,
};
}),
);
break;
}
case 'error':
addItem(
{ type: MessageType.ERROR, text: event.message },
userMessageTimestampRef.current,
);
break;
case 'initialize':
case 'session_update':
case 'elicitation_request':
case 'elicitation_response':
case 'usage':
case 'custom':
// These events are currently not handled in the UI
break;
default:
debugLogger.error('Unknown agent event type:', event);
event satisfies never;
break;
}
},
[
addItem,
flushPendingText,
setPendingHistoryItem,
setTrackedTools,
setStreamingState,
setThought,
setLastOutputTime,
],
);
useEffect(() => {
const unsubscribe = agent?.subscribe(handleEvent);
return () => unsubscribe?.();
}, [agent, handleEvent]);
const submitQuery = useCallback(
async (
query: Part[] | string,
options?: { isContinuation: boolean },
_prompt_id?: string,
) => {
if (!agent) return;
const timestamp = Date.now();
setLastOutputTime(timestamp);
userMessageTimestampRef.current = timestamp;
geminiMessageBufferRef.current = '';
if (!options?.isContinuation) {
if (typeof query === 'string') {
addItem({ type: MessageType.USER, text: query }, timestamp);
void logger?.logMessage(MessageSenderType.USER, query);
}
startNewPrompt();
}
const parts = geminiPartsToContentParts(
typeof query === 'string' ? [{ text: query }] : query,
);
try {
const { streamId } = await agent.send({
message: { content: parts },
});
currentStreamIdRef.current = streamId;
} catch (err) {
addItem(
{ type: MessageType.ERROR, text: getErrorMessage(err) },
timestamp,
);
}
},
[agent, addItem, logger, startNewPrompt],
);
useEffect(() => {
if (trackedTools.length > 0) {
const isNewBatch = !trackedTools.some((tc) =>
pushedToolCallIdsRef.current.has(tc.callId),
);
if (isNewBatch) {
setPushedToolCallIds(new Set());
setIsFirstToolInGroup(true);
}
} else if (streamingState === StreamingState.Idle) {
setPushedToolCallIds(new Set());
setIsFirstToolInGroup(true);
}
}, [
trackedTools,
pushedToolCallIdsRef,
setPushedToolCallIds,
setIsFirstToolInGroup,
streamingState,
]);
// Push completed tools to history
useEffect(() => {
const toolsToPush: IndividualToolCallDisplay[] = [];
for (let i = 0; i < trackedTools.length; i++) {
const tc = trackedTools[i];
if (pushedToolCallIdsRef.current.has(tc.callId)) continue;
if (
tc.status === 'success' ||
tc.status === 'error' ||
tc.status === 'cancelled'
) {
toolsToPush.push(tc);
} else {
break;
}
}
if (toolsToPush.length > 0) {
const newPushed = new Set(pushedToolCallIdsRef.current);
for (const tc of toolsToPush) {
newPushed.add(tc.callId);
}
const isLastInBatch =
toolsToPush[toolsToPush.length - 1] ===
trackedTools[trackedTools.length - 1];
const appearance = getToolGroupBorderAppearance(
{ type: 'tool_group', tools: trackedTools },
activePtyId,
!!isShellFocused,
[],
backgroundTasks,
);
const historyItem: HistoryItemToolGroup = {
type: 'tool_group',
tools: toolsToPush,
borderTop: isFirstToolInGroupRef.current,
borderBottom: isLastInBatch,
...appearance,
};
addItem(historyItem);
setPushedToolCallIds(newPushed);
setIsFirstToolInGroup(false);
}
}, [
trackedTools,
pushedToolCallIdsRef,
isFirstToolInGroupRef,
setPushedToolCallIds,
setIsFirstToolInGroup,
addItem,
activePtyId,
isShellFocused,
backgroundTasks,
]);
const pendingToolGroupItems = useMemo((): HistoryItemWithoutId[] => {
const remainingTools = trackedTools.filter(
(tc) => !pushedToolCallIds.has(tc.callId),
);
const items: HistoryItemWithoutId[] = [];
const appearance = getToolGroupBorderAppearance(
{ type: 'tool_group', tools: trackedTools },
activePtyId,
!!isShellFocused,
[],
backgroundTasks,
);
if (remainingTools.length > 0) {
items.push({
type: 'tool_group',
tools: remainingTools,
borderTop: pushedToolCallIds.size === 0,
borderBottom: false,
...appearance,
});
}
const allTerminal =
trackedTools.length > 0 &&
trackedTools.every(
(tc) =>
tc.status === 'success' ||
tc.status === 'error' ||
tc.status === 'cancelled',
);
const allPushed =
trackedTools.length > 0 &&
trackedTools.every((tc) => pushedToolCallIds.has(tc.callId));
const anyVisibleInHistory = pushedToolCallIds.size > 0;
const anyVisibleInPending = remainingTools.length > 0;
if (
trackedTools.length > 0 &&
!(allTerminal && allPushed) &&
(anyVisibleInHistory || anyVisibleInPending)
) {
items.push({
type: 'tool_group' as const,
tools: [],
borderTop: false,
borderBottom: true,
...appearance,
});
}
return items;
}, [
trackedTools,
pushedToolCallIds,
activePtyId,
isShellFocused,
backgroundTasks,
]);
const pendingHistoryItems = useMemo(
() =>
[pendingHistoryItem, ...pendingToolGroupItems].filter(
(i): i is HistoryItemWithoutId => i !== undefined && i !== null,
),
[pendingHistoryItem, pendingToolGroupItems],
);
return {
streamingState,
submitQuery,
initError,
pendingHistoryItems,
thought,
cancelOngoingRequest,
pendingToolCalls,
handleApprovalModeChange,
activePtyId,
loopDetectionConfirmationRequest,
lastOutputTime,
backgroundTaskCount,
isBackgroundTaskVisible,
toggleBackgroundTasks,
backgroundCurrentExecution,
backgroundTasks,
retryStatus,
dismissBackgroundTask,
};
};

View File

@@ -5,20 +5,22 @@
*/
import { useInactivityTimer } from './useInactivityTimer.js';
import { useTurnActivityMonitor } from './useTurnActivityMonitor.js';
import {
useTurnActivityMonitor,
type MinimalTrackedToolCall,
} from './useTurnActivityMonitor.js';
import {
SHELL_FOCUS_HINT_DELAY_MS,
SHELL_ACTION_REQUIRED_TITLE_DELAY_MS,
SHELL_SILENT_WORKING_TITLE_DELAY_MS,
} from '../constants.js';
import type { StreamingState } from '../types.js';
import { type TrackedToolCall } from './useToolScheduler.js';
interface ShellInactivityStatusProps {
activePtyId: number | string | null | undefined;
lastOutputTime: number;
streamingState: StreamingState;
pendingToolCalls: TrackedToolCall[];
pendingToolCalls: MinimalTrackedToolCall[];
embeddedShellFocused: boolean;
isInteractiveShellEnabled: boolean;
}

View File

@@ -79,6 +79,7 @@ export function useToolScheduler(
React.Dispatch<React.SetStateAction<TrackedToolCall[]>>,
CancelAllFn,
number,
Scheduler,
] {
// State stores tool calls organized by their originating schedulerId
const [toolCallsMap, setToolCallsMap] = useState<
@@ -319,6 +320,7 @@ export function useToolScheduler(
setToolCallsForDisplay,
cancelAll,
lastToolOutputTime,
scheduler,
];
}

View File

@@ -6,8 +6,16 @@
import { useState, useEffect, useRef, useMemo } from 'react';
import { StreamingState } from '../types.js';
import { hasRedirection } from '@google/gemini-cli-core';
import { type TrackedToolCall } from './useToolScheduler.js';
import {
hasRedirection,
type CoreToolCallStatus,
type ToolCallRequestInfo,
} from '@google/gemini-cli-core';
export interface MinimalTrackedToolCall {
status: CoreToolCallStatus;
request: ToolCallRequestInfo;
}
export interface TurnActivityStatus {
operationStartTime: number;
@@ -21,7 +29,7 @@ export interface TurnActivityStatus {
export const useTurnActivityMonitor = (
streamingState: StreamingState,
activePtyId: number | string | null | undefined,
pendingToolCalls: TrackedToolCall[] = [],
pendingToolCalls: MinimalTrackedToolCall[] = [],
): TurnActivityStatus => {
const [operationStartTime, setOperationStartTime] = useState(0);

View File

@@ -29,7 +29,10 @@ export function getToolGroupBorderAppearance(
item:
| HistoryItem
| HistoryItemWithoutId
| { type: 'tool_group'; tools: TrackedToolCall[] },
| {
type: 'tool_group';
tools: Array<IndividualToolCallDisplay | TrackedToolCall>;
},
activeShellPtyId: number | null | undefined,
embeddedShellFocused: boolean | undefined,
allPendingItems: HistoryItemWithoutId[] = [],
@@ -41,7 +44,7 @@ export function getToolGroupBorderAppearance(
// If this item has no tools, it's a closing slice for the current batch.
// We need to look at the last pending item to determine the batch's appearance.
const toolsToInspect: Array<IndividualToolCallDisplay | TrackedToolCall> =
const toolsToInspect =
item.tools.length > 0
? item.tools
: allPendingItems