From 6975224e452544ea2ed4016dd9baea2be7f35b32 Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Tue, 24 Mar 2026 11:52:16 -0700 Subject: [PATCH] feat: implement tool display parity for agent stream This commit achieves visual parity for tool execution in the interactive stream when using the experimental 'useAgentProtocol' flag. It removes direct UI dependency on the tool scheduler's internal state. Key changes: - Core 'LegacyAgentSession' now attaches display metadata (displayName, description, etc.) to 'tool_request' AgentEvents. - Core 'LegacyAgentSession' listens to the MessageBus to emit 'tool_update' AgentEvents for live output (e.g., shell commands). - UI 'useAgentStream' now maintains its own 'trackedTools' local state, constructed entirely from incoming 'tool_request', 'tool_update', and 'tool_response' events. - The local 'trackedTools' state is mapped to 'pendingToolGroupItems' using the existing 'mapToDisplay' function for seamless visual parity. --- packages/cli/src/ui/hooks/useAgentStream.ts | 212 ++++++++++++- .../core/src/agent/legacy-agent-session.ts | 292 +++++++++++------- 2 files changed, 382 insertions(+), 122 deletions(-) diff --git a/packages/cli/src/ui/hooks/useAgentStream.ts b/packages/cli/src/ui/hooks/useAgentStream.ts index 1d664f7815..30fdac3c67 100644 --- a/packages/cli/src/ui/hooks/useAgentStream.ts +++ b/packages/cli/src/ui/hooks/useAgentStream.ts @@ -30,12 +30,18 @@ import type { } from '../types.js'; import { StreamingState, MessageType } from '../types.js'; import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js'; +import { getToolGroupBorderAppearance } from '../utils/borderStyles.js'; import { type BackgroundShell } from './shellCommandProcessor.js'; import type { UseHistoryManagerReturn } from './useHistoryManager.js'; import { useLogger } from './useLogger.js'; +import { mapToDisplay as mapTrackedToolCallsToDisplay } from './toolMapping.js'; import { useToolScheduler, } from './useToolScheduler.js'; +import type { + TrackedToolCall, +} from './useToolScheduler.js'; + import { useSessionStats } from '../contexts/SessionContext.js'; import type { LoadedSettings } from '../../config/settings.js'; import type { SlashCommandProcessorResult } from '../types.js'; @@ -86,6 +92,13 @@ export const useAgentStream = ( const [pendingHistoryItem, pendingHistoryItemRef, setPendingHistoryItem] = useStateAndRef(null); + const [trackedTools, , setTrackedTools] = + useStateAndRef([]); + const [pushedToolCallIds, pushedToolCallIdsRef, setPushedToolCallIds] = + useStateAndRef>(new Set()); + const [_isFirstToolInGroup, isFirstToolInGroupRef, setIsFirstToolInGroup] = + useStateAndRef(true); + const [ toolCalls, _schedule, @@ -189,10 +202,58 @@ export const useAgentStream = ( } break; case 'tool_request': - // UI state is handled automatically by useToolScheduler via MessageBus + setTrackedTools((prev) => [ + ...prev, + { + request: { + callId: event.requestId, + name: event.name, + args: event.args, + isClientInitiated: false, + originalRequestName: event.name, + }, + status: 'executing', + tool: { + displayName: (event._meta?.['displayName'] as string) ?? event.name, + isOutputMarkdown: (event._meta?.['isOutputMarkdown'] as boolean) ?? false, + }, + invocation: { + getDescription: () => (event._meta?.['description'] as string) ?? '', + }, + } as unknown as TrackedToolCall, + ]); + break; + case 'tool_update': + setTrackedTools((prev) => + prev.map((tc) => + tc.request.callId === event.requestId + ? ({ + ...tc, + liveOutput: event.displayContent?.[0]?.type === 'text' ? event.displayContent[0].text : undefined, + progressMessage: event.data?.['progressMessage'] as string | undefined, + progress: event.data?.['progress'] as number | undefined, + progressTotal: event.data?.['progressTotal'] as number | undefined, + pid: event.data?.['pid'] as number | undefined, + } as unknown as TrackedToolCall) + : tc, + ), + ); break; case 'tool_response': - // UI state is handled automatically by useToolScheduler via MessageBus + setTrackedTools((prev) => + prev.map((tc) => + tc.request.callId === event.requestId + ? ({ + ...tc, + status: event.isError ? 'error' : 'success', + response: { + resultDisplay: event.displayContent?.[0]?.type === 'text' ? event.displayContent[0].text : undefined, + }, + responseSubmittedToGemini: true, + } as unknown as TrackedToolCall) + : tc, + ), + ); break; case 'error': addItem( @@ -263,9 +324,152 @@ export const useAgentStream = ( [addItem, logger, startNewPrompt], ); + useEffect(() => { + if (trackedTools.length > 0) { + const isNewBatch = !trackedTools.some((tc) => + pushedToolCallIdsRef.current.has(tc.request.callId), + ); + if (isNewBatch) { + setPushedToolCallIds(new Set()); + setIsFirstToolInGroup(true); + } + } else if (streamingState === StreamingState.Idle) { + setPushedToolCallIds(new Set()); + setIsFirstToolInGroup(true); + } + }, [ + trackedTools, + pushedToolCallIdsRef, + setPushedToolCallIds, + setIsFirstToolInGroup, + streamingState, + ]); + + // Push completed tools to history + useEffect(() => { + const toolsToPush: TrackedToolCall[] = []; + for (let i = 0; i < trackedTools.length; i++) { + const tc = trackedTools[i]; + if (pushedToolCallIdsRef.current.has(tc.request.callId)) continue; + + if ( + tc.status === 'success' || + tc.status === 'error' || + tc.status === 'cancelled' + ) { + toolsToPush.push(tc); + } else { + break; + } + } + + if (toolsToPush.length > 0) { + const newPushed = new Set(pushedToolCallIdsRef.current); + for (const tc of toolsToPush) { + newPushed.add(tc.request.callId); + } + + const isLastInBatch = + toolsToPush[toolsToPush.length - 1] === trackedTools[trackedTools.length - 1]; + + const historyItem = mapTrackedToolCallsToDisplay(toolsToPush, { + borderTop: isFirstToolInGroupRef.current, + borderBottom: isLastInBatch, + ...getToolGroupBorderAppearance( + { type: 'tool_group', tools: trackedTools as any[] }, + activePtyId, + !!_isShellFocused, + [], + backgroundShells, + ), + }); + + addItem(historyItem); + setPushedToolCallIds(newPushed); + setIsFirstToolInGroup(false); + } + }, [ + trackedTools, + pushedToolCallIdsRef, + isFirstToolInGroupRef, + setPushedToolCallIds, + setIsFirstToolInGroup, + addItem, + activePtyId, + _isShellFocused, + backgroundShells, + ]); + + const pendingToolGroupItems = useMemo((): HistoryItemWithoutId[] => { + const remainingTools = trackedTools.filter( + (tc) => !pushedToolCallIds.has(tc.request.callId), + ); + + const items: HistoryItemWithoutId[] = []; + + const appearance = getToolGroupBorderAppearance( + { type: 'tool_group', tools: trackedTools as any[] }, + activePtyId, + !!_isShellFocused, + [], + backgroundShells, + ); + + if (remainingTools.length > 0) { + items.push( + mapTrackedToolCallsToDisplay(remainingTools, { + borderTop: pushedToolCallIds.size === 0, + borderBottom: false, + ...appearance, + }), + ); + } + + const allTerminal = + trackedTools.length > 0 && + trackedTools.every( + (tc) => + tc.status === 'success' || + tc.status === 'error' || + tc.status === 'cancelled', + ); + + const allPushed = + trackedTools.length > 0 && + trackedTools.every((tc) => pushedToolCallIds.has(tc.request.callId)); + + const anyVisibleInHistory = pushedToolCallIds.size > 0; + const anyVisibleInPending = remainingTools.length > 0; + + if ( + trackedTools.length > 0 && + !(allTerminal && allPushed) && + (anyVisibleInHistory || anyVisibleInPending) + ) { + items.push({ + type: 'tool_group' as const, + tools: [], + borderTop: false, + borderBottom: true, + ...appearance, + }); + } + + return items; + }, [ + trackedTools, + pushedToolCallIds, + isFirstToolInGroupRef, + activePtyId, + _isShellFocused, + backgroundShells, + ]); + const pendingHistoryItems = useMemo(() => { - return pendingHistoryItem ? [pendingHistoryItem] : []; - }, [pendingHistoryItem]); + return [pendingHistoryItem, ...pendingToolGroupItems].filter( + (i): i is HistoryItemWithoutId => i !== undefined && i !== null, + ); + }, [pendingHistoryItem, pendingToolGroupItems]); return { streamingState, diff --git a/packages/core/src/agent/legacy-agent-session.ts b/packages/core/src/agent/legacy-agent-session.ts index fa1d652eb2..36760b2629 100644 --- a/packages/core/src/agent/legacy-agent-session.ts +++ b/packages/core/src/agent/legacy-agent-session.ts @@ -18,6 +18,8 @@ import type { Scheduler } from '../scheduler/scheduler.js'; import { recordToolCallInteractions } from '../code_assist/telemetry.js'; import { ToolErrorType, isFatalToolError } from '../tools/tool-error.js'; import { debugLogger } from '../utils/debugLogger.js'; +import { MessageBusType } from '../confirmation-bus/types.js'; +import type { ToolCallsUpdateMessage } from '../confirmation-bus/types.js'; import { buildToolResponseData, contentPartsToGeminiParts, @@ -163,142 +165,182 @@ class LegacyAgentProtocol implements AgentProtocol { let turnCount = 0; const maxTurns = this._config.getMaxSessionTurns(); - while (true) { - turnCount++; - if (maxTurns >= 0 && turnCount > maxTurns) { - this._finishStream('max_turns', { - code: 'MAX_TURNS_EXCEEDED', - maxTurns, - turnCount: turnCount - 1, - }); - return; + const handleToolCallsUpdate = (event: ToolCallsUpdateMessage) => { + const toolUpdates: AgentEvent[] = []; + for (const tc of event.toolCalls) { + if (tc.status === 'executing') { + toolUpdates.push( + this._makeToolUpdateEvent({ + requestId: tc.request.callId, + displayContent: toolResultDisplayToContentParts(tc.liveOutput), + data: { + progressMessage: tc.progressMessage, + progress: tc.progress, + progressTotal: tc.progressTotal, + pid: tc.pid, + }, + }), + ); + } } + this._emit(toolUpdates); + }; - const toolCallRequests: ToolCallRequestInfo[] = []; - const responseStream = this._client.sendMessageStream( - currentParts, - this._abortController.signal, - this._promptId, - undefined, - false, - currentDisplayContent, - ); - currentDisplayContent = undefined; + this._config.getMessageBus().subscribe(MessageBusType.TOOL_CALLS_UPDATE, handleToolCallsUpdate); + + try { + while (true) { + turnCount++; + if (maxTurns >= 0 && turnCount > maxTurns) { + this._finishStream('max_turns', { + code: 'MAX_TURNS_EXCEEDED', + maxTurns, + turnCount: turnCount - 1, + }); + return; + } + + const toolCallRequests: ToolCallRequestInfo[] = []; + const responseStream = this._client.sendMessageStream( + currentParts, + this._abortController.signal, + this._promptId, + undefined, + false, + currentDisplayContent, + ); + currentDisplayContent = undefined; + + for await (const event of responseStream) { + if (this._abortController.signal.aborted) { + this._finishStream('aborted'); + return; + } + + if (event.type === GeminiEventType.ToolCallRequest) { + toolCallRequests.push(event.value); + } + + const translatedEvents = translateEvent(event, this._translationState); + + for (const ev of translatedEvents) { + if (ev.type === 'tool_request') { + const tool = this._config.getToolRegistry().getTool(ev.name); + ev._meta = { + displayName: tool?.displayName ?? ev.name, + description: tool?.description ?? '', + isOutputMarkdown: tool?.isOutputMarkdown ?? false, + }; + } + } + + this._emit(translatedEvents); + + switch (event.type) { + case GeminiEventType.Error: + case GeminiEventType.InvalidStream: + case GeminiEventType.ContextWindowWillOverflow: + this._finishStream('failed'); + return; + case GeminiEventType.Finished: + if (toolCallRequests.length === 0) { + this._finishStream(mapFinishReason(event.value.reason)); + return; + } + break; + case GeminiEventType.AgentExecutionStopped: + case GeminiEventType.UserCancelled: + case GeminiEventType.MaxSessionTurns: + this._clearActiveStream(); + return; + default: + break; + } + } - for await (const event of responseStream) { if (this._abortController.signal.aborted) { this._finishStream('aborted'); return; } - if (event.type === GeminiEventType.ToolCallRequest) { - toolCallRequests.push(event.value); + if (toolCallRequests.length === 0) { + this._finishStream('completed'); + return; } - this._emit(translateEvent(event, this._translationState)); - - switch (event.type) { - case GeminiEventType.Error: - case GeminiEventType.InvalidStream: - case GeminiEventType.ContextWindowWillOverflow: - this._finishStream('failed'); - return; - case GeminiEventType.Finished: - if (toolCallRequests.length === 0) { - this._finishStream(mapFinishReason(event.value.reason)); - return; - } - break; - case GeminiEventType.AgentExecutionStopped: - case GeminiEventType.UserCancelled: - case GeminiEventType.MaxSessionTurns: - this._clearActiveStream(); - return; - default: - break; - } - } - - if (this._abortController.signal.aborted) { - this._finishStream('aborted'); - return; - } - - if (toolCallRequests.length === 0) { - this._finishStream('completed'); - return; - } - - const completedToolCalls = await this._scheduler.schedule( - toolCallRequests, - this._abortController.signal, - ); - - if (this._abortController.signal.aborted) { - this._finishStream('aborted'); - return; - } - - const toolResponseParts: Part[] = []; - for (const tc of completedToolCalls) { - const response = tc.response; - const request = tc.request; - const content: ContentPart[] = response.error - ? [{ type: 'text', text: response.error.message }] - : geminiPartsToContentParts(response.responseParts); - const displayContent = toolResultDisplayToContentParts( - response.resultDisplay, + const completedToolCalls = await this._scheduler.schedule( + toolCallRequests, + this._abortController.signal, ); - const data = buildToolResponseData(response); - this._emit([ - this._makeToolResponseEvent({ - requestId: request.callId, - name: request.name, - content, - isError: response.error !== undefined, - ...(displayContent ? { displayContent } : {}), - ...(data ? { data } : {}), - }), - ]); - - if (response.responseParts) { - toolResponseParts.push(...response.responseParts); + if (this._abortController.signal.aborted) { + this._finishStream('aborted'); + return; } - } - try { - const currentModel = - this._client.getCurrentSequenceModel() ?? this._config.getModel(); - this._client - .getChat() - .recordCompletedToolCalls(currentModel, completedToolCalls); - await recordToolCallInteractions(this._config, completedToolCalls); - } catch (error) { - debugLogger.error( - `Error recording completed tool call information: ${error}`, + const toolResponseParts: Part[] = []; + for (const tc of completedToolCalls) { + const response = tc.response; + const request = tc.request; + const content: ContentPart[] = response.error + ? [{ type: 'text', text: response.error.message }] + : geminiPartsToContentParts(response.responseParts); + const displayContent = toolResultDisplayToContentParts( + response.resultDisplay, + ); + const data = buildToolResponseData(response); + + this._emit([ + this._makeToolResponseEvent({ + requestId: request.callId, + name: request.name, + content, + isError: response.error !== undefined, + ...(displayContent ? { displayContent } : {}), + ...(data ? { data } : {}), + }), + ]); + + if (response.responseParts) { + toolResponseParts.push(...response.responseParts); + } + } + + try { + const currentModel = + this._client.getCurrentSequenceModel() ?? this._config.getModel(); + this._client + .getChat() + .recordCompletedToolCalls(currentModel, completedToolCalls); + await recordToolCallInteractions(this._config, completedToolCalls); + } catch (error) { + debugLogger.error( + `Error recording completed tool call information: ${error}`, + ); + } + + const stopTool = completedToolCalls.find( + (tc) => + tc.response.errorType === ToolErrorType.STOP_EXECUTION && + tc.response.error !== undefined, ); - } + if (stopTool) { + this._finishStream('completed'); + return; + } - const stopTool = completedToolCalls.find( - (tc) => - tc.response.errorType === ToolErrorType.STOP_EXECUTION && - tc.response.error !== undefined, - ); - if (stopTool) { - this._finishStream('completed'); - return; - } + const fatalTool = completedToolCalls.find((tc) => + isFatalToolError(tc.response.errorType), + ); + if (fatalTool) { + this._finishStream('failed'); + return; + } - const fatalTool = completedToolCalls.find((tc) => - isFatalToolError(tc.response.errorType), - ); - if (fatalTool) { - this._finishStream('failed'); - return; + currentParts = toolResponseParts; } - - currentParts = toolResponseParts; + } finally { + this._config.getMessageBus().unsubscribe(MessageBusType.TOOL_CALLS_UPDATE, handleToolCallsUpdate); } } @@ -434,6 +476,20 @@ class LegacyAgentProtocol implements AgentProtocol { return event; } + private _makeToolUpdateEvent( + payload: Omit< + AgentEvent<'tool_update'>, + 'id' | 'timestamp' | 'streamId' | 'type' + >, + ): AgentEvent<'tool_update'> { + const event = { + ...this._nextEventFields(), + type: 'tool_update', + ...payload, + } satisfies AgentEvent<'tool_update'>; + return event; + } + private _makeAgentStartEvent(): AgentEvent<'agent_start'> { const event = { ...this._nextEventFields(),