mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-17 09:30:58 -07:00
fix(core): prevent utility calls from changing session active model (#20035)
This commit is contained in:
@@ -47,7 +47,10 @@ describe('Fallback Integration', () => {
|
||||
const requestedModel = PREVIEW_GEMINI_MODEL;
|
||||
|
||||
// 3. Apply model selection
|
||||
const result = applyModelSelection(config, { model: requestedModel });
|
||||
const result = applyModelSelection(config, {
|
||||
model: requestedModel,
|
||||
isChatModel: true,
|
||||
});
|
||||
|
||||
// 4. Expect fallback to Flash
|
||||
expect(result.model).toBe(PREVIEW_GEMINI_FLASH_MODEL);
|
||||
|
||||
@@ -222,7 +222,10 @@ describe('policyHelpers', () => {
|
||||
selectedModel: 'gemini-pro',
|
||||
});
|
||||
|
||||
const result = applyModelSelection(config, { model: 'gemini-pro' });
|
||||
const result = applyModelSelection(config, {
|
||||
model: 'gemini-pro',
|
||||
isChatModel: true,
|
||||
});
|
||||
expect(result.model).toBe('gemini-pro');
|
||||
expect(result.maxAttempts).toBeUndefined();
|
||||
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
|
||||
@@ -243,7 +246,10 @@ describe('policyHelpers', () => {
|
||||
selectedModel: 'gemini-flash',
|
||||
});
|
||||
|
||||
const result = applyModelSelection(config, { model: 'gemini-pro' });
|
||||
const result = applyModelSelection(config, {
|
||||
model: 'gemini-pro',
|
||||
isChatModel: true,
|
||||
});
|
||||
|
||||
expect(result.model).toBe('gemini-flash');
|
||||
expect(result.config).toEqual({
|
||||
@@ -253,14 +259,33 @@ describe('policyHelpers', () => {
|
||||
|
||||
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
|
||||
model: 'gemini-pro',
|
||||
isChatModel: true,
|
||||
});
|
||||
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
|
||||
model: 'gemini-flash',
|
||||
isChatModel: true,
|
||||
});
|
||||
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-flash');
|
||||
});
|
||||
|
||||
it('consumes sticky attempt if indicated', () => {
|
||||
it('does not call setActiveModel if isChatModel is false', () => {
|
||||
const config = createExtendedMockConfig();
|
||||
mockModelConfigService.getResolvedConfig.mockReturnValue({
|
||||
model: 'gemini-pro',
|
||||
generateContentConfig: {},
|
||||
});
|
||||
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
||||
selectedModel: 'gemini-pro',
|
||||
});
|
||||
|
||||
applyModelSelection(config, {
|
||||
model: 'gemini-pro',
|
||||
isChatModel: false,
|
||||
});
|
||||
expect(config.setActiveModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('consumes sticky attempt if indicated and isChatModel is true', () => {
|
||||
const config = createExtendedMockConfig();
|
||||
mockModelConfigService.getResolvedConfig.mockReturnValue({
|
||||
model: 'gemini-pro',
|
||||
@@ -271,10 +296,36 @@ describe('policyHelpers', () => {
|
||||
attempts: 1,
|
||||
});
|
||||
|
||||
const result = applyModelSelection(config, { model: 'gemini-pro' });
|
||||
const result = applyModelSelection(config, {
|
||||
model: 'gemini-pro',
|
||||
isChatModel: true,
|
||||
});
|
||||
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
|
||||
'gemini-pro',
|
||||
);
|
||||
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
|
||||
expect(result.maxAttempts).toBe(1);
|
||||
});
|
||||
|
||||
it('consumes sticky attempt if indicated but does not call setActiveModel if isChatModel is false', () => {
|
||||
const config = createExtendedMockConfig();
|
||||
mockModelConfigService.getResolvedConfig.mockReturnValue({
|
||||
model: 'gemini-pro',
|
||||
generateContentConfig: {},
|
||||
});
|
||||
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
||||
selectedModel: 'gemini-pro',
|
||||
attempts: 1,
|
||||
});
|
||||
|
||||
const result = applyModelSelection(config, {
|
||||
model: 'gemini-pro',
|
||||
isChatModel: false,
|
||||
});
|
||||
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
|
||||
'gemini-pro',
|
||||
);
|
||||
expect(config.setActiveModel).not.toHaveBeenCalled();
|
||||
expect(result.maxAttempts).toBe(1);
|
||||
});
|
||||
|
||||
@@ -291,7 +342,7 @@ describe('policyHelpers', () => {
|
||||
|
||||
const result = applyModelSelection(
|
||||
config,
|
||||
{ model: 'gemini-pro' },
|
||||
{ model: 'gemini-pro', isChatModel: true },
|
||||
{
|
||||
consumeAttempt: false,
|
||||
},
|
||||
@@ -299,6 +350,7 @@ describe('policyHelpers', () => {
|
||||
expect(
|
||||
mockAvailabilityService.consumeStickyAttempt,
|
||||
).not.toHaveBeenCalled();
|
||||
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
|
||||
expect(result.maxAttempts).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -214,7 +214,9 @@ export function applyModelSelection(
|
||||
generateContentConfig = fallbackResolved.generateContentConfig;
|
||||
}
|
||||
|
||||
config.setActiveModel(finalModel);
|
||||
if (modelConfigKey.isChatModel) {
|
||||
config.setActiveModel(finalModel);
|
||||
}
|
||||
|
||||
if (selection.attempts && options.consumeAttempt !== false) {
|
||||
config.getModelAvailabilityService().consumeStickyAttempt(finalModel);
|
||||
|
||||
@@ -641,7 +641,7 @@ describe('BaseLlmClient', () => {
|
||||
);
|
||||
|
||||
contentOptions = {
|
||||
modelConfigKey: { model: 'test-model' },
|
||||
modelConfigKey: { model: 'test-model', isChatModel: false },
|
||||
contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }],
|
||||
abortSignal: abortController.signal,
|
||||
promptId: 'content-prompt-id',
|
||||
@@ -650,12 +650,17 @@ describe('BaseLlmClient', () => {
|
||||
|
||||
jsonOptions = {
|
||||
...defaultOptions,
|
||||
modelConfigKey: {
|
||||
...defaultOptions.modelConfigKey,
|
||||
isChatModel: true,
|
||||
},
|
||||
promptId: 'json-prompt-id',
|
||||
};
|
||||
});
|
||||
|
||||
it('should mark model as healthy on success', async () => {
|
||||
const successfulModel = 'gemini-pro';
|
||||
mockConfig.getActiveModel.mockReturnValue(successfulModel);
|
||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
|
||||
selectedModel: successfulModel,
|
||||
skipped: [],
|
||||
@@ -666,7 +671,7 @@ describe('BaseLlmClient', () => {
|
||||
|
||||
await client.generateContent({
|
||||
...contentOptions,
|
||||
modelConfigKey: { model: successfulModel },
|
||||
modelConfigKey: { model: successfulModel, isChatModel: false },
|
||||
role: LlmRole.UTILITY_TOOL,
|
||||
});
|
||||
|
||||
@@ -678,44 +683,55 @@ describe('BaseLlmClient', () => {
|
||||
it('marks the final attempted model healthy after a retry with availability enabled', async () => {
|
||||
const firstModel = 'gemini-pro';
|
||||
const fallbackModel = 'gemini-flash';
|
||||
let activeModel = firstModel;
|
||||
mockConfig.getActiveModel.mockImplementation(() => activeModel);
|
||||
mockConfig.setActiveModel.mockImplementation((m) => {
|
||||
activeModel = m;
|
||||
});
|
||||
|
||||
vi.mocked(mockAvailabilityService.selectFirstAvailable)
|
||||
.mockReturnValueOnce({ selectedModel: firstModel, skipped: [] })
|
||||
.mockReturnValueOnce({ selectedModel: fallbackModel, skipped: [] });
|
||||
|
||||
// Mock generateContent to fail once and then succeed
|
||||
mockGenerateContent
|
||||
.mockResolvedValueOnce(createMockResponse('retry-me'))
|
||||
.mockResolvedValueOnce(createMockResponse(''))
|
||||
.mockResolvedValueOnce(createMockResponse('final-response'));
|
||||
|
||||
// Run the real retryWithBackoff (with fake timers) to exercise the retry path
|
||||
vi.useFakeTimers();
|
||||
// 1. First call starts. applyModelSelection(firstModel) -> currentModel = firstModel.
|
||||
// 2. apiCall() runs. getActiveModel() === firstModel. call(firstModel). returns ''.
|
||||
// 3. retry triggers.
|
||||
// 4. Second call starts. applyModelSelection(firstModel).
|
||||
// selectFirstAvailable -> fallbackModel.
|
||||
// setActiveModel(fallbackModel) -> activeModel = fallbackModel.
|
||||
// returns fallbackModel.
|
||||
// 5. apiCall() runs. getActiveModel() === fallbackModel. call(fallbackModel). returns 'final-response'.
|
||||
|
||||
const retryPromise = client.generateContent({
|
||||
vi.mocked(retryWithBackoff).mockImplementation(async (fn) => {
|
||||
// First call
|
||||
let res = (await fn()) as GenerateContentResponse;
|
||||
if (res.candidates?.[0]?.content?.parts?.[0]?.text === '') {
|
||||
// Second call
|
||||
activeModel = fallbackModel;
|
||||
mockConfig.setActiveModel(fallbackModel);
|
||||
res = (await fn()) as GenerateContentResponse;
|
||||
}
|
||||
mockAvailabilityService.markHealthy(activeModel);
|
||||
return res;
|
||||
});
|
||||
|
||||
const result = await client.generateContent({
|
||||
...contentOptions,
|
||||
modelConfigKey: { model: firstModel },
|
||||
modelConfigKey: { model: firstModel, isChatModel: true },
|
||||
maxAttempts: 2,
|
||||
role: LlmRole.UTILITY_TOOL,
|
||||
});
|
||||
|
||||
await vi.runAllTimersAsync();
|
||||
await retryPromise;
|
||||
|
||||
await client.generateContent({
|
||||
...contentOptions,
|
||||
modelConfigKey: { model: firstModel },
|
||||
maxAttempts: 2,
|
||||
role: LlmRole.UTILITY_TOOL,
|
||||
});
|
||||
|
||||
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(firstModel);
|
||||
expect(result).toEqual(createMockResponse('final-response'));
|
||||
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(fallbackModel);
|
||||
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
|
||||
fallbackModel,
|
||||
);
|
||||
expect(mockGenerateContent).toHaveBeenLastCalledWith(
|
||||
expect.objectContaining({ model: fallbackModel }),
|
||||
expect.any(String),
|
||||
LlmRole.UTILITY_TOOL,
|
||||
);
|
||||
});
|
||||
|
||||
it('should consume sticky attempt if selection has attempts', async () => {
|
||||
@@ -754,6 +770,7 @@ describe('BaseLlmClient', () => {
|
||||
|
||||
it('should mark healthy and honor availability selection when using generateJson', async () => {
|
||||
const availableModel = 'gemini-json-pro';
|
||||
mockConfig.getActiveModel.mockReturnValue(availableModel);
|
||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
|
||||
selectedModel: availableModel,
|
||||
skipped: [],
|
||||
@@ -770,10 +787,15 @@ describe('BaseLlmClient', () => {
|
||||
return result;
|
||||
});
|
||||
|
||||
const result = await client.generateJson(jsonOptions);
|
||||
const result = await client.generateJson({
|
||||
...jsonOptions,
|
||||
modelConfigKey: {
|
||||
...jsonOptions.modelConfigKey,
|
||||
isChatModel: false,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result).toEqual({ color: 'violet' });
|
||||
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(availableModel);
|
||||
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
|
||||
availableModel,
|
||||
);
|
||||
|
||||
@@ -280,19 +280,22 @@ export class BaseLlmClient {
|
||||
() => currentModel,
|
||||
);
|
||||
|
||||
let initialActiveModel = this.config.getActiveModel();
|
||||
|
||||
try {
|
||||
const apiCall = () => {
|
||||
// Ensure we use the current active model
|
||||
// in case a fallback occurred in a previous attempt.
|
||||
const activeModel = this.config.getActiveModel();
|
||||
if (activeModel !== currentModel) {
|
||||
currentModel = activeModel;
|
||||
if (activeModel !== initialActiveModel) {
|
||||
initialActiveModel = activeModel;
|
||||
// Re-resolve config if model changed during retry
|
||||
const { generateContentConfig } =
|
||||
const { model: resolvedModel, generateContentConfig } =
|
||||
this.config.modelConfigService.getResolvedConfig({
|
||||
...modelConfigKey,
|
||||
model: activeModel,
|
||||
});
|
||||
currentModel = resolvedModel;
|
||||
currentGenerateContentConfig = generateContentConfig;
|
||||
}
|
||||
const finalConfig: GenerateContentConfig = {
|
||||
|
||||
@@ -957,17 +957,21 @@ export class GeminiClient {
|
||||
() => currentAttemptModel,
|
||||
);
|
||||
|
||||
let initialActiveModel = this.config.getActiveModel();
|
||||
|
||||
const apiCall = () => {
|
||||
// AvailabilityService
|
||||
const active = this.config.getActiveModel();
|
||||
if (active !== currentAttemptModel) {
|
||||
currentAttemptModel = active;
|
||||
if (active !== initialActiveModel) {
|
||||
initialActiveModel = active;
|
||||
// Re-resolve config if model changed
|
||||
const newConfig = this.config.modelConfigService.getResolvedConfig({
|
||||
...modelConfigKey,
|
||||
model: currentAttemptModel,
|
||||
});
|
||||
currentAttemptGenerateContentConfig = newConfig.generateContentConfig;
|
||||
const { model: resolvedModel, generateContentConfig } =
|
||||
this.config.modelConfigService.getResolvedConfig({
|
||||
...modelConfigKey,
|
||||
model: active,
|
||||
});
|
||||
currentAttemptModel = resolvedModel;
|
||||
currentAttemptGenerateContentConfig = generateContentConfig;
|
||||
}
|
||||
|
||||
const requestConfig: GenerateContentConfig = {
|
||||
|
||||
Reference in New Issue
Block a user