feat: implement tool display parity for agent stream

This commit achieves visual parity for tool execution in the interactive stream when using the experimental 'useAgentProtocol' flag. It removes direct UI dependency on the tool scheduler's internal state.

Key changes:
- Core 'LegacyAgentSession' now attaches display metadata (displayName, description, etc.) to 'tool_request' AgentEvents.
- Core 'LegacyAgentSession' listens to the MessageBus to emit 'tool_update' AgentEvents for live output (e.g., shell commands).
- UI 'useAgentStream' now maintains its own 'trackedTools' local state, constructed entirely from incoming 'tool_request', 'tool_update', and 'tool_response' events.
- The local 'trackedTools' state is mapped to 'pendingToolGroupItems' using the existing 'mapToDisplay' function for seamless visual parity.
This commit is contained in:
Michael Bleigh
2026-03-24 11:52:16 -07:00
parent d3a7fc39e3
commit 6975224e45
2 changed files with 382 additions and 122 deletions
+208 -4
View File
@@ -30,12 +30,18 @@ import type {
} from '../types.js';
import { StreamingState, MessageType } from '../types.js';
import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
import { getToolGroupBorderAppearance } from '../utils/borderStyles.js';
import { type BackgroundShell } from './shellCommandProcessor.js';
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
import { useLogger } from './useLogger.js';
import { mapToDisplay as mapTrackedToolCallsToDisplay } from './toolMapping.js';
import {
useToolScheduler,
} from './useToolScheduler.js';
import type {
TrackedToolCall,
} from './useToolScheduler.js';
import { useSessionStats } from '../contexts/SessionContext.js';
import type { LoadedSettings } from '../../config/settings.js';
import type { SlashCommandProcessorResult } from '../types.js';
@@ -86,6 +92,13 @@ export const useAgentStream = (
const [pendingHistoryItem, pendingHistoryItemRef, setPendingHistoryItem] =
useStateAndRef<HistoryItemWithoutId | null>(null);
const [trackedTools, , setTrackedTools] =
useStateAndRef<TrackedToolCall[]>([]);
const [pushedToolCallIds, pushedToolCallIdsRef, setPushedToolCallIds] =
useStateAndRef<Set<string>>(new Set());
const [_isFirstToolInGroup, isFirstToolInGroupRef, setIsFirstToolInGroup] =
useStateAndRef<boolean>(true);
const [
toolCalls,
_schedule,
@@ -189,10 +202,58 @@ export const useAgentStream = (
}
break;
case 'tool_request':
// UI state is handled automatically by useToolScheduler via MessageBus
setTrackedTools((prev) => [
...prev,
{
request: {
callId: event.requestId,
name: event.name,
args: event.args,
isClientInitiated: false,
originalRequestName: event.name,
},
status: 'executing',
tool: {
displayName: (event._meta?.['displayName'] as string) ?? event.name,
isOutputMarkdown: (event._meta?.['isOutputMarkdown'] as boolean) ?? false,
},
invocation: {
getDescription: () => (event._meta?.['description'] as string) ?? '',
},
} as unknown as TrackedToolCall,
]);
break;
case 'tool_update':
setTrackedTools((prev) =>
prev.map((tc) =>
tc.request.callId === event.requestId
? ({
...tc,
liveOutput: event.displayContent?.[0]?.type === 'text' ? event.displayContent[0].text : undefined,
progressMessage: event.data?.['progressMessage'] as string | undefined,
progress: event.data?.['progress'] as number | undefined,
progressTotal: event.data?.['progressTotal'] as number | undefined,
pid: event.data?.['pid'] as number | undefined,
} as unknown as TrackedToolCall)
: tc,
),
);
break;
case 'tool_response':
// UI state is handled automatically by useToolScheduler via MessageBus
setTrackedTools((prev) =>
prev.map((tc) =>
tc.request.callId === event.requestId
? ({
...tc,
status: event.isError ? 'error' : 'success',
response: {
resultDisplay: event.displayContent?.[0]?.type === 'text' ? event.displayContent[0].text : undefined,
},
responseSubmittedToGemini: true,
} as unknown as TrackedToolCall)
: tc,
),
);
break;
case 'error':
addItem(
@@ -263,9 +324,152 @@ export const useAgentStream = (
[addItem, logger, startNewPrompt],
);
useEffect(() => {
if (trackedTools.length > 0) {
const isNewBatch = !trackedTools.some((tc) =>
pushedToolCallIdsRef.current.has(tc.request.callId),
);
if (isNewBatch) {
setPushedToolCallIds(new Set());
setIsFirstToolInGroup(true);
}
} else if (streamingState === StreamingState.Idle) {
setPushedToolCallIds(new Set());
setIsFirstToolInGroup(true);
}
}, [
trackedTools,
pushedToolCallIdsRef,
setPushedToolCallIds,
setIsFirstToolInGroup,
streamingState,
]);
// Push completed tools to history
useEffect(() => {
const toolsToPush: TrackedToolCall[] = [];
for (let i = 0; i < trackedTools.length; i++) {
const tc = trackedTools[i];
if (pushedToolCallIdsRef.current.has(tc.request.callId)) continue;
if (
tc.status === 'success' ||
tc.status === 'error' ||
tc.status === 'cancelled'
) {
toolsToPush.push(tc);
} else {
break;
}
}
if (toolsToPush.length > 0) {
const newPushed = new Set(pushedToolCallIdsRef.current);
for (const tc of toolsToPush) {
newPushed.add(tc.request.callId);
}
const isLastInBatch =
toolsToPush[toolsToPush.length - 1] === trackedTools[trackedTools.length - 1];
const historyItem = mapTrackedToolCallsToDisplay(toolsToPush, {
borderTop: isFirstToolInGroupRef.current,
borderBottom: isLastInBatch,
...getToolGroupBorderAppearance(
{ type: 'tool_group', tools: trackedTools as any[] },
activePtyId,
!!_isShellFocused,
[],
backgroundShells,
),
});
addItem(historyItem);
setPushedToolCallIds(newPushed);
setIsFirstToolInGroup(false);
}
}, [
trackedTools,
pushedToolCallIdsRef,
isFirstToolInGroupRef,
setPushedToolCallIds,
setIsFirstToolInGroup,
addItem,
activePtyId,
_isShellFocused,
backgroundShells,
]);
const pendingToolGroupItems = useMemo((): HistoryItemWithoutId[] => {
const remainingTools = trackedTools.filter(
(tc) => !pushedToolCallIds.has(tc.request.callId),
);
const items: HistoryItemWithoutId[] = [];
const appearance = getToolGroupBorderAppearance(
{ type: 'tool_group', tools: trackedTools as any[] },
activePtyId,
!!_isShellFocused,
[],
backgroundShells,
);
if (remainingTools.length > 0) {
items.push(
mapTrackedToolCallsToDisplay(remainingTools, {
borderTop: pushedToolCallIds.size === 0,
borderBottom: false,
...appearance,
}),
);
}
const allTerminal =
trackedTools.length > 0 &&
trackedTools.every(
(tc) =>
tc.status === 'success' ||
tc.status === 'error' ||
tc.status === 'cancelled',
);
const allPushed =
trackedTools.length > 0 &&
trackedTools.every((tc) => pushedToolCallIds.has(tc.request.callId));
const anyVisibleInHistory = pushedToolCallIds.size > 0;
const anyVisibleInPending = remainingTools.length > 0;
if (
trackedTools.length > 0 &&
!(allTerminal && allPushed) &&
(anyVisibleInHistory || anyVisibleInPending)
) {
items.push({
type: 'tool_group' as const,
tools: [],
borderTop: false,
borderBottom: true,
...appearance,
});
}
return items;
}, [
trackedTools,
pushedToolCallIds,
isFirstToolInGroupRef,
activePtyId,
_isShellFocused,
backgroundShells,
]);
const pendingHistoryItems = useMemo(() => {
return pendingHistoryItem ? [pendingHistoryItem] : [];
}, [pendingHistoryItem]);
return [pendingHistoryItem, ...pendingToolGroupItems].filter(
(i): i is HistoryItemWithoutId => i !== undefined && i !== null,
);
}, [pendingHistoryItem, pendingToolGroupItems]);
return {
streamingState,
+174 -118
View File
@@ -18,6 +18,8 @@ import type { Scheduler } from '../scheduler/scheduler.js';
import { recordToolCallInteractions } from '../code_assist/telemetry.js';
import { ToolErrorType, isFatalToolError } from '../tools/tool-error.js';
import { debugLogger } from '../utils/debugLogger.js';
import { MessageBusType } from '../confirmation-bus/types.js';
import type { ToolCallsUpdateMessage } from '../confirmation-bus/types.js';
import {
buildToolResponseData,
contentPartsToGeminiParts,
@@ -163,142 +165,182 @@ class LegacyAgentProtocol implements AgentProtocol {
let turnCount = 0;
const maxTurns = this._config.getMaxSessionTurns();
while (true) {
turnCount++;
if (maxTurns >= 0 && turnCount > maxTurns) {
this._finishStream('max_turns', {
code: 'MAX_TURNS_EXCEEDED',
maxTurns,
turnCount: turnCount - 1,
});
return;
const handleToolCallsUpdate = (event: ToolCallsUpdateMessage) => {
const toolUpdates: AgentEvent[] = [];
for (const tc of event.toolCalls) {
if (tc.status === 'executing') {
toolUpdates.push(
this._makeToolUpdateEvent({
requestId: tc.request.callId,
displayContent: toolResultDisplayToContentParts(tc.liveOutput),
data: {
progressMessage: tc.progressMessage,
progress: tc.progress,
progressTotal: tc.progressTotal,
pid: tc.pid,
},
}),
);
}
}
this._emit(toolUpdates);
};
const toolCallRequests: ToolCallRequestInfo[] = [];
const responseStream = this._client.sendMessageStream(
currentParts,
this._abortController.signal,
this._promptId,
undefined,
false,
currentDisplayContent,
);
currentDisplayContent = undefined;
this._config.getMessageBus().subscribe(MessageBusType.TOOL_CALLS_UPDATE, handleToolCallsUpdate);
try {
while (true) {
turnCount++;
if (maxTurns >= 0 && turnCount > maxTurns) {
this._finishStream('max_turns', {
code: 'MAX_TURNS_EXCEEDED',
maxTurns,
turnCount: turnCount - 1,
});
return;
}
const toolCallRequests: ToolCallRequestInfo[] = [];
const responseStream = this._client.sendMessageStream(
currentParts,
this._abortController.signal,
this._promptId,
undefined,
false,
currentDisplayContent,
);
currentDisplayContent = undefined;
for await (const event of responseStream) {
if (this._abortController.signal.aborted) {
this._finishStream('aborted');
return;
}
if (event.type === GeminiEventType.ToolCallRequest) {
toolCallRequests.push(event.value);
}
const translatedEvents = translateEvent(event, this._translationState);
for (const ev of translatedEvents) {
if (ev.type === 'tool_request') {
const tool = this._config.getToolRegistry().getTool(ev.name);
ev._meta = {
displayName: tool?.displayName ?? ev.name,
description: tool?.description ?? '',
isOutputMarkdown: tool?.isOutputMarkdown ?? false,
};
}
}
this._emit(translatedEvents);
switch (event.type) {
case GeminiEventType.Error:
case GeminiEventType.InvalidStream:
case GeminiEventType.ContextWindowWillOverflow:
this._finishStream('failed');
return;
case GeminiEventType.Finished:
if (toolCallRequests.length === 0) {
this._finishStream(mapFinishReason(event.value.reason));
return;
}
break;
case GeminiEventType.AgentExecutionStopped:
case GeminiEventType.UserCancelled:
case GeminiEventType.MaxSessionTurns:
this._clearActiveStream();
return;
default:
break;
}
}
for await (const event of responseStream) {
if (this._abortController.signal.aborted) {
this._finishStream('aborted');
return;
}
if (event.type === GeminiEventType.ToolCallRequest) {
toolCallRequests.push(event.value);
if (toolCallRequests.length === 0) {
this._finishStream('completed');
return;
}
this._emit(translateEvent(event, this._translationState));
switch (event.type) {
case GeminiEventType.Error:
case GeminiEventType.InvalidStream:
case GeminiEventType.ContextWindowWillOverflow:
this._finishStream('failed');
return;
case GeminiEventType.Finished:
if (toolCallRequests.length === 0) {
this._finishStream(mapFinishReason(event.value.reason));
return;
}
break;
case GeminiEventType.AgentExecutionStopped:
case GeminiEventType.UserCancelled:
case GeminiEventType.MaxSessionTurns:
this._clearActiveStream();
return;
default:
break;
}
}
if (this._abortController.signal.aborted) {
this._finishStream('aborted');
return;
}
if (toolCallRequests.length === 0) {
this._finishStream('completed');
return;
}
const completedToolCalls = await this._scheduler.schedule(
toolCallRequests,
this._abortController.signal,
);
if (this._abortController.signal.aborted) {
this._finishStream('aborted');
return;
}
const toolResponseParts: Part[] = [];
for (const tc of completedToolCalls) {
const response = tc.response;
const request = tc.request;
const content: ContentPart[] = response.error
? [{ type: 'text', text: response.error.message }]
: geminiPartsToContentParts(response.responseParts);
const displayContent = toolResultDisplayToContentParts(
response.resultDisplay,
const completedToolCalls = await this._scheduler.schedule(
toolCallRequests,
this._abortController.signal,
);
const data = buildToolResponseData(response);
this._emit([
this._makeToolResponseEvent({
requestId: request.callId,
name: request.name,
content,
isError: response.error !== undefined,
...(displayContent ? { displayContent } : {}),
...(data ? { data } : {}),
}),
]);
if (response.responseParts) {
toolResponseParts.push(...response.responseParts);
if (this._abortController.signal.aborted) {
this._finishStream('aborted');
return;
}
}
try {
const currentModel =
this._client.getCurrentSequenceModel() ?? this._config.getModel();
this._client
.getChat()
.recordCompletedToolCalls(currentModel, completedToolCalls);
await recordToolCallInteractions(this._config, completedToolCalls);
} catch (error) {
debugLogger.error(
`Error recording completed tool call information: ${error}`,
const toolResponseParts: Part[] = [];
for (const tc of completedToolCalls) {
const response = tc.response;
const request = tc.request;
const content: ContentPart[] = response.error
? [{ type: 'text', text: response.error.message }]
: geminiPartsToContentParts(response.responseParts);
const displayContent = toolResultDisplayToContentParts(
response.resultDisplay,
);
const data = buildToolResponseData(response);
this._emit([
this._makeToolResponseEvent({
requestId: request.callId,
name: request.name,
content,
isError: response.error !== undefined,
...(displayContent ? { displayContent } : {}),
...(data ? { data } : {}),
}),
]);
if (response.responseParts) {
toolResponseParts.push(...response.responseParts);
}
}
try {
const currentModel =
this._client.getCurrentSequenceModel() ?? this._config.getModel();
this._client
.getChat()
.recordCompletedToolCalls(currentModel, completedToolCalls);
await recordToolCallInteractions(this._config, completedToolCalls);
} catch (error) {
debugLogger.error(
`Error recording completed tool call information: ${error}`,
);
}
const stopTool = completedToolCalls.find(
(tc) =>
tc.response.errorType === ToolErrorType.STOP_EXECUTION &&
tc.response.error !== undefined,
);
}
if (stopTool) {
this._finishStream('completed');
return;
}
const stopTool = completedToolCalls.find(
(tc) =>
tc.response.errorType === ToolErrorType.STOP_EXECUTION &&
tc.response.error !== undefined,
);
if (stopTool) {
this._finishStream('completed');
return;
}
const fatalTool = completedToolCalls.find((tc) =>
isFatalToolError(tc.response.errorType),
);
if (fatalTool) {
this._finishStream('failed');
return;
}
const fatalTool = completedToolCalls.find((tc) =>
isFatalToolError(tc.response.errorType),
);
if (fatalTool) {
this._finishStream('failed');
return;
currentParts = toolResponseParts;
}
currentParts = toolResponseParts;
} finally {
this._config.getMessageBus().unsubscribe(MessageBusType.TOOL_CALLS_UPDATE, handleToolCallsUpdate);
}
}
@@ -434,6 +476,20 @@ class LegacyAgentProtocol implements AgentProtocol {
return event;
}
private _makeToolUpdateEvent(
payload: Omit<
AgentEvent<'tool_update'>,
'id' | 'timestamp' | 'streamId' | 'type'
>,
): AgentEvent<'tool_update'> {
const event = {
...this._nextEventFields(),
type: 'tool_update',
...payload,
} satisfies AgentEvent<'tool_update'>;
return event;
}
private _makeAgentStartEvent(): AgentEvent<'agent_start'> {
const event = {
...this._nextEventFields(),