chore: do not retry the model request if the user has aborted the request (#11224)

This commit is contained in:
Adam Weidman
2025-10-20 17:21:49 +02:00
committed by GitHub
parent a788a6df48
commit 8731309d7c
5 changed files with 199 additions and 12 deletions

View File

@@ -383,6 +383,7 @@ export class GeminiChat {
onPersistent429: onPersistent429Callback,
authType: this.config.getContentGeneratorConfig()?.authType,
retryFetchErrors: this.config.getRetryFetchErrors(),
signal: params.config?.abortSignal,
});
return this.processStreamResponse(model, streamResponse);

View File

@@ -0,0 +1,112 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { beforeEach, afterEach, describe, expect, it, vi } from 'vitest';
import { delay } from './delay.js';
describe('abortableDelay', () => {
beforeEach(() => {
vi.useFakeTimers();
});
afterEach(() => {
vi.useRealTimers();
vi.restoreAllMocks();
});
it('resolves after the specified duration without a signal', async () => {
const promise = delay(100);
await vi.advanceTimersByTimeAsync(100);
await expect(promise).resolves.toBeUndefined();
});
it('resolves when a non-aborted signal is provided', async () => {
const controller = new AbortController();
const promise = delay(200, controller.signal);
await vi.advanceTimersByTimeAsync(200);
await expect(promise).resolves.toBeUndefined();
});
it('rejects immediately if the signal is already aborted', async () => {
const controller = new AbortController();
controller.abort();
await expect(delay(50, controller.signal)).rejects.toMatchObject({
name: 'AbortError',
message: 'Aborted',
});
});
it('rejects if the signal aborts while waiting', async () => {
const controller = new AbortController();
const promise = delay(500, controller.signal);
await vi.advanceTimersByTimeAsync(100);
controller.abort();
await expect(promise).rejects.toMatchObject({
name: 'AbortError',
message: 'Aborted',
});
});
it('cleans up signal listeners after resolving', async () => {
const removeEventListener = vi.fn();
const mockSignal = {
aborted: false,
addEventListener: vi
.fn()
.mockImplementation((_type: string, listener: () => void) => {
mockSignal.__listener = listener;
}),
removeEventListener,
__listener: undefined as (() => void) | undefined,
} as unknown as AbortSignal & { __listener?: () => void };
const promise = delay(150, mockSignal);
await vi.advanceTimersByTimeAsync(150);
await promise;
expect(mockSignal.addEventListener).toHaveBeenCalledTimes(1);
expect(removeEventListener).toHaveBeenCalledTimes(1);
expect(removeEventListener.mock.calls[0][1]).toBe(mockSignal.__listener);
});
// Technically unnecessary due to `onceTrue` but good sanity check
it('cleans up signal listeners when aborted before completion', async () => {
const controller = new AbortController();
const removeEventListenerSpy = vi.spyOn(
controller.signal,
'removeEventListener',
);
const promise = delay(400, controller.signal);
await vi.advanceTimersByTimeAsync(50);
controller.abort();
await expect(promise).rejects.toMatchObject({
name: 'AbortError',
});
expect(removeEventListenerSpy).toHaveBeenCalledTimes(1);
});
it('cleans up timeout when aborted before completion', async () => {
const clearTimeoutSpy = vi.spyOn(global, 'clearTimeout');
const controller = new AbortController();
const promise = delay(400, controller.signal);
await vi.advanceTimersByTimeAsync(50);
controller.abort();
await expect(promise).rejects.toMatchObject({
name: 'AbortError',
});
expect(clearTimeoutSpy).toHaveBeenCalled();
});
});

View File

@@ -0,0 +1,48 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
/**
* Factory to create a standard abort error for delay helpers.
*/
export function createAbortError(): Error {
const abortError = new Error('Aborted');
abortError.name = 'AbortError';
return abortError;
}
/**
* Returns a promise that resolves after the provided duration unless aborted.
*
* @param ms Delay duration in milliseconds.
* @param signal Optional abort signal to cancel the wait early.
*/
export function delay(ms: number, signal?: AbortSignal): Promise<void> {
// If no abort signal is provided, set simple delay
if (!signal) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
// Immediately reject if signal has already been aborted
if (signal.aborted) {
return Promise.reject(createAbortError());
}
// remove abort and timeout listeners to prevent memory-leaks
return new Promise((resolve, reject) => {
const onAbort = () => {
clearTimeout(timeoutId);
signal.removeEventListener('abort', onAbort);
reject(createAbortError());
};
const timeoutId = setTimeout(() => {
signal.removeEventListener('abort', onAbort);
resolve();
}, ms);
signal.addEventListener('abort', onAbort, { once: true });
});
}

View File

@@ -500,4 +500,25 @@ describe('retryWithBackoff', () => {
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
});
});
it('should abort the retry loop when the signal is aborted', async () => {
const abortController = new AbortController();
const mockFn = vi.fn().mockImplementation(async () => {
const error: HttpError = new Error('Server error');
error.status = 500;
throw error;
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 5,
initialDelayMs: 100,
signal: abortController.signal,
});
await vi.advanceTimersByTimeAsync(50);
abortController.abort();
await expect(promise).rejects.toThrow(
expect.objectContaining({ name: 'AbortError' }),
);
expect(mockFn).toHaveBeenCalledTimes(1);
});
});

View File

@@ -11,6 +11,7 @@ import {
isProQuotaExceededError,
isGenericQuotaExceededError,
} from './quotaErrorDetection.js';
import { delay, createAbortError } from './delay.js';
const FETCH_FAILED_MESSAGE =
'exception TypeError: fetch failed sending request';
@@ -31,6 +32,7 @@ export interface RetryOptions {
) => Promise<string | boolean | null>;
authType?: string;
retryFetchErrors?: boolean;
signal?: AbortSignal;
}
const DEFAULT_RETRY_OPTIONS: RetryOptions = {
@@ -75,15 +77,6 @@ function defaultShouldRetry(
return false;
}
/**
* Delays execution for a specified number of milliseconds.
* @param ms The number of milliseconds to delay.
* @returns A promise that resolves after the delay.
*/
function delay(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
/**
* Retries a function with exponential backoff and jitter.
* @param fn The asynchronous function to retry.
@@ -95,6 +88,10 @@ export async function retryWithBackoff<T>(
fn: () => Promise<T>,
options?: Partial<RetryOptions>,
): Promise<T> {
if (options?.signal?.aborted) {
throw createAbortError();
}
if (options?.maxAttempts !== undefined && options.maxAttempts <= 0) {
throw new Error('maxAttempts must be a positive number.');
}
@@ -112,6 +109,7 @@ export async function retryWithBackoff<T>(
shouldRetryOnError,
shouldRetryOnContent,
retryFetchErrors,
signal,
} = {
...DEFAULT_RETRY_OPTIONS,
...cleanOptions,
@@ -122,6 +120,9 @@ export async function retryWithBackoff<T>(
let consecutive429Count = 0;
while (attempt < maxAttempts) {
if (signal?.aborted) {
throw createAbortError();
}
attempt++;
try {
const result = await fn();
@@ -132,13 +133,17 @@ export async function retryWithBackoff<T>(
) {
const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1);
const delayWithJitter = Math.max(0, currentDelay + jitter);
await delay(delayWithJitter);
await delay(delayWithJitter, signal);
currentDelay = Math.min(maxDelayMs, currentDelay * 2);
continue;
}
return result;
} catch (error) {
if (error instanceof Error && error.name === 'AbortError') {
throw error;
}
const errorStatus = getErrorStatus(error);
// Check for Pro quota exceeded error first - immediate fallback for OAuth users
@@ -243,7 +248,7 @@ export async function retryWithBackoff<T>(
`Attempt ${attempt} failed with status ${delayErrorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`,
error,
);
await delay(delayDurationMs);
await delay(delayDurationMs, signal);
// Reset currentDelay for next potential non-429 error, or if Retry-After is not present next time
currentDelay = initialDelayMs;
} else {
@@ -252,7 +257,7 @@ export async function retryWithBackoff<T>(
// Add jitter: +/- 30% of currentDelay
const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1);
const delayWithJitter = Math.max(0, currentDelay + jitter);
await delay(delayWithJitter);
await delay(delayWithJitter, signal);
currentDelay = Math.min(maxDelayMs, currentDelay * 2);
}
}