mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-28 22:14:52 -07:00
feat(core): Failed Response Retry via Extra Prompt (#10828)
Co-authored-by: Sandy Tao <sandytao520@icloud.com>
This commit is contained in:
@@ -759,6 +759,7 @@ export const useGeminiStream = (
|
|||||||
loopDetectedRef.current = true;
|
loopDetectedRef.current = true;
|
||||||
break;
|
break;
|
||||||
case ServerGeminiEventType.Retry:
|
case ServerGeminiEventType.Retry:
|
||||||
|
case ServerGeminiEventType.InvalidStream:
|
||||||
// Will add the missing logic later
|
// Will add the missing logic later
|
||||||
break;
|
break;
|
||||||
default: {
|
default: {
|
||||||
|
|||||||
@@ -619,6 +619,31 @@ describe('Server Config (config.ts)', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('ContinueOnFailedApiCall Configuration', () => {
|
||||||
|
it('should default continueOnFailedApiCall to false when not provided', () => {
|
||||||
|
const config = new Config(baseParams);
|
||||||
|
expect(config.getContinueOnFailedApiCall()).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should set continueOnFailedApiCall to true when provided as true', () => {
|
||||||
|
const paramsWithContinueOnFailedApiCall: ConfigParameters = {
|
||||||
|
...baseParams,
|
||||||
|
continueOnFailedApiCall: true,
|
||||||
|
};
|
||||||
|
const config = new Config(paramsWithContinueOnFailedApiCall);
|
||||||
|
expect(config.getContinueOnFailedApiCall()).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should set continueOnFailedApiCall to false when explicitly provided as false', () => {
|
||||||
|
const paramsWithContinueOnFailedApiCall: ConfigParameters = {
|
||||||
|
...baseParams,
|
||||||
|
continueOnFailedApiCall: false,
|
||||||
|
};
|
||||||
|
const config = new Config(paramsWithContinueOnFailedApiCall);
|
||||||
|
expect(config.getContinueOnFailedApiCall()).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('createToolRegistry', () => {
|
describe('createToolRegistry', () => {
|
||||||
it('should register a tool if coreTools contains an argument-specific pattern', async () => {
|
it('should register a tool if coreTools contains an argument-specific pattern', async () => {
|
||||||
const params: ConfigParameters = {
|
const params: ConfigParameters = {
|
||||||
|
|||||||
@@ -270,6 +270,7 @@ export interface ConfigParameters {
|
|||||||
useModelRouter?: boolean;
|
useModelRouter?: boolean;
|
||||||
enableMessageBusIntegration?: boolean;
|
enableMessageBusIntegration?: boolean;
|
||||||
enableSubagents?: boolean;
|
enableSubagents?: boolean;
|
||||||
|
continueOnFailedApiCall?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export class Config {
|
export class Config {
|
||||||
@@ -363,6 +364,7 @@ export class Config {
|
|||||||
private readonly useModelRouter: boolean;
|
private readonly useModelRouter: boolean;
|
||||||
private readonly enableMessageBusIntegration: boolean;
|
private readonly enableMessageBusIntegration: boolean;
|
||||||
private readonly enableSubagents: boolean;
|
private readonly enableSubagents: boolean;
|
||||||
|
private readonly continueOnFailedApiCall: boolean;
|
||||||
|
|
||||||
constructor(params: ConfigParameters) {
|
constructor(params: ConfigParameters) {
|
||||||
this.sessionId = params.sessionId;
|
this.sessionId = params.sessionId;
|
||||||
@@ -453,6 +455,7 @@ export class Config {
|
|||||||
this.enableMessageBusIntegration =
|
this.enableMessageBusIntegration =
|
||||||
params.enableMessageBusIntegration ?? false;
|
params.enableMessageBusIntegration ?? false;
|
||||||
this.enableSubagents = params.enableSubagents ?? false;
|
this.enableSubagents = params.enableSubagents ?? false;
|
||||||
|
this.continueOnFailedApiCall = params.continueOnFailedApiCall ?? true;
|
||||||
this.extensionManagement = params.extensionManagement ?? true;
|
this.extensionManagement = params.extensionManagement ?? true;
|
||||||
this.storage = new Storage(this.targetDir);
|
this.storage = new Storage(this.targetDir);
|
||||||
this.enablePromptCompletion = params.enablePromptCompletion ?? false;
|
this.enablePromptCompletion = params.enablePromptCompletion ?? false;
|
||||||
@@ -951,6 +954,10 @@ export class Config {
|
|||||||
return this.skipNextSpeakerCheck;
|
return this.skipNextSpeakerCheck;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getContinueOnFailedApiCall(): boolean {
|
||||||
|
return this.continueOnFailedApiCall;
|
||||||
|
}
|
||||||
|
|
||||||
getShellExecutionConfig(): ShellExecutionConfig {
|
getShellExecutionConfig(): ShellExecutionConfig {
|
||||||
return this.shellExecutionConfig;
|
return this.shellExecutionConfig;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -315,6 +315,7 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false),
|
getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false),
|
||||||
getUseSmartEdit: vi.fn().mockReturnValue(false),
|
getUseSmartEdit: vi.fn().mockReturnValue(false),
|
||||||
getUseModelRouter: vi.fn().mockReturnValue(false),
|
getUseModelRouter: vi.fn().mockReturnValue(false),
|
||||||
|
getContinueOnFailedApiCall: vi.fn(),
|
||||||
getProjectRoot: vi.fn().mockReturnValue('/test/project/root'),
|
getProjectRoot: vi.fn().mockReturnValue('/test/project/root'),
|
||||||
storage: {
|
storage: {
|
||||||
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
|
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
|
||||||
@@ -1288,6 +1289,9 @@ ${JSON.stringify(
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should stop infinite loop after MAX_TURNS when nextSpeaker always returns model', async () => {
|
it('should stop infinite loop after MAX_TURNS when nextSpeaker always returns model', async () => {
|
||||||
|
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
|
||||||
|
true,
|
||||||
|
);
|
||||||
// Get the mocked checkNextSpeaker function and configure it to trigger infinite loop
|
// Get the mocked checkNextSpeaker function and configure it to trigger infinite loop
|
||||||
const { checkNextSpeaker } = await import(
|
const { checkNextSpeaker } = await import(
|
||||||
'../utils/nextSpeakerChecker.js'
|
'../utils/nextSpeakerChecker.js'
|
||||||
@@ -1784,6 +1788,131 @@ ${JSON.stringify(
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should recursively call sendMessageStream with "Please continue." when InvalidStream event is received', async () => {
|
||||||
|
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
|
||||||
|
true,
|
||||||
|
);
|
||||||
|
// Arrange
|
||||||
|
const mockStream1 = (async function* () {
|
||||||
|
yield { type: GeminiEventType.InvalidStream };
|
||||||
|
})();
|
||||||
|
const mockStream2 = (async function* () {
|
||||||
|
yield { type: GeminiEventType.Content, value: 'Continued content' };
|
||||||
|
})();
|
||||||
|
|
||||||
|
mockTurnRunFn
|
||||||
|
.mockReturnValueOnce(mockStream1)
|
||||||
|
.mockReturnValueOnce(mockStream2);
|
||||||
|
|
||||||
|
const mockChat: Partial<GeminiChat> = {
|
||||||
|
addHistory: vi.fn(),
|
||||||
|
getHistory: vi.fn().mockReturnValue([]),
|
||||||
|
};
|
||||||
|
client['chat'] = mockChat as GeminiChat;
|
||||||
|
|
||||||
|
const initialRequest = [{ text: 'Hi' }];
|
||||||
|
const promptId = 'prompt-id-invalid-stream';
|
||||||
|
const signal = new AbortController().signal;
|
||||||
|
|
||||||
|
// Act
|
||||||
|
const stream = client.sendMessageStream(initialRequest, signal, promptId);
|
||||||
|
const events = await fromAsync(stream);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
expect(events).toEqual([
|
||||||
|
{ type: GeminiEventType.InvalidStream },
|
||||||
|
{ type: GeminiEventType.Content, value: 'Continued content' },
|
||||||
|
]);
|
||||||
|
|
||||||
|
// Verify that turn.run was called twice
|
||||||
|
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
|
||||||
|
|
||||||
|
// First call with original request
|
||||||
|
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
|
||||||
|
1,
|
||||||
|
expect.any(String),
|
||||||
|
initialRequest,
|
||||||
|
expect.any(Object),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Second call with "Please continue."
|
||||||
|
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
|
||||||
|
2,
|
||||||
|
expect.any(String),
|
||||||
|
[{ text: 'System: Please continue.' }],
|
||||||
|
expect.any(Object),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not recursively call sendMessageStream with "Please continue." when InvalidStream event is received and flag is false', async () => {
|
||||||
|
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
// Arrange
|
||||||
|
const mockStream1 = (async function* () {
|
||||||
|
yield { type: GeminiEventType.InvalidStream };
|
||||||
|
})();
|
||||||
|
|
||||||
|
mockTurnRunFn.mockReturnValueOnce(mockStream1);
|
||||||
|
|
||||||
|
const mockChat: Partial<GeminiChat> = {
|
||||||
|
addHistory: vi.fn(),
|
||||||
|
getHistory: vi.fn().mockReturnValue([]),
|
||||||
|
};
|
||||||
|
client['chat'] = mockChat as GeminiChat;
|
||||||
|
|
||||||
|
const initialRequest = [{ text: 'Hi' }];
|
||||||
|
const promptId = 'prompt-id-invalid-stream';
|
||||||
|
const signal = new AbortController().signal;
|
||||||
|
|
||||||
|
// Act
|
||||||
|
const stream = client.sendMessageStream(initialRequest, signal, promptId);
|
||||||
|
const events = await fromAsync(stream);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
expect(events).toEqual([{ type: GeminiEventType.InvalidStream }]);
|
||||||
|
|
||||||
|
// Verify that turn.run was called only once
|
||||||
|
expect(mockTurnRunFn).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should stop recursing after one retry when InvalidStream events are repeatedly received', async () => {
|
||||||
|
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
|
||||||
|
true,
|
||||||
|
);
|
||||||
|
// Arrange
|
||||||
|
// Always return a new invalid stream
|
||||||
|
mockTurnRunFn.mockImplementation(() =>
|
||||||
|
(async function* () {
|
||||||
|
yield { type: GeminiEventType.InvalidStream };
|
||||||
|
})(),
|
||||||
|
);
|
||||||
|
|
||||||
|
const mockChat: Partial<GeminiChat> = {
|
||||||
|
addHistory: vi.fn(),
|
||||||
|
getHistory: vi.fn().mockReturnValue([]),
|
||||||
|
};
|
||||||
|
client['chat'] = mockChat as GeminiChat;
|
||||||
|
|
||||||
|
const initialRequest = [{ text: 'Hi' }];
|
||||||
|
const promptId = 'prompt-id-infinite-invalid-stream';
|
||||||
|
const signal = new AbortController().signal;
|
||||||
|
|
||||||
|
// Act
|
||||||
|
const stream = client.sendMessageStream(initialRequest, signal, promptId);
|
||||||
|
const events = await fromAsync(stream);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
// We expect 2 InvalidStream events (original + 1 retry)
|
||||||
|
expect(events.length).toBe(2);
|
||||||
|
expect(
|
||||||
|
events.every((e) => e.type === GeminiEventType.InvalidStream),
|
||||||
|
).toBe(true);
|
||||||
|
|
||||||
|
// Verify that turn.run was called twice
|
||||||
|
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
|
||||||
|
});
|
||||||
|
|
||||||
describe('Editor context delta', () => {
|
describe('Editor context delta', () => {
|
||||||
const mockStream = (async function* () {
|
const mockStream = (async function* () {
|
||||||
yield { type: 'content', value: 'Hello' };
|
yield { type: 'content', value: 'Hello' };
|
||||||
|
|||||||
@@ -40,9 +40,11 @@ import { LoopDetectionService } from '../services/loopDetectionService.js';
|
|||||||
import { ideContextStore } from '../ide/ideContext.js';
|
import { ideContextStore } from '../ide/ideContext.js';
|
||||||
import {
|
import {
|
||||||
logChatCompression,
|
logChatCompression,
|
||||||
|
logContentRetryFailure,
|
||||||
logNextSpeakerCheck,
|
logNextSpeakerCheck,
|
||||||
} from '../telemetry/loggers.js';
|
} from '../telemetry/loggers.js';
|
||||||
import {
|
import {
|
||||||
|
ContentRetryFailureEvent,
|
||||||
makeChatCompressionEvent,
|
makeChatCompressionEvent,
|
||||||
NextSpeakerCheckEvent,
|
NextSpeakerCheckEvent,
|
||||||
} from '../telemetry/types.js';
|
} from '../telemetry/types.js';
|
||||||
@@ -476,6 +478,7 @@ My setup is complete. I will provide my first command in the next turn.
|
|||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
prompt_id: string,
|
prompt_id: string,
|
||||||
turns: number = MAX_TURNS,
|
turns: number = MAX_TURNS,
|
||||||
|
isInvalidStreamRetry: boolean = false,
|
||||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||||
if (this.lastPromptId !== prompt_id) {
|
if (this.lastPromptId !== prompt_id) {
|
||||||
this.loopDetector.reset(prompt_id);
|
this.loopDetector.reset(prompt_id);
|
||||||
@@ -586,6 +589,31 @@ My setup is complete. I will provide my first command in the next turn.
|
|||||||
return turn;
|
return turn;
|
||||||
}
|
}
|
||||||
yield event;
|
yield event;
|
||||||
|
if (event.type === GeminiEventType.InvalidStream) {
|
||||||
|
if (this.config.getContinueOnFailedApiCall()) {
|
||||||
|
if (isInvalidStreamRetry) {
|
||||||
|
// We already retried once, so stop here.
|
||||||
|
logContentRetryFailure(
|
||||||
|
this.config,
|
||||||
|
new ContentRetryFailureEvent(
|
||||||
|
4, // 2 initial + 2 after injections
|
||||||
|
'FAILED_AFTER_PROMPT_INJECTION',
|
||||||
|
modelToUse,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
return turn;
|
||||||
|
}
|
||||||
|
const nextRequest = [{ text: 'System: Please continue.' }];
|
||||||
|
yield* this.sendMessageStream(
|
||||||
|
nextRequest,
|
||||||
|
signal,
|
||||||
|
prompt_id,
|
||||||
|
boundedTurns - 1,
|
||||||
|
true, // Set isInvalidStreamRetry to true
|
||||||
|
);
|
||||||
|
return turn;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (event.type === GeminiEventType.Error) {
|
if (event.type === GeminiEventType.Error) {
|
||||||
return turn;
|
return turn;
|
||||||
}
|
}
|
||||||
@@ -623,6 +651,7 @@ My setup is complete. I will provide my first command in the next turn.
|
|||||||
signal,
|
signal,
|
||||||
prompt_id,
|
prompt_id,
|
||||||
boundedTurns - 1,
|
boundedTurns - 1,
|
||||||
|
// isInvalidStreamRetry is false here, as this is a next speaker check
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import { Turn, GeminiEventType } from './turn.js';
|
|||||||
import type { GenerateContentResponse, Part, Content } from '@google/genai';
|
import type { GenerateContentResponse, Part, Content } from '@google/genai';
|
||||||
import { reportError } from '../utils/errorReporting.js';
|
import { reportError } from '../utils/errorReporting.js';
|
||||||
import type { GeminiChat } from './geminiChat.js';
|
import type { GeminiChat } from './geminiChat.js';
|
||||||
import { StreamEventType } from './geminiChat.js';
|
import { InvalidStreamError, StreamEventType } from './geminiChat.js';
|
||||||
|
|
||||||
const mockSendMessageStream = vi.fn();
|
const mockSendMessageStream = vi.fn();
|
||||||
const mockGetHistory = vi.fn();
|
const mockGetHistory = vi.fn();
|
||||||
@@ -223,6 +223,28 @@ describe('Turn', () => {
|
|||||||
expect(turn.getDebugResponses().length).toBe(1);
|
expect(turn.getDebugResponses().length).toBe(1);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should yield InvalidStream event if sendMessageStream throws InvalidStreamError', async () => {
|
||||||
|
const error = new InvalidStreamError(
|
||||||
|
'Test invalid stream',
|
||||||
|
'NO_FINISH_REASON',
|
||||||
|
);
|
||||||
|
mockSendMessageStream.mockRejectedValue(error);
|
||||||
|
const reqParts: Part[] = [{ text: 'Trigger invalid stream' }];
|
||||||
|
|
||||||
|
const events = [];
|
||||||
|
for await (const event of turn.run(
|
||||||
|
'test-model',
|
||||||
|
reqParts,
|
||||||
|
new AbortController().signal,
|
||||||
|
)) {
|
||||||
|
events.push(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(events).toEqual([{ type: GeminiEventType.InvalidStream }]);
|
||||||
|
expect(turn.getDebugResponses().length).toBe(0);
|
||||||
|
expect(reportError).not.toHaveBeenCalled(); // Should not report as error
|
||||||
|
});
|
||||||
|
|
||||||
it('should yield Error event and report if sendMessageStream throws', async () => {
|
it('should yield Error event and report if sendMessageStream throws', async () => {
|
||||||
const error = new Error('API Error');
|
const error = new Error('API Error');
|
||||||
mockSendMessageStream.mockRejectedValue(error);
|
mockSendMessageStream.mockRejectedValue(error);
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ import {
|
|||||||
toFriendlyError,
|
toFriendlyError,
|
||||||
} from '../utils/errors.js';
|
} from '../utils/errors.js';
|
||||||
import type { GeminiChat } from './geminiChat.js';
|
import type { GeminiChat } from './geminiChat.js';
|
||||||
|
import { InvalidStreamError } from './geminiChat.js';
|
||||||
import { parseThought, type ThoughtSummary } from '../utils/thoughtUtils.js';
|
import { parseThought, type ThoughtSummary } from '../utils/thoughtUtils.js';
|
||||||
import { createUserContent } from '@google/genai';
|
import { createUserContent } from '@google/genai';
|
||||||
|
|
||||||
@@ -60,6 +61,7 @@ export enum GeminiEventType {
|
|||||||
Citation = 'citation',
|
Citation = 'citation',
|
||||||
Retry = 'retry',
|
Retry = 'retry',
|
||||||
ContextWindowWillOverflow = 'context_window_will_overflow',
|
ContextWindowWillOverflow = 'context_window_will_overflow',
|
||||||
|
InvalidStream = 'invalid_stream',
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ServerGeminiRetryEvent = {
|
export type ServerGeminiRetryEvent = {
|
||||||
@@ -74,6 +76,10 @@ export type ServerGeminiContextWindowWillOverflowEvent = {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type ServerGeminiInvalidStreamEvent = {
|
||||||
|
type: GeminiEventType.InvalidStream;
|
||||||
|
};
|
||||||
|
|
||||||
export interface StructuredError {
|
export interface StructuredError {
|
||||||
message: string;
|
message: string;
|
||||||
status?: number;
|
status?: number;
|
||||||
@@ -203,7 +209,8 @@ export type ServerGeminiStreamEvent =
|
|||||||
| ServerGeminiToolCallResponseEvent
|
| ServerGeminiToolCallResponseEvent
|
||||||
| ServerGeminiUserCancelledEvent
|
| ServerGeminiUserCancelledEvent
|
||||||
| ServerGeminiRetryEvent
|
| ServerGeminiRetryEvent
|
||||||
| ServerGeminiContextWindowWillOverflowEvent;
|
| ServerGeminiContextWindowWillOverflowEvent
|
||||||
|
| ServerGeminiInvalidStreamEvent;
|
||||||
|
|
||||||
// A turn manages the agentic loop turn within the server context.
|
// A turn manages the agentic loop turn within the server context.
|
||||||
export class Turn {
|
export class Turn {
|
||||||
@@ -312,6 +319,11 @@ export class Turn {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (e instanceof InvalidStreamError) {
|
||||||
|
yield { type: GeminiEventType.InvalidStream };
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const error = toFriendlyError(e);
|
const error = toFriendlyError(e);
|
||||||
if (error instanceof UnauthorizedError) {
|
if (error instanceof UnauthorizedError) {
|
||||||
throw error;
|
throw error;
|
||||||
|
|||||||
Reference in New Issue
Block a user