From dd3fd73ffe9a5a084cae650bee7dcb26f8f7effb Mon Sep 17 00:00:00 2001 From: matt korwel Date: Fri, 5 Dec 2025 09:49:08 -0800 Subject: [PATCH] fix(core): improve API response error handling and retry logic (#14563) --- packages/core/src/core/geminiChat.test.ts | 10 +- packages/core/src/core/geminiChat.ts | 36 ++- .../src/core/geminiChat_network_retry.test.ts | 271 ++++++++++++++++++ packages/core/src/tools/web-fetch.test.ts | 1 + packages/core/src/tools/web-fetch.ts | 27 +- packages/core/src/utils/fetch.ts | 5 +- packages/core/src/utils/retry.test.ts | 53 +++- packages/core/src/utils/retry.ts | 17 +- 8 files changed, 375 insertions(+), 45 deletions(-) create mode 100644 packages/core/src/core/geminiChat_network_retry.test.ts diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index faa2b29ad4..4c2dc2786d 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -69,9 +69,13 @@ const { mockRetryWithBackoff } = vi.hoisted(() => ({ mockRetryWithBackoff: vi.fn(), })); -vi.mock('../utils/retry.js', () => ({ - retryWithBackoff: mockRetryWithBackoff, -})); +vi.mock('../utils/retry.js', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + retryWithBackoff: mockRetryWithBackoff, + }; +}); vi.mock('../fallback/handler.js', () => ({ handleFallback: mockHandleFallback, diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index bd8edce3cd..98b99dbe2f 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -19,7 +19,7 @@ import type { import { ThinkingLevel } from '@google/genai'; import { toParts } from '../code_assist/converter.js'; import { createUserContent, FinishReason } from '@google/genai'; -import { retryWithBackoff } from '../utils/retry.js'; +import { retryWithBackoff, isRetryableError } from '../utils/retry.js'; import type { Config } from '../config/config.js'; import { DEFAULT_GEMINI_MODEL, @@ -310,6 +310,7 @@ export class GeminiChat { } for (let attempt = 0; attempt < maxAttempts; attempt++) { + let isConnectionPhase = true; try { if (attempt > 0) { yield { type: StreamEventType.RETRY }; @@ -320,13 +321,14 @@ export class GeminiChat { generateContentConfig.temperature = 1; } + isConnectionPhase = true; const stream = await this.makeApiCallAndProcessStream( model, generateContentConfig, requestContents, prompt_id, ); - + isConnectionPhase = false; for await (const chunk of stream) { yield { type: StreamEventType.CHUNK, value: chunk }; } @@ -334,27 +336,33 @@ export class GeminiChat { lastError = null; break; } catch (error) { + if (isConnectionPhase) { + throw error; + } lastError = error; const isContentError = error instanceof InvalidStreamError; + const isRetryable = isRetryableError( + error, + this.config.getRetryFetchErrors(), + ); - if (isContentError && isGemini2Model(model)) { + if ( + (isContentError && isGemini2Model(model)) || + (isRetryable && !signal.aborted) + ) { // Check if we have more attempts left. if (attempt < maxAttempts - 1) { + const delayMs = INVALID_CONTENT_RETRY_OPTIONS.initialDelayMs; + const retryType = isContentError + ? (error as InvalidStreamError).type + : 'NETWORK_ERROR'; + logContentRetry( this.config, - new ContentRetryEvent( - attempt, - (error as InvalidStreamError).type, - INVALID_CONTENT_RETRY_OPTIONS.initialDelayMs, - model, - ), + new ContentRetryEvent(attempt, retryType, delayMs, model), ); await new Promise((res) => - setTimeout( - res, - INVALID_CONTENT_RETRY_OPTIONS.initialDelayMs * - (attempt + 1), - ), + setTimeout(res, delayMs * (attempt + 1)), ); continue; } diff --git a/packages/core/src/core/geminiChat_network_retry.test.ts b/packages/core/src/core/geminiChat_network_retry.test.ts new file mode 100644 index 0000000000..8952a1053a --- /dev/null +++ b/packages/core/src/core/geminiChat_network_retry.test.ts @@ -0,0 +1,271 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import type { GenerateContentResponse } from '@google/genai'; +import { ApiError } from '@google/genai'; +import type { ContentGenerator } from '../core/contentGenerator.js'; +import { GeminiChat, StreamEventType, type StreamEvent } from './geminiChat.js'; +import type { Config } from '../config/config.js'; +import { setSimulate429 } from '../utils/testUtils.js'; +import { HookSystem } from '../hooks/hookSystem.js'; +import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; + +// Mock fs module +vi.mock('node:fs', () => ({ + default: { + mkdirSync: vi.fn(), + writeFileSync: vi.fn(), + readFileSync: vi.fn(() => { + const error = new Error('ENOENT'); + (error as NodeJS.ErrnoException).code = 'ENOENT'; + throw error; + }), + existsSync: vi.fn(() => false), + }, +})); + +const { mockRetryWithBackoff } = vi.hoisted(() => ({ + mockRetryWithBackoff: vi.fn(), +})); + +vi.mock('../utils/retry.js', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + retryWithBackoff: mockRetryWithBackoff, + }; +}); + +// Mock loggers +const { mockLogContentRetry, mockLogContentRetryFailure } = vi.hoisted(() => ({ + mockLogContentRetry: vi.fn(), + mockLogContentRetryFailure: vi.fn(), +})); + +vi.mock('../telemetry/loggers.js', () => ({ + logContentRetry: mockLogContentRetry, + logContentRetryFailure: mockLogContentRetryFailure, +})); + +describe('GeminiChat Network Retries', () => { + let mockContentGenerator: ContentGenerator; + let chat: GeminiChat; + let mockConfig: Config; + + beforeEach(() => { + vi.clearAllMocks(); + + mockContentGenerator = { + generateContent: vi.fn(), + generateContentStream: vi.fn(), + } as unknown as ContentGenerator; + + // Default mock implementation: execute the function immediately + mockRetryWithBackoff.mockImplementation(async (apiCall) => apiCall()); + + mockConfig = { + getSessionId: () => 'test-session-id', + getTelemetryLogPromptsEnabled: () => true, + getUsageStatisticsEnabled: () => true, + getDebugMode: () => false, + getPreviewFeatures: () => false, + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: 'oauth-personal', + model: 'test-model', + }), + getModel: vi.fn().mockReturnValue('gemini-pro'), + isInFallbackMode: vi.fn().mockReturnValue(false), + getQuotaErrorOccurred: vi.fn().mockReturnValue(false), + getProjectRoot: vi.fn().mockReturnValue('/test/project/root'), + storage: { + getProjectTempDir: vi.fn().mockReturnValue('/test/temp'), + }, + getToolRegistry: vi.fn().mockReturnValue({ getTool: vi.fn() }), + getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator), + getRetryFetchErrors: vi.fn().mockReturnValue(false), // Default false + modelConfigService: { + getResolvedConfig: vi.fn().mockImplementation((modelConfigKey) => ({ + model: modelConfigKey.model, + generateContentConfig: { temperature: 0 }, + })), + }, + isPreviewModelBypassMode: vi.fn().mockReturnValue(false), + setPreviewModelBypassMode: vi.fn(), + isPreviewModelFallbackMode: vi.fn().mockReturnValue(false), + getEnableHooks: vi.fn().mockReturnValue(false), + setPreviewModelFallbackMode: vi.fn(), + } as unknown as Config; + + const mockMessageBus = createMockMessageBus(); + mockConfig.getMessageBus = vi.fn().mockReturnValue(mockMessageBus); + mockConfig.getHookSystem = vi + .fn() + .mockReturnValue(new HookSystem(mockConfig)); + + setSimulate429(false); + chat = new GeminiChat(mockConfig); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('should retry when a 503 ApiError occurs during stream iteration', async () => { + // 1. Mock the API to yield one chunk, then throw a 503 error. + const error503 = new ApiError({ + message: 'Service Unavailable', + status: 503, + }); + + vi.mocked(mockContentGenerator.generateContentStream) + .mockImplementationOnce(async () => + (async function* () { + yield { + candidates: [{ content: { parts: [{ text: 'First part' }] } }], + } as unknown as GenerateContentResponse; + throw error503; + })(), + ) + .mockImplementationOnce(async () => + (async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'Retry success' }] }, + finishReason: 'STOP', + }, + ], + } as unknown as GenerateContentResponse; + })(), + ); + + // 2. Execute sendMessageStream + const stream = await chat.sendMessageStream( + { model: 'test-model' }, + 'test message', + 'prompt-id-retry-network', + new AbortController().signal, + ); + + const events: StreamEvent[] = []; + for await (const event of stream) { + events.push(event); + } + + // 3. Assertions + // Expected sequence: CHUNK('First part') -> RETRY -> CHUNK('Retry success') + expect(events.length).toBeGreaterThanOrEqual(3); + + const firstChunk = events.find( + (e) => + e.type === StreamEventType.CHUNK && + e.value.candidates?.[0]?.content?.parts?.[0]?.text === 'First part', + ); + expect(firstChunk).toBeDefined(); + + const retryEvent = events.find((e) => e.type === StreamEventType.RETRY); + expect(retryEvent).toBeDefined(); + + const successChunk = events.find( + (e) => + e.type === StreamEventType.CHUNK && + e.value.candidates?.[0]?.content?.parts?.[0]?.text === 'Retry success', + ); + expect(successChunk).toBeDefined(); + + // Verify retry logging + expect(mockLogContentRetry).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + error_type: 'NETWORK_ERROR', + }), + ); + }); + + it('should retry on generic network error if retryFetchErrors is true', async () => { + vi.mocked(mockConfig.getRetryFetchErrors).mockReturnValue(true); + + const fetchError = new Error('fetch failed: socket hang up'); + + vi.mocked(mockContentGenerator.generateContentStream) + .mockImplementationOnce(async () => + (async function* () { + yield { + candidates: [{ content: { parts: [{ text: '' }] } }], + } as GenerateContentResponse; // Dummy yield + throw fetchError; + })(), + ) + .mockImplementationOnce(async () => + (async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'Success' }] }, + finishReason: 'STOP', + }, + ], + } as unknown as GenerateContentResponse; + })(), + ); + + const stream = await chat.sendMessageStream( + { model: 'test-model' }, + 'test message', + 'prompt-id-retry-fetch', + new AbortController().signal, + ); + + const events: StreamEvent[] = []; + for await (const event of stream) { + events.push(event); + } + + const retryEvent = events.find((e) => e.type === StreamEventType.RETRY); + expect(retryEvent).toBeDefined(); + + const successChunk = events.find( + (e) => + e.type === StreamEventType.CHUNK && + e.value.candidates?.[0]?.content?.parts?.[0]?.text === 'Success', + ); + expect(successChunk).toBeDefined(); + }); + + it('should NOT retry on 400 ApiError', async () => { + const error400 = new ApiError({ + message: 'Bad Request', + status: 400, + }); + + vi.mocked( + mockContentGenerator.generateContentStream, + ).mockImplementationOnce(async () => + (async function* () { + yield { + candidates: [{ content: { parts: [{ text: '' }] } }], + } as GenerateContentResponse; // Dummy yield + throw error400; + })(), + ); + + const stream = await chat.sendMessageStream( + { model: 'test-model' }, + 'test message', + 'prompt-id-no-retry', + new AbortController().signal, + ); + + await expect(async () => { + for await (const _ of stream) { + // consume + } + }).rejects.toThrow(error400); + + expect(mockLogContentRetry).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/core/src/tools/web-fetch.test.ts b/packages/core/src/tools/web-fetch.test.ts index aecb0e24f1..f37db3d558 100644 --- a/packages/core/src/tools/web-fetch.test.ts +++ b/packages/core/src/tools/web-fetch.test.ts @@ -134,6 +134,7 @@ describe('WebFetchTool', () => { setApprovalMode: vi.fn(), getProxy: vi.fn(), getGeminiClient: mockGetGeminiClient, + getRetryFetchErrors: vi.fn().mockReturnValue(false), modelConfigService: { getResolvedConfig: vi.fn().mockImplementation(({ model }) => ({ model, diff --git a/packages/core/src/tools/web-fetch.ts b/packages/core/src/tools/web-fetch.ts index dbfb173079..a1836c37ef 100644 --- a/packages/core/src/tools/web-fetch.ts +++ b/packages/core/src/tools/web-fetch.ts @@ -29,6 +29,7 @@ import { } from '../telemetry/index.js'; import { WEB_FETCH_TOOL_NAME } from './tool-names.js'; import { debugLogger } from '../utils/debugLogger.js'; +import { retryWithBackoff } from '../utils/retry.js'; const URL_FETCH_TIMEOUT_MS = 10000; const MAX_CONTENT_LENGTH = 100000; @@ -102,6 +103,10 @@ export interface WebFetchToolParams { prompt: string; } +interface ErrorWithStatus extends Error { + status?: number; +} + class WebFetchToolInvocation extends BaseToolInvocation< WebFetchToolParams, ToolResult @@ -129,12 +134,22 @@ class WebFetchToolInvocation extends BaseToolInvocation< } try { - const response = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS); - if (!response.ok) { - throw new Error( - `Request failed with status code ${response.status} ${response.statusText}`, - ); - } + const response = await retryWithBackoff( + async () => { + const res = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS); + if (!res.ok) { + const error = new Error( + `Request failed with status code ${res.status} ${res.statusText}`, + ); + (error as ErrorWithStatus).status = res.status; + throw error; + } + return res; + }, + { + retryFetchErrors: this.config.getRetryFetchErrors(), + }, + ); const rawContent = await response.text(); const contentType = response.headers.get('content-type') || ''; diff --git a/packages/core/src/utils/fetch.ts b/packages/core/src/utils/fetch.ts index ba9bb83c80..3c59b2ef31 100644 --- a/packages/core/src/utils/fetch.ts +++ b/packages/core/src/utils/fetch.ts @@ -22,8 +22,9 @@ export class FetchError extends Error { constructor( message: string, public code?: string, + options?: ErrorOptions, ) { - super(message); + super(message, options); this.name = 'FetchError'; } } @@ -51,7 +52,7 @@ export async function fetchWithTimeout( if (isNodeError(error) && error.code === 'ABORT_ERR') { throw new FetchError(`Request timed out after ${timeout}ms`, 'ETIMEDOUT'); } - throw new FetchError(getErrorMessage(error)); + throw new FetchError(getErrorMessage(error), undefined, { cause: error }); } finally { clearTimeout(timeoutId); } diff --git a/packages/core/src/utils/retry.test.ts b/packages/core/src/utils/retry.test.ts index fe607e6914..f0d78677e3 100644 --- a/packages/core/src/utils/retry.test.ts +++ b/packages/core/src/utils/retry.test.ts @@ -307,7 +307,7 @@ describe('retryWithBackoff', () => { }); describe('Fetch error retries', () => { - it('should retry on specific fetch error when retryFetchErrors is true', async () => { + it("should retry on 'fetch failed' when retryFetchErrors is true", async () => { const mockFn = vi.fn(); mockFn.mockRejectedValueOnce(new TypeError('fetch failed')); mockFn.mockResolvedValueOnce('success'); @@ -365,19 +365,48 @@ describe('retryWithBackoff', () => { expect(mockFn).toHaveBeenCalledTimes(2); }); - it.each([false, undefined])( - 'should not retry on specific fetch error when retryFetchErrors is %s', - async (retryFetchErrors) => { - const mockFn = vi.fn().mockRejectedValue(new TypeError('fetch failed')); + it("should retry on 'fetch failed' when retryFetchErrors is true (short delays)", async () => { + const mockFn = vi + .fn() + .mockRejectedValueOnce(new TypeError('fetch failed')) + .mockResolvedValue('success'); - const promise = retryWithBackoff(mockFn, { - retryFetchErrors, - }); + const promise = retryWithBackoff(mockFn, { + retryFetchErrors: true, + initialDelayMs: 1, + maxDelayMs: 1, + }); + await vi.runAllTimersAsync(); + await expect(promise).resolves.toBe('success'); + }); - await expect(promise).rejects.toThrow('fetch failed'); - expect(mockFn).toHaveBeenCalledTimes(1); - }, - ); + it("should not retry on 'fetch failed' when retryFetchErrors is false", async () => { + const mockFn = vi.fn().mockRejectedValue(new TypeError('fetch failed')); + const promise = retryWithBackoff(mockFn, { + retryFetchErrors: false, + initialDelayMs: 1, + maxDelayMs: 1, + }); + await expect(promise).rejects.toThrow('fetch failed'); + expect(mockFn).toHaveBeenCalledTimes(1); + }); + + it('should retry on network error code (ETIMEDOUT) even when retryFetchErrors is false', async () => { + const error = new Error('connect ETIMEDOUT'); + (error as any).code = 'ETIMEDOUT'; + const mockFn = vi + .fn() + .mockRejectedValueOnce(error) + .mockResolvedValue('success'); + + const promise = retryWithBackoff(mockFn, { + retryFetchErrors: false, + initialDelayMs: 1, + maxDelayMs: 1, + }); + await vi.runAllTimersAsync(); + await expect(promise).resolves.toBe('success'); + }); }); describe('Flash model fallback for OAuth users', () => { diff --git a/packages/core/src/utils/retry.ts b/packages/core/src/utils/retry.ts index 99ed75ea7e..515eb92426 100644 --- a/packages/core/src/utils/retry.ts +++ b/packages/core/src/utils/retry.ts @@ -35,7 +35,7 @@ const DEFAULT_RETRY_OPTIONS: RetryOptions = { maxAttempts: 3, initialDelayMs: 5000, maxDelayMs: 30000, // 30 seconds - shouldRetryOnError: defaultShouldRetry, + shouldRetryOnError: isRetryableError, }; const RETRYABLE_NETWORK_CODES = [ @@ -79,21 +79,21 @@ const FETCH_FAILED_MESSAGE = 'fetch failed'; * @param retryFetchErrors Whether to retry on specific fetch errors. * @returns True if the error is a transient error, false otherwise. */ -function defaultShouldRetry( +export function isRetryableError( error: Error | unknown, retryFetchErrors?: boolean, ): boolean { + // Check for common network error codes + const errorCode = getNetworkErrorCode(error); + if (errorCode && RETRYABLE_NETWORK_CODES.includes(errorCode)) { + return true; + } + if (retryFetchErrors && error instanceof Error) { // Check for generic fetch failed message (case-insensitive) if (error.message.toLowerCase().includes(FETCH_FAILED_MESSAGE)) { return true; } - - // Check for common network error codes - const errorCode = getNetworkErrorCode(error); - if (errorCode && RETRYABLE_NETWORK_CODES.includes(errorCode)) { - return true; - } } // Priority check for ApiError @@ -147,6 +147,7 @@ export async function retryWithBackoff( signal, } = { ...DEFAULT_RETRY_OPTIONS, + shouldRetryOnError: isRetryableError, ...cleanOptions, };