mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
fix(core): improve API response error handling and retry logic (#14563)
This commit is contained in:
@@ -69,9 +69,13 @@ const { mockRetryWithBackoff } = vi.hoisted(() => ({
|
||||
mockRetryWithBackoff: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../utils/retry.js', () => ({
|
||||
retryWithBackoff: mockRetryWithBackoff,
|
||||
}));
|
||||
vi.mock('../utils/retry.js', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('../utils/retry.js')>();
|
||||
return {
|
||||
...actual,
|
||||
retryWithBackoff: mockRetryWithBackoff,
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('../fallback/handler.js', () => ({
|
||||
handleFallback: mockHandleFallback,
|
||||
|
||||
@@ -19,7 +19,7 @@ import type {
|
||||
import { ThinkingLevel } from '@google/genai';
|
||||
import { toParts } from '../code_assist/converter.js';
|
||||
import { createUserContent, FinishReason } from '@google/genai';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
import { retryWithBackoff, isRetryableError } from '../utils/retry.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
@@ -310,6 +310,7 @@ export class GeminiChat {
|
||||
}
|
||||
|
||||
for (let attempt = 0; attempt < maxAttempts; attempt++) {
|
||||
let isConnectionPhase = true;
|
||||
try {
|
||||
if (attempt > 0) {
|
||||
yield { type: StreamEventType.RETRY };
|
||||
@@ -320,13 +321,14 @@ export class GeminiChat {
|
||||
generateContentConfig.temperature = 1;
|
||||
}
|
||||
|
||||
isConnectionPhase = true;
|
||||
const stream = await this.makeApiCallAndProcessStream(
|
||||
model,
|
||||
generateContentConfig,
|
||||
requestContents,
|
||||
prompt_id,
|
||||
);
|
||||
|
||||
isConnectionPhase = false;
|
||||
for await (const chunk of stream) {
|
||||
yield { type: StreamEventType.CHUNK, value: chunk };
|
||||
}
|
||||
@@ -334,27 +336,33 @@ export class GeminiChat {
|
||||
lastError = null;
|
||||
break;
|
||||
} catch (error) {
|
||||
if (isConnectionPhase) {
|
||||
throw error;
|
||||
}
|
||||
lastError = error;
|
||||
const isContentError = error instanceof InvalidStreamError;
|
||||
const isRetryable = isRetryableError(
|
||||
error,
|
||||
this.config.getRetryFetchErrors(),
|
||||
);
|
||||
|
||||
if (isContentError && isGemini2Model(model)) {
|
||||
if (
|
||||
(isContentError && isGemini2Model(model)) ||
|
||||
(isRetryable && !signal.aborted)
|
||||
) {
|
||||
// Check if we have more attempts left.
|
||||
if (attempt < maxAttempts - 1) {
|
||||
const delayMs = INVALID_CONTENT_RETRY_OPTIONS.initialDelayMs;
|
||||
const retryType = isContentError
|
||||
? (error as InvalidStreamError).type
|
||||
: 'NETWORK_ERROR';
|
||||
|
||||
logContentRetry(
|
||||
this.config,
|
||||
new ContentRetryEvent(
|
||||
attempt,
|
||||
(error as InvalidStreamError).type,
|
||||
INVALID_CONTENT_RETRY_OPTIONS.initialDelayMs,
|
||||
model,
|
||||
),
|
||||
new ContentRetryEvent(attempt, retryType, delayMs, model),
|
||||
);
|
||||
await new Promise((res) =>
|
||||
setTimeout(
|
||||
res,
|
||||
INVALID_CONTENT_RETRY_OPTIONS.initialDelayMs *
|
||||
(attempt + 1),
|
||||
),
|
||||
setTimeout(res, delayMs * (attempt + 1)),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
271
packages/core/src/core/geminiChat_network_retry.test.ts
Normal file
271
packages/core/src/core/geminiChat_network_retry.test.ts
Normal file
@@ -0,0 +1,271 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import type { GenerateContentResponse } from '@google/genai';
|
||||
import { ApiError } from '@google/genai';
|
||||
import type { ContentGenerator } from '../core/contentGenerator.js';
|
||||
import { GeminiChat, StreamEventType, type StreamEvent } from './geminiChat.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { setSimulate429 } from '../utils/testUtils.js';
|
||||
import { HookSystem } from '../hooks/hookSystem.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
|
||||
// Mock fs module
|
||||
vi.mock('node:fs', () => ({
|
||||
default: {
|
||||
mkdirSync: vi.fn(),
|
||||
writeFileSync: vi.fn(),
|
||||
readFileSync: vi.fn(() => {
|
||||
const error = new Error('ENOENT');
|
||||
(error as NodeJS.ErrnoException).code = 'ENOENT';
|
||||
throw error;
|
||||
}),
|
||||
existsSync: vi.fn(() => false),
|
||||
},
|
||||
}));
|
||||
|
||||
const { mockRetryWithBackoff } = vi.hoisted(() => ({
|
||||
mockRetryWithBackoff: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../utils/retry.js', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('../utils/retry.js')>();
|
||||
return {
|
||||
...actual,
|
||||
retryWithBackoff: mockRetryWithBackoff,
|
||||
};
|
||||
});
|
||||
|
||||
// Mock loggers
|
||||
const { mockLogContentRetry, mockLogContentRetryFailure } = vi.hoisted(() => ({
|
||||
mockLogContentRetry: vi.fn(),
|
||||
mockLogContentRetryFailure: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../telemetry/loggers.js', () => ({
|
||||
logContentRetry: mockLogContentRetry,
|
||||
logContentRetryFailure: mockLogContentRetryFailure,
|
||||
}));
|
||||
|
||||
describe('GeminiChat Network Retries', () => {
|
||||
let mockContentGenerator: ContentGenerator;
|
||||
let chat: GeminiChat;
|
||||
let mockConfig: Config;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
mockContentGenerator = {
|
||||
generateContent: vi.fn(),
|
||||
generateContentStream: vi.fn(),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
// Default mock implementation: execute the function immediately
|
||||
mockRetryWithBackoff.mockImplementation(async (apiCall) => apiCall());
|
||||
|
||||
mockConfig = {
|
||||
getSessionId: () => 'test-session-id',
|
||||
getTelemetryLogPromptsEnabled: () => true,
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getPreviewFeatures: () => false,
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
authType: 'oauth-personal',
|
||||
model: 'test-model',
|
||||
}),
|
||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||
isInFallbackMode: vi.fn().mockReturnValue(false),
|
||||
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
|
||||
getProjectRoot: vi.fn().mockReturnValue('/test/project/root'),
|
||||
storage: {
|
||||
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
|
||||
},
|
||||
getToolRegistry: vi.fn().mockReturnValue({ getTool: vi.fn() }),
|
||||
getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator),
|
||||
getRetryFetchErrors: vi.fn().mockReturnValue(false), // Default false
|
||||
modelConfigService: {
|
||||
getResolvedConfig: vi.fn().mockImplementation((modelConfigKey) => ({
|
||||
model: modelConfigKey.model,
|
||||
generateContentConfig: { temperature: 0 },
|
||||
})),
|
||||
},
|
||||
isPreviewModelBypassMode: vi.fn().mockReturnValue(false),
|
||||
setPreviewModelBypassMode: vi.fn(),
|
||||
isPreviewModelFallbackMode: vi.fn().mockReturnValue(false),
|
||||
getEnableHooks: vi.fn().mockReturnValue(false),
|
||||
setPreviewModelFallbackMode: vi.fn(),
|
||||
} as unknown as Config;
|
||||
|
||||
const mockMessageBus = createMockMessageBus();
|
||||
mockConfig.getMessageBus = vi.fn().mockReturnValue(mockMessageBus);
|
||||
mockConfig.getHookSystem = vi
|
||||
.fn()
|
||||
.mockReturnValue(new HookSystem(mockConfig));
|
||||
|
||||
setSimulate429(false);
|
||||
chat = new GeminiChat(mockConfig);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should retry when a 503 ApiError occurs during stream iteration', async () => {
|
||||
// 1. Mock the API to yield one chunk, then throw a 503 error.
|
||||
const error503 = new ApiError({
|
||||
message: 'Service Unavailable',
|
||||
status: 503,
|
||||
});
|
||||
|
||||
vi.mocked(mockContentGenerator.generateContentStream)
|
||||
.mockImplementationOnce(async () =>
|
||||
(async function* () {
|
||||
yield {
|
||||
candidates: [{ content: { parts: [{ text: 'First part' }] } }],
|
||||
} as unknown as GenerateContentResponse;
|
||||
throw error503;
|
||||
})(),
|
||||
)
|
||||
.mockImplementationOnce(async () =>
|
||||
(async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Retry success' }] },
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})(),
|
||||
);
|
||||
|
||||
// 2. Execute sendMessageStream
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ model: 'test-model' },
|
||||
'test message',
|
||||
'prompt-id-retry-network',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
const events: StreamEvent[] = [];
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// 3. Assertions
|
||||
// Expected sequence: CHUNK('First part') -> RETRY -> CHUNK('Retry success')
|
||||
expect(events.length).toBeGreaterThanOrEqual(3);
|
||||
|
||||
const firstChunk = events.find(
|
||||
(e) =>
|
||||
e.type === StreamEventType.CHUNK &&
|
||||
e.value.candidates?.[0]?.content?.parts?.[0]?.text === 'First part',
|
||||
);
|
||||
expect(firstChunk).toBeDefined();
|
||||
|
||||
const retryEvent = events.find((e) => e.type === StreamEventType.RETRY);
|
||||
expect(retryEvent).toBeDefined();
|
||||
|
||||
const successChunk = events.find(
|
||||
(e) =>
|
||||
e.type === StreamEventType.CHUNK &&
|
||||
e.value.candidates?.[0]?.content?.parts?.[0]?.text === 'Retry success',
|
||||
);
|
||||
expect(successChunk).toBeDefined();
|
||||
|
||||
// Verify retry logging
|
||||
expect(mockLogContentRetry).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
error_type: 'NETWORK_ERROR',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should retry on generic network error if retryFetchErrors is true', async () => {
|
||||
vi.mocked(mockConfig.getRetryFetchErrors).mockReturnValue(true);
|
||||
|
||||
const fetchError = new Error('fetch failed: socket hang up');
|
||||
|
||||
vi.mocked(mockContentGenerator.generateContentStream)
|
||||
.mockImplementationOnce(async () =>
|
||||
(async function* () {
|
||||
yield {
|
||||
candidates: [{ content: { parts: [{ text: '' }] } }],
|
||||
} as GenerateContentResponse; // Dummy yield
|
||||
throw fetchError;
|
||||
})(),
|
||||
)
|
||||
.mockImplementationOnce(async () =>
|
||||
(async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Success' }] },
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})(),
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ model: 'test-model' },
|
||||
'test message',
|
||||
'prompt-id-retry-fetch',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
const events: StreamEvent[] = [];
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const retryEvent = events.find((e) => e.type === StreamEventType.RETRY);
|
||||
expect(retryEvent).toBeDefined();
|
||||
|
||||
const successChunk = events.find(
|
||||
(e) =>
|
||||
e.type === StreamEventType.CHUNK &&
|
||||
e.value.candidates?.[0]?.content?.parts?.[0]?.text === 'Success',
|
||||
);
|
||||
expect(successChunk).toBeDefined();
|
||||
});
|
||||
|
||||
it('should NOT retry on 400 ApiError', async () => {
|
||||
const error400 = new ApiError({
|
||||
message: 'Bad Request',
|
||||
status: 400,
|
||||
});
|
||||
|
||||
vi.mocked(
|
||||
mockContentGenerator.generateContentStream,
|
||||
).mockImplementationOnce(async () =>
|
||||
(async function* () {
|
||||
yield {
|
||||
candidates: [{ content: { parts: [{ text: '' }] } }],
|
||||
} as GenerateContentResponse; // Dummy yield
|
||||
throw error400;
|
||||
})(),
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ model: 'test-model' },
|
||||
'test message',
|
||||
'prompt-id-no-retry',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
await expect(async () => {
|
||||
for await (const _ of stream) {
|
||||
// consume
|
||||
}
|
||||
}).rejects.toThrow(error400);
|
||||
|
||||
expect(mockLogContentRetry).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -134,6 +134,7 @@ describe('WebFetchTool', () => {
|
||||
setApprovalMode: vi.fn(),
|
||||
getProxy: vi.fn(),
|
||||
getGeminiClient: mockGetGeminiClient,
|
||||
getRetryFetchErrors: vi.fn().mockReturnValue(false),
|
||||
modelConfigService: {
|
||||
getResolvedConfig: vi.fn().mockImplementation(({ model }) => ({
|
||||
model,
|
||||
|
||||
@@ -29,6 +29,7 @@ import {
|
||||
} from '../telemetry/index.js';
|
||||
import { WEB_FETCH_TOOL_NAME } from './tool-names.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
|
||||
const URL_FETCH_TIMEOUT_MS = 10000;
|
||||
const MAX_CONTENT_LENGTH = 100000;
|
||||
@@ -102,6 +103,10 @@ export interface WebFetchToolParams {
|
||||
prompt: string;
|
||||
}
|
||||
|
||||
interface ErrorWithStatus extends Error {
|
||||
status?: number;
|
||||
}
|
||||
|
||||
class WebFetchToolInvocation extends BaseToolInvocation<
|
||||
WebFetchToolParams,
|
||||
ToolResult
|
||||
@@ -129,12 +134,22 @@ class WebFetchToolInvocation extends BaseToolInvocation<
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS);
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Request failed with status code ${response.status} ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
const response = await retryWithBackoff(
|
||||
async () => {
|
||||
const res = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS);
|
||||
if (!res.ok) {
|
||||
const error = new Error(
|
||||
`Request failed with status code ${res.status} ${res.statusText}`,
|
||||
);
|
||||
(error as ErrorWithStatus).status = res.status;
|
||||
throw error;
|
||||
}
|
||||
return res;
|
||||
},
|
||||
{
|
||||
retryFetchErrors: this.config.getRetryFetchErrors(),
|
||||
},
|
||||
);
|
||||
|
||||
const rawContent = await response.text();
|
||||
const contentType = response.headers.get('content-type') || '';
|
||||
|
||||
@@ -22,8 +22,9 @@ export class FetchError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public code?: string,
|
||||
options?: ErrorOptions,
|
||||
) {
|
||||
super(message);
|
||||
super(message, options);
|
||||
this.name = 'FetchError';
|
||||
}
|
||||
}
|
||||
@@ -51,7 +52,7 @@ export async function fetchWithTimeout(
|
||||
if (isNodeError(error) && error.code === 'ABORT_ERR') {
|
||||
throw new FetchError(`Request timed out after ${timeout}ms`, 'ETIMEDOUT');
|
||||
}
|
||||
throw new FetchError(getErrorMessage(error));
|
||||
throw new FetchError(getErrorMessage(error), undefined, { cause: error });
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
|
||||
@@ -307,7 +307,7 @@ describe('retryWithBackoff', () => {
|
||||
});
|
||||
|
||||
describe('Fetch error retries', () => {
|
||||
it('should retry on specific fetch error when retryFetchErrors is true', async () => {
|
||||
it("should retry on 'fetch failed' when retryFetchErrors is true", async () => {
|
||||
const mockFn = vi.fn();
|
||||
mockFn.mockRejectedValueOnce(new TypeError('fetch failed'));
|
||||
mockFn.mockResolvedValueOnce('success');
|
||||
@@ -365,19 +365,48 @@ describe('retryWithBackoff', () => {
|
||||
expect(mockFn).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it.each([false, undefined])(
|
||||
'should not retry on specific fetch error when retryFetchErrors is %s',
|
||||
async (retryFetchErrors) => {
|
||||
const mockFn = vi.fn().mockRejectedValue(new TypeError('fetch failed'));
|
||||
it("should retry on 'fetch failed' when retryFetchErrors is true (short delays)", async () => {
|
||||
const mockFn = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(new TypeError('fetch failed'))
|
||||
.mockResolvedValue('success');
|
||||
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
retryFetchErrors,
|
||||
});
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
retryFetchErrors: true,
|
||||
initialDelayMs: 1,
|
||||
maxDelayMs: 1,
|
||||
});
|
||||
await vi.runAllTimersAsync();
|
||||
await expect(promise).resolves.toBe('success');
|
||||
});
|
||||
|
||||
await expect(promise).rejects.toThrow('fetch failed');
|
||||
expect(mockFn).toHaveBeenCalledTimes(1);
|
||||
},
|
||||
);
|
||||
it("should not retry on 'fetch failed' when retryFetchErrors is false", async () => {
|
||||
const mockFn = vi.fn().mockRejectedValue(new TypeError('fetch failed'));
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
retryFetchErrors: false,
|
||||
initialDelayMs: 1,
|
||||
maxDelayMs: 1,
|
||||
});
|
||||
await expect(promise).rejects.toThrow('fetch failed');
|
||||
expect(mockFn).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should retry on network error code (ETIMEDOUT) even when retryFetchErrors is false', async () => {
|
||||
const error = new Error('connect ETIMEDOUT');
|
||||
(error as any).code = 'ETIMEDOUT';
|
||||
const mockFn = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(error)
|
||||
.mockResolvedValue('success');
|
||||
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
retryFetchErrors: false,
|
||||
initialDelayMs: 1,
|
||||
maxDelayMs: 1,
|
||||
});
|
||||
await vi.runAllTimersAsync();
|
||||
await expect(promise).resolves.toBe('success');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Flash model fallback for OAuth users', () => {
|
||||
|
||||
@@ -35,7 +35,7 @@ const DEFAULT_RETRY_OPTIONS: RetryOptions = {
|
||||
maxAttempts: 3,
|
||||
initialDelayMs: 5000,
|
||||
maxDelayMs: 30000, // 30 seconds
|
||||
shouldRetryOnError: defaultShouldRetry,
|
||||
shouldRetryOnError: isRetryableError,
|
||||
};
|
||||
|
||||
const RETRYABLE_NETWORK_CODES = [
|
||||
@@ -79,21 +79,21 @@ const FETCH_FAILED_MESSAGE = 'fetch failed';
|
||||
* @param retryFetchErrors Whether to retry on specific fetch errors.
|
||||
* @returns True if the error is a transient error, false otherwise.
|
||||
*/
|
||||
function defaultShouldRetry(
|
||||
export function isRetryableError(
|
||||
error: Error | unknown,
|
||||
retryFetchErrors?: boolean,
|
||||
): boolean {
|
||||
// Check for common network error codes
|
||||
const errorCode = getNetworkErrorCode(error);
|
||||
if (errorCode && RETRYABLE_NETWORK_CODES.includes(errorCode)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (retryFetchErrors && error instanceof Error) {
|
||||
// Check for generic fetch failed message (case-insensitive)
|
||||
if (error.message.toLowerCase().includes(FETCH_FAILED_MESSAGE)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check for common network error codes
|
||||
const errorCode = getNetworkErrorCode(error);
|
||||
if (errorCode && RETRYABLE_NETWORK_CODES.includes(errorCode)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Priority check for ApiError
|
||||
@@ -147,6 +147,7 @@ export async function retryWithBackoff<T>(
|
||||
signal,
|
||||
} = {
|
||||
...DEFAULT_RETRY_OPTIONS,
|
||||
shouldRetryOnError: isRetryableError,
|
||||
...cleanOptions,
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user