diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index a090728200..f4dc1c552c 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -383,6 +383,7 @@ export class GeminiChat { onPersistent429: onPersistent429Callback, authType: this.config.getContentGeneratorConfig()?.authType, retryFetchErrors: this.config.getRetryFetchErrors(), + signal: params.config?.abortSignal, }); return this.processStreamResponse(model, streamResponse); diff --git a/packages/core/src/utils/delay.test.ts b/packages/core/src/utils/delay.test.ts new file mode 100644 index 0000000000..2f9a4d0702 --- /dev/null +++ b/packages/core/src/utils/delay.test.ts @@ -0,0 +1,112 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { beforeEach, afterEach, describe, expect, it, vi } from 'vitest'; +import { delay } from './delay.js'; + +describe('abortableDelay', () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.useRealTimers(); + vi.restoreAllMocks(); + }); + + it('resolves after the specified duration without a signal', async () => { + const promise = delay(100); + await vi.advanceTimersByTimeAsync(100); + await expect(promise).resolves.toBeUndefined(); + }); + + it('resolves when a non-aborted signal is provided', async () => { + const controller = new AbortController(); + const promise = delay(200, controller.signal); + + await vi.advanceTimersByTimeAsync(200); + + await expect(promise).resolves.toBeUndefined(); + }); + + it('rejects immediately if the signal is already aborted', async () => { + const controller = new AbortController(); + controller.abort(); + + await expect(delay(50, controller.signal)).rejects.toMatchObject({ + name: 'AbortError', + message: 'Aborted', + }); + }); + + it('rejects if the signal aborts while waiting', async () => { + const controller = new AbortController(); + const promise = delay(500, controller.signal); + + await vi.advanceTimersByTimeAsync(100); + controller.abort(); + + await expect(promise).rejects.toMatchObject({ + name: 'AbortError', + message: 'Aborted', + }); + }); + + it('cleans up signal listeners after resolving', async () => { + const removeEventListener = vi.fn(); + const mockSignal = { + aborted: false, + addEventListener: vi + .fn() + .mockImplementation((_type: string, listener: () => void) => { + mockSignal.__listener = listener; + }), + removeEventListener, + __listener: undefined as (() => void) | undefined, + } as unknown as AbortSignal & { __listener?: () => void }; + + const promise = delay(150, mockSignal); + await vi.advanceTimersByTimeAsync(150); + await promise; + + expect(mockSignal.addEventListener).toHaveBeenCalledTimes(1); + expect(removeEventListener).toHaveBeenCalledTimes(1); + expect(removeEventListener.mock.calls[0][1]).toBe(mockSignal.__listener); + }); + + // Technically unnecessary due to `onceTrue` but good sanity check + it('cleans up signal listeners when aborted before completion', async () => { + const controller = new AbortController(); + const removeEventListenerSpy = vi.spyOn( + controller.signal, + 'removeEventListener', + ); + + const promise = delay(400, controller.signal); + + await vi.advanceTimersByTimeAsync(50); + controller.abort(); + + await expect(promise).rejects.toMatchObject({ + name: 'AbortError', + }); + expect(removeEventListenerSpy).toHaveBeenCalledTimes(1); + }); + + it('cleans up timeout when aborted before completion', async () => { + const clearTimeoutSpy = vi.spyOn(global, 'clearTimeout'); + const controller = new AbortController(); + const promise = delay(400, controller.signal); + + await vi.advanceTimersByTimeAsync(50); + controller.abort(); + + await expect(promise).rejects.toMatchObject({ + name: 'AbortError', + }); + expect(clearTimeoutSpy).toHaveBeenCalled(); + }); +}); diff --git a/packages/core/src/utils/delay.ts b/packages/core/src/utils/delay.ts new file mode 100644 index 0000000000..d48db4951b --- /dev/null +++ b/packages/core/src/utils/delay.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Factory to create a standard abort error for delay helpers. + */ +export function createAbortError(): Error { + const abortError = new Error('Aborted'); + abortError.name = 'AbortError'; + return abortError; +} + +/** + * Returns a promise that resolves after the provided duration unless aborted. + * + * @param ms Delay duration in milliseconds. + * @param signal Optional abort signal to cancel the wait early. + */ +export function delay(ms: number, signal?: AbortSignal): Promise { + // If no abort signal is provided, set simple delay + if (!signal) { + return new Promise((resolve) => setTimeout(resolve, ms)); + } + + // Immediately reject if signal has already been aborted + if (signal.aborted) { + return Promise.reject(createAbortError()); + } + + // remove abort and timeout listeners to prevent memory-leaks + return new Promise((resolve, reject) => { + const onAbort = () => { + clearTimeout(timeoutId); + signal.removeEventListener('abort', onAbort); + reject(createAbortError()); + }; + + const timeoutId = setTimeout(() => { + signal.removeEventListener('abort', onAbort); + resolve(); + }, ms); + + signal.addEventListener('abort', onAbort, { once: true }); + }); +} diff --git a/packages/core/src/utils/retry.test.ts b/packages/core/src/utils/retry.test.ts index 9461b39b69..6a05223771 100644 --- a/packages/core/src/utils/retry.test.ts +++ b/packages/core/src/utils/retry.test.ts @@ -500,4 +500,25 @@ describe('retryWithBackoff', () => { expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal'); }); }); + it('should abort the retry loop when the signal is aborted', async () => { + const abortController = new AbortController(); + const mockFn = vi.fn().mockImplementation(async () => { + const error: HttpError = new Error('Server error'); + error.status = 500; + throw error; + }); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 5, + initialDelayMs: 100, + signal: abortController.signal, + }); + await vi.advanceTimersByTimeAsync(50); + abortController.abort(); + + await expect(promise).rejects.toThrow( + expect.objectContaining({ name: 'AbortError' }), + ); + expect(mockFn).toHaveBeenCalledTimes(1); + }); }); diff --git a/packages/core/src/utils/retry.ts b/packages/core/src/utils/retry.ts index 007874f965..0f42003d1a 100644 --- a/packages/core/src/utils/retry.ts +++ b/packages/core/src/utils/retry.ts @@ -11,6 +11,7 @@ import { isProQuotaExceededError, isGenericQuotaExceededError, } from './quotaErrorDetection.js'; +import { delay, createAbortError } from './delay.js'; const FETCH_FAILED_MESSAGE = 'exception TypeError: fetch failed sending request'; @@ -31,6 +32,7 @@ export interface RetryOptions { ) => Promise; authType?: string; retryFetchErrors?: boolean; + signal?: AbortSignal; } const DEFAULT_RETRY_OPTIONS: RetryOptions = { @@ -75,15 +77,6 @@ function defaultShouldRetry( return false; } -/** - * Delays execution for a specified number of milliseconds. - * @param ms The number of milliseconds to delay. - * @returns A promise that resolves after the delay. - */ -function delay(ms: number): Promise { - return new Promise((resolve) => setTimeout(resolve, ms)); -} - /** * Retries a function with exponential backoff and jitter. * @param fn The asynchronous function to retry. @@ -95,6 +88,10 @@ export async function retryWithBackoff( fn: () => Promise, options?: Partial, ): Promise { + if (options?.signal?.aborted) { + throw createAbortError(); + } + if (options?.maxAttempts !== undefined && options.maxAttempts <= 0) { throw new Error('maxAttempts must be a positive number.'); } @@ -112,6 +109,7 @@ export async function retryWithBackoff( shouldRetryOnError, shouldRetryOnContent, retryFetchErrors, + signal, } = { ...DEFAULT_RETRY_OPTIONS, ...cleanOptions, @@ -122,6 +120,9 @@ export async function retryWithBackoff( let consecutive429Count = 0; while (attempt < maxAttempts) { + if (signal?.aborted) { + throw createAbortError(); + } attempt++; try { const result = await fn(); @@ -132,13 +133,17 @@ export async function retryWithBackoff( ) { const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1); const delayWithJitter = Math.max(0, currentDelay + jitter); - await delay(delayWithJitter); + await delay(delayWithJitter, signal); currentDelay = Math.min(maxDelayMs, currentDelay * 2); continue; } return result; } catch (error) { + if (error instanceof Error && error.name === 'AbortError') { + throw error; + } + const errorStatus = getErrorStatus(error); // Check for Pro quota exceeded error first - immediate fallback for OAuth users @@ -243,7 +248,7 @@ export async function retryWithBackoff( `Attempt ${attempt} failed with status ${delayErrorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`, error, ); - await delay(delayDurationMs); + await delay(delayDurationMs, signal); // Reset currentDelay for next potential non-429 error, or if Retry-After is not present next time currentDelay = initialDelayMs; } else { @@ -252,7 +257,7 @@ export async function retryWithBackoff( // Add jitter: +/- 30% of currentDelay const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1); const delayWithJitter = Math.max(0, currentDelay + jitter); - await delay(delayWithJitter); + await delay(delayWithJitter, signal); currentDelay = Math.min(maxDelayMs, currentDelay * 2); } }