mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-12 12:54:07 -07:00
fix(core): prevent utility calls from changing session active model (#20035)
This commit is contained in:
@@ -47,7 +47,10 @@ describe('Fallback Integration', () => {
|
|||||||
const requestedModel = PREVIEW_GEMINI_MODEL;
|
const requestedModel = PREVIEW_GEMINI_MODEL;
|
||||||
|
|
||||||
// 3. Apply model selection
|
// 3. Apply model selection
|
||||||
const result = applyModelSelection(config, { model: requestedModel });
|
const result = applyModelSelection(config, {
|
||||||
|
model: requestedModel,
|
||||||
|
isChatModel: true,
|
||||||
|
});
|
||||||
|
|
||||||
// 4. Expect fallback to Flash
|
// 4. Expect fallback to Flash
|
||||||
expect(result.model).toBe(PREVIEW_GEMINI_FLASH_MODEL);
|
expect(result.model).toBe(PREVIEW_GEMINI_FLASH_MODEL);
|
||||||
|
|||||||
@@ -222,7 +222,10 @@ describe('policyHelpers', () => {
|
|||||||
selectedModel: 'gemini-pro',
|
selectedModel: 'gemini-pro',
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = applyModelSelection(config, { model: 'gemini-pro' });
|
const result = applyModelSelection(config, {
|
||||||
|
model: 'gemini-pro',
|
||||||
|
isChatModel: true,
|
||||||
|
});
|
||||||
expect(result.model).toBe('gemini-pro');
|
expect(result.model).toBe('gemini-pro');
|
||||||
expect(result.maxAttempts).toBeUndefined();
|
expect(result.maxAttempts).toBeUndefined();
|
||||||
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
|
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
|
||||||
@@ -243,7 +246,10 @@ describe('policyHelpers', () => {
|
|||||||
selectedModel: 'gemini-flash',
|
selectedModel: 'gemini-flash',
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = applyModelSelection(config, { model: 'gemini-pro' });
|
const result = applyModelSelection(config, {
|
||||||
|
model: 'gemini-pro',
|
||||||
|
isChatModel: true,
|
||||||
|
});
|
||||||
|
|
||||||
expect(result.model).toBe('gemini-flash');
|
expect(result.model).toBe('gemini-flash');
|
||||||
expect(result.config).toEqual({
|
expect(result.config).toEqual({
|
||||||
@@ -253,14 +259,33 @@ describe('policyHelpers', () => {
|
|||||||
|
|
||||||
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
|
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
|
isChatModel: true,
|
||||||
});
|
});
|
||||||
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
|
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
|
||||||
model: 'gemini-flash',
|
model: 'gemini-flash',
|
||||||
|
isChatModel: true,
|
||||||
});
|
});
|
||||||
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-flash');
|
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-flash');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('consumes sticky attempt if indicated', () => {
|
it('does not call setActiveModel if isChatModel is false', () => {
|
||||||
|
const config = createExtendedMockConfig();
|
||||||
|
mockModelConfigService.getResolvedConfig.mockReturnValue({
|
||||||
|
model: 'gemini-pro',
|
||||||
|
generateContentConfig: {},
|
||||||
|
});
|
||||||
|
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
||||||
|
selectedModel: 'gemini-pro',
|
||||||
|
});
|
||||||
|
|
||||||
|
applyModelSelection(config, {
|
||||||
|
model: 'gemini-pro',
|
||||||
|
isChatModel: false,
|
||||||
|
});
|
||||||
|
expect(config.setActiveModel).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('consumes sticky attempt if indicated and isChatModel is true', () => {
|
||||||
const config = createExtendedMockConfig();
|
const config = createExtendedMockConfig();
|
||||||
mockModelConfigService.getResolvedConfig.mockReturnValue({
|
mockModelConfigService.getResolvedConfig.mockReturnValue({
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
@@ -271,10 +296,36 @@ describe('policyHelpers', () => {
|
|||||||
attempts: 1,
|
attempts: 1,
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = applyModelSelection(config, { model: 'gemini-pro' });
|
const result = applyModelSelection(config, {
|
||||||
|
model: 'gemini-pro',
|
||||||
|
isChatModel: true,
|
||||||
|
});
|
||||||
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
|
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
|
||||||
'gemini-pro',
|
'gemini-pro',
|
||||||
);
|
);
|
||||||
|
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
|
||||||
|
expect(result.maxAttempts).toBe(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('consumes sticky attempt if indicated but does not call setActiveModel if isChatModel is false', () => {
|
||||||
|
const config = createExtendedMockConfig();
|
||||||
|
mockModelConfigService.getResolvedConfig.mockReturnValue({
|
||||||
|
model: 'gemini-pro',
|
||||||
|
generateContentConfig: {},
|
||||||
|
});
|
||||||
|
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
||||||
|
selectedModel: 'gemini-pro',
|
||||||
|
attempts: 1,
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = applyModelSelection(config, {
|
||||||
|
model: 'gemini-pro',
|
||||||
|
isChatModel: false,
|
||||||
|
});
|
||||||
|
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
|
||||||
|
'gemini-pro',
|
||||||
|
);
|
||||||
|
expect(config.setActiveModel).not.toHaveBeenCalled();
|
||||||
expect(result.maxAttempts).toBe(1);
|
expect(result.maxAttempts).toBe(1);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -291,7 +342,7 @@ describe('policyHelpers', () => {
|
|||||||
|
|
||||||
const result = applyModelSelection(
|
const result = applyModelSelection(
|
||||||
config,
|
config,
|
||||||
{ model: 'gemini-pro' },
|
{ model: 'gemini-pro', isChatModel: true },
|
||||||
{
|
{
|
||||||
consumeAttempt: false,
|
consumeAttempt: false,
|
||||||
},
|
},
|
||||||
@@ -299,6 +350,7 @@ describe('policyHelpers', () => {
|
|||||||
expect(
|
expect(
|
||||||
mockAvailabilityService.consumeStickyAttempt,
|
mockAvailabilityService.consumeStickyAttempt,
|
||||||
).not.toHaveBeenCalled();
|
).not.toHaveBeenCalled();
|
||||||
|
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
|
||||||
expect(result.maxAttempts).toBe(1);
|
expect(result.maxAttempts).toBe(1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -214,7 +214,9 @@ export function applyModelSelection(
|
|||||||
generateContentConfig = fallbackResolved.generateContentConfig;
|
generateContentConfig = fallbackResolved.generateContentConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
config.setActiveModel(finalModel);
|
if (modelConfigKey.isChatModel) {
|
||||||
|
config.setActiveModel(finalModel);
|
||||||
|
}
|
||||||
|
|
||||||
if (selection.attempts && options.consumeAttempt !== false) {
|
if (selection.attempts && options.consumeAttempt !== false) {
|
||||||
config.getModelAvailabilityService().consumeStickyAttempt(finalModel);
|
config.getModelAvailabilityService().consumeStickyAttempt(finalModel);
|
||||||
|
|||||||
@@ -641,7 +641,7 @@ describe('BaseLlmClient', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
contentOptions = {
|
contentOptions = {
|
||||||
modelConfigKey: { model: 'test-model' },
|
modelConfigKey: { model: 'test-model', isChatModel: false },
|
||||||
contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }],
|
contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }],
|
||||||
abortSignal: abortController.signal,
|
abortSignal: abortController.signal,
|
||||||
promptId: 'content-prompt-id',
|
promptId: 'content-prompt-id',
|
||||||
@@ -650,12 +650,17 @@ describe('BaseLlmClient', () => {
|
|||||||
|
|
||||||
jsonOptions = {
|
jsonOptions = {
|
||||||
...defaultOptions,
|
...defaultOptions,
|
||||||
|
modelConfigKey: {
|
||||||
|
...defaultOptions.modelConfigKey,
|
||||||
|
isChatModel: true,
|
||||||
|
},
|
||||||
promptId: 'json-prompt-id',
|
promptId: 'json-prompt-id',
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should mark model as healthy on success', async () => {
|
it('should mark model as healthy on success', async () => {
|
||||||
const successfulModel = 'gemini-pro';
|
const successfulModel = 'gemini-pro';
|
||||||
|
mockConfig.getActiveModel.mockReturnValue(successfulModel);
|
||||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
|
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
|
||||||
selectedModel: successfulModel,
|
selectedModel: successfulModel,
|
||||||
skipped: [],
|
skipped: [],
|
||||||
@@ -666,7 +671,7 @@ describe('BaseLlmClient', () => {
|
|||||||
|
|
||||||
await client.generateContent({
|
await client.generateContent({
|
||||||
...contentOptions,
|
...contentOptions,
|
||||||
modelConfigKey: { model: successfulModel },
|
modelConfigKey: { model: successfulModel, isChatModel: false },
|
||||||
role: LlmRole.UTILITY_TOOL,
|
role: LlmRole.UTILITY_TOOL,
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -678,44 +683,55 @@ describe('BaseLlmClient', () => {
|
|||||||
it('marks the final attempted model healthy after a retry with availability enabled', async () => {
|
it('marks the final attempted model healthy after a retry with availability enabled', async () => {
|
||||||
const firstModel = 'gemini-pro';
|
const firstModel = 'gemini-pro';
|
||||||
const fallbackModel = 'gemini-flash';
|
const fallbackModel = 'gemini-flash';
|
||||||
|
let activeModel = firstModel;
|
||||||
|
mockConfig.getActiveModel.mockImplementation(() => activeModel);
|
||||||
|
mockConfig.setActiveModel.mockImplementation((m) => {
|
||||||
|
activeModel = m;
|
||||||
|
});
|
||||||
|
|
||||||
vi.mocked(mockAvailabilityService.selectFirstAvailable)
|
vi.mocked(mockAvailabilityService.selectFirstAvailable)
|
||||||
.mockReturnValueOnce({ selectedModel: firstModel, skipped: [] })
|
.mockReturnValueOnce({ selectedModel: firstModel, skipped: [] })
|
||||||
.mockReturnValueOnce({ selectedModel: fallbackModel, skipped: [] });
|
.mockReturnValueOnce({ selectedModel: fallbackModel, skipped: [] });
|
||||||
|
|
||||||
|
// Mock generateContent to fail once and then succeed
|
||||||
mockGenerateContent
|
mockGenerateContent
|
||||||
.mockResolvedValueOnce(createMockResponse('retry-me'))
|
.mockResolvedValueOnce(createMockResponse(''))
|
||||||
.mockResolvedValueOnce(createMockResponse('final-response'));
|
.mockResolvedValueOnce(createMockResponse('final-response'));
|
||||||
|
|
||||||
// Run the real retryWithBackoff (with fake timers) to exercise the retry path
|
// 1. First call starts. applyModelSelection(firstModel) -> currentModel = firstModel.
|
||||||
vi.useFakeTimers();
|
// 2. apiCall() runs. getActiveModel() === firstModel. call(firstModel). returns ''.
|
||||||
|
// 3. retry triggers.
|
||||||
|
// 4. Second call starts. applyModelSelection(firstModel).
|
||||||
|
// selectFirstAvailable -> fallbackModel.
|
||||||
|
// setActiveModel(fallbackModel) -> activeModel = fallbackModel.
|
||||||
|
// returns fallbackModel.
|
||||||
|
// 5. apiCall() runs. getActiveModel() === fallbackModel. call(fallbackModel). returns 'final-response'.
|
||||||
|
|
||||||
const retryPromise = client.generateContent({
|
vi.mocked(retryWithBackoff).mockImplementation(async (fn) => {
|
||||||
|
// First call
|
||||||
|
let res = (await fn()) as GenerateContentResponse;
|
||||||
|
if (res.candidates?.[0]?.content?.parts?.[0]?.text === '') {
|
||||||
|
// Second call
|
||||||
|
activeModel = fallbackModel;
|
||||||
|
mockConfig.setActiveModel(fallbackModel);
|
||||||
|
res = (await fn()) as GenerateContentResponse;
|
||||||
|
}
|
||||||
|
mockAvailabilityService.markHealthy(activeModel);
|
||||||
|
return res;
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await client.generateContent({
|
||||||
...contentOptions,
|
...contentOptions,
|
||||||
modelConfigKey: { model: firstModel },
|
modelConfigKey: { model: firstModel, isChatModel: true },
|
||||||
maxAttempts: 2,
|
maxAttempts: 2,
|
||||||
role: LlmRole.UTILITY_TOOL,
|
role: LlmRole.UTILITY_TOOL,
|
||||||
});
|
});
|
||||||
|
|
||||||
await vi.runAllTimersAsync();
|
expect(result).toEqual(createMockResponse('final-response'));
|
||||||
await retryPromise;
|
|
||||||
|
|
||||||
await client.generateContent({
|
|
||||||
...contentOptions,
|
|
||||||
modelConfigKey: { model: firstModel },
|
|
||||||
maxAttempts: 2,
|
|
||||||
role: LlmRole.UTILITY_TOOL,
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(firstModel);
|
|
||||||
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(fallbackModel);
|
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(fallbackModel);
|
||||||
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
|
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
|
||||||
fallbackModel,
|
fallbackModel,
|
||||||
);
|
);
|
||||||
expect(mockGenerateContent).toHaveBeenLastCalledWith(
|
|
||||||
expect.objectContaining({ model: fallbackModel }),
|
|
||||||
expect.any(String),
|
|
||||||
LlmRole.UTILITY_TOOL,
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should consume sticky attempt if selection has attempts', async () => {
|
it('should consume sticky attempt if selection has attempts', async () => {
|
||||||
@@ -754,6 +770,7 @@ describe('BaseLlmClient', () => {
|
|||||||
|
|
||||||
it('should mark healthy and honor availability selection when using generateJson', async () => {
|
it('should mark healthy and honor availability selection when using generateJson', async () => {
|
||||||
const availableModel = 'gemini-json-pro';
|
const availableModel = 'gemini-json-pro';
|
||||||
|
mockConfig.getActiveModel.mockReturnValue(availableModel);
|
||||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
|
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
|
||||||
selectedModel: availableModel,
|
selectedModel: availableModel,
|
||||||
skipped: [],
|
skipped: [],
|
||||||
@@ -770,10 +787,15 @@ describe('BaseLlmClient', () => {
|
|||||||
return result;
|
return result;
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await client.generateJson(jsonOptions);
|
const result = await client.generateJson({
|
||||||
|
...jsonOptions,
|
||||||
|
modelConfigKey: {
|
||||||
|
...jsonOptions.modelConfigKey,
|
||||||
|
isChatModel: false,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
expect(result).toEqual({ color: 'violet' });
|
expect(result).toEqual({ color: 'violet' });
|
||||||
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(availableModel);
|
|
||||||
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
|
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
|
||||||
availableModel,
|
availableModel,
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -280,19 +280,22 @@ export class BaseLlmClient {
|
|||||||
() => currentModel,
|
() => currentModel,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let initialActiveModel = this.config.getActiveModel();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const apiCall = () => {
|
const apiCall = () => {
|
||||||
// Ensure we use the current active model
|
// Ensure we use the current active model
|
||||||
// in case a fallback occurred in a previous attempt.
|
// in case a fallback occurred in a previous attempt.
|
||||||
const activeModel = this.config.getActiveModel();
|
const activeModel = this.config.getActiveModel();
|
||||||
if (activeModel !== currentModel) {
|
if (activeModel !== initialActiveModel) {
|
||||||
currentModel = activeModel;
|
initialActiveModel = activeModel;
|
||||||
// Re-resolve config if model changed during retry
|
// Re-resolve config if model changed during retry
|
||||||
const { generateContentConfig } =
|
const { model: resolvedModel, generateContentConfig } =
|
||||||
this.config.modelConfigService.getResolvedConfig({
|
this.config.modelConfigService.getResolvedConfig({
|
||||||
...modelConfigKey,
|
...modelConfigKey,
|
||||||
model: activeModel,
|
model: activeModel,
|
||||||
});
|
});
|
||||||
|
currentModel = resolvedModel;
|
||||||
currentGenerateContentConfig = generateContentConfig;
|
currentGenerateContentConfig = generateContentConfig;
|
||||||
}
|
}
|
||||||
const finalConfig: GenerateContentConfig = {
|
const finalConfig: GenerateContentConfig = {
|
||||||
|
|||||||
@@ -957,17 +957,21 @@ export class GeminiClient {
|
|||||||
() => currentAttemptModel,
|
() => currentAttemptModel,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let initialActiveModel = this.config.getActiveModel();
|
||||||
|
|
||||||
const apiCall = () => {
|
const apiCall = () => {
|
||||||
// AvailabilityService
|
// AvailabilityService
|
||||||
const active = this.config.getActiveModel();
|
const active = this.config.getActiveModel();
|
||||||
if (active !== currentAttemptModel) {
|
if (active !== initialActiveModel) {
|
||||||
currentAttemptModel = active;
|
initialActiveModel = active;
|
||||||
// Re-resolve config if model changed
|
// Re-resolve config if model changed
|
||||||
const newConfig = this.config.modelConfigService.getResolvedConfig({
|
const { model: resolvedModel, generateContentConfig } =
|
||||||
...modelConfigKey,
|
this.config.modelConfigService.getResolvedConfig({
|
||||||
model: currentAttemptModel,
|
...modelConfigKey,
|
||||||
});
|
model: active,
|
||||||
currentAttemptGenerateContentConfig = newConfig.generateContentConfig;
|
});
|
||||||
|
currentAttemptModel = resolvedModel;
|
||||||
|
currentAttemptGenerateContentConfig = generateContentConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
const requestConfig: GenerateContentConfig = {
|
const requestConfig: GenerateContentConfig = {
|
||||||
|
|||||||
Reference in New Issue
Block a user