feat(acp): add set models interface (#20991)

This commit is contained in:
Shreya Keshive
2026-03-03 17:29:42 -05:00
committed by GitHub
parent c70c95ead3
commit 34f0c1538b
3 changed files with 221 additions and 2 deletions

View File

@@ -93,6 +93,9 @@ describe('GeminiAgent Session Resume', () => {
},
getApprovalMode: vi.fn().mockReturnValue('default'),
isPlanEnabled: vi.fn().mockReturnValue(false),
getModel: vi.fn().mockReturnValue('gemini-pro'),
getHasAccessToPreviewModel: vi.fn().mockReturnValue(false),
getGemini31LaunchedSync: vi.fn().mockReturnValue(false),
getCheckpointingEnabled: vi.fn().mockReturnValue(false),
} as unknown as Mocked<Config>;
mockSettings = {
@@ -204,6 +207,10 @@ describe('GeminiAgent Session Resume', () => {
],
currentModeId: ApprovalMode.DEFAULT,
},
models: {
availableModels: expect.any(Array) as unknown,
currentModelId: 'gemini-pro',
},
});
// Verify resumeChat received the correct arguments

View File

@@ -173,6 +173,8 @@ describe('GeminiAgent', () => {
}),
getApprovalMode: vi.fn().mockReturnValue('default'),
isPlanEnabled: vi.fn().mockReturnValue(false),
getGemini31LaunchedSync: vi.fn().mockReturnValue(false),
getHasAccessToPreviewModel: vi.fn().mockReturnValue(false),
getCheckpointingEnabled: vi.fn().mockReturnValue(false),
} as unknown as Mocked<Awaited<ReturnType<typeof loadCliConfig>>>;
mockSettings = {
@@ -304,6 +306,38 @@ describe('GeminiAgent', () => {
],
currentModeId: 'default',
});
expect(response.models).toEqual({
availableModels: expect.arrayContaining([
expect.objectContaining({
modelId: 'auto-gemini-2.5',
name: 'Auto (Gemini 2.5)',
}),
]),
currentModelId: 'gemini-pro',
});
});
it('should include preview models when user has access', async () => {
mockConfig.getHasAccessToPreviewModel = vi.fn().mockReturnValue(true);
mockConfig.getGemini31LaunchedSync = vi.fn().mockReturnValue(true);
const response = await agent.newSession({
cwd: '/tmp',
mcpServers: [],
});
expect(response.models?.availableModels).toEqual(
expect.arrayContaining([
expect.objectContaining({
modelId: 'auto-gemini-3',
name: expect.stringContaining('Auto'),
}),
expect.objectContaining({
modelId: 'gemini-3.1-pro-preview',
name: 'gemini-3.1-pro-preview',
}),
]),
);
});
it('should return modes with plan mode when plan is enabled', async () => {
@@ -331,6 +365,15 @@ describe('GeminiAgent', () => {
],
currentModeId: 'plan',
});
expect(response.models).toEqual({
availableModels: expect.arrayContaining([
expect.objectContaining({
modelId: 'auto-gemini-2.5',
name: 'Auto (Gemini 2.5)',
}),
]),
currentModelId: 'gemini-pro',
});
});
it('should fail session creation if Gemini API key is missing', async () => {
@@ -480,6 +523,32 @@ describe('GeminiAgent', () => {
}),
).rejects.toThrow('Session not found: unknown');
});
it('should delegate setModel to session (unstable)', async () => {
await agent.newSession({ cwd: '/tmp', mcpServers: [] });
const session = (
agent as unknown as { sessions: Map<string, Session> }
).sessions.get('test-session-id');
if (!session) throw new Error('Session not found');
session.setModel = vi.fn().mockReturnValue({});
const result = await agent.unstable_setSessionModel({
sessionId: 'test-session-id',
modelId: 'gemini-2.0-pro-exp',
});
expect(session.setModel).toHaveBeenCalledWith('gemini-2.0-pro-exp');
expect(result).toEqual({});
});
it('should throw error when setting model on non-existent session (unstable)', async () => {
await expect(
agent.unstable_setSessionModel({
sessionId: 'unknown',
modelId: 'gemini-2.0-pro-exp',
}),
).rejects.toThrow('Session not found: unknown');
});
});
describe('Session', () => {
@@ -528,6 +597,7 @@ describe('Session', () => {
getDebugMode: vi.fn().mockReturnValue(false),
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
setApprovalMode: vi.fn(),
setModel: vi.fn(),
isPlanEnabled: vi.fn().mockReturnValue(false),
getCheckpointingEnabled: vi.fn().mockReturnValue(false),
getGitService: vi.fn().mockResolvedValue({} as GitService),
@@ -1383,6 +1453,12 @@ describe('Session', () => {
'Invalid or unavailable mode: invalid-mode',
);
});
it('should set model on config', () => {
session.setModel('gemini-2.0-flash-exp');
expect(mockConfig.setModel).toHaveBeenCalledWith('gemini-2.0-flash-exp');
});
it('should handle unquoted commands from autocomplete (with empty leading parts)', async () => {
// Mock handleCommand to verify it gets called
const handleCommandSpy = vi

View File

@@ -37,6 +37,16 @@ import {
ApprovalMode,
getVersion,
convertSessionToClientHistory,
DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_FLASH_LITE_MODEL,
PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_3_1_MODEL,
PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL,
PREVIEW_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL_AUTO,
PREVIEW_GEMINI_MODEL_AUTO,
getDisplayString,
} from '@google/gemini-cli-core';
import * as acp from '@agentclientprotocol/sdk';
import { AcpFileSystemService } from './fileSystemService.js';
@@ -255,13 +265,23 @@ export class GeminiAgent {
session.sendAvailableCommands();
}, 0);
return {
const { availableModels, currentModelId } = buildAvailableModels(
config,
loadedSettings,
);
const response = {
sessionId,
modes: {
availableModes: buildAvailableModes(config.isPlanEnabled()),
currentModeId: config.getApprovalMode(),
},
models: {
availableModels,
currentModelId,
},
};
return response;
}
async loadSession({
@@ -316,12 +336,22 @@ export class GeminiAgent {
session.sendAvailableCommands();
}, 0);
return {
const { availableModels, currentModelId } = buildAvailableModels(
config,
this.settings,
);
const response = {
modes: {
availableModes: buildAvailableModes(config.isPlanEnabled()),
currentModeId: config.getApprovalMode(),
},
models: {
availableModels,
currentModelId,
},
};
return response;
}
private async initializeSessionConfig(
@@ -432,6 +462,16 @@ export class GeminiAgent {
}
return session.setMode(params.modeId);
}
async unstable_setSessionModel(
params: acp.SetSessionModelRequest,
): Promise<acp.SetSessionModelResponse> {
const session = this.sessions.get(params.sessionId);
if (!session) {
throw new Error(`Session not found: ${params.sessionId}`);
}
return session.setModel(params.modelId);
}
}
export class Session {
@@ -482,6 +522,11 @@ export class Session {
});
}
setModel(modelId: acp.ModelId): acp.SetSessionModelResponse {
this.config.setModel(modelId);
return {};
}
async streamHistory(messages: ConversationRecord['messages']): Promise<void> {
for (const msg of messages) {
const contentString = partListUnionToString(msg.content);
@@ -1467,3 +1512,94 @@ function buildAvailableModes(isPlanEnabled: boolean): acp.SessionMode[] {
return modes;
}
function buildAvailableModels(
config: Config,
settings: LoadedSettings,
): {
availableModels: Array<{
modelId: string;
name: string;
description?: string;
}>;
currentModelId: string;
} {
const preferredModel = config.getModel() || DEFAULT_GEMINI_MODEL_AUTO;
const shouldShowPreviewModels = config.getHasAccessToPreviewModel();
const useGemini31 = config.getGemini31LaunchedSync?.() ?? false;
const selectedAuthType = settings.merged.security.auth.selectedType;
const useCustomToolModel =
useGemini31 && selectedAuthType === AuthType.USE_GEMINI;
const mainOptions = [
{
value: DEFAULT_GEMINI_MODEL_AUTO,
title: getDisplayString(DEFAULT_GEMINI_MODEL_AUTO),
description:
'Let Gemini CLI decide the best model for the task: gemini-2.5-pro, gemini-2.5-flash',
},
];
if (shouldShowPreviewModels) {
mainOptions.unshift({
value: PREVIEW_GEMINI_MODEL_AUTO,
title: getDisplayString(PREVIEW_GEMINI_MODEL_AUTO),
description: useGemini31
? 'Let Gemini CLI decide the best model for the task: gemini-3.1-pro, gemini-3-flash'
: 'Let Gemini CLI decide the best model for the task: gemini-3-pro, gemini-3-flash',
});
}
const manualOptions = [
{
value: DEFAULT_GEMINI_MODEL,
title: getDisplayString(DEFAULT_GEMINI_MODEL),
},
{
value: DEFAULT_GEMINI_FLASH_MODEL,
title: getDisplayString(DEFAULT_GEMINI_FLASH_MODEL),
},
{
value: DEFAULT_GEMINI_FLASH_LITE_MODEL,
title: getDisplayString(DEFAULT_GEMINI_FLASH_LITE_MODEL),
},
];
if (shouldShowPreviewModels) {
const previewProModel = useGemini31
? PREVIEW_GEMINI_3_1_MODEL
: PREVIEW_GEMINI_MODEL;
const previewProValue = useCustomToolModel
? PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL
: previewProModel;
manualOptions.unshift(
{
value: previewProValue,
title: getDisplayString(previewProModel),
},
{
value: PREVIEW_GEMINI_FLASH_MODEL,
title: getDisplayString(PREVIEW_GEMINI_FLASH_MODEL),
},
);
}
const scaleOptions = (
options: Array<{ value: string; title: string; description?: string }>,
) =>
options.map((o) => ({
modelId: o.value,
name: o.title,
description: o.description,
}));
return {
availableModels: [
...scaleOptions(mainOptions),
...scaleOptions(manualOptions),
],
currentModelId: preferredModel,
};
}