mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-17 23:32:43 -07:00
feat: implement tool display parity for agent stream
This commit achieves visual parity for tool execution in the interactive stream when using the experimental 'useAgentProtocol' flag. It removes direct UI dependency on the tool scheduler's internal state. Key changes: - Core 'LegacyAgentSession' now attaches display metadata (displayName, description, etc.) to 'tool_request' AgentEvents. - Core 'LegacyAgentSession' listens to the MessageBus to emit 'tool_update' AgentEvents for live output (e.g., shell commands). - UI 'useAgentStream' now maintains its own 'trackedTools' local state, constructed entirely from incoming 'tool_request', 'tool_update', and 'tool_response' events. - The local 'trackedTools' state is mapped to 'pendingToolGroupItems' using the existing 'mapToDisplay' function for seamless visual parity.
This commit is contained in:
@@ -30,12 +30,18 @@ import type {
|
||||
} from '../types.js';
|
||||
import { StreamingState, MessageType } from '../types.js';
|
||||
import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
|
||||
import { getToolGroupBorderAppearance } from '../utils/borderStyles.js';
|
||||
import { type BackgroundShell } from './shellCommandProcessor.js';
|
||||
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import { useLogger } from './useLogger.js';
|
||||
import { mapToDisplay as mapTrackedToolCallsToDisplay } from './toolMapping.js';
|
||||
import {
|
||||
useToolScheduler,
|
||||
} from './useToolScheduler.js';
|
||||
import type {
|
||||
TrackedToolCall,
|
||||
} from './useToolScheduler.js';
|
||||
|
||||
import { useSessionStats } from '../contexts/SessionContext.js';
|
||||
import type { LoadedSettings } from '../../config/settings.js';
|
||||
import type { SlashCommandProcessorResult } from '../types.js';
|
||||
@@ -86,6 +92,13 @@ export const useAgentStream = (
|
||||
const [pendingHistoryItem, pendingHistoryItemRef, setPendingHistoryItem] =
|
||||
useStateAndRef<HistoryItemWithoutId | null>(null);
|
||||
|
||||
const [trackedTools, , setTrackedTools] =
|
||||
useStateAndRef<TrackedToolCall[]>([]);
|
||||
const [pushedToolCallIds, pushedToolCallIdsRef, setPushedToolCallIds] =
|
||||
useStateAndRef<Set<string>>(new Set());
|
||||
const [_isFirstToolInGroup, isFirstToolInGroupRef, setIsFirstToolInGroup] =
|
||||
useStateAndRef<boolean>(true);
|
||||
|
||||
const [
|
||||
toolCalls,
|
||||
_schedule,
|
||||
@@ -189,10 +202,58 @@ export const useAgentStream = (
|
||||
}
|
||||
break;
|
||||
case 'tool_request':
|
||||
// UI state is handled automatically by useToolScheduler via MessageBus
|
||||
setTrackedTools((prev) => [
|
||||
...prev,
|
||||
{
|
||||
request: {
|
||||
callId: event.requestId,
|
||||
name: event.name,
|
||||
args: event.args,
|
||||
isClientInitiated: false,
|
||||
originalRequestName: event.name,
|
||||
},
|
||||
status: 'executing',
|
||||
tool: {
|
||||
displayName: (event._meta?.['displayName'] as string) ?? event.name,
|
||||
isOutputMarkdown: (event._meta?.['isOutputMarkdown'] as boolean) ?? false,
|
||||
},
|
||||
invocation: {
|
||||
getDescription: () => (event._meta?.['description'] as string) ?? '',
|
||||
},
|
||||
} as unknown as TrackedToolCall,
|
||||
]);
|
||||
break;
|
||||
case 'tool_update':
|
||||
setTrackedTools((prev) =>
|
||||
prev.map((tc) =>
|
||||
tc.request.callId === event.requestId
|
||||
? ({
|
||||
...tc,
|
||||
liveOutput: event.displayContent?.[0]?.type === 'text' ? event.displayContent[0].text : undefined,
|
||||
progressMessage: event.data?.['progressMessage'] as string | undefined,
|
||||
progress: event.data?.['progress'] as number | undefined,
|
||||
progressTotal: event.data?.['progressTotal'] as number | undefined,
|
||||
pid: event.data?.['pid'] as number | undefined,
|
||||
} as unknown as TrackedToolCall)
|
||||
: tc,
|
||||
),
|
||||
);
|
||||
break;
|
||||
case 'tool_response':
|
||||
// UI state is handled automatically by useToolScheduler via MessageBus
|
||||
setTrackedTools((prev) =>
|
||||
prev.map((tc) =>
|
||||
tc.request.callId === event.requestId
|
||||
? ({
|
||||
...tc,
|
||||
status: event.isError ? 'error' : 'success',
|
||||
response: {
|
||||
resultDisplay: event.displayContent?.[0]?.type === 'text' ? event.displayContent[0].text : undefined,
|
||||
},
|
||||
responseSubmittedToGemini: true,
|
||||
} as unknown as TrackedToolCall)
|
||||
: tc,
|
||||
),
|
||||
);
|
||||
break;
|
||||
case 'error':
|
||||
addItem(
|
||||
@@ -263,9 +324,152 @@ export const useAgentStream = (
|
||||
[addItem, logger, startNewPrompt],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (trackedTools.length > 0) {
|
||||
const isNewBatch = !trackedTools.some((tc) =>
|
||||
pushedToolCallIdsRef.current.has(tc.request.callId),
|
||||
);
|
||||
if (isNewBatch) {
|
||||
setPushedToolCallIds(new Set());
|
||||
setIsFirstToolInGroup(true);
|
||||
}
|
||||
} else if (streamingState === StreamingState.Idle) {
|
||||
setPushedToolCallIds(new Set());
|
||||
setIsFirstToolInGroup(true);
|
||||
}
|
||||
}, [
|
||||
trackedTools,
|
||||
pushedToolCallIdsRef,
|
||||
setPushedToolCallIds,
|
||||
setIsFirstToolInGroup,
|
||||
streamingState,
|
||||
]);
|
||||
|
||||
// Push completed tools to history
|
||||
useEffect(() => {
|
||||
const toolsToPush: TrackedToolCall[] = [];
|
||||
for (let i = 0; i < trackedTools.length; i++) {
|
||||
const tc = trackedTools[i];
|
||||
if (pushedToolCallIdsRef.current.has(tc.request.callId)) continue;
|
||||
|
||||
if (
|
||||
tc.status === 'success' ||
|
||||
tc.status === 'error' ||
|
||||
tc.status === 'cancelled'
|
||||
) {
|
||||
toolsToPush.push(tc);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (toolsToPush.length > 0) {
|
||||
const newPushed = new Set(pushedToolCallIdsRef.current);
|
||||
for (const tc of toolsToPush) {
|
||||
newPushed.add(tc.request.callId);
|
||||
}
|
||||
|
||||
const isLastInBatch =
|
||||
toolsToPush[toolsToPush.length - 1] === trackedTools[trackedTools.length - 1];
|
||||
|
||||
const historyItem = mapTrackedToolCallsToDisplay(toolsToPush, {
|
||||
borderTop: isFirstToolInGroupRef.current,
|
||||
borderBottom: isLastInBatch,
|
||||
...getToolGroupBorderAppearance(
|
||||
{ type: 'tool_group', tools: trackedTools as any[] },
|
||||
activePtyId,
|
||||
!!_isShellFocused,
|
||||
[],
|
||||
backgroundShells,
|
||||
),
|
||||
});
|
||||
|
||||
addItem(historyItem);
|
||||
setPushedToolCallIds(newPushed);
|
||||
setIsFirstToolInGroup(false);
|
||||
}
|
||||
}, [
|
||||
trackedTools,
|
||||
pushedToolCallIdsRef,
|
||||
isFirstToolInGroupRef,
|
||||
setPushedToolCallIds,
|
||||
setIsFirstToolInGroup,
|
||||
addItem,
|
||||
activePtyId,
|
||||
_isShellFocused,
|
||||
backgroundShells,
|
||||
]);
|
||||
|
||||
const pendingToolGroupItems = useMemo((): HistoryItemWithoutId[] => {
|
||||
const remainingTools = trackedTools.filter(
|
||||
(tc) => !pushedToolCallIds.has(tc.request.callId),
|
||||
);
|
||||
|
||||
const items: HistoryItemWithoutId[] = [];
|
||||
|
||||
const appearance = getToolGroupBorderAppearance(
|
||||
{ type: 'tool_group', tools: trackedTools as any[] },
|
||||
activePtyId,
|
||||
!!_isShellFocused,
|
||||
[],
|
||||
backgroundShells,
|
||||
);
|
||||
|
||||
if (remainingTools.length > 0) {
|
||||
items.push(
|
||||
mapTrackedToolCallsToDisplay(remainingTools, {
|
||||
borderTop: pushedToolCallIds.size === 0,
|
||||
borderBottom: false,
|
||||
...appearance,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
const allTerminal =
|
||||
trackedTools.length > 0 &&
|
||||
trackedTools.every(
|
||||
(tc) =>
|
||||
tc.status === 'success' ||
|
||||
tc.status === 'error' ||
|
||||
tc.status === 'cancelled',
|
||||
);
|
||||
|
||||
const allPushed =
|
||||
trackedTools.length > 0 &&
|
||||
trackedTools.every((tc) => pushedToolCallIds.has(tc.request.callId));
|
||||
|
||||
const anyVisibleInHistory = pushedToolCallIds.size > 0;
|
||||
const anyVisibleInPending = remainingTools.length > 0;
|
||||
|
||||
if (
|
||||
trackedTools.length > 0 &&
|
||||
!(allTerminal && allPushed) &&
|
||||
(anyVisibleInHistory || anyVisibleInPending)
|
||||
) {
|
||||
items.push({
|
||||
type: 'tool_group' as const,
|
||||
tools: [],
|
||||
borderTop: false,
|
||||
borderBottom: true,
|
||||
...appearance,
|
||||
});
|
||||
}
|
||||
|
||||
return items;
|
||||
}, [
|
||||
trackedTools,
|
||||
pushedToolCallIds,
|
||||
isFirstToolInGroupRef,
|
||||
activePtyId,
|
||||
_isShellFocused,
|
||||
backgroundShells,
|
||||
]);
|
||||
|
||||
const pendingHistoryItems = useMemo(() => {
|
||||
return pendingHistoryItem ? [pendingHistoryItem] : [];
|
||||
}, [pendingHistoryItem]);
|
||||
return [pendingHistoryItem, ...pendingToolGroupItems].filter(
|
||||
(i): i is HistoryItemWithoutId => i !== undefined && i !== null,
|
||||
);
|
||||
}, [pendingHistoryItem, pendingToolGroupItems]);
|
||||
|
||||
return {
|
||||
streamingState,
|
||||
|
||||
@@ -18,6 +18,8 @@ import type { Scheduler } from '../scheduler/scheduler.js';
|
||||
import { recordToolCallInteractions } from '../code_assist/telemetry.js';
|
||||
import { ToolErrorType, isFatalToolError } from '../tools/tool-error.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { MessageBusType } from '../confirmation-bus/types.js';
|
||||
import type { ToolCallsUpdateMessage } from '../confirmation-bus/types.js';
|
||||
import {
|
||||
buildToolResponseData,
|
||||
contentPartsToGeminiParts,
|
||||
@@ -163,142 +165,182 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
let turnCount = 0;
|
||||
const maxTurns = this._config.getMaxSessionTurns();
|
||||
|
||||
while (true) {
|
||||
turnCount++;
|
||||
if (maxTurns >= 0 && turnCount > maxTurns) {
|
||||
this._finishStream('max_turns', {
|
||||
code: 'MAX_TURNS_EXCEEDED',
|
||||
maxTurns,
|
||||
turnCount: turnCount - 1,
|
||||
});
|
||||
return;
|
||||
const handleToolCallsUpdate = (event: ToolCallsUpdateMessage) => {
|
||||
const toolUpdates: AgentEvent[] = [];
|
||||
for (const tc of event.toolCalls) {
|
||||
if (tc.status === 'executing') {
|
||||
toolUpdates.push(
|
||||
this._makeToolUpdateEvent({
|
||||
requestId: tc.request.callId,
|
||||
displayContent: toolResultDisplayToContentParts(tc.liveOutput),
|
||||
data: {
|
||||
progressMessage: tc.progressMessage,
|
||||
progress: tc.progress,
|
||||
progressTotal: tc.progressTotal,
|
||||
pid: tc.pid,
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
this._emit(toolUpdates);
|
||||
};
|
||||
|
||||
const toolCallRequests: ToolCallRequestInfo[] = [];
|
||||
const responseStream = this._client.sendMessageStream(
|
||||
currentParts,
|
||||
this._abortController.signal,
|
||||
this._promptId,
|
||||
undefined,
|
||||
false,
|
||||
currentDisplayContent,
|
||||
);
|
||||
currentDisplayContent = undefined;
|
||||
this._config.getMessageBus().subscribe(MessageBusType.TOOL_CALLS_UPDATE, handleToolCallsUpdate);
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
turnCount++;
|
||||
if (maxTurns >= 0 && turnCount > maxTurns) {
|
||||
this._finishStream('max_turns', {
|
||||
code: 'MAX_TURNS_EXCEEDED',
|
||||
maxTurns,
|
||||
turnCount: turnCount - 1,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const toolCallRequests: ToolCallRequestInfo[] = [];
|
||||
const responseStream = this._client.sendMessageStream(
|
||||
currentParts,
|
||||
this._abortController.signal,
|
||||
this._promptId,
|
||||
undefined,
|
||||
false,
|
||||
currentDisplayContent,
|
||||
);
|
||||
currentDisplayContent = undefined;
|
||||
|
||||
for await (const event of responseStream) {
|
||||
if (this._abortController.signal.aborted) {
|
||||
this._finishStream('aborted');
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.type === GeminiEventType.ToolCallRequest) {
|
||||
toolCallRequests.push(event.value);
|
||||
}
|
||||
|
||||
const translatedEvents = translateEvent(event, this._translationState);
|
||||
|
||||
for (const ev of translatedEvents) {
|
||||
if (ev.type === 'tool_request') {
|
||||
const tool = this._config.getToolRegistry().getTool(ev.name);
|
||||
ev._meta = {
|
||||
displayName: tool?.displayName ?? ev.name,
|
||||
description: tool?.description ?? '',
|
||||
isOutputMarkdown: tool?.isOutputMarkdown ?? false,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
this._emit(translatedEvents);
|
||||
|
||||
switch (event.type) {
|
||||
case GeminiEventType.Error:
|
||||
case GeminiEventType.InvalidStream:
|
||||
case GeminiEventType.ContextWindowWillOverflow:
|
||||
this._finishStream('failed');
|
||||
return;
|
||||
case GeminiEventType.Finished:
|
||||
if (toolCallRequests.length === 0) {
|
||||
this._finishStream(mapFinishReason(event.value.reason));
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case GeminiEventType.AgentExecutionStopped:
|
||||
case GeminiEventType.UserCancelled:
|
||||
case GeminiEventType.MaxSessionTurns:
|
||||
this._clearActiveStream();
|
||||
return;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for await (const event of responseStream) {
|
||||
if (this._abortController.signal.aborted) {
|
||||
this._finishStream('aborted');
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.type === GeminiEventType.ToolCallRequest) {
|
||||
toolCallRequests.push(event.value);
|
||||
if (toolCallRequests.length === 0) {
|
||||
this._finishStream('completed');
|
||||
return;
|
||||
}
|
||||
|
||||
this._emit(translateEvent(event, this._translationState));
|
||||
|
||||
switch (event.type) {
|
||||
case GeminiEventType.Error:
|
||||
case GeminiEventType.InvalidStream:
|
||||
case GeminiEventType.ContextWindowWillOverflow:
|
||||
this._finishStream('failed');
|
||||
return;
|
||||
case GeminiEventType.Finished:
|
||||
if (toolCallRequests.length === 0) {
|
||||
this._finishStream(mapFinishReason(event.value.reason));
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case GeminiEventType.AgentExecutionStopped:
|
||||
case GeminiEventType.UserCancelled:
|
||||
case GeminiEventType.MaxSessionTurns:
|
||||
this._clearActiveStream();
|
||||
return;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (this._abortController.signal.aborted) {
|
||||
this._finishStream('aborted');
|
||||
return;
|
||||
}
|
||||
|
||||
if (toolCallRequests.length === 0) {
|
||||
this._finishStream('completed');
|
||||
return;
|
||||
}
|
||||
|
||||
const completedToolCalls = await this._scheduler.schedule(
|
||||
toolCallRequests,
|
||||
this._abortController.signal,
|
||||
);
|
||||
|
||||
if (this._abortController.signal.aborted) {
|
||||
this._finishStream('aborted');
|
||||
return;
|
||||
}
|
||||
|
||||
const toolResponseParts: Part[] = [];
|
||||
for (const tc of completedToolCalls) {
|
||||
const response = tc.response;
|
||||
const request = tc.request;
|
||||
const content: ContentPart[] = response.error
|
||||
? [{ type: 'text', text: response.error.message }]
|
||||
: geminiPartsToContentParts(response.responseParts);
|
||||
const displayContent = toolResultDisplayToContentParts(
|
||||
response.resultDisplay,
|
||||
const completedToolCalls = await this._scheduler.schedule(
|
||||
toolCallRequests,
|
||||
this._abortController.signal,
|
||||
);
|
||||
const data = buildToolResponseData(response);
|
||||
|
||||
this._emit([
|
||||
this._makeToolResponseEvent({
|
||||
requestId: request.callId,
|
||||
name: request.name,
|
||||
content,
|
||||
isError: response.error !== undefined,
|
||||
...(displayContent ? { displayContent } : {}),
|
||||
...(data ? { data } : {}),
|
||||
}),
|
||||
]);
|
||||
|
||||
if (response.responseParts) {
|
||||
toolResponseParts.push(...response.responseParts);
|
||||
if (this._abortController.signal.aborted) {
|
||||
this._finishStream('aborted');
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const currentModel =
|
||||
this._client.getCurrentSequenceModel() ?? this._config.getModel();
|
||||
this._client
|
||||
.getChat()
|
||||
.recordCompletedToolCalls(currentModel, completedToolCalls);
|
||||
await recordToolCallInteractions(this._config, completedToolCalls);
|
||||
} catch (error) {
|
||||
debugLogger.error(
|
||||
`Error recording completed tool call information: ${error}`,
|
||||
const toolResponseParts: Part[] = [];
|
||||
for (const tc of completedToolCalls) {
|
||||
const response = tc.response;
|
||||
const request = tc.request;
|
||||
const content: ContentPart[] = response.error
|
||||
? [{ type: 'text', text: response.error.message }]
|
||||
: geminiPartsToContentParts(response.responseParts);
|
||||
const displayContent = toolResultDisplayToContentParts(
|
||||
response.resultDisplay,
|
||||
);
|
||||
const data = buildToolResponseData(response);
|
||||
|
||||
this._emit([
|
||||
this._makeToolResponseEvent({
|
||||
requestId: request.callId,
|
||||
name: request.name,
|
||||
content,
|
||||
isError: response.error !== undefined,
|
||||
...(displayContent ? { displayContent } : {}),
|
||||
...(data ? { data } : {}),
|
||||
}),
|
||||
]);
|
||||
|
||||
if (response.responseParts) {
|
||||
toolResponseParts.push(...response.responseParts);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const currentModel =
|
||||
this._client.getCurrentSequenceModel() ?? this._config.getModel();
|
||||
this._client
|
||||
.getChat()
|
||||
.recordCompletedToolCalls(currentModel, completedToolCalls);
|
||||
await recordToolCallInteractions(this._config, completedToolCalls);
|
||||
} catch (error) {
|
||||
debugLogger.error(
|
||||
`Error recording completed tool call information: ${error}`,
|
||||
);
|
||||
}
|
||||
|
||||
const stopTool = completedToolCalls.find(
|
||||
(tc) =>
|
||||
tc.response.errorType === ToolErrorType.STOP_EXECUTION &&
|
||||
tc.response.error !== undefined,
|
||||
);
|
||||
}
|
||||
if (stopTool) {
|
||||
this._finishStream('completed');
|
||||
return;
|
||||
}
|
||||
|
||||
const stopTool = completedToolCalls.find(
|
||||
(tc) =>
|
||||
tc.response.errorType === ToolErrorType.STOP_EXECUTION &&
|
||||
tc.response.error !== undefined,
|
||||
);
|
||||
if (stopTool) {
|
||||
this._finishStream('completed');
|
||||
return;
|
||||
}
|
||||
const fatalTool = completedToolCalls.find((tc) =>
|
||||
isFatalToolError(tc.response.errorType),
|
||||
);
|
||||
if (fatalTool) {
|
||||
this._finishStream('failed');
|
||||
return;
|
||||
}
|
||||
|
||||
const fatalTool = completedToolCalls.find((tc) =>
|
||||
isFatalToolError(tc.response.errorType),
|
||||
);
|
||||
if (fatalTool) {
|
||||
this._finishStream('failed');
|
||||
return;
|
||||
currentParts = toolResponseParts;
|
||||
}
|
||||
|
||||
currentParts = toolResponseParts;
|
||||
} finally {
|
||||
this._config.getMessageBus().unsubscribe(MessageBusType.TOOL_CALLS_UPDATE, handleToolCallsUpdate);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -434,6 +476,20 @@ class LegacyAgentProtocol implements AgentProtocol {
|
||||
return event;
|
||||
}
|
||||
|
||||
private _makeToolUpdateEvent(
|
||||
payload: Omit<
|
||||
AgentEvent<'tool_update'>,
|
||||
'id' | 'timestamp' | 'streamId' | 'type'
|
||||
>,
|
||||
): AgentEvent<'tool_update'> {
|
||||
const event = {
|
||||
...this._nextEventFields(),
|
||||
type: 'tool_update',
|
||||
...payload,
|
||||
} satisfies AgentEvent<'tool_update'>;
|
||||
return event;
|
||||
}
|
||||
|
||||
private _makeAgentStartEvent(): AgentEvent<'agent_start'> {
|
||||
const event = {
|
||||
...this._nextEventFields(),
|
||||
|
||||
Reference in New Issue
Block a user