feat(cli): continue request after disabling loop detection (#11416)

This commit is contained in:
Sandy Tao
2025-10-21 13:27:57 -07:00
committed by GitHub
parent cf16d1678d
commit dd3b1cb653
2 changed files with 99 additions and 38 deletions

View File

@@ -2677,7 +2677,8 @@ describe('useGeminiStream', () => {
}; };
mockConfig.getGeminiClient = vi.fn().mockReturnValue(mockClient); mockConfig.getGeminiClient = vi.fn().mockReturnValue(mockClient);
mockSendMessageStream.mockReturnValue( // Mock for the initial request
mockSendMessageStream.mockReturnValueOnce(
(async function* () { (async function* () {
yield { yield {
type: ServerGeminiEventType.LoopDetected, type: ServerGeminiEventType.LoopDetected,
@@ -2685,6 +2686,20 @@ describe('useGeminiStream', () => {
})(), })(),
); );
// Mock for the retry request
mockSendMessageStream.mockReturnValueOnce(
(async function* () {
yield {
type: ServerGeminiEventType.Content,
value: 'Retry successful',
};
yield {
type: ServerGeminiEventType.Finished,
value: { reason: 'STOP' },
};
})(),
);
const { result } = renderTestHook(); const { result } = renderTestHook();
await act(async () => { await act(async () => {
@@ -2715,10 +2730,21 @@ describe('useGeminiStream', () => {
expect(mockAddItem).toHaveBeenCalledWith( expect(mockAddItem).toHaveBeenCalledWith(
{ {
type: 'info', type: 'info',
text: 'Loop detection has been disabled for this session. Please try your request again.', text: 'Loop detection has been disabled for this session. Retrying request...',
}, },
expect.any(Number), expect.any(Number),
); );
// Verify that the request was retried
await waitFor(() => {
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
expect(mockSendMessageStream).toHaveBeenNthCalledWith(
2,
'test query',
expect.any(AbortSignal),
expect.any(String),
);
});
}); });
it('should keep loop detection enabled and show message when user selects "keep"', async () => { it('should keep loop detection enabled and show message when user selects "keep"', async () => {
@@ -2771,6 +2797,9 @@ describe('useGeminiStream', () => {
}, },
expect.any(Number), expect.any(Number),
); );
// Verify that the request was NOT retried
expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
}); });
it('should handle multiple loop detection events properly', async () => { it('should handle multiple loop detection events properly', async () => {
@@ -2821,6 +2850,20 @@ describe('useGeminiStream', () => {
})(), })(),
); );
// Mock for the retry request
mockSendMessageStream.mockReturnValueOnce(
(async function* () {
yield {
type: ServerGeminiEventType.Content,
value: 'Retry successful',
};
yield {
type: ServerGeminiEventType.Finished,
value: { reason: 'STOP' },
};
})(),
);
// Second loop detection // Second loop detection
await act(async () => { await act(async () => {
await result.current.submitQuery('second query'); await result.current.submitQuery('second query');
@@ -2843,10 +2886,21 @@ describe('useGeminiStream', () => {
expect(mockAddItem).toHaveBeenCalledWith( expect(mockAddItem).toHaveBeenCalledWith(
{ {
type: 'info', type: 'info',
text: 'Loop detection has been disabled for this session. Please try your request again.', text: 'Loop detection has been disabled for this session. Retrying request...',
}, },
expect.any(Number), expect.any(Number),
); );
// Verify that the request was retried
await waitFor(() => {
expect(mockSendMessageStream).toHaveBeenCalledTimes(3); // 1st query, 2nd query, retry of 2nd query
expect(mockSendMessageStream).toHaveBeenNthCalledWith(
3,
'second query',
expect.any(AbortSignal),
expect.any(String),
);
});
}); });
it('should process LoopDetected event after moving pending history to history', async () => { it('should process LoopDetected event after moving pending history to history', async () => {

View File

@@ -185,6 +185,8 @@ export const useGeminiStream = (
return undefined; return undefined;
}, [toolCalls]); }, [toolCalls]);
const lastQueryRef = useRef<PartListUnion | null>(null);
const lastPromptIdRef = useRef<string | null>(null);
const loopDetectedRef = useRef(false); const loopDetectedRef = useRef(false);
const [ const [
loopDetectionConfirmationRequest, loopDetectionConfirmationRequest,
@@ -668,39 +670,6 @@ export const useGeminiStream = (
[addItem, onCancelSubmit, config], [addItem, onCancelSubmit, config],
); );
const handleLoopDetectionConfirmation = useCallback(
(result: { userSelection: 'disable' | 'keep' }) => {
setLoopDetectionConfirmationRequest(null);
if (result.userSelection === 'disable') {
config.getGeminiClient().getLoopDetectionService().disableForSession();
addItem(
{
type: 'info',
text: `Loop detection has been disabled for this session. Please try your request again.`,
},
Date.now(),
);
} else {
addItem(
{
type: 'info',
text: `A potential loop was detected. This can happen due to repetitive tool calls or other model behavior. The request has been halted.`,
},
Date.now(),
);
}
},
[config, addItem],
);
const handleLoopDetectedEvent = useCallback(() => {
// Show the confirmation dialog to choose whether to disable loop detection
setLoopDetectionConfirmationRequest({
onComplete: handleLoopDetectionConfirmation,
});
}, [handleLoopDetectionConfirmation]);
const processGeminiStreamEvents = useCallback( const processGeminiStreamEvents = useCallback(
async ( async (
stream: AsyncIterable<GeminiEvent>, stream: AsyncIterable<GeminiEvent>,
@@ -850,6 +819,10 @@ export const useGeminiStream = (
setIsResponding(true); setIsResponding(true);
setInitError(null); setInitError(null);
// Store query and prompt_id for potential retry on loop detection
lastQueryRef.current = queryToSend;
lastPromptIdRef.current = prompt_id;
try { try {
const stream = geminiClient.sendMessageStream( const stream = geminiClient.sendMessageStream(
queryToSend, queryToSend,
@@ -872,7 +845,42 @@ export const useGeminiStream = (
} }
if (loopDetectedRef.current) { if (loopDetectedRef.current) {
loopDetectedRef.current = false; loopDetectedRef.current = false;
handleLoopDetectedEvent(); // Show the confirmation dialog to choose whether to disable loop detection
setLoopDetectionConfirmationRequest({
onComplete: (result: { userSelection: 'disable' | 'keep' }) => {
setLoopDetectionConfirmationRequest(null);
if (result.userSelection === 'disable') {
config
.getGeminiClient()
.getLoopDetectionService()
.disableForSession();
addItem(
{
type: 'info',
text: `Loop detection has been disabled for this session. Retrying request...`,
},
Date.now(),
);
if (lastQueryRef.current && lastPromptIdRef.current) {
submitQuery(
lastQueryRef.current,
{ isContinuation: true },
lastPromptIdRef.current,
);
}
} else {
addItem(
{
type: 'info',
text: `A potential loop was detected. This can happen due to repetitive tool calls or other model behavior. The request has been halted.`,
},
Date.now(),
);
}
},
});
} }
} catch (error: unknown) { } catch (error: unknown) {
if (error instanceof UnauthorizedError) { if (error instanceof UnauthorizedError) {
@@ -911,7 +919,6 @@ export const useGeminiStream = (
config, config,
startNewPrompt, startNewPrompt,
getPromptCount, getPromptCount,
handleLoopDetectedEvent,
], ],
); );