mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 05:12:55 -07:00
fix(routing): Refactor tool turn handling for the conversation history in NumericalClassifierStrategy to prevent 400 Bad Request (#26761)
This commit is contained in:
@@ -5,7 +5,10 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { NumericalClassifierStrategy } from './numericalClassifierStrategy.js';
|
||||
import {
|
||||
NumericalClassifierStrategy,
|
||||
HISTORY_TURNS_FOR_CONTEXT,
|
||||
} from './numericalClassifierStrategy.js';
|
||||
import type { RoutingContext } from '../routingStrategy.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
@@ -423,18 +426,21 @@ describe('NumericalClassifierStrategy', () => {
|
||||
expect(consoleWarnSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should include tool-related history when sending to classifier', async () => {
|
||||
mockContext.history = [
|
||||
{ role: 'user', parts: [{ text: 'call a tool' }] },
|
||||
{ role: 'model', parts: [{ functionCall: { name: 'test_tool' } }] },
|
||||
it('should strip leading tool turns when history starts with tool calls', async () => {
|
||||
const history: Content[] = [
|
||||
{ role: 'model', parts: [{ functionCall: { name: 'leading_tool' } }] },
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{ functionResponse: { name: 'test_tool', response: { ok: true } } },
|
||||
{
|
||||
functionResponse: { name: 'leading_tool', response: { ok: true } },
|
||||
},
|
||||
],
|
||||
},
|
||||
{ role: 'user', parts: [{ text: 'another user turn' }] },
|
||||
{ role: 'model', parts: [{ text: 'text response 1' }] },
|
||||
{ role: 'user', parts: [{ text: 'text request 2' }] },
|
||||
];
|
||||
mockContext.history = history;
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Simple.',
|
||||
complexity_score: 10,
|
||||
@@ -454,9 +460,9 @@ describe('NumericalClassifierStrategy', () => {
|
||||
.calls[0][0];
|
||||
const contents = generateJsonCall.contents;
|
||||
|
||||
// Expect leading tool turns (index 0 and 1) to be stripped, keeping only text turns (index 2 and 3)
|
||||
const expectedContents = [
|
||||
...mockContext.history,
|
||||
// The last user turn is the request part
|
||||
...history.slice(2),
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'simple task' }],
|
||||
@@ -466,12 +472,25 @@ describe('NumericalClassifierStrategy', () => {
|
||||
expect(contents).toEqual(expectedContents);
|
||||
});
|
||||
|
||||
it('should respect HISTORY_TURNS_FOR_CONTEXT', async () => {
|
||||
const longHistory: Content[] = [];
|
||||
for (let i = 0; i < 30; i++) {
|
||||
longHistory.push({ role: 'user', parts: [{ text: `Message ${i}` }] });
|
||||
}
|
||||
mockContext.history = longHistory;
|
||||
it('should preserve tool turns when they appear after a non-tool turn in the middle of history', async () => {
|
||||
const history: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'turn 0 (before)' }] },
|
||||
{ role: 'model', parts: [{ text: 'turn 1 (before)' }] },
|
||||
{ role: 'user', parts: [{ text: 'turn 2 (before)' }] },
|
||||
{ role: 'model', parts: [{ text: 'turn 3 (before)' }] },
|
||||
{ role: 'model', parts: [{ functionCall: { name: 'middle_tool' } }] },
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{ functionResponse: { name: 'middle_tool', response: { ok: true } } },
|
||||
],
|
||||
},
|
||||
{ role: 'model', parts: [{ text: 'turn 6 (after)' }] },
|
||||
{ role: 'user', parts: [{ text: 'turn 7 (after)' }] },
|
||||
{ role: 'model', parts: [{ text: 'turn 8 (after)' }] },
|
||||
{ role: 'user', parts: [{ text: 'turn 9 (after)' }] },
|
||||
];
|
||||
mockContext.history = history;
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Simple.',
|
||||
complexity_score: 10,
|
||||
@@ -491,18 +510,188 @@ describe('NumericalClassifierStrategy', () => {
|
||||
.calls[0][0];
|
||||
const contents = generateJsonCall.contents;
|
||||
|
||||
// Manually calculate what the history should be
|
||||
const HISTORY_TURNS_FOR_CONTEXT = 8;
|
||||
const finalHistory = longHistory.slice(-HISTORY_TURNS_FOR_CONTEXT);
|
||||
// Expect all 8 sliced turns (starting from non-tool turn 2) to be preserved
|
||||
const expectedContents = [
|
||||
...history.slice(2),
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'simple task' }],
|
||||
},
|
||||
];
|
||||
|
||||
// Last part is the request
|
||||
const requestPart = {
|
||||
role: 'user',
|
||||
parts: [{ text: 'simple task' }],
|
||||
expect(contents).toEqual(expectedContents);
|
||||
});
|
||||
|
||||
it('should preserve tool turns when they appear at the very end of history following a non-tool turn', async () => {
|
||||
const history: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'turn 0' }] },
|
||||
{ role: 'model', parts: [{ text: 'turn 1' }] },
|
||||
{ role: 'user', parts: [{ text: 'turn 2' }] },
|
||||
{ role: 'model', parts: [{ text: 'turn 3' }] },
|
||||
{ role: 'user', parts: [{ text: 'turn 4' }] },
|
||||
{ role: 'model', parts: [{ text: 'turn 5' }] },
|
||||
{ role: 'user', parts: [{ text: 'turn 6' }] },
|
||||
{ role: 'model', parts: [{ text: 'turn 7' }] },
|
||||
{ role: 'model', parts: [{ functionCall: { name: 'end_tool' } }] },
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{ functionResponse: { name: 'end_tool', response: { ok: true } } },
|
||||
],
|
||||
},
|
||||
];
|
||||
mockContext.history = history;
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Simple.',
|
||||
complexity_score: 10,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
expect(contents).toEqual([...finalHistory, requestPart]);
|
||||
expect(contents).toHaveLength(9);
|
||||
await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||
.calls[0][0];
|
||||
const contents = generateJsonCall.contents;
|
||||
|
||||
// Expect all 8 sliced turns to be preserved because index 2 is a non-tool turn
|
||||
const expectedContents = [
|
||||
...history.slice(2),
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'simple task' }],
|
||||
},
|
||||
];
|
||||
|
||||
expect(contents).toEqual(expectedContents);
|
||||
});
|
||||
|
||||
it('should send only the new request prompt if the entire history consists of tool-related turns', async () => {
|
||||
const history: Content[] = [
|
||||
{ role: 'model', parts: [{ functionCall: { name: 'tool_A' } }] },
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{ functionResponse: { name: 'tool_A', response: { ok: true } } },
|
||||
],
|
||||
},
|
||||
{ role: 'model', parts: [{ functionCall: { name: 'tool_B' } }] },
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{ functionResponse: { name: 'tool_B', response: { ok: true } } },
|
||||
],
|
||||
},
|
||||
];
|
||||
mockContext.history = history;
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Simple standalone task.',
|
||||
complexity_score: 10,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||
.calls[0][0];
|
||||
const contents = generateJsonCall.contents;
|
||||
|
||||
// Expect all history turns to be filtered out, leaving exactly just the new request
|
||||
const expectedContents = [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'simple task' }],
|
||||
},
|
||||
];
|
||||
|
||||
expect(contents).toEqual(expectedContents);
|
||||
});
|
||||
|
||||
it('should respect HISTORY_TURNS_FOR_CONTEXT correctly when history has only text turns', async () => {
|
||||
const history: Content[] = [];
|
||||
for (let i = 0; i < HISTORY_TURNS_FOR_CONTEXT + 2; i++) {
|
||||
history.push({ role: 'user', parts: [{ text: `Message ${i}` }] });
|
||||
}
|
||||
mockContext.history = history;
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Simple.',
|
||||
complexity_score: 10,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||
.calls[0][0];
|
||||
const contents = generateJsonCall.contents;
|
||||
|
||||
// Expect exactly the last 8 turns (history.slice(2))
|
||||
expect(contents).toEqual([
|
||||
...history.slice(2),
|
||||
{ role: 'user', parts: [{ text: 'simple task' }] },
|
||||
]);
|
||||
expect(contents).toHaveLength(HISTORY_TURNS_FOR_CONTEXT + 1);
|
||||
});
|
||||
|
||||
it('should respect HISTORY_TURNS_FOR_CONTEXT correctly when history starts with tool calls', async () => {
|
||||
const history: Content[] = [
|
||||
{ role: 'model', parts: [{ functionCall: { name: 'tool_0' } }] },
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{ functionResponse: { name: 'tool_0', response: { ok: true } } },
|
||||
],
|
||||
},
|
||||
];
|
||||
for (let i = 0; i < HISTORY_TURNS_FOR_CONTEXT; i++) {
|
||||
history.push({ role: 'user', parts: [{ text: `Message ${i}` }] });
|
||||
}
|
||||
mockContext.history = history;
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Simple.',
|
||||
complexity_score: 10,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||
.calls[0][0];
|
||||
const contents = generateJsonCall.contents;
|
||||
|
||||
// Expect exactly the last 8 text turns (history.slice(2))
|
||||
expect(contents).toEqual([
|
||||
...history.slice(2),
|
||||
{ role: 'user', parts: [{ text: 'simple task' }] },
|
||||
]);
|
||||
expect(contents).toHaveLength(HISTORY_TURNS_FOR_CONTEXT + 1);
|
||||
});
|
||||
|
||||
it('should use a fallback promptId if not found in context', async () => {
|
||||
|
||||
@@ -15,12 +15,16 @@ import type {
|
||||
import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
|
||||
import { createUserContent, Type } from '@google/genai';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import {
|
||||
isFunctionCall,
|
||||
isFunctionResponse,
|
||||
} from '../../utils/messageInspectors.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js';
|
||||
import { LlmRole } from '../../telemetry/types.js';
|
||||
|
||||
// The number of recent history turns to provide to the router for context.
|
||||
const HISTORY_TURNS_FOR_CONTEXT = 8;
|
||||
export const HISTORY_TURNS_FOR_CONTEXT = 8;
|
||||
|
||||
const FLASH_MODEL = 'flash';
|
||||
const PRO_MODEL = 'pro';
|
||||
@@ -115,7 +119,22 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
|
||||
|
||||
const promptId = getPromptIdWithFallback('classifier-router');
|
||||
|
||||
const finalHistory = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT);
|
||||
const candidateSlice = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT);
|
||||
|
||||
// Find the first non-tool turn. The server cannot always handle tool-related
|
||||
// turns in the first slots of the contents array, so we strip them if they appear at the start.
|
||||
let firstTextIndex = -1;
|
||||
for (let i = 0; i < candidateSlice.length; i++) {
|
||||
if (
|
||||
!isFunctionCall(candidateSlice[i]) &&
|
||||
!isFunctionResponse(candidateSlice[i])
|
||||
) {
|
||||
firstTextIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const finalHistory =
|
||||
firstTextIndex === -1 ? [] : candidateSlice.slice(firstTextIndex);
|
||||
|
||||
// Wrap the user's request in tags to prevent prompt injection
|
||||
const requestParts = Array.isArray(context.request)
|
||||
|
||||
Reference in New Issue
Block a user