fix(routing): Refactor tool turn handling for the conversation history in NumericalClassifierStrategy to prevent 400 Bad Request (#26761)

This commit is contained in:
Daniel Weis
2026-05-11 16:09:38 -04:00
committed by GitHub
parent 36a7fa089c
commit f8198a25d8
2 changed files with 234 additions and 26 deletions
@@ -5,7 +5,10 @@
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { NumericalClassifierStrategy } from './numericalClassifierStrategy.js';
import {
NumericalClassifierStrategy,
HISTORY_TURNS_FOR_CONTEXT,
} from './numericalClassifierStrategy.js';
import type { RoutingContext } from '../routingStrategy.js';
import type { Config } from '../../config/config.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
@@ -423,18 +426,21 @@ describe('NumericalClassifierStrategy', () => {
expect(consoleWarnSpy).toHaveBeenCalled();
});
it('should include tool-related history when sending to classifier', async () => {
mockContext.history = [
{ role: 'user', parts: [{ text: 'call a tool' }] },
{ role: 'model', parts: [{ functionCall: { name: 'test_tool' } }] },
it('should strip leading tool turns when history starts with tool calls', async () => {
const history: Content[] = [
{ role: 'model', parts: [{ functionCall: { name: 'leading_tool' } }] },
{
role: 'user',
parts: [
{ functionResponse: { name: 'test_tool', response: { ok: true } } },
{
functionResponse: { name: 'leading_tool', response: { ok: true } },
},
],
},
{ role: 'user', parts: [{ text: 'another user turn' }] },
{ role: 'model', parts: [{ text: 'text response 1' }] },
{ role: 'user', parts: [{ text: 'text request 2' }] },
];
mockContext.history = history;
const mockApiResponse = {
complexity_reasoning: 'Simple.',
complexity_score: 10,
@@ -454,9 +460,9 @@ describe('NumericalClassifierStrategy', () => {
.calls[0][0];
const contents = generateJsonCall.contents;
// Expect leading tool turns (index 0 and 1) to be stripped, keeping only text turns (index 2 and 3)
const expectedContents = [
...mockContext.history,
// The last user turn is the request part
...history.slice(2),
{
role: 'user',
parts: [{ text: 'simple task' }],
@@ -466,12 +472,25 @@ describe('NumericalClassifierStrategy', () => {
expect(contents).toEqual(expectedContents);
});
it('should respect HISTORY_TURNS_FOR_CONTEXT', async () => {
const longHistory: Content[] = [];
for (let i = 0; i < 30; i++) {
longHistory.push({ role: 'user', parts: [{ text: `Message ${i}` }] });
}
mockContext.history = longHistory;
it('should preserve tool turns when they appear after a non-tool turn in the middle of history', async () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'turn 0 (before)' }] },
{ role: 'model', parts: [{ text: 'turn 1 (before)' }] },
{ role: 'user', parts: [{ text: 'turn 2 (before)' }] },
{ role: 'model', parts: [{ text: 'turn 3 (before)' }] },
{ role: 'model', parts: [{ functionCall: { name: 'middle_tool' } }] },
{
role: 'user',
parts: [
{ functionResponse: { name: 'middle_tool', response: { ok: true } } },
],
},
{ role: 'model', parts: [{ text: 'turn 6 (after)' }] },
{ role: 'user', parts: [{ text: 'turn 7 (after)' }] },
{ role: 'model', parts: [{ text: 'turn 8 (after)' }] },
{ role: 'user', parts: [{ text: 'turn 9 (after)' }] },
];
mockContext.history = history;
const mockApiResponse = {
complexity_reasoning: 'Simple.',
complexity_score: 10,
@@ -491,18 +510,188 @@ describe('NumericalClassifierStrategy', () => {
.calls[0][0];
const contents = generateJsonCall.contents;
// Manually calculate what the history should be
const HISTORY_TURNS_FOR_CONTEXT = 8;
const finalHistory = longHistory.slice(-HISTORY_TURNS_FOR_CONTEXT);
// Expect all 8 sliced turns (starting from non-tool turn 2) to be preserved
const expectedContents = [
...history.slice(2),
{
role: 'user',
parts: [{ text: 'simple task' }],
},
];
// Last part is the request
const requestPart = {
role: 'user',
parts: [{ text: 'simple task' }],
expect(contents).toEqual(expectedContents);
});
it('should preserve tool turns when they appear at the very end of history following a non-tool turn', async () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'turn 0' }] },
{ role: 'model', parts: [{ text: 'turn 1' }] },
{ role: 'user', parts: [{ text: 'turn 2' }] },
{ role: 'model', parts: [{ text: 'turn 3' }] },
{ role: 'user', parts: [{ text: 'turn 4' }] },
{ role: 'model', parts: [{ text: 'turn 5' }] },
{ role: 'user', parts: [{ text: 'turn 6' }] },
{ role: 'model', parts: [{ text: 'turn 7' }] },
{ role: 'model', parts: [{ functionCall: { name: 'end_tool' } }] },
{
role: 'user',
parts: [
{ functionResponse: { name: 'end_tool', response: { ok: true } } },
],
},
];
mockContext.history = history;
const mockApiResponse = {
complexity_reasoning: 'Simple.',
complexity_score: 10,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
expect(contents).toEqual([...finalHistory, requestPart]);
expect(contents).toHaveLength(9);
await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
mockLocalLiteRtLmClient,
);
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
.calls[0][0];
const contents = generateJsonCall.contents;
// Expect all 8 sliced turns to be preserved because index 2 is a non-tool turn
const expectedContents = [
...history.slice(2),
{
role: 'user',
parts: [{ text: 'simple task' }],
},
];
expect(contents).toEqual(expectedContents);
});
it('should send only the new request prompt if the entire history consists of tool-related turns', async () => {
const history: Content[] = [
{ role: 'model', parts: [{ functionCall: { name: 'tool_A' } }] },
{
role: 'user',
parts: [
{ functionResponse: { name: 'tool_A', response: { ok: true } } },
],
},
{ role: 'model', parts: [{ functionCall: { name: 'tool_B' } }] },
{
role: 'user',
parts: [
{ functionResponse: { name: 'tool_B', response: { ok: true } } },
],
},
];
mockContext.history = history;
const mockApiResponse = {
complexity_reasoning: 'Simple standalone task.',
complexity_score: 10,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
mockLocalLiteRtLmClient,
);
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
.calls[0][0];
const contents = generateJsonCall.contents;
// Expect all history turns to be filtered out, leaving exactly just the new request
const expectedContents = [
{
role: 'user',
parts: [{ text: 'simple task' }],
},
];
expect(contents).toEqual(expectedContents);
});
it('should respect HISTORY_TURNS_FOR_CONTEXT correctly when history has only text turns', async () => {
const history: Content[] = [];
for (let i = 0; i < HISTORY_TURNS_FOR_CONTEXT + 2; i++) {
history.push({ role: 'user', parts: [{ text: `Message ${i}` }] });
}
mockContext.history = history;
const mockApiResponse = {
complexity_reasoning: 'Simple.',
complexity_score: 10,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
mockLocalLiteRtLmClient,
);
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
.calls[0][0];
const contents = generateJsonCall.contents;
// Expect exactly the last 8 turns (history.slice(2))
expect(contents).toEqual([
...history.slice(2),
{ role: 'user', parts: [{ text: 'simple task' }] },
]);
expect(contents).toHaveLength(HISTORY_TURNS_FOR_CONTEXT + 1);
});
it('should respect HISTORY_TURNS_FOR_CONTEXT correctly when history starts with tool calls', async () => {
const history: Content[] = [
{ role: 'model', parts: [{ functionCall: { name: 'tool_0' } }] },
{
role: 'user',
parts: [
{ functionResponse: { name: 'tool_0', response: { ok: true } } },
],
},
];
for (let i = 0; i < HISTORY_TURNS_FOR_CONTEXT; i++) {
history.push({ role: 'user', parts: [{ text: `Message ${i}` }] });
}
mockContext.history = history;
const mockApiResponse = {
complexity_reasoning: 'Simple.',
complexity_score: 10,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
mockLocalLiteRtLmClient,
);
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
.calls[0][0];
const contents = generateJsonCall.contents;
// Expect exactly the last 8 text turns (history.slice(2))
expect(contents).toEqual([
...history.slice(2),
{ role: 'user', parts: [{ text: 'simple task' }] },
]);
expect(contents).toHaveLength(HISTORY_TURNS_FOR_CONTEXT + 1);
});
it('should use a fallback promptId if not found in context', async () => {
@@ -15,12 +15,16 @@ import type {
import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
import { createUserContent, Type } from '@google/genai';
import type { Config } from '../../config/config.js';
import {
isFunctionCall,
isFunctionResponse,
} from '../../utils/messageInspectors.js';
import { debugLogger } from '../../utils/debugLogger.js';
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
import { LlmRole } from '../../telemetry/types.js';
// The number of recent history turns to provide to the router for context.
const HISTORY_TURNS_FOR_CONTEXT = 8;
export const HISTORY_TURNS_FOR_CONTEXT = 8;
const FLASH_MODEL = 'flash';
const PRO_MODEL = 'pro';
@@ -115,7 +119,22 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
const promptId = getPromptIdWithFallback('classifier-router');
const finalHistory = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT);
const candidateSlice = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT);
// Find the first non-tool turn. The server cannot always handle tool-related
// turns in the first slots of the contents array, so we strip them if they appear at the start.
let firstTextIndex = -1;
for (let i = 0; i < candidateSlice.length; i++) {
if (
!isFunctionCall(candidateSlice[i]) &&
!isFunctionResponse(candidateSlice[i])
) {
firstTextIndex = i;
break;
}
}
const finalHistory =
firstTextIndex === -1 ? [] : candidateSlice.slice(firstTextIndex);
// Wrap the user's request in tags to prevent prompt injection
const requestParts = Array.isArray(context.request)