diff --git a/packages/cli/src/ui/hooks/useAgentStream.ts b/packages/cli/src/ui/hooks/useAgentStream.ts index 1d664f7815..30fdac3c67 100644 --- a/packages/cli/src/ui/hooks/useAgentStream.ts +++ b/packages/cli/src/ui/hooks/useAgentStream.ts @@ -30,12 +30,18 @@ import type { } from '../types.js'; import { StreamingState, MessageType } from '../types.js'; import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js'; +import { getToolGroupBorderAppearance } from '../utils/borderStyles.js'; import { type BackgroundShell } from './shellCommandProcessor.js'; import type { UseHistoryManagerReturn } from './useHistoryManager.js'; import { useLogger } from './useLogger.js'; +import { mapToDisplay as mapTrackedToolCallsToDisplay } from './toolMapping.js'; import { useToolScheduler, } from './useToolScheduler.js'; +import type { + TrackedToolCall, +} from './useToolScheduler.js'; + import { useSessionStats } from '../contexts/SessionContext.js'; import type { LoadedSettings } from '../../config/settings.js'; import type { SlashCommandProcessorResult } from '../types.js'; @@ -86,6 +92,13 @@ export const useAgentStream = ( const [pendingHistoryItem, pendingHistoryItemRef, setPendingHistoryItem] = useStateAndRef(null); + const [trackedTools, , setTrackedTools] = + useStateAndRef([]); + const [pushedToolCallIds, pushedToolCallIdsRef, setPushedToolCallIds] = + useStateAndRef>(new Set()); + const [_isFirstToolInGroup, isFirstToolInGroupRef, setIsFirstToolInGroup] = + useStateAndRef(true); + const [ toolCalls, _schedule, @@ -189,10 +202,58 @@ export const useAgentStream = ( } break; case 'tool_request': - // UI state is handled automatically by useToolScheduler via MessageBus + setTrackedTools((prev) => [ + ...prev, + { + request: { + callId: event.requestId, + name: event.name, + args: event.args, + isClientInitiated: false, + originalRequestName: event.name, + }, + status: 'executing', + tool: { + displayName: (event._meta?.['displayName'] as string) ?? event.name, + isOutputMarkdown: (event._meta?.['isOutputMarkdown'] as boolean) ?? false, + }, + invocation: { + getDescription: () => (event._meta?.['description'] as string) ?? '', + }, + } as unknown as TrackedToolCall, + ]); + break; + case 'tool_update': + setTrackedTools((prev) => + prev.map((tc) => + tc.request.callId === event.requestId + ? ({ + ...tc, + liveOutput: event.displayContent?.[0]?.type === 'text' ? event.displayContent[0].text : undefined, + progressMessage: event.data?.['progressMessage'] as string | undefined, + progress: event.data?.['progress'] as number | undefined, + progressTotal: event.data?.['progressTotal'] as number | undefined, + pid: event.data?.['pid'] as number | undefined, + } as unknown as TrackedToolCall) + : tc, + ), + ); break; case 'tool_response': - // UI state is handled automatically by useToolScheduler via MessageBus + setTrackedTools((prev) => + prev.map((tc) => + tc.request.callId === event.requestId + ? ({ + ...tc, + status: event.isError ? 'error' : 'success', + response: { + resultDisplay: event.displayContent?.[0]?.type === 'text' ? event.displayContent[0].text : undefined, + }, + responseSubmittedToGemini: true, + } as unknown as TrackedToolCall) + : tc, + ), + ); break; case 'error': addItem( @@ -263,9 +324,152 @@ export const useAgentStream = ( [addItem, logger, startNewPrompt], ); + useEffect(() => { + if (trackedTools.length > 0) { + const isNewBatch = !trackedTools.some((tc) => + pushedToolCallIdsRef.current.has(tc.request.callId), + ); + if (isNewBatch) { + setPushedToolCallIds(new Set()); + setIsFirstToolInGroup(true); + } + } else if (streamingState === StreamingState.Idle) { + setPushedToolCallIds(new Set()); + setIsFirstToolInGroup(true); + } + }, [ + trackedTools, + pushedToolCallIdsRef, + setPushedToolCallIds, + setIsFirstToolInGroup, + streamingState, + ]); + + // Push completed tools to history + useEffect(() => { + const toolsToPush: TrackedToolCall[] = []; + for (let i = 0; i < trackedTools.length; i++) { + const tc = trackedTools[i]; + if (pushedToolCallIdsRef.current.has(tc.request.callId)) continue; + + if ( + tc.status === 'success' || + tc.status === 'error' || + tc.status === 'cancelled' + ) { + toolsToPush.push(tc); + } else { + break; + } + } + + if (toolsToPush.length > 0) { + const newPushed = new Set(pushedToolCallIdsRef.current); + for (const tc of toolsToPush) { + newPushed.add(tc.request.callId); + } + + const isLastInBatch = + toolsToPush[toolsToPush.length - 1] === trackedTools[trackedTools.length - 1]; + + const historyItem = mapTrackedToolCallsToDisplay(toolsToPush, { + borderTop: isFirstToolInGroupRef.current, + borderBottom: isLastInBatch, + ...getToolGroupBorderAppearance( + { type: 'tool_group', tools: trackedTools as any[] }, + activePtyId, + !!_isShellFocused, + [], + backgroundShells, + ), + }); + + addItem(historyItem); + setPushedToolCallIds(newPushed); + setIsFirstToolInGroup(false); + } + }, [ + trackedTools, + pushedToolCallIdsRef, + isFirstToolInGroupRef, + setPushedToolCallIds, + setIsFirstToolInGroup, + addItem, + activePtyId, + _isShellFocused, + backgroundShells, + ]); + + const pendingToolGroupItems = useMemo((): HistoryItemWithoutId[] => { + const remainingTools = trackedTools.filter( + (tc) => !pushedToolCallIds.has(tc.request.callId), + ); + + const items: HistoryItemWithoutId[] = []; + + const appearance = getToolGroupBorderAppearance( + { type: 'tool_group', tools: trackedTools as any[] }, + activePtyId, + !!_isShellFocused, + [], + backgroundShells, + ); + + if (remainingTools.length > 0) { + items.push( + mapTrackedToolCallsToDisplay(remainingTools, { + borderTop: pushedToolCallIds.size === 0, + borderBottom: false, + ...appearance, + }), + ); + } + + const allTerminal = + trackedTools.length > 0 && + trackedTools.every( + (tc) => + tc.status === 'success' || + tc.status === 'error' || + tc.status === 'cancelled', + ); + + const allPushed = + trackedTools.length > 0 && + trackedTools.every((tc) => pushedToolCallIds.has(tc.request.callId)); + + const anyVisibleInHistory = pushedToolCallIds.size > 0; + const anyVisibleInPending = remainingTools.length > 0; + + if ( + trackedTools.length > 0 && + !(allTerminal && allPushed) && + (anyVisibleInHistory || anyVisibleInPending) + ) { + items.push({ + type: 'tool_group' as const, + tools: [], + borderTop: false, + borderBottom: true, + ...appearance, + }); + } + + return items; + }, [ + trackedTools, + pushedToolCallIds, + isFirstToolInGroupRef, + activePtyId, + _isShellFocused, + backgroundShells, + ]); + const pendingHistoryItems = useMemo(() => { - return pendingHistoryItem ? [pendingHistoryItem] : []; - }, [pendingHistoryItem]); + return [pendingHistoryItem, ...pendingToolGroupItems].filter( + (i): i is HistoryItemWithoutId => i !== undefined && i !== null, + ); + }, [pendingHistoryItem, pendingToolGroupItems]); return { streamingState, diff --git a/packages/core/src/agent/legacy-agent-session.ts b/packages/core/src/agent/legacy-agent-session.ts index fa1d652eb2..36760b2629 100644 --- a/packages/core/src/agent/legacy-agent-session.ts +++ b/packages/core/src/agent/legacy-agent-session.ts @@ -18,6 +18,8 @@ import type { Scheduler } from '../scheduler/scheduler.js'; import { recordToolCallInteractions } from '../code_assist/telemetry.js'; import { ToolErrorType, isFatalToolError } from '../tools/tool-error.js'; import { debugLogger } from '../utils/debugLogger.js'; +import { MessageBusType } from '../confirmation-bus/types.js'; +import type { ToolCallsUpdateMessage } from '../confirmation-bus/types.js'; import { buildToolResponseData, contentPartsToGeminiParts, @@ -163,142 +165,182 @@ class LegacyAgentProtocol implements AgentProtocol { let turnCount = 0; const maxTurns = this._config.getMaxSessionTurns(); - while (true) { - turnCount++; - if (maxTurns >= 0 && turnCount > maxTurns) { - this._finishStream('max_turns', { - code: 'MAX_TURNS_EXCEEDED', - maxTurns, - turnCount: turnCount - 1, - }); - return; + const handleToolCallsUpdate = (event: ToolCallsUpdateMessage) => { + const toolUpdates: AgentEvent[] = []; + for (const tc of event.toolCalls) { + if (tc.status === 'executing') { + toolUpdates.push( + this._makeToolUpdateEvent({ + requestId: tc.request.callId, + displayContent: toolResultDisplayToContentParts(tc.liveOutput), + data: { + progressMessage: tc.progressMessage, + progress: tc.progress, + progressTotal: tc.progressTotal, + pid: tc.pid, + }, + }), + ); + } } + this._emit(toolUpdates); + }; - const toolCallRequests: ToolCallRequestInfo[] = []; - const responseStream = this._client.sendMessageStream( - currentParts, - this._abortController.signal, - this._promptId, - undefined, - false, - currentDisplayContent, - ); - currentDisplayContent = undefined; + this._config.getMessageBus().subscribe(MessageBusType.TOOL_CALLS_UPDATE, handleToolCallsUpdate); + + try { + while (true) { + turnCount++; + if (maxTurns >= 0 && turnCount > maxTurns) { + this._finishStream('max_turns', { + code: 'MAX_TURNS_EXCEEDED', + maxTurns, + turnCount: turnCount - 1, + }); + return; + } + + const toolCallRequests: ToolCallRequestInfo[] = []; + const responseStream = this._client.sendMessageStream( + currentParts, + this._abortController.signal, + this._promptId, + undefined, + false, + currentDisplayContent, + ); + currentDisplayContent = undefined; + + for await (const event of responseStream) { + if (this._abortController.signal.aborted) { + this._finishStream('aborted'); + return; + } + + if (event.type === GeminiEventType.ToolCallRequest) { + toolCallRequests.push(event.value); + } + + const translatedEvents = translateEvent(event, this._translationState); + + for (const ev of translatedEvents) { + if (ev.type === 'tool_request') { + const tool = this._config.getToolRegistry().getTool(ev.name); + ev._meta = { + displayName: tool?.displayName ?? ev.name, + description: tool?.description ?? '', + isOutputMarkdown: tool?.isOutputMarkdown ?? false, + }; + } + } + + this._emit(translatedEvents); + + switch (event.type) { + case GeminiEventType.Error: + case GeminiEventType.InvalidStream: + case GeminiEventType.ContextWindowWillOverflow: + this._finishStream('failed'); + return; + case GeminiEventType.Finished: + if (toolCallRequests.length === 0) { + this._finishStream(mapFinishReason(event.value.reason)); + return; + } + break; + case GeminiEventType.AgentExecutionStopped: + case GeminiEventType.UserCancelled: + case GeminiEventType.MaxSessionTurns: + this._clearActiveStream(); + return; + default: + break; + } + } - for await (const event of responseStream) { if (this._abortController.signal.aborted) { this._finishStream('aborted'); return; } - if (event.type === GeminiEventType.ToolCallRequest) { - toolCallRequests.push(event.value); + if (toolCallRequests.length === 0) { + this._finishStream('completed'); + return; } - this._emit(translateEvent(event, this._translationState)); - - switch (event.type) { - case GeminiEventType.Error: - case GeminiEventType.InvalidStream: - case GeminiEventType.ContextWindowWillOverflow: - this._finishStream('failed'); - return; - case GeminiEventType.Finished: - if (toolCallRequests.length === 0) { - this._finishStream(mapFinishReason(event.value.reason)); - return; - } - break; - case GeminiEventType.AgentExecutionStopped: - case GeminiEventType.UserCancelled: - case GeminiEventType.MaxSessionTurns: - this._clearActiveStream(); - return; - default: - break; - } - } - - if (this._abortController.signal.aborted) { - this._finishStream('aborted'); - return; - } - - if (toolCallRequests.length === 0) { - this._finishStream('completed'); - return; - } - - const completedToolCalls = await this._scheduler.schedule( - toolCallRequests, - this._abortController.signal, - ); - - if (this._abortController.signal.aborted) { - this._finishStream('aborted'); - return; - } - - const toolResponseParts: Part[] = []; - for (const tc of completedToolCalls) { - const response = tc.response; - const request = tc.request; - const content: ContentPart[] = response.error - ? [{ type: 'text', text: response.error.message }] - : geminiPartsToContentParts(response.responseParts); - const displayContent = toolResultDisplayToContentParts( - response.resultDisplay, + const completedToolCalls = await this._scheduler.schedule( + toolCallRequests, + this._abortController.signal, ); - const data = buildToolResponseData(response); - this._emit([ - this._makeToolResponseEvent({ - requestId: request.callId, - name: request.name, - content, - isError: response.error !== undefined, - ...(displayContent ? { displayContent } : {}), - ...(data ? { data } : {}), - }), - ]); - - if (response.responseParts) { - toolResponseParts.push(...response.responseParts); + if (this._abortController.signal.aborted) { + this._finishStream('aborted'); + return; } - } - try { - const currentModel = - this._client.getCurrentSequenceModel() ?? this._config.getModel(); - this._client - .getChat() - .recordCompletedToolCalls(currentModel, completedToolCalls); - await recordToolCallInteractions(this._config, completedToolCalls); - } catch (error) { - debugLogger.error( - `Error recording completed tool call information: ${error}`, + const toolResponseParts: Part[] = []; + for (const tc of completedToolCalls) { + const response = tc.response; + const request = tc.request; + const content: ContentPart[] = response.error + ? [{ type: 'text', text: response.error.message }] + : geminiPartsToContentParts(response.responseParts); + const displayContent = toolResultDisplayToContentParts( + response.resultDisplay, + ); + const data = buildToolResponseData(response); + + this._emit([ + this._makeToolResponseEvent({ + requestId: request.callId, + name: request.name, + content, + isError: response.error !== undefined, + ...(displayContent ? { displayContent } : {}), + ...(data ? { data } : {}), + }), + ]); + + if (response.responseParts) { + toolResponseParts.push(...response.responseParts); + } + } + + try { + const currentModel = + this._client.getCurrentSequenceModel() ?? this._config.getModel(); + this._client + .getChat() + .recordCompletedToolCalls(currentModel, completedToolCalls); + await recordToolCallInteractions(this._config, completedToolCalls); + } catch (error) { + debugLogger.error( + `Error recording completed tool call information: ${error}`, + ); + } + + const stopTool = completedToolCalls.find( + (tc) => + tc.response.errorType === ToolErrorType.STOP_EXECUTION && + tc.response.error !== undefined, ); - } + if (stopTool) { + this._finishStream('completed'); + return; + } - const stopTool = completedToolCalls.find( - (tc) => - tc.response.errorType === ToolErrorType.STOP_EXECUTION && - tc.response.error !== undefined, - ); - if (stopTool) { - this._finishStream('completed'); - return; - } + const fatalTool = completedToolCalls.find((tc) => + isFatalToolError(tc.response.errorType), + ); + if (fatalTool) { + this._finishStream('failed'); + return; + } - const fatalTool = completedToolCalls.find((tc) => - isFatalToolError(tc.response.errorType), - ); - if (fatalTool) { - this._finishStream('failed'); - return; + currentParts = toolResponseParts; } - - currentParts = toolResponseParts; + } finally { + this._config.getMessageBus().unsubscribe(MessageBusType.TOOL_CALLS_UPDATE, handleToolCallsUpdate); } } @@ -434,6 +476,20 @@ class LegacyAgentProtocol implements AgentProtocol { return event; } + private _makeToolUpdateEvent( + payload: Omit< + AgentEvent<'tool_update'>, + 'id' | 'timestamp' | 'streamId' | 'type' + >, + ): AgentEvent<'tool_update'> { + const event = { + ...this._nextEventFields(), + type: 'tool_update', + ...payload, + } satisfies AgentEvent<'tool_update'>; + return event; + } + private _makeAgentStartEvent(): AgentEvent<'agent_start'> { const event = { ...this._nextEventFields(),