mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-11 14:40:52 -07:00
chore: do not retry the model request if the user has aborted the request (#11224)
This commit is contained in:
@@ -383,6 +383,7 @@ export class GeminiChat {
|
||||
onPersistent429: onPersistent429Callback,
|
||||
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||
retryFetchErrors: this.config.getRetryFetchErrors(),
|
||||
signal: params.config?.abortSignal,
|
||||
});
|
||||
|
||||
return this.processStreamResponse(model, streamResponse);
|
||||
|
||||
112
packages/core/src/utils/delay.test.ts
Normal file
112
packages/core/src/utils/delay.test.ts
Normal file
@@ -0,0 +1,112 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { beforeEach, afterEach, describe, expect, it, vi } from 'vitest';
|
||||
import { delay } from './delay.js';
|
||||
|
||||
describe('abortableDelay', () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('resolves after the specified duration without a signal', async () => {
|
||||
const promise = delay(100);
|
||||
await vi.advanceTimersByTimeAsync(100);
|
||||
await expect(promise).resolves.toBeUndefined();
|
||||
});
|
||||
|
||||
it('resolves when a non-aborted signal is provided', async () => {
|
||||
const controller = new AbortController();
|
||||
const promise = delay(200, controller.signal);
|
||||
|
||||
await vi.advanceTimersByTimeAsync(200);
|
||||
|
||||
await expect(promise).resolves.toBeUndefined();
|
||||
});
|
||||
|
||||
it('rejects immediately if the signal is already aborted', async () => {
|
||||
const controller = new AbortController();
|
||||
controller.abort();
|
||||
|
||||
await expect(delay(50, controller.signal)).rejects.toMatchObject({
|
||||
name: 'AbortError',
|
||||
message: 'Aborted',
|
||||
});
|
||||
});
|
||||
|
||||
it('rejects if the signal aborts while waiting', async () => {
|
||||
const controller = new AbortController();
|
||||
const promise = delay(500, controller.signal);
|
||||
|
||||
await vi.advanceTimersByTimeAsync(100);
|
||||
controller.abort();
|
||||
|
||||
await expect(promise).rejects.toMatchObject({
|
||||
name: 'AbortError',
|
||||
message: 'Aborted',
|
||||
});
|
||||
});
|
||||
|
||||
it('cleans up signal listeners after resolving', async () => {
|
||||
const removeEventListener = vi.fn();
|
||||
const mockSignal = {
|
||||
aborted: false,
|
||||
addEventListener: vi
|
||||
.fn()
|
||||
.mockImplementation((_type: string, listener: () => void) => {
|
||||
mockSignal.__listener = listener;
|
||||
}),
|
||||
removeEventListener,
|
||||
__listener: undefined as (() => void) | undefined,
|
||||
} as unknown as AbortSignal & { __listener?: () => void };
|
||||
|
||||
const promise = delay(150, mockSignal);
|
||||
await vi.advanceTimersByTimeAsync(150);
|
||||
await promise;
|
||||
|
||||
expect(mockSignal.addEventListener).toHaveBeenCalledTimes(1);
|
||||
expect(removeEventListener).toHaveBeenCalledTimes(1);
|
||||
expect(removeEventListener.mock.calls[0][1]).toBe(mockSignal.__listener);
|
||||
});
|
||||
|
||||
// Technically unnecessary due to `onceTrue` but good sanity check
|
||||
it('cleans up signal listeners when aborted before completion', async () => {
|
||||
const controller = new AbortController();
|
||||
const removeEventListenerSpy = vi.spyOn(
|
||||
controller.signal,
|
||||
'removeEventListener',
|
||||
);
|
||||
|
||||
const promise = delay(400, controller.signal);
|
||||
|
||||
await vi.advanceTimersByTimeAsync(50);
|
||||
controller.abort();
|
||||
|
||||
await expect(promise).rejects.toMatchObject({
|
||||
name: 'AbortError',
|
||||
});
|
||||
expect(removeEventListenerSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('cleans up timeout when aborted before completion', async () => {
|
||||
const clearTimeoutSpy = vi.spyOn(global, 'clearTimeout');
|
||||
const controller = new AbortController();
|
||||
const promise = delay(400, controller.signal);
|
||||
|
||||
await vi.advanceTimersByTimeAsync(50);
|
||||
controller.abort();
|
||||
|
||||
await expect(promise).rejects.toMatchObject({
|
||||
name: 'AbortError',
|
||||
});
|
||||
expect(clearTimeoutSpy).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
48
packages/core/src/utils/delay.ts
Normal file
48
packages/core/src/utils/delay.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/**
|
||||
* Factory to create a standard abort error for delay helpers.
|
||||
*/
|
||||
export function createAbortError(): Error {
|
||||
const abortError = new Error('Aborted');
|
||||
abortError.name = 'AbortError';
|
||||
return abortError;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a promise that resolves after the provided duration unless aborted.
|
||||
*
|
||||
* @param ms Delay duration in milliseconds.
|
||||
* @param signal Optional abort signal to cancel the wait early.
|
||||
*/
|
||||
export function delay(ms: number, signal?: AbortSignal): Promise<void> {
|
||||
// If no abort signal is provided, set simple delay
|
||||
if (!signal) {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
// Immediately reject if signal has already been aborted
|
||||
if (signal.aborted) {
|
||||
return Promise.reject(createAbortError());
|
||||
}
|
||||
|
||||
// remove abort and timeout listeners to prevent memory-leaks
|
||||
return new Promise((resolve, reject) => {
|
||||
const onAbort = () => {
|
||||
clearTimeout(timeoutId);
|
||||
signal.removeEventListener('abort', onAbort);
|
||||
reject(createAbortError());
|
||||
};
|
||||
|
||||
const timeoutId = setTimeout(() => {
|
||||
signal.removeEventListener('abort', onAbort);
|
||||
resolve();
|
||||
}, ms);
|
||||
|
||||
signal.addEventListener('abort', onAbort, { once: true });
|
||||
});
|
||||
}
|
||||
@@ -500,4 +500,25 @@ describe('retryWithBackoff', () => {
|
||||
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
|
||||
});
|
||||
});
|
||||
it('should abort the retry loop when the signal is aborted', async () => {
|
||||
const abortController = new AbortController();
|
||||
const mockFn = vi.fn().mockImplementation(async () => {
|
||||
const error: HttpError = new Error('Server error');
|
||||
error.status = 500;
|
||||
throw error;
|
||||
});
|
||||
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
maxAttempts: 5,
|
||||
initialDelayMs: 100,
|
||||
signal: abortController.signal,
|
||||
});
|
||||
await vi.advanceTimersByTimeAsync(50);
|
||||
abortController.abort();
|
||||
|
||||
await expect(promise).rejects.toThrow(
|
||||
expect.objectContaining({ name: 'AbortError' }),
|
||||
);
|
||||
expect(mockFn).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
isProQuotaExceededError,
|
||||
isGenericQuotaExceededError,
|
||||
} from './quotaErrorDetection.js';
|
||||
import { delay, createAbortError } from './delay.js';
|
||||
|
||||
const FETCH_FAILED_MESSAGE =
|
||||
'exception TypeError: fetch failed sending request';
|
||||
@@ -31,6 +32,7 @@ export interface RetryOptions {
|
||||
) => Promise<string | boolean | null>;
|
||||
authType?: string;
|
||||
retryFetchErrors?: boolean;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
const DEFAULT_RETRY_OPTIONS: RetryOptions = {
|
||||
@@ -75,15 +77,6 @@ function defaultShouldRetry(
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Delays execution for a specified number of milliseconds.
|
||||
* @param ms The number of milliseconds to delay.
|
||||
* @returns A promise that resolves after the delay.
|
||||
*/
|
||||
function delay(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
/**
|
||||
* Retries a function with exponential backoff and jitter.
|
||||
* @param fn The asynchronous function to retry.
|
||||
@@ -95,6 +88,10 @@ export async function retryWithBackoff<T>(
|
||||
fn: () => Promise<T>,
|
||||
options?: Partial<RetryOptions>,
|
||||
): Promise<T> {
|
||||
if (options?.signal?.aborted) {
|
||||
throw createAbortError();
|
||||
}
|
||||
|
||||
if (options?.maxAttempts !== undefined && options.maxAttempts <= 0) {
|
||||
throw new Error('maxAttempts must be a positive number.');
|
||||
}
|
||||
@@ -112,6 +109,7 @@ export async function retryWithBackoff<T>(
|
||||
shouldRetryOnError,
|
||||
shouldRetryOnContent,
|
||||
retryFetchErrors,
|
||||
signal,
|
||||
} = {
|
||||
...DEFAULT_RETRY_OPTIONS,
|
||||
...cleanOptions,
|
||||
@@ -122,6 +120,9 @@ export async function retryWithBackoff<T>(
|
||||
let consecutive429Count = 0;
|
||||
|
||||
while (attempt < maxAttempts) {
|
||||
if (signal?.aborted) {
|
||||
throw createAbortError();
|
||||
}
|
||||
attempt++;
|
||||
try {
|
||||
const result = await fn();
|
||||
@@ -132,13 +133,17 @@ export async function retryWithBackoff<T>(
|
||||
) {
|
||||
const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1);
|
||||
const delayWithJitter = Math.max(0, currentDelay + jitter);
|
||||
await delay(delayWithJitter);
|
||||
await delay(delayWithJitter, signal);
|
||||
currentDelay = Math.min(maxDelayMs, currentDelay * 2);
|
||||
continue;
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
throw error;
|
||||
}
|
||||
|
||||
const errorStatus = getErrorStatus(error);
|
||||
|
||||
// Check for Pro quota exceeded error first - immediate fallback for OAuth users
|
||||
@@ -243,7 +248,7 @@ export async function retryWithBackoff<T>(
|
||||
`Attempt ${attempt} failed with status ${delayErrorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`,
|
||||
error,
|
||||
);
|
||||
await delay(delayDurationMs);
|
||||
await delay(delayDurationMs, signal);
|
||||
// Reset currentDelay for next potential non-429 error, or if Retry-After is not present next time
|
||||
currentDelay = initialDelayMs;
|
||||
} else {
|
||||
@@ -252,7 +257,7 @@ export async function retryWithBackoff<T>(
|
||||
// Add jitter: +/- 30% of currentDelay
|
||||
const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1);
|
||||
const delayWithJitter = Math.max(0, currentDelay + jitter);
|
||||
await delay(delayWithJitter);
|
||||
await delay(delayWithJitter, signal);
|
||||
currentDelay = Math.min(maxDelayMs, currentDelay * 2);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user