mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-13 23:51:16 -07:00
feat(core): Failed Response Retry via Extra Prompt (#10828)
Co-authored-by: Sandy Tao <sandytao520@icloud.com>
This commit is contained in:
@@ -759,6 +759,7 @@ export const useGeminiStream = (
|
||||
loopDetectedRef.current = true;
|
||||
break;
|
||||
case ServerGeminiEventType.Retry:
|
||||
case ServerGeminiEventType.InvalidStream:
|
||||
// Will add the missing logic later
|
||||
break;
|
||||
default: {
|
||||
|
||||
@@ -619,6 +619,31 @@ describe('Server Config (config.ts)', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('ContinueOnFailedApiCall Configuration', () => {
|
||||
it('should default continueOnFailedApiCall to false when not provided', () => {
|
||||
const config = new Config(baseParams);
|
||||
expect(config.getContinueOnFailedApiCall()).toBe(true);
|
||||
});
|
||||
|
||||
it('should set continueOnFailedApiCall to true when provided as true', () => {
|
||||
const paramsWithContinueOnFailedApiCall: ConfigParameters = {
|
||||
...baseParams,
|
||||
continueOnFailedApiCall: true,
|
||||
};
|
||||
const config = new Config(paramsWithContinueOnFailedApiCall);
|
||||
expect(config.getContinueOnFailedApiCall()).toBe(true);
|
||||
});
|
||||
|
||||
it('should set continueOnFailedApiCall to false when explicitly provided as false', () => {
|
||||
const paramsWithContinueOnFailedApiCall: ConfigParameters = {
|
||||
...baseParams,
|
||||
continueOnFailedApiCall: false,
|
||||
};
|
||||
const config = new Config(paramsWithContinueOnFailedApiCall);
|
||||
expect(config.getContinueOnFailedApiCall()).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('createToolRegistry', () => {
|
||||
it('should register a tool if coreTools contains an argument-specific pattern', async () => {
|
||||
const params: ConfigParameters = {
|
||||
|
||||
@@ -270,6 +270,7 @@ export interface ConfigParameters {
|
||||
useModelRouter?: boolean;
|
||||
enableMessageBusIntegration?: boolean;
|
||||
enableSubagents?: boolean;
|
||||
continueOnFailedApiCall?: boolean;
|
||||
}
|
||||
|
||||
export class Config {
|
||||
@@ -363,6 +364,7 @@ export class Config {
|
||||
private readonly useModelRouter: boolean;
|
||||
private readonly enableMessageBusIntegration: boolean;
|
||||
private readonly enableSubagents: boolean;
|
||||
private readonly continueOnFailedApiCall: boolean;
|
||||
|
||||
constructor(params: ConfigParameters) {
|
||||
this.sessionId = params.sessionId;
|
||||
@@ -453,6 +455,7 @@ export class Config {
|
||||
this.enableMessageBusIntegration =
|
||||
params.enableMessageBusIntegration ?? false;
|
||||
this.enableSubagents = params.enableSubagents ?? false;
|
||||
this.continueOnFailedApiCall = params.continueOnFailedApiCall ?? true;
|
||||
this.extensionManagement = params.extensionManagement ?? true;
|
||||
this.storage = new Storage(this.targetDir);
|
||||
this.enablePromptCompletion = params.enablePromptCompletion ?? false;
|
||||
@@ -951,6 +954,10 @@ export class Config {
|
||||
return this.skipNextSpeakerCheck;
|
||||
}
|
||||
|
||||
getContinueOnFailedApiCall(): boolean {
|
||||
return this.continueOnFailedApiCall;
|
||||
}
|
||||
|
||||
getShellExecutionConfig(): ShellExecutionConfig {
|
||||
return this.shellExecutionConfig;
|
||||
}
|
||||
|
||||
@@ -315,6 +315,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false),
|
||||
getUseSmartEdit: vi.fn().mockReturnValue(false),
|
||||
getUseModelRouter: vi.fn().mockReturnValue(false),
|
||||
getContinueOnFailedApiCall: vi.fn(),
|
||||
getProjectRoot: vi.fn().mockReturnValue('/test/project/root'),
|
||||
storage: {
|
||||
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
|
||||
@@ -1288,6 +1289,9 @@ ${JSON.stringify(
|
||||
});
|
||||
|
||||
it('should stop infinite loop after MAX_TURNS when nextSpeaker always returns model', async () => {
|
||||
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
|
||||
true,
|
||||
);
|
||||
// Get the mocked checkNextSpeaker function and configure it to trigger infinite loop
|
||||
const { checkNextSpeaker } = await import(
|
||||
'../utils/nextSpeakerChecker.js'
|
||||
@@ -1784,6 +1788,131 @@ ${JSON.stringify(
|
||||
});
|
||||
});
|
||||
|
||||
it('should recursively call sendMessageStream with "Please continue." when InvalidStream event is received', async () => {
|
||||
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
|
||||
true,
|
||||
);
|
||||
// Arrange
|
||||
const mockStream1 = (async function* () {
|
||||
yield { type: GeminiEventType.InvalidStream };
|
||||
})();
|
||||
const mockStream2 = (async function* () {
|
||||
yield { type: GeminiEventType.Content, value: 'Continued content' };
|
||||
})();
|
||||
|
||||
mockTurnRunFn
|
||||
.mockReturnValueOnce(mockStream1)
|
||||
.mockReturnValueOnce(mockStream2);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const initialRequest = [{ text: 'Hi' }];
|
||||
const promptId = 'prompt-id-invalid-stream';
|
||||
const signal = new AbortController().signal;
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(initialRequest, signal, promptId);
|
||||
const events = await fromAsync(stream);
|
||||
|
||||
// Assert
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.InvalidStream },
|
||||
{ type: GeminiEventType.Content, value: 'Continued content' },
|
||||
]);
|
||||
|
||||
// Verify that turn.run was called twice
|
||||
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
|
||||
|
||||
// First call with original request
|
||||
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.any(String),
|
||||
initialRequest,
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
// Second call with "Please continue."
|
||||
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.any(String),
|
||||
[{ text: 'System: Please continue.' }],
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not recursively call sendMessageStream with "Please continue." when InvalidStream event is received and flag is false', async () => {
|
||||
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
|
||||
false,
|
||||
);
|
||||
// Arrange
|
||||
const mockStream1 = (async function* () {
|
||||
yield { type: GeminiEventType.InvalidStream };
|
||||
})();
|
||||
|
||||
mockTurnRunFn.mockReturnValueOnce(mockStream1);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const initialRequest = [{ text: 'Hi' }];
|
||||
const promptId = 'prompt-id-invalid-stream';
|
||||
const signal = new AbortController().signal;
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(initialRequest, signal, promptId);
|
||||
const events = await fromAsync(stream);
|
||||
|
||||
// Assert
|
||||
expect(events).toEqual([{ type: GeminiEventType.InvalidStream }]);
|
||||
|
||||
// Verify that turn.run was called only once
|
||||
expect(mockTurnRunFn).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should stop recursing after one retry when InvalidStream events are repeatedly received', async () => {
|
||||
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
|
||||
true,
|
||||
);
|
||||
// Arrange
|
||||
// Always return a new invalid stream
|
||||
mockTurnRunFn.mockImplementation(() =>
|
||||
(async function* () {
|
||||
yield { type: GeminiEventType.InvalidStream };
|
||||
})(),
|
||||
);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const initialRequest = [{ text: 'Hi' }];
|
||||
const promptId = 'prompt-id-infinite-invalid-stream';
|
||||
const signal = new AbortController().signal;
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(initialRequest, signal, promptId);
|
||||
const events = await fromAsync(stream);
|
||||
|
||||
// Assert
|
||||
// We expect 2 InvalidStream events (original + 1 retry)
|
||||
expect(events.length).toBe(2);
|
||||
expect(
|
||||
events.every((e) => e.type === GeminiEventType.InvalidStream),
|
||||
).toBe(true);
|
||||
|
||||
// Verify that turn.run was called twice
|
||||
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
describe('Editor context delta', () => {
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'Hello' };
|
||||
|
||||
@@ -40,9 +40,11 @@ import { LoopDetectionService } from '../services/loopDetectionService.js';
|
||||
import { ideContextStore } from '../ide/ideContext.js';
|
||||
import {
|
||||
logChatCompression,
|
||||
logContentRetryFailure,
|
||||
logNextSpeakerCheck,
|
||||
} from '../telemetry/loggers.js';
|
||||
import {
|
||||
ContentRetryFailureEvent,
|
||||
makeChatCompressionEvent,
|
||||
NextSpeakerCheckEvent,
|
||||
} from '../telemetry/types.js';
|
||||
@@ -476,6 +478,7 @@ My setup is complete. I will provide my first command in the next turn.
|
||||
signal: AbortSignal,
|
||||
prompt_id: string,
|
||||
turns: number = MAX_TURNS,
|
||||
isInvalidStreamRetry: boolean = false,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
if (this.lastPromptId !== prompt_id) {
|
||||
this.loopDetector.reset(prompt_id);
|
||||
@@ -586,6 +589,31 @@ My setup is complete. I will provide my first command in the next turn.
|
||||
return turn;
|
||||
}
|
||||
yield event;
|
||||
if (event.type === GeminiEventType.InvalidStream) {
|
||||
if (this.config.getContinueOnFailedApiCall()) {
|
||||
if (isInvalidStreamRetry) {
|
||||
// We already retried once, so stop here.
|
||||
logContentRetryFailure(
|
||||
this.config,
|
||||
new ContentRetryFailureEvent(
|
||||
4, // 2 initial + 2 after injections
|
||||
'FAILED_AFTER_PROMPT_INJECTION',
|
||||
modelToUse,
|
||||
),
|
||||
);
|
||||
return turn;
|
||||
}
|
||||
const nextRequest = [{ text: 'System: Please continue.' }];
|
||||
yield* this.sendMessageStream(
|
||||
nextRequest,
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns - 1,
|
||||
true, // Set isInvalidStreamRetry to true
|
||||
);
|
||||
return turn;
|
||||
}
|
||||
}
|
||||
if (event.type === GeminiEventType.Error) {
|
||||
return turn;
|
||||
}
|
||||
@@ -623,6 +651,7 @@ My setup is complete. I will provide my first command in the next turn.
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns - 1,
|
||||
// isInvalidStreamRetry is false here, as this is a next speaker check
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import { Turn, GeminiEventType } from './turn.js';
|
||||
import type { GenerateContentResponse, Part, Content } from '@google/genai';
|
||||
import { reportError } from '../utils/errorReporting.js';
|
||||
import type { GeminiChat } from './geminiChat.js';
|
||||
import { StreamEventType } from './geminiChat.js';
|
||||
import { InvalidStreamError, StreamEventType } from './geminiChat.js';
|
||||
|
||||
const mockSendMessageStream = vi.fn();
|
||||
const mockGetHistory = vi.fn();
|
||||
@@ -223,6 +223,28 @@ describe('Turn', () => {
|
||||
expect(turn.getDebugResponses().length).toBe(1);
|
||||
});
|
||||
|
||||
it('should yield InvalidStream event if sendMessageStream throws InvalidStreamError', async () => {
|
||||
const error = new InvalidStreamError(
|
||||
'Test invalid stream',
|
||||
'NO_FINISH_REASON',
|
||||
);
|
||||
mockSendMessageStream.mockRejectedValue(error);
|
||||
const reqParts: Part[] = [{ text: 'Trigger invalid stream' }];
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events).toEqual([{ type: GeminiEventType.InvalidStream }]);
|
||||
expect(turn.getDebugResponses().length).toBe(0);
|
||||
expect(reportError).not.toHaveBeenCalled(); // Should not report as error
|
||||
});
|
||||
|
||||
it('should yield Error event and report if sendMessageStream throws', async () => {
|
||||
const error = new Error('API Error');
|
||||
mockSendMessageStream.mockRejectedValue(error);
|
||||
|
||||
@@ -27,6 +27,7 @@ import {
|
||||
toFriendlyError,
|
||||
} from '../utils/errors.js';
|
||||
import type { GeminiChat } from './geminiChat.js';
|
||||
import { InvalidStreamError } from './geminiChat.js';
|
||||
import { parseThought, type ThoughtSummary } from '../utils/thoughtUtils.js';
|
||||
import { createUserContent } from '@google/genai';
|
||||
|
||||
@@ -60,6 +61,7 @@ export enum GeminiEventType {
|
||||
Citation = 'citation',
|
||||
Retry = 'retry',
|
||||
ContextWindowWillOverflow = 'context_window_will_overflow',
|
||||
InvalidStream = 'invalid_stream',
|
||||
}
|
||||
|
||||
export type ServerGeminiRetryEvent = {
|
||||
@@ -74,6 +76,10 @@ export type ServerGeminiContextWindowWillOverflowEvent = {
|
||||
};
|
||||
};
|
||||
|
||||
export type ServerGeminiInvalidStreamEvent = {
|
||||
type: GeminiEventType.InvalidStream;
|
||||
};
|
||||
|
||||
export interface StructuredError {
|
||||
message: string;
|
||||
status?: number;
|
||||
@@ -203,7 +209,8 @@ export type ServerGeminiStreamEvent =
|
||||
| ServerGeminiToolCallResponseEvent
|
||||
| ServerGeminiUserCancelledEvent
|
||||
| ServerGeminiRetryEvent
|
||||
| ServerGeminiContextWindowWillOverflowEvent;
|
||||
| ServerGeminiContextWindowWillOverflowEvent
|
||||
| ServerGeminiInvalidStreamEvent;
|
||||
|
||||
// A turn manages the agentic loop turn within the server context.
|
||||
export class Turn {
|
||||
@@ -312,6 +319,11 @@ export class Turn {
|
||||
return;
|
||||
}
|
||||
|
||||
if (e instanceof InvalidStreamError) {
|
||||
yield { type: GeminiEventType.InvalidStream };
|
||||
return;
|
||||
}
|
||||
|
||||
const error = toFriendlyError(e);
|
||||
if (error instanceof UnauthorizedError) {
|
||||
throw error;
|
||||
|
||||
Reference in New Issue
Block a user