mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 05:12:55 -07:00
fix(core): remove "System: Please continue." injection on InvalidStream events (#26340)
This commit is contained in:
@@ -263,7 +263,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-1',
|
'prompt-id-1',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'Test input',
|
'Test input',
|
||||||
);
|
);
|
||||||
expect(getWrittenOutput()).toBe('Hello World\n');
|
expect(getWrittenOutput()).toBe('Hello World\n');
|
||||||
@@ -382,7 +381,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-2',
|
'prompt-id-2',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
undefined,
|
undefined,
|
||||||
);
|
);
|
||||||
expect(getWrittenOutput()).toBe('Final answer\n');
|
expect(getWrittenOutput()).toBe('Final answer\n');
|
||||||
@@ -542,7 +540,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-3',
|
'prompt-id-3',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
undefined,
|
undefined,
|
||||||
);
|
);
|
||||||
expect(getWrittenOutput()).toBe('Sorry, let me try again.\n');
|
expect(getWrittenOutput()).toBe('Sorry, let me try again.\n');
|
||||||
@@ -684,7 +681,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-7',
|
'prompt-id-7',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
rawInput,
|
rawInput,
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -720,7 +716,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-1',
|
'prompt-id-1',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'Test input',
|
'Test input',
|
||||||
);
|
);
|
||||||
expect(processStdoutSpy).toHaveBeenCalledWith(
|
expect(processStdoutSpy).toHaveBeenCalledWith(
|
||||||
@@ -853,7 +848,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-empty',
|
'prompt-id-empty',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'Empty response test',
|
'Empty response test',
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -990,7 +984,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-slash',
|
'prompt-id-slash',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'/testcommand',
|
'/testcommand',
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -1036,7 +1029,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-slash',
|
'prompt-id-slash',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'/help',
|
'/help',
|
||||||
);
|
);
|
||||||
expect(getWrittenOutput()).toBe('Response to slash command\n');
|
expect(getWrittenOutput()).toBe('Response to slash command\n');
|
||||||
@@ -1214,7 +1206,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-unknown',
|
'prompt-id-unknown',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'/unknowncommand',
|
'/unknowncommand',
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -319,7 +319,6 @@ export async function runNonInteractive(
|
|||||||
abortController.signal,
|
abortController.signal,
|
||||||
prompt_id,
|
prompt_id,
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
turnCount === 1 ? input : undefined,
|
turnCount === 1 ? input : undefined,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -269,7 +269,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-1',
|
'prompt-id-1',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'Test input',
|
'Test input',
|
||||||
);
|
);
|
||||||
expect(getWrittenOutput()).toBe('Hello World\n');
|
expect(getWrittenOutput()).toBe('Hello World\n');
|
||||||
@@ -436,7 +435,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-2',
|
'prompt-id-2',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
undefined,
|
undefined,
|
||||||
);
|
);
|
||||||
expect(getWrittenOutput()).toBe('Final answer\n');
|
expect(getWrittenOutput()).toBe('Final answer\n');
|
||||||
@@ -596,7 +594,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-3',
|
'prompt-id-3',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
undefined,
|
undefined,
|
||||||
);
|
);
|
||||||
expect(getWrittenOutput()).toBe('Sorry, let me try again.\n');
|
expect(getWrittenOutput()).toBe('Sorry, let me try again.\n');
|
||||||
@@ -738,7 +735,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-7',
|
'prompt-id-7',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
rawInput,
|
rawInput,
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -774,7 +770,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-1',
|
'prompt-id-1',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'Test input',
|
'Test input',
|
||||||
);
|
);
|
||||||
expect(processStdoutSpy).toHaveBeenCalledWith(
|
expect(processStdoutSpy).toHaveBeenCalledWith(
|
||||||
@@ -980,7 +975,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-empty',
|
'prompt-id-empty',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'Empty response test',
|
'Empty response test',
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -1117,7 +1111,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-slash',
|
'prompt-id-slash',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'/testcommand',
|
'/testcommand',
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -1163,7 +1156,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-slash',
|
'prompt-id-slash',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'/help',
|
'/help',
|
||||||
);
|
);
|
||||||
expect(getWrittenOutput()).toBe('Response to slash command\n');
|
expect(getWrittenOutput()).toBe('Response to slash command\n');
|
||||||
@@ -1383,7 +1375,6 @@ describe('runNonInteractive', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-unknown',
|
'prompt-id-unknown',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'/unknowncommand',
|
'/unknowncommand',
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -135,7 +135,6 @@ export const createMockConfig = (overrides: Partial<Config> = {}): Config =>
|
|||||||
getUseRipgrep: vi.fn().mockReturnValue(false),
|
getUseRipgrep: vi.fn().mockReturnValue(false),
|
||||||
getEnableInteractiveShell: vi.fn().mockReturnValue(false),
|
getEnableInteractiveShell: vi.fn().mockReturnValue(false),
|
||||||
getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false),
|
getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false),
|
||||||
getContinueOnFailedApiCall: vi.fn().mockReturnValue(false),
|
|
||||||
getRetryFetchErrors: vi.fn().mockReturnValue(true),
|
getRetryFetchErrors: vi.fn().mockReturnValue(true),
|
||||||
getEnableShellOutputEfficiency: vi.fn().mockReturnValue(true),
|
getEnableShellOutputEfficiency: vi.fn().mockReturnValue(true),
|
||||||
getShellToolInactivityTimeout: vi.fn().mockReturnValue(300000),
|
getShellToolInactivityTimeout: vi.fn().mockReturnValue(300000),
|
||||||
|
|||||||
@@ -805,7 +805,6 @@ describe('useGeminiStream', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-2',
|
'prompt-id-2',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
expectedMergedResponse,
|
expectedMergedResponse,
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
@@ -1532,7 +1531,6 @@ describe('useGeminiStream', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'prompt-id-4',
|
'prompt-id-4',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
toolCallResponseParts,
|
toolCallResponseParts,
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
@@ -2027,7 +2025,6 @@ describe('useGeminiStream', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
expect.any(String),
|
expect.any(String),
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'/my-custom-command',
|
'/my-custom-command',
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -2056,7 +2053,6 @@ describe('useGeminiStream', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
expect.any(String),
|
expect.any(String),
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'/emptycmd',
|
'/emptycmd',
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
@@ -2077,7 +2073,6 @@ describe('useGeminiStream', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
expect.any(String),
|
expect.any(String),
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'// This is a line comment',
|
'// This is a line comment',
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
@@ -2098,7 +2093,6 @@ describe('useGeminiStream', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
expect.any(String),
|
expect.any(String),
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'/* This is a block comment */',
|
'/* This is a block comment */',
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
@@ -3058,7 +3052,6 @@ describe('useGeminiStream', () => {
|
|||||||
expect.any(AbortSignal), // Argument 2: An AbortSignal
|
expect.any(AbortSignal), // Argument 2: An AbortSignal
|
||||||
expect.any(String), // Argument 3: The prompt_id string
|
expect.any(String), // Argument 3: The prompt_id string
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
rawQuery,
|
rawQuery,
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
@@ -3709,7 +3702,6 @@ describe('useGeminiStream', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
expect.any(String),
|
expect.any(String),
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'test query',
|
'test query',
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
@@ -3859,7 +3851,6 @@ describe('useGeminiStream', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
expect.any(String),
|
expect.any(String),
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'second query',
|
'second query',
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
@@ -4004,7 +3995,6 @@ describe('useGeminiStream', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
expect.any(String),
|
expect.any(String),
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'test query',
|
'test query',
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1670,7 +1670,6 @@ export const useGeminiStream = (
|
|||||||
abortSignal,
|
abortSignal,
|
||||||
prompt_id!,
|
prompt_id!,
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
query,
|
query,
|
||||||
);
|
);
|
||||||
const processingStatus = await processGeminiStreamEvents(
|
const processingStatus = await processGeminiStreamEvents(
|
||||||
|
|||||||
@@ -200,7 +200,6 @@ describe('LegacyAgentSession', () => {
|
|||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
'test-prompt',
|
'test-prompt',
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
'raw input',
|
'raw input',
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -196,7 +196,6 @@ export class LegacyAgentProtocol implements AgentProtocol {
|
|||||||
this._abortController.signal,
|
this._abortController.signal,
|
||||||
this._promptId,
|
this._promptId,
|
||||||
undefined,
|
undefined,
|
||||||
false,
|
|
||||||
currentDisplayContent,
|
currentDisplayContent,
|
||||||
);
|
);
|
||||||
currentDisplayContent = undefined;
|
currentDisplayContent = undefined;
|
||||||
|
|||||||
@@ -1437,31 +1437,6 @@ describe('Server Config (config.ts)', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('ContinueOnFailedApiCall Configuration', () => {
|
|
||||||
it('should default continueOnFailedApiCall to false when not provided', () => {
|
|
||||||
const config = new Config(baseParams);
|
|
||||||
expect(config.getContinueOnFailedApiCall()).toBe(true);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should set continueOnFailedApiCall to true when provided as true', () => {
|
|
||||||
const paramsWithContinueOnFailedApiCall: ConfigParameters = {
|
|
||||||
...baseParams,
|
|
||||||
continueOnFailedApiCall: true,
|
|
||||||
};
|
|
||||||
const config = new Config(paramsWithContinueOnFailedApiCall);
|
|
||||||
expect(config.getContinueOnFailedApiCall()).toBe(true);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should set continueOnFailedApiCall to false when explicitly provided as false', () => {
|
|
||||||
const paramsWithContinueOnFailedApiCall: ConfigParameters = {
|
|
||||||
...baseParams,
|
|
||||||
continueOnFailedApiCall: false,
|
|
||||||
};
|
|
||||||
const config = new Config(paramsWithContinueOnFailedApiCall);
|
|
||||||
expect(config.getContinueOnFailedApiCall()).toBe(false);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('createToolRegistry', () => {
|
describe('createToolRegistry', () => {
|
||||||
it('should register a tool if coreTools contains an argument-specific pattern', async () => {
|
it('should register a tool if coreTools contains an argument-specific pattern', async () => {
|
||||||
const params: ConfigParameters = {
|
const params: ConfigParameters = {
|
||||||
|
|||||||
@@ -681,7 +681,6 @@ export interface ConfigParameters {
|
|||||||
gemmaModelRouter?: GemmaModelRouterSettings;
|
gemmaModelRouter?: GemmaModelRouterSettings;
|
||||||
adk?: ADKSettings;
|
adk?: ADKSettings;
|
||||||
disableModelRouterForAuth?: AuthType[];
|
disableModelRouterForAuth?: AuthType[];
|
||||||
continueOnFailedApiCall?: boolean;
|
|
||||||
retryFetchErrors?: boolean;
|
retryFetchErrors?: boolean;
|
||||||
maxAttempts?: number;
|
maxAttempts?: number;
|
||||||
enableShellOutputEfficiency?: boolean;
|
enableShellOutputEfficiency?: boolean;
|
||||||
@@ -911,7 +910,6 @@ export class Config implements McpContext, AgentLoopContext {
|
|||||||
private readonly agentSessionNoninteractiveEnabled: boolean;
|
private readonly agentSessionNoninteractiveEnabled: boolean;
|
||||||
private readonly agentSessionInteractiveEnabled: boolean;
|
private readonly agentSessionInteractiveEnabled: boolean;
|
||||||
|
|
||||||
private readonly continueOnFailedApiCall: boolean;
|
|
||||||
private readonly retryFetchErrors: boolean;
|
private readonly retryFetchErrors: boolean;
|
||||||
private readonly maxAttempts: number;
|
private readonly maxAttempts: number;
|
||||||
private readonly enableShellOutputEfficiency: boolean;
|
private readonly enableShellOutputEfficiency: boolean;
|
||||||
@@ -1288,7 +1286,6 @@ export class Config implements McpContext, AgentLoopContext {
|
|||||||
this.enableHooks = params.enableHooks ?? true;
|
this.enableHooks = params.enableHooks ?? true;
|
||||||
this.disabledHooks = params.disabledHooks ?? [];
|
this.disabledHooks = params.disabledHooks ?? [];
|
||||||
|
|
||||||
this.continueOnFailedApiCall = params.continueOnFailedApiCall ?? true;
|
|
||||||
this.enableShellOutputEfficiency =
|
this.enableShellOutputEfficiency =
|
||||||
params.enableShellOutputEfficiency ?? true;
|
params.enableShellOutputEfficiency ?? true;
|
||||||
this.shellToolInactivityTimeout =
|
this.shellToolInactivityTimeout =
|
||||||
@@ -3449,10 +3446,6 @@ export class Config implements McpContext, AgentLoopContext {
|
|||||||
return this.skipNextSpeakerCheck;
|
return this.skipNextSpeakerCheck;
|
||||||
}
|
}
|
||||||
|
|
||||||
getContinueOnFailedApiCall(): boolean {
|
|
||||||
return this.continueOnFailedApiCall;
|
|
||||||
}
|
|
||||||
|
|
||||||
getRetryFetchErrors(): boolean {
|
getRetryFetchErrors(): boolean {
|
||||||
return this.retryFetchErrors;
|
return this.retryFetchErrors;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -259,7 +259,6 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
getCompressionThreshold: vi.fn().mockReturnValue(undefined),
|
getCompressionThreshold: vi.fn().mockReturnValue(undefined),
|
||||||
getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false),
|
getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false),
|
||||||
getShowModelInfoInChat: vi.fn().mockReturnValue(false),
|
getShowModelInfoInChat: vi.fn().mockReturnValue(false),
|
||||||
getContinueOnFailedApiCall: vi.fn(),
|
|
||||||
getProjectRoot: vi.fn().mockReturnValue('/test/project/root'),
|
getProjectRoot: vi.fn().mockReturnValue('/test/project/root'),
|
||||||
getIncludeDirectoryTree: vi.fn().mockReturnValue(true),
|
getIncludeDirectoryTree: vi.fn().mockReturnValue(true),
|
||||||
storage: {
|
storage: {
|
||||||
@@ -1304,9 +1303,6 @@ ${JSON.stringify(
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should stop infinite loop after MAX_TURNS when nextSpeaker always returns model', async () => {
|
it('should stop infinite loop after MAX_TURNS when nextSpeaker always returns model', async () => {
|
||||||
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
// Get the mocked checkNextSpeaker function and configure it to trigger infinite loop
|
// Get the mocked checkNextSpeaker function and configure it to trigger infinite loop
|
||||||
const { checkNextSpeaker } = await import(
|
const { checkNextSpeaker } = await import(
|
||||||
'../utils/nextSpeakerChecker.js'
|
'../utils/nextSpeakerChecker.js'
|
||||||
@@ -2059,26 +2055,13 @@ ${JSON.stringify(
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should recursively call sendMessageStream with "Please continue." when InvalidStream event is received for Gemini 2 models', async () => {
|
it('should propagate InvalidStream events without injecting "Please continue." or recursing', async () => {
|
||||||
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
|
// Arrange: a single turn that yields an InvalidStream event.
|
||||||
true,
|
const mockStream = (async function* () {
|
||||||
);
|
|
||||||
// Arrange - router must return a Gemini 2 model for retry to trigger
|
|
||||||
mockRouterService.route.mockResolvedValue({
|
|
||||||
model: 'gemini-2.0-flash',
|
|
||||||
reason: 'test',
|
|
||||||
});
|
|
||||||
|
|
||||||
const mockStream1 = (async function* () {
|
|
||||||
yield { type: GeminiEventType.InvalidStream };
|
yield { type: GeminiEventType.InvalidStream };
|
||||||
})();
|
})();
|
||||||
const mockStream2 = (async function* () {
|
|
||||||
yield { type: GeminiEventType.Content, value: 'Continued content' };
|
|
||||||
})();
|
|
||||||
|
|
||||||
mockTurnRunFn
|
mockTurnRunFn.mockReturnValueOnce(mockStream);
|
||||||
.mockReturnValueOnce(mockStream1)
|
|
||||||
.mockReturnValueOnce(mockStream2);
|
|
||||||
|
|
||||||
const mockChat: Partial<GeminiChat> = {
|
const mockChat: Partial<GeminiChat> = {
|
||||||
addHistory: vi.fn(),
|
addHistory: vi.fn(),
|
||||||
@@ -2096,117 +2079,16 @@ ${JSON.stringify(
|
|||||||
const stream = client.sendMessageStream(initialRequest, signal, promptId);
|
const stream = client.sendMessageStream(initialRequest, signal, promptId);
|
||||||
const events = await fromAsync(stream);
|
const events = await fromAsync(stream);
|
||||||
|
|
||||||
// Assert
|
// Assert: the InvalidStream event is forwarded to the consumer and the
|
||||||
expect(events).toEqual([
|
// turn ends. No "System: Please continue." is injected and turn.run is
|
||||||
{ type: GeminiEventType.ModelInfo, value: 'gemini-2.0-flash' },
|
// not called a second time.
|
||||||
{ type: GeminiEventType.InvalidStream },
|
|
||||||
{ type: GeminiEventType.Content, value: 'Continued content' },
|
|
||||||
]);
|
|
||||||
|
|
||||||
// Verify that turn.run was called twice
|
|
||||||
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
|
|
||||||
|
|
||||||
// First call with original request
|
|
||||||
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
|
|
||||||
1,
|
|
||||||
{ model: 'gemini-2.0-flash', isChatModel: true },
|
|
||||||
initialRequest,
|
|
||||||
expect.any(AbortSignal),
|
|
||||||
undefined,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Second call with "Please continue."
|
|
||||||
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
|
|
||||||
2,
|
|
||||||
{ model: 'gemini-2.0-flash', isChatModel: true },
|
|
||||||
[{ text: 'System: Please continue.' }],
|
|
||||||
expect.any(AbortSignal),
|
|
||||||
undefined,
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should not recursively call sendMessageStream with "Please continue." when InvalidStream event is received and flag is false', async () => {
|
|
||||||
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
// Arrange
|
|
||||||
const mockStream1 = (async function* () {
|
|
||||||
yield { type: GeminiEventType.InvalidStream };
|
|
||||||
})();
|
|
||||||
|
|
||||||
mockTurnRunFn.mockReturnValueOnce(mockStream1);
|
|
||||||
|
|
||||||
const mockChat: Partial<GeminiChat> = {
|
|
||||||
addHistory: vi.fn(),
|
|
||||||
setTools: vi.fn(),
|
|
||||||
getHistory: vi.fn().mockReturnValue([]),
|
|
||||||
getLastPromptTokenCount: vi.fn(),
|
|
||||||
};
|
|
||||||
client['chat'] = mockChat as GeminiChat;
|
|
||||||
|
|
||||||
const initialRequest = [{ text: 'Hi' }];
|
|
||||||
const promptId = 'prompt-id-invalid-stream';
|
|
||||||
const signal = new AbortController().signal;
|
|
||||||
|
|
||||||
// Act
|
|
||||||
const stream = client.sendMessageStream(initialRequest, signal, promptId);
|
|
||||||
const events = await fromAsync(stream);
|
|
||||||
|
|
||||||
// Assert
|
|
||||||
expect(events).toEqual([
|
expect(events).toEqual([
|
||||||
{ type: GeminiEventType.ModelInfo, value: 'default-routed-model' },
|
{ type: GeminiEventType.ModelInfo, value: 'default-routed-model' },
|
||||||
{ type: GeminiEventType.InvalidStream },
|
{ type: GeminiEventType.InvalidStream },
|
||||||
]);
|
]);
|
||||||
|
|
||||||
// Verify that turn.run was called only once
|
|
||||||
expect(mockTurnRunFn).toHaveBeenCalledTimes(1);
|
expect(mockTurnRunFn).toHaveBeenCalledTimes(1);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should stop recursing after one retry when InvalidStream events are repeatedly received', async () => {
|
|
||||||
vi.spyOn(client['config'], 'getContinueOnFailedApiCall').mockReturnValue(
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
// Arrange - router must return a Gemini 2 model for retry to trigger
|
|
||||||
mockRouterService.route.mockResolvedValue({
|
|
||||||
model: 'gemini-2.0-flash',
|
|
||||||
reason: 'test',
|
|
||||||
});
|
|
||||||
// Always return a new invalid stream
|
|
||||||
mockTurnRunFn.mockImplementation(() =>
|
|
||||||
(async function* () {
|
|
||||||
yield { type: GeminiEventType.InvalidStream };
|
|
||||||
})(),
|
|
||||||
);
|
|
||||||
|
|
||||||
const mockChat: Partial<GeminiChat> = {
|
|
||||||
addHistory: vi.fn(),
|
|
||||||
setTools: vi.fn(),
|
|
||||||
getHistory: vi.fn().mockReturnValue([]),
|
|
||||||
getLastPromptTokenCount: vi.fn(),
|
|
||||||
};
|
|
||||||
client['chat'] = mockChat as GeminiChat;
|
|
||||||
|
|
||||||
const initialRequest = [{ text: 'Hi' }];
|
|
||||||
const promptId = 'prompt-id-infinite-invalid-stream';
|
|
||||||
const signal = new AbortController().signal;
|
|
||||||
|
|
||||||
// Act
|
|
||||||
const stream = client.sendMessageStream(initialRequest, signal, promptId);
|
|
||||||
const events = await fromAsync(stream);
|
|
||||||
|
|
||||||
// Assert
|
|
||||||
// We expect 3 events (model_info + original + 1 retry)
|
|
||||||
expect(events.length).toBe(3);
|
|
||||||
expect(
|
|
||||||
events
|
|
||||||
.filter((e) => e.type === GeminiEventType.ModelInfo)
|
|
||||||
.map((e) => e.value),
|
|
||||||
).toEqual(['gemini-2.0-flash']);
|
|
||||||
|
|
||||||
// Verify that turn.run was called twice
|
|
||||||
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('Editor context delta', () => {
|
describe('Editor context delta', () => {
|
||||||
const mockStream = (async function* () {
|
const mockStream = (async function* () {
|
||||||
yield { type: 'content', value: 'Hello' };
|
yield { type: 'content', value: 'Hello' };
|
||||||
@@ -2584,42 +2466,6 @@ ${JSON.stringify(
|
|||||||
|
|
||||||
expect(mockConfig.resetTurn).toHaveBeenCalled();
|
expect(mockConfig.resetTurn).toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should NOT reset turn on invalid stream retry', async () => {
|
|
||||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue(
|
|
||||||
{
|
|
||||||
selectedModel: 'model-a',
|
|
||||||
skipped: [],
|
|
||||||
},
|
|
||||||
);
|
|
||||||
// We simulate a retry by calling sendMessageStream with isInvalidStreamRetry=true
|
|
||||||
// But the public API doesn't expose that argument directly unless we use the private method or simulate the recursion.
|
|
||||||
// We can simulate recursion by mocking turn run to return invalid stream once.
|
|
||||||
|
|
||||||
vi.spyOn(
|
|
||||||
client['config'],
|
|
||||||
'getContinueOnFailedApiCall',
|
|
||||||
).mockReturnValue(true);
|
|
||||||
const mockStream1 = (async function* () {
|
|
||||||
yield { type: GeminiEventType.InvalidStream };
|
|
||||||
})();
|
|
||||||
const mockStream2 = (async function* () {
|
|
||||||
yield { type: 'content', value: 'ok' };
|
|
||||||
})();
|
|
||||||
mockTurnRunFn
|
|
||||||
.mockReturnValueOnce(mockStream1)
|
|
||||||
.mockReturnValueOnce(mockStream2);
|
|
||||||
|
|
||||||
const stream = client.sendMessageStream(
|
|
||||||
[{ text: 'Hi' }],
|
|
||||||
new AbortController().signal,
|
|
||||||
'prompt-retry',
|
|
||||||
);
|
|
||||||
await fromAsync(stream);
|
|
||||||
|
|
||||||
// resetTurn should be called once (for the initial call) but NOT for the recursive call
|
|
||||||
expect(mockConfig.resetTurn).toHaveBeenCalledTimes(1);
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('IDE context with pending tool calls', () => {
|
describe('IDE context with pending tool calls', () => {
|
||||||
|
|||||||
@@ -47,19 +47,12 @@ import { ChatCompressionService } from '../context/chatCompressionService.js';
|
|||||||
import { AgentHistoryProvider } from '../context/agentHistoryProvider.js';
|
import { AgentHistoryProvider } from '../context/agentHistoryProvider.js';
|
||||||
import type { ContextManager } from '../context/contextManager.js';
|
import type { ContextManager } from '../context/contextManager.js';
|
||||||
import { ideContextStore } from '../ide/ideContext.js';
|
import { ideContextStore } from '../ide/ideContext.js';
|
||||||
import {
|
import { logNextSpeakerCheck } from '../telemetry/loggers.js';
|
||||||
logContentRetryFailure,
|
|
||||||
logNextSpeakerCheck,
|
|
||||||
} from '../telemetry/loggers.js';
|
|
||||||
import type {
|
import type {
|
||||||
DefaultHookOutput,
|
DefaultHookOutput,
|
||||||
AfterAgentHookOutput,
|
AfterAgentHookOutput,
|
||||||
} from '../hooks/types.js';
|
} from '../hooks/types.js';
|
||||||
import {
|
import { NextSpeakerCheckEvent, type LlmRole } from '../telemetry/types.js';
|
||||||
ContentRetryFailureEvent,
|
|
||||||
NextSpeakerCheckEvent,
|
|
||||||
type LlmRole,
|
|
||||||
} from '../telemetry/types.js';
|
|
||||||
import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
|
import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
|
||||||
import type { IdeContext, File } from '../ide/types.js';
|
import type { IdeContext, File } from '../ide/types.js';
|
||||||
import { handleFallback } from '../fallback/handler.js';
|
import { handleFallback } from '../fallback/handler.js';
|
||||||
@@ -603,7 +596,6 @@ export class GeminiClient {
|
|||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
prompt_id: string,
|
prompt_id: string,
|
||||||
boundedTurns: number,
|
boundedTurns: number,
|
||||||
isInvalidStreamRetry: boolean,
|
|
||||||
displayContent?: PartListUnion,
|
displayContent?: PartListUnion,
|
||||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||||
// Re-initialize turn (it was empty before if in loop, or new instance)
|
// Re-initialize turn (it was empty before if in loop, or new instance)
|
||||||
@@ -708,7 +700,6 @@ export class GeminiClient {
|
|||||||
signal,
|
signal,
|
||||||
prompt_id,
|
prompt_id,
|
||||||
boundedTurns,
|
boundedTurns,
|
||||||
isInvalidStreamRetry,
|
|
||||||
displayContent,
|
displayContent,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -758,7 +749,6 @@ export class GeminiClient {
|
|||||||
displayContent,
|
displayContent,
|
||||||
);
|
);
|
||||||
let isError = false;
|
let isError = false;
|
||||||
let isInvalidStream = false;
|
|
||||||
|
|
||||||
let loopDetectedAbort = false;
|
let loopDetectedAbort = false;
|
||||||
let loopRecoverResult: { detail?: string } | undefined;
|
let loopRecoverResult: { detail?: string } | undefined;
|
||||||
@@ -781,9 +771,6 @@ export class GeminiClient {
|
|||||||
|
|
||||||
this.updateTelemetryTokenCount();
|
this.updateTelemetryTokenCount();
|
||||||
|
|
||||||
if (event.type === GeminiEventType.InvalidStream) {
|
|
||||||
isInvalidStream = true;
|
|
||||||
}
|
|
||||||
if (event.type === GeminiEventType.Error) {
|
if (event.type === GeminiEventType.Error) {
|
||||||
isError = true;
|
isError = true;
|
||||||
}
|
}
|
||||||
@@ -799,7 +786,6 @@ export class GeminiClient {
|
|||||||
signal,
|
signal,
|
||||||
prompt_id,
|
prompt_id,
|
||||||
boundedTurns,
|
boundedTurns,
|
||||||
isInvalidStreamRetry,
|
|
||||||
displayContent,
|
displayContent,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -821,33 +807,6 @@ export class GeminiClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isInvalidStream) {
|
|
||||||
if (this.config.getContinueOnFailedApiCall()) {
|
|
||||||
if (isInvalidStreamRetry) {
|
|
||||||
logContentRetryFailure(
|
|
||||||
this.config,
|
|
||||||
new ContentRetryFailureEvent(
|
|
||||||
4,
|
|
||||||
'FAILED_AFTER_PROMPT_INJECTION',
|
|
||||||
modelToUse,
|
|
||||||
),
|
|
||||||
);
|
|
||||||
return turn;
|
|
||||||
}
|
|
||||||
const nextRequest = [{ text: 'System: Please continue.' }];
|
|
||||||
// Recursive call - update turn with result
|
|
||||||
turn = yield* this.sendMessageStream(
|
|
||||||
nextRequest,
|
|
||||||
signal,
|
|
||||||
prompt_id,
|
|
||||||
boundedTurns - 1,
|
|
||||||
true,
|
|
||||||
displayContent,
|
|
||||||
);
|
|
||||||
return turn;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
|
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
|
||||||
if (
|
if (
|
||||||
!this.config.getQuotaErrorOccurred() &&
|
!this.config.getQuotaErrorOccurred() &&
|
||||||
@@ -874,7 +833,6 @@ export class GeminiClient {
|
|||||||
signal,
|
signal,
|
||||||
prompt_id,
|
prompt_id,
|
||||||
boundedTurns - 1,
|
boundedTurns - 1,
|
||||||
false, // isInvalidStreamRetry is false
|
|
||||||
displayContent,
|
displayContent,
|
||||||
);
|
);
|
||||||
return turn;
|
return turn;
|
||||||
@@ -889,13 +847,10 @@ export class GeminiClient {
|
|||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
prompt_id: string,
|
prompt_id: string,
|
||||||
turns: number = MAX_TURNS,
|
turns: number = MAX_TURNS,
|
||||||
isInvalidStreamRetry: boolean = false,
|
|
||||||
displayContent?: PartListUnion,
|
displayContent?: PartListUnion,
|
||||||
stopHookActive: boolean = false,
|
stopHookActive: boolean = false,
|
||||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||||
if (!isInvalidStreamRetry) {
|
this.config.resetTurn();
|
||||||
this.config.resetTurn();
|
|
||||||
}
|
|
||||||
|
|
||||||
const hooksEnabled = this.config.getEnableHooks();
|
const hooksEnabled = this.config.getEnableHooks();
|
||||||
const messageBus = this.context.messageBus;
|
const messageBus = this.context.messageBus;
|
||||||
@@ -947,7 +902,6 @@ export class GeminiClient {
|
|||||||
signal,
|
signal,
|
||||||
prompt_id,
|
prompt_id,
|
||||||
boundedTurns,
|
boundedTurns,
|
||||||
isInvalidStreamRetry,
|
|
||||||
displayContent,
|
displayContent,
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -1009,7 +963,6 @@ export class GeminiClient {
|
|||||||
signal,
|
signal,
|
||||||
prompt_id,
|
prompt_id,
|
||||||
boundedTurns - 1,
|
boundedTurns - 1,
|
||||||
false,
|
|
||||||
displayContent,
|
displayContent,
|
||||||
true, // stopHookActive: signal retry to AfterAgent hooks
|
true, // stopHookActive: signal retry to AfterAgent hooks
|
||||||
);
|
);
|
||||||
@@ -1254,7 +1207,6 @@ export class GeminiClient {
|
|||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
prompt_id: string,
|
prompt_id: string,
|
||||||
boundedTurns: number,
|
boundedTurns: number,
|
||||||
isInvalidStreamRetry: boolean,
|
|
||||||
displayContent?: PartListUnion,
|
displayContent?: PartListUnion,
|
||||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||||
// Clear the detection flag so the recursive turn can proceed, but the count remains 1.
|
// Clear the detection flag so the recursive turn can proceed, but the count remains 1.
|
||||||
@@ -1276,7 +1228,6 @@ export class GeminiClient {
|
|||||||
signal,
|
signal,
|
||||||
prompt_id,
|
prompt_id,
|
||||||
boundedTurns - 1,
|
boundedTurns - 1,
|
||||||
isInvalidStreamRetry,
|
|
||||||
displayContent,
|
displayContent,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -744,25 +744,41 @@ describe('GeminiChat', () => {
|
|||||||
).rejects.toThrow(InvalidStreamError);
|
).rejects.toThrow(InvalidStreamError);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should throw InvalidStreamError when no tool call and empty response text', async () => {
|
it('should throw InvalidStreamError without retrying when no tool call and empty response text', async () => {
|
||||||
// Setup: Stream with finish reason but empty response (only thoughts)
|
vi.mocked(mockContentGenerator.generateContentStream)
|
||||||
const streamWithEmptyResponse = (async function* () {
|
.mockImplementationOnce(async () =>
|
||||||
yield {
|
// First attempt: finish reason is present, but the stream has no
|
||||||
candidates: [
|
// non-thought text, which is NO_RESPONSE_TEXT.
|
||||||
{
|
(async function* () {
|
||||||
content: {
|
yield {
|
||||||
role: 'model',
|
candidates: [
|
||||||
parts: [{ thought: 'thinking...' }],
|
{
|
||||||
},
|
content: {
|
||||||
finishReason: 'STOP',
|
role: 'model',
|
||||||
},
|
parts: [{ thought: true, text: 'thinking...' }],
|
||||||
],
|
},
|
||||||
} as unknown as GenerateContentResponse;
|
finishReason: 'STOP',
|
||||||
})();
|
},
|
||||||
|
],
|
||||||
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
} as unknown as GenerateContentResponse;
|
||||||
streamWithEmptyResponse,
|
})(),
|
||||||
);
|
)
|
||||||
|
.mockImplementationOnce(async () =>
|
||||||
|
// This would succeed if NO_RESPONSE_TEXT were retried.
|
||||||
|
(async function* () {
|
||||||
|
yield {
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: 'valid response after retry' }],
|
||||||
|
},
|
||||||
|
finishReason: 'STOP',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
} as unknown as GenerateContentResponse;
|
||||||
|
})(),
|
||||||
|
);
|
||||||
|
|
||||||
const stream = await chat.sendMessageStream(
|
const stream = await chat.sendMessageStream(
|
||||||
{ model: 'gemini-2.0-flash' },
|
{ model: 'gemini-2.0-flash' },
|
||||||
@@ -779,6 +795,11 @@ describe('GeminiChat', () => {
|
|||||||
}
|
}
|
||||||
})(),
|
})(),
|
||||||
).rejects.toThrow(InvalidStreamError);
|
).rejects.toThrow(InvalidStreamError);
|
||||||
|
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(
|
||||||
|
1,
|
||||||
|
);
|
||||||
|
expect(mockLogContentRetry).not.toHaveBeenCalled();
|
||||||
|
expect(mockLogContentRetryFailure).toHaveBeenCalledTimes(1);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should succeed when there is finish reason and response text', async () => {
|
it('should succeed when there is finish reason and response text', async () => {
|
||||||
|
|||||||
@@ -424,11 +424,13 @@ export class GeminiChat {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const isContentError = error instanceof InvalidStreamError;
|
const isContentError = error instanceof InvalidStreamError;
|
||||||
|
const isRetryableContentError =
|
||||||
|
isContentError && error.type !== 'NO_RESPONSE_TEXT';
|
||||||
const errorType = isContentError
|
const errorType = isContentError
|
||||||
? error.type
|
? error.type
|
||||||
: getRetryErrorType(error);
|
: getRetryErrorType(error);
|
||||||
|
|
||||||
if (isContentError || (isRetryable && !signal.aborted)) {
|
if (isRetryableContentError || (isRetryable && !signal.aborted)) {
|
||||||
// The issue requests exactly 3 retries (4 attempts) for API errors during stream iteration.
|
// The issue requests exactly 3 retries (4 attempts) for API errors during stream iteration.
|
||||||
// Regardless of the global maxAttempts (e.g. 10), we only want to retry these mid-stream API errors
|
// Regardless of the global maxAttempts (e.g. 10), we only want to retry these mid-stream API errors
|
||||||
// up to 3 times before finally throwing the error to the user.
|
// up to 3 times before finally throwing the error to the user.
|
||||||
|
|||||||
Reference in New Issue
Block a user