feat(core): Add content-based retries for JSON generation (#9264)

This commit is contained in:
Sandy Tao
2025-09-29 12:27:15 -07:00
committed by GitHub
parent 042288e72c
commit ac4a79223a
7 changed files with 145 additions and 69 deletions
+87 -22
View File
@@ -36,7 +36,29 @@ vi.mock('../utils/errors.js', async (importOriginal) => {
}); });
vi.mock('../utils/retry.js', () => ({ vi.mock('../utils/retry.js', () => ({
retryWithBackoff: vi.fn(async (fn) => await fn()), retryWithBackoff: vi.fn(async (fn, options) => {
// Default implementation - just call the function
const result = await fn();
// If shouldRetryOnContent is provided, test it but don't actually retry
// (unless we want to simulate retry exhaustion for testing)
if (options?.shouldRetryOnContent) {
const shouldRetry = options.shouldRetryOnContent(result);
if (shouldRetry) {
// Check if we need to simulate retry exhaustion (for error testing)
const responseText = result?.candidates?.[0]?.content?.parts?.[0]?.text;
if (
!responseText ||
responseText.trim() === '' ||
responseText.includes('{"color": "blue"')
) {
throw new Error('Retry attempts exhausted for invalid content');
}
}
}
return result;
}),
})); }));
const mockGenerateContent = vi.fn(); const mockGenerateContent = vi.fn();
@@ -96,8 +118,14 @@ describe('BaseLlmClient', () => {
expect(result).toEqual({ color: 'blue' }); expect(result).toEqual({ color: 'blue' });
// Ensure the retry mechanism was engaged // Ensure the retry mechanism was engaged with shouldRetryOnContent
expect(retryWithBackoff).toHaveBeenCalledTimes(1); expect(retryWithBackoff).toHaveBeenCalledTimes(1);
expect(retryWithBackoff).toHaveBeenCalledWith(
expect.any(Function),
expect.objectContaining({
shouldRetryOnContent: expect.any(Function),
}),
);
// Validate the parameters passed to the underlying generator // Validate the parameters passed to the underlying generator
expect(mockGenerateContent).toHaveBeenCalledTimes(1); expect(mockGenerateContent).toHaveBeenCalledTimes(1);
@@ -194,9 +222,12 @@ describe('BaseLlmClient', () => {
await client.generateJson(options); await client.generateJson(options);
expect(retryWithBackoff).toHaveBeenCalledTimes(1); expect(retryWithBackoff).toHaveBeenCalledTimes(1);
expect(retryWithBackoff).toHaveBeenCalledWith(expect.any(Function), { expect(retryWithBackoff).toHaveBeenCalledWith(
maxAttempts: customMaxAttempts, expect.any(Function),
}); expect.objectContaining({
maxAttempts: customMaxAttempts,
}),
);
}); });
it('should call retryWithBackoff without maxAttempts when not provided', async () => { it('should call retryWithBackoff without maxAttempts when not provided', async () => {
@@ -206,9 +237,44 @@ describe('BaseLlmClient', () => {
// No maxAttempts in defaultOptions // No maxAttempts in defaultOptions
await client.generateJson(defaultOptions); await client.generateJson(defaultOptions);
expect(retryWithBackoff).toHaveBeenCalledWith(expect.any(Function), { expect(retryWithBackoff).toHaveBeenCalledWith(
maxAttempts: 5, expect.any(Function),
}); expect.objectContaining({
maxAttempts: 5,
}),
);
});
});
describe('generateJson - Content Validation and Retries', () => {
it('should validate content using shouldRetryOnContent function', async () => {
const mockResponse = createMockResponse('{"color": "blue"}');
mockGenerateContent.mockResolvedValue(mockResponse);
await client.generateJson(defaultOptions);
// Verify that retryWithBackoff was called with shouldRetryOnContent
expect(retryWithBackoff).toHaveBeenCalledWith(
expect.any(Function),
expect.objectContaining({
shouldRetryOnContent: expect.any(Function),
}),
);
// Test the shouldRetryOnContent function behavior
const retryCall = vi.mocked(retryWithBackoff).mock.calls[0];
const shouldRetryOnContent = retryCall[1]?.shouldRetryOnContent;
// Valid JSON should not trigger retry
expect(shouldRetryOnContent!(mockResponse)).toBe(false);
// Empty response should trigger retry
expect(shouldRetryOnContent!(createMockResponse(''))).toBe(true);
// Invalid JSON should trigger retry
expect(
shouldRetryOnContent!(createMockResponse('{"color": "blue"')),
).toBe(true);
}); });
}); });
@@ -222,14 +288,14 @@ describe('BaseLlmClient', () => {
const result = await client.generateJson(defaultOptions); const result = await client.generateJson(defaultOptions);
expect(result).toEqual({ color: 'purple' }); expect(result).toEqual({ color: 'purple' });
expect(logMalformedJsonResponse).toHaveBeenCalledTimes(1);
expect(logMalformedJsonResponse).toHaveBeenCalledWith( expect(logMalformedJsonResponse).toHaveBeenCalledWith(
mockConfig, mockConfig,
expect.any(MalformedJsonResponseEvent), expect.any(MalformedJsonResponseEvent),
); );
// Validate the telemetry event content // Validate the telemetry event content - find the most recent call
const event = vi.mocked(logMalformedJsonResponse).mock const calls = vi.mocked(logMalformedJsonResponse).mock.calls;
.calls[0][1] as MalformedJsonResponseEvent; const lastCall = calls[calls.length - 1];
const event = lastCall[1] as MalformedJsonResponseEvent;
expect(event.model).toBe('test-model'); expect(event.model).toBe('test-model');
}); });
@@ -247,38 +313,37 @@ describe('BaseLlmClient', () => {
}); });
describe('generateJson - Error Handling', () => { describe('generateJson - Error Handling', () => {
it('should throw and report error for empty response', async () => { it('should throw and report error for empty response after retry exhaustion', async () => {
mockGenerateContent.mockResolvedValue(createMockResponse('')); mockGenerateContent.mockResolvedValue(createMockResponse(''));
// The final error message includes the prefix added by the client's outer catch block.
await expect(client.generateJson(defaultOptions)).rejects.toThrow( await expect(client.generateJson(defaultOptions)).rejects.toThrow(
'Failed to generate JSON content: API returned an empty response for generateJson.', 'Failed to generate JSON content: Retry attempts exhausted for invalid content',
); );
// Verify error reporting details // Verify error reporting details
expect(reportError).toHaveBeenCalledTimes(1); expect(reportError).toHaveBeenCalledTimes(1);
expect(reportError).toHaveBeenCalledWith( expect(reportError).toHaveBeenCalledWith(
expect.any(Error), expect.any(Error),
'Error in generateJson: API returned an empty response.', 'API returned invalid content (empty or unparsable JSON) after all retries.',
defaultOptions.contents, defaultOptions.contents,
'generateJson-empty-response', 'generateJson-invalid-content',
); );
}); });
it('should throw and report error for invalid JSON syntax', async () => { it('should throw and report error for invalid JSON syntax after retry exhaustion', async () => {
const invalidJson = '{"color": "blue"'; // missing closing brace const invalidJson = '{"color": "blue"'; // missing closing brace
mockGenerateContent.mockResolvedValue(createMockResponse(invalidJson)); mockGenerateContent.mockResolvedValue(createMockResponse(invalidJson));
await expect(client.generateJson(defaultOptions)).rejects.toThrow( await expect(client.generateJson(defaultOptions)).rejects.toThrow(
/^Failed to generate JSON content: Failed to parse API response as JSON:/, 'Failed to generate JSON content: Retry attempts exhausted for invalid content',
); );
expect(reportError).toHaveBeenCalledTimes(1); expect(reportError).toHaveBeenCalledTimes(1);
expect(reportError).toHaveBeenCalledWith( expect(reportError).toHaveBeenCalledWith(
expect.any(Error), expect.any(Error),
'Failed to parse JSON response from generateJson.', 'API returned invalid content (empty or unparsable JSON) after all retries.',
expect.objectContaining({ responseTextFailedToParse: invalidJson }), defaultOptions.contents,
'generateJson-parse', 'generateJson-invalid-content',
); );
}); });
+27 -36
View File
@@ -9,6 +9,7 @@ import type {
GenerateContentConfig, GenerateContentConfig,
Part, Part,
EmbedContentParameters, EmbedContentParameters,
GenerateContentResponse,
} from '@google/genai'; } from '@google/genai';
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
import type { ContentGenerator } from './contentGenerator.js'; import type { ContentGenerator } from './contentGenerator.js';
@@ -107,54 +108,44 @@ export class BaseLlmClient {
promptId, promptId,
); );
const shouldRetryOnContent = (response: GenerateContentResponse) => {
const text = getResponseText(response)?.trim();
if (!text) {
return true; // Retry on empty response
}
try {
JSON.parse(this.cleanJsonResponse(text, model));
return false;
} catch (_e) {
return true;
}
};
const result = await retryWithBackoff(apiCall, { const result = await retryWithBackoff(apiCall, {
shouldRetryOnContent,
maxAttempts: maxAttempts ?? DEFAULT_MAX_ATTEMPTS, maxAttempts: maxAttempts ?? DEFAULT_MAX_ATTEMPTS,
}); });
let text = getResponseText(result)?.trim(); // If we are here, the content is valid (not empty and parsable).
if (!text) { return JSON.parse(
const error = new Error( this.cleanJsonResponse(getResponseText(result)!.trim(), model),
'API returned an empty response for generateJson.', );
);
await reportError(
error,
'Error in generateJson: API returned an empty response.',
contents,
'generateJson-empty-response',
);
throw error;
}
text = this.cleanJsonResponse(text, model);
try {
return JSON.parse(text);
} catch (parseError) {
const error = new Error(
`Failed to parse API response as JSON: ${getErrorMessage(parseError)}`,
);
await reportError(
parseError,
'Failed to parse JSON response from generateJson.',
{
responseTextFailedToParse: text,
originalRequestContents: contents,
},
'generateJson-parse',
);
throw error;
}
} catch (error) { } catch (error) {
if (abortSignal.aborted) { if (abortSignal.aborted) {
throw error; throw error;
} }
// Check if the error is from exhausting retries, and report accordingly.
if ( if (
error instanceof Error && error instanceof Error &&
(error.message === 'API returned an empty response for generateJson.' || error.message.includes('Retry attempts exhausted')
error.message.startsWith('Failed to parse API response as JSON:'))
) { ) {
// We perform this check so that we don't report these again. await reportError(
error,
'API returned invalid content (empty or unparsable JSON) after all retries.',
contents,
'generateJson-invalid-content',
);
} else { } else {
await reportError( await reportError(
error, error,
+4 -1
View File
@@ -911,7 +911,10 @@ describe('GeminiChat', () => {
try { try {
return await apiCall(); return await apiCall();
} catch (error) { } catch (error) {
if (options?.shouldRetry && options.shouldRetry(error)) { if (
options?.shouldRetryOnError &&
options.shouldRetryOnError(error)
) {
// Try again // Try again
return await apiCall(); return await apiCall();
} }
+1 -1
View File
@@ -376,7 +376,7 @@ export class GeminiChat {
) => await handleFallback(this.config, model, authType, error); ) => await handleFallback(this.config, model, authType, error);
const streamResponse = await retryWithBackoff(apiCall, { const streamResponse = await retryWithBackoff(apiCall, {
shouldRetry: (error: unknown) => { shouldRetryOnError: (error: unknown) => {
if (error instanceof ApiError && error.message) { if (error instanceof ApiError && error.message) {
if (error.status === 400) return false; if (error.status === 400) return false;
if (isSchemaDepthError(error.message)) return false; if (isSchemaDepthError(error.message)) return false;
@@ -86,7 +86,7 @@ describe('Retry Utility Fallback Integration', () => {
maxAttempts: 2, maxAttempts: 2,
initialDelayMs: 1, initialDelayMs: 1,
maxDelayMs: 10, maxDelayMs: 10,
shouldRetry: (error: Error) => { shouldRetryOnError: (error: Error) => {
const status = (error as Error & { status?: number }).status; const status = (error as Error & { status?: number }).status;
return status === 429; return status === 429;
}, },
@@ -123,7 +123,7 @@ describe('Retry Utility Fallback Integration', () => {
maxAttempts: 5, maxAttempts: 5,
initialDelayMs: 10, initialDelayMs: 10,
maxDelayMs: 100, maxDelayMs: 100,
shouldRetry: (error: Error) => { shouldRetryOnError: (error: Error) => {
const status = (error as Error & { status?: number }).status; const status = (error as Error & { status?: number }).status;
return status === 429; return status === 429;
}, },
+3 -2
View File
@@ -137,10 +137,11 @@ describe('retryWithBackoff', () => {
const mockFn = vi.fn(async () => { const mockFn = vi.fn(async () => {
throw new NonRetryableError('Non-retryable error'); throw new NonRetryableError('Non-retryable error');
}); });
const shouldRetry = (error: Error) => !(error instanceof NonRetryableError); const shouldRetryOnError = (error: Error) =>
!(error instanceof NonRetryableError);
const promise = retryWithBackoff(mockFn, { const promise = retryWithBackoff(mockFn, {
shouldRetry, shouldRetryOnError,
initialDelayMs: 10, initialDelayMs: 10,
}); });
+21 -5
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import type { GenerateContentResponse } from '@google/genai';
import { AuthType } from '../core/contentGenerator.js'; import { AuthType } from '../core/contentGenerator.js';
import { import {
isProQuotaExceededError, isProQuotaExceededError,
@@ -18,7 +19,8 @@ export interface RetryOptions {
maxAttempts: number; maxAttempts: number;
initialDelayMs: number; initialDelayMs: number;
maxDelayMs: number; maxDelayMs: number;
shouldRetry: (error: Error) => boolean; shouldRetryOnError: (error: Error) => boolean;
shouldRetryOnContent?: (content: GenerateContentResponse) => boolean;
onPersistent429?: ( onPersistent429?: (
authType?: string, authType?: string,
error?: unknown, error?: unknown,
@@ -30,7 +32,7 @@ const DEFAULT_RETRY_OPTIONS: RetryOptions = {
maxAttempts: 5, maxAttempts: 5,
initialDelayMs: 5000, initialDelayMs: 5000,
maxDelayMs: 30000, // 30 seconds maxDelayMs: 30000, // 30 seconds
shouldRetry: defaultShouldRetry, shouldRetryOnError: defaultShouldRetry,
}; };
/** /**
@@ -88,7 +90,8 @@ export async function retryWithBackoff<T>(
maxDelayMs, maxDelayMs,
onPersistent429, onPersistent429,
authType, authType,
shouldRetry, shouldRetryOnError,
shouldRetryOnContent,
} = { } = {
...DEFAULT_RETRY_OPTIONS, ...DEFAULT_RETRY_OPTIONS,
...cleanOptions, ...cleanOptions,
@@ -101,7 +104,20 @@ export async function retryWithBackoff<T>(
while (attempt < maxAttempts) { while (attempt < maxAttempts) {
attempt++; attempt++;
try { try {
return await fn(); const result = await fn();
if (
shouldRetryOnContent &&
shouldRetryOnContent(result as GenerateContentResponse)
) {
const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1);
const delayWithJitter = Math.max(0, currentDelay + jitter);
await delay(delayWithJitter);
currentDelay = Math.min(maxDelayMs, currentDelay * 2);
continue;
}
return result;
} catch (error) { } catch (error) {
const errorStatus = getErrorStatus(error); const errorStatus = getErrorStatus(error);
@@ -191,7 +207,7 @@ export async function retryWithBackoff<T>(
} }
// Check if we've exhausted retries or shouldn't retry // Check if we've exhausted retries or shouldn't retry
if (attempt >= maxAttempts || !shouldRetry(error as Error)) { if (attempt >= maxAttempts || !shouldRetryOnError(error as Error)) {
throw error; throw error;
} }