feat(acp): add set models interface (#20991)

This commit is contained in:
Shreya Keshive
2026-03-03 17:29:42 -05:00
committed by GitHub
parent c70c95ead3
commit 34f0c1538b
3 changed files with 221 additions and 2 deletions
@@ -173,6 +173,8 @@ describe('GeminiAgent', () => {
}),
getApprovalMode: vi.fn().mockReturnValue('default'),
isPlanEnabled: vi.fn().mockReturnValue(false),
getGemini31LaunchedSync: vi.fn().mockReturnValue(false),
getHasAccessToPreviewModel: vi.fn().mockReturnValue(false),
getCheckpointingEnabled: vi.fn().mockReturnValue(false),
} as unknown as Mocked<Awaited<ReturnType<typeof loadCliConfig>>>;
mockSettings = {
@@ -304,6 +306,38 @@ describe('GeminiAgent', () => {
],
currentModeId: 'default',
});
expect(response.models).toEqual({
availableModels: expect.arrayContaining([
expect.objectContaining({
modelId: 'auto-gemini-2.5',
name: 'Auto (Gemini 2.5)',
}),
]),
currentModelId: 'gemini-pro',
});
});
it('should include preview models when user has access', async () => {
mockConfig.getHasAccessToPreviewModel = vi.fn().mockReturnValue(true);
mockConfig.getGemini31LaunchedSync = vi.fn().mockReturnValue(true);
const response = await agent.newSession({
cwd: '/tmp',
mcpServers: [],
});
expect(response.models?.availableModels).toEqual(
expect.arrayContaining([
expect.objectContaining({
modelId: 'auto-gemini-3',
name: expect.stringContaining('Auto'),
}),
expect.objectContaining({
modelId: 'gemini-3.1-pro-preview',
name: 'gemini-3.1-pro-preview',
}),
]),
);
});
it('should return modes with plan mode when plan is enabled', async () => {
@@ -331,6 +365,15 @@ describe('GeminiAgent', () => {
],
currentModeId: 'plan',
});
expect(response.models).toEqual({
availableModels: expect.arrayContaining([
expect.objectContaining({
modelId: 'auto-gemini-2.5',
name: 'Auto (Gemini 2.5)',
}),
]),
currentModelId: 'gemini-pro',
});
});
it('should fail session creation if Gemini API key is missing', async () => {
@@ -480,6 +523,32 @@ describe('GeminiAgent', () => {
}),
).rejects.toThrow('Session not found: unknown');
});
it('should delegate setModel to session (unstable)', async () => {
await agent.newSession({ cwd: '/tmp', mcpServers: [] });
const session = (
agent as unknown as { sessions: Map<string, Session> }
).sessions.get('test-session-id');
if (!session) throw new Error('Session not found');
session.setModel = vi.fn().mockReturnValue({});
const result = await agent.unstable_setSessionModel({
sessionId: 'test-session-id',
modelId: 'gemini-2.0-pro-exp',
});
expect(session.setModel).toHaveBeenCalledWith('gemini-2.0-pro-exp');
expect(result).toEqual({});
});
it('should throw error when setting model on non-existent session (unstable)', async () => {
await expect(
agent.unstable_setSessionModel({
sessionId: 'unknown',
modelId: 'gemini-2.0-pro-exp',
}),
).rejects.toThrow('Session not found: unknown');
});
});
describe('Session', () => {
@@ -528,6 +597,7 @@ describe('Session', () => {
getDebugMode: vi.fn().mockReturnValue(false),
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
setApprovalMode: vi.fn(),
setModel: vi.fn(),
isPlanEnabled: vi.fn().mockReturnValue(false),
getCheckpointingEnabled: vi.fn().mockReturnValue(false),
getGitService: vi.fn().mockResolvedValue({} as GitService),
@@ -1383,6 +1453,12 @@ describe('Session', () => {
'Invalid or unavailable mode: invalid-mode',
);
});
it('should set model on config', () => {
session.setModel('gemini-2.0-flash-exp');
expect(mockConfig.setModel).toHaveBeenCalledWith('gemini-2.0-flash-exp');
});
it('should handle unquoted commands from autocomplete (with empty leading parts)', async () => {
// Mock handleCommand to verify it gets called
const handleCommandSpy = vi