mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-11 05:41:08 -07:00
feat(core, cli): Implement sequential approval. (#11593)
This commit is contained in:
@@ -37,7 +37,7 @@ import {
|
||||
} from '@google/gemini-cli-core';
|
||||
import type { Part, PartListUnion } from '@google/genai';
|
||||
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import type { HistoryItem, SlashCommandProcessorResult } from '../types.js';
|
||||
import type { SlashCommandProcessorResult } from '../types.js';
|
||||
import { MessageType, StreamingState } from '../types.js';
|
||||
import type { LoadedSettings } from '../../config/settings.js';
|
||||
|
||||
@@ -231,8 +231,9 @@ describe('useGeminiStream', () => {
|
||||
mockUseReactToolScheduler.mockReturnValue([
|
||||
[], // Default to empty array for toolCalls
|
||||
mockScheduleToolCalls,
|
||||
mockCancelAllToolCalls,
|
||||
mockMarkToolsAsSubmitted,
|
||||
vi.fn(), // setToolCallsForDisplay
|
||||
mockCancelAllToolCalls,
|
||||
]);
|
||||
|
||||
// Reset mocks for GeminiClient instance methods (startChat and sendMessageStream)
|
||||
@@ -259,38 +260,71 @@ describe('useGeminiStream', () => {
|
||||
initialToolCalls: TrackedToolCall[] = [],
|
||||
geminiClient?: any,
|
||||
) => {
|
||||
let currentToolCalls = initialToolCalls;
|
||||
const setToolCalls = (newToolCalls: TrackedToolCall[]) => {
|
||||
currentToolCalls = newToolCalls;
|
||||
};
|
||||
|
||||
mockUseReactToolScheduler.mockImplementation(() => [
|
||||
currentToolCalls,
|
||||
mockScheduleToolCalls,
|
||||
mockCancelAllToolCalls,
|
||||
mockMarkToolsAsSubmitted,
|
||||
]);
|
||||
|
||||
const client = geminiClient || mockConfig.getGeminiClient();
|
||||
|
||||
const initialProps = {
|
||||
client,
|
||||
history: [],
|
||||
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||
config: mockConfig,
|
||||
onDebugMessage: mockOnDebugMessage,
|
||||
handleSlashCommand: mockHandleSlashCommand as unknown as (
|
||||
cmd: PartListUnion,
|
||||
) => Promise<SlashCommandProcessorResult | false>,
|
||||
shellModeActive: false,
|
||||
loadedSettings: mockLoadedSettings,
|
||||
toolCalls: initialToolCalls,
|
||||
};
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
(props: {
|
||||
client: any;
|
||||
history: HistoryItem[];
|
||||
addItem: UseHistoryManagerReturn['addItem'];
|
||||
config: Config;
|
||||
onDebugMessage: (message: string) => void;
|
||||
handleSlashCommand: (
|
||||
cmd: PartListUnion,
|
||||
) => Promise<SlashCommandProcessorResult | false>;
|
||||
shellModeActive: boolean;
|
||||
loadedSettings: LoadedSettings;
|
||||
toolCalls?: TrackedToolCall[]; // Allow passing updated toolCalls
|
||||
}) => {
|
||||
// Update the mock's return value if new toolCalls are passed in props
|
||||
if (props.toolCalls) {
|
||||
setToolCalls(props.toolCalls);
|
||||
}
|
||||
(props: typeof initialProps) => {
|
||||
// This mock needs to be stateful. When setToolCallsForDisplay is called,
|
||||
// it should trigger a rerender with the new state.
|
||||
const mockSetToolCallsForDisplay = vi.fn((updater) => {
|
||||
const newToolCalls =
|
||||
typeof updater === 'function' ? updater(props.toolCalls) : updater;
|
||||
rerender({ ...props, toolCalls: newToolCalls });
|
||||
});
|
||||
|
||||
// Create a stateful mock for cancellation that updates the toolCalls state.
|
||||
const statefulCancelAllToolCalls = vi.fn((...args) => {
|
||||
// Call the original spy so `toHaveBeenCalled` checks still work.
|
||||
mockCancelAllToolCalls(...args);
|
||||
|
||||
const newToolCalls = props.toolCalls.map((tc) => {
|
||||
// Only cancel tools that are in a cancellable state.
|
||||
if (
|
||||
tc.status === 'awaiting_approval' ||
|
||||
tc.status === 'executing' ||
|
||||
tc.status === 'scheduled' ||
|
||||
tc.status === 'validating'
|
||||
) {
|
||||
// A real cancelled tool call has a response object.
|
||||
// We need to simulate this to avoid type errors downstream.
|
||||
return {
|
||||
...tc,
|
||||
status: 'cancelled',
|
||||
response: {
|
||||
callId: tc.request.callId,
|
||||
responseParts: [],
|
||||
resultDisplay: 'Request cancelled.',
|
||||
},
|
||||
responseSubmittedToGemini: true, // Mark as "processed"
|
||||
} as any as TrackedCancelledToolCall;
|
||||
}
|
||||
return tc;
|
||||
});
|
||||
rerender({ ...props, toolCalls: newToolCalls });
|
||||
});
|
||||
|
||||
mockUseReactToolScheduler.mockImplementation(() => [
|
||||
props.toolCalls,
|
||||
mockScheduleToolCalls,
|
||||
mockMarkToolsAsSubmitted,
|
||||
mockSetToolCallsForDisplay,
|
||||
statefulCancelAllToolCalls, // Use the stateful mock
|
||||
]);
|
||||
|
||||
return useGeminiStream(
|
||||
props.client,
|
||||
props.history,
|
||||
@@ -313,19 +347,7 @@ describe('useGeminiStream', () => {
|
||||
);
|
||||
},
|
||||
{
|
||||
initialProps: {
|
||||
client,
|
||||
history: [],
|
||||
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||
config: mockConfig,
|
||||
onDebugMessage: mockOnDebugMessage,
|
||||
handleSlashCommand: mockHandleSlashCommand as unknown as (
|
||||
cmd: PartListUnion,
|
||||
) => Promise<SlashCommandProcessorResult | false>,
|
||||
shellModeActive: false,
|
||||
loadedSettings: mockLoadedSettings,
|
||||
toolCalls: initialToolCalls,
|
||||
},
|
||||
initialProps,
|
||||
},
|
||||
);
|
||||
return {
|
||||
@@ -452,7 +474,7 @@ describe('useGeminiStream', () => {
|
||||
|
||||
mockUseReactToolScheduler.mockImplementation((onComplete) => {
|
||||
capturedOnComplete = onComplete;
|
||||
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted];
|
||||
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted, vi.fn()];
|
||||
});
|
||||
|
||||
renderHook(() =>
|
||||
@@ -535,7 +557,7 @@ describe('useGeminiStream', () => {
|
||||
|
||||
mockUseReactToolScheduler.mockImplementation((onComplete) => {
|
||||
capturedOnComplete = onComplete;
|
||||
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted];
|
||||
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted, vi.fn()];
|
||||
});
|
||||
|
||||
renderHook(() =>
|
||||
@@ -647,7 +669,7 @@ describe('useGeminiStream', () => {
|
||||
|
||||
mockUseReactToolScheduler.mockImplementation((onComplete) => {
|
||||
capturedOnComplete = onComplete;
|
||||
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted];
|
||||
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted, vi.fn()];
|
||||
});
|
||||
|
||||
renderHook(() =>
|
||||
@@ -760,6 +782,7 @@ describe('useGeminiStream', () => {
|
||||
currentToolCalls,
|
||||
mockScheduleToolCalls,
|
||||
mockMarkToolsAsSubmitted,
|
||||
vi.fn(), // setToolCallsForDisplay
|
||||
];
|
||||
});
|
||||
|
||||
@@ -797,6 +820,7 @@ describe('useGeminiStream', () => {
|
||||
completedToolCalls,
|
||||
mockScheduleToolCalls,
|
||||
mockMarkToolsAsSubmitted,
|
||||
vi.fn(), // setToolCallsForDisplay
|
||||
];
|
||||
});
|
||||
|
||||
@@ -1031,7 +1055,7 @@ describe('useGeminiStream', () => {
|
||||
expect(result.current.streamingState).toBe(StreamingState.Idle);
|
||||
});
|
||||
|
||||
it('should not cancel if a tool call is in progress (not just responding)', async () => {
|
||||
it('should cancel if a tool call is in progress', async () => {
|
||||
const toolCalls: TrackedToolCall[] = [
|
||||
{
|
||||
request: { callId: 'call1', name: 'tool1', args: {} },
|
||||
@@ -1052,7 +1076,6 @@ describe('useGeminiStream', () => {
|
||||
} as TrackedExecutingToolCall,
|
||||
];
|
||||
|
||||
const abortSpy = vi.spyOn(AbortController.prototype, 'abort');
|
||||
const { result } = renderTestHook(toolCalls);
|
||||
|
||||
// State is `Responding` because a tool is running
|
||||
@@ -1061,8 +1084,71 @@ describe('useGeminiStream', () => {
|
||||
// Try to cancel
|
||||
simulateEscapeKeyPress();
|
||||
|
||||
// Nothing should happen because the state is not `Responding`
|
||||
expect(abortSpy).not.toHaveBeenCalled();
|
||||
// The cancel function should be called
|
||||
expect(mockCancelAllToolCalls).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should cancel a request when a tool is awaiting confirmation', async () => {
|
||||
const mockOnConfirm = vi.fn().mockResolvedValue(undefined);
|
||||
const toolCalls: TrackedToolCall[] = [
|
||||
{
|
||||
request: {
|
||||
callId: 'confirm-call',
|
||||
name: 'some_tool',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
status: 'awaiting_approval',
|
||||
responseSubmittedToGemini: false,
|
||||
tool: {
|
||||
name: 'some_tool',
|
||||
description: 'a tool',
|
||||
build: vi.fn().mockImplementation((_) => ({
|
||||
getDescription: () => `Mock description`,
|
||||
})),
|
||||
} as any,
|
||||
invocation: {
|
||||
getDescription: () => `Mock description`,
|
||||
} as unknown as AnyToolInvocation,
|
||||
confirmationDetails: {
|
||||
type: 'edit',
|
||||
title: 'Confirm Edit',
|
||||
onConfirm: mockOnConfirm,
|
||||
fileName: 'file.txt',
|
||||
filePath: '/test/file.txt',
|
||||
fileDiff: 'fake diff',
|
||||
originalContent: 'old',
|
||||
newContent: 'new',
|
||||
},
|
||||
} as TrackedWaitingToolCall,
|
||||
];
|
||||
|
||||
const { result } = renderTestHook(toolCalls);
|
||||
|
||||
// State is `WaitingForConfirmation` because a tool is awaiting approval
|
||||
expect(result.current.streamingState).toBe(
|
||||
StreamingState.WaitingForConfirmation,
|
||||
);
|
||||
|
||||
// Try to cancel
|
||||
simulateEscapeKeyPress();
|
||||
|
||||
// The imperative cancel function should be called on the scheduler
|
||||
expect(mockCancelAllToolCalls).toHaveBeenCalled();
|
||||
|
||||
// A cancellation message should be added to history
|
||||
await waitFor(() => {
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
text: 'Request cancelled.',
|
||||
}),
|
||||
expect.any(Number),
|
||||
);
|
||||
});
|
||||
|
||||
// The final state should be idle
|
||||
expect(result.current.streamingState).toBe(StreamingState.Idle);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1282,7 +1368,7 @@ describe('useGeminiStream', () => {
|
||||
|
||||
mockUseReactToolScheduler.mockImplementation((onComplete) => {
|
||||
capturedOnComplete = onComplete;
|
||||
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted];
|
||||
return [[], mockScheduleToolCalls, mockMarkToolsAsSubmitted, vi.fn()];
|
||||
});
|
||||
|
||||
renderHook(() =>
|
||||
|
||||
@@ -111,6 +111,7 @@ export const useGeminiStream = (
|
||||
const [initError, setInitError] = useState<string | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
const turnCancelledRef = useRef(false);
|
||||
const activeQueryIdRef = useRef<string | null>(null);
|
||||
const [isResponding, setIsResponding] = useState<boolean>(false);
|
||||
const [thought, setThought] = useState<ThoughtSummary | null>(null);
|
||||
const [pendingHistoryItem, pendingHistoryItemRef, setPendingHistoryItem] =
|
||||
@@ -126,47 +127,55 @@ export const useGeminiStream = (
|
||||
return new GitService(config.getProjectRoot(), storage);
|
||||
}, [config, storage]);
|
||||
|
||||
const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] =
|
||||
useReactToolScheduler(
|
||||
async (completedToolCallsFromScheduler) => {
|
||||
// This onComplete is called when ALL scheduled tools for a given batch are done.
|
||||
if (completedToolCallsFromScheduler.length > 0) {
|
||||
// Add the final state of these tools to the history for display.
|
||||
addItem(
|
||||
mapTrackedToolCallsToDisplay(
|
||||
completedToolCallsFromScheduler as TrackedToolCall[],
|
||||
),
|
||||
Date.now(),
|
||||
);
|
||||
|
||||
// Record tool calls with full metadata before sending responses.
|
||||
try {
|
||||
const currentModel =
|
||||
config.getGeminiClient().getCurrentSequenceModel() ??
|
||||
config.getModel();
|
||||
config
|
||||
.getGeminiClient()
|
||||
.getChat()
|
||||
.recordCompletedToolCalls(
|
||||
currentModel,
|
||||
completedToolCallsFromScheduler,
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error recording completed tool call information: ${error}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Handle tool response submission immediately when tools complete
|
||||
await handleCompletedTools(
|
||||
const [
|
||||
toolCalls,
|
||||
scheduleToolCalls,
|
||||
markToolsAsSubmitted,
|
||||
setToolCallsForDisplay,
|
||||
cancelAllToolCalls,
|
||||
] = useReactToolScheduler(
|
||||
async (completedToolCallsFromScheduler) => {
|
||||
// This onComplete is called when ALL scheduled tools for a given batch are done.
|
||||
if (completedToolCallsFromScheduler.length > 0) {
|
||||
// Add the final state of these tools to the history for display.
|
||||
addItem(
|
||||
mapTrackedToolCallsToDisplay(
|
||||
completedToolCallsFromScheduler as TrackedToolCall[],
|
||||
),
|
||||
Date.now(),
|
||||
);
|
||||
|
||||
// Clear the live-updating display now that the final state is in history.
|
||||
setToolCallsForDisplay([]);
|
||||
|
||||
// Record tool calls with full metadata before sending responses.
|
||||
try {
|
||||
const currentModel =
|
||||
config.getGeminiClient().getCurrentSequenceModel() ??
|
||||
config.getModel();
|
||||
config
|
||||
.getGeminiClient()
|
||||
.getChat()
|
||||
.recordCompletedToolCalls(
|
||||
currentModel,
|
||||
completedToolCallsFromScheduler,
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error recording completed tool call information: ${error}`,
|
||||
);
|
||||
}
|
||||
},
|
||||
config,
|
||||
getPreferredEditor,
|
||||
onEditorClose,
|
||||
);
|
||||
|
||||
// Handle tool response submission immediately when tools complete
|
||||
await handleCompletedTools(
|
||||
completedToolCallsFromScheduler as TrackedToolCall[],
|
||||
);
|
||||
}
|
||||
},
|
||||
config,
|
||||
getPreferredEditor,
|
||||
onEditorClose,
|
||||
);
|
||||
|
||||
const pendingToolCallGroupDisplay = useMemo(
|
||||
() =>
|
||||
@@ -265,27 +274,54 @@ export const useGeminiStream = (
|
||||
}, [streamingState, config, history]);
|
||||
|
||||
const cancelOngoingRequest = useCallback(() => {
|
||||
if (streamingState !== StreamingState.Responding) {
|
||||
if (
|
||||
streamingState !== StreamingState.Responding &&
|
||||
streamingState !== StreamingState.WaitingForConfirmation
|
||||
) {
|
||||
return;
|
||||
}
|
||||
if (turnCancelledRef.current) {
|
||||
return;
|
||||
}
|
||||
turnCancelledRef.current = true;
|
||||
abortControllerRef.current?.abort();
|
||||
|
||||
// A full cancellation means no tools have produced a final result yet.
|
||||
// This determines if we show a generic "Request cancelled" message.
|
||||
const isFullCancellation = !toolCalls.some(
|
||||
(tc) => tc.status === 'success' || tc.status === 'error',
|
||||
);
|
||||
|
||||
// Ensure we have an abort controller, creating one if it doesn't exist.
|
||||
if (!abortControllerRef.current) {
|
||||
abortControllerRef.current = new AbortController();
|
||||
}
|
||||
|
||||
// The order is important here.
|
||||
// 1. Fire the signal to interrupt any active async operations.
|
||||
abortControllerRef.current.abort();
|
||||
// 2. Call the imperative cancel to clear the queue of pending tools.
|
||||
cancelAllToolCalls(abortControllerRef.current.signal);
|
||||
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, Date.now());
|
||||
}
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: 'Request cancelled.',
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
setPendingHistoryItem(null);
|
||||
|
||||
// If it was a full cancellation, add the info message now.
|
||||
// Otherwise, we let handleCompletedTools figure out the next step,
|
||||
// which might involve sending partial results back to the model.
|
||||
if (isFullCancellation) {
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: 'Request cancelled.',
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
setIsResponding(false);
|
||||
}
|
||||
|
||||
onCancelSubmit();
|
||||
setIsResponding(false);
|
||||
setShellInputFocused(false);
|
||||
}, [
|
||||
streamingState,
|
||||
@@ -294,6 +330,8 @@ export const useGeminiStream = (
|
||||
onCancelSubmit,
|
||||
pendingHistoryItemRef,
|
||||
setShellInputFocused,
|
||||
cancelAllToolCalls,
|
||||
toolCalls,
|
||||
]);
|
||||
|
||||
useKeypress(
|
||||
@@ -302,7 +340,11 @@ export const useGeminiStream = (
|
||||
cancelOngoingRequest();
|
||||
}
|
||||
},
|
||||
{ isActive: streamingState === StreamingState.Responding },
|
||||
{
|
||||
isActive:
|
||||
streamingState === StreamingState.Responding ||
|
||||
streamingState === StreamingState.WaitingForConfirmation,
|
||||
},
|
||||
);
|
||||
|
||||
const prepareQueryForGemini = useCallback(
|
||||
@@ -764,6 +806,8 @@ export const useGeminiStream = (
|
||||
options?: { isContinuation: boolean },
|
||||
prompt_id?: string,
|
||||
) => {
|
||||
const queryId = `${Date.now()}-${Math.random()}`;
|
||||
activeQueryIdRef.current = queryId;
|
||||
if (
|
||||
(streamingState === StreamingState.Responding ||
|
||||
streamingState === StreamingState.WaitingForConfirmation) &&
|
||||
@@ -901,7 +945,9 @@ export const useGeminiStream = (
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
setIsResponding(false);
|
||||
if (activeQueryIdRef.current === queryId) {
|
||||
setIsResponding(false);
|
||||
}
|
||||
}
|
||||
});
|
||||
},
|
||||
@@ -963,10 +1009,6 @@ export const useGeminiStream = (
|
||||
|
||||
const handleCompletedTools = useCallback(
|
||||
async (completedToolCallsFromScheduler: TrackedToolCall[]) => {
|
||||
if (isResponding) {
|
||||
return;
|
||||
}
|
||||
|
||||
const completedAndReadyToSubmitTools =
|
||||
completedToolCallsFromScheduler.filter(
|
||||
(
|
||||
@@ -1028,6 +1070,19 @@ export const useGeminiStream = (
|
||||
);
|
||||
|
||||
if (allToolsCancelled) {
|
||||
// If the turn was cancelled via the imperative escape key flow,
|
||||
// the cancellation message is added there. We check the ref to avoid duplication.
|
||||
if (!turnCancelledRef.current) {
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: 'Request cancelled.',
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
}
|
||||
setIsResponding(false);
|
||||
|
||||
if (geminiClient) {
|
||||
// We need to manually add the function responses to the history
|
||||
// so the model knows the tools were cancelled.
|
||||
@@ -1074,12 +1129,12 @@ export const useGeminiStream = (
|
||||
);
|
||||
},
|
||||
[
|
||||
isResponding,
|
||||
submitQuery,
|
||||
markToolsAsSubmitted,
|
||||
geminiClient,
|
||||
performMemoryRefresh,
|
||||
modelSwitchedFromQuotaError,
|
||||
addItem,
|
||||
],
|
||||
);
|
||||
|
||||
|
||||
@@ -62,12 +62,20 @@ export type TrackedToolCall =
|
||||
| TrackedCompletedToolCall
|
||||
| TrackedCancelledToolCall;
|
||||
|
||||
export type CancelAllFn = (signal: AbortSignal) => void;
|
||||
|
||||
export function useReactToolScheduler(
|
||||
onComplete: (tools: CompletedToolCall[]) => Promise<void>,
|
||||
config: Config,
|
||||
getPreferredEditor: () => EditorType | undefined,
|
||||
onEditorClose: () => void,
|
||||
): [TrackedToolCall[], ScheduleFn, MarkToolsAsSubmittedFn] {
|
||||
): [
|
||||
TrackedToolCall[],
|
||||
ScheduleFn,
|
||||
MarkToolsAsSubmittedFn,
|
||||
React.Dispatch<React.SetStateAction<TrackedToolCall[]>>,
|
||||
CancelAllFn,
|
||||
] {
|
||||
const [toolCallsForDisplay, setToolCallsForDisplay] = useState<
|
||||
TrackedToolCall[]
|
||||
>([]);
|
||||
@@ -112,37 +120,36 @@ export function useReactToolScheduler(
|
||||
);
|
||||
|
||||
const toolCallsUpdateHandler: ToolCallsUpdateHandler = useCallback(
|
||||
(updatedCoreToolCalls: ToolCall[]) => {
|
||||
setToolCallsForDisplay((prevTrackedCalls) =>
|
||||
updatedCoreToolCalls.map((coreTc) => {
|
||||
const existingTrackedCall = prevTrackedCalls.find(
|
||||
(ptc) => ptc.request.callId === coreTc.request.callId,
|
||||
);
|
||||
// Start with the new core state, then layer on the existing UI state
|
||||
// to ensure UI-only properties like pid are preserved.
|
||||
(allCoreToolCalls: ToolCall[]) => {
|
||||
setToolCallsForDisplay((prevTrackedCalls) => {
|
||||
const prevCallsMap = new Map(
|
||||
prevTrackedCalls.map((c) => [c.request.callId, c]),
|
||||
);
|
||||
|
||||
return allCoreToolCalls.map((coreTc): TrackedToolCall => {
|
||||
const existingTrackedCall = prevCallsMap.get(coreTc.request.callId);
|
||||
|
||||
const responseSubmittedToGemini =
|
||||
existingTrackedCall?.responseSubmittedToGemini ?? false;
|
||||
|
||||
if (coreTc.status === 'executing') {
|
||||
// Preserve live output if it exists from a previous render.
|
||||
const liveOutput = (existingTrackedCall as TrackedExecutingToolCall)
|
||||
?.liveOutput;
|
||||
return {
|
||||
...coreTc,
|
||||
responseSubmittedToGemini,
|
||||
liveOutput: (existingTrackedCall as TrackedExecutingToolCall)
|
||||
?.liveOutput,
|
||||
liveOutput,
|
||||
pid: (coreTc as ExecutingToolCall).pid,
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
...coreTc,
|
||||
responseSubmittedToGemini,
|
||||
};
|
||||
}
|
||||
|
||||
// For other statuses, explicitly set liveOutput and pid to undefined
|
||||
// to ensure they are not carried over from a previous executing state.
|
||||
return {
|
||||
...coreTc,
|
||||
responseSubmittedToGemini,
|
||||
liveOutput: undefined,
|
||||
pid: undefined,
|
||||
};
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
},
|
||||
[setToolCallsForDisplay],
|
||||
);
|
||||
@@ -178,9 +185,10 @@ export function useReactToolScheduler(
|
||||
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||
signal: AbortSignal,
|
||||
) => {
|
||||
setToolCallsForDisplay([]);
|
||||
void scheduler.schedule(request, signal);
|
||||
},
|
||||
[scheduler],
|
||||
[scheduler, setToolCallsForDisplay],
|
||||
);
|
||||
|
||||
const markToolsAsSubmitted: MarkToolsAsSubmittedFn = useCallback(
|
||||
@@ -196,7 +204,20 @@ export function useReactToolScheduler(
|
||||
[],
|
||||
);
|
||||
|
||||
return [toolCallsForDisplay, schedule, markToolsAsSubmitted];
|
||||
const cancelAllToolCalls = useCallback(
|
||||
(signal: AbortSignal) => {
|
||||
scheduler.cancelAll(signal);
|
||||
},
|
||||
[scheduler],
|
||||
);
|
||||
|
||||
return [
|
||||
toolCallsForDisplay,
|
||||
schedule,
|
||||
markToolsAsSubmitted,
|
||||
setToolCallsForDisplay,
|
||||
cancelAllToolCalls,
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -260,9 +260,15 @@ describe('useReactToolScheduler', () => {
|
||||
args: { param: 'value' },
|
||||
} as any;
|
||||
|
||||
let completedToolCalls: ToolCall[] = [];
|
||||
onComplete.mockImplementation((calls) => {
|
||||
completedToolCalls = calls;
|
||||
});
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
@@ -292,7 +298,110 @@ describe('useReactToolScheduler', () => {
|
||||
}),
|
||||
}),
|
||||
]);
|
||||
expect(result.current[0]).toEqual([]);
|
||||
expect(completedToolCalls).toHaveLength(1);
|
||||
expect(completedToolCalls[0].status).toBe('success');
|
||||
expect(completedToolCalls[0].request).toBe(request);
|
||||
});
|
||||
|
||||
it('should clear previous tool calls when scheduling new ones', async () => {
|
||||
mockToolRegistry.getTool.mockReturnValue(mockTool);
|
||||
(mockTool.execute as Mock).mockResolvedValue({
|
||||
llmContent: 'Tool output',
|
||||
returnDisplay: 'Formatted tool output',
|
||||
} as ToolResult);
|
||||
|
||||
const { result } = renderScheduler();
|
||||
const schedule = result.current[1];
|
||||
const setToolCallsForDisplay = result.current[3];
|
||||
|
||||
// Manually set a tool call in the display.
|
||||
const oldToolCall = {
|
||||
request: { callId: 'oldCall' },
|
||||
status: 'success',
|
||||
} as any;
|
||||
act(() => {
|
||||
setToolCallsForDisplay([oldToolCall]);
|
||||
});
|
||||
expect(result.current[0]).toEqual([oldToolCall]);
|
||||
|
||||
const newRequest: ToolCallRequestInfo = {
|
||||
callId: 'newCall',
|
||||
name: 'mockTool',
|
||||
args: {},
|
||||
} as any;
|
||||
act(() => {
|
||||
schedule(newRequest, new AbortController().signal);
|
||||
});
|
||||
|
||||
// After scheduling, the old call should be gone,
|
||||
// and the new one should be in the display in its initial state.
|
||||
expect(result.current[0].length).toBe(1);
|
||||
expect(result.current[0][0].request.callId).toBe('newCall');
|
||||
expect(result.current[0][0].request.callId).not.toBe('oldCall');
|
||||
|
||||
// Let the new call finish.
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
expect(onComplete).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should cancel all running tool calls', async () => {
|
||||
mockToolRegistry.getTool.mockReturnValue(mockTool);
|
||||
|
||||
let resolveExecute: (value: ToolResult) => void = () => {};
|
||||
const executePromise = new Promise<ToolResult>((resolve) => {
|
||||
resolveExecute = resolve;
|
||||
});
|
||||
(mockTool.execute as Mock).mockReturnValue(executePromise);
|
||||
(mockTool.shouldConfirmExecute as Mock).mockResolvedValue(null);
|
||||
|
||||
const { result } = renderScheduler();
|
||||
const schedule = result.current[1];
|
||||
const cancelAllToolCalls = result.current[4];
|
||||
const request: ToolCallRequestInfo = {
|
||||
callId: 'cancelCall',
|
||||
name: 'mockTool',
|
||||
args: {},
|
||||
} as any;
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
}); // validation
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
}); // scheduling
|
||||
|
||||
// At this point, the tool is 'executing' and waiting on the promise.
|
||||
expect(result.current[0][0].status).toBe('executing');
|
||||
|
||||
const cancelController = new AbortController();
|
||||
act(() => {
|
||||
cancelAllToolCalls(cancelController.signal);
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(onComplete).toHaveBeenCalledWith([
|
||||
expect.objectContaining({
|
||||
status: 'cancelled',
|
||||
request,
|
||||
}),
|
||||
]);
|
||||
|
||||
// Clean up the pending promise to avoid open handles.
|
||||
resolveExecute({ llmContent: 'output', returnDisplay: 'display' });
|
||||
});
|
||||
|
||||
it('should handle tool not found', async () => {
|
||||
@@ -305,6 +414,11 @@ describe('useReactToolScheduler', () => {
|
||||
args: {},
|
||||
} as any;
|
||||
|
||||
let completedToolCalls: ToolCall[] = [];
|
||||
onComplete.mockImplementation((calls) => {
|
||||
completedToolCalls = calls;
|
||||
});
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
@@ -315,24 +429,15 @@ describe('useReactToolScheduler', () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(onComplete).toHaveBeenCalledWith([
|
||||
expect.objectContaining({
|
||||
status: 'error',
|
||||
request,
|
||||
response: expect.objectContaining({
|
||||
error: expect.objectContaining({
|
||||
message: expect.stringMatching(
|
||||
/Tool "nonexistentTool" not found in registry/,
|
||||
),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
]);
|
||||
const errorMessage = onComplete.mock.calls[0][0][0].response.error.message;
|
||||
expect(errorMessage).toContain('Did you mean one of:');
|
||||
expect(errorMessage).toContain('"mockTool"');
|
||||
expect(errorMessage).toContain('"anotherTool"');
|
||||
expect(result.current[0]).toEqual([]);
|
||||
expect(completedToolCalls).toHaveLength(1);
|
||||
expect(completedToolCalls[0].status).toBe('error');
|
||||
expect(completedToolCalls[0].request).toBe(request);
|
||||
expect((completedToolCalls[0] as any).response.error.message).toContain(
|
||||
'Tool "nonexistentTool" not found in registry',
|
||||
);
|
||||
expect((completedToolCalls[0] as any).response.error.message).toContain(
|
||||
'Did you mean one of:',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle error during shouldConfirmExecute', async () => {
|
||||
@@ -348,6 +453,11 @@ describe('useReactToolScheduler', () => {
|
||||
args: {},
|
||||
} as any;
|
||||
|
||||
let completedToolCalls: ToolCall[] = [];
|
||||
onComplete.mockImplementation((calls) => {
|
||||
completedToolCalls = calls;
|
||||
});
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
@@ -358,16 +468,10 @@ describe('useReactToolScheduler', () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(onComplete).toHaveBeenCalledWith([
|
||||
expect.objectContaining({
|
||||
status: 'error',
|
||||
request,
|
||||
response: expect.objectContaining({
|
||||
error: confirmError,
|
||||
}),
|
||||
}),
|
||||
]);
|
||||
expect(result.current[0]).toEqual([]);
|
||||
expect(completedToolCalls).toHaveLength(1);
|
||||
expect(completedToolCalls[0].status).toBe('error');
|
||||
expect(completedToolCalls[0].request).toBe(request);
|
||||
expect((completedToolCalls[0] as any).response.error).toBe(confirmError);
|
||||
});
|
||||
|
||||
it('should handle error during execute', async () => {
|
||||
@@ -384,6 +488,11 @@ describe('useReactToolScheduler', () => {
|
||||
args: {},
|
||||
} as any;
|
||||
|
||||
let completedToolCalls: ToolCall[] = [];
|
||||
onComplete.mockImplementation((calls) => {
|
||||
completedToolCalls = calls;
|
||||
});
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
});
|
||||
@@ -397,16 +506,10 @@ describe('useReactToolScheduler', () => {
|
||||
await vi.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(onComplete).toHaveBeenCalledWith([
|
||||
expect.objectContaining({
|
||||
status: 'error',
|
||||
request,
|
||||
response: expect.objectContaining({
|
||||
error: execError,
|
||||
}),
|
||||
}),
|
||||
]);
|
||||
expect(result.current[0]).toEqual([]);
|
||||
expect(completedToolCalls).toHaveLength(1);
|
||||
expect(completedToolCalls[0].status).toBe('error');
|
||||
expect(completedToolCalls[0].request).toBe(request);
|
||||
expect((completedToolCalls[0] as any).response.error).toBe(execError);
|
||||
});
|
||||
|
||||
it('should handle tool requiring confirmation - approved', async () => {
|
||||
@@ -518,7 +621,7 @@ describe('useReactToolScheduler', () => {
|
||||
functionResponse: expect.objectContaining({
|
||||
response: expect.objectContaining({
|
||||
error:
|
||||
'[Operation Cancelled] Reason: User did not allow tool call',
|
||||
'[Operation Cancelled] Reason: User cancelled the operation.',
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
@@ -705,7 +808,9 @@ describe('useReactToolScheduler', () => {
|
||||
],
|
||||
}),
|
||||
});
|
||||
expect(result.current[0]).toEqual([]);
|
||||
|
||||
expect(completedCalls).toHaveLength(2);
|
||||
expect(completedCalls.every((t) => t.status === 'success')).toBe(true);
|
||||
});
|
||||
|
||||
it('should queue if scheduling while already running', async () => {
|
||||
@@ -774,7 +879,8 @@ describe('useReactToolScheduler', () => {
|
||||
response: expect.objectContaining({ resultDisplay: 'done display' }),
|
||||
}),
|
||||
]);
|
||||
expect(result.current[0]).toEqual([]);
|
||||
const toolCalls = result.current[0];
|
||||
expect(toolCalls).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user