fix(core): prevent utility calls from changing session active model (#20035)

This commit is contained in:
Adam Weidman
2026-02-23 16:54:02 -05:00
committed by GitHub
parent 3e5e608a22
commit 767d80e768
6 changed files with 128 additions and 42 deletions

View File

@@ -47,7 +47,10 @@ describe('Fallback Integration', () => {
const requestedModel = PREVIEW_GEMINI_MODEL;
// 3. Apply model selection
const result = applyModelSelection(config, { model: requestedModel });
const result = applyModelSelection(config, {
model: requestedModel,
isChatModel: true,
});
// 4. Expect fallback to Flash
expect(result.model).toBe(PREVIEW_GEMINI_FLASH_MODEL);

View File

@@ -222,7 +222,10 @@ describe('policyHelpers', () => {
selectedModel: 'gemini-pro',
});
const result = applyModelSelection(config, { model: 'gemini-pro' });
const result = applyModelSelection(config, {
model: 'gemini-pro',
isChatModel: true,
});
expect(result.model).toBe('gemini-pro');
expect(result.maxAttempts).toBeUndefined();
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
@@ -243,7 +246,10 @@ describe('policyHelpers', () => {
selectedModel: 'gemini-flash',
});
const result = applyModelSelection(config, { model: 'gemini-pro' });
const result = applyModelSelection(config, {
model: 'gemini-pro',
isChatModel: true,
});
expect(result.model).toBe('gemini-flash');
expect(result.config).toEqual({
@@ -253,14 +259,33 @@ describe('policyHelpers', () => {
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
model: 'gemini-pro',
isChatModel: true,
});
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
model: 'gemini-flash',
isChatModel: true,
});
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-flash');
});
it('consumes sticky attempt if indicated', () => {
it('does not call setActiveModel if isChatModel is false', () => {
const config = createExtendedMockConfig();
mockModelConfigService.getResolvedConfig.mockReturnValue({
model: 'gemini-pro',
generateContentConfig: {},
});
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
selectedModel: 'gemini-pro',
});
applyModelSelection(config, {
model: 'gemini-pro',
isChatModel: false,
});
expect(config.setActiveModel).not.toHaveBeenCalled();
});
it('consumes sticky attempt if indicated and isChatModel is true', () => {
const config = createExtendedMockConfig();
mockModelConfigService.getResolvedConfig.mockReturnValue({
model: 'gemini-pro',
@@ -271,10 +296,36 @@ describe('policyHelpers', () => {
attempts: 1,
});
const result = applyModelSelection(config, { model: 'gemini-pro' });
const result = applyModelSelection(config, {
model: 'gemini-pro',
isChatModel: true,
});
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
'gemini-pro',
);
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
expect(result.maxAttempts).toBe(1);
});
it('consumes sticky attempt if indicated but does not call setActiveModel if isChatModel is false', () => {
const config = createExtendedMockConfig();
mockModelConfigService.getResolvedConfig.mockReturnValue({
model: 'gemini-pro',
generateContentConfig: {},
});
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
selectedModel: 'gemini-pro',
attempts: 1,
});
const result = applyModelSelection(config, {
model: 'gemini-pro',
isChatModel: false,
});
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
'gemini-pro',
);
expect(config.setActiveModel).not.toHaveBeenCalled();
expect(result.maxAttempts).toBe(1);
});
@@ -291,7 +342,7 @@ describe('policyHelpers', () => {
const result = applyModelSelection(
config,
{ model: 'gemini-pro' },
{ model: 'gemini-pro', isChatModel: true },
{
consumeAttempt: false,
},
@@ -299,6 +350,7 @@ describe('policyHelpers', () => {
expect(
mockAvailabilityService.consumeStickyAttempt,
).not.toHaveBeenCalled();
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
expect(result.maxAttempts).toBe(1);
});
});

View File

@@ -214,7 +214,9 @@ export function applyModelSelection(
generateContentConfig = fallbackResolved.generateContentConfig;
}
config.setActiveModel(finalModel);
if (modelConfigKey.isChatModel) {
config.setActiveModel(finalModel);
}
if (selection.attempts && options.consumeAttempt !== false) {
config.getModelAvailabilityService().consumeStickyAttempt(finalModel);

View File

@@ -641,7 +641,7 @@ describe('BaseLlmClient', () => {
);
contentOptions = {
modelConfigKey: { model: 'test-model' },
modelConfigKey: { model: 'test-model', isChatModel: false },
contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }],
abortSignal: abortController.signal,
promptId: 'content-prompt-id',
@@ -650,12 +650,17 @@ describe('BaseLlmClient', () => {
jsonOptions = {
...defaultOptions,
modelConfigKey: {
...defaultOptions.modelConfigKey,
isChatModel: true,
},
promptId: 'json-prompt-id',
};
});
it('should mark model as healthy on success', async () => {
const successfulModel = 'gemini-pro';
mockConfig.getActiveModel.mockReturnValue(successfulModel);
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
selectedModel: successfulModel,
skipped: [],
@@ -666,7 +671,7 @@ describe('BaseLlmClient', () => {
await client.generateContent({
...contentOptions,
modelConfigKey: { model: successfulModel },
modelConfigKey: { model: successfulModel, isChatModel: false },
role: LlmRole.UTILITY_TOOL,
});
@@ -678,44 +683,55 @@ describe('BaseLlmClient', () => {
it('marks the final attempted model healthy after a retry with availability enabled', async () => {
const firstModel = 'gemini-pro';
const fallbackModel = 'gemini-flash';
let activeModel = firstModel;
mockConfig.getActiveModel.mockImplementation(() => activeModel);
mockConfig.setActiveModel.mockImplementation((m) => {
activeModel = m;
});
vi.mocked(mockAvailabilityService.selectFirstAvailable)
.mockReturnValueOnce({ selectedModel: firstModel, skipped: [] })
.mockReturnValueOnce({ selectedModel: fallbackModel, skipped: [] });
// Mock generateContent to fail once and then succeed
mockGenerateContent
.mockResolvedValueOnce(createMockResponse('retry-me'))
.mockResolvedValueOnce(createMockResponse(''))
.mockResolvedValueOnce(createMockResponse('final-response'));
// Run the real retryWithBackoff (with fake timers) to exercise the retry path
vi.useFakeTimers();
// 1. First call starts. applyModelSelection(firstModel) -> currentModel = firstModel.
// 2. apiCall() runs. getActiveModel() === firstModel. call(firstModel). returns ''.
// 3. retry triggers.
// 4. Second call starts. applyModelSelection(firstModel).
// selectFirstAvailable -> fallbackModel.
// setActiveModel(fallbackModel) -> activeModel = fallbackModel.
// returns fallbackModel.
// 5. apiCall() runs. getActiveModel() === fallbackModel. call(fallbackModel). returns 'final-response'.
const retryPromise = client.generateContent({
vi.mocked(retryWithBackoff).mockImplementation(async (fn) => {
// First call
let res = (await fn()) as GenerateContentResponse;
if (res.candidates?.[0]?.content?.parts?.[0]?.text === '') {
// Second call
activeModel = fallbackModel;
mockConfig.setActiveModel(fallbackModel);
res = (await fn()) as GenerateContentResponse;
}
mockAvailabilityService.markHealthy(activeModel);
return res;
});
const result = await client.generateContent({
...contentOptions,
modelConfigKey: { model: firstModel },
modelConfigKey: { model: firstModel, isChatModel: true },
maxAttempts: 2,
role: LlmRole.UTILITY_TOOL,
});
await vi.runAllTimersAsync();
await retryPromise;
await client.generateContent({
...contentOptions,
modelConfigKey: { model: firstModel },
maxAttempts: 2,
role: LlmRole.UTILITY_TOOL,
});
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(firstModel);
expect(result).toEqual(createMockResponse('final-response'));
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(fallbackModel);
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
fallbackModel,
);
expect(mockGenerateContent).toHaveBeenLastCalledWith(
expect.objectContaining({ model: fallbackModel }),
expect.any(String),
LlmRole.UTILITY_TOOL,
);
});
it('should consume sticky attempt if selection has attempts', async () => {
@@ -754,6 +770,7 @@ describe('BaseLlmClient', () => {
it('should mark healthy and honor availability selection when using generateJson', async () => {
const availableModel = 'gemini-json-pro';
mockConfig.getActiveModel.mockReturnValue(availableModel);
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
selectedModel: availableModel,
skipped: [],
@@ -770,10 +787,15 @@ describe('BaseLlmClient', () => {
return result;
});
const result = await client.generateJson(jsonOptions);
const result = await client.generateJson({
...jsonOptions,
modelConfigKey: {
...jsonOptions.modelConfigKey,
isChatModel: false,
},
});
expect(result).toEqual({ color: 'violet' });
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(availableModel);
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
availableModel,
);

View File

@@ -280,19 +280,22 @@ export class BaseLlmClient {
() => currentModel,
);
let initialActiveModel = this.config.getActiveModel();
try {
const apiCall = () => {
// Ensure we use the current active model
// in case a fallback occurred in a previous attempt.
const activeModel = this.config.getActiveModel();
if (activeModel !== currentModel) {
currentModel = activeModel;
if (activeModel !== initialActiveModel) {
initialActiveModel = activeModel;
// Re-resolve config if model changed during retry
const { generateContentConfig } =
const { model: resolvedModel, generateContentConfig } =
this.config.modelConfigService.getResolvedConfig({
...modelConfigKey,
model: activeModel,
});
currentModel = resolvedModel;
currentGenerateContentConfig = generateContentConfig;
}
const finalConfig: GenerateContentConfig = {

View File

@@ -957,17 +957,21 @@ export class GeminiClient {
() => currentAttemptModel,
);
let initialActiveModel = this.config.getActiveModel();
const apiCall = () => {
// AvailabilityService
const active = this.config.getActiveModel();
if (active !== currentAttemptModel) {
currentAttemptModel = active;
if (active !== initialActiveModel) {
initialActiveModel = active;
// Re-resolve config if model changed
const newConfig = this.config.modelConfigService.getResolvedConfig({
...modelConfigKey,
model: currentAttemptModel,
});
currentAttemptGenerateContentConfig = newConfig.generateContentConfig;
const { model: resolvedModel, generateContentConfig } =
this.config.modelConfigService.getResolvedConfig({
...modelConfigKey,
model: active,
});
currentAttemptModel = resolvedModel;
currentAttemptGenerateContentConfig = generateContentConfig;
}
const requestConfig: GenerateContentConfig = {