diff --git a/packages/core/src/availability/fallbackIntegration.test.ts b/packages/core/src/availability/fallbackIntegration.test.ts index f9de1f3b2b..62174c9abb 100644 --- a/packages/core/src/availability/fallbackIntegration.test.ts +++ b/packages/core/src/availability/fallbackIntegration.test.ts @@ -47,7 +47,10 @@ describe('Fallback Integration', () => { const requestedModel = PREVIEW_GEMINI_MODEL; // 3. Apply model selection - const result = applyModelSelection(config, { model: requestedModel }); + const result = applyModelSelection(config, { + model: requestedModel, + isChatModel: true, + }); // 4. Expect fallback to Flash expect(result.model).toBe(PREVIEW_GEMINI_FLASH_MODEL); diff --git a/packages/core/src/availability/policyHelpers.test.ts b/packages/core/src/availability/policyHelpers.test.ts index 22b8a62700..2eb6129f61 100644 --- a/packages/core/src/availability/policyHelpers.test.ts +++ b/packages/core/src/availability/policyHelpers.test.ts @@ -222,7 +222,10 @@ describe('policyHelpers', () => { selectedModel: 'gemini-pro', }); - const result = applyModelSelection(config, { model: 'gemini-pro' }); + const result = applyModelSelection(config, { + model: 'gemini-pro', + isChatModel: true, + }); expect(result.model).toBe('gemini-pro'); expect(result.maxAttempts).toBeUndefined(); expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro'); @@ -243,7 +246,10 @@ describe('policyHelpers', () => { selectedModel: 'gemini-flash', }); - const result = applyModelSelection(config, { model: 'gemini-pro' }); + const result = applyModelSelection(config, { + model: 'gemini-pro', + isChatModel: true, + }); expect(result.model).toBe('gemini-flash'); expect(result.config).toEqual({ @@ -253,14 +259,33 @@ describe('policyHelpers', () => { expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({ model: 'gemini-pro', + isChatModel: true, }); expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({ model: 'gemini-flash', + isChatModel: true, }); expect(config.setActiveModel).toHaveBeenCalledWith('gemini-flash'); }); - it('consumes sticky attempt if indicated', () => { + it('does not call setActiveModel if isChatModel is false', () => { + const config = createExtendedMockConfig(); + mockModelConfigService.getResolvedConfig.mockReturnValue({ + model: 'gemini-pro', + generateContentConfig: {}, + }); + mockAvailabilityService.selectFirstAvailable.mockReturnValue({ + selectedModel: 'gemini-pro', + }); + + applyModelSelection(config, { + model: 'gemini-pro', + isChatModel: false, + }); + expect(config.setActiveModel).not.toHaveBeenCalled(); + }); + + it('consumes sticky attempt if indicated and isChatModel is true', () => { const config = createExtendedMockConfig(); mockModelConfigService.getResolvedConfig.mockReturnValue({ model: 'gemini-pro', @@ -271,10 +296,36 @@ describe('policyHelpers', () => { attempts: 1, }); - const result = applyModelSelection(config, { model: 'gemini-pro' }); + const result = applyModelSelection(config, { + model: 'gemini-pro', + isChatModel: true, + }); expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith( 'gemini-pro', ); + expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro'); + expect(result.maxAttempts).toBe(1); + }); + + it('consumes sticky attempt if indicated but does not call setActiveModel if isChatModel is false', () => { + const config = createExtendedMockConfig(); + mockModelConfigService.getResolvedConfig.mockReturnValue({ + model: 'gemini-pro', + generateContentConfig: {}, + }); + mockAvailabilityService.selectFirstAvailable.mockReturnValue({ + selectedModel: 'gemini-pro', + attempts: 1, + }); + + const result = applyModelSelection(config, { + model: 'gemini-pro', + isChatModel: false, + }); + expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith( + 'gemini-pro', + ); + expect(config.setActiveModel).not.toHaveBeenCalled(); expect(result.maxAttempts).toBe(1); }); @@ -291,7 +342,7 @@ describe('policyHelpers', () => { const result = applyModelSelection( config, - { model: 'gemini-pro' }, + { model: 'gemini-pro', isChatModel: true }, { consumeAttempt: false, }, @@ -299,6 +350,7 @@ describe('policyHelpers', () => { expect( mockAvailabilityService.consumeStickyAttempt, ).not.toHaveBeenCalled(); + expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro'); expect(result.maxAttempts).toBe(1); }); }); diff --git a/packages/core/src/availability/policyHelpers.ts b/packages/core/src/availability/policyHelpers.ts index 456c8a855f..05c1dd19f9 100644 --- a/packages/core/src/availability/policyHelpers.ts +++ b/packages/core/src/availability/policyHelpers.ts @@ -214,7 +214,9 @@ export function applyModelSelection( generateContentConfig = fallbackResolved.generateContentConfig; } - config.setActiveModel(finalModel); + if (modelConfigKey.isChatModel) { + config.setActiveModel(finalModel); + } if (selection.attempts && options.consumeAttempt !== false) { config.getModelAvailabilityService().consumeStickyAttempt(finalModel); diff --git a/packages/core/src/core/baseLlmClient.test.ts b/packages/core/src/core/baseLlmClient.test.ts index d067ec49ef..db1086fe81 100644 --- a/packages/core/src/core/baseLlmClient.test.ts +++ b/packages/core/src/core/baseLlmClient.test.ts @@ -641,7 +641,7 @@ describe('BaseLlmClient', () => { ); contentOptions = { - modelConfigKey: { model: 'test-model' }, + modelConfigKey: { model: 'test-model', isChatModel: false }, contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }], abortSignal: abortController.signal, promptId: 'content-prompt-id', @@ -650,12 +650,17 @@ describe('BaseLlmClient', () => { jsonOptions = { ...defaultOptions, + modelConfigKey: { + ...defaultOptions.modelConfigKey, + isChatModel: true, + }, promptId: 'json-prompt-id', }; }); it('should mark model as healthy on success', async () => { const successfulModel = 'gemini-pro'; + mockConfig.getActiveModel.mockReturnValue(successfulModel); vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({ selectedModel: successfulModel, skipped: [], @@ -666,7 +671,7 @@ describe('BaseLlmClient', () => { await client.generateContent({ ...contentOptions, - modelConfigKey: { model: successfulModel }, + modelConfigKey: { model: successfulModel, isChatModel: false }, role: LlmRole.UTILITY_TOOL, }); @@ -678,44 +683,55 @@ describe('BaseLlmClient', () => { it('marks the final attempted model healthy after a retry with availability enabled', async () => { const firstModel = 'gemini-pro'; const fallbackModel = 'gemini-flash'; + let activeModel = firstModel; + mockConfig.getActiveModel.mockImplementation(() => activeModel); + mockConfig.setActiveModel.mockImplementation((m) => { + activeModel = m; + }); + vi.mocked(mockAvailabilityService.selectFirstAvailable) .mockReturnValueOnce({ selectedModel: firstModel, skipped: [] }) .mockReturnValueOnce({ selectedModel: fallbackModel, skipped: [] }); + // Mock generateContent to fail once and then succeed mockGenerateContent - .mockResolvedValueOnce(createMockResponse('retry-me')) + .mockResolvedValueOnce(createMockResponse('')) .mockResolvedValueOnce(createMockResponse('final-response')); - // Run the real retryWithBackoff (with fake timers) to exercise the retry path - vi.useFakeTimers(); + // 1. First call starts. applyModelSelection(firstModel) -> currentModel = firstModel. + // 2. apiCall() runs. getActiveModel() === firstModel. call(firstModel). returns ''. + // 3. retry triggers. + // 4. Second call starts. applyModelSelection(firstModel). + // selectFirstAvailable -> fallbackModel. + // setActiveModel(fallbackModel) -> activeModel = fallbackModel. + // returns fallbackModel. + // 5. apiCall() runs. getActiveModel() === fallbackModel. call(fallbackModel). returns 'final-response'. - const retryPromise = client.generateContent({ + vi.mocked(retryWithBackoff).mockImplementation(async (fn) => { + // First call + let res = (await fn()) as GenerateContentResponse; + if (res.candidates?.[0]?.content?.parts?.[0]?.text === '') { + // Second call + activeModel = fallbackModel; + mockConfig.setActiveModel(fallbackModel); + res = (await fn()) as GenerateContentResponse; + } + mockAvailabilityService.markHealthy(activeModel); + return res; + }); + + const result = await client.generateContent({ ...contentOptions, - modelConfigKey: { model: firstModel }, + modelConfigKey: { model: firstModel, isChatModel: true }, maxAttempts: 2, role: LlmRole.UTILITY_TOOL, }); - await vi.runAllTimersAsync(); - await retryPromise; - - await client.generateContent({ - ...contentOptions, - modelConfigKey: { model: firstModel }, - maxAttempts: 2, - role: LlmRole.UTILITY_TOOL, - }); - - expect(mockConfig.setActiveModel).toHaveBeenCalledWith(firstModel); + expect(result).toEqual(createMockResponse('final-response')); expect(mockConfig.setActiveModel).toHaveBeenCalledWith(fallbackModel); expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith( fallbackModel, ); - expect(mockGenerateContent).toHaveBeenLastCalledWith( - expect.objectContaining({ model: fallbackModel }), - expect.any(String), - LlmRole.UTILITY_TOOL, - ); }); it('should consume sticky attempt if selection has attempts', async () => { @@ -754,6 +770,7 @@ describe('BaseLlmClient', () => { it('should mark healthy and honor availability selection when using generateJson', async () => { const availableModel = 'gemini-json-pro'; + mockConfig.getActiveModel.mockReturnValue(availableModel); vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({ selectedModel: availableModel, skipped: [], @@ -770,10 +787,15 @@ describe('BaseLlmClient', () => { return result; }); - const result = await client.generateJson(jsonOptions); + const result = await client.generateJson({ + ...jsonOptions, + modelConfigKey: { + ...jsonOptions.modelConfigKey, + isChatModel: false, + }, + }); expect(result).toEqual({ color: 'violet' }); - expect(mockConfig.setActiveModel).toHaveBeenCalledWith(availableModel); expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith( availableModel, ); diff --git a/packages/core/src/core/baseLlmClient.ts b/packages/core/src/core/baseLlmClient.ts index 64442ac86e..0de4dd1e20 100644 --- a/packages/core/src/core/baseLlmClient.ts +++ b/packages/core/src/core/baseLlmClient.ts @@ -280,19 +280,22 @@ export class BaseLlmClient { () => currentModel, ); + let initialActiveModel = this.config.getActiveModel(); + try { const apiCall = () => { // Ensure we use the current active model // in case a fallback occurred in a previous attempt. const activeModel = this.config.getActiveModel(); - if (activeModel !== currentModel) { - currentModel = activeModel; + if (activeModel !== initialActiveModel) { + initialActiveModel = activeModel; // Re-resolve config if model changed during retry - const { generateContentConfig } = + const { model: resolvedModel, generateContentConfig } = this.config.modelConfigService.getResolvedConfig({ ...modelConfigKey, model: activeModel, }); + currentModel = resolvedModel; currentGenerateContentConfig = generateContentConfig; } const finalConfig: GenerateContentConfig = { diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 56447468bd..c94dd5c04d 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -957,17 +957,21 @@ export class GeminiClient { () => currentAttemptModel, ); + let initialActiveModel = this.config.getActiveModel(); + const apiCall = () => { // AvailabilityService const active = this.config.getActiveModel(); - if (active !== currentAttemptModel) { - currentAttemptModel = active; + if (active !== initialActiveModel) { + initialActiveModel = active; // Re-resolve config if model changed - const newConfig = this.config.modelConfigService.getResolvedConfig({ - ...modelConfigKey, - model: currentAttemptModel, - }); - currentAttemptGenerateContentConfig = newConfig.generateContentConfig; + const { model: resolvedModel, generateContentConfig } = + this.config.modelConfigService.getResolvedConfig({ + ...modelConfigKey, + model: active, + }); + currentAttemptModel = resolvedModel; + currentAttemptGenerateContentConfig = generateContentConfig; } const requestConfig: GenerateContentConfig = {