fix(core): prevent utility calls from changing session active model (#20035)

This commit is contained in:
Adam Weidman
2026-02-23 16:54:02 -05:00
committed by GitHub
parent 3e5e608a22
commit 767d80e768
6 changed files with 128 additions and 42 deletions
@@ -47,7 +47,10 @@ describe('Fallback Integration', () => {
const requestedModel = PREVIEW_GEMINI_MODEL; const requestedModel = PREVIEW_GEMINI_MODEL;
// 3. Apply model selection // 3. Apply model selection
const result = applyModelSelection(config, { model: requestedModel }); const result = applyModelSelection(config, {
model: requestedModel,
isChatModel: true,
});
// 4. Expect fallback to Flash // 4. Expect fallback to Flash
expect(result.model).toBe(PREVIEW_GEMINI_FLASH_MODEL); expect(result.model).toBe(PREVIEW_GEMINI_FLASH_MODEL);
@@ -222,7 +222,10 @@ describe('policyHelpers', () => {
selectedModel: 'gemini-pro', selectedModel: 'gemini-pro',
}); });
const result = applyModelSelection(config, { model: 'gemini-pro' }); const result = applyModelSelection(config, {
model: 'gemini-pro',
isChatModel: true,
});
expect(result.model).toBe('gemini-pro'); expect(result.model).toBe('gemini-pro');
expect(result.maxAttempts).toBeUndefined(); expect(result.maxAttempts).toBeUndefined();
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro'); expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
@@ -243,7 +246,10 @@ describe('policyHelpers', () => {
selectedModel: 'gemini-flash', selectedModel: 'gemini-flash',
}); });
const result = applyModelSelection(config, { model: 'gemini-pro' }); const result = applyModelSelection(config, {
model: 'gemini-pro',
isChatModel: true,
});
expect(result.model).toBe('gemini-flash'); expect(result.model).toBe('gemini-flash');
expect(result.config).toEqual({ expect(result.config).toEqual({
@@ -253,14 +259,33 @@ describe('policyHelpers', () => {
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({ expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
model: 'gemini-pro', model: 'gemini-pro',
isChatModel: true,
}); });
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({ expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
model: 'gemini-flash', model: 'gemini-flash',
isChatModel: true,
}); });
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-flash'); expect(config.setActiveModel).toHaveBeenCalledWith('gemini-flash');
}); });
it('consumes sticky attempt if indicated', () => { it('does not call setActiveModel if isChatModel is false', () => {
const config = createExtendedMockConfig();
mockModelConfigService.getResolvedConfig.mockReturnValue({
model: 'gemini-pro',
generateContentConfig: {},
});
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
selectedModel: 'gemini-pro',
});
applyModelSelection(config, {
model: 'gemini-pro',
isChatModel: false,
});
expect(config.setActiveModel).not.toHaveBeenCalled();
});
it('consumes sticky attempt if indicated and isChatModel is true', () => {
const config = createExtendedMockConfig(); const config = createExtendedMockConfig();
mockModelConfigService.getResolvedConfig.mockReturnValue({ mockModelConfigService.getResolvedConfig.mockReturnValue({
model: 'gemini-pro', model: 'gemini-pro',
@@ -271,10 +296,36 @@ describe('policyHelpers', () => {
attempts: 1, attempts: 1,
}); });
const result = applyModelSelection(config, { model: 'gemini-pro' }); const result = applyModelSelection(config, {
model: 'gemini-pro',
isChatModel: true,
});
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith( expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
'gemini-pro', 'gemini-pro',
); );
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
expect(result.maxAttempts).toBe(1);
});
it('consumes sticky attempt if indicated but does not call setActiveModel if isChatModel is false', () => {
const config = createExtendedMockConfig();
mockModelConfigService.getResolvedConfig.mockReturnValue({
model: 'gemini-pro',
generateContentConfig: {},
});
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
selectedModel: 'gemini-pro',
attempts: 1,
});
const result = applyModelSelection(config, {
model: 'gemini-pro',
isChatModel: false,
});
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
'gemini-pro',
);
expect(config.setActiveModel).not.toHaveBeenCalled();
expect(result.maxAttempts).toBe(1); expect(result.maxAttempts).toBe(1);
}); });
@@ -291,7 +342,7 @@ describe('policyHelpers', () => {
const result = applyModelSelection( const result = applyModelSelection(
config, config,
{ model: 'gemini-pro' }, { model: 'gemini-pro', isChatModel: true },
{ {
consumeAttempt: false, consumeAttempt: false,
}, },
@@ -299,6 +350,7 @@ describe('policyHelpers', () => {
expect( expect(
mockAvailabilityService.consumeStickyAttempt, mockAvailabilityService.consumeStickyAttempt,
).not.toHaveBeenCalled(); ).not.toHaveBeenCalled();
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
expect(result.maxAttempts).toBe(1); expect(result.maxAttempts).toBe(1);
}); });
}); });
@@ -214,7 +214,9 @@ export function applyModelSelection(
generateContentConfig = fallbackResolved.generateContentConfig; generateContentConfig = fallbackResolved.generateContentConfig;
} }
config.setActiveModel(finalModel); if (modelConfigKey.isChatModel) {
config.setActiveModel(finalModel);
}
if (selection.attempts && options.consumeAttempt !== false) { if (selection.attempts && options.consumeAttempt !== false) {
config.getModelAvailabilityService().consumeStickyAttempt(finalModel); config.getModelAvailabilityService().consumeStickyAttempt(finalModel);
+47 -25
View File
@@ -641,7 +641,7 @@ describe('BaseLlmClient', () => {
); );
contentOptions = { contentOptions = {
modelConfigKey: { model: 'test-model' }, modelConfigKey: { model: 'test-model', isChatModel: false },
contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }], contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }],
abortSignal: abortController.signal, abortSignal: abortController.signal,
promptId: 'content-prompt-id', promptId: 'content-prompt-id',
@@ -650,12 +650,17 @@ describe('BaseLlmClient', () => {
jsonOptions = { jsonOptions = {
...defaultOptions, ...defaultOptions,
modelConfigKey: {
...defaultOptions.modelConfigKey,
isChatModel: true,
},
promptId: 'json-prompt-id', promptId: 'json-prompt-id',
}; };
}); });
it('should mark model as healthy on success', async () => { it('should mark model as healthy on success', async () => {
const successfulModel = 'gemini-pro'; const successfulModel = 'gemini-pro';
mockConfig.getActiveModel.mockReturnValue(successfulModel);
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({ vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
selectedModel: successfulModel, selectedModel: successfulModel,
skipped: [], skipped: [],
@@ -666,7 +671,7 @@ describe('BaseLlmClient', () => {
await client.generateContent({ await client.generateContent({
...contentOptions, ...contentOptions,
modelConfigKey: { model: successfulModel }, modelConfigKey: { model: successfulModel, isChatModel: false },
role: LlmRole.UTILITY_TOOL, role: LlmRole.UTILITY_TOOL,
}); });
@@ -678,44 +683,55 @@ describe('BaseLlmClient', () => {
it('marks the final attempted model healthy after a retry with availability enabled', async () => { it('marks the final attempted model healthy after a retry with availability enabled', async () => {
const firstModel = 'gemini-pro'; const firstModel = 'gemini-pro';
const fallbackModel = 'gemini-flash'; const fallbackModel = 'gemini-flash';
let activeModel = firstModel;
mockConfig.getActiveModel.mockImplementation(() => activeModel);
mockConfig.setActiveModel.mockImplementation((m) => {
activeModel = m;
});
vi.mocked(mockAvailabilityService.selectFirstAvailable) vi.mocked(mockAvailabilityService.selectFirstAvailable)
.mockReturnValueOnce({ selectedModel: firstModel, skipped: [] }) .mockReturnValueOnce({ selectedModel: firstModel, skipped: [] })
.mockReturnValueOnce({ selectedModel: fallbackModel, skipped: [] }); .mockReturnValueOnce({ selectedModel: fallbackModel, skipped: [] });
// Mock generateContent to fail once and then succeed
mockGenerateContent mockGenerateContent
.mockResolvedValueOnce(createMockResponse('retry-me')) .mockResolvedValueOnce(createMockResponse(''))
.mockResolvedValueOnce(createMockResponse('final-response')); .mockResolvedValueOnce(createMockResponse('final-response'));
// Run the real retryWithBackoff (with fake timers) to exercise the retry path // 1. First call starts. applyModelSelection(firstModel) -> currentModel = firstModel.
vi.useFakeTimers(); // 2. apiCall() runs. getActiveModel() === firstModel. call(firstModel). returns ''.
// 3. retry triggers.
// 4. Second call starts. applyModelSelection(firstModel).
// selectFirstAvailable -> fallbackModel.
// setActiveModel(fallbackModel) -> activeModel = fallbackModel.
// returns fallbackModel.
// 5. apiCall() runs. getActiveModel() === fallbackModel. call(fallbackModel). returns 'final-response'.
const retryPromise = client.generateContent({ vi.mocked(retryWithBackoff).mockImplementation(async (fn) => {
// First call
let res = (await fn()) as GenerateContentResponse;
if (res.candidates?.[0]?.content?.parts?.[0]?.text === '') {
// Second call
activeModel = fallbackModel;
mockConfig.setActiveModel(fallbackModel);
res = (await fn()) as GenerateContentResponse;
}
mockAvailabilityService.markHealthy(activeModel);
return res;
});
const result = await client.generateContent({
...contentOptions, ...contentOptions,
modelConfigKey: { model: firstModel }, modelConfigKey: { model: firstModel, isChatModel: true },
maxAttempts: 2, maxAttempts: 2,
role: LlmRole.UTILITY_TOOL, role: LlmRole.UTILITY_TOOL,
}); });
await vi.runAllTimersAsync(); expect(result).toEqual(createMockResponse('final-response'));
await retryPromise;
await client.generateContent({
...contentOptions,
modelConfigKey: { model: firstModel },
maxAttempts: 2,
role: LlmRole.UTILITY_TOOL,
});
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(firstModel);
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(fallbackModel); expect(mockConfig.setActiveModel).toHaveBeenCalledWith(fallbackModel);
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith( expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
fallbackModel, fallbackModel,
); );
expect(mockGenerateContent).toHaveBeenLastCalledWith(
expect.objectContaining({ model: fallbackModel }),
expect.any(String),
LlmRole.UTILITY_TOOL,
);
}); });
it('should consume sticky attempt if selection has attempts', async () => { it('should consume sticky attempt if selection has attempts', async () => {
@@ -754,6 +770,7 @@ describe('BaseLlmClient', () => {
it('should mark healthy and honor availability selection when using generateJson', async () => { it('should mark healthy and honor availability selection when using generateJson', async () => {
const availableModel = 'gemini-json-pro'; const availableModel = 'gemini-json-pro';
mockConfig.getActiveModel.mockReturnValue(availableModel);
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({ vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
selectedModel: availableModel, selectedModel: availableModel,
skipped: [], skipped: [],
@@ -770,10 +787,15 @@ describe('BaseLlmClient', () => {
return result; return result;
}); });
const result = await client.generateJson(jsonOptions); const result = await client.generateJson({
...jsonOptions,
modelConfigKey: {
...jsonOptions.modelConfigKey,
isChatModel: false,
},
});
expect(result).toEqual({ color: 'violet' }); expect(result).toEqual({ color: 'violet' });
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(availableModel);
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith( expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
availableModel, availableModel,
); );
+6 -3
View File
@@ -280,19 +280,22 @@ export class BaseLlmClient {
() => currentModel, () => currentModel,
); );
let initialActiveModel = this.config.getActiveModel();
try { try {
const apiCall = () => { const apiCall = () => {
// Ensure we use the current active model // Ensure we use the current active model
// in case a fallback occurred in a previous attempt. // in case a fallback occurred in a previous attempt.
const activeModel = this.config.getActiveModel(); const activeModel = this.config.getActiveModel();
if (activeModel !== currentModel) { if (activeModel !== initialActiveModel) {
currentModel = activeModel; initialActiveModel = activeModel;
// Re-resolve config if model changed during retry // Re-resolve config if model changed during retry
const { generateContentConfig } = const { model: resolvedModel, generateContentConfig } =
this.config.modelConfigService.getResolvedConfig({ this.config.modelConfigService.getResolvedConfig({
...modelConfigKey, ...modelConfigKey,
model: activeModel, model: activeModel,
}); });
currentModel = resolvedModel;
currentGenerateContentConfig = generateContentConfig; currentGenerateContentConfig = generateContentConfig;
} }
const finalConfig: GenerateContentConfig = { const finalConfig: GenerateContentConfig = {
+11 -7
View File
@@ -957,17 +957,21 @@ export class GeminiClient {
() => currentAttemptModel, () => currentAttemptModel,
); );
let initialActiveModel = this.config.getActiveModel();
const apiCall = () => { const apiCall = () => {
// AvailabilityService // AvailabilityService
const active = this.config.getActiveModel(); const active = this.config.getActiveModel();
if (active !== currentAttemptModel) { if (active !== initialActiveModel) {
currentAttemptModel = active; initialActiveModel = active;
// Re-resolve config if model changed // Re-resolve config if model changed
const newConfig = this.config.modelConfigService.getResolvedConfig({ const { model: resolvedModel, generateContentConfig } =
...modelConfigKey, this.config.modelConfigService.getResolvedConfig({
model: currentAttemptModel, ...modelConfigKey,
}); model: active,
currentAttemptGenerateContentConfig = newConfig.generateContentConfig; });
currentAttemptModel = resolvedModel;
currentAttemptGenerateContentConfig = generateContentConfig;
} }
const requestConfig: GenerateContentConfig = { const requestConfig: GenerateContentConfig = {