From c2e5b28e94917ebb8bb0d4a55acfb614493b426f Mon Sep 17 00:00:00 2001 From: Sri Pasumarthi <111310667+sripasg@users.noreply.github.com> Date: Wed, 29 Apr 2026 07:51:01 -0700 Subject: [PATCH] refactor(acp): modularize monolithic acpClient into specialized files (#26143) --- packages/cli/src/acp/README.md | 81 + packages/cli/src/acp/acpClient.test.ts | 2294 ----------------- ...dler.test.ts => acpCommandHandler.test.ts} | 2 +- ...commandHandler.ts => acpCommandHandler.ts} | 2 +- packages/cli/src/acp/acpErrors.test.ts | 2 +- packages/cli/src/acp/acpErrors.ts | 2 +- ...e.test.ts => acpFileSystemService.test.ts} | 4 +- ...stemService.ts => acpFileSystemService.ts} | 2 +- packages/cli/src/acp/acpResume.test.ts | 15 +- packages/cli/src/acp/acpRpcDispatcher.test.ts | 338 +++ packages/cli/src/acp/acpRpcDispatcher.ts | 232 ++ packages/cli/src/acp/acpSession.test.ts | 463 ++++ .../src/acp/{acpClient.ts => acpSession.ts} | 899 +------ .../cli/src/acp/acpSessionManager.test.ts | 386 +++ packages/cli/src/acp/acpSessionManager.ts | 322 +++ packages/cli/src/acp/acpStdioTransport.ts | 35 + packages/cli/src/acp/acpUtils.ts | 373 +++ .../cli/src/acp/commands/commandRegistry.ts | 2 +- packages/cli/src/acp/commands/extensions.ts | 2 +- packages/cli/src/acp/commands/init.ts | 2 +- packages/cli/src/acp/commands/memory.ts | 2 +- packages/cli/src/acp/commands/restore.test.ts | 23 +- packages/cli/src/acp/commands/restore.ts | 2 +- packages/cli/src/acp/commands/types.ts | 2 +- packages/cli/src/gemini.tsx | 2 +- packages/cli/src/gemini_cleanup.test.tsx | 2 +- packages/core/src/utils/partUtils.ts | 12 + 27 files changed, 2301 insertions(+), 3202 deletions(-) create mode 100644 packages/cli/src/acp/README.md delete mode 100644 packages/cli/src/acp/acpClient.test.ts rename packages/cli/src/acp/{commandHandler.test.ts => acpCommandHandler.test.ts} (94%) rename packages/cli/src/acp/{commandHandler.ts => acpCommandHandler.ts} (99%) rename packages/cli/src/acp/{fileSystemService.test.ts => acpFileSystemService.test.ts} (98%) rename packages/cli/src/acp/{fileSystemService.ts => acpFileSystemService.ts} (98%) create mode 100644 packages/cli/src/acp/acpRpcDispatcher.test.ts create mode 100644 packages/cli/src/acp/acpRpcDispatcher.ts create mode 100644 packages/cli/src/acp/acpSession.test.ts rename packages/cli/src/acp/{acpClient.ts => acpSession.ts} (62%) create mode 100644 packages/cli/src/acp/acpSessionManager.test.ts create mode 100644 packages/cli/src/acp/acpSessionManager.ts create mode 100644 packages/cli/src/acp/acpStdioTransport.ts create mode 100644 packages/cli/src/acp/acpUtils.ts diff --git a/packages/cli/src/acp/README.md b/packages/cli/src/acp/README.md new file mode 100644 index 0000000000..e8b4b4be70 --- /dev/null +++ b/packages/cli/src/acp/README.md @@ -0,0 +1,81 @@ +# Agent Client Protocol (ACP) Implementation + +This directory contains the implementation of the Agent Client Protocol (ACP) +for the Gemini CLI. The ACP allows external clients (like IDE extensions) to +communicate with the Gemini CLI agent over a structured JSON-RPC based protocol. + +## Directory Structure + +Following Phase 1 of the modularization refactor, the ACP client is organized +into the following specialized modules, all sharing the `acp` prefix for +consistency: + +- **[acpStdioTransport.ts](./acpStdioTransport.ts)**: Handles raw I/O. It sets + up the Web streams for standard input/output and creates the + `AgentSideConnection` using line-delimited JSON (ndjson). +- **[acpRpcDispatcher.ts](./acpRpcDispatcher.ts)**: Contains the `GeminiAgent` + class. This is the main entry point for incoming JSON-RPC messages. It + implements the protocol methods and delegates session-specific work to the + manager and individual sessions. +- **[acpSessionManager.ts](./acpSessionManager.ts)**: Manages multi-session + state. It handles session creation (`newSession`), loading (`loadSession`), + and configuration, isolating session state from the RPC routing. +- **[acpSession.ts](./acpSession.ts)**: Manages individual active chat sessions. + It handles prompt execution, `@path` file resolution, tool execution, command + interception, and streaming updates back to the client. +- **[acpUtils.ts](./acpUtils.ts)**: Contains shared helper functions, type + mappers (e.g., mapping internal tool kinds to ACP kinds), and Zod schemas used + across the modules. +- **[acpErrors.ts](./acpErrors.ts)**: Centralized error handling and mapping to + ACP-compliant error codes. +- **[acpCommandHandler.ts](./acpCommandHandler.ts)**: Handles interception and + execution of slash commands (e.g., `/memory`, `/init`) sent via ACP prompts. +- **[acpFileSystemService.ts](./acpFileSystemService.ts)**: Provides access to + the file system restricted by the workspace boundaries and permissions. + +## Development Instructions + +### Running Tests + +Tests are co-located with the source files: + +- `acpRpcDispatcher.test.ts`: Tests for initialization, authentication, and + handler delegation. +- `acpSessionManager.test.ts`: Tests for session lifecycle and configuration. +- `acpSession.test.ts`: Tests for prompt loops, tool execution, and @path + resolution. +- `acpResume.test.ts`: Integration tests for loading/resuming sessions. + +To run specific tests, use Vitest with the workspace filter: + +```bash +# General pattern +npm test -w @google/gemini-cli -- src/acp/.ts + +# Example +npm test -w @google/gemini-cli -- src/acp/acpRpcDispatcher.test.ts +``` + +Note: You may need to ensure your environment has Node available. If running in +a restricted environment, try sourcing NVM first: + +```bash +source ~/.nvm/nvm.sh && nvm use default && npm test -w @google/gemini-cli -- src/acp/acpSession.test.ts +``` + +### Adding New Features + +- **New RPC Method**: Add the method to `GeminiAgent` in `acpRpcDispatcher.ts` + and register it in the `AgentSideConnection` setup if necessary. +- **Session State**: If a feature requires storing state across turns within a + session, add it to the `Session` class in `acpSession.ts`. +- **Protocol Helpers**: Add any new mapping or serialization logic to + `acpUtils.ts`. + +### Coding Conventions + +- **Imports**: Use specific imports and do not import across package boundaries + using relative paths. +- **License Headers**: All new files must include the Apache-2.0 license header. +- **Type Safety**: Avoid using `any` assertions. Use Zod schemas to validate + untrusted input from the protocol. diff --git a/packages/cli/src/acp/acpClient.test.ts b/packages/cli/src/acp/acpClient.test.ts deleted file mode 100644 index 10c90824f9..0000000000 --- a/packages/cli/src/acp/acpClient.test.ts +++ /dev/null @@ -1,2294 +0,0 @@ -/** - * @license - * Copyright 2025 Google LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import { - describe, - it, - expect, - vi, - beforeEach, - afterEach, - type Mock, - type Mocked, -} from 'vitest'; -import { GeminiAgent, Session } from './acpClient.js'; -import type { CommandHandler } from './commandHandler.js'; -import * as acp from '@agentclientprotocol/sdk'; -import { - AuthType, - ToolConfirmationOutcome, - StreamEventType, - ReadManyFilesTool, - type GeminiChat, - type Config, - type MessageBus, - LlmRole, - type GitService, - type ModelRouterService, - processSingleFileContent, - InvalidStreamError, -} from '@google/gemini-cli-core'; -import { - SettingScope, - type LoadedSettings, - loadSettings, -} from '../config/settings.js'; -import { loadCliConfig, type CliArgs } from '../config/config.js'; -import * as fs from 'node:fs/promises'; -import * as path from 'node:path'; -import { ApprovalMode } from '@google/gemini-cli-core/src/policy/types.js'; - -const startMemoryServiceMock = vi.hoisted(() => vi.fn()); - -vi.mock('../config/config.js', () => ({ - loadCliConfig: vi.fn(), -})); - -vi.mock('../config/settings.js', async (importOriginal) => { - const actual = await importOriginal(); - return { - ...actual, - loadSettings: vi.fn(), - }; -}); - -vi.mock('node:crypto', () => ({ - randomUUID: () => 'test-session-id', -})); - -vi.mock('node:fs/promises'); -vi.mock('node:path', async (importOriginal) => { - const actual = await importOriginal(); - return { - ...actual, - resolve: vi.fn(), - }; -}); - -vi.mock('../ui/commands/memoryCommand.js', () => ({ - memoryCommand: { - name: 'memory', - action: vi.fn(), - }, -})); - -vi.mock('../ui/commands/extensionsCommand.js', () => ({ - extensionsCommand: vi.fn().mockReturnValue({ - name: 'extensions', - action: vi.fn(), - }), -})); - -vi.mock('../ui/commands/restoreCommand.js', () => ({ - restoreCommand: vi.fn().mockReturnValue({ - name: 'restore', - action: vi.fn(), - }), -})); - -vi.mock('../ui/commands/initCommand.js', () => ({ - initCommand: { - name: 'init', - action: vi.fn(), - }, -})); -vi.mock( - '@google/gemini-cli-core', - async ( - importOriginal: () => Promise, - ) => { - const actual = await importOriginal(); - return { - ...actual, - startMemoryService: startMemoryServiceMock, - updatePolicy: vi.fn(), - createPolicyUpdater: vi.fn(), - ReadManyFilesTool: vi.fn(), - logToolCall: vi.fn(), - LlmRole: { - MAIN: 'main', - SUBAGENT: 'subagent', - UTILITY_TOOL: 'utility_tool', - UTILITY_COMPRESSOR: 'utility_compressor', - UTILITY_SUMMARIZER: 'utility_summarizer', - UTILITY_ROUTER: 'utility_router', - UTILITY_LOOP_DETECTOR: 'utility_loop_detector', - UTILITY_NEXT_SPEAKER: 'utility_next_speaker', - UTILITY_EDIT_CORRECTOR: 'utility_edit_corrector', - UTILITY_AUTOCOMPLETE: 'utility_autocomplete', - UTILITY_FAST_ACK_HELPER: 'utility_fast_ack_helper', - }, - CoreToolCallStatus: { - Validating: 'validating', - Scheduled: 'scheduled', - Error: 'error', - Success: 'success', - Executing: 'executing', - Cancelled: 'cancelled', - AwaitingApproval: 'awaiting_approval', - }, - processSingleFileContent: vi.fn(), - }; - }, -); - -// Helper to create mock streams -// eslint-disable-next-line @typescript-eslint/no-explicit-any -async function* createMockStream(items: any[]) { - for (const item of items) { - yield item; - } -} - -describe('GeminiAgent', () => { - let mockConfig: Mocked>>; - let mockSettings: Mocked; - let mockArgv: CliArgs; - let mockConnection: Mocked; - let agent: GeminiAgent; - - beforeEach(() => { - vi.clearAllMocks(); - startMemoryServiceMock.mockResolvedValue(undefined); - mockConfig = { - refreshAuth: vi.fn(), - initialize: vi.fn(), - waitForMcpInit: vi.fn(), - getFileSystemService: vi.fn(), - setFileSystemService: vi.fn(), - getContentGeneratorConfig: vi.fn(), - isAutoMemoryEnabled: vi.fn().mockReturnValue(false), - getActiveModel: vi.fn().mockReturnValue('gemini-pro'), - getModel: vi.fn().mockReturnValue('gemini-pro'), - getGeminiClient: vi.fn().mockReturnValue({ - startChat: vi.fn().mockResolvedValue({}), - }), - getMessageBus: vi.fn().mockReturnValue({ - publish: vi.fn(), - subscribe: vi.fn(), - unsubscribe: vi.fn(), - }), - getApprovalMode: vi.fn().mockReturnValue('default'), - isPlanEnabled: vi.fn().mockReturnValue(true), - getGemini31LaunchedSync: vi.fn().mockReturnValue(false), - getHasAccessToPreviewModel: vi.fn().mockReturnValue(false), - getCheckpointingEnabled: vi.fn().mockReturnValue(false), - getDisableAlwaysAllow: vi.fn().mockReturnValue(false), - validatePathAccess: vi.fn().mockReturnValue(null), - getWorkspaceContext: vi.fn().mockReturnValue({ - addReadOnlyPath: vi.fn(), - }), - getPolicyEngine: vi.fn().mockReturnValue({ - addRule: vi.fn(), - }), - messageBus: { - publish: vi.fn(), - subscribe: vi.fn(), - unsubscribe: vi.fn(), - }, - storage: { - getWorkspaceAutoSavedPolicyPath: vi.fn(), - getAutoSavedPolicyPath: vi.fn(), - setClientName: vi.fn(), - }, - setClientName: vi.fn(), - get config() { - return this; - }, - } as unknown as Mocked>>; - mockSettings = { - merged: { - security: { auth: { selectedType: 'login_with_google' } }, - mcpServers: {}, - }, - setValue: vi.fn(), - } as unknown as Mocked; - mockArgv = {} as unknown as CliArgs; - mockConnection = { - sessionUpdate: vi.fn(), - requestPermission: vi.fn(), - } as unknown as Mocked; - - (loadCliConfig as unknown as Mock).mockResolvedValue(mockConfig); - (loadSettings as unknown as Mock).mockImplementation(() => ({ - merged: { - security: { - auth: { selectedType: AuthType.LOGIN_WITH_GOOGLE }, - enablePermanentToolApproval: true, - }, - mcpServers: {}, - }, - setValue: vi.fn(), - })); - - agent = new GeminiAgent(mockConfig, mockSettings, mockArgv, mockConnection); - }); - - it('should initialize correctly', async () => { - const response = await agent.initialize({ - clientCapabilities: { fs: { readTextFile: true, writeTextFile: true } }, - protocolVersion: 1, - }); - - expect(response.protocolVersion).toBe(acp.PROTOCOL_VERSION); - expect(response.authMethods).toHaveLength(4); - const gatewayAuth = response.authMethods?.find( - (m) => m.id === AuthType.GATEWAY, - ); - expect(gatewayAuth?._meta).toEqual({ - gateway: { - protocol: 'google', - restartRequired: 'false', - }, - }); - const geminiAuth = response.authMethods?.find( - (m) => m.id === AuthType.USE_GEMINI, - ); - expect(geminiAuth?._meta).toEqual({ - 'api-key': { - provider: 'google', - }, - }); - expect(response.agentCapabilities?.loadSession).toBe(true); - }); - - it('should authenticate correctly', async () => { - await agent.authenticate({ - methodId: AuthType.LOGIN_WITH_GOOGLE, - }); - - expect(mockConfig.refreshAuth).toHaveBeenCalledWith( - AuthType.LOGIN_WITH_GOOGLE, - undefined, - undefined, - undefined, - ); - expect(mockSettings.setValue).toHaveBeenCalledWith( - SettingScope.User, - 'security.auth.selectedType', - AuthType.LOGIN_WITH_GOOGLE, - ); - }); - - it('should authenticate correctly with api-key in _meta', async () => { - await agent.authenticate({ - methodId: AuthType.USE_GEMINI, - _meta: { - 'api-key': 'test-api-key', - }, - } as unknown as acp.AuthenticateRequest); - - expect(mockConfig.refreshAuth).toHaveBeenCalledWith( - AuthType.USE_GEMINI, - 'test-api-key', - undefined, - undefined, - ); - expect(mockSettings.setValue).toHaveBeenCalledWith( - SettingScope.User, - 'security.auth.selectedType', - AuthType.USE_GEMINI, - ); - }); - - it('should authenticate correctly with gateway method', async () => { - await agent.authenticate({ - methodId: AuthType.GATEWAY, - _meta: { - gateway: { - baseUrl: 'https://example.com', - headers: { Authorization: 'Bearer token' }, - }, - }, - } as unknown as acp.AuthenticateRequest); - - expect(mockConfig.refreshAuth).toHaveBeenCalledWith( - AuthType.GATEWAY, - undefined, - 'https://example.com', - { Authorization: 'Bearer token' }, - ); - expect(mockSettings.setValue).toHaveBeenCalledWith( - SettingScope.User, - 'security.auth.selectedType', - AuthType.GATEWAY, - ); - }); - - it('should throw acp.RequestError when gateway payload is malformed', async () => { - await expect( - agent.authenticate({ - methodId: AuthType.GATEWAY, - _meta: { - gateway: { - // Invalid baseUrl - baseUrl: 123, - headers: { Authorization: 'Bearer token' }, - }, - }, - } as unknown as acp.AuthenticateRequest), - ).rejects.toThrow(/Malformed gateway payload/); - }); - - it('should create a new session', async () => { - vi.useFakeTimers(); - mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ - apiKey: 'test-key', - }); - const response = await agent.newSession({ - cwd: '/tmp', - mcpServers: [], - }); - - expect(response.sessionId).toBe('test-session-id'); - expect(loadCliConfig).toHaveBeenCalled(); - expect(mockConfig.initialize).toHaveBeenCalled(); - expect(mockConfig.getGeminiClient).toHaveBeenCalled(); - - // Verify deferred call - await vi.runAllTimersAsync(); - expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( - expect.objectContaining({ - update: expect.objectContaining({ - sessionUpdate: 'available_commands_update', - }), - }), - ); - vi.useRealTimers(); - }); - - it('should start auto memory for new ACP sessions when enabled', async () => { - mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ - apiKey: 'test-key', - }); - mockConfig.isAutoMemoryEnabled = vi.fn().mockReturnValue(true); - - await agent.newSession({ - cwd: '/tmp', - mcpServers: [], - }); - - expect(startMemoryServiceMock).toHaveBeenCalledWith(mockConfig); - }); - - it('should not start auto memory for new ACP sessions when disabled', async () => { - mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ - apiKey: 'test-key', - }); - mockConfig.isAutoMemoryEnabled = vi.fn().mockReturnValue(false); - - await agent.newSession({ - cwd: '/tmp', - mcpServers: [], - }); - - expect(startMemoryServiceMock).not.toHaveBeenCalled(); - }); - - it('should return modes without plan mode when plan is disabled', async () => { - mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ - apiKey: 'test-key', - }); - mockConfig.isPlanEnabled = vi.fn().mockReturnValue(false); - mockConfig.getApprovalMode = vi.fn().mockReturnValue('default'); - - const response = await agent.newSession({ - cwd: '/tmp', - mcpServers: [], - }); - - expect(response.modes).toEqual({ - availableModes: [ - { id: 'default', name: 'Default', description: 'Prompts for approval' }, - { - id: 'autoEdit', - name: 'Auto Edit', - description: 'Auto-approves edit tools', - }, - { id: 'yolo', name: 'YOLO', description: 'Auto-approves all tools' }, - ], - currentModeId: 'default', - }); - expect(response.models).toEqual({ - availableModels: expect.arrayContaining([ - expect.objectContaining({ - modelId: 'auto-gemini-2.5', - name: 'Auto (Gemini 2.5)', - }), - ]), - currentModelId: 'gemini-pro', - }); - }); - - it('should include preview models when user has access', async () => { - mockConfig.getHasAccessToPreviewModel = vi.fn().mockReturnValue(true); - mockConfig.getGemini31LaunchedSync = vi.fn().mockReturnValue(true); - - const response = await agent.newSession({ - cwd: '/tmp', - mcpServers: [], - }); - - expect(response.models?.availableModels).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - modelId: 'auto-gemini-3', - name: expect.stringContaining('Auto'), - }), - expect.objectContaining({ - modelId: 'gemini-3.1-pro-preview', - name: 'gemini-3.1-pro-preview', - }), - ]), - ); - }); - - it('should include gemini-3.1-flash-lite when useGemini31FlashLite is true', async () => { - mockConfig.getHasAccessToPreviewModel = vi.fn().mockReturnValue(true); - mockConfig.getGemini31LaunchedSync = vi.fn().mockReturnValue(true); - mockConfig.getGemini31FlashLiteLaunchedSync = vi.fn().mockReturnValue(true); - - const response = await agent.newSession({ - cwd: '/tmp', - mcpServers: [], - }); - - expect(response.models?.availableModels).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - modelId: 'gemini-3.1-flash-lite-preview', - name: 'gemini-3.1-flash-lite-preview', - }), - ]), - ); - }); - - it('should return modes with plan mode when plan is enabled', async () => { - mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ - apiKey: 'test-key', - }); - mockConfig.isPlanEnabled = vi.fn().mockReturnValue(true); - mockConfig.getApprovalMode = vi.fn().mockReturnValue('plan'); - - const response = await agent.newSession({ - cwd: '/tmp', - mcpServers: [], - }); - - expect(response.modes).toEqual({ - availableModes: [ - { id: 'default', name: 'Default', description: 'Prompts for approval' }, - { - id: 'autoEdit', - name: 'Auto Edit', - description: 'Auto-approves edit tools', - }, - { id: 'yolo', name: 'YOLO', description: 'Auto-approves all tools' }, - { id: 'plan', name: 'Plan', description: 'Read-only mode' }, - ], - currentModeId: 'plan', - }); - expect(response.models).toEqual({ - availableModels: expect.arrayContaining([ - expect.objectContaining({ - modelId: 'auto-gemini-2.5', - name: 'Auto (Gemini 2.5)', - }), - ]), - currentModelId: 'gemini-pro', - }); - }); - - it('should fail session creation if Gemini API key is missing', async () => { - (loadSettings as unknown as Mock).mockImplementation(() => ({ - merged: { - security: { auth: { selectedType: AuthType.USE_GEMINI } }, - mcpServers: {}, - }, - setValue: vi.fn(), - })); - mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ - apiKey: undefined, - }); - - await expect( - agent.newSession({ - cwd: '/tmp', - mcpServers: [], - }), - ).rejects.toMatchObject({ - message: 'Gemini API key is missing or not configured.', - }); - }); - - it('should create a new session with mcp servers', async () => { - const mcpServers = [ - { - name: 'test-server', - command: 'node', - args: ['server.js'], - env: [{ name: 'KEY', value: 'VALUE' }], - }, - ]; - - await agent.newSession({ - cwd: '/tmp', - mcpServers, - }); - - expect(loadCliConfig).toHaveBeenCalledWith( - expect.objectContaining({ - mcpServers: expect.objectContaining({ - 'test-server': expect.objectContaining({ - command: 'node', - args: ['server.js'], - env: { KEY: 'VALUE' }, - }), - }), - }), - 'test-session-id', - mockArgv, - { cwd: '/tmp' }, - ); - }); - - it('should handle authentication failure gracefully', async () => { - mockConfig.refreshAuth.mockRejectedValue(new Error('Auth failed')); - const debugSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); - - // Should throw RequestError with custom message - await expect( - agent.newSession({ - cwd: '/tmp', - mcpServers: [], - }), - ).rejects.toMatchObject({ - message: 'Auth failed', - }); - - debugSpy.mockRestore(); - }); - - it('should initialize file system service if client supports it', async () => { - agent = new GeminiAgent(mockConfig, mockSettings, mockArgv, mockConnection); - await agent.initialize({ - clientCapabilities: { fs: { readTextFile: true, writeTextFile: true } }, - protocolVersion: 1, - }); - - await agent.newSession({ - cwd: '/tmp', - mcpServers: [], - }); - - expect(mockConfig.setFileSystemService).toHaveBeenCalled(); - }); - - it('should cancel a session', async () => { - await agent.newSession({ cwd: '/tmp', mcpServers: [] }); - // Mock the session's cancelPendingPrompt - const session = ( - agent as unknown as { sessions: Map } - ).sessions.get('test-session-id'); - if (!session) throw new Error('Session not found'); - session.cancelPendingPrompt = vi.fn(); - - await agent.cancel({ sessionId: 'test-session-id' }); - - expect(session.cancelPendingPrompt).toHaveBeenCalled(); - }); - - it('should throw error when cancelling non-existent session', async () => { - await expect(agent.cancel({ sessionId: 'unknown' })).rejects.toThrow( - 'Session not found', - ); - }); - - it('should delegate prompt to session', async () => { - await agent.newSession({ cwd: '/tmp', mcpServers: [] }); - const session = ( - agent as unknown as { sessions: Map } - ).sessions.get('test-session-id'); - if (!session) throw new Error('Session not found'); - session.prompt = vi.fn().mockResolvedValue({ stopReason: 'end_turn' }); - - const result = await agent.prompt({ - sessionId: 'test-session-id', - prompt: [], - }); - - expect(session.prompt).toHaveBeenCalled(); - expect(result).toMatchObject({ stopReason: 'end_turn' }); - }); - - it('should delegate setMode to session', async () => { - await agent.newSession({ cwd: '/tmp', mcpServers: [] }); - const session = ( - agent as unknown as { sessions: Map } - ).sessions.get('test-session-id'); - if (!session) throw new Error('Session not found'); - session.setMode = vi.fn().mockReturnValue({}); - - const result = await agent.setSessionMode({ - sessionId: 'test-session-id', - modeId: 'plan', - }); - - expect(session.setMode).toHaveBeenCalledWith('plan'); - expect(result).toEqual({}); - }); - - it('should throw error when setting mode on non-existent session', async () => { - await expect( - agent.setSessionMode({ - sessionId: 'unknown', - modeId: 'plan', - }), - ).rejects.toThrow('Session not found: unknown'); - }); - - it('should delegate setModel to session (unstable)', async () => { - await agent.newSession({ cwd: '/tmp', mcpServers: [] }); - const session = ( - agent as unknown as { sessions: Map } - ).sessions.get('test-session-id'); - if (!session) throw new Error('Session not found'); - session.setModel = vi.fn().mockReturnValue({}); - - const result = await agent.unstable_setSessionModel({ - sessionId: 'test-session-id', - modelId: 'gemini-2.0-pro-exp', - }); - - expect(session.setModel).toHaveBeenCalledWith('gemini-2.0-pro-exp'); - expect(result).toEqual({}); - }); - - it('should throw error when setting model on non-existent session (unstable)', async () => { - await expect( - agent.unstable_setSessionModel({ - sessionId: 'unknown', - modelId: 'gemini-2.0-pro-exp', - }), - ).rejects.toThrow('Session not found: unknown'); - }); -}); - -describe('Session', () => { - let mockChat: Mocked; - let mockConfig: Mocked; - let mockConnection: Mocked; - let session: Session; - let mockToolRegistry: { getTool: Mock }; - let mockTool: { kind: string; build: Mock }; - let mockMessageBus: Mocked; - - beforeEach(() => { - mockChat = { - sendMessageStream: vi.fn(), - addHistory: vi.fn(), - recordCompletedToolCalls: vi.fn(), - getHistory: vi.fn().mockReturnValue([]), - } as unknown as Mocked; - mockTool = { - kind: 'read', - build: vi.fn().mockReturnValue({ - getDescription: () => 'Test Tool', - toolLocations: () => [], - shouldConfirmExecute: vi.fn().mockResolvedValue(null), - execute: vi.fn().mockResolvedValue({ llmContent: 'Tool Result' }), - }), - }; - mockToolRegistry = { - getTool: vi.fn().mockReturnValue(mockTool), - }; - mockMessageBus = { - publish: vi.fn(), - subscribe: vi.fn(), - unsubscribe: vi.fn(), - } as unknown as Mocked; - mockConfig = { - getModel: vi.fn().mockReturnValue('gemini-pro'), - getActiveModel: vi.fn().mockReturnValue('gemini-pro'), - getModelRouterService: vi.fn().mockReturnValue({ - route: vi.fn().mockResolvedValue({ model: 'resolved-model' }), - }), - getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), - getMcpServers: vi.fn(), - getFileService: vi.fn().mockReturnValue({ - shouldIgnoreFile: vi.fn().mockReturnValue(false), - }), - getFileFilteringOptions: vi.fn().mockReturnValue({}), - getFileSystemService: vi.fn().mockReturnValue({}), - getTargetDir: vi.fn().mockReturnValue('/tmp'), - getEnableRecursiveFileSearch: vi.fn().mockReturnValue(false), - getDebugMode: vi.fn().mockReturnValue(false), - getMessageBus: vi.fn().mockReturnValue(mockMessageBus), - setApprovalMode: vi.fn(), - setModel: vi.fn(), - isPlanEnabled: vi.fn().mockReturnValue(true), - getCheckpointingEnabled: vi.fn().mockReturnValue(false), - getGitService: vi.fn().mockResolvedValue({} as GitService), - validatePathAccess: vi.fn().mockReturnValue(null), - getWorkspaceContext: vi.fn().mockReturnValue({ - addReadOnlyPath: vi.fn(), - }), - waitForMcpInit: vi.fn(), - getDisableAlwaysAllow: vi.fn().mockReturnValue(false), - get config() { - return this; - }, - get toolRegistry() { - return mockToolRegistry; - }, - } as unknown as Mocked; - mockConnection = { - sessionUpdate: vi.fn(), - requestPermission: vi.fn(), - sendNotification: vi.fn(), - } as unknown as Mocked; - - session = new Session('session-1', mockChat, mockConfig, mockConnection, { - system: { settings: {} }, - systemDefaults: { settings: {} }, - user: { settings: {} }, - workspace: { settings: {} }, - merged: { - security: { enablePermanentToolApproval: true }, - mcpServers: {}, - }, - errors: [], - } as unknown as LoadedSettings); - - (ReadManyFilesTool as unknown as Mock).mockImplementation(() => ({ - name: 'read_many_files', - kind: 'read', - build: vi.fn().mockReturnValue({ - getDescription: () => 'Read files', - toolLocations: () => [], - execute: vi.fn().mockResolvedValue({ - llmContent: ['--- file.txt ---\n\nFile content\n\n'], - }), - }), - })); - }); - - afterEach(() => { - vi.restoreAllMocks(); - }); - - it('should send available commands', async () => { - await session.sendAvailableCommands(); - - expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( - expect.objectContaining({ - update: expect.objectContaining({ - sessionUpdate: 'available_commands_update', - availableCommands: expect.arrayContaining([ - expect.objectContaining({ name: 'memory' }), - expect.objectContaining({ name: 'extensions' }), - expect.objectContaining({ name: 'restore' }), - expect.objectContaining({ name: 'init' }), - ]), - }), - }), - ); - }); - - it('should await MCP initialization before processing a prompt', async () => { - const stream = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [{ content: { parts: [{ text: 'Hi' }] } }] }, - }, - ]); - mockChat.sendMessageStream.mockResolvedValue(stream); - - await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'test' }], - }); - - expect(mockConfig.waitForMcpInit).toHaveBeenCalledOnce(); - const waitOrder = (mockConfig.waitForMcpInit as Mock).mock - .invocationCallOrder[0]; - const sendOrder = (mockChat.sendMessageStream as Mock).mock - .invocationCallOrder[0]; - expect(waitOrder).toBeLessThan(sendOrder); - }); - - it('should handle prompt with text response', async () => { - const stream = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { - candidates: [{ content: { parts: [{ text: 'Hello' }] } }], - }, - }, - ]); - mockChat.sendMessageStream.mockResolvedValue(stream); - - const result = await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Hi' }], - }); - - expect(mockChat.sendMessageStream).toHaveBeenCalled(); - expect(mockConnection.sessionUpdate).toHaveBeenCalledWith({ - sessionId: 'session-1', - update: { - sessionUpdate: 'agent_message_chunk', - content: { type: 'text', text: 'Hello' }, - }, - }); - expect(result).toMatchObject({ stopReason: 'end_turn' }); - }); - - it('should use model router to determine model', async () => { - const mockRouter = { - route: vi.fn().mockResolvedValue({ model: 'routed-model' }), - } as unknown as ModelRouterService; - mockConfig.getModelRouterService.mockReturnValue(mockRouter); - - const stream = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { - candidates: [{ content: { parts: [{ text: 'Hello' }] } }], - }, - }, - ]); - mockChat.sendMessageStream.mockResolvedValue(stream); - - await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Hi' }], - }); - - expect(mockRouter.route).toHaveBeenCalledWith( - expect.objectContaining({ - requestedModel: 'gemini-pro', - request: [{ text: 'Hi' }], - }), - ); - expect(mockChat.sendMessageStream).toHaveBeenCalledWith( - expect.objectContaining({ model: 'routed-model' }), - expect.any(Array), - expect.any(String), - expect.any(Object), - expect.any(String), - ); - }); - - it('should handle prompt with empty response (InvalidStreamError)', async () => { - mockChat.sendMessageStream.mockRejectedValue( - new InvalidStreamError('Empty response', 'NO_RESPONSE_TEXT'), - ); - - const result = await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Hi' }], - }); - - expect(mockChat.sendMessageStream).toHaveBeenCalled(); - expect(result).toMatchObject({ stopReason: 'end_turn' }); - }); - - it('should handle prompt with empty response (NO_RESPONSE_TEXT anomaly)', async () => { - mockChat.sendMessageStream.mockRejectedValue({ type: 'NO_RESPONSE_TEXT' }); - - const result = await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Hi' }], - }); - - expect(mockChat.sendMessageStream).toHaveBeenCalled(); - expect(result).toMatchObject({ stopReason: 'end_turn' }); - }); - - it('should handle prompt with no finish reason (InvalidStreamError)', async () => { - mockChat.sendMessageStream.mockRejectedValue( - new InvalidStreamError('No finish reason', 'NO_FINISH_REASON'), - ); - - const result = await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Hi' }], - }); - - expect(mockChat.sendMessageStream).toHaveBeenCalled(); - expect(result).toMatchObject({ stopReason: 'end_turn' }); - }); - - it('should handle prompt with no finish reason (NO_FINISH_REASON anomaly)', async () => { - mockChat.sendMessageStream.mockRejectedValue({ type: 'NO_FINISH_REASON' }); - - const result = await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Hi' }], - }); - - expect(mockChat.sendMessageStream).toHaveBeenCalled(); - expect(result).toMatchObject({ stopReason: 'end_turn' }); - }); - - it('should handle /memory command', async () => { - const handleCommandSpy = vi - .spyOn( - (session as unknown as { commandHandler: CommandHandler }) - .commandHandler, - 'handleCommand', - ) - .mockResolvedValue(true); - - const result = await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: '/memory view' }], - }); - - expect(result).toMatchObject({ stopReason: 'end_turn' }); - expect(handleCommandSpy).toHaveBeenCalledWith( - '/memory view', - expect.any(Object), - ); - expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); - }); - - it('should handle /extensions command', async () => { - const handleCommandSpy = vi - .spyOn( - (session as unknown as { commandHandler: CommandHandler }) - .commandHandler, - 'handleCommand', - ) - .mockResolvedValue(true); - - const result = await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: '/extensions list' }], - }); - - expect(result).toMatchObject({ stopReason: 'end_turn' }); - expect(handleCommandSpy).toHaveBeenCalledWith( - '/extensions list', - expect.any(Object), - ); - expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); - }); - - it('should handle /extensions explore command', async () => { - const handleCommandSpy = vi - .spyOn( - (session as unknown as { commandHandler: CommandHandler }) - .commandHandler, - 'handleCommand', - ) - .mockResolvedValue(true); - - const result = await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: '/extensions explore' }], - }); - - expect(result).toMatchObject({ stopReason: 'end_turn' }); - expect(handleCommandSpy).toHaveBeenCalledWith( - '/extensions explore', - expect.any(Object), - ); - expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); - }); - - it('should handle /restore command', async () => { - const handleCommandSpy = vi - .spyOn( - (session as unknown as { commandHandler: CommandHandler }) - .commandHandler, - 'handleCommand', - ) - .mockResolvedValue(true); - - const result = await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: '/restore' }], - }); - - expect(result).toMatchObject({ stopReason: 'end_turn' }); - expect(handleCommandSpy).toHaveBeenCalledWith( - '/restore', - expect.any(Object), - ); - expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); - }); - - it('should handle /init command', async () => { - const handleCommandSpy = vi - .spyOn( - (session as unknown as { commandHandler: CommandHandler }) - .commandHandler, - 'handleCommand', - ) - .mockResolvedValue(true); - - const result = await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: '/init' }], - }); - - expect(result).toMatchObject({ stopReason: 'end_turn' }); - expect(handleCommandSpy).toHaveBeenCalledWith('/init', expect.any(Object)); - expect(mockChat.sendMessageStream).not.toHaveBeenCalled(); - }); - - it('should handle tool calls', async () => { - const stream1 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { - functionCalls: [{ name: 'test_tool', args: { foo: 'bar' } }], - }, - }, - ]); - const stream2 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { - candidates: [{ content: { parts: [{ text: 'Result' }] } }], - }, - }, - ]); - - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); - - const result = await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Call tool' }], - }); - - expect(mockToolRegistry.getTool).toHaveBeenCalledWith('test_tool'); - expect(mockTool.build).toHaveBeenCalledWith({ foo: 'bar' }); - expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( - expect.objectContaining({ - update: expect.objectContaining({ - sessionUpdate: 'tool_call', - status: 'in_progress', - kind: 'read', - }), - }), - ); - expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( - expect.objectContaining({ - update: expect.objectContaining({ - sessionUpdate: 'tool_call_update', - status: 'completed', - title: 'Test Tool', - locations: [], - kind: 'read', - }), - }), - ); - expect(result).toMatchObject({ stopReason: 'end_turn' }); - }); - - it('should handle tool call permission request', async () => { - const confirmationDetails = { - type: 'info', - onConfirm: vi.fn(), - }; - mockTool.build.mockReturnValue({ - getDescription: () => 'Test Tool', - toolLocations: () => [], - shouldConfirmExecute: vi.fn().mockResolvedValue(confirmationDetails), - execute: vi.fn().mockResolvedValue({ llmContent: 'Tool Result' }), - }); - - mockConnection.requestPermission.mockResolvedValue({ - outcome: { - outcome: 'selected', - optionId: ToolConfirmationOutcome.ProceedOnce, - }, - }); - - const stream1 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { - functionCalls: [{ name: 'test_tool', args: {} }], - }, - }, - ]); - const stream2 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); - - await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Call tool' }], - }); - - expect(mockConnection.requestPermission).toHaveBeenCalled(); - expect(confirmationDetails.onConfirm).toHaveBeenCalledWith( - ToolConfirmationOutcome.ProceedOnce, - ); - }); - - it('should exclude always allow options when disableAlwaysAllow is true', async () => { - mockConfig.getDisableAlwaysAllow = vi.fn().mockReturnValue(true); - const confirmationDetails = { - type: 'info', - onConfirm: vi.fn(), - }; - mockTool.build.mockReturnValue({ - getDescription: () => 'Test Tool', - toolLocations: () => [], - shouldConfirmExecute: vi.fn().mockResolvedValue(confirmationDetails), - execute: vi.fn().mockResolvedValue({ llmContent: 'Tool Result' }), - }); - - mockConnection.requestPermission.mockResolvedValue({ - outcome: { - outcome: 'selected', - optionId: ToolConfirmationOutcome.ProceedOnce, - }, - }); - - const stream1 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { - functionCalls: [{ name: 'test_tool', args: {} }], - }, - }, - ]); - const stream2 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); - - await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Call tool' }], - }); - - expect(mockConnection.requestPermission).toHaveBeenCalledWith( - expect.objectContaining({ - options: expect.not.arrayContaining([ - expect.objectContaining({ - optionId: ToolConfirmationOutcome.ProceedAlways, - }), - ]), - }), - ); - }); - - it('should exclude always allow and save permanent option when enablePermanentToolApproval is false', async () => { - mockConfig.getDisableAlwaysAllow = vi.fn().mockReturnValue(false); - const confirmationDetails = { - type: 'edit', - onConfirm: vi.fn(), - }; - mockTool.build.mockReturnValue({ - getDescription: () => 'Test Tool', - toolLocations: () => [], - shouldConfirmExecute: vi.fn().mockResolvedValue(confirmationDetails), - execute: vi.fn().mockResolvedValue({ llmContent: 'Tool Result' }), - }); - - const customSettings = { - system: { settings: {} }, - systemDefaults: { settings: {} }, - user: { settings: {} }, - workspace: { settings: {} }, - merged: { - security: { enablePermanentToolApproval: false }, - mcpServers: {}, - }, - errors: [], - } as unknown as LoadedSettings; - - const localSession = new Session( - 'session-2', - mockChat, - mockConfig, - mockConnection, - customSettings, - ); - - mockConnection.requestPermission.mockResolvedValueOnce({ - outcome: { - outcome: 'selected', - optionId: ToolConfirmationOutcome.ProceedOnce, - }, - }); - - const stream1 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { - functionCalls: [{ name: 'test_tool', args: {} }], - }, - }, - ]); - const stream2 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); - - await localSession.prompt({ - sessionId: 'session-2', - prompt: [{ type: 'text', text: 'Call tool' }], - }); - - expect(mockConnection.requestPermission).toHaveBeenCalledWith( - expect.objectContaining({ - options: expect.not.arrayContaining([ - expect.objectContaining({ - optionId: ToolConfirmationOutcome.ProceedAlwaysAndSave, - }), - ]), - }), - ); - expect(mockConnection.requestPermission).toHaveBeenCalledWith( - expect.objectContaining({ - options: expect.arrayContaining([ - expect.objectContaining({ - optionId: ToolConfirmationOutcome.ProceedAlways, - }), - ]), - }), - ); - }); - - it('should include always allow and save permanent option when enablePermanentToolApproval is true', async () => { - mockConfig.getDisableAlwaysAllow = vi.fn().mockReturnValue(false); - const confirmationDetails = { - type: 'edit', - onConfirm: vi.fn(), - }; - mockTool.build.mockReturnValue({ - getDescription: () => 'Test Tool', - toolLocations: () => [], - shouldConfirmExecute: vi.fn().mockResolvedValue(confirmationDetails), - execute: vi.fn().mockResolvedValue({ llmContent: 'Tool Result' }), - }); - - const customSettings = { - system: { settings: {} }, - systemDefaults: { settings: {} }, - user: { settings: {} }, - workspace: { settings: {} }, - merged: { - security: { enablePermanentToolApproval: true }, - mcpServers: {}, - }, - errors: [], - } as unknown as LoadedSettings; - - const localSession = new Session( - 'session-2', - mockChat, - mockConfig, - mockConnection, - customSettings, - ); - - mockConnection.requestPermission.mockResolvedValueOnce({ - outcome: { - outcome: 'selected', - optionId: ToolConfirmationOutcome.ProceedOnce, - }, - }); - - const stream1 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { - functionCalls: [{ name: 'test_tool', args: {} }], - }, - }, - ]); - const stream2 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); - - await localSession.prompt({ - sessionId: 'session-2', - prompt: [{ type: 'text', text: 'Call tool' }], - }); - - expect(mockConnection.requestPermission).toHaveBeenCalledWith( - expect.objectContaining({ - options: expect.arrayContaining([ - expect.objectContaining({ - optionId: ToolConfirmationOutcome.ProceedAlwaysAndSave, - name: 'Allow for this file in all future sessions', - }), - ]), - }), - ); - }); - - it('should use filePath for ACP diff content in permission request', async () => { - const confirmationDetails = { - type: 'edit', - title: 'Confirm Write: test.txt', - fileName: 'test.txt', - filePath: '/tmp/test.txt', - originalContent: 'old', - newContent: 'new', - onConfirm: vi.fn(), - }; - mockTool.build.mockReturnValue({ - getDescription: () => 'Test Tool', - toolLocations: () => [], - shouldConfirmExecute: vi.fn().mockResolvedValue(confirmationDetails), - execute: vi.fn().mockResolvedValue({ llmContent: 'Tool Result' }), - }); - - mockConnection.requestPermission.mockResolvedValue({ - outcome: { - outcome: 'selected', - optionId: ToolConfirmationOutcome.ProceedOnce, - }, - }); - - const stream1 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { - functionCalls: [{ name: 'test_tool', args: {} }], - }, - }, - ]); - const stream2 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); - - await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Call tool' }], - }); - - expect(mockConnection.requestPermission).toHaveBeenCalledWith( - expect.objectContaining({ - toolCall: expect.objectContaining({ - content: expect.arrayContaining([ - expect.objectContaining({ - type: 'diff', - path: '/tmp/test.txt', - oldText: 'old', - newText: 'new', - }), - ]), - }), - }), - ); - }); - - it('should split getDisplayTitle and getExplanation for title and content in permission request', async () => { - const confirmationDetails = { - type: 'info', - onConfirm: vi.fn(), - }; - mockTool.build.mockReturnValue({ - getDescription: () => 'Original Description', - getDisplayTitle: () => 'Display Title Only', - getExplanation: () => 'A detailed explanation text', - toolLocations: () => [], - shouldConfirmExecute: vi.fn().mockResolvedValue(confirmationDetails), - execute: vi.fn().mockResolvedValue({ llmContent: 'Tool Result' }), - }); - - mockConnection.requestPermission.mockResolvedValue({ - outcome: { - outcome: 'selected', - optionId: ToolConfirmationOutcome.ProceedOnce, - }, - }); - - const stream1 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { - functionCalls: [{ name: 'test_tool', args: {} }], - }, - }, - ]); - const stream2 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); - - await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Call tool' }], - }); - - expect(mockConnection.requestPermission).toHaveBeenCalledWith( - expect.objectContaining({ - toolCall: expect.objectContaining({ - title: 'Display Title Only', - content: [], - }), - }), - ); - - expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( - expect.objectContaining({ - update: expect.objectContaining({ - sessionUpdate: 'agent_thought_chunk', - content: { type: 'text', text: 'A detailed explanation text' }, - }), - }), - ); - }); - - it('should call updatePolicy when tool permission triggers always allow', async () => { - const confirmationDetails = { - type: 'info', - onConfirm: vi.fn(), - }; - mockTool.build.mockReturnValue({ - getDescription: () => 'Test Tool', - toolLocations: () => [], - shouldConfirmExecute: vi.fn().mockResolvedValue(confirmationDetails), - execute: vi.fn().mockResolvedValue({ llmContent: 'Tool Result' }), - }); - - mockConnection.requestPermission.mockResolvedValue({ - outcome: { - outcome: 'selected', - optionId: ToolConfirmationOutcome.ProceedAlways, - }, - }); - - const stream1 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { - functionCalls: [{ name: 'test_tool', args: {} }], - }, - }, - ]); - const stream2 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); - - const { updatePolicy } = await import('@google/gemini-cli-core'); - - await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Call tool' }], - }); - - expect(confirmationDetails.onConfirm).toHaveBeenCalled(); - - expect(updatePolicy).toHaveBeenCalled(); - }); - - it('should use filePath for ACP diff content in tool result', async () => { - mockTool.build.mockReturnValue({ - getDescription: () => 'Test Tool', - toolLocations: () => [], - shouldConfirmExecute: vi.fn().mockResolvedValue(null), - execute: vi.fn().mockResolvedValue({ - llmContent: 'Tool Result', - returnDisplay: { - fileName: 'test.txt', - filePath: '/tmp/test.txt', - originalContent: 'old', - newContent: 'new', - }, - }), - }); - - const stream1 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { - functionCalls: [{ name: 'test_tool', args: {} }], - }, - }, - ]); - const stream2 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); - - await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Call tool' }], - }); - - const updateCalls = mockConnection.sessionUpdate.mock.calls.map( - (call) => call[0], - ); - const toolCallUpdate = updateCalls.find( - (call) => call.update?.sessionUpdate === 'tool_call_update', - ); - - expect(toolCallUpdate).toEqual( - expect.objectContaining({ - update: expect.objectContaining({ - content: expect.arrayContaining([ - expect.objectContaining({ - type: 'diff', - path: '/tmp/test.txt', - oldText: 'old', - newText: 'new', - }), - ]), - }), - }), - ); - }); - - it('should handle tool call cancellation by user', async () => { - const confirmationDetails = { - type: 'info', - onConfirm: vi.fn(), - }; - mockTool.build.mockReturnValue({ - getDescription: () => 'Test Tool', - toolLocations: () => [], - shouldConfirmExecute: vi.fn().mockResolvedValue(confirmationDetails), - execute: vi.fn().mockResolvedValue({ llmContent: 'Tool Result' }), - }); - - mockConnection.requestPermission.mockResolvedValue({ - outcome: { outcome: 'cancelled' }, - }); - - const stream1 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { - functionCalls: [{ name: 'test_tool', args: {} }], - }, - }, - ]); - const stream2 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); - - await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Call tool' }], - }); - - // When cancelled, it sends an error response to the model - // We can verify that the second call to sendMessageStream contains the error - expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); - const secondCallArgs = mockChat.sendMessageStream.mock.calls[1]; - const parts = secondCallArgs[1]; // parts - expect(parts).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - functionResponse: expect.objectContaining({ - response: { - error: expect.stringContaining('canceled by the user'), - }, - }), - }), - ]), - ); - }); - - it('should include _meta.kind in diff tool calls', async () => { - // Test 'add' (no original content) - const addConfirmation = { - type: 'edit', - fileName: 'new.txt', - originalContent: null, - newContent: 'New content', - onConfirm: vi.fn(), - }; - - // Test 'modify' (original and new content) - const modifyConfirmation = { - type: 'edit', - fileName: 'existing.txt', - originalContent: 'Old content', - newContent: 'New content', - onConfirm: vi.fn(), - }; - - // Test 'delete' (original content, no new content) - const deleteConfirmation = { - type: 'edit', - fileName: 'deleted.txt', - originalContent: 'Old content', - newContent: '', - onConfirm: vi.fn(), - }; - - const mockBuild = vi.fn(); - mockTool.build = mockBuild; - - // Helper to simulate tool call and check permission request - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const checkDiffKind = async (confirmation: any, expectedKind: string) => { - mockBuild.mockReturnValueOnce({ - getDescription: () => 'Test Tool', - toolLocations: () => [], - shouldConfirmExecute: vi.fn().mockResolvedValue(confirmation), - execute: vi.fn().mockResolvedValue({ llmContent: 'Result' }), - }); - - mockConnection.requestPermission.mockResolvedValueOnce({ - outcome: { - outcome: 'selected', - optionId: ToolConfirmationOutcome.ProceedOnce, - }, - }); - - const stream = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { - functionCalls: [{ name: 'test_tool', args: {} }], - }, - }, - ]); - const emptyStream = createMockStream([]); - - mockChat.sendMessageStream - .mockResolvedValueOnce(stream) - .mockResolvedValueOnce(emptyStream); - - await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Call tool' }], - }); - - expect(mockConnection.requestPermission).toHaveBeenCalledWith( - expect.objectContaining({ - toolCall: expect.objectContaining({ - content: expect.arrayContaining([ - expect.objectContaining({ - type: 'diff', - _meta: { kind: expectedKind }, - }), - ]), - }), - }), - ); - }; - - await checkDiffKind(addConfirmation, 'add'); - await checkDiffKind(modifyConfirmation, 'modify'); - await checkDiffKind(deleteConfirmation, 'delete'); - }); - - it('should handle @path resolution', async () => { - (path.resolve as unknown as Mock).mockReturnValue('/tmp/file.txt'); - (fs.stat as unknown as Mock).mockResolvedValue({ - isDirectory: () => false, - }); - - const stream = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - mockChat.sendMessageStream.mockResolvedValue(stream); - - await session.prompt({ - sessionId: 'session-1', - prompt: [ - { type: 'text', text: 'Read' }, - { - type: 'resource_link', - uri: 'file://file.txt', - mimeType: 'text/plain', - name: 'file.txt', - }, - ], - }); - - expect(path.resolve).toHaveBeenCalled(); - expect(fs.stat).toHaveBeenCalled(); - - expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( - expect.objectContaining({ - update: expect.objectContaining({ - sessionUpdate: 'tool_call_update', - status: 'completed', - title: 'Read files', - locations: [], - kind: 'read', - }), - }), - ); - - // Verify ReadManyFilesTool was used (implicitly by checking if sendMessageStream was called with resolved content) - // Since we mocked ReadManyFilesTool to return specific content, we can check the args passed to sendMessageStream - expect(mockChat.sendMessageStream).toHaveBeenCalledWith( - expect.anything(), - expect.arrayContaining([ - expect.objectContaining({ - text: expect.stringContaining('Content from @file.txt'), - }), - ]), - expect.anything(), - expect.any(AbortSignal), - LlmRole.MAIN, - ); - }); - - it('should handle @path resolution error', async () => { - (path.resolve as unknown as Mock).mockReturnValue('/tmp/error.txt'); - (fs.stat as unknown as Mock).mockResolvedValue({ - isDirectory: () => false, - }); - - const MockReadManyFilesTool = ReadManyFilesTool as unknown as Mock; - MockReadManyFilesTool.mockImplementationOnce(() => ({ - name: 'read_many_files', - kind: 'read', - build: vi.fn().mockReturnValue({ - getDescription: () => 'Read files', - toolLocations: () => [], - execute: vi.fn().mockRejectedValue(new Error('File read failed')), - }), - })); - - const stream = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - mockChat.sendMessageStream.mockResolvedValue(stream); - - await expect( - session.prompt({ - sessionId: 'session-1', - prompt: [ - { type: 'text', text: 'Read' }, - { - type: 'resource_link', - uri: 'file://error.txt', - mimeType: 'text/plain', - name: 'error.txt', - }, - ], - }), - ).rejects.toThrow('File read failed'); - - expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( - expect.objectContaining({ - update: expect.objectContaining({ - sessionUpdate: 'tool_call_update', - status: 'failed', - content: expect.arrayContaining([ - expect.objectContaining({ - content: expect.objectContaining({ - text: expect.stringMatching(/File read failed/), - }), - }), - ]), - kind: 'read', - }), - }), - ); - }); - - it('should handle @path validation error and bubble it to user', async () => { - mockConfig.getTargetDir.mockReturnValue('/workspace'); - (path.resolve as unknown as Mock).mockReturnValue('/tmp/disallowed.txt'); - mockConfig.validatePathAccess.mockReturnValue('Path is outside workspace'); - - // Force fs.stat to fail to skip direct reading and triggers the warning - (fs.stat as unknown as Mock).mockRejectedValue(new Error('File not found')); - - const stream = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - mockChat.sendMessageStream.mockResolvedValue(stream); - - await session.prompt({ - sessionId: 'session-1', - prompt: [ - { - type: 'resource_link', - uri: 'file://disallowed.txt', - mimeType: 'text/plain', - name: 'disallowed.txt', - }, - ], - }); - - // Verify warning sent via sendUpdate - expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( - expect.objectContaining({ - update: expect.objectContaining({ - sessionUpdate: 'agent_thought_chunk', - content: expect.objectContaining({ - text: expect.stringContaining( - 'Warning: skipping access to `disallowed.txt`. Reason: Path is outside workspace', - ), - }), - }), - }), - ); - }); - - it('should read absolute file directly if outside workspace', async () => { - mockConfig.getTargetDir.mockReturnValue('/workspace'); - const testFilePath = '/tmp/custom.txt'; - (path.resolve as unknown as Mock).mockReturnValue(testFilePath); - mockConfig.validatePathAccess.mockReturnValue('Path is outside workspace'); - - mockConnection.requestPermission.mockResolvedValue({ - outcome: { - outcome: 'selected', - optionId: ToolConfirmationOutcome.ProceedOnce, - }, - } as unknown as acp.RequestPermissionResponse); - - const mockStats = { - isFile: () => true, - isDirectory: () => false, - }; - (fs.stat as unknown as Mock).mockResolvedValue(mockStats); - (processSingleFileContent as unknown as Mock).mockResolvedValue({ - llmContent: 'Absolute File Content', - }); - - const stream = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - mockChat.sendMessageStream.mockResolvedValue(stream); - - await session.prompt({ - sessionId: 'session-1', - prompt: [ - { - type: 'resource_link', - uri: `file://${testFilePath}`, - mimeType: 'text/plain', - name: 'custom.txt', - }, - ], - }); - - expect(processSingleFileContent).toHaveBeenCalledWith( - testFilePath, - expect.anything(), - expect.anything(), - ); - - // Verify content appended to sendMessageStream parts - expect(mockChat.sendMessageStream).toHaveBeenCalledWith( - expect.anything(), - expect.arrayContaining([ - expect.objectContaining({ - text: 'Absolute File Content', - }), - ]), - expect.anything(), - expect.any(AbortSignal), - expect.anything(), - ); - }); - - it('should read escaping relative file directly if outside workspace', async () => { - mockConfig.getTargetDir.mockReturnValue('/workspace'); - const testFilePath = '../../custom.txt'; - (path.resolve as unknown as Mock).mockReturnValue('/custom.txt'); - mockConfig.validatePathAccess.mockReturnValue('Path is outside workspace'); - - mockConnection.requestPermission.mockResolvedValue({ - outcome: { - outcome: 'selected', - optionId: ToolConfirmationOutcome.ProceedOnce, - }, - } as unknown as acp.RequestPermissionResponse); - - const mockStats = { - isFile: () => true, - isDirectory: () => false, - }; - (fs.stat as unknown as Mock).mockResolvedValue(mockStats); - (processSingleFileContent as unknown as Mock).mockResolvedValue({ - llmContent: 'Escaping Relative File Content', - }); - - const stream = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - mockChat.sendMessageStream.mockResolvedValue(stream); - - await session.prompt({ - sessionId: 'session-1', - prompt: [ - { - type: 'resource_link', - uri: `file://${testFilePath}`, - mimeType: 'text/plain', - name: 'custom.txt', - }, - ], - }); - - expect(processSingleFileContent).toHaveBeenCalledWith( - '/custom.txt', - expect.any(String), - expect.anything(), - ); - - expect(mockChat.sendMessageStream).toHaveBeenCalledWith( - expect.anything(), - expect.arrayContaining([ - expect.objectContaining({ - text: 'Escaping Relative File Content', - }), - ]), - expect.anything(), - expect.any(AbortSignal), - expect.anything(), - ); - }); - - it('should handle cancellation during prompt', async () => { - let streamController: ReadableStreamDefaultController; - const stream = new ReadableStream({ - start(controller) { - streamController = controller; - }, - }); - - let streamStarted: (value: unknown) => void; - const streamStartedPromise = new Promise((resolve) => { - streamStarted = resolve; - }); - - // Adapt web stream to async iterable - async function* asyncStream() { - process.stdout.write('TEST: asyncStream started\n'); - streamStarted(true); - const reader = stream.getReader(); - try { - while (true) { - process.stdout.write('TEST: waiting for read\n'); - const { done, value } = await reader.read(); - process.stdout.write(`TEST: read returned done=${done}\n`); - if (done) break; - yield value; - } - } finally { - process.stdout.write('TEST: releasing lock\n'); - reader.releaseLock(); - } - } - - mockChat.sendMessageStream.mockResolvedValue(asyncStream()); - - process.stdout.write('TEST: calling prompt\n'); - const promptPromise = session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Hi' }], - }); - - process.stdout.write('TEST: waiting for streamStarted\n'); - await streamStartedPromise; - process.stdout.write('TEST: streamStarted\n'); - await session.cancelPendingPrompt(); - process.stdout.write('TEST: cancelled\n'); - - // Close the stream to allow prompt loop to continue and check aborted signal - streamController!.close(); - process.stdout.write('TEST: stream closed\n'); - - const result = await promptPromise; - process.stdout.write(`TEST: result received ${JSON.stringify(result)}\n`); - expect(result).toEqual({ stopReason: 'cancelled' }); - }); - - it('should handle rate limit error', async () => { - const error = new Error('Rate limit'); - (error as unknown as { status: number }).status = 429; - mockChat.sendMessageStream.mockRejectedValue(error); - - await expect( - session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Hi' }], - }), - ).rejects.toMatchObject({ - code: 429, - message: 'Rate limit exceeded. Try again later.', - }); - }); - - it('should handle tool execution error', async () => { - mockTool.build.mockReturnValue({ - getDescription: () => 'Test Tool', - toolLocations: () => [], - shouldConfirmExecute: vi.fn().mockResolvedValue(null), - execute: vi.fn().mockRejectedValue(new Error('Tool failed')), - }); - - const stream1 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { - functionCalls: [{ name: 'test_tool', args: {} }], - }, - }, - ]); - const stream2 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); - - await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Call tool' }], - }); - - expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( - expect.objectContaining({ - update: expect.objectContaining({ - sessionUpdate: 'tool_call_update', - status: 'failed', - content: expect.arrayContaining([ - expect.objectContaining({ - content: expect.objectContaining({ text: 'Tool failed' }), - }), - ]), - kind: 'read', - }), - }), - ); - }); - - it('should handle missing tool', async () => { - mockToolRegistry.getTool.mockReturnValue(undefined); - - const stream1 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { - functionCalls: [{ name: 'unknown_tool', args: {} }], - }, - }, - ]); - const stream2 = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); - - await session.prompt({ - sessionId: 'session-1', - prompt: [{ type: 'text', text: 'Call tool' }], - }); - - // Should send error response to model - expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); - const secondCallArgs = mockChat.sendMessageStream.mock.calls[1]; - const parts = secondCallArgs[1]; - expect(parts).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - functionResponse: expect.objectContaining({ - response: { - error: expect.stringContaining('not found in registry'), - }, - }), - }), - ]), - ); - }); - - it('should ignore files based on configuration', async () => { - ( - mockConfig.getFileService().shouldIgnoreFile as unknown as Mock - ).mockReturnValue(true); - const stream = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - mockChat.sendMessageStream.mockResolvedValue(stream); - - await session.prompt({ - sessionId: 'session-1', - prompt: [ - { - type: 'resource_link', - uri: 'file://ignored.txt', - mimeType: 'text/plain', - name: 'ignored.txt', - }, - ], - }); - - // Should not read file - expect(mockToolRegistry.getTool).not.toHaveBeenCalledWith( - 'read_many_files', - ); - }); - - it('should handle directory resolution with glob', async () => { - (path.resolve as unknown as Mock).mockReturnValue('/tmp/dir'); - (fs.stat as unknown as Mock).mockResolvedValue({ - isDirectory: () => true, - }); - - const stream = createMockStream([ - { - type: StreamEventType.CHUNK, - value: { candidates: [] }, - }, - ]); - mockChat.sendMessageStream.mockResolvedValue(stream); - - await session.prompt({ - sessionId: 'session-1', - prompt: [ - { - type: 'resource_link', - uri: 'file://dir', - mimeType: 'text/plain', - name: 'dir', - }, - ], - }); - - // Should use glob - // ReadManyFilesTool is instantiated directly, so we check if the mock instance's build method was called - const MockReadManyFilesTool = ReadManyFilesTool as unknown as Mock; - const mockInstance = - MockReadManyFilesTool.mock.results[ - MockReadManyFilesTool.mock.results.length - 1 - ].value; - expect(mockInstance.build).toHaveBeenCalled(); - }); - - it('should set mode on config', () => { - session.setMode(ApprovalMode.AUTO_EDIT); - expect(mockConfig.setApprovalMode).toHaveBeenCalledWith( - ApprovalMode.AUTO_EDIT, - ); - }); - - it('should throw error for invalid mode', () => { - expect(() => session.setMode('invalid-mode')).toThrow( - 'Invalid or unavailable mode: invalid-mode', - ); - }); - - it('should set model on config', () => { - session.setModel('gemini-2.0-flash-exp'); - expect(mockConfig.setModel).toHaveBeenCalledWith('gemini-2.0-flash-exp'); - }); - - it('should handle unquoted commands from autocomplete (with empty leading parts)', async () => { - // Mock handleCommand to verify it gets called - const handleCommandSpy = vi - .spyOn( - (session as unknown as { commandHandler: CommandHandler }) - .commandHandler, - 'handleCommand', - ) - .mockResolvedValue(true); - - await session.prompt({ - sessionId: 'session-1', - prompt: [ - { type: 'text', text: '' }, - { type: 'text', text: '/memory' }, - ], - }); - - expect(handleCommandSpy).toHaveBeenCalledWith('/memory', expect.anything()); - }); -}); diff --git a/packages/cli/src/acp/commandHandler.test.ts b/packages/cli/src/acp/acpCommandHandler.test.ts similarity index 94% rename from packages/cli/src/acp/commandHandler.test.ts rename to packages/cli/src/acp/acpCommandHandler.test.ts index 4a1ce6d2e5..7cc1670688 100644 --- a/packages/cli/src/acp/commandHandler.test.ts +++ b/packages/cli/src/acp/acpCommandHandler.test.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { CommandHandler } from './commandHandler.js'; +import { CommandHandler } from './acpCommandHandler.js'; import { describe, it, expect } from 'vitest'; describe('CommandHandler', () => { diff --git a/packages/cli/src/acp/commandHandler.ts b/packages/cli/src/acp/acpCommandHandler.ts similarity index 99% rename from packages/cli/src/acp/commandHandler.ts rename to packages/cli/src/acp/acpCommandHandler.ts index b35512adb2..6b171b5532 100644 --- a/packages/cli/src/acp/commandHandler.ts +++ b/packages/cli/src/acp/acpCommandHandler.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2025 Google LLC + * Copyright 2026 Google LLC * SPDX-License-Identifier: Apache-2.0 */ diff --git a/packages/cli/src/acp/acpErrors.test.ts b/packages/cli/src/acp/acpErrors.test.ts index 2ea4d528d0..1eeb78d59d 100644 --- a/packages/cli/src/acp/acpErrors.test.ts +++ b/packages/cli/src/acp/acpErrors.test.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2025 Google LLC + * Copyright 2026 Google LLC * SPDX-License-Identifier: Apache-2.0 */ diff --git a/packages/cli/src/acp/acpErrors.ts b/packages/cli/src/acp/acpErrors.ts index 57067115bf..c2988b8c4f 100644 --- a/packages/cli/src/acp/acpErrors.ts +++ b/packages/cli/src/acp/acpErrors.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2025 Google LLC + * Copyright 2026 Google LLC * SPDX-License-Identifier: Apache-2.0 */ diff --git a/packages/cli/src/acp/fileSystemService.test.ts b/packages/cli/src/acp/acpFileSystemService.test.ts similarity index 98% rename from packages/cli/src/acp/fileSystemService.test.ts rename to packages/cli/src/acp/acpFileSystemService.test.ts index 188aadbc09..7ddbb21537 100644 --- a/packages/cli/src/acp/fileSystemService.test.ts +++ b/packages/cli/src/acp/acpFileSystemService.test.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2025 Google LLC + * Copyright 2026 Google LLC * SPDX-License-Identifier: Apache-2.0 */ @@ -13,7 +13,7 @@ import { afterEach, type Mocked, } from 'vitest'; -import { AcpFileSystemService } from './fileSystemService.js'; +import { AcpFileSystemService } from './acpFileSystemService.js'; import type { AgentSideConnection } from '@agentclientprotocol/sdk'; import type { FileSystemService } from '@google/gemini-cli-core'; import os from 'node:os'; diff --git a/packages/cli/src/acp/fileSystemService.ts b/packages/cli/src/acp/acpFileSystemService.ts similarity index 98% rename from packages/cli/src/acp/fileSystemService.ts rename to packages/cli/src/acp/acpFileSystemService.ts index b020cd27f2..c11dc7f6cf 100644 --- a/packages/cli/src/acp/fileSystemService.ts +++ b/packages/cli/src/acp/acpFileSystemService.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2025 Google LLC + * Copyright 2026 Google LLC * SPDX-License-Identifier: Apache-2.0 */ diff --git a/packages/cli/src/acp/acpResume.test.ts b/packages/cli/src/acp/acpResume.test.ts index 6a92d68814..d8bbe7e5db 100644 --- a/packages/cli/src/acp/acpResume.test.ts +++ b/packages/cli/src/acp/acpResume.test.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2025 Google LLC + * Copyright 2026 Google LLC * SPDX-License-Identifier: Apache-2.0 */ @@ -13,7 +13,7 @@ import { type Mocked, type Mock, } from 'vitest'; -import { GeminiAgent } from './acpClient.js'; +import { GeminiAgent } from './acpRpcDispatcher.js'; import * as acp from '@agentclientprotocol/sdk'; import { ApprovalMode, @@ -28,6 +28,7 @@ import { } from '../utils/sessionUtils.js'; import { convertSessionToClientHistory } from '@google/gemini-cli-core'; import type { LoadedSettings } from '../config/settings.js'; +import { waitFor } from '../test-utils/async.js'; vi.mock('../config/config.js', () => ({ loadCliConfig: vi.fn(), @@ -106,6 +107,9 @@ describe('GeminiAgent Session Resume', () => { getHasAccessToPreviewModel: vi.fn().mockReturnValue(false), getGemini31LaunchedSync: vi.fn().mockReturnValue(false), getCheckpointingEnabled: vi.fn().mockReturnValue(false), + toolRegistry: { + getTool: vi.fn().mockReturnValue({ kind: 'read' }), + }, get config() { return this; }, @@ -170,11 +174,6 @@ describe('GeminiAgent Session Resume', () => { ], }; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (mockConfig as any).toolRegistry = { - getTool: vi.fn().mockReturnValue({ kind: 'read' }), - }; - (SessionSelector as unknown as Mock).mockImplementation(() => ({ resolveSession: vi.fn().mockResolvedValue({ sessionData, @@ -240,7 +239,7 @@ describe('GeminiAgent Session Resume', () => { }), ); - await vi.waitFor(() => { + await waitFor(() => { // User message expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( expect.objectContaining({ diff --git a/packages/cli/src/acp/acpRpcDispatcher.test.ts b/packages/cli/src/acp/acpRpcDispatcher.test.ts new file mode 100644 index 0000000000..a677c5631b --- /dev/null +++ b/packages/cli/src/acp/acpRpcDispatcher.test.ts @@ -0,0 +1,338 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + type Mock, + type Mocked, +} from 'vitest'; +import { GeminiAgent } from './acpRpcDispatcher.js'; +import * as acp from '@agentclientprotocol/sdk'; +import { + AuthType, + type Config, + type MessageBus, + type Storage, +} from '@google/gemini-cli-core'; +import type { LoadedSettings } from '../config/settings.js'; +import { loadCliConfig, type CliArgs } from '../config/config.js'; +import { loadSettings, SettingScope } from '../config/settings.js'; + +vi.mock('../config/config.js', () => ({ + loadCliConfig: vi.fn(), +})); + +vi.mock('../config/settings.js', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + loadSettings: vi.fn(), + }; +}); + +describe('GeminiAgent - RPC Dispatcher', () => { + let mockConfig: Mocked; + let mockSettings: Mocked; + let mockArgv: CliArgs; + let mockConnection: Mocked; + let agent: GeminiAgent; + + beforeEach(() => { + mockConfig = { + refreshAuth: vi.fn(), + initialize: vi.fn(), + waitForMcpInit: vi.fn(), + getFileSystemService: vi.fn(), + setFileSystemService: vi.fn(), + getContentGeneratorConfig: vi.fn(), + getActiveModel: vi.fn().mockReturnValue('gemini-pro'), + getModel: vi.fn().mockReturnValue('gemini-pro'), + getGeminiClient: vi.fn().mockReturnValue({ + startChat: vi.fn().mockResolvedValue({}), + }), + getMessageBus: vi.fn().mockReturnValue({ + publish: vi.fn(), + subscribe: vi.fn(), + unsubscribe: vi.fn(), + }), + getApprovalMode: vi.fn().mockReturnValue('default'), + isPlanEnabled: vi.fn().mockReturnValue(true), + getGemini31LaunchedSync: vi.fn().mockReturnValue(false), + getHasAccessToPreviewModel: vi.fn().mockReturnValue(false), + getCheckpointingEnabled: vi.fn().mockReturnValue(false), + getDisableAlwaysAllow: vi.fn().mockReturnValue(false), + validatePathAccess: vi.fn().mockReturnValue(null), + getWorkspaceContext: vi.fn().mockReturnValue({ + addReadOnlyPath: vi.fn(), + }), + getPolicyEngine: vi.fn().mockReturnValue({ + addRule: vi.fn(), + }), + messageBus: { + publish: vi.fn(), + subscribe: vi.fn(), + unsubscribe: vi.fn(), + } as unknown as MessageBus, + storage: { + getWorkspaceAutoSavedPolicyPath: vi.fn(), + getAutoSavedPolicyPath: vi.fn(), + } as unknown as Storage, + + get config() { + return this; + }, + } as unknown as Mocked; + mockSettings = { + merged: { + security: { auth: { selectedType: 'login_with_google' } }, + mcpServers: {}, + }, + setValue: vi.fn(), + } as unknown as Mocked; + mockArgv = {} as unknown as CliArgs; + mockConnection = { + sessionUpdate: vi.fn(), + requestPermission: vi.fn(), + } as unknown as Mocked; + + (loadCliConfig as unknown as Mock).mockResolvedValue(mockConfig); + (loadSettings as unknown as Mock).mockImplementation(() => ({ + merged: { + security: { + auth: { selectedType: AuthType.LOGIN_WITH_GOOGLE }, + enablePermanentToolApproval: true, + }, + mcpServers: {}, + }, + setValue: vi.fn(), + })); + + agent = new GeminiAgent(mockConfig, mockSettings, mockArgv, mockConnection); + }); + + it('should initialize correctly', async () => { + const response = await agent.initialize({ + clientCapabilities: { fs: { readTextFile: true, writeTextFile: true } }, + protocolVersion: 1, + }); + + expect(response.protocolVersion).toBe(acp.PROTOCOL_VERSION); + expect(response.authMethods).toHaveLength(4); + const gatewayAuth = response.authMethods?.find( + (m) => m.id === AuthType.GATEWAY, + ); + expect(gatewayAuth?._meta).toEqual({ + gateway: { + protocol: 'google', + restartRequired: 'false', + }, + }); + const geminiAuth = response.authMethods?.find( + (m) => m.id === AuthType.USE_GEMINI, + ); + expect(geminiAuth?._meta).toEqual({ + 'api-key': { + provider: 'google', + }, + }); + expect(response.agentCapabilities?.loadSession).toBe(true); + }); + + it('should authenticate correctly', async () => { + await agent.authenticate({ + methodId: AuthType.LOGIN_WITH_GOOGLE, + }); + + expect(mockConfig.refreshAuth).toHaveBeenCalledWith( + AuthType.LOGIN_WITH_GOOGLE, + undefined, + undefined, + undefined, + ); + expect(mockSettings.setValue).toHaveBeenCalledWith( + SettingScope.User, + 'security.auth.selectedType', + AuthType.LOGIN_WITH_GOOGLE, + ); + }); + + it('should authenticate correctly with api-key in _meta', async () => { + await agent.authenticate({ + methodId: AuthType.USE_GEMINI, + _meta: { + 'api-key': 'test-api-key', + }, + } as unknown as acp.AuthenticateRequest); + + expect(mockConfig.refreshAuth).toHaveBeenCalledWith( + AuthType.USE_GEMINI, + 'test-api-key', + undefined, + undefined, + ); + expect(mockSettings.setValue).toHaveBeenCalledWith( + SettingScope.User, + 'security.auth.selectedType', + AuthType.USE_GEMINI, + ); + }); + + it('should authenticate correctly with gateway method', async () => { + await agent.authenticate({ + methodId: AuthType.GATEWAY, + _meta: { + gateway: { + baseUrl: 'https://example.com', + headers: { Authorization: 'Bearer token' }, + }, + }, + } as unknown as acp.AuthenticateRequest); + + expect(mockConfig.refreshAuth).toHaveBeenCalledWith( + AuthType.GATEWAY, + undefined, + 'https://example.com', + { Authorization: 'Bearer token' }, + ); + expect(mockSettings.setValue).toHaveBeenCalledWith( + SettingScope.User, + 'security.auth.selectedType', + AuthType.GATEWAY, + ); + }); + + it('should throw acp.RequestError when gateway payload is malformed', async () => { + await expect( + agent.authenticate({ + methodId: AuthType.GATEWAY, + _meta: { + gateway: { + baseUrl: 123, + headers: { Authorization: 'Bearer token' }, + }, + }, + } as unknown as acp.AuthenticateRequest), + ).rejects.toThrow(/Malformed gateway payload/); + }); + + it('should cancel a session', async () => { + const mockSession = { + cancelPendingPrompt: vi.fn(), + }; + ( + agent as unknown as { sessionManager: { getSession: Mock } } + ).sessionManager = { + getSession: vi.fn().mockReturnValue(mockSession), + }; + + await agent.cancel({ sessionId: 'test-session-id' }); + + expect(mockSession.cancelPendingPrompt).toHaveBeenCalled(); + }); + + it('should throw error when cancelling non-existent session', async () => { + ( + agent as unknown as { sessionManager: { getSession: Mock } } + ).sessionManager = { + getSession: vi.fn().mockReturnValue(undefined), + }; + + await expect(agent.cancel({ sessionId: 'unknown' })).rejects.toThrow( + 'Session not found', + ); + }); + + it('should delegate prompt to session', async () => { + const mockSession = { + prompt: vi.fn().mockResolvedValue({ stopReason: 'end_turn' }), + }; + ( + agent as unknown as { sessionManager: { getSession: Mock } } + ).sessionManager = { + getSession: vi.fn().mockReturnValue(mockSession), + }; + + const result = await agent.prompt({ + sessionId: 'test-session-id', + prompt: [], + }); + + expect(mockSession.prompt).toHaveBeenCalled(); + expect(result).toMatchObject({ stopReason: 'end_turn' }); + }); + + it('should delegate setMode to session', async () => { + const mockSession = { + setMode: vi.fn().mockReturnValue({}), + }; + ( + agent as unknown as { sessionManager: { getSession: Mock } } + ).sessionManager = { + getSession: vi.fn().mockReturnValue(mockSession), + }; + + const result = await agent.setSessionMode({ + sessionId: 'test-session-id', + modeId: 'plan', + }); + + expect(mockSession.setMode).toHaveBeenCalledWith('plan'); + expect(result).toEqual({}); + }); + + it('should throw error when setting mode on non-existent session', async () => { + ( + agent as unknown as { sessionManager: { getSession: Mock } } + ).sessionManager = { + getSession: vi.fn().mockReturnValue(undefined), + }; + + await expect( + agent.setSessionMode({ + sessionId: 'unknown', + modeId: 'plan', + }), + ).rejects.toThrow('Session not found: unknown'); + }); + + it('should delegate setModel to session (unstable)', async () => { + const mockSession = { + setModel: vi.fn().mockReturnValue({}), + }; + ( + agent as unknown as { sessionManager: { getSession: Mock } } + ).sessionManager = { + getSession: vi.fn().mockReturnValue(mockSession), + }; + + const result = await agent.unstable_setSessionModel({ + sessionId: 'test-session-id', + modelId: 'gemini-2.0-pro-exp', + }); + + expect(mockSession.setModel).toHaveBeenCalledWith('gemini-2.0-pro-exp'); + expect(result).toEqual({}); + }); + + it('should throw error when setting model on non-existent session (unstable)', async () => { + ( + agent as unknown as { sessionManager: { getSession: Mock } } + ).sessionManager = { + getSession: vi.fn().mockReturnValue(undefined), + }; + + await expect( + agent.unstable_setSessionModel({ + sessionId: 'unknown', + modelId: 'gemini-2.0-pro-exp', + }), + ).rejects.toThrow('Session not found: unknown'); + }); +}); diff --git a/packages/cli/src/acp/acpRpcDispatcher.ts b/packages/cli/src/acp/acpRpcDispatcher.ts new file mode 100644 index 0000000000..97fb0d4011 --- /dev/null +++ b/packages/cli/src/acp/acpRpcDispatcher.ts @@ -0,0 +1,232 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + type AgentLoopContext, + AuthType, + clearCachedCredentialFile, + getVersion, +} from '@google/gemini-cli-core'; +import * as acp from '@agentclientprotocol/sdk'; +import { z } from 'zod'; +import { SettingScope, type LoadedSettings } from '../config/settings.js'; +import type { CliArgs } from '../config/config.js'; +import { getAcpErrorMessage } from './acpErrors.js'; +import { AcpSessionManager, type AuthDetails } from './acpSessionManager.js'; +import { hasMeta } from './acpUtils.js'; + +export class GeminiAgent { + private apiKey: string | undefined; + private baseUrl: string | undefined; + private customHeaders: Record | undefined; + private sessionManager: AcpSessionManager; + + constructor( + private context: AgentLoopContext, + private settings: LoadedSettings, + argv: CliArgs, + connection: acp.AgentSideConnection, + ) { + this.sessionManager = new AcpSessionManager(settings, argv, connection); + } + + async initialize( + args: acp.InitializeRequest, + ): Promise { + if (args.clientCapabilities) { + this.sessionManager.setClientCapabilities(args.clientCapabilities); + } + + const authMethods = [ + { + id: AuthType.LOGIN_WITH_GOOGLE, + name: 'Log in with Google', + description: 'Log in with your Google account', + }, + { + id: AuthType.USE_GEMINI, + name: 'Gemini API key', + description: 'Use an API key with Gemini Developer API', + _meta: { + 'api-key': { + provider: 'google', + }, + }, + }, + { + id: AuthType.USE_VERTEX_AI, + name: 'Vertex AI', + description: 'Use an API key with Vertex AI GenAI API', + }, + { + id: AuthType.GATEWAY, + name: 'AI API Gateway', + description: 'Use a custom AI API Gateway', + _meta: { + gateway: { + protocol: 'google', + restartRequired: 'false', + }, + }, + }, + ]; + + await this.context.config.initialize(); + const version = await getVersion(); + return { + protocolVersion: acp.PROTOCOL_VERSION, + authMethods, + agentInfo: { + name: 'gemini-cli', + title: 'Gemini CLI', + version, + }, + agentCapabilities: { + loadSession: true, + promptCapabilities: { + image: true, + audio: true, + embeddedContext: true, + }, + mcpCapabilities: { + http: true, + sse: true, + }, + }, + }; + } + + async authenticate(req: acp.AuthenticateRequest): Promise { + const { methodId } = req; + const method = z.nativeEnum(AuthType).parse(methodId); + const selectedAuthType = this.settings.merged.security.auth.selectedType; + + // Only clear credentials when switching to a different auth method + if (selectedAuthType && selectedAuthType !== method) { + await clearCachedCredentialFile(); + } + // Check for api-key in _meta + const meta = hasMeta(req) ? req._meta : undefined; + const apiKey = + typeof meta?.['api-key'] === 'string' ? meta['api-key'] : undefined; + + // Refresh auth with the requested method + // This will reuse existing credentials if they're valid, + // or perform new authentication if needed + try { + if (apiKey) { + this.apiKey = apiKey; + } + + // Extract gateway details if present + const gatewaySchema = z.object({ + baseUrl: z.string().optional(), + headers: z.record(z.string()).optional(), + }); + + let baseUrl: string | undefined; + let headers: Record | undefined; + + if (meta?.['gateway']) { + const result = gatewaySchema.safeParse(meta['gateway']); + if (result.success) { + baseUrl = result.data.baseUrl; + headers = result.data.headers; + } else { + throw new acp.RequestError( + -32602, + `Malformed gateway payload: ${result.error.message}`, + ); + } + } + + this.baseUrl = baseUrl; + this.customHeaders = headers; + + await this.context.config.refreshAuth( + method, + apiKey ?? this.apiKey, + baseUrl, + headers, + ); + } catch (e) { + throw new acp.RequestError(-32000, getAcpErrorMessage(e)); + } + this.settings.setValue( + SettingScope.User, + 'security.auth.selectedType', + method, + ); + } + + private getAuthDetails(): AuthDetails { + return { + apiKey: this.apiKey, + baseUrl: this.baseUrl, + customHeaders: this.customHeaders, + }; + } + + async newSession( + params: acp.NewSessionRequest, + ): Promise { + return this.sessionManager.newSession(params, this.getAuthDetails()); + } + + async loadSession( + params: acp.LoadSessionRequest, + ): Promise { + return this.sessionManager.loadSession(params, this.getAuthDetails()); + } + + async cancel(params: acp.CancelNotification): Promise { + const session = this.sessionManager.getSession(params.sessionId); + if (!session) { + throw new acp.RequestError( + -32602, + `Session not found: ${params.sessionId}`, + ); + } + await session.cancelPendingPrompt(); + } + + async prompt(params: acp.PromptRequest): Promise { + const session = this.sessionManager.getSession(params.sessionId); + if (!session) { + throw new acp.RequestError( + -32602, + `Session not found: ${params.sessionId}`, + ); + } + return session.prompt(params); + } + + async setSessionMode( + params: acp.SetSessionModeRequest, + ): Promise { + const session = this.sessionManager.getSession(params.sessionId); + if (!session) { + throw new acp.RequestError( + -32602, + `Session not found: ${params.sessionId}`, + ); + } + return session.setMode(params.modeId); + } + + async unstable_setSessionModel( + params: acp.SetSessionModelRequest, + ): Promise { + const session = this.sessionManager.getSession(params.sessionId); + if (!session) { + throw new acp.RequestError( + -32602, + `Session not found: ${params.sessionId}`, + ); + } + return session.setModel(params.modelId); + } +} diff --git a/packages/cli/src/acp/acpSession.test.ts b/packages/cli/src/acp/acpSession.test.ts new file mode 100644 index 0000000000..07639108ce --- /dev/null +++ b/packages/cli/src/acp/acpSession.test.ts @@ -0,0 +1,463 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, + type Mocked, +} from 'vitest'; +import { Session } from './acpSession.js'; +import type * as acp from '@agentclientprotocol/sdk'; +import { + StreamEventType, + ReadManyFilesTool, + type GeminiChat, + type Config, + type MessageBus, + LlmRole, + type GitService, + type ModelRouterService, + InvalidStreamError, +} from '@google/gemini-cli-core'; +import type { LoadedSettings } from '../config/settings.js'; +import * as fs from 'node:fs/promises'; +import * as path from 'node:path'; +import type { CommandHandler } from './acpCommandHandler.js'; + +vi.mock('node:fs/promises'); +vi.mock('node:path', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + resolve: vi.fn(), + }; +}); + +vi.mock( + '@google/gemini-cli-core', + async ( + importOriginal: () => Promise, + ) => { + const actual = await importOriginal(); + return { + ...actual, + updatePolicy: vi.fn(), + ReadManyFilesTool: vi.fn(), + logToolCall: vi.fn(), + processSingleFileContent: vi.fn(), + }; + }, +); + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +async function* createMockStream(items: any[]) { + for (const item of items) { + yield item; + } +} + +describe('Session', () => { + let mockChat: Mocked; + let mockConfig: Mocked; + let mockConnection: Mocked; + let session: Session; + let mockToolRegistry: { getTool: Mock }; + let mockTool: { kind: string; build: Mock }; + let mockMessageBus: Mocked; + + beforeEach(() => { + mockChat = { + sendMessageStream: vi.fn(), + addHistory: vi.fn(), + recordCompletedToolCalls: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + } as unknown as Mocked; + mockTool = { + kind: 'read', + build: vi.fn().mockReturnValue({ + getDescription: () => 'Test Tool', + toolLocations: () => [], + shouldConfirmExecute: vi.fn().mockResolvedValue(null), + execute: vi.fn().mockResolvedValue({ llmContent: 'Tool Result' }), + }), + }; + mockToolRegistry = { + getTool: vi.fn().mockReturnValue(mockTool), + }; + mockMessageBus = { + publish: vi.fn(), + subscribe: vi.fn(), + unsubscribe: vi.fn(), + } as unknown as Mocked; + mockConfig = { + getModel: vi.fn().mockReturnValue('gemini-pro'), + getActiveModel: vi.fn().mockReturnValue('gemini-pro'), + getModelRouterService: vi.fn().mockReturnValue({ + route: vi.fn().mockResolvedValue({ model: 'resolved-model' }), + }), + getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + getFileService: vi.fn().mockReturnValue({ + shouldIgnoreFile: vi.fn().mockReturnValue(false), + }), + getFileFilteringOptions: vi.fn().mockReturnValue({}), + getFileSystemService: vi.fn().mockReturnValue({}), + getTargetDir: vi.fn().mockReturnValue('/tmp'), + getEnableRecursiveFileSearch: vi.fn().mockReturnValue(false), + getDebugMode: vi.fn().mockReturnValue(false), + getMessageBus: vi.fn().mockReturnValue(mockMessageBus), + setApprovalMode: vi.fn(), + setModel: vi.fn(), + isPlanEnabled: vi.fn().mockReturnValue(true), + getCheckpointingEnabled: vi.fn().mockReturnValue(false), + getGitService: vi.fn().mockResolvedValue({} as GitService), + validatePathAccess: vi.fn().mockReturnValue(null), + getWorkspaceContext: vi.fn().mockReturnValue({ + addReadOnlyPath: vi.fn(), + }), + waitForMcpInit: vi.fn(), + getDisableAlwaysAllow: vi.fn().mockReturnValue(false), + get config() { + return this; + }, + get toolRegistry() { + return mockToolRegistry; + }, + } as unknown as Mocked; + mockConnection = { + sessionUpdate: vi.fn(), + requestPermission: vi.fn(), + } as unknown as Mocked; + + session = new Session('session-1', mockChat, mockConfig, mockConnection, { + merged: { + security: { enablePermanentToolApproval: true }, + mcpServers: {}, + }, + errors: [], + } as unknown as LoadedSettings); + + (ReadManyFilesTool as unknown as Mock).mockImplementation(() => ({ + name: 'read_many_files', + kind: 'read', + build: vi.fn().mockReturnValue({ + getDescription: () => 'Read files', + toolLocations: () => [], + execute: vi.fn().mockResolvedValue({ + llmContent: ['--- file.txt ---\n\nFile content\n\n'], + }), + }), + })); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('should send available commands', async () => { + await session.sendAvailableCommands(); + + expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( + expect.objectContaining({ + update: expect.objectContaining({ + sessionUpdate: 'available_commands_update', + }), + }), + ); + }); + + it('should await MCP initialization before processing a prompt', async () => { + const stream = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { candidates: [{ content: { parts: [{ text: 'Hi' }] } }] }, + }, + ]); + mockChat.sendMessageStream.mockResolvedValue(stream); + + await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'test' }], + }); + + expect(mockConfig.waitForMcpInit).toHaveBeenCalledOnce(); + }); + + it('should handle prompt with text response', async () => { + const stream = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { + candidates: [{ content: { parts: [{ text: 'Hello' }] } }], + }, + }, + ]); + mockChat.sendMessageStream.mockResolvedValue(stream); + + const result = await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Hi' }], + }); + + expect(mockChat.sendMessageStream).toHaveBeenCalled(); + expect(mockConnection.sessionUpdate).toHaveBeenCalledWith({ + sessionId: 'session-1', + update: { + sessionUpdate: 'agent_message_chunk', + content: { type: 'text', text: 'Hello' }, + }, + }); + expect(result).toMatchObject({ stopReason: 'end_turn' }); + }); + + it('should use model router to determine model', async () => { + const mockRouter = { + route: vi.fn().mockResolvedValue({ model: 'routed-model' }), + } as unknown as ModelRouterService; + mockConfig.getModelRouterService.mockReturnValue(mockRouter); + + const stream = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { + candidates: [{ content: { parts: [{ text: 'Hello' }] } }], + }, + }, + ]); + mockChat.sendMessageStream.mockResolvedValue(stream); + + await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Hi' }], + }); + + expect(mockRouter.route).toHaveBeenCalled(); + expect(mockChat.sendMessageStream).toHaveBeenCalledWith( + expect.objectContaining({ model: 'routed-model' }), + expect.any(Array), + expect.any(String), + expect.any(Object), + expect.any(String), + ); + }); + + it('should handle prompt with empty response (InvalidStreamError)', async () => { + mockChat.sendMessageStream.mockRejectedValue( + new InvalidStreamError('Empty response', 'NO_RESPONSE_TEXT'), + ); + + const result = await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Hi' }], + }); + + expect(result).toMatchObject({ stopReason: 'end_turn' }); + }); + + it('should handle prompt with no finish reason (InvalidStreamError)', async () => { + mockChat.sendMessageStream.mockRejectedValue( + new InvalidStreamError('No finish reason', 'NO_FINISH_REASON'), + ); + + const result = await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Hi' }], + }); + + expect(result).toMatchObject({ stopReason: 'end_turn' }); + }); + + it('should handle /memory command', async () => { + const handleCommandSpy = vi + .spyOn( + (session as unknown as { commandHandler: CommandHandler }) + .commandHandler, + 'handleCommand', + ) + .mockResolvedValue(true); + + const result = await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: '/memory view' }], + }); + + expect(result).toMatchObject({ stopReason: 'end_turn' }); + expect(handleCommandSpy).toHaveBeenCalledWith( + '/memory view', + expect.any(Object), + ); + }); + + it('should handle tool calls', async () => { + const stream1 = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { + functionCalls: [{ name: 'test_tool', args: { foo: 'bar' } }], + }, + }, + ]); + const stream2 = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { + candidates: [{ content: { parts: [{ text: 'Result' }] } }], + }, + }, + ]); + + mockChat.sendMessageStream + .mockResolvedValueOnce(stream1) + .mockResolvedValueOnce(stream2); + + const result = await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Call tool' }], + }); + + expect(mockToolRegistry.getTool).toHaveBeenCalledWith('test_tool'); + expect(result).toMatchObject({ stopReason: 'end_turn' }); + }); + + it('should handle tool call permission request', async () => { + const confirmationDetails = { + type: 'info', + onConfirm: vi.fn(), + }; + mockTool.build.mockReturnValue({ + getDescription: () => 'Test Tool', + toolLocations: () => [], + shouldConfirmExecute: vi.fn().mockResolvedValue(confirmationDetails), + execute: vi.fn().mockResolvedValue({ llmContent: 'Tool Result' }), + }); + + mockConnection.requestPermission.mockResolvedValue({ + outcome: { + outcome: 'selected', + optionId: 'proceed_once', + }, + }); + + const stream1 = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { + functionCalls: [{ name: 'test_tool', args: {} }], + }, + }, + ]); + const stream2 = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { candidates: [] }, + }, + ]); + + mockChat.sendMessageStream + .mockResolvedValueOnce(stream1) + .mockResolvedValueOnce(stream2); + + await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Call tool' }], + }); + + expect(mockConnection.requestPermission).toHaveBeenCalled(); + expect(confirmationDetails.onConfirm).toHaveBeenCalled(); + }); + + it('should handle @path resolution', async () => { + (path.resolve as unknown as Mock).mockReturnValue('/tmp/file.txt'); + (fs.stat as unknown as Mock).mockResolvedValue({ + isDirectory: () => false, + }); + + const stream = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { candidates: [] }, + }, + ]); + mockChat.sendMessageStream.mockResolvedValue(stream); + + await session.prompt({ + sessionId: 'session-1', + prompt: [ + { type: 'text', text: 'Read' }, + { + type: 'resource_link', + uri: 'file://file.txt', + mimeType: 'text/plain', + name: 'file.txt', + }, + ], + }); + + expect(path.resolve).toHaveBeenCalled(); + expect(fs.stat).toHaveBeenCalled(); + expect(mockChat.sendMessageStream).toHaveBeenCalledWith( + expect.anything(), + expect.arrayContaining([ + expect.objectContaining({ + text: expect.stringContaining('Content from @file.txt'), + }), + ]), + expect.anything(), + expect.any(AbortSignal), + LlmRole.MAIN, + ); + }); + + it('should handle rate limit error', async () => { + const error = new Error('Rate limit'); + (error as unknown as { status: number }).status = 429; + mockChat.sendMessageStream.mockRejectedValue(error); + + await expect( + session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Hi' }], + }), + ).rejects.toMatchObject({ + code: 429, + message: 'Rate limit exceeded. Try again later.', + }); + }); + + it('should handle missing tool', async () => { + mockToolRegistry.getTool.mockReturnValue(undefined); + + const stream1 = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { + functionCalls: [{ name: 'unknown_tool', args: {} }], + }, + }, + ]); + const stream2 = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { candidates: [] }, + }, + ]); + + mockChat.sendMessageStream + .mockResolvedValueOnce(stream1) + .mockResolvedValueOnce(stream2); + + await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Call tool' }], + }); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); + }); +}); diff --git a/packages/cli/src/acp/acpClient.ts b/packages/cli/src/acp/acpSession.ts similarity index 62% rename from packages/cli/src/acp/acpClient.ts rename to packages/cli/src/acp/acpSession.ts index 57c7790b05..db0c185007 100644 --- a/packages/cli/src/acp/acpClient.ts +++ b/packages/cli/src/acp/acpSession.ts @@ -1,27 +1,20 @@ /** * @license - * Copyright 2025 Google LLC + * Copyright 2026 Google LLC * SPDX-License-Identifier: Apache-2.0 */ import { - type Config, + type ApprovalMode, type GeminiChat, type ToolResult, - type ToolCallConfirmationDetails, - type FilterFilesOptions, type ConversationRecord, CoreToolCallStatus, - AuthType, logToolCall, convertToFunctionResponse, ToolConfirmationOutcome, - clearCachedCredentialFile, - isNodeError, - getErrorMessage, isWithinRoot, getErrorStatus, - MCPServerConfig, DiscoveredMCPTool, StreamEventType, ToolCallEvent, @@ -29,547 +22,42 @@ import { ReadManyFilesTool, REFERENCE_CONTENT_START, type RoutingContext, - createWorkingStdio, - startupProfiler, - Kind, partListUnionToString, LlmRole, - ApprovalMode, - getVersion, - convertSessionToClientHistory, - DEFAULT_GEMINI_MODEL, - DEFAULT_GEMINI_FLASH_MODEL, - DEFAULT_GEMINI_FLASH_LITE_MODEL, - PREVIEW_GEMINI_MODEL, - PREVIEW_GEMINI_3_1_MODEL, - PREVIEW_GEMINI_3_1_FLASH_LITE_MODEL, - PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, - PREVIEW_GEMINI_FLASH_MODEL, - DEFAULT_GEMINI_MODEL_AUTO, - PREVIEW_GEMINI_MODEL_AUTO, - getDisplayString, processSingleFileContent, InvalidStreamError, type AgentLoopContext, updatePolicy, + isNodeError, + getErrorMessage, + type FilterFilesOptions, + isTextPart, } from '@google/gemini-cli-core'; import * as acp from '@agentclientprotocol/sdk'; -import { AcpFileSystemService } from './fileSystemService.js'; -import { getAcpErrorMessage } from './acpErrors.js'; -import { Readable, Writable } from 'node:stream'; - -function hasMeta(obj: unknown): obj is { _meta?: Record } { - return typeof obj === 'object' && obj !== null && '_meta' in obj; -} import type { Content, Part, FunctionCall } from '@google/genai'; -import { - SettingScope, - loadSettings, - type LoadedSettings, -} from '../config/settings.js'; -import { createPolicyUpdater } from '../config/policy.js'; +import type { LoadedSettings } from '../config/settings.js'; import * as fs from 'node:fs/promises'; import * as path from 'node:path'; -import { z } from 'zod'; - import { randomUUID } from 'node:crypto'; -import { loadCliConfig, type CliArgs } from '../config/config.js'; -import { runExitCleanup } from '../utils/cleanup.js'; -import { SessionSelector } from '../utils/sessionUtils.js'; -import { startAutoMemoryIfEnabled } from '../utils/autoMemory.js'; - -import { CommandHandler } from './commandHandler.js'; - -const RequestPermissionResponseSchema = z.object({ - outcome: z.discriminatedUnion('outcome', [ - z.object({ outcome: z.literal('cancelled') }), - z.object({ - outcome: z.literal('selected'), - optionId: z.string(), - }), - ]), -}); - -export async function runAcpClient( - config: Config, - settings: LoadedSettings, - argv: CliArgs, -) { - // ... (skip unchanged lines) ... - - const { stdout: workingStdout } = createWorkingStdio(); - const stdout = Writable.toWeb(workingStdout) as WritableStream; - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const stdin = Readable.toWeb(process.stdin) as ReadableStream; - - const stream = acp.ndJsonStream(stdout, stdin); - const connection = new acp.AgentSideConnection( - (connection) => new GeminiAgent(config, settings, argv, connection), - stream, - ); - - // SIGTERM/SIGINT handlers (in sdk.ts) don't fire when stdin closes. - // We must explicitly await the connection close to flush telemetry. - // Use finally() to ensure cleanup runs even on stream errors. - await connection.closed.finally(runExitCleanup); -} - -export class GeminiAgent { - private static callIdCounter = 0; - - static generateCallId(name: string): string { - return `${name}-${Date.now()}-${++GeminiAgent.callIdCounter}`; - } - - private sessions: Map = new Map(); - private clientCapabilities: acp.ClientCapabilities | undefined; - private apiKey: string | undefined; - private baseUrl: string | undefined; - private customHeaders: Record | undefined; - - constructor( - private context: AgentLoopContext, - private settings: LoadedSettings, - private argv: CliArgs, - private connection: acp.AgentSideConnection, - ) {} - - async initialize( - args: acp.InitializeRequest, - ): Promise { - this.clientCapabilities = args.clientCapabilities; - - const authMethods = [ - { - id: AuthType.LOGIN_WITH_GOOGLE, - name: 'Log in with Google', - description: 'Log in with your Google account', - }, - { - id: AuthType.USE_GEMINI, - name: 'Gemini API key', - description: 'Use an API key with Gemini Developer API', - _meta: { - 'api-key': { - provider: 'google', - }, - }, - }, - { - id: AuthType.USE_VERTEX_AI, - name: 'Vertex AI', - description: 'Use an API key with Vertex AI GenAI API', - }, - { - id: AuthType.GATEWAY, - name: 'AI API Gateway', - description: 'Use a custom AI API Gateway', - _meta: { - gateway: { - protocol: 'google', - restartRequired: 'false', - }, - }, - }, - ]; - - await this.context.config.initialize(); - const version = await getVersion(); - return { - protocolVersion: acp.PROTOCOL_VERSION, - authMethods, - agentInfo: { - name: 'gemini-cli', - title: 'Gemini CLI', - version, - }, - agentCapabilities: { - loadSession: true, - promptCapabilities: { - image: true, - audio: true, - embeddedContext: true, - }, - mcpCapabilities: { - http: true, - sse: true, - }, - }, - }; - } - - async authenticate(req: acp.AuthenticateRequest): Promise { - const { methodId } = req; - const method = z.nativeEnum(AuthType).parse(methodId); - const selectedAuthType = this.settings.merged.security.auth.selectedType; - - // Only clear credentials when switching to a different auth method - if (selectedAuthType && selectedAuthType !== method) { - await clearCachedCredentialFile(); - } - // Check for api-key in _meta - const meta = hasMeta(req) ? req._meta : undefined; - const apiKey = - typeof meta?.['api-key'] === 'string' ? meta['api-key'] : undefined; - - // Refresh auth with the requested method - // This will reuse existing credentials if they're valid, - // or perform new authentication if needed - try { - if (apiKey) { - this.apiKey = apiKey; - } - - // Extract gateway details if present - const gatewaySchema = z.object({ - baseUrl: z.string().optional(), - headers: z.record(z.string()).optional(), - }); - - let baseUrl: string | undefined; - let headers: Record | undefined; - - if (meta?.['gateway']) { - const result = gatewaySchema.safeParse(meta['gateway']); - if (result.success) { - baseUrl = result.data.baseUrl; - headers = result.data.headers; - } else { - throw new acp.RequestError( - -32602, - `Malformed gateway payload: ${result.error.message}`, - ); - } - } - - this.baseUrl = baseUrl; - this.customHeaders = headers; - - await this.context.config.refreshAuth( - method, - apiKey ?? this.apiKey, - baseUrl, - headers, - ); - } catch (e) { - throw new acp.RequestError(-32000, getAcpErrorMessage(e)); - } - this.settings.setValue( - SettingScope.User, - 'security.auth.selectedType', - method, - ); - } - - async newSession({ - cwd, - mcpServers, - }: acp.NewSessionRequest): Promise { - const sessionId = randomUUID(); - const loadedSettings = loadSettings(cwd); - const config = await this.newSessionConfig( - sessionId, - cwd, - mcpServers, - loadedSettings, - ); - - const authType = - loadedSettings.merged.security.auth.selectedType || AuthType.USE_GEMINI; - - let isAuthenticated = false; - let authErrorMessage = ''; - try { - await config.refreshAuth( - authType, - this.apiKey, - this.baseUrl, - this.customHeaders, - ); - isAuthenticated = true; - - // Extra validation for Gemini API key - const contentGeneratorConfig = config.getContentGeneratorConfig(); - if ( - authType === AuthType.USE_GEMINI && - (!contentGeneratorConfig || !contentGeneratorConfig.apiKey) - ) { - isAuthenticated = false; - authErrorMessage = 'Gemini API key is missing or not configured.'; - } - } catch (e) { - isAuthenticated = false; - authErrorMessage = getAcpErrorMessage(e); - debugLogger.error( - `Authentication failed: ${e instanceof Error ? e.stack : e}`, - ); - } - - if (!isAuthenticated) { - throw new acp.RequestError( - -32000, - authErrorMessage || 'Authentication required.', - ); - } - - if (this.clientCapabilities?.fs) { - const acpFileSystemService = new AcpFileSystemService( - this.connection, - sessionId, - this.clientCapabilities.fs, - config.getFileSystemService(), - cwd, - ); - config.setFileSystemService(acpFileSystemService); - } - - await config.initialize(); - startupProfiler.flush(config); - startAutoMemoryIfEnabled(config); - - const geminiClient = config.getGeminiClient(); - const chat = await geminiClient.startChat(); - - const session = new Session( - sessionId, - chat, - config, - this.connection, - this.settings, - ); - this.sessions.set(sessionId, session); - - setTimeout(() => { - // eslint-disable-next-line @typescript-eslint/no-floating-promises - session.sendAvailableCommands(); - }, 0); - - const { availableModels, currentModelId } = buildAvailableModels( - config, - loadedSettings, - ); - - const response = { - sessionId, - modes: { - availableModes: buildAvailableModes(config.isPlanEnabled()), - currentModeId: config.getApprovalMode(), - }, - models: { - availableModels, - currentModelId, - }, - }; - return response; - } - - async loadSession({ - sessionId, - cwd, - mcpServers, - }: acp.LoadSessionRequest): Promise { - const config = await this.initializeSessionConfig( - sessionId, - cwd, - mcpServers, - ); - - const sessionSelector = new SessionSelector(config.storage); - const { sessionData, sessionPath } = - await sessionSelector.resolveSession(sessionId); - - const clientHistory = convertSessionToClientHistory(sessionData.messages); - - const geminiClient = config.getGeminiClient(); - await geminiClient.initialize(); - await geminiClient.resumeChat(clientHistory, { - conversation: sessionData, - filePath: sessionPath, - }); - - const session = new Session( - sessionId, - geminiClient.getChat(), - config, - this.connection, - this.settings, - ); - this.sessions.set(sessionId, session); - - // Stream history back to client - // eslint-disable-next-line @typescript-eslint/no-floating-promises - session.streamHistory(sessionData.messages); - - setTimeout(() => { - // eslint-disable-next-line @typescript-eslint/no-floating-promises - session.sendAvailableCommands(); - }, 0); - - const { availableModels, currentModelId } = buildAvailableModels( - config, - this.settings, - ); - - const response = { - modes: { - availableModes: buildAvailableModes(config.isPlanEnabled()), - currentModeId: config.getApprovalMode(), - }, - models: { - availableModels, - currentModelId, - }, - }; - return response; - } - - private async initializeSessionConfig( - sessionId: string, - cwd: string, - mcpServers: acp.McpServer[], - ): Promise { - const selectedAuthType = this.settings.merged.security.auth.selectedType; - if (!selectedAuthType) { - throw acp.RequestError.authRequired(); - } - - // 1. Create config WITHOUT initializing it (no MCP servers started yet) - const config = await this.newSessionConfig(sessionId, cwd, mcpServers); - - // 2. Authenticate BEFORE initializing configuration or starting MCP servers. - // This satisfies the security requirement to verify the user before executing - // potentially unsafe server definitions. - try { - await config.refreshAuth( - selectedAuthType, - this.apiKey, - this.baseUrl, - this.customHeaders, - ); - } catch (e) { - debugLogger.error(`Authentication failed: ${e}`); - throw acp.RequestError.authRequired(); - } - - // 3. Set the ACP FileSystemService (if supported) before config initialization - if (this.clientCapabilities?.fs) { - const acpFileSystemService = new AcpFileSystemService( - this.connection, - sessionId, - this.clientCapabilities.fs, - config.getFileSystemService(), - cwd, - ); - config.setFileSystemService(acpFileSystemService); - } - - // 4. Now that we are authenticated, it is safe to initialize the config - // which starts the MCP servers and other heavy resources. - await config.initialize(); - startupProfiler.flush(config); - startAutoMemoryIfEnabled(config); - - return config; - } - - async newSessionConfig( - sessionId: string, - cwd: string, - mcpServers: acp.McpServer[], - loadedSettings?: LoadedSettings, - ): Promise { - const currentSettings = loadedSettings || this.settings; - const mergedMcpServers = { ...currentSettings.merged.mcpServers }; - - for (const server of mcpServers) { - if ( - 'type' in server && - (server.type === 'sse' || server.type === 'http') - ) { - // HTTP or SSE MCP server - const headers = Object.fromEntries( - server.headers.map(({ name, value }) => [name, value]), - ); - mergedMcpServers[server.name] = new MCPServerConfig( - undefined, // command - undefined, // args - undefined, // env - undefined, // cwd - server.type === 'sse' ? server.url : undefined, // url (sse) - server.type === 'http' ? server.url : undefined, // httpUrl - headers, - ); - } else if ('command' in server) { - // Stdio MCP server - const env: Record = {}; - for (const { name: envName, value } of server.env) { - env[envName] = value; - } - mergedMcpServers[server.name] = new MCPServerConfig( - server.command, - server.args, - env, - cwd, - ); - } - } - - const settings = { - ...currentSettings.merged, - mcpServers: mergedMcpServers, - }; - - const config = await loadCliConfig(settings, sessionId, this.argv, { cwd }); - - createPolicyUpdater( - config.getPolicyEngine(), - config.messageBus, - config.storage, - ); - - return config; - } - - async cancel(params: acp.CancelNotification): Promise { - const session = this.sessions.get(params.sessionId); - if (!session) { - throw new Error(`Session not found: ${params.sessionId}`); - } - await session.cancelPendingPrompt(); - } - - async prompt(params: acp.PromptRequest): Promise { - const session = this.sessions.get(params.sessionId); - if (!session) { - throw new Error(`Session not found: ${params.sessionId}`); - } - return session.prompt(params); - } - - async setSessionMode( - params: acp.SetSessionModeRequest, - ): Promise { - const session = this.sessions.get(params.sessionId); - if (!session) { - throw new Error(`Session not found: ${params.sessionId}`); - } - return session.setMode(params.modeId); - } - - async unstable_setSessionModel( - params: acp.SetSessionModelRequest, - ): Promise { - const session = this.sessions.get(params.sessionId); - if (!session) { - throw new Error(`Session not found: ${params.sessionId}`); - } - return session.setModel(params.modelId); - } -} +import { CommandHandler } from './acpCommandHandler.js'; +import { + toToolCallContent, + toPermissionOptions, + toAcpToolKind, + buildAvailableModes, + RequestPermissionResponseSchema, +} from './acpUtils.js'; +import { z } from 'zod'; +import { getAcpErrorMessage } from './acpErrors.js'; export class Session { private pendingPrompt: AbortController | null = null; private commandHandler = new CommandHandler(); + private callIdCounter = 0; + + private generateCallId(name: string): string { + return `${name}-${Date.now()}-${++this.callIdCounter}`; + } constructor( private readonly id: string, @@ -709,13 +197,10 @@ export class Session { for (const part of parts) { if (typeof part === 'object' && part !== null) { - if ('text' in part) { + if (isTextPart(part)) { // It is a text part - // eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-assignment, @typescript-eslint/no-unsafe-type-assertion - const text = (part as any).text; - if (typeof text === 'string') { - commandText += text; - } + const text = part.text; + commandText += text; } else { // Non-text part (image, embedded resource) // Stop looking for command @@ -972,7 +457,7 @@ export class Session { promptId: string, fc: FunctionCall, ): Promise { - const callId = fc.id ?? GeminiAgent.generateCallId(fc.name || 'unknown'); + const callId = fc.id ?? this.generateCallId(fc.name || 'unknown'); const args = fc.args ?? {}; const startTime = Date.now(); @@ -1661,7 +1146,7 @@ export class Session { include: pathSpecsToRead, }; - const callId = GeminiAgent.generateCallId(readManyFilesTool.name); + const callId = this.generateCallId(readManyFilesTool.name); try { const invocation = readManyFilesTool.build(toolArgs); @@ -1799,333 +1284,3 @@ export class Session { } } } - -function toToolCallContent(toolResult: ToolResult): acp.ToolCallContent | null { - if (toolResult.error?.message) { - throw new Error(toolResult.error.message); - } - - if (toolResult.returnDisplay) { - if (typeof toolResult.returnDisplay === 'string') { - return { - type: 'content', - content: { type: 'text', text: toolResult.returnDisplay }, - }; - } else { - if ('fileName' in toolResult.returnDisplay) { - return { - type: 'diff', - path: - toolResult.returnDisplay.filePath ?? - toolResult.returnDisplay.fileName, - oldText: toolResult.returnDisplay.originalContent, - newText: toolResult.returnDisplay.newContent, - _meta: { - kind: !toolResult.returnDisplay.originalContent - ? 'add' - : toolResult.returnDisplay.newContent === '' - ? 'delete' - : 'modify', - }, - }; - } - return null; - } - } else { - return null; - } -} - -const basicPermissionOptions = [ - { - optionId: ToolConfirmationOutcome.ProceedOnce, - name: 'Allow', - kind: 'allow_once', - }, - { - optionId: ToolConfirmationOutcome.Cancel, - name: 'Reject', - kind: 'reject_once', - }, -] as const; - -function toPermissionOptions( - confirmation: ToolCallConfirmationDetails, - config: Config, - enablePermanentToolApproval: boolean = false, -): acp.PermissionOption[] { - const disableAlwaysAllow = config.getDisableAlwaysAllow(); - const options: acp.PermissionOption[] = []; - - if (!disableAlwaysAllow) { - switch (confirmation.type) { - case 'edit': - options.push({ - optionId: ToolConfirmationOutcome.ProceedAlways, - name: 'Allow for this session', - kind: 'allow_always', - }); - if (enablePermanentToolApproval) { - options.push({ - optionId: ToolConfirmationOutcome.ProceedAlwaysAndSave, - name: 'Allow for this file in all future sessions', - kind: 'allow_always', - }); - } - break; - case 'exec': - options.push({ - optionId: ToolConfirmationOutcome.ProceedAlways, - name: 'Allow for this session', - kind: 'allow_always', - }); - if (enablePermanentToolApproval) { - options.push({ - optionId: ToolConfirmationOutcome.ProceedAlwaysAndSave, - name: 'Allow this command for all future sessions', - kind: 'allow_always', - }); - } - break; - case 'mcp': - options.push( - { - optionId: ToolConfirmationOutcome.ProceedAlwaysServer, - name: 'Allow all server tools for this session', - kind: 'allow_always', - }, - { - optionId: ToolConfirmationOutcome.ProceedAlwaysTool, - name: 'Allow tool for this session', - kind: 'allow_always', - }, - ); - if (enablePermanentToolApproval) { - options.push({ - optionId: ToolConfirmationOutcome.ProceedAlwaysAndSave, - name: 'Allow tool for all future sessions', - kind: 'allow_always', - }); - } - break; - case 'info': - options.push({ - optionId: ToolConfirmationOutcome.ProceedAlways, - name: 'Allow for this session', - kind: 'allow_always', - }); - if (enablePermanentToolApproval) { - options.push({ - optionId: ToolConfirmationOutcome.ProceedAlwaysAndSave, - name: 'Allow for all future sessions', - kind: 'allow_always', - }); - } - break; - case 'ask_user': - case 'exit_plan_mode': - // askuser and exit_plan_mode don't need "always allow" options - break; - default: - // No "always allow" options for other types - break; - } - } - - options.push(...basicPermissionOptions); - - // Exhaustive check - switch (confirmation.type) { - case 'edit': - case 'exec': - case 'mcp': - case 'info': - case 'ask_user': - case 'exit_plan_mode': - case 'sandbox_expansion': - break; - default: { - const unreachable: never = confirmation; - throw new Error(`Unexpected: ${unreachable}`); - } - } - - return options; -} - -/** - * Maps our internal tool kind to the ACP ToolKind. - * Fallback to 'other' for kinds that are not supported by the ACP protocol. - */ -function toAcpToolKind(kind: Kind): acp.ToolKind { - switch (kind) { - case Kind.Read: - case Kind.Edit: - case Kind.Execute: - case Kind.Search: - case Kind.Delete: - case Kind.Move: - case Kind.Think: - case Kind.Fetch: - case Kind.SwitchMode: - case Kind.Other: - return kind as acp.ToolKind; - case Kind.Agent: - return 'think'; - case Kind.Plan: - case Kind.Communicate: - default: - return 'other'; - } -} - -function buildAvailableModes(isPlanEnabled: boolean): acp.SessionMode[] { - const modes: acp.SessionMode[] = [ - { - id: ApprovalMode.DEFAULT, - name: 'Default', - description: 'Prompts for approval', - }, - { - id: ApprovalMode.AUTO_EDIT, - name: 'Auto Edit', - description: 'Auto-approves edit tools', - }, - { - id: ApprovalMode.YOLO, - name: 'YOLO', - description: 'Auto-approves all tools', - }, - ]; - - if (isPlanEnabled) { - modes.push({ - id: ApprovalMode.PLAN, - name: 'Plan', - description: 'Read-only mode', - }); - } - - return modes; -} - -function buildAvailableModels( - config: Config, - settings: LoadedSettings, -): { - availableModels: Array<{ - modelId: string; - name: string; - description?: string; - }>; - currentModelId: string; -} { - const preferredModel = config.getModel() || DEFAULT_GEMINI_MODEL_AUTO; - const shouldShowPreviewModels = config.getHasAccessToPreviewModel(); - const useGemini31 = config.getGemini31LaunchedSync?.() ?? false; - const useGemini31FlashLite = - config.getGemini31FlashLiteLaunchedSync?.() ?? false; - const selectedAuthType = settings.merged.security.auth.selectedType; - const useCustomToolModel = - useGemini31 && selectedAuthType === AuthType.USE_GEMINI; - - // --- DYNAMIC PATH --- - if ( - config.getExperimentalDynamicModelConfiguration?.() === true && - config.getModelConfigService - ) { - const options = config.getModelConfigService().getAvailableModelOptions({ - useGemini3_1: useGemini31, - useGemini3_1FlashLite: useGemini31FlashLite, - useCustomTools: useCustomToolModel, - hasAccessToPreview: shouldShowPreviewModels, - }); - - return { - availableModels: options, - currentModelId: preferredModel, - }; - } - - // --- LEGACY PATH --- - const mainOptions = [ - { - value: DEFAULT_GEMINI_MODEL_AUTO, - title: getDisplayString(DEFAULT_GEMINI_MODEL_AUTO), - description: - 'Let Gemini CLI decide the best model for the task: gemini-2.5-pro, gemini-2.5-flash', - }, - ]; - - if (shouldShowPreviewModels) { - mainOptions.unshift({ - value: PREVIEW_GEMINI_MODEL_AUTO, - title: getDisplayString(PREVIEW_GEMINI_MODEL_AUTO), - description: useGemini31 - ? 'Let Gemini CLI decide the best model for the task: gemini-3.1-pro, gemini-3-flash' - : 'Let Gemini CLI decide the best model for the task: gemini-3-pro, gemini-3-flash', - }); - } - - const manualOptions = [ - { - value: DEFAULT_GEMINI_MODEL, - title: getDisplayString(DEFAULT_GEMINI_MODEL), - }, - { - value: DEFAULT_GEMINI_FLASH_MODEL, - title: getDisplayString(DEFAULT_GEMINI_FLASH_MODEL), - }, - { - value: DEFAULT_GEMINI_FLASH_LITE_MODEL, - title: getDisplayString(DEFAULT_GEMINI_FLASH_LITE_MODEL), - }, - ]; - - if (shouldShowPreviewModels) { - const previewProModel = useGemini31 - ? PREVIEW_GEMINI_3_1_MODEL - : PREVIEW_GEMINI_MODEL; - - const previewProValue = useCustomToolModel - ? PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL - : previewProModel; - - const previewOptions = [ - { - value: previewProValue, - title: getDisplayString(previewProModel), - }, - { - value: PREVIEW_GEMINI_FLASH_MODEL, - title: getDisplayString(PREVIEW_GEMINI_FLASH_MODEL), - }, - ]; - - if (useGemini31FlashLite) { - previewOptions.push({ - value: PREVIEW_GEMINI_3_1_FLASH_LITE_MODEL, - title: getDisplayString(PREVIEW_GEMINI_3_1_FLASH_LITE_MODEL), - }); - } - - manualOptions.unshift(...previewOptions); - } - - const scaleOptions = ( - options: Array<{ value: string; title: string; description?: string }>, - ) => - options.map((o) => ({ - modelId: o.value, - name: o.title, - description: o.description, - })); - - return { - availableModels: [ - ...scaleOptions(mainOptions), - ...scaleOptions(manualOptions), - ], - currentModelId: preferredModel, - }; -} diff --git a/packages/cli/src/acp/acpSessionManager.test.ts b/packages/cli/src/acp/acpSessionManager.test.ts new file mode 100644 index 0000000000..81a556a952 --- /dev/null +++ b/packages/cli/src/acp/acpSessionManager.test.ts @@ -0,0 +1,386 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, + type Mocked, +} from 'vitest'; +import { AcpSessionManager } from './acpSessionManager.js'; +import type * as acp from '@agentclientprotocol/sdk'; +import { + AuthType, + type Config, + type MessageBus, + type Storage, +} from '@google/gemini-cli-core'; +import type { LoadedSettings } from '../config/settings.js'; +import { loadCliConfig, type CliArgs } from '../config/config.js'; +import { loadSettings } from '../config/settings.js'; + +vi.mock('../config/config.js', () => ({ + loadCliConfig: vi.fn(), +})); + +vi.mock('../config/settings.js', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + loadSettings: vi.fn(), + }; +}); + +const startAutoMemoryIfEnabledMock = vi.fn(); +vi.mock('../utils/autoMemory.js', () => ({ + startAutoMemoryIfEnabled: (config: Config) => + startAutoMemoryIfEnabledMock(config), +})); + +describe('AcpSessionManager', () => { + let mockConfig: Mocked; + let mockSettings: Mocked; + let mockArgv: CliArgs; + let mockConnection: Mocked; + let manager: AcpSessionManager; + + beforeEach(() => { + mockConfig = { + refreshAuth: vi.fn(), + initialize: vi.fn(), + waitForMcpInit: vi.fn(), + getFileSystemService: vi.fn(), + setFileSystemService: vi.fn(), + getContentGeneratorConfig: vi.fn(), + getActiveModel: vi.fn().mockReturnValue('gemini-pro'), + getModel: vi.fn().mockReturnValue('gemini-pro'), + getGeminiClient: vi.fn().mockReturnValue({ + startChat: vi.fn().mockResolvedValue({}), + }), + getMessageBus: vi.fn().mockReturnValue({ + publish: vi.fn(), + subscribe: vi.fn(), + unsubscribe: vi.fn(), + }), + getApprovalMode: vi.fn().mockReturnValue('default'), + isPlanEnabled: vi.fn().mockReturnValue(true), + getGemini31LaunchedSync: vi.fn().mockReturnValue(false), + getHasAccessToPreviewModel: vi.fn().mockReturnValue(false), + getCheckpointingEnabled: vi.fn().mockReturnValue(false), + getDisableAlwaysAllow: vi.fn().mockReturnValue(false), + validatePathAccess: vi.fn().mockReturnValue(null), + getWorkspaceContext: vi.fn().mockReturnValue({ + addReadOnlyPath: vi.fn(), + }), + getPolicyEngine: vi.fn().mockReturnValue({ + addRule: vi.fn(), + }), + messageBus: { + publish: vi.fn(), + subscribe: vi.fn(), + unsubscribe: vi.fn(), + } as unknown as MessageBus, + storage: { + getWorkspaceAutoSavedPolicyPath: vi.fn(), + getAutoSavedPolicyPath: vi.fn(), + } as unknown as Storage, + + get config() { + return this; + }, + } as unknown as Mocked; + mockSettings = { + merged: { + security: { auth: { selectedType: 'login_with_google' } }, + mcpServers: {}, + }, + setValue: vi.fn(), + } as unknown as Mocked; + mockArgv = {} as unknown as CliArgs; + mockConnection = { + sessionUpdate: vi.fn(), + requestPermission: vi.fn(), + } as unknown as Mocked; + + (loadCliConfig as unknown as Mock).mockResolvedValue(mockConfig); + (loadSettings as unknown as Mock).mockImplementation(() => ({ + merged: { + security: { + auth: { selectedType: AuthType.LOGIN_WITH_GOOGLE }, + enablePermanentToolApproval: true, + }, + mcpServers: {}, + }, + setValue: vi.fn(), + })); + + manager = new AcpSessionManager(mockSettings, mockArgv, mockConnection); + vi.mock('node:crypto', () => ({ + randomUUID: () => 'test-session-id', + })); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('should create a new session', async () => { + vi.useFakeTimers(); + mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ + apiKey: 'test-key', + }); + const response = await manager.newSession( + { + cwd: '/tmp', + mcpServers: [], + }, + {}, + ); + + expect(response.sessionId).toBe('test-session-id'); + expect(loadCliConfig).toHaveBeenCalled(); + expect(mockConfig.initialize).toHaveBeenCalled(); + expect(mockConfig.getGeminiClient).toHaveBeenCalled(); + + // Verify deferred call (sendAvailableCommands) + await vi.runAllTimersAsync(); + expect(mockConnection.sessionUpdate).toHaveBeenCalledWith( + expect.objectContaining({ + update: expect.objectContaining({ + sessionUpdate: 'available_commands_update', + }), + }), + ); + vi.useRealTimers(); + }); + + it('should return modes without plan mode when plan is disabled', async () => { + mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ + apiKey: 'test-key', + }); + mockConfig.isPlanEnabled = vi.fn().mockReturnValue(false); + mockConfig.getApprovalMode = vi.fn().mockReturnValue('default'); + + const response = await manager.newSession( + { + cwd: '/tmp', + mcpServers: [], + }, + {}, + ); + + expect(response.modes).toEqual({ + availableModes: [ + { id: 'default', name: 'Default', description: 'Prompts for approval' }, + { + id: 'autoEdit', + name: 'Auto Edit', + description: 'Auto-approves edit tools', + }, + { id: 'yolo', name: 'YOLO', description: 'Auto-approves all tools' }, + ], + currentModeId: 'default', + }); + }); + + it('should include preview models when user has access', async () => { + mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ + apiKey: 'test-key', + }); + mockConfig.getHasAccessToPreviewModel = vi.fn().mockReturnValue(true); + mockConfig.getGemini31LaunchedSync = vi.fn().mockReturnValue(true); + + const response = await manager.newSession( + { + cwd: '/tmp', + mcpServers: [], + }, + {}, + ); + + expect(response.models?.availableModels).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + modelId: 'auto-gemini-3', + name: expect.stringContaining('Auto'), + }), + ]), + ); + }); + + it('should include gemini-3.1-flash-lite when useGemini31FlashLite is true', async () => { + mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ + apiKey: 'test-key', + }); + mockConfig.getHasAccessToPreviewModel = vi.fn().mockReturnValue(true); + mockConfig.getGemini31LaunchedSync = vi.fn().mockReturnValue(true); + mockConfig.getGemini31FlashLiteLaunchedSync = vi.fn().mockReturnValue(true); + + const response = await manager.newSession( + { + cwd: '/tmp', + mcpServers: [], + }, + {}, + ); + + expect(response.models?.availableModels).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + modelId: 'gemini-3.1-flash-lite-preview', + name: 'gemini-3.1-flash-lite-preview', + }), + ]), + ); + }); + + it('should return modes with plan mode when plan is enabled', async () => { + mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ + apiKey: 'test-key', + }); + mockConfig.isPlanEnabled = vi.fn().mockReturnValue(true); + mockConfig.getApprovalMode = vi.fn().mockReturnValue('plan'); + + const response = await manager.newSession( + { + cwd: '/tmp', + mcpServers: [], + }, + {}, + ); + + expect(response.modes).toEqual({ + availableModes: [ + { id: 'default', name: 'Default', description: 'Prompts for approval' }, + { + id: 'autoEdit', + name: 'Auto Edit', + description: 'Auto-approves edit tools', + }, + { id: 'yolo', name: 'YOLO', description: 'Auto-approves all tools' }, + { id: 'plan', name: 'Plan', description: 'Read-only mode' }, + ], + currentModeId: 'plan', + }); + }); + + it('should fail session creation if Gemini API key is missing', async () => { + (loadSettings as unknown as Mock).mockImplementation(() => ({ + merged: { + security: { auth: { selectedType: AuthType.USE_GEMINI } }, + mcpServers: {}, + }, + setValue: vi.fn(), + })); + mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ + apiKey: undefined, + }); + + await expect( + manager.newSession( + { + cwd: '/tmp', + mcpServers: [], + }, + {}, + ), + ).rejects.toMatchObject({ + message: 'Gemini API key is missing or not configured.', + }); + }); + + it('should create a new session with mcp servers', async () => { + mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ + apiKey: 'test-key', + }); + const mcpServers = [ + { + name: 'test-server', + command: 'node', + args: ['server.js'], + env: [{ name: 'KEY', value: 'VALUE' }], + }, + ]; + + await manager.newSession( + { + cwd: '/tmp', + mcpServers, + }, + {}, + ); + + expect(loadCliConfig).toHaveBeenCalledWith( + expect.objectContaining({ + mcpServers: expect.objectContaining({ + 'test-server': expect.objectContaining({ + command: 'node', + args: ['server.js'], + env: { KEY: 'VALUE' }, + }), + }), + }), + 'test-session-id', + mockArgv, + { cwd: '/tmp' }, + ); + }); + + it('should handle authentication failure gracefully', async () => { + mockConfig.refreshAuth.mockRejectedValue(new Error('Auth failed')); + + await expect( + manager.newSession( + { + cwd: '/tmp', + mcpServers: [], + }, + {}, + ), + ).rejects.toMatchObject({ + message: 'Auth failed', + }); + }); + + it('should initialize file system service if client supports it', async () => { + mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ + apiKey: 'test-key', + }); + manager.setClientCapabilities({ + fs: { readTextFile: true, writeTextFile: true }, + }); + + await manager.newSession( + { + cwd: '/tmp', + mcpServers: [], + }, + {}, + ); + + expect(mockConfig.setFileSystemService).toHaveBeenCalled(); + }); + + it('should start auto memory for new ACP sessions', async () => { + mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ + apiKey: 'test-key', + }); + + await manager.newSession( + { + cwd: '/tmp', + mcpServers: [], + }, + {}, + ); + + expect(startAutoMemoryIfEnabledMock).toHaveBeenCalledWith(mockConfig); + }); +}); diff --git a/packages/cli/src/acp/acpSessionManager.ts b/packages/cli/src/acp/acpSessionManager.ts new file mode 100644 index 0000000000..828dae9b14 --- /dev/null +++ b/packages/cli/src/acp/acpSessionManager.ts @@ -0,0 +1,322 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + type Config, + AuthType, + MCPServerConfig, + debugLogger, + startupProfiler, + convertSessionToClientHistory, + createPolicyUpdater, +} from '@google/gemini-cli-core'; +import * as acp from '@agentclientprotocol/sdk'; +import { randomUUID } from 'node:crypto'; +import { loadSettings, type LoadedSettings } from '../config/settings.js'; +import { SessionSelector } from '../utils/sessionUtils.js'; +import { Session } from './acpSession.js'; +import { AcpFileSystemService } from './acpFileSystemService.js'; +import { getAcpErrorMessage } from './acpErrors.js'; +import { buildAvailableModels, buildAvailableModes } from './acpUtils.js'; +import { loadCliConfig, type CliArgs } from '../config/config.js'; +import { startAutoMemoryIfEnabled } from '../utils/autoMemory.js'; + +export interface AuthDetails { + apiKey?: string; + baseUrl?: string; + customHeaders?: Record; +} + +export class AcpSessionManager { + private sessions: Map = new Map(); + private clientCapabilities: acp.ClientCapabilities | undefined; + + constructor( + private settings: LoadedSettings, + private argv: CliArgs, + private connection: acp.AgentSideConnection, + ) {} + + setClientCapabilities(capabilities: acp.ClientCapabilities) { + this.clientCapabilities = capabilities; + } + + getSession(sessionId: string): Session | undefined { + return this.sessions.get(sessionId); + } + + async newSession( + { cwd, mcpServers }: acp.NewSessionRequest, + authDetails: AuthDetails, + ): Promise { + const sessionId = randomUUID(); + const loadedSettings = loadSettings(cwd); + const config = await this.newSessionConfig( + sessionId, + cwd, + mcpServers, + loadedSettings, + ); + + const authType = + loadedSettings.merged.security.auth.selectedType || AuthType.USE_GEMINI; + + let isAuthenticated = false; + let authErrorMessage = ''; + try { + await config.refreshAuth( + authType, + authDetails.apiKey, + authDetails.baseUrl, + authDetails.customHeaders, + ); + isAuthenticated = true; + + // Extra validation for Gemini API key + const contentGeneratorConfig = config.getContentGeneratorConfig(); + if ( + authType === AuthType.USE_GEMINI && + (!contentGeneratorConfig || !contentGeneratorConfig.apiKey) + ) { + isAuthenticated = false; + authErrorMessage = 'Gemini API key is missing or not configured.'; + } + } catch (e) { + isAuthenticated = false; + authErrorMessage = getAcpErrorMessage(e); + debugLogger.error( + `Authentication failed: ${e instanceof Error ? e.stack : e}`, + ); + } + + if (!isAuthenticated) { + throw new acp.RequestError( + -32000, + authErrorMessage || 'Authentication required.', + ); + } + + if (this.clientCapabilities?.fs) { + const acpFileSystemService = new AcpFileSystemService( + this.connection, + sessionId, + this.clientCapabilities.fs, + config.getFileSystemService(), + cwd, + ); + config.setFileSystemService(acpFileSystemService); + } + + await config.initialize(); + startupProfiler.flush(config); + startAutoMemoryIfEnabled(config); + + const geminiClient = config.getGeminiClient(); + + const chat = await geminiClient.startChat(); + + const session = new Session( + sessionId, + chat, + config, + this.connection, + this.settings, + ); + this.sessions.set(sessionId, session); + + setTimeout(() => { + // eslint-disable-next-line @typescript-eslint/no-floating-promises + session.sendAvailableCommands(); + }, 0); + + const { availableModels, currentModelId } = buildAvailableModels( + config, + loadedSettings, + ); + + const response = { + sessionId, + modes: { + availableModes: buildAvailableModes(config.isPlanEnabled()), + currentModeId: config.getApprovalMode(), + }, + models: { + availableModels, + currentModelId, + }, + }; + return response; + } + + async loadSession( + { sessionId, cwd, mcpServers }: acp.LoadSessionRequest, + authDetails: AuthDetails, + ): Promise { + const config = await this.initializeSessionConfig( + sessionId, + cwd, + mcpServers, + authDetails, + ); + + const sessionSelector = new SessionSelector(config.storage); + + const { sessionData, sessionPath } = + await sessionSelector.resolveSession(sessionId); + + const clientHistory = convertSessionToClientHistory(sessionData.messages); + + const geminiClient = config.getGeminiClient(); + await geminiClient.initialize(); + await geminiClient.resumeChat(clientHistory, { + conversation: sessionData, + filePath: sessionPath, + }); + + const session = new Session( + sessionId, + geminiClient.getChat(), + config, + this.connection, + this.settings, + ); + this.sessions.set(sessionId, session); + + // Stream history back to client + // eslint-disable-next-line @typescript-eslint/no-floating-promises + session.streamHistory(sessionData.messages); + + setTimeout(() => { + // eslint-disable-next-line @typescript-eslint/no-floating-promises + session.sendAvailableCommands(); + }, 0); + + const { availableModels, currentModelId } = buildAvailableModels( + config, + this.settings, + ); + + const response = { + modes: { + availableModes: buildAvailableModes(config.isPlanEnabled()), + currentModeId: config.getApprovalMode(), + }, + models: { + availableModels, + currentModelId, + }, + }; + return response; + } + + private async initializeSessionConfig( + sessionId: string, + cwd: string, + mcpServers: acp.McpServer[], + authDetails: AuthDetails, + ): Promise { + const selectedAuthType = this.settings.merged.security.auth.selectedType; + if (!selectedAuthType) { + throw acp.RequestError.authRequired(); + } + + // 1. Create config WITHOUT initializing it (no MCP servers started yet) + const config = await this.newSessionConfig(sessionId, cwd, mcpServers); + + // 2. Authenticate BEFORE initializing configuration or starting MCP servers. + // This satisfies the security requirement to verify the user before executing + // potentially unsafe server definitions. + try { + await config.refreshAuth( + selectedAuthType, + authDetails.apiKey, + authDetails.baseUrl, + authDetails.customHeaders, + ); + } catch (e) { + debugLogger.error(`Authentication failed: ${e}`); + throw acp.RequestError.authRequired(); + } + + // 3. Set the ACP FileSystemService (if supported) before config initialization + if (this.clientCapabilities?.fs) { + const acpFileSystemService = new AcpFileSystemService( + this.connection, + sessionId, + this.clientCapabilities.fs, + config.getFileSystemService(), + cwd, + ); + config.setFileSystemService(acpFileSystemService); + } + + // 4. Now that we are authenticated, it is safe to initialize the config + // which starts the MCP servers and other heavy resources. + await config.initialize(); + startupProfiler.flush(config); + startAutoMemoryIfEnabled(config); + + return config; + } + + async newSessionConfig( + sessionId: string, + cwd: string, + mcpServers: acp.McpServer[], + loadedSettings?: LoadedSettings, + ): Promise { + const currentSettings = loadedSettings || this.settings; + const mergedMcpServers = { ...currentSettings.merged.mcpServers }; + + for (const server of mcpServers) { + if ( + 'type' in server && + (server.type === 'sse' || server.type === 'http') + ) { + // HTTP or SSE MCP server + const headers = Object.fromEntries( + server.headers.map(({ name, value }) => [name, value]), + ); + mergedMcpServers[server.name] = new MCPServerConfig( + undefined, // command + undefined, // args + undefined, // env + undefined, // cwd + server.type === 'sse' ? server.url : undefined, // url (sse) + server.type === 'http' ? server.url : undefined, // httpUrl + headers, + ); + } else if ('command' in server) { + // Stdio MCP server + const env: Record = {}; + for (const { name: envName, value } of server.env) { + env[envName] = value; + } + mergedMcpServers[server.name] = new MCPServerConfig( + server.command, + server.args, + env, + cwd, + ); + } + } + + const settings = { + ...currentSettings.merged, + mcpServers: mergedMcpServers, + }; + + const config = await loadCliConfig(settings, sessionId, this.argv, { cwd }); + + createPolicyUpdater( + config.getPolicyEngine(), + config.messageBus, + config.storage, + ); + + return config; + } +} diff --git a/packages/cli/src/acp/acpStdioTransport.ts b/packages/cli/src/acp/acpStdioTransport.ts new file mode 100644 index 0000000000..59198dee62 --- /dev/null +++ b/packages/cli/src/acp/acpStdioTransport.ts @@ -0,0 +1,35 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { type Config, createWorkingStdio } from '@google/gemini-cli-core'; +import { runExitCleanup } from '../utils/cleanup.js'; +import * as acp from '@agentclientprotocol/sdk'; +import { Readable, Writable } from 'node:stream'; +import type { LoadedSettings } from '../config/settings.js'; +import type { CliArgs } from '../config/config.js'; +import { GeminiAgent } from './acpRpcDispatcher.js'; + +export async function runAcpClient( + config: Config, + settings: LoadedSettings, + argv: CliArgs, +) { + const { stdout: workingStdout } = createWorkingStdio(); + const stdout = Writable.toWeb(workingStdout) as WritableStream; + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + const stdin = Readable.toWeb(process.stdin) as ReadableStream; + + const stream = acp.ndJsonStream(stdout, stdin); + const connection = new acp.AgentSideConnection( + (connection) => new GeminiAgent(config, settings, argv, connection), + stream, + ); + + // SIGTERM/SIGINT handlers (in sdk.ts) don't fire when stdin closes. + // We must explicitly await the connection close to flush telemetry. + // Use finally() to ensure cleanup runs even on stream errors. + await connection.closed.finally(runExitCleanup); +} diff --git a/packages/cli/src/acp/acpUtils.ts b/packages/cli/src/acp/acpUtils.ts new file mode 100644 index 0000000000..403227628e --- /dev/null +++ b/packages/cli/src/acp/acpUtils.ts @@ -0,0 +1,373 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + type Config, + type ToolResult, + type ToolCallConfirmationDetails, + Kind, + ApprovalMode, + DEFAULT_GEMINI_MODEL_AUTO, + PREVIEW_GEMINI_MODEL_AUTO, + DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_FLASH_LITE_MODEL, + PREVIEW_GEMINI_3_1_MODEL, + PREVIEW_GEMINI_MODEL, + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, + PREVIEW_GEMINI_FLASH_MODEL, + PREVIEW_GEMINI_3_1_FLASH_LITE_MODEL, + getDisplayString, + AuthType, + ToolConfirmationOutcome, +} from '@google/gemini-cli-core'; +import type * as acp from '@agentclientprotocol/sdk'; +import { z } from 'zod'; +import type { LoadedSettings } from '../config/settings.js'; + +export function hasMeta( + obj: unknown, +): obj is { _meta?: Record } { + return typeof obj === 'object' && obj !== null && '_meta' in obj; +} + +export const RequestPermissionResponseSchema = z.object({ + outcome: z.discriminatedUnion('outcome', [ + z.object({ outcome: z.literal('cancelled') }), + z.object({ + outcome: z.literal('selected'), + optionId: z.string(), + }), + ]), +}); + +export function toToolCallContent( + toolResult: ToolResult, +): acp.ToolCallContent | null { + if (toolResult.error?.message) { + throw new Error(toolResult.error.message); + } + + if (toolResult.returnDisplay) { + if (typeof toolResult.returnDisplay === 'string') { + return { + type: 'content', + content: { type: 'text', text: toolResult.returnDisplay }, + }; + } else { + if ('fileName' in toolResult.returnDisplay) { + return { + type: 'diff', + path: + toolResult.returnDisplay.filePath ?? + toolResult.returnDisplay.fileName, + oldText: toolResult.returnDisplay.originalContent, + newText: toolResult.returnDisplay.newContent, + _meta: { + kind: !toolResult.returnDisplay.originalContent + ? 'add' + : toolResult.returnDisplay.newContent === '' + ? 'delete' + : 'modify', + }, + }; + } + return null; + } + } else { + return null; + } +} + +const basicPermissionOptions = [ + { + optionId: ToolConfirmationOutcome.ProceedOnce, + name: 'Allow', + kind: 'allow_once', + }, + { + optionId: ToolConfirmationOutcome.Cancel, + name: 'Reject', + kind: 'reject_once', + }, +] as const; + +export function toPermissionOptions( + confirmation: ToolCallConfirmationDetails, + config: Config, + enablePermanentToolApproval: boolean = false, +): acp.PermissionOption[] { + const disableAlwaysAllow = config.getDisableAlwaysAllow(); + const options: acp.PermissionOption[] = []; + + if (!disableAlwaysAllow) { + switch (confirmation.type) { + case 'edit': + options.push({ + optionId: ToolConfirmationOutcome.ProceedAlways, + name: 'Allow for this session', + kind: 'allow_always', + }); + if (enablePermanentToolApproval) { + options.push({ + optionId: ToolConfirmationOutcome.ProceedAlwaysAndSave, + name: 'Allow for this file in all future sessions', + kind: 'allow_always', + }); + } + break; + case 'exec': + options.push({ + optionId: ToolConfirmationOutcome.ProceedAlways, + name: 'Allow for this session', + kind: 'allow_always', + }); + if (enablePermanentToolApproval) { + options.push({ + optionId: ToolConfirmationOutcome.ProceedAlwaysAndSave, + name: 'Allow this command for all future sessions', + kind: 'allow_always', + }); + } + break; + case 'mcp': + options.push( + { + optionId: ToolConfirmationOutcome.ProceedAlwaysServer, + name: 'Allow all server tools for this session', + kind: 'allow_always', + }, + { + optionId: ToolConfirmationOutcome.ProceedAlwaysTool, + name: 'Allow tool for this session', + kind: 'allow_always', + }, + ); + if (enablePermanentToolApproval) { + options.push({ + optionId: ToolConfirmationOutcome.ProceedAlwaysAndSave, + name: 'Allow tool for all future sessions', + kind: 'allow_always', + }); + } + break; + case 'info': + options.push({ + optionId: ToolConfirmationOutcome.ProceedAlways, + name: 'Allow for this session', + kind: 'allow_always', + }); + if (enablePermanentToolApproval) { + options.push({ + optionId: ToolConfirmationOutcome.ProceedAlwaysAndSave, + name: 'Allow for all future sessions', + kind: 'allow_always', + }); + } + break; + case 'ask_user': + case 'exit_plan_mode': + // askuser and exit_plan_mode don't need "always allow" options + break; + default: + // No "always allow" options for other types + break; + } + } + + options.push(...basicPermissionOptions); + + // Exhaustive check + switch (confirmation.type) { + case 'edit': + case 'exec': + case 'mcp': + case 'info': + case 'ask_user': + case 'exit_plan_mode': + case 'sandbox_expansion': + break; + default: { + const unreachable: never = confirmation; + throw new Error(`Unexpected: ${unreachable}`); + } + } + + return options; +} + +export function toAcpToolKind(kind: Kind): acp.ToolKind { + switch (kind) { + case Kind.Read: + case Kind.Edit: + case Kind.Execute: + case Kind.Search: + case Kind.Delete: + case Kind.Move: + case Kind.Think: + case Kind.Fetch: + case Kind.SwitchMode: + case Kind.Other: + return kind as acp.ToolKind; + case Kind.Agent: + return 'think'; + case Kind.Plan: + case Kind.Communicate: + default: + return 'other'; + } +} + +export function buildAvailableModes(isPlanEnabled: boolean): acp.SessionMode[] { + const modes: acp.SessionMode[] = [ + { + id: ApprovalMode.DEFAULT, + name: 'Default', + description: 'Prompts for approval', + }, + { + id: ApprovalMode.AUTO_EDIT, + name: 'Auto Edit', + description: 'Auto-approves edit tools', + }, + { + id: ApprovalMode.YOLO, + name: 'YOLO', + description: 'Auto-approves all tools', + }, + ]; + + if (isPlanEnabled) { + modes.push({ + id: ApprovalMode.PLAN, + name: 'Plan', + description: 'Read-only mode', + }); + } + + return modes; +} + +export function buildAvailableModels( + config: Config, + settings: LoadedSettings, +): { + availableModels: Array<{ + modelId: string; + name: string; + description?: string; + }>; + currentModelId: string; +} { + const preferredModel = config.getModel() || DEFAULT_GEMINI_MODEL_AUTO; + const shouldShowPreviewModels = config.getHasAccessToPreviewModel(); + const useGemini31 = config.getGemini31LaunchedSync?.() ?? false; + const useGemini31FlashLite = + config.getGemini31FlashLiteLaunchedSync?.() ?? false; + const selectedAuthType = settings.merged.security.auth.selectedType; + const useCustomToolModel = + useGemini31 && selectedAuthType === AuthType.USE_GEMINI; + + // --- DYNAMIC PATH --- + if ( + config.getExperimentalDynamicModelConfiguration?.() === true && + config.getModelConfigService + ) { + const options = config.getModelConfigService().getAvailableModelOptions({ + useGemini3_1: useGemini31, + useGemini3_1FlashLite: useGemini31FlashLite, + useCustomTools: useCustomToolModel, + hasAccessToPreview: shouldShowPreviewModels, + }); + + return { + availableModels: options, + currentModelId: preferredModel, + }; + } + + // --- LEGACY PATH --- + const mainOptions = [ + { + value: DEFAULT_GEMINI_MODEL_AUTO, + title: getDisplayString(DEFAULT_GEMINI_MODEL_AUTO), + description: + 'Let Gemini CLI decide the best model for the task: gemini-2.5-pro, gemini-2.5-flash', + }, + ]; + + if (shouldShowPreviewModels) { + mainOptions.unshift({ + value: PREVIEW_GEMINI_MODEL_AUTO, + title: getDisplayString(PREVIEW_GEMINI_MODEL_AUTO), + description: useGemini31 + ? 'Let Gemini CLI decide the best model for the task: gemini-3.1-pro, gemini-3-flash' + : 'Let Gemini CLI decide the best model for the task: gemini-3-pro, gemini-3-flash', + }); + } + + const manualOptions = [ + { + value: DEFAULT_GEMINI_MODEL, + title: getDisplayString(DEFAULT_GEMINI_MODEL), + }, + { + value: DEFAULT_GEMINI_FLASH_MODEL, + title: getDisplayString(DEFAULT_GEMINI_FLASH_MODEL), + }, + { + value: DEFAULT_GEMINI_FLASH_LITE_MODEL, + title: getDisplayString(DEFAULT_GEMINI_FLASH_LITE_MODEL), + }, + ]; + + if (shouldShowPreviewModels) { + const previewProModel = useGemini31 + ? PREVIEW_GEMINI_3_1_MODEL + : PREVIEW_GEMINI_MODEL; + + const previewProValue = useCustomToolModel + ? PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL + : previewProModel; + + const previewOptions = [ + { + value: previewProValue, + title: getDisplayString(previewProModel), + }, + { + value: PREVIEW_GEMINI_FLASH_MODEL, + title: getDisplayString(PREVIEW_GEMINI_FLASH_MODEL), + }, + ]; + + if (useGemini31FlashLite) { + previewOptions.push({ + value: PREVIEW_GEMINI_3_1_FLASH_LITE_MODEL, + title: getDisplayString(PREVIEW_GEMINI_3_1_FLASH_LITE_MODEL), + }); + } + + manualOptions.unshift(...previewOptions); + } + + const scaleOptions = ( + options: Array<{ value: string; title: string; description?: string }>, + ) => + options.map((o) => ({ + modelId: o.value, + name: o.title, + description: o.description, + })); + + return { + availableModels: [ + ...scaleOptions(mainOptions), + ...scaleOptions(manualOptions), + ], + currentModelId: preferredModel, + }; +} diff --git a/packages/cli/src/acp/commands/commandRegistry.ts b/packages/cli/src/acp/commands/commandRegistry.ts index b689d5d602..e8af6c4048 100644 --- a/packages/cli/src/acp/commands/commandRegistry.ts +++ b/packages/cli/src/acp/commands/commandRegistry.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2025 Google LLC + * Copyright 2026 Google LLC * SPDX-License-Identifier: Apache-2.0 */ diff --git a/packages/cli/src/acp/commands/extensions.ts b/packages/cli/src/acp/commands/extensions.ts index 7ebe922402..52f17c48b3 100644 --- a/packages/cli/src/acp/commands/extensions.ts +++ b/packages/cli/src/acp/commands/extensions.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2025 Google LLC + * Copyright 2026 Google LLC * SPDX-License-Identifier: Apache-2.0 */ diff --git a/packages/cli/src/acp/commands/init.ts b/packages/cli/src/acp/commands/init.ts index a9104aa84f..954be123d0 100644 --- a/packages/cli/src/acp/commands/init.ts +++ b/packages/cli/src/acp/commands/init.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2025 Google LLC + * Copyright 2026 Google LLC * SPDX-License-Identifier: Apache-2.0 */ diff --git a/packages/cli/src/acp/commands/memory.ts b/packages/cli/src/acp/commands/memory.ts index 8e990e12a7..bb91e5dbdd 100644 --- a/packages/cli/src/acp/commands/memory.ts +++ b/packages/cli/src/acp/commands/memory.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2025 Google LLC + * Copyright 2026 Google LLC * SPDX-License-Identifier: Apache-2.0 */ diff --git a/packages/cli/src/acp/commands/restore.test.ts b/packages/cli/src/acp/commands/restore.test.ts index 681e42a491..c9bfaf5063 100644 --- a/packages/cli/src/acp/commands/restore.test.ts +++ b/packages/cli/src/acp/commands/restore.test.ts @@ -157,10 +157,15 @@ describe('RestoreCommand', () => { describe('ListCheckpointsCommand', () => { let context: CommandContext; let listCommand: ListCheckpointsCommand; + let mockReaddir: Mock<(path: string) => Promise>; beforeEach(() => { vi.resetAllMocks(); listCommand = new ListCheckpointsCommand(); + mockReaddir = vi.mocked(fs.readdir) as unknown as Mock< + (path: string) => Promise + >; + context = { agentContext: { config: { @@ -186,10 +191,7 @@ describe('ListCheckpointsCommand', () => { }); it('returns "No checkpoints found." when no .json checkpoints exist', async () => { - vi.mocked(fs.readdir).mockResolvedValue([ - 'not-a-checkpoint.txt', - // eslint-disable-next-line @typescript-eslint/no-explicit-any - ] as any); + mockReaddir.mockResolvedValue(['not-a-checkpoint.txt']); const response = await listCommand.execute(context); @@ -198,7 +200,7 @@ describe('ListCheckpointsCommand', () => { it('ignores error when mkdir fails', async () => { vi.mocked(fs.mkdir).mockRejectedValue(new Error('mkdir fail')); - vi.mocked(fs.readdir).mockResolvedValue([]); + mockReaddir.mockResolvedValue([]); const response = await listCommand.execute(context); @@ -207,11 +209,7 @@ describe('ListCheckpointsCommand', () => { }); it('formats checkpoint summary output from checkpoint metadata', async () => { - vi.mocked(fs.readdir).mockResolvedValue([ - 'cp1.json', - 'cp2.json', - // eslint-disable-next-line @typescript-eslint/no-explicit-any - ] as any); + mockReaddir.mockResolvedValue(['cp1.json', 'cp2.json']); vi.mocked(getCheckpointInfoList).mockReturnValue([ { messageId: 'id1', checkpoint: 'cp1' }, { messageId: 'id2', checkpoint: 'cp2' }, @@ -226,8 +224,7 @@ describe('ListCheckpointsCommand', () => { }); it('handles empty checkpoint info list', async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - vi.mocked(fs.readdir).mockResolvedValue(['some.json'] as any); + mockReaddir.mockResolvedValue(['some.json']); vi.mocked(getCheckpointInfoList).mockReturnValue([]); const response = await listCommand.execute(context); @@ -236,7 +233,7 @@ describe('ListCheckpointsCommand', () => { }); it('returns generic unexpected error message on failures', async () => { - vi.mocked(fs.readdir).mockRejectedValue(new Error('Readdir fail')); + mockReaddir.mockRejectedValue(new Error('Readdir fail')); const response = await listCommand.execute(context); diff --git a/packages/cli/src/acp/commands/restore.ts b/packages/cli/src/acp/commands/restore.ts index 4ffc5dfba2..e45dec67e2 100644 --- a/packages/cli/src/acp/commands/restore.ts +++ b/packages/cli/src/acp/commands/restore.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2025 Google LLC + * Copyright 2026 Google LLC * SPDX-License-Identifier: Apache-2.0 */ diff --git a/packages/cli/src/acp/commands/types.ts b/packages/cli/src/acp/commands/types.ts index 6f5656bd89..a175d5fc82 100644 --- a/packages/cli/src/acp/commands/types.ts +++ b/packages/cli/src/acp/commands/types.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2025 Google LLC + * Copyright 2026 Google LLC * SPDX-License-Identifier: Apache-2.0 */ diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 727a0af9c7..846be5890b 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -76,7 +76,7 @@ import { type InitializationResult, } from './core/initializer.js'; import { validateAuthMethod } from './config/auth.js'; -import { runAcpClient } from './acp/acpClient.js'; +import { runAcpClient } from './acp/acpStdioTransport.js'; import { validateNonInteractiveAuth } from './validateNonInterActiveAuth.js'; import { appEvents, AppEvent } from './utils/events.js'; import { SessionError, SessionSelector } from './utils/sessionUtils.js'; diff --git a/packages/cli/src/gemini_cleanup.test.tsx b/packages/cli/src/gemini_cleanup.test.tsx index b76a132234..75e9cf3959 100644 --- a/packages/cli/src/gemini_cleanup.test.tsx +++ b/packages/cli/src/gemini_cleanup.test.tsx @@ -157,7 +157,7 @@ vi.mock('./utils/cleanup.js', async (importOriginal) => { }; }); -vi.mock('./acp/acpClient.js', () => ({ +vi.mock('./acp/acpStdioTransport.js', () => ({ runAcpClient: vi.fn().mockResolvedValue(undefined), })); diff --git a/packages/core/src/utils/partUtils.ts b/packages/core/src/utils/partUtils.ts index b176d2ed21..e7a124eed6 100644 --- a/packages/core/src/utils/partUtils.ts +++ b/packages/core/src/utils/partUtils.ts @@ -179,3 +179,15 @@ export function appendToLastTextPart( return newPrompt; } + +/** + * Type guard to determine if a Part is a TextPart. + */ +export function isTextPart(part: unknown): part is { text: string } { + return ( + typeof part === 'object' && + part !== null && + 'text' in part && + typeof (part as { text: unknown }).text === 'string' + ); +}