diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 2743c6be10..5ecd615c0c 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -759,6 +759,7 @@ export const useGeminiStream = ( loopDetectedRef.current = true; break; case ServerGeminiEventType.Retry: + case ServerGeminiEventType.InvalidStream: // Will add the missing logic later break; default: { diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 99d43b08aa..72c3e63628 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -619,6 +619,31 @@ describe('Server Config (config.ts)', () => { }); }); + describe('ContinueOnFailedApiCall Configuration', () => { + it('should default continueOnFailedApiCall to false when not provided', () => { + const config = new Config(baseParams); + expect(config.getContinueOnFailedApiCall()).toBe(true); + }); + + it('should set continueOnFailedApiCall to true when provided as true', () => { + const paramsWithContinueOnFailedApiCall: ConfigParameters = { + ...baseParams, + continueOnFailedApiCall: true, + }; + const config = new Config(paramsWithContinueOnFailedApiCall); + expect(config.getContinueOnFailedApiCall()).toBe(true); + }); + + it('should set continueOnFailedApiCall to false when explicitly provided as false', () => { + const paramsWithContinueOnFailedApiCall: ConfigParameters = { + ...baseParams, + continueOnFailedApiCall: false, + }; + const config = new Config(paramsWithContinueOnFailedApiCall); + expect(config.getContinueOnFailedApiCall()).toBe(false); + }); + }); + describe('createToolRegistry', () => { it('should register a tool if coreTools contains an argument-specific pattern', async () => { const params: ConfigParameters = { diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 6873497460..eb996cf9e3 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -270,6 +270,7 @@ export interface ConfigParameters { useModelRouter?: boolean; enableMessageBusIntegration?: boolean; enableSubagents?: boolean; + continueOnFailedApiCall?: boolean; } export class Config { @@ -363,6 +364,7 @@ export class Config { private readonly useModelRouter: boolean; private readonly enableMessageBusIntegration: boolean; private readonly enableSubagents: boolean; + private readonly continueOnFailedApiCall: boolean; constructor(params: ConfigParameters) { this.sessionId = params.sessionId; @@ -453,6 +455,7 @@ export class Config { this.enableMessageBusIntegration = params.enableMessageBusIntegration ?? false; this.enableSubagents = params.enableSubagents ?? false; + this.continueOnFailedApiCall = params.continueOnFailedApiCall ?? true; this.extensionManagement = params.extensionManagement ?? true; this.storage = new Storage(this.targetDir); this.enablePromptCompletion = params.enablePromptCompletion ?? false; @@ -951,6 +954,10 @@ export class Config { return this.skipNextSpeakerCheck; } + getContinueOnFailedApiCall(): boolean { + return this.continueOnFailedApiCall; + } + getShellExecutionConfig(): ShellExecutionConfig { return this.shellExecutionConfig; } diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 039debac09..aec1976844 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -315,6 +315,7 @@ describe('Gemini Client (client.ts)', () => { getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false), getUseSmartEdit: vi.fn().mockReturnValue(false), getUseModelRouter: vi.fn().mockReturnValue(false), + getContinueOnFailedApiCall: vi.fn(), getProjectRoot: vi.fn().mockReturnValue('/test/project/root'), storage: { getProjectTempDir: vi.fn().mockReturnValue('/test/temp'), @@ -1288,6 +1289,9 @@ ${JSON.stringify( }); it('should stop infinite loop after MAX_TURNS when nextSpeaker always returns model', async () => { + vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue( + true, + ); // Get the mocked checkNextSpeaker function and configure it to trigger infinite loop const { checkNextSpeaker } = await import( '../utils/nextSpeakerChecker.js' @@ -1784,6 +1788,131 @@ ${JSON.stringify( }); }); + it('should recursively call sendMessageStream with "Please continue." when InvalidStream event is received', async () => { + vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue( + true, + ); + // Arrange + const mockStream1 = (async function* () { + yield { type: GeminiEventType.InvalidStream }; + })(); + const mockStream2 = (async function* () { + yield { type: GeminiEventType.Content, value: 'Continued content' }; + })(); + + mockTurnRunFn + .mockReturnValueOnce(mockStream1) + .mockReturnValueOnce(mockStream2); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; + + const initialRequest = [{ text: 'Hi' }]; + const promptId = 'prompt-id-invalid-stream'; + const signal = new AbortController().signal; + + // Act + const stream = client.sendMessageStream(initialRequest, signal, promptId); + const events = await fromAsync(stream); + + // Assert + expect(events).toEqual([ + { type: GeminiEventType.InvalidStream }, + { type: GeminiEventType.Content, value: 'Continued content' }, + ]); + + // Verify that turn.run was called twice + expect(mockTurnRunFn).toHaveBeenCalledTimes(2); + + // First call with original request + expect(mockTurnRunFn).toHaveBeenNthCalledWith( + 1, + expect.any(String), + initialRequest, + expect.any(Object), + ); + + // Second call with "Please continue." + expect(mockTurnRunFn).toHaveBeenNthCalledWith( + 2, + expect.any(String), + [{ text: 'System: Please continue.' }], + expect.any(Object), + ); + }); + + it('should not recursively call sendMessageStream with "Please continue." when InvalidStream event is received and flag is false', async () => { + vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue( + false, + ); + // Arrange + const mockStream1 = (async function* () { + yield { type: GeminiEventType.InvalidStream }; + })(); + + mockTurnRunFn.mockReturnValueOnce(mockStream1); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; + + const initialRequest = [{ text: 'Hi' }]; + const promptId = 'prompt-id-invalid-stream'; + const signal = new AbortController().signal; + + // Act + const stream = client.sendMessageStream(initialRequest, signal, promptId); + const events = await fromAsync(stream); + + // Assert + expect(events).toEqual([{ type: GeminiEventType.InvalidStream }]); + + // Verify that turn.run was called only once + expect(mockTurnRunFn).toHaveBeenCalledTimes(1); + }); + + it('should stop recursing after one retry when InvalidStream events are repeatedly received', async () => { + vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue( + true, + ); + // Arrange + // Always return a new invalid stream + mockTurnRunFn.mockImplementation(() => + (async function* () { + yield { type: GeminiEventType.InvalidStream }; + })(), + ); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; + + const initialRequest = [{ text: 'Hi' }]; + const promptId = 'prompt-id-infinite-invalid-stream'; + const signal = new AbortController().signal; + + // Act + const stream = client.sendMessageStream(initialRequest, signal, promptId); + const events = await fromAsync(stream); + + // Assert + // We expect 2 InvalidStream events (original + 1 retry) + expect(events.length).toBe(2); + expect( + events.every((e) => e.type === GeminiEventType.InvalidStream), + ).toBe(true); + + // Verify that turn.run was called twice + expect(mockTurnRunFn).toHaveBeenCalledTimes(2); + }); + describe('Editor context delta', () => { const mockStream = (async function* () { yield { type: 'content', value: 'Hello' }; diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 2df71de211..849249f4c9 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -40,9 +40,11 @@ import { LoopDetectionService } from '../services/loopDetectionService.js'; import { ideContextStore } from '../ide/ideContext.js'; import { logChatCompression, + logContentRetryFailure, logNextSpeakerCheck, } from '../telemetry/loggers.js'; import { + ContentRetryFailureEvent, makeChatCompressionEvent, NextSpeakerCheckEvent, } from '../telemetry/types.js'; @@ -476,6 +478,7 @@ My setup is complete. I will provide my first command in the next turn. signal: AbortSignal, prompt_id: string, turns: number = MAX_TURNS, + isInvalidStreamRetry: boolean = false, ): AsyncGenerator { if (this.lastPromptId !== prompt_id) { this.loopDetector.reset(prompt_id); @@ -586,6 +589,31 @@ My setup is complete. I will provide my first command in the next turn. return turn; } yield event; + if (event.type === GeminiEventType.InvalidStream) { + if (this.config.getContinueOnFailedApiCall()) { + if (isInvalidStreamRetry) { + // We already retried once, so stop here. + logContentRetryFailure( + this.config, + new ContentRetryFailureEvent( + 4, // 2 initial + 2 after injections + 'FAILED_AFTER_PROMPT_INJECTION', + modelToUse, + ), + ); + return turn; + } + const nextRequest = [{ text: 'System: Please continue.' }]; + yield* this.sendMessageStream( + nextRequest, + signal, + prompt_id, + boundedTurns - 1, + true, // Set isInvalidStreamRetry to true + ); + return turn; + } + } if (event.type === GeminiEventType.Error) { return turn; } @@ -623,6 +651,7 @@ My setup is complete. I will provide my first command in the next turn. signal, prompt_id, boundedTurns - 1, + // isInvalidStreamRetry is false here, as this is a next speaker check ); } } diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts index 45f78a26a5..97e7195d30 100644 --- a/packages/core/src/core/turn.test.ts +++ b/packages/core/src/core/turn.test.ts @@ -13,7 +13,7 @@ import { Turn, GeminiEventType } from './turn.js'; import type { GenerateContentResponse, Part, Content } from '@google/genai'; import { reportError } from '../utils/errorReporting.js'; import type { GeminiChat } from './geminiChat.js'; -import { StreamEventType } from './geminiChat.js'; +import { InvalidStreamError, StreamEventType } from './geminiChat.js'; const mockSendMessageStream = vi.fn(); const mockGetHistory = vi.fn(); @@ -223,6 +223,28 @@ describe('Turn', () => { expect(turn.getDebugResponses().length).toBe(1); }); + it('should yield InvalidStream event if sendMessageStream throws InvalidStreamError', async () => { + const error = new InvalidStreamError( + 'Test invalid stream', + 'NO_FINISH_REASON', + ); + mockSendMessageStream.mockRejectedValue(error); + const reqParts: Part[] = [{ text: 'Trigger invalid stream' }]; + + const events = []; + for await (const event of turn.run( + 'test-model', + reqParts, + new AbortController().signal, + )) { + events.push(event); + } + + expect(events).toEqual([{ type: GeminiEventType.InvalidStream }]); + expect(turn.getDebugResponses().length).toBe(0); + expect(reportError).not.toHaveBeenCalled(); // Should not report as error + }); + it('should yield Error event and report if sendMessageStream throws', async () => { const error = new Error('API Error'); mockSendMessageStream.mockRejectedValue(error); diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 66f2df4b46..fc6772bf3e 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -27,6 +27,7 @@ import { toFriendlyError, } from '../utils/errors.js'; import type { GeminiChat } from './geminiChat.js'; +import { InvalidStreamError } from './geminiChat.js'; import { parseThought, type ThoughtSummary } from '../utils/thoughtUtils.js'; import { createUserContent } from '@google/genai'; @@ -60,6 +61,7 @@ export enum GeminiEventType { Citation = 'citation', Retry = 'retry', ContextWindowWillOverflow = 'context_window_will_overflow', + InvalidStream = 'invalid_stream', } export type ServerGeminiRetryEvent = { @@ -74,6 +76,10 @@ export type ServerGeminiContextWindowWillOverflowEvent = { }; }; +export type ServerGeminiInvalidStreamEvent = { + type: GeminiEventType.InvalidStream; +}; + export interface StructuredError { message: string; status?: number; @@ -203,7 +209,8 @@ export type ServerGeminiStreamEvent = | ServerGeminiToolCallResponseEvent | ServerGeminiUserCancelledEvent | ServerGeminiRetryEvent - | ServerGeminiContextWindowWillOverflowEvent; + | ServerGeminiContextWindowWillOverflowEvent + | ServerGeminiInvalidStreamEvent; // A turn manages the agentic loop turn within the server context. export class Turn { @@ -312,6 +319,11 @@ export class Turn { return; } + if (e instanceof InvalidStreamError) { + yield { type: GeminiEventType.InvalidStream }; + return; + } + const error = toFriendlyError(e); if (error instanceof UnauthorizedError) { throw error;