From 34f0c1538be7ec801cccc960ec5f44de994e2ec3 Mon Sep 17 00:00:00 2001 From: Shreya Keshive Date: Tue, 3 Mar 2026 17:29:42 -0500 Subject: [PATCH] feat(acp): add set models interface (#20991) --- .../cli/src/zed-integration/acpResume.test.ts | 7 + .../zed-integration/zedIntegration.test.ts | 76 ++++++++++ .../cli/src/zed-integration/zedIntegration.ts | 140 +++++++++++++++++- 3 files changed, 221 insertions(+), 2 deletions(-) diff --git a/packages/cli/src/zed-integration/acpResume.test.ts b/packages/cli/src/zed-integration/acpResume.test.ts index 9addafd369..cda47c17b4 100644 --- a/packages/cli/src/zed-integration/acpResume.test.ts +++ b/packages/cli/src/zed-integration/acpResume.test.ts @@ -93,6 +93,9 @@ describe('GeminiAgent Session Resume', () => { }, getApprovalMode: vi.fn().mockReturnValue('default'), isPlanEnabled: vi.fn().mockReturnValue(false), + getModel: vi.fn().mockReturnValue('gemini-pro'), + getHasAccessToPreviewModel: vi.fn().mockReturnValue(false), + getGemini31LaunchedSync: vi.fn().mockReturnValue(false), getCheckpointingEnabled: vi.fn().mockReturnValue(false), } as unknown as Mocked; mockSettings = { @@ -204,6 +207,10 @@ describe('GeminiAgent Session Resume', () => { ], currentModeId: ApprovalMode.DEFAULT, }, + models: { + availableModels: expect.any(Array) as unknown, + currentModelId: 'gemini-pro', + }, }); // Verify resumeChat received the correct arguments diff --git a/packages/cli/src/zed-integration/zedIntegration.test.ts b/packages/cli/src/zed-integration/zedIntegration.test.ts index 23ba8b8ab8..810cb9a1de 100644 --- a/packages/cli/src/zed-integration/zedIntegration.test.ts +++ b/packages/cli/src/zed-integration/zedIntegration.test.ts @@ -173,6 +173,8 @@ describe('GeminiAgent', () => { }), getApprovalMode: vi.fn().mockReturnValue('default'), isPlanEnabled: vi.fn().mockReturnValue(false), + getGemini31LaunchedSync: vi.fn().mockReturnValue(false), + getHasAccessToPreviewModel: vi.fn().mockReturnValue(false), getCheckpointingEnabled: vi.fn().mockReturnValue(false), } as unknown as Mocked>>; mockSettings = { @@ -304,6 +306,38 @@ describe('GeminiAgent', () => { ], currentModeId: 'default', }); + expect(response.models).toEqual({ + availableModels: expect.arrayContaining([ + expect.objectContaining({ + modelId: 'auto-gemini-2.5', + name: 'Auto (Gemini 2.5)', + }), + ]), + currentModelId: 'gemini-pro', + }); + }); + + it('should include preview models when user has access', async () => { + mockConfig.getHasAccessToPreviewModel = vi.fn().mockReturnValue(true); + mockConfig.getGemini31LaunchedSync = vi.fn().mockReturnValue(true); + + const response = await agent.newSession({ + cwd: '/tmp', + mcpServers: [], + }); + + expect(response.models?.availableModels).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + modelId: 'auto-gemini-3', + name: expect.stringContaining('Auto'), + }), + expect.objectContaining({ + modelId: 'gemini-3.1-pro-preview', + name: 'gemini-3.1-pro-preview', + }), + ]), + ); }); it('should return modes with plan mode when plan is enabled', async () => { @@ -331,6 +365,15 @@ describe('GeminiAgent', () => { ], currentModeId: 'plan', }); + expect(response.models).toEqual({ + availableModels: expect.arrayContaining([ + expect.objectContaining({ + modelId: 'auto-gemini-2.5', + name: 'Auto (Gemini 2.5)', + }), + ]), + currentModelId: 'gemini-pro', + }); }); it('should fail session creation if Gemini API key is missing', async () => { @@ -480,6 +523,32 @@ describe('GeminiAgent', () => { }), ).rejects.toThrow('Session not found: unknown'); }); + + it('should delegate setModel to session (unstable)', async () => { + await agent.newSession({ cwd: '/tmp', mcpServers: [] }); + const session = ( + agent as unknown as { sessions: Map } + ).sessions.get('test-session-id'); + if (!session) throw new Error('Session not found'); + session.setModel = vi.fn().mockReturnValue({}); + + const result = await agent.unstable_setSessionModel({ + sessionId: 'test-session-id', + modelId: 'gemini-2.0-pro-exp', + }); + + expect(session.setModel).toHaveBeenCalledWith('gemini-2.0-pro-exp'); + expect(result).toEqual({}); + }); + + it('should throw error when setting model on non-existent session (unstable)', async () => { + await expect( + agent.unstable_setSessionModel({ + sessionId: 'unknown', + modelId: 'gemini-2.0-pro-exp', + }), + ).rejects.toThrow('Session not found: unknown'); + }); }); describe('Session', () => { @@ -528,6 +597,7 @@ describe('Session', () => { getDebugMode: vi.fn().mockReturnValue(false), getMessageBus: vi.fn().mockReturnValue(mockMessageBus), setApprovalMode: vi.fn(), + setModel: vi.fn(), isPlanEnabled: vi.fn().mockReturnValue(false), getCheckpointingEnabled: vi.fn().mockReturnValue(false), getGitService: vi.fn().mockResolvedValue({} as GitService), @@ -1383,6 +1453,12 @@ describe('Session', () => { 'Invalid or unavailable mode: invalid-mode', ); }); + + it('should set model on config', () => { + session.setModel('gemini-2.0-flash-exp'); + expect(mockConfig.setModel).toHaveBeenCalledWith('gemini-2.0-flash-exp'); + }); + it('should handle unquoted commands from autocomplete (with empty leading parts)', async () => { // Mock handleCommand to verify it gets called const handleCommandSpy = vi diff --git a/packages/cli/src/zed-integration/zedIntegration.ts b/packages/cli/src/zed-integration/zedIntegration.ts index 30bf8551f0..dc07502f7f 100644 --- a/packages/cli/src/zed-integration/zedIntegration.ts +++ b/packages/cli/src/zed-integration/zedIntegration.ts @@ -37,6 +37,16 @@ import { ApprovalMode, getVersion, convertSessionToClientHistory, + DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_FLASH_LITE_MODEL, + PREVIEW_GEMINI_MODEL, + PREVIEW_GEMINI_3_1_MODEL, + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, + PREVIEW_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_MODEL_AUTO, + PREVIEW_GEMINI_MODEL_AUTO, + getDisplayString, } from '@google/gemini-cli-core'; import * as acp from '@agentclientprotocol/sdk'; import { AcpFileSystemService } from './fileSystemService.js'; @@ -255,13 +265,23 @@ export class GeminiAgent { session.sendAvailableCommands(); }, 0); - return { + const { availableModels, currentModelId } = buildAvailableModels( + config, + loadedSettings, + ); + + const response = { sessionId, modes: { availableModes: buildAvailableModes(config.isPlanEnabled()), currentModeId: config.getApprovalMode(), }, + models: { + availableModels, + currentModelId, + }, }; + return response; } async loadSession({ @@ -316,12 +336,22 @@ export class GeminiAgent { session.sendAvailableCommands(); }, 0); - return { + const { availableModels, currentModelId } = buildAvailableModels( + config, + this.settings, + ); + + const response = { modes: { availableModes: buildAvailableModes(config.isPlanEnabled()), currentModeId: config.getApprovalMode(), }, + models: { + availableModels, + currentModelId, + }, }; + return response; } private async initializeSessionConfig( @@ -432,6 +462,16 @@ export class GeminiAgent { } return session.setMode(params.modeId); } + + async unstable_setSessionModel( + params: acp.SetSessionModelRequest, + ): Promise { + const session = this.sessions.get(params.sessionId); + if (!session) { + throw new Error(`Session not found: ${params.sessionId}`); + } + return session.setModel(params.modelId); + } } export class Session { @@ -482,6 +522,11 @@ export class Session { }); } + setModel(modelId: acp.ModelId): acp.SetSessionModelResponse { + this.config.setModel(modelId); + return {}; + } + async streamHistory(messages: ConversationRecord['messages']): Promise { for (const msg of messages) { const contentString = partListUnionToString(msg.content); @@ -1467,3 +1512,94 @@ function buildAvailableModes(isPlanEnabled: boolean): acp.SessionMode[] { return modes; } + +function buildAvailableModels( + config: Config, + settings: LoadedSettings, +): { + availableModels: Array<{ + modelId: string; + name: string; + description?: string; + }>; + currentModelId: string; +} { + const preferredModel = config.getModel() || DEFAULT_GEMINI_MODEL_AUTO; + const shouldShowPreviewModels = config.getHasAccessToPreviewModel(); + const useGemini31 = config.getGemini31LaunchedSync?.() ?? false; + const selectedAuthType = settings.merged.security.auth.selectedType; + const useCustomToolModel = + useGemini31 && selectedAuthType === AuthType.USE_GEMINI; + + const mainOptions = [ + { + value: DEFAULT_GEMINI_MODEL_AUTO, + title: getDisplayString(DEFAULT_GEMINI_MODEL_AUTO), + description: + 'Let Gemini CLI decide the best model for the task: gemini-2.5-pro, gemini-2.5-flash', + }, + ]; + + if (shouldShowPreviewModels) { + mainOptions.unshift({ + value: PREVIEW_GEMINI_MODEL_AUTO, + title: getDisplayString(PREVIEW_GEMINI_MODEL_AUTO), + description: useGemini31 + ? 'Let Gemini CLI decide the best model for the task: gemini-3.1-pro, gemini-3-flash' + : 'Let Gemini CLI decide the best model for the task: gemini-3-pro, gemini-3-flash', + }); + } + + const manualOptions = [ + { + value: DEFAULT_GEMINI_MODEL, + title: getDisplayString(DEFAULT_GEMINI_MODEL), + }, + { + value: DEFAULT_GEMINI_FLASH_MODEL, + title: getDisplayString(DEFAULT_GEMINI_FLASH_MODEL), + }, + { + value: DEFAULT_GEMINI_FLASH_LITE_MODEL, + title: getDisplayString(DEFAULT_GEMINI_FLASH_LITE_MODEL), + }, + ]; + + if (shouldShowPreviewModels) { + const previewProModel = useGemini31 + ? PREVIEW_GEMINI_3_1_MODEL + : PREVIEW_GEMINI_MODEL; + + const previewProValue = useCustomToolModel + ? PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL + : previewProModel; + + manualOptions.unshift( + { + value: previewProValue, + title: getDisplayString(previewProModel), + }, + { + value: PREVIEW_GEMINI_FLASH_MODEL, + title: getDisplayString(PREVIEW_GEMINI_FLASH_MODEL), + }, + ); + } + + const scaleOptions = ( + options: Array<{ value: string; title: string; description?: string }>, + ) => + options.map((o) => ({ + modelId: o.value, + name: o.title, + description: o.description, + })); + + return { + availableModels: [ + ...scaleOptions(mainOptions), + ...scaleOptions(manualOptions), + ], + currentModelId: preferredModel, + }; +}