fix(core): improve API response error handling and retry logic (#14563)

This commit is contained in:
matt korwel
2025-12-05 09:49:08 -08:00
committed by GitHub
parent 738354ff65
commit dd3fd73ffe
8 changed files with 375 additions and 45 deletions

View File

@@ -69,9 +69,13 @@ const { mockRetryWithBackoff } = vi.hoisted(() => ({
mockRetryWithBackoff: vi.fn(),
}));
vi.mock('../utils/retry.js', () => ({
retryWithBackoff: mockRetryWithBackoff,
}));
vi.mock('../utils/retry.js', async (importOriginal) => {
const actual = await importOriginal<typeof import('../utils/retry.js')>();
return {
...actual,
retryWithBackoff: mockRetryWithBackoff,
};
});
vi.mock('../fallback/handler.js', () => ({
handleFallback: mockHandleFallback,

View File

@@ -19,7 +19,7 @@ import type {
import { ThinkingLevel } from '@google/genai';
import { toParts } from '../code_assist/converter.js';
import { createUserContent, FinishReason } from '@google/genai';
import { retryWithBackoff } from '../utils/retry.js';
import { retryWithBackoff, isRetryableError } from '../utils/retry.js';
import type { Config } from '../config/config.js';
import {
DEFAULT_GEMINI_MODEL,
@@ -310,6 +310,7 @@ export class GeminiChat {
}
for (let attempt = 0; attempt < maxAttempts; attempt++) {
let isConnectionPhase = true;
try {
if (attempt > 0) {
yield { type: StreamEventType.RETRY };
@@ -320,13 +321,14 @@ export class GeminiChat {
generateContentConfig.temperature = 1;
}
isConnectionPhase = true;
const stream = await this.makeApiCallAndProcessStream(
model,
generateContentConfig,
requestContents,
prompt_id,
);
isConnectionPhase = false;
for await (const chunk of stream) {
yield { type: StreamEventType.CHUNK, value: chunk };
}
@@ -334,27 +336,33 @@ export class GeminiChat {
lastError = null;
break;
} catch (error) {
if (isConnectionPhase) {
throw error;
}
lastError = error;
const isContentError = error instanceof InvalidStreamError;
const isRetryable = isRetryableError(
error,
this.config.getRetryFetchErrors(),
);
if (isContentError && isGemini2Model(model)) {
if (
(isContentError && isGemini2Model(model)) ||
(isRetryable && !signal.aborted)
) {
// Check if we have more attempts left.
if (attempt < maxAttempts - 1) {
const delayMs = INVALID_CONTENT_RETRY_OPTIONS.initialDelayMs;
const retryType = isContentError
? (error as InvalidStreamError).type
: 'NETWORK_ERROR';
logContentRetry(
this.config,
new ContentRetryEvent(
attempt,
(error as InvalidStreamError).type,
INVALID_CONTENT_RETRY_OPTIONS.initialDelayMs,
model,
),
new ContentRetryEvent(attempt, retryType, delayMs, model),
);
await new Promise((res) =>
setTimeout(
res,
INVALID_CONTENT_RETRY_OPTIONS.initialDelayMs *
(attempt + 1),
),
setTimeout(res, delayMs * (attempt + 1)),
);
continue;
}

View File

@@ -0,0 +1,271 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import type { GenerateContentResponse } from '@google/genai';
import { ApiError } from '@google/genai';
import type { ContentGenerator } from '../core/contentGenerator.js';
import { GeminiChat, StreamEventType, type StreamEvent } from './geminiChat.js';
import type { Config } from '../config/config.js';
import { setSimulate429 } from '../utils/testUtils.js';
import { HookSystem } from '../hooks/hookSystem.js';
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
// Mock fs module
vi.mock('node:fs', () => ({
default: {
mkdirSync: vi.fn(),
writeFileSync: vi.fn(),
readFileSync: vi.fn(() => {
const error = new Error('ENOENT');
(error as NodeJS.ErrnoException).code = 'ENOENT';
throw error;
}),
existsSync: vi.fn(() => false),
},
}));
const { mockRetryWithBackoff } = vi.hoisted(() => ({
mockRetryWithBackoff: vi.fn(),
}));
vi.mock('../utils/retry.js', async (importOriginal) => {
const actual = await importOriginal<typeof import('../utils/retry.js')>();
return {
...actual,
retryWithBackoff: mockRetryWithBackoff,
};
});
// Mock loggers
const { mockLogContentRetry, mockLogContentRetryFailure } = vi.hoisted(() => ({
mockLogContentRetry: vi.fn(),
mockLogContentRetryFailure: vi.fn(),
}));
vi.mock('../telemetry/loggers.js', () => ({
logContentRetry: mockLogContentRetry,
logContentRetryFailure: mockLogContentRetryFailure,
}));
describe('GeminiChat Network Retries', () => {
let mockContentGenerator: ContentGenerator;
let chat: GeminiChat;
let mockConfig: Config;
beforeEach(() => {
vi.clearAllMocks();
mockContentGenerator = {
generateContent: vi.fn(),
generateContentStream: vi.fn(),
} as unknown as ContentGenerator;
// Default mock implementation: execute the function immediately
mockRetryWithBackoff.mockImplementation(async (apiCall) => apiCall());
mockConfig = {
getSessionId: () => 'test-session-id',
getTelemetryLogPromptsEnabled: () => true,
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getPreviewFeatures: () => false,
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'oauth-personal',
model: 'test-model',
}),
getModel: vi.fn().mockReturnValue('gemini-pro'),
isInFallbackMode: vi.fn().mockReturnValue(false),
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
getProjectRoot: vi.fn().mockReturnValue('/test/project/root'),
storage: {
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
},
getToolRegistry: vi.fn().mockReturnValue({ getTool: vi.fn() }),
getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator),
getRetryFetchErrors: vi.fn().mockReturnValue(false), // Default false
modelConfigService: {
getResolvedConfig: vi.fn().mockImplementation((modelConfigKey) => ({
model: modelConfigKey.model,
generateContentConfig: { temperature: 0 },
})),
},
isPreviewModelBypassMode: vi.fn().mockReturnValue(false),
setPreviewModelBypassMode: vi.fn(),
isPreviewModelFallbackMode: vi.fn().mockReturnValue(false),
getEnableHooks: vi.fn().mockReturnValue(false),
setPreviewModelFallbackMode: vi.fn(),
} as unknown as Config;
const mockMessageBus = createMockMessageBus();
mockConfig.getMessageBus = vi.fn().mockReturnValue(mockMessageBus);
mockConfig.getHookSystem = vi
.fn()
.mockReturnValue(new HookSystem(mockConfig));
setSimulate429(false);
chat = new GeminiChat(mockConfig);
});
afterEach(() => {
vi.restoreAllMocks();
});
it('should retry when a 503 ApiError occurs during stream iteration', async () => {
// 1. Mock the API to yield one chunk, then throw a 503 error.
const error503 = new ApiError({
message: 'Service Unavailable',
status: 503,
});
vi.mocked(mockContentGenerator.generateContentStream)
.mockImplementationOnce(async () =>
(async function* () {
yield {
candidates: [{ content: { parts: [{ text: 'First part' }] } }],
} as unknown as GenerateContentResponse;
throw error503;
})(),
)
.mockImplementationOnce(async () =>
(async function* () {
yield {
candidates: [
{
content: { parts: [{ text: 'Retry success' }] },
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})(),
);
// 2. Execute sendMessageStream
const stream = await chat.sendMessageStream(
{ model: 'test-model' },
'test message',
'prompt-id-retry-network',
new AbortController().signal,
);
const events: StreamEvent[] = [];
for await (const event of stream) {
events.push(event);
}
// 3. Assertions
// Expected sequence: CHUNK('First part') -> RETRY -> CHUNK('Retry success')
expect(events.length).toBeGreaterThanOrEqual(3);
const firstChunk = events.find(
(e) =>
e.type === StreamEventType.CHUNK &&
e.value.candidates?.[0]?.content?.parts?.[0]?.text === 'First part',
);
expect(firstChunk).toBeDefined();
const retryEvent = events.find((e) => e.type === StreamEventType.RETRY);
expect(retryEvent).toBeDefined();
const successChunk = events.find(
(e) =>
e.type === StreamEventType.CHUNK &&
e.value.candidates?.[0]?.content?.parts?.[0]?.text === 'Retry success',
);
expect(successChunk).toBeDefined();
// Verify retry logging
expect(mockLogContentRetry).toHaveBeenCalledWith(
expect.anything(),
expect.objectContaining({
error_type: 'NETWORK_ERROR',
}),
);
});
it('should retry on generic network error if retryFetchErrors is true', async () => {
vi.mocked(mockConfig.getRetryFetchErrors).mockReturnValue(true);
const fetchError = new Error('fetch failed: socket hang up');
vi.mocked(mockContentGenerator.generateContentStream)
.mockImplementationOnce(async () =>
(async function* () {
yield {
candidates: [{ content: { parts: [{ text: '' }] } }],
} as GenerateContentResponse; // Dummy yield
throw fetchError;
})(),
)
.mockImplementationOnce(async () =>
(async function* () {
yield {
candidates: [
{
content: { parts: [{ text: 'Success' }] },
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})(),
);
const stream = await chat.sendMessageStream(
{ model: 'test-model' },
'test message',
'prompt-id-retry-fetch',
new AbortController().signal,
);
const events: StreamEvent[] = [];
for await (const event of stream) {
events.push(event);
}
const retryEvent = events.find((e) => e.type === StreamEventType.RETRY);
expect(retryEvent).toBeDefined();
const successChunk = events.find(
(e) =>
e.type === StreamEventType.CHUNK &&
e.value.candidates?.[0]?.content?.parts?.[0]?.text === 'Success',
);
expect(successChunk).toBeDefined();
});
it('should NOT retry on 400 ApiError', async () => {
const error400 = new ApiError({
message: 'Bad Request',
status: 400,
});
vi.mocked(
mockContentGenerator.generateContentStream,
).mockImplementationOnce(async () =>
(async function* () {
yield {
candidates: [{ content: { parts: [{ text: '' }] } }],
} as GenerateContentResponse; // Dummy yield
throw error400;
})(),
);
const stream = await chat.sendMessageStream(
{ model: 'test-model' },
'test message',
'prompt-id-no-retry',
new AbortController().signal,
);
await expect(async () => {
for await (const _ of stream) {
// consume
}
}).rejects.toThrow(error400);
expect(mockLogContentRetry).not.toHaveBeenCalled();
});
});

View File

@@ -134,6 +134,7 @@ describe('WebFetchTool', () => {
setApprovalMode: vi.fn(),
getProxy: vi.fn(),
getGeminiClient: mockGetGeminiClient,
getRetryFetchErrors: vi.fn().mockReturnValue(false),
modelConfigService: {
getResolvedConfig: vi.fn().mockImplementation(({ model }) => ({
model,

View File

@@ -29,6 +29,7 @@ import {
} from '../telemetry/index.js';
import { WEB_FETCH_TOOL_NAME } from './tool-names.js';
import { debugLogger } from '../utils/debugLogger.js';
import { retryWithBackoff } from '../utils/retry.js';
const URL_FETCH_TIMEOUT_MS = 10000;
const MAX_CONTENT_LENGTH = 100000;
@@ -102,6 +103,10 @@ export interface WebFetchToolParams {
prompt: string;
}
interface ErrorWithStatus extends Error {
status?: number;
}
class WebFetchToolInvocation extends BaseToolInvocation<
WebFetchToolParams,
ToolResult
@@ -129,12 +134,22 @@ class WebFetchToolInvocation extends BaseToolInvocation<
}
try {
const response = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS);
if (!response.ok) {
throw new Error(
`Request failed with status code ${response.status} ${response.statusText}`,
);
}
const response = await retryWithBackoff(
async () => {
const res = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS);
if (!res.ok) {
const error = new Error(
`Request failed with status code ${res.status} ${res.statusText}`,
);
(error as ErrorWithStatus).status = res.status;
throw error;
}
return res;
},
{
retryFetchErrors: this.config.getRetryFetchErrors(),
},
);
const rawContent = await response.text();
const contentType = response.headers.get('content-type') || '';

View File

@@ -22,8 +22,9 @@ export class FetchError extends Error {
constructor(
message: string,
public code?: string,
options?: ErrorOptions,
) {
super(message);
super(message, options);
this.name = 'FetchError';
}
}
@@ -51,7 +52,7 @@ export async function fetchWithTimeout(
if (isNodeError(error) && error.code === 'ABORT_ERR') {
throw new FetchError(`Request timed out after ${timeout}ms`, 'ETIMEDOUT');
}
throw new FetchError(getErrorMessage(error));
throw new FetchError(getErrorMessage(error), undefined, { cause: error });
} finally {
clearTimeout(timeoutId);
}

View File

@@ -307,7 +307,7 @@ describe('retryWithBackoff', () => {
});
describe('Fetch error retries', () => {
it('should retry on specific fetch error when retryFetchErrors is true', async () => {
it("should retry on 'fetch failed' when retryFetchErrors is true", async () => {
const mockFn = vi.fn();
mockFn.mockRejectedValueOnce(new TypeError('fetch failed'));
mockFn.mockResolvedValueOnce('success');
@@ -365,19 +365,48 @@ describe('retryWithBackoff', () => {
expect(mockFn).toHaveBeenCalledTimes(2);
});
it.each([false, undefined])(
'should not retry on specific fetch error when retryFetchErrors is %s',
async (retryFetchErrors) => {
const mockFn = vi.fn().mockRejectedValue(new TypeError('fetch failed'));
it("should retry on 'fetch failed' when retryFetchErrors is true (short delays)", async () => {
const mockFn = vi
.fn()
.mockRejectedValueOnce(new TypeError('fetch failed'))
.mockResolvedValue('success');
const promise = retryWithBackoff(mockFn, {
retryFetchErrors,
});
const promise = retryWithBackoff(mockFn, {
retryFetchErrors: true,
initialDelayMs: 1,
maxDelayMs: 1,
});
await vi.runAllTimersAsync();
await expect(promise).resolves.toBe('success');
});
await expect(promise).rejects.toThrow('fetch failed');
expect(mockFn).toHaveBeenCalledTimes(1);
},
);
it("should not retry on 'fetch failed' when retryFetchErrors is false", async () => {
const mockFn = vi.fn().mockRejectedValue(new TypeError('fetch failed'));
const promise = retryWithBackoff(mockFn, {
retryFetchErrors: false,
initialDelayMs: 1,
maxDelayMs: 1,
});
await expect(promise).rejects.toThrow('fetch failed');
expect(mockFn).toHaveBeenCalledTimes(1);
});
it('should retry on network error code (ETIMEDOUT) even when retryFetchErrors is false', async () => {
const error = new Error('connect ETIMEDOUT');
(error as any).code = 'ETIMEDOUT';
const mockFn = vi
.fn()
.mockRejectedValueOnce(error)
.mockResolvedValue('success');
const promise = retryWithBackoff(mockFn, {
retryFetchErrors: false,
initialDelayMs: 1,
maxDelayMs: 1,
});
await vi.runAllTimersAsync();
await expect(promise).resolves.toBe('success');
});
});
describe('Flash model fallback for OAuth users', () => {

View File

@@ -35,7 +35,7 @@ const DEFAULT_RETRY_OPTIONS: RetryOptions = {
maxAttempts: 3,
initialDelayMs: 5000,
maxDelayMs: 30000, // 30 seconds
shouldRetryOnError: defaultShouldRetry,
shouldRetryOnError: isRetryableError,
};
const RETRYABLE_NETWORK_CODES = [
@@ -79,21 +79,21 @@ const FETCH_FAILED_MESSAGE = 'fetch failed';
* @param retryFetchErrors Whether to retry on specific fetch errors.
* @returns True if the error is a transient error, false otherwise.
*/
function defaultShouldRetry(
export function isRetryableError(
error: Error | unknown,
retryFetchErrors?: boolean,
): boolean {
// Check for common network error codes
const errorCode = getNetworkErrorCode(error);
if (errorCode && RETRYABLE_NETWORK_CODES.includes(errorCode)) {
return true;
}
if (retryFetchErrors && error instanceof Error) {
// Check for generic fetch failed message (case-insensitive)
if (error.message.toLowerCase().includes(FETCH_FAILED_MESSAGE)) {
return true;
}
// Check for common network error codes
const errorCode = getNetworkErrorCode(error);
if (errorCode && RETRYABLE_NETWORK_CODES.includes(errorCode)) {
return true;
}
}
// Priority check for ApiError
@@ -147,6 +147,7 @@ export async function retryWithBackoff<T>(
signal,
} = {
...DEFAULT_RETRY_OPTIONS,
shouldRetryOnError: isRetryableError,
...cleanOptions,
};