From f8198a25d8ecaf5e36ed248c60825253b83f6dd7 Mon Sep 17 00:00:00 2001 From: Daniel Weis Date: Mon, 11 May 2026 16:09:38 -0400 Subject: [PATCH] fix(routing): Refactor tool turn handling for the conversation history in NumericalClassifierStrategy to prevent 400 Bad Request (#26761) --- .../numericalClassifierStrategy.test.ts | 237 ++++++++++++++++-- .../strategies/numericalClassifierStrategy.ts | 23 +- 2 files changed, 234 insertions(+), 26 deletions(-) diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts index dcfdff786b..f400dfc51b 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts @@ -5,7 +5,10 @@ */ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import { NumericalClassifierStrategy } from './numericalClassifierStrategy.js'; +import { + NumericalClassifierStrategy, + HISTORY_TURNS_FOR_CONTEXT, +} from './numericalClassifierStrategy.js'; import type { RoutingContext } from '../routingStrategy.js'; import type { Config } from '../../config/config.js'; import type { BaseLlmClient } from '../../core/baseLlmClient.js'; @@ -423,18 +426,21 @@ describe('NumericalClassifierStrategy', () => { expect(consoleWarnSpy).toHaveBeenCalled(); }); - it('should include tool-related history when sending to classifier', async () => { - mockContext.history = [ - { role: 'user', parts: [{ text: 'call a tool' }] }, - { role: 'model', parts: [{ functionCall: { name: 'test_tool' } }] }, + it('should strip leading tool turns when history starts with tool calls', async () => { + const history: Content[] = [ + { role: 'model', parts: [{ functionCall: { name: 'leading_tool' } }] }, { role: 'user', parts: [ - { functionResponse: { name: 'test_tool', response: { ok: true } } }, + { + functionResponse: { name: 'leading_tool', response: { ok: true } }, + }, ], }, - { role: 'user', parts: [{ text: 'another user turn' }] }, + { role: 'model', parts: [{ text: 'text response 1' }] }, + { role: 'user', parts: [{ text: 'text request 2' }] }, ]; + mockContext.history = history; const mockApiResponse = { complexity_reasoning: 'Simple.', complexity_score: 10, @@ -454,9 +460,9 @@ describe('NumericalClassifierStrategy', () => { .calls[0][0]; const contents = generateJsonCall.contents; + // Expect leading tool turns (index 0 and 1) to be stripped, keeping only text turns (index 2 and 3) const expectedContents = [ - ...mockContext.history, - // The last user turn is the request part + ...history.slice(2), { role: 'user', parts: [{ text: 'simple task' }], @@ -466,12 +472,25 @@ describe('NumericalClassifierStrategy', () => { expect(contents).toEqual(expectedContents); }); - it('should respect HISTORY_TURNS_FOR_CONTEXT', async () => { - const longHistory: Content[] = []; - for (let i = 0; i < 30; i++) { - longHistory.push({ role: 'user', parts: [{ text: `Message ${i}` }] }); - } - mockContext.history = longHistory; + it('should preserve tool turns when they appear after a non-tool turn in the middle of history', async () => { + const history: Content[] = [ + { role: 'user', parts: [{ text: 'turn 0 (before)' }] }, + { role: 'model', parts: [{ text: 'turn 1 (before)' }] }, + { role: 'user', parts: [{ text: 'turn 2 (before)' }] }, + { role: 'model', parts: [{ text: 'turn 3 (before)' }] }, + { role: 'model', parts: [{ functionCall: { name: 'middle_tool' } }] }, + { + role: 'user', + parts: [ + { functionResponse: { name: 'middle_tool', response: { ok: true } } }, + ], + }, + { role: 'model', parts: [{ text: 'turn 6 (after)' }] }, + { role: 'user', parts: [{ text: 'turn 7 (after)' }] }, + { role: 'model', parts: [{ text: 'turn 8 (after)' }] }, + { role: 'user', parts: [{ text: 'turn 9 (after)' }] }, + ]; + mockContext.history = history; const mockApiResponse = { complexity_reasoning: 'Simple.', complexity_score: 10, @@ -491,18 +510,188 @@ describe('NumericalClassifierStrategy', () => { .calls[0][0]; const contents = generateJsonCall.contents; - // Manually calculate what the history should be - const HISTORY_TURNS_FOR_CONTEXT = 8; - const finalHistory = longHistory.slice(-HISTORY_TURNS_FOR_CONTEXT); + // Expect all 8 sliced turns (starting from non-tool turn 2) to be preserved + const expectedContents = [ + ...history.slice(2), + { + role: 'user', + parts: [{ text: 'simple task' }], + }, + ]; - // Last part is the request - const requestPart = { - role: 'user', - parts: [{ text: 'simple task' }], + expect(contents).toEqual(expectedContents); + }); + + it('should preserve tool turns when they appear at the very end of history following a non-tool turn', async () => { + const history: Content[] = [ + { role: 'user', parts: [{ text: 'turn 0' }] }, + { role: 'model', parts: [{ text: 'turn 1' }] }, + { role: 'user', parts: [{ text: 'turn 2' }] }, + { role: 'model', parts: [{ text: 'turn 3' }] }, + { role: 'user', parts: [{ text: 'turn 4' }] }, + { role: 'model', parts: [{ text: 'turn 5' }] }, + { role: 'user', parts: [{ text: 'turn 6' }] }, + { role: 'model', parts: [{ text: 'turn 7' }] }, + { role: 'model', parts: [{ functionCall: { name: 'end_tool' } }] }, + { + role: 'user', + parts: [ + { functionResponse: { name: 'end_tool', response: { ok: true } } }, + ], + }, + ]; + mockContext.history = history; + const mockApiResponse = { + complexity_reasoning: 'Simple.', + complexity_score: 10, }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); - expect(contents).toEqual([...finalHistory, requestPart]); - expect(contents).toHaveLength(9); + await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); + + const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock + .calls[0][0]; + const contents = generateJsonCall.contents; + + // Expect all 8 sliced turns to be preserved because index 2 is a non-tool turn + const expectedContents = [ + ...history.slice(2), + { + role: 'user', + parts: [{ text: 'simple task' }], + }, + ]; + + expect(contents).toEqual(expectedContents); + }); + + it('should send only the new request prompt if the entire history consists of tool-related turns', async () => { + const history: Content[] = [ + { role: 'model', parts: [{ functionCall: { name: 'tool_A' } }] }, + { + role: 'user', + parts: [ + { functionResponse: { name: 'tool_A', response: { ok: true } } }, + ], + }, + { role: 'model', parts: [{ functionCall: { name: 'tool_B' } }] }, + { + role: 'user', + parts: [ + { functionResponse: { name: 'tool_B', response: { ok: true } } }, + ], + }, + ]; + mockContext.history = history; + const mockApiResponse = { + complexity_reasoning: 'Simple standalone task.', + complexity_score: 10, + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); + + const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock + .calls[0][0]; + const contents = generateJsonCall.contents; + + // Expect all history turns to be filtered out, leaving exactly just the new request + const expectedContents = [ + { + role: 'user', + parts: [{ text: 'simple task' }], + }, + ]; + + expect(contents).toEqual(expectedContents); + }); + + it('should respect HISTORY_TURNS_FOR_CONTEXT correctly when history has only text turns', async () => { + const history: Content[] = []; + for (let i = 0; i < HISTORY_TURNS_FOR_CONTEXT + 2; i++) { + history.push({ role: 'user', parts: [{ text: `Message ${i}` }] }); + } + mockContext.history = history; + const mockApiResponse = { + complexity_reasoning: 'Simple.', + complexity_score: 10, + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); + + const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock + .calls[0][0]; + const contents = generateJsonCall.contents; + + // Expect exactly the last 8 turns (history.slice(2)) + expect(contents).toEqual([ + ...history.slice(2), + { role: 'user', parts: [{ text: 'simple task' }] }, + ]); + expect(contents).toHaveLength(HISTORY_TURNS_FOR_CONTEXT + 1); + }); + + it('should respect HISTORY_TURNS_FOR_CONTEXT correctly when history starts with tool calls', async () => { + const history: Content[] = [ + { role: 'model', parts: [{ functionCall: { name: 'tool_0' } }] }, + { + role: 'user', + parts: [ + { functionResponse: { name: 'tool_0', response: { ok: true } } }, + ], + }, + ]; + for (let i = 0; i < HISTORY_TURNS_FOR_CONTEXT; i++) { + history.push({ role: 'user', parts: [{ text: `Message ${i}` }] }); + } + mockContext.history = history; + const mockApiResponse = { + complexity_reasoning: 'Simple.', + complexity_score: 10, + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + mockLocalLiteRtLmClient, + ); + + const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock + .calls[0][0]; + const contents = generateJsonCall.contents; + + // Expect exactly the last 8 text turns (history.slice(2)) + expect(contents).toEqual([ + ...history.slice(2), + { role: 'user', parts: [{ text: 'simple task' }] }, + ]); + expect(contents).toHaveLength(HISTORY_TURNS_FOR_CONTEXT + 1); }); it('should use a fallback promptId if not found in context', async () => { diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts index 8bcfb3da67..0e2401c8f1 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts @@ -15,12 +15,16 @@ import type { import { resolveClassifierModel, isGemini3Model } from '../../config/models.js'; import { createUserContent, Type } from '@google/genai'; import type { Config } from '../../config/config.js'; +import { + isFunctionCall, + isFunctionResponse, +} from '../../utils/messageInspectors.js'; import { debugLogger } from '../../utils/debugLogger.js'; import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; import { LlmRole } from '../../telemetry/types.js'; // The number of recent history turns to provide to the router for context. -const HISTORY_TURNS_FOR_CONTEXT = 8; +export const HISTORY_TURNS_FOR_CONTEXT = 8; const FLASH_MODEL = 'flash'; const PRO_MODEL = 'pro'; @@ -115,7 +119,22 @@ export class NumericalClassifierStrategy implements RoutingStrategy { const promptId = getPromptIdWithFallback('classifier-router'); - const finalHistory = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT); + const candidateSlice = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT); + + // Find the first non-tool turn. The server cannot always handle tool-related + // turns in the first slots of the contents array, so we strip them if they appear at the start. + let firstTextIndex = -1; + for (let i = 0; i < candidateSlice.length; i++) { + if ( + !isFunctionCall(candidateSlice[i]) && + !isFunctionResponse(candidateSlice[i]) + ) { + firstTextIndex = i; + break; + } + } + const finalHistory = + firstTextIndex === -1 ? [] : candidateSlice.slice(firstTextIndex); // Wrap the user's request in tags to prevent prompt injection const requestParts = Array.isArray(context.request)