feat(core): Failed Response Retry via Extra Prompt (#10828)

Co-authored-by: Sandy Tao <sandytao520@icloud.com>
This commit is contained in:
Victor May
2025-10-09 18:04:08 -04:00
committed by GitHub
parent 1f6716f98a
commit 0b6c02000f
7 changed files with 227 additions and 2 deletions

View File

@@ -759,6 +759,7 @@ export const useGeminiStream = (
loopDetectedRef.current = true;
break;
case ServerGeminiEventType.Retry:
case ServerGeminiEventType.InvalidStream:
// Will add the missing logic later
break;
default: {

View File

@@ -619,6 +619,31 @@ describe('Server Config (config.ts)', () => {
});
});
describe('ContinueOnFailedApiCall Configuration', () => {
it('should default continueOnFailedApiCall to false when not provided', () => {
const config = new Config(baseParams);
expect(config.getContinueOnFailedApiCall()).toBe(true);
});
it('should set continueOnFailedApiCall to true when provided as true', () => {
const paramsWithContinueOnFailedApiCall: ConfigParameters = {
...baseParams,
continueOnFailedApiCall: true,
};
const config = new Config(paramsWithContinueOnFailedApiCall);
expect(config.getContinueOnFailedApiCall()).toBe(true);
});
it('should set continueOnFailedApiCall to false when explicitly provided as false', () => {
const paramsWithContinueOnFailedApiCall: ConfigParameters = {
...baseParams,
continueOnFailedApiCall: false,
};
const config = new Config(paramsWithContinueOnFailedApiCall);
expect(config.getContinueOnFailedApiCall()).toBe(false);
});
});
describe('createToolRegistry', () => {
it('should register a tool if coreTools contains an argument-specific pattern', async () => {
const params: ConfigParameters = {

View File

@@ -270,6 +270,7 @@ export interface ConfigParameters {
useModelRouter?: boolean;
enableMessageBusIntegration?: boolean;
enableSubagents?: boolean;
continueOnFailedApiCall?: boolean;
}
export class Config {
@@ -363,6 +364,7 @@ export class Config {
private readonly useModelRouter: boolean;
private readonly enableMessageBusIntegration: boolean;
private readonly enableSubagents: boolean;
private readonly continueOnFailedApiCall: boolean;
constructor(params: ConfigParameters) {
this.sessionId = params.sessionId;
@@ -453,6 +455,7 @@ export class Config {
this.enableMessageBusIntegration =
params.enableMessageBusIntegration ?? false;
this.enableSubagents = params.enableSubagents ?? false;
this.continueOnFailedApiCall = params.continueOnFailedApiCall ?? true;
this.extensionManagement = params.extensionManagement ?? true;
this.storage = new Storage(this.targetDir);
this.enablePromptCompletion = params.enablePromptCompletion ?? false;
@@ -951,6 +954,10 @@ export class Config {
return this.skipNextSpeakerCheck;
}
getContinueOnFailedApiCall(): boolean {
return this.continueOnFailedApiCall;
}
getShellExecutionConfig(): ShellExecutionConfig {
return this.shellExecutionConfig;
}

View File

@@ -315,6 +315,7 @@ describe('Gemini Client (client.ts)', () => {
getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false),
getUseSmartEdit: vi.fn().mockReturnValue(false),
getUseModelRouter: vi.fn().mockReturnValue(false),
getContinueOnFailedApiCall: vi.fn(),
getProjectRoot: vi.fn().mockReturnValue('/test/project/root'),
storage: {
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
@@ -1288,6 +1289,9 @@ ${JSON.stringify(
});
it('should stop infinite loop after MAX_TURNS when nextSpeaker always returns model', async () => {
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
true,
);
// Get the mocked checkNextSpeaker function and configure it to trigger infinite loop
const { checkNextSpeaker } = await import(
'../utils/nextSpeakerChecker.js'
@@ -1784,6 +1788,131 @@ ${JSON.stringify(
});
});
it('should recursively call sendMessageStream with "Please continue." when InvalidStream event is received', async () => {
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
true,
);
// Arrange
const mockStream1 = (async function* () {
yield { type: GeminiEventType.InvalidStream };
})();
const mockStream2 = (async function* () {
yield { type: GeminiEventType.Content, value: 'Continued content' };
})();
mockTurnRunFn
.mockReturnValueOnce(mockStream1)
.mockReturnValueOnce(mockStream2);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = mockChat as GeminiChat;
const initialRequest = [{ text: 'Hi' }];
const promptId = 'prompt-id-invalid-stream';
const signal = new AbortController().signal;
// Act
const stream = client.sendMessageStream(initialRequest, signal, promptId);
const events = await fromAsync(stream);
// Assert
expect(events).toEqual([
{ type: GeminiEventType.InvalidStream },
{ type: GeminiEventType.Content, value: 'Continued content' },
]);
// Verify that turn.run was called twice
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
// First call with original request
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
1,
expect.any(String),
initialRequest,
expect.any(Object),
);
// Second call with "Please continue."
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
2,
expect.any(String),
[{ text: 'System: Please continue.' }],
expect.any(Object),
);
});
it('should not recursively call sendMessageStream with "Please continue." when InvalidStream event is received and flag is false', async () => {
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
false,
);
// Arrange
const mockStream1 = (async function* () {
yield { type: GeminiEventType.InvalidStream };
})();
mockTurnRunFn.mockReturnValueOnce(mockStream1);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = mockChat as GeminiChat;
const initialRequest = [{ text: 'Hi' }];
const promptId = 'prompt-id-invalid-stream';
const signal = new AbortController().signal;
// Act
const stream = client.sendMessageStream(initialRequest, signal, promptId);
const events = await fromAsync(stream);
// Assert
expect(events).toEqual([{ type: GeminiEventType.InvalidStream }]);
// Verify that turn.run was called only once
expect(mockTurnRunFn).toHaveBeenCalledTimes(1);
});
it('should stop recursing after one retry when InvalidStream events are repeatedly received', async () => {
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
true,
);
// Arrange
// Always return a new invalid stream
mockTurnRunFn.mockImplementation(() =>
(async function* () {
yield { type: GeminiEventType.InvalidStream };
})(),
);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = mockChat as GeminiChat;
const initialRequest = [{ text: 'Hi' }];
const promptId = 'prompt-id-infinite-invalid-stream';
const signal = new AbortController().signal;
// Act
const stream = client.sendMessageStream(initialRequest, signal, promptId);
const events = await fromAsync(stream);
// Assert
// We expect 2 InvalidStream events (original + 1 retry)
expect(events.length).toBe(2);
expect(
events.every((e) => e.type === GeminiEventType.InvalidStream),
).toBe(true);
// Verify that turn.run was called twice
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
});
describe('Editor context delta', () => {
const mockStream = (async function* () {
yield { type: 'content', value: 'Hello' };

View File

@@ -40,9 +40,11 @@ import { LoopDetectionService } from '../services/loopDetectionService.js';
import { ideContextStore } from '../ide/ideContext.js';
import {
logChatCompression,
logContentRetryFailure,
logNextSpeakerCheck,
} from '../telemetry/loggers.js';
import {
ContentRetryFailureEvent,
makeChatCompressionEvent,
NextSpeakerCheckEvent,
} from '../telemetry/types.js';
@@ -476,6 +478,7 @@ My setup is complete. I will provide my first command in the next turn.
signal: AbortSignal,
prompt_id: string,
turns: number = MAX_TURNS,
isInvalidStreamRetry: boolean = false,
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
if (this.lastPromptId !== prompt_id) {
this.loopDetector.reset(prompt_id);
@@ -586,6 +589,31 @@ My setup is complete. I will provide my first command in the next turn.
return turn;
}
yield event;
if (event.type === GeminiEventType.InvalidStream) {
if (this.config.getContinueOnFailedApiCall()) {
if (isInvalidStreamRetry) {
// We already retried once, so stop here.
logContentRetryFailure(
this.config,
new ContentRetryFailureEvent(
4, // 2 initial + 2 after injections
'FAILED_AFTER_PROMPT_INJECTION',
modelToUse,
),
);
return turn;
}
const nextRequest = [{ text: 'System: Please continue.' }];
yield* this.sendMessageStream(
nextRequest,
signal,
prompt_id,
boundedTurns - 1,
true, // Set isInvalidStreamRetry to true
);
return turn;
}
}
if (event.type === GeminiEventType.Error) {
return turn;
}
@@ -623,6 +651,7 @@ My setup is complete. I will provide my first command in the next turn.
signal,
prompt_id,
boundedTurns - 1,
// isInvalidStreamRetry is false here, as this is a next speaker check
);
}
}

View File

@@ -13,7 +13,7 @@ import { Turn, GeminiEventType } from './turn.js';
import type { GenerateContentResponse, Part, Content } from '@google/genai';
import { reportError } from '../utils/errorReporting.js';
import type { GeminiChat } from './geminiChat.js';
import { StreamEventType } from './geminiChat.js';
import { InvalidStreamError, StreamEventType } from './geminiChat.js';
const mockSendMessageStream = vi.fn();
const mockGetHistory = vi.fn();
@@ -223,6 +223,28 @@ describe('Turn', () => {
expect(turn.getDebugResponses().length).toBe(1);
});
it('should yield InvalidStream event if sendMessageStream throws InvalidStreamError', async () => {
const error = new InvalidStreamError(
'Test invalid stream',
'NO_FINISH_REASON',
);
mockSendMessageStream.mockRejectedValue(error);
const reqParts: Part[] = [{ text: 'Trigger invalid stream' }];
const events = [];
for await (const event of turn.run(
'test-model',
reqParts,
new AbortController().signal,
)) {
events.push(event);
}
expect(events).toEqual([{ type: GeminiEventType.InvalidStream }]);
expect(turn.getDebugResponses().length).toBe(0);
expect(reportError).not.toHaveBeenCalled(); // Should not report as error
});
it('should yield Error event and report if sendMessageStream throws', async () => {
const error = new Error('API Error');
mockSendMessageStream.mockRejectedValue(error);

View File

@@ -27,6 +27,7 @@ import {
toFriendlyError,
} from '../utils/errors.js';
import type { GeminiChat } from './geminiChat.js';
import { InvalidStreamError } from './geminiChat.js';
import { parseThought, type ThoughtSummary } from '../utils/thoughtUtils.js';
import { createUserContent } from '@google/genai';
@@ -60,6 +61,7 @@ export enum GeminiEventType {
Citation = 'citation',
Retry = 'retry',
ContextWindowWillOverflow = 'context_window_will_overflow',
InvalidStream = 'invalid_stream',
}
export type ServerGeminiRetryEvent = {
@@ -74,6 +76,10 @@ export type ServerGeminiContextWindowWillOverflowEvent = {
};
};
export type ServerGeminiInvalidStreamEvent = {
type: GeminiEventType.InvalidStream;
};
export interface StructuredError {
message: string;
status?: number;
@@ -203,7 +209,8 @@ export type ServerGeminiStreamEvent =
| ServerGeminiToolCallResponseEvent
| ServerGeminiUserCancelledEvent
| ServerGeminiRetryEvent
| ServerGeminiContextWindowWillOverflowEvent;
| ServerGeminiContextWindowWillOverflowEvent
| ServerGeminiInvalidStreamEvent;
// A turn manages the agentic loop turn within the server context.
export class Turn {
@@ -312,6 +319,11 @@ export class Turn {
return;
}
if (e instanceof InvalidStreamError) {
yield { type: GeminiEventType.InvalidStream };
return;
}
const error = toFriendlyError(e);
if (error instanceof UnauthorizedError) {
throw error;