Improve MALFORMED_FUNCTION_CALL handling (#12965)

This commit is contained in:
Tommaso Sciortino
2025-11-13 08:07:49 -08:00
committed by GitHub
parent 54c1e13853
commit fb03242950
5 changed files with 194 additions and 28 deletions
+96
View File
@@ -680,6 +680,102 @@ describe('GeminiChat', () => {
).resolves.not.toThrow(); ).resolves.not.toThrow();
}); });
it('should throw InvalidStreamError when finishReason is MALFORMED_FUNCTION_CALL', async () => {
// Setup: Stream with MALFORMED_FUNCTION_CALL finish reason and empty response
const streamWithMalformedFunctionCall = (async function* () {
yield {
candidates: [
{
content: {
role: 'model',
parts: [], // Empty parts
},
finishReason: 'MALFORMED_FUNCTION_CALL',
},
],
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
streamWithMalformedFunctionCall,
);
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'test' },
'prompt-id-malformed',
);
// Should throw an error
await expect(
(async () => {
for await (const _ of stream) {
// consume stream
}
})(),
).rejects.toThrow(InvalidStreamError);
});
it('should retry when finishReason is MALFORMED_FUNCTION_CALL', async () => {
// 1. Mock the API to fail once with MALFORMED_FUNCTION_CALL, then succeed.
vi.mocked(mockContentGenerator.generateContentStream)
.mockImplementationOnce(async () =>
(async function* () {
yield {
candidates: [
{
content: { parts: [], role: 'model' },
finishReason: 'MALFORMED_FUNCTION_CALL',
},
],
} as unknown as GenerateContentResponse;
})(),
)
.mockImplementationOnce(async () =>
// Second attempt succeeds
(async function* () {
yield {
candidates: [
{
content: { parts: [{ text: 'Success after retry' }] },
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})(),
);
// 2. Send a message
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'test retry' },
'prompt-id-retry-malformed',
);
const events: StreamEvent[] = [];
for await (const event of stream) {
events.push(event);
}
// 3. Assertions
// Should be called twice (initial + retry)
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(
2,
);
// Check for a retry event
expect(events.some((e) => e.type === StreamEventType.RETRY)).toBe(true);
// Check for the successful content chunk
expect(
events.some(
(e) =>
e.type === StreamEventType.CHUNK &&
e.value.candidates?.[0]?.content?.parts?.[0]?.text ===
'Success after retry',
),
).toBe(true);
});
it('should call generateContentStream with the correct parameters', async () => { it('should call generateContentStream with the correct parameters', async () => {
const response = (async function* () { const response = (async function* () {
yield { yield {
+27 -17
View File
@@ -16,7 +16,7 @@ import type {
Tool, Tool,
} from '@google/genai'; } from '@google/genai';
import { toParts } from '../code_assist/converter.js'; import { toParts } from '../code_assist/converter.js';
import { createUserContent } from '@google/genai'; import { createUserContent, FinishReason } from '@google/genai';
import { retryWithBackoff } from '../utils/retry.js'; import { retryWithBackoff } from '../utils/retry.js';
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
import { import {
@@ -167,9 +167,15 @@ function extractCuratedHistory(comprehensiveHistory: Content[]): Content[] {
* which should trigger a retry. * which should trigger a retry.
*/ */
export class InvalidStreamError extends Error { export class InvalidStreamError extends Error {
readonly type: 'NO_FINISH_REASON' | 'NO_RESPONSE_TEXT'; readonly type:
| 'NO_FINISH_REASON'
| 'NO_RESPONSE_TEXT'
| 'MALFORMED_FUNCTION_CALL';
constructor(message: string, type: 'NO_FINISH_REASON' | 'NO_RESPONSE_TEXT') { constructor(
message: string,
type: 'NO_FINISH_REASON' | 'NO_RESPONSE_TEXT' | 'MALFORMED_FUNCTION_CALL',
) {
super(message); super(message);
this.name = 'InvalidStreamError'; this.name = 'InvalidStreamError';
this.type = type; this.type = type;
@@ -502,11 +508,16 @@ export class GeminiChat {
const modelResponseParts: Part[] = []; const modelResponseParts: Part[] = [];
let hasToolCall = false; let hasToolCall = false;
let hasFinishReason = false; let finishReason: FinishReason | undefined;
for await (const chunk of streamResponse) { for await (const chunk of streamResponse) {
hasFinishReason = const candidateWithReason = chunk?.candidates?.find(
chunk?.candidates?.some((candidate) => candidate.finishReason) ?? false; (candidate) => candidate.finishReason,
);
if (candidateWithReason) {
finishReason = candidateWithReason.finishReason as FinishReason;
}
if (isValidResponse(chunk)) { if (isValidResponse(chunk)) {
const content = chunk.candidates?.[0]?.content; const content = chunk.candidates?.[0]?.content;
if (content?.parts) { if (content?.parts) {
@@ -564,21 +575,20 @@ export class GeminiChat {
content: responseText, content: responseText,
}); });
} }
if (!hasToolCall) {
// Stream validation logic: A stream is considered successful if: if (!finishReason) {
// 1. There's a tool call (tool calls can end without explicit finish reasons), OR
// 2. There's a finish reason AND we have non-empty response text
//
// We throw an error only when there's no tool call AND:
// - No finish reason, OR
// - Empty response text (e.g., only thoughts with no actual content)
if (!hasToolCall && (!hasFinishReason || !responseText)) {
if (!hasFinishReason) {
throw new InvalidStreamError( throw new InvalidStreamError(
'Model stream ended without a finish reason.', 'Model stream ended without a finish reason.',
'NO_FINISH_REASON', 'NO_FINISH_REASON',
); );
} else { }
if (finishReason === FinishReason.MALFORMED_FUNCTION_CALL) {
throw new InvalidStreamError(
'Model stream ended with malformed function call.',
'MALFORMED_FUNCTION_CALL',
);
}
if (!responseText) {
throw new InvalidStreamError( throw new InvalidStreamError(
'Model stream ended with empty response text.', 'Model stream ended with empty response text.',
'NO_RESPONSE_TEXT', 'NO_RESPONSE_TEXT',
+2 -2
View File
@@ -442,7 +442,7 @@ export function logContentRetry(
attributes: event.toOpenTelemetryAttributes(config), attributes: event.toOpenTelemetryAttributes(config),
}; };
logger.emit(logRecord); logger.emit(logRecord);
recordContentRetry(config); recordContentRetry(config, event.error_type);
} }
export function logContentRetryFailure( export function logContentRetryFailure(
@@ -458,7 +458,7 @@ export function logContentRetryFailure(
attributes: event.toOpenTelemetryAttributes(config), attributes: event.toOpenTelemetryAttributes(config),
}; };
logger.emit(logRecord); logger.emit(logRecord);
recordContentRetryFailure(config); recordContentRetryFailure(config, event.final_error_type);
} }
export function logModelRouting( export function logModelRouting(
@@ -96,6 +96,8 @@ describe('Telemetry Metrics', () => {
let recordAgentRunMetricsModule: typeof import('./metrics.js').recordAgentRunMetrics; let recordAgentRunMetricsModule: typeof import('./metrics.js').recordAgentRunMetrics;
let recordLinesChangedModule: typeof import('./metrics.js').recordLinesChanged; let recordLinesChangedModule: typeof import('./metrics.js').recordLinesChanged;
let recordSlowRenderModule: typeof import('./metrics.js').recordSlowRender; let recordSlowRenderModule: typeof import('./metrics.js').recordSlowRender;
let recordContentRetryModule: typeof import('./metrics.js').recordContentRetry;
let recordContentRetryFailureModule: typeof import('./metrics.js').recordContentRetryFailure;
beforeEach(async () => { beforeEach(async () => {
vi.resetModules(); vi.resetModules();
@@ -140,6 +142,8 @@ describe('Telemetry Metrics', () => {
recordAgentRunMetricsModule = metricsJsModule.recordAgentRunMetrics; recordAgentRunMetricsModule = metricsJsModule.recordAgentRunMetrics;
recordLinesChangedModule = metricsJsModule.recordLinesChanged; recordLinesChangedModule = metricsJsModule.recordLinesChanged;
recordSlowRenderModule = metricsJsModule.recordSlowRender; recordSlowRenderModule = metricsJsModule.recordSlowRender;
recordContentRetryModule = metricsJsModule.recordContentRetry;
recordContentRetryFailureModule = metricsJsModule.recordContentRetryFailure;
const otelApiModule = await import('@opentelemetry/api'); const otelApiModule = await import('@opentelemetry/api');
@@ -1343,4 +1347,50 @@ describe('Telemetry Metrics', () => {
}); });
}); });
}); });
describe('recordContentRetry', () => {
it('does not record metrics if not initialized', () => {
const config = makeFakeConfig({});
recordContentRetryModule(config, 'NO_FINISH_REASON');
expect(mockCounterAddFn).not.toHaveBeenCalled();
});
it('records a content retry event with error type when initialized', () => {
const config = makeFakeConfig({});
initializeMetricsModule(config);
mockCounterAddFn.mockClear(); // Clear the session start call
recordContentRetryModule(config, 'MALFORMED_FUNCTION_CALL');
expect(mockCounterAddFn).toHaveBeenCalledWith(1, {
'session.id': 'test-session-id',
'installation.id': 'test-installation-id',
'user.email': 'test@example.com',
error_type: 'MALFORMED_FUNCTION_CALL',
});
});
});
describe('recordContentRetryFailure', () => {
it('does not record metrics if not initialized', () => {
const config = makeFakeConfig({});
recordContentRetryFailureModule(config, 'NO_RESPONSE_TEXT');
expect(mockCounterAddFn).not.toHaveBeenCalled();
});
it('records a content retry failure event with error type when initialized', () => {
const config = makeFakeConfig({});
initializeMetricsModule(config);
mockCounterAddFn.mockClear(); // Clear the session start call
recordContentRetryFailureModule(config, 'MALFORMED_FUNCTION_CALL');
expect(mockCounterAddFn).toHaveBeenCalledWith(1, {
'session.id': 'test-session-id',
'installation.id': 'test-installation-id',
'user.email': 'test@example.com',
error_type: 'MALFORMED_FUNCTION_CALL',
});
});
});
}); });
+19 -9
View File
@@ -136,13 +136,17 @@ const COUNTER_DEFINITIONS = {
description: 'Counts retries due to content errors (e.g., empty stream).', description: 'Counts retries due to content errors (e.g., empty stream).',
valueType: ValueType.INT, valueType: ValueType.INT,
assign: (c: Counter) => (contentRetryCounter = c), assign: (c: Counter) => (contentRetryCounter = c),
attributes: {} as Record<string, never>, attributes: {} as {
error_type: string;
},
}, },
[CONTENT_RETRY_FAILURE_COUNT]: { [CONTENT_RETRY_FAILURE_COUNT]: {
description: 'Counts occurrences of all content retries failing.', description: 'Counts occurrences of all content retries failing.',
valueType: ValueType.INT, valueType: ValueType.INT,
assign: (c: Counter) => (contentRetryFailureCounter = c), assign: (c: Counter) => (contentRetryFailureCounter = c),
attributes: {} as Record<string, never>, attributes: {} as {
error_type: string;
},
}, },
[MODEL_ROUTING_FAILURE_COUNT]: { [MODEL_ROUTING_FAILURE_COUNT]: {
description: 'Counts model routing failures.', description: 'Counts model routing failures.',
@@ -715,20 +719,26 @@ export function recordInvalidChunk(config: Config): void {
/** /**
* Records a metric for when a retry is triggered due to a content error. * Records a metric for when a retry is triggered due to a content error.
*/ */
export function recordContentRetry(config: Config): void { export function recordContentRetry(config: Config, errorType: string): void {
if (!contentRetryCounter || !isMetricsInitialized) return; if (!contentRetryCounter || !isMetricsInitialized) return;
contentRetryCounter.add(1, baseMetricDefinition.getCommonAttributes(config)); contentRetryCounter.add(1, {
...baseMetricDefinition.getCommonAttributes(config),
error_type: errorType,
});
} }
/** /**
* Records a metric for when all content error retries have failed for a request. * Records a metric for when all content error retries have failed for a request.
*/ */
export function recordContentRetryFailure(config: Config): void { export function recordContentRetryFailure(
config: Config,
errorType: string,
): void {
if (!contentRetryFailureCounter || !isMetricsInitialized) return; if (!contentRetryFailureCounter || !isMetricsInitialized) return;
contentRetryFailureCounter.add( contentRetryFailureCounter.add(1, {
1, ...baseMetricDefinition.getCommonAttributes(config),
baseMetricDefinition.getCommonAttributes(config), error_type: errorType,
); });
} }
export function recordModelSlashCommand( export function recordModelSlashCommand(