mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-30 15:04:16 -07:00
chore: do not retry the model request if the user has aborted the request (#11224)
This commit is contained in:
@@ -383,6 +383,7 @@ export class GeminiChat {
|
|||||||
onPersistent429: onPersistent429Callback,
|
onPersistent429: onPersistent429Callback,
|
||||||
authType: this.config.getContentGeneratorConfig()?.authType,
|
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||||
retryFetchErrors: this.config.getRetryFetchErrors(),
|
retryFetchErrors: this.config.getRetryFetchErrors(),
|
||||||
|
signal: params.config?.abortSignal,
|
||||||
});
|
});
|
||||||
|
|
||||||
return this.processStreamResponse(model, streamResponse);
|
return this.processStreamResponse(model, streamResponse);
|
||||||
|
|||||||
@@ -0,0 +1,112 @@
|
|||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { beforeEach, afterEach, describe, expect, it, vi } from 'vitest';
|
||||||
|
import { delay } from './delay.js';
|
||||||
|
|
||||||
|
describe('abortableDelay', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.useFakeTimers();
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.useRealTimers();
|
||||||
|
vi.restoreAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('resolves after the specified duration without a signal', async () => {
|
||||||
|
const promise = delay(100);
|
||||||
|
await vi.advanceTimersByTimeAsync(100);
|
||||||
|
await expect(promise).resolves.toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('resolves when a non-aborted signal is provided', async () => {
|
||||||
|
const controller = new AbortController();
|
||||||
|
const promise = delay(200, controller.signal);
|
||||||
|
|
||||||
|
await vi.advanceTimersByTimeAsync(200);
|
||||||
|
|
||||||
|
await expect(promise).resolves.toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('rejects immediately if the signal is already aborted', async () => {
|
||||||
|
const controller = new AbortController();
|
||||||
|
controller.abort();
|
||||||
|
|
||||||
|
await expect(delay(50, controller.signal)).rejects.toMatchObject({
|
||||||
|
name: 'AbortError',
|
||||||
|
message: 'Aborted',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('rejects if the signal aborts while waiting', async () => {
|
||||||
|
const controller = new AbortController();
|
||||||
|
const promise = delay(500, controller.signal);
|
||||||
|
|
||||||
|
await vi.advanceTimersByTimeAsync(100);
|
||||||
|
controller.abort();
|
||||||
|
|
||||||
|
await expect(promise).rejects.toMatchObject({
|
||||||
|
name: 'AbortError',
|
||||||
|
message: 'Aborted',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('cleans up signal listeners after resolving', async () => {
|
||||||
|
const removeEventListener = vi.fn();
|
||||||
|
const mockSignal = {
|
||||||
|
aborted: false,
|
||||||
|
addEventListener: vi
|
||||||
|
.fn()
|
||||||
|
.mockImplementation((_type: string, listener: () => void) => {
|
||||||
|
mockSignal.__listener = listener;
|
||||||
|
}),
|
||||||
|
removeEventListener,
|
||||||
|
__listener: undefined as (() => void) | undefined,
|
||||||
|
} as unknown as AbortSignal & { __listener?: () => void };
|
||||||
|
|
||||||
|
const promise = delay(150, mockSignal);
|
||||||
|
await vi.advanceTimersByTimeAsync(150);
|
||||||
|
await promise;
|
||||||
|
|
||||||
|
expect(mockSignal.addEventListener).toHaveBeenCalledTimes(1);
|
||||||
|
expect(removeEventListener).toHaveBeenCalledTimes(1);
|
||||||
|
expect(removeEventListener.mock.calls[0][1]).toBe(mockSignal.__listener);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Technically unnecessary due to `onceTrue` but good sanity check
|
||||||
|
it('cleans up signal listeners when aborted before completion', async () => {
|
||||||
|
const controller = new AbortController();
|
||||||
|
const removeEventListenerSpy = vi.spyOn(
|
||||||
|
controller.signal,
|
||||||
|
'removeEventListener',
|
||||||
|
);
|
||||||
|
|
||||||
|
const promise = delay(400, controller.signal);
|
||||||
|
|
||||||
|
await vi.advanceTimersByTimeAsync(50);
|
||||||
|
controller.abort();
|
||||||
|
|
||||||
|
await expect(promise).rejects.toMatchObject({
|
||||||
|
name: 'AbortError',
|
||||||
|
});
|
||||||
|
expect(removeEventListenerSpy).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('cleans up timeout when aborted before completion', async () => {
|
||||||
|
const clearTimeoutSpy = vi.spyOn(global, 'clearTimeout');
|
||||||
|
const controller = new AbortController();
|
||||||
|
const promise = delay(400, controller.signal);
|
||||||
|
|
||||||
|
await vi.advanceTimersByTimeAsync(50);
|
||||||
|
controller.abort();
|
||||||
|
|
||||||
|
await expect(promise).rejects.toMatchObject({
|
||||||
|
name: 'AbortError',
|
||||||
|
});
|
||||||
|
expect(clearTimeoutSpy).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Factory to create a standard abort error for delay helpers.
|
||||||
|
*/
|
||||||
|
export function createAbortError(): Error {
|
||||||
|
const abortError = new Error('Aborted');
|
||||||
|
abortError.name = 'AbortError';
|
||||||
|
return abortError;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a promise that resolves after the provided duration unless aborted.
|
||||||
|
*
|
||||||
|
* @param ms Delay duration in milliseconds.
|
||||||
|
* @param signal Optional abort signal to cancel the wait early.
|
||||||
|
*/
|
||||||
|
export function delay(ms: number, signal?: AbortSignal): Promise<void> {
|
||||||
|
// If no abort signal is provided, set simple delay
|
||||||
|
if (!signal) {
|
||||||
|
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Immediately reject if signal has already been aborted
|
||||||
|
if (signal.aborted) {
|
||||||
|
return Promise.reject(createAbortError());
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove abort and timeout listeners to prevent memory-leaks
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const onAbort = () => {
|
||||||
|
clearTimeout(timeoutId);
|
||||||
|
signal.removeEventListener('abort', onAbort);
|
||||||
|
reject(createAbortError());
|
||||||
|
};
|
||||||
|
|
||||||
|
const timeoutId = setTimeout(() => {
|
||||||
|
signal.removeEventListener('abort', onAbort);
|
||||||
|
resolve();
|
||||||
|
}, ms);
|
||||||
|
|
||||||
|
signal.addEventListener('abort', onAbort, { once: true });
|
||||||
|
});
|
||||||
|
}
|
||||||
@@ -500,4 +500,25 @@ describe('retryWithBackoff', () => {
|
|||||||
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
|
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
it('should abort the retry loop when the signal is aborted', async () => {
|
||||||
|
const abortController = new AbortController();
|
||||||
|
const mockFn = vi.fn().mockImplementation(async () => {
|
||||||
|
const error: HttpError = new Error('Server error');
|
||||||
|
error.status = 500;
|
||||||
|
throw error;
|
||||||
|
});
|
||||||
|
|
||||||
|
const promise = retryWithBackoff(mockFn, {
|
||||||
|
maxAttempts: 5,
|
||||||
|
initialDelayMs: 100,
|
||||||
|
signal: abortController.signal,
|
||||||
|
});
|
||||||
|
await vi.advanceTimersByTimeAsync(50);
|
||||||
|
abortController.abort();
|
||||||
|
|
||||||
|
await expect(promise).rejects.toThrow(
|
||||||
|
expect.objectContaining({ name: 'AbortError' }),
|
||||||
|
);
|
||||||
|
expect(mockFn).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import {
|
|||||||
isProQuotaExceededError,
|
isProQuotaExceededError,
|
||||||
isGenericQuotaExceededError,
|
isGenericQuotaExceededError,
|
||||||
} from './quotaErrorDetection.js';
|
} from './quotaErrorDetection.js';
|
||||||
|
import { delay, createAbortError } from './delay.js';
|
||||||
|
|
||||||
const FETCH_FAILED_MESSAGE =
|
const FETCH_FAILED_MESSAGE =
|
||||||
'exception TypeError: fetch failed sending request';
|
'exception TypeError: fetch failed sending request';
|
||||||
@@ -31,6 +32,7 @@ export interface RetryOptions {
|
|||||||
) => Promise<string | boolean | null>;
|
) => Promise<string | boolean | null>;
|
||||||
authType?: string;
|
authType?: string;
|
||||||
retryFetchErrors?: boolean;
|
retryFetchErrors?: boolean;
|
||||||
|
signal?: AbortSignal;
|
||||||
}
|
}
|
||||||
|
|
||||||
const DEFAULT_RETRY_OPTIONS: RetryOptions = {
|
const DEFAULT_RETRY_OPTIONS: RetryOptions = {
|
||||||
@@ -75,15 +77,6 @@ function defaultShouldRetry(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Delays execution for a specified number of milliseconds.
|
|
||||||
* @param ms The number of milliseconds to delay.
|
|
||||||
* @returns A promise that resolves after the delay.
|
|
||||||
*/
|
|
||||||
function delay(ms: number): Promise<void> {
|
|
||||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Retries a function with exponential backoff and jitter.
|
* Retries a function with exponential backoff and jitter.
|
||||||
* @param fn The asynchronous function to retry.
|
* @param fn The asynchronous function to retry.
|
||||||
@@ -95,6 +88,10 @@ export async function retryWithBackoff<T>(
|
|||||||
fn: () => Promise<T>,
|
fn: () => Promise<T>,
|
||||||
options?: Partial<RetryOptions>,
|
options?: Partial<RetryOptions>,
|
||||||
): Promise<T> {
|
): Promise<T> {
|
||||||
|
if (options?.signal?.aborted) {
|
||||||
|
throw createAbortError();
|
||||||
|
}
|
||||||
|
|
||||||
if (options?.maxAttempts !== undefined && options.maxAttempts <= 0) {
|
if (options?.maxAttempts !== undefined && options.maxAttempts <= 0) {
|
||||||
throw new Error('maxAttempts must be a positive number.');
|
throw new Error('maxAttempts must be a positive number.');
|
||||||
}
|
}
|
||||||
@@ -112,6 +109,7 @@ export async function retryWithBackoff<T>(
|
|||||||
shouldRetryOnError,
|
shouldRetryOnError,
|
||||||
shouldRetryOnContent,
|
shouldRetryOnContent,
|
||||||
retryFetchErrors,
|
retryFetchErrors,
|
||||||
|
signal,
|
||||||
} = {
|
} = {
|
||||||
...DEFAULT_RETRY_OPTIONS,
|
...DEFAULT_RETRY_OPTIONS,
|
||||||
...cleanOptions,
|
...cleanOptions,
|
||||||
@@ -122,6 +120,9 @@ export async function retryWithBackoff<T>(
|
|||||||
let consecutive429Count = 0;
|
let consecutive429Count = 0;
|
||||||
|
|
||||||
while (attempt < maxAttempts) {
|
while (attempt < maxAttempts) {
|
||||||
|
if (signal?.aborted) {
|
||||||
|
throw createAbortError();
|
||||||
|
}
|
||||||
attempt++;
|
attempt++;
|
||||||
try {
|
try {
|
||||||
const result = await fn();
|
const result = await fn();
|
||||||
@@ -132,13 +133,17 @@ export async function retryWithBackoff<T>(
|
|||||||
) {
|
) {
|
||||||
const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1);
|
const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1);
|
||||||
const delayWithJitter = Math.max(0, currentDelay + jitter);
|
const delayWithJitter = Math.max(0, currentDelay + jitter);
|
||||||
await delay(delayWithJitter);
|
await delay(delayWithJitter, signal);
|
||||||
currentDelay = Math.min(maxDelayMs, currentDelay * 2);
|
currentDelay = Math.min(maxDelayMs, currentDelay * 2);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (error instanceof Error && error.name === 'AbortError') {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
const errorStatus = getErrorStatus(error);
|
const errorStatus = getErrorStatus(error);
|
||||||
|
|
||||||
// Check for Pro quota exceeded error first - immediate fallback for OAuth users
|
// Check for Pro quota exceeded error first - immediate fallback for OAuth users
|
||||||
@@ -243,7 +248,7 @@ export async function retryWithBackoff<T>(
|
|||||||
`Attempt ${attempt} failed with status ${delayErrorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`,
|
`Attempt ${attempt} failed with status ${delayErrorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`,
|
||||||
error,
|
error,
|
||||||
);
|
);
|
||||||
await delay(delayDurationMs);
|
await delay(delayDurationMs, signal);
|
||||||
// Reset currentDelay for next potential non-429 error, or if Retry-After is not present next time
|
// Reset currentDelay for next potential non-429 error, or if Retry-After is not present next time
|
||||||
currentDelay = initialDelayMs;
|
currentDelay = initialDelayMs;
|
||||||
} else {
|
} else {
|
||||||
@@ -252,7 +257,7 @@ export async function retryWithBackoff<T>(
|
|||||||
// Add jitter: +/- 30% of currentDelay
|
// Add jitter: +/- 30% of currentDelay
|
||||||
const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1);
|
const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1);
|
||||||
const delayWithJitter = Math.max(0, currentDelay + jitter);
|
const delayWithJitter = Math.max(0, currentDelay + jitter);
|
||||||
await delay(delayWithJitter);
|
await delay(delayWithJitter, signal);
|
||||||
currentDelay = Math.min(maxDelayMs, currentDelay * 2);
|
currentDelay = Math.min(maxDelayMs, currentDelay * 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user