From 3370644ffea707d918a3f7671e75ee813afe3634 Mon Sep 17 00:00:00 2001 From: Megha Bansal Date: Fri, 21 Nov 2025 21:08:33 +0530 Subject: [PATCH] Improved code coverage for cli/src/zed-integration (#13570) --- .../src/zed-integration/connection.test.ts | 216 +++++ .../zed-integration/fileSystemService.test.ts | 131 +++ .../zed-integration/zedIntegration.test.ts | 768 ++++++++++++++++++ .../cli/src/zed-integration/zedIntegration.ts | 8 +- 4 files changed, 1121 insertions(+), 2 deletions(-) create mode 100644 packages/cli/src/zed-integration/connection.test.ts create mode 100644 packages/cli/src/zed-integration/fileSystemService.test.ts create mode 100644 packages/cli/src/zed-integration/zedIntegration.test.ts diff --git a/packages/cli/src/zed-integration/connection.test.ts b/packages/cli/src/zed-integration/connection.test.ts new file mode 100644 index 0000000000..20bd709fca --- /dev/null +++ b/packages/cli/src/zed-integration/connection.test.ts @@ -0,0 +1,216 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { Connection, RequestError } from './connection.js'; +import { ReadableStream, WritableStream } from 'node:stream/web'; + +describe('Connection', () => { + let toPeer: WritableStream; + let fromPeer: ReadableStream; + let peerController: ReadableStreamDefaultController; + let receivedChunks: string[] = []; + let connection: Connection; + let handler: ReturnType; + + beforeEach(() => { + receivedChunks = []; + toPeer = new WritableStream({ + write(chunk) { + const str = new TextDecoder().decode(chunk); + receivedChunks.push(str); + }, + }); + + fromPeer = new ReadableStream({ + start(controller) { + peerController = controller; + }, + }); + + handler = vi.fn(); + connection = new Connection(handler, toPeer, fromPeer); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it('should send a request and receive a response', async () => { + const responsePromise = connection.sendRequest('testMethod', { + key: 'value', + }); + + // Verify request was sent + await vi.waitFor(() => { + expect(receivedChunks.length).toBeGreaterThan(0); + }); + const request = JSON.parse(receivedChunks[0]); + expect(request).toMatchObject({ + jsonrpc: '2.0', + method: 'testMethod', + params: { key: 'value' }, + }); + expect(request.id).toBeDefined(); + + // Simulate response + const response = { + jsonrpc: '2.0', + id: request.id, + result: { success: true }, + }; + peerController.enqueue( + new TextEncoder().encode(JSON.stringify(response) + '\n'), + ); + + const result = await responsePromise; + expect(result).toEqual({ success: true }); + }); + + it('should send a notification', async () => { + await connection.sendNotification('notifyMethod', { key: 'value' }); + + await vi.waitFor(() => { + expect(receivedChunks.length).toBeGreaterThan(0); + }); + const notification = JSON.parse(receivedChunks[0]); + expect(notification).toMatchObject({ + jsonrpc: '2.0', + method: 'notifyMethod', + params: { key: 'value' }, + }); + expect(notification.id).toBeUndefined(); + }); + + it('should handle incoming requests', async () => { + handler.mockResolvedValue({ result: 'ok' }); + + const request = { + jsonrpc: '2.0', + id: 1, + method: 'incomingMethod', + params: { foo: 'bar' }, + }; + peerController.enqueue( + new TextEncoder().encode(JSON.stringify(request) + '\n'), + ); + + // Wait for handler to be called and response to be written + await vi.waitFor(() => { + expect(handler).toHaveBeenCalledWith('incomingMethod', { foo: 'bar' }); + expect(receivedChunks.length).toBeGreaterThan(0); + }); + + const response = JSON.parse(receivedChunks[receivedChunks.length - 1]); + expect(response).toMatchObject({ + jsonrpc: '2.0', + id: 1, + result: { result: 'ok' }, + }); + }); + + it('should handle incoming notifications', async () => { + const notification = { + jsonrpc: '2.0', + method: 'incomingNotify', + params: { foo: 'bar' }, + }; + peerController.enqueue( + new TextEncoder().encode(JSON.stringify(notification) + '\n'), + ); + + // Wait for handler to be called + await vi.waitFor(() => { + expect(handler).toHaveBeenCalledWith('incomingNotify', { foo: 'bar' }); + }); + // Notifications don't send responses + expect(receivedChunks.length).toBe(0); + }); + + it('should handle request errors from handler', async () => { + handler.mockRejectedValue(new Error('Handler failed')); + + const request = { + jsonrpc: '2.0', + id: 2, + method: 'failMethod', + }; + peerController.enqueue( + new TextEncoder().encode(JSON.stringify(request) + '\n'), + ); + + await vi.waitFor(() => { + expect(receivedChunks.length).toBeGreaterThan(0); + }); + + const response = JSON.parse(receivedChunks[receivedChunks.length - 1]); + expect(response).toMatchObject({ + jsonrpc: '2.0', + id: 2, + error: { + code: -32603, + message: 'Internal error', + data: { details: 'Handler failed' }, + }, + }); + }); + + it('should handle RequestError from handler', async () => { + handler.mockRejectedValue(RequestError.methodNotFound('Unknown method')); + + const request = { + jsonrpc: '2.0', + id: 3, + method: 'unknown', + }; + peerController.enqueue( + new TextEncoder().encode(JSON.stringify(request) + '\n'), + ); + + await vi.waitFor(() => { + expect(receivedChunks.length).toBeGreaterThan(0); + }); + + const response = JSON.parse(receivedChunks[receivedChunks.length - 1]); + expect(response).toMatchObject({ + jsonrpc: '2.0', + id: 3, + error: { + code: -32601, + message: 'Method not found', + data: { details: 'Unknown method' }, + }, + }); + }); + + it('should handle response errors', async () => { + const responsePromise = connection.sendRequest('testMethod'); + + // Verify request was sent + await vi.waitFor(() => { + expect(receivedChunks.length).toBeGreaterThan(0); + }); + const request = JSON.parse(receivedChunks[0]); + + // Simulate error response + const response = { + jsonrpc: '2.0', + id: request.id, + error: { + code: -32000, + message: 'Custom error', + }, + }; + peerController.enqueue( + new TextEncoder().encode(JSON.stringify(response) + '\n'), + ); + + await expect(responsePromise).rejects.toMatchObject({ + code: -32000, + message: 'Custom error', + }); + }); +}); diff --git a/packages/cli/src/zed-integration/fileSystemService.test.ts b/packages/cli/src/zed-integration/fileSystemService.test.ts new file mode 100644 index 0000000000..e274df0618 --- /dev/null +++ b/packages/cli/src/zed-integration/fileSystemService.test.ts @@ -0,0 +1,131 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, type Mocked } from 'vitest'; +import { AcpFileSystemService } from './fileSystemService.js'; +import type { Client } from './acp.js'; +import type { FileSystemService } from '@google/gemini-cli-core'; + +describe('AcpFileSystemService', () => { + let mockClient: Mocked; + let mockFallback: Mocked; + let service: AcpFileSystemService; + + beforeEach(() => { + mockClient = { + requestPermission: vi.fn(), + sessionUpdate: vi.fn(), + writeTextFile: vi.fn(), + readTextFile: vi.fn(), + }; + mockFallback = { + readTextFile: vi.fn(), + writeTextFile: vi.fn(), + findFiles: vi.fn(), + }; + }); + + describe('readTextFile', () => { + it.each([ + { + capability: true, + desc: 'client if capability exists', + setup: () => { + mockClient.readTextFile.mockResolvedValue({ content: 'content' }); + }, + verify: () => { + expect(mockClient.readTextFile).toHaveBeenCalledWith({ + path: '/path/to/file', + sessionId: 'session-1', + line: null, + limit: null, + }); + expect(mockFallback.readTextFile).not.toHaveBeenCalled(); + }, + }, + { + capability: false, + desc: 'fallback if capability missing', + setup: () => { + mockFallback.readTextFile.mockResolvedValue('content'); + }, + verify: () => { + expect(mockFallback.readTextFile).toHaveBeenCalledWith( + '/path/to/file', + ); + expect(mockClient.readTextFile).not.toHaveBeenCalled(); + }, + }, + ])('should use $desc', async ({ capability, setup, verify }) => { + service = new AcpFileSystemService( + mockClient, + 'session-1', + { readTextFile: capability, writeTextFile: true }, + mockFallback, + ); + setup(); + + const result = await service.readTextFile('/path/to/file'); + + expect(result).toBe('content'); + verify(); + }); + }); + + describe('writeTextFile', () => { + it.each([ + { + capability: true, + desc: 'client if capability exists', + verify: () => { + expect(mockClient.writeTextFile).toHaveBeenCalledWith({ + path: '/path/to/file', + content: 'content', + sessionId: 'session-1', + }); + expect(mockFallback.writeTextFile).not.toHaveBeenCalled(); + }, + }, + { + capability: false, + desc: 'fallback if capability missing', + verify: () => { + expect(mockFallback.writeTextFile).toHaveBeenCalledWith( + '/path/to/file', + 'content', + ); + expect(mockClient.writeTextFile).not.toHaveBeenCalled(); + }, + }, + ])('should use $desc', async ({ capability, verify }) => { + service = new AcpFileSystemService( + mockClient, + 'session-1', + { writeTextFile: capability, readTextFile: true }, + mockFallback, + ); + + await service.writeTextFile('/path/to/file', 'content'); + + verify(); + }); + }); + + it('should always use fallback for findFiles', () => { + service = new AcpFileSystemService( + mockClient, + 'session-1', + { readTextFile: true, writeTextFile: true }, + mockFallback, + ); + mockFallback.findFiles.mockReturnValue(['file1', 'file2']); + + const result = service.findFiles('pattern', ['/path']); + + expect(mockFallback.findFiles).toHaveBeenCalledWith('pattern', ['/path']); + expect(result).toEqual(['file1', 'file2']); + }); +}); diff --git a/packages/cli/src/zed-integration/zedIntegration.test.ts b/packages/cli/src/zed-integration/zedIntegration.test.ts new file mode 100644 index 0000000000..48a741be20 --- /dev/null +++ b/packages/cli/src/zed-integration/zedIntegration.test.ts @@ -0,0 +1,768 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, + type Mocked, +} from 'vitest'; +import { GeminiAgent, Session } from './zedIntegration.js'; +import * as acp from './acp.js'; +import { + AuthType, + ToolConfirmationOutcome, + StreamEventType, + isWithinRoot, + ReadManyFilesTool, + type GeminiChat, + type Config, +} from '@google/gemini-cli-core'; +import { SettingScope, type LoadedSettings } from '../config/settings.js'; +import { loadCliConfig, type CliArgs } from '../config/config.js'; +import * as fs from 'node:fs/promises'; +import * as path from 'node:path'; + +vi.mock('../config/config.js', () => ({ + loadCliConfig: vi.fn(), +})); + +vi.mock('node:crypto', () => ({ + randomUUID: () => 'test-session-id', +})); + +vi.mock('node:fs/promises'); +vi.mock('node:path', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + resolve: vi.fn(), + }; +}); + +// Mock ReadManyFilesTool +vi.mock( + '@google/gemini-cli-core', + async ( + importOriginal: () => Promise, + ) => { + const actual = await importOriginal(); + return { + ...actual, + ReadManyFilesTool: vi.fn().mockImplementation(() => ({ + name: 'read_many_files', + kind: 'native', + build: vi.fn().mockReturnValue({ + getDescription: () => 'Read files', + toolLocations: () => [], + execute: vi.fn().mockResolvedValue({ + llmContent: ['--- file.txt ---\n\nFile content\n\n'], + }), + }), + })), + logToolCall: vi.fn(), + isWithinRoot: vi.fn().mockReturnValue(true), + }; + }, +); + +// Helper to create mock streams +// eslint-disable-next-line @typescript-eslint/no-explicit-any +async function* createMockStream(items: any[]) { + for (const item of items) { + yield item; + } +} + +describe('GeminiAgent', () => { + let mockConfig: Mocked>>; + let mockSettings: Mocked; + let mockArgv: CliArgs; + let mockClient: Mocked; + let agent: GeminiAgent; + + beforeEach(() => { + mockConfig = { + refreshAuth: vi.fn(), + initialize: vi.fn(), + getFileSystemService: vi.fn(), + setFileSystemService: vi.fn(), + getGeminiClient: vi.fn().mockReturnValue({ + startChat: vi.fn().mockResolvedValue({}), + }), + } as unknown as Mocked>>; + mockSettings = { + merged: { + security: { auth: { selectedType: 'login_with_google' } }, + mcpServers: {}, + }, + setValue: vi.fn(), + } as unknown as Mocked; + mockArgv = {} as unknown as CliArgs; + mockClient = { + sessionUpdate: vi.fn(), + } as unknown as Mocked; + + (loadCliConfig as unknown as Mock).mockResolvedValue(mockConfig); + + agent = new GeminiAgent(mockConfig, mockSettings, mockArgv, mockClient); + }); + + it('should initialize correctly', async () => { + const response = await agent.initialize({ + clientCapabilities: { fs: { readTextFile: true, writeTextFile: true } }, + protocolVersion: 1, + }); + + expect(response.protocolVersion).toBe(acp.PROTOCOL_VERSION); + expect(response.authMethods).toHaveLength(3); + expect(response.agentCapabilities.loadSession).toBe(false); + }); + + it('should authenticate correctly', async () => { + await agent.authenticate({ + methodId: AuthType.LOGIN_WITH_GOOGLE, + authMethod: { + id: AuthType.LOGIN_WITH_GOOGLE, + name: 'Log in with Google', + description: null, + }, + }); + + expect(mockConfig.refreshAuth).toHaveBeenCalledWith( + AuthType.LOGIN_WITH_GOOGLE, + ); + expect(mockSettings.setValue).toHaveBeenCalledWith( + SettingScope.User, + 'security.auth.selectedType', + AuthType.LOGIN_WITH_GOOGLE, + ); + }); + + it('should create a new session', async () => { + const response = await agent.newSession({ + cwd: '/tmp', + mcpServers: [], + }); + + expect(response.sessionId).toBe('test-session-id'); + expect(loadCliConfig).toHaveBeenCalled(); + expect(mockConfig.initialize).toHaveBeenCalled(); + expect(mockConfig.getGeminiClient).toHaveBeenCalled(); + }); + + it('should create a new session with mcp servers', async () => { + const mcpServers = [ + { + name: 'test-server', + command: 'node', + args: ['server.js'], + env: [{ name: 'KEY', value: 'VALUE' }], + }, + ]; + + await agent.newSession({ + cwd: '/tmp', + mcpServers, + }); + + expect(loadCliConfig).toHaveBeenCalledWith( + expect.objectContaining({ + mcpServers: expect.objectContaining({ + 'test-server': expect.objectContaining({ + command: 'node', + args: ['server.js'], + env: { KEY: 'VALUE' }, + }), + }), + }), + 'test-session-id', + mockArgv, + '/tmp', + ); + }); + + it('should handle authentication failure gracefully', async () => { + mockConfig.refreshAuth.mockRejectedValue(new Error('Auth failed')); + const debugSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + + // Should throw RequestError.authRequired() + await expect( + agent.newSession({ + cwd: '/tmp', + mcpServers: [], + }), + ).rejects.toMatchObject({ + message: 'Authentication required', + }); + + debugSpy.mockRestore(); + }); + + it('should initialize file system service if client supports it', async () => { + agent = new GeminiAgent(mockConfig, mockSettings, mockArgv, mockClient); + await agent.initialize({ + clientCapabilities: { fs: { readTextFile: true, writeTextFile: true } }, + protocolVersion: 1, + }); + + await agent.newSession({ + cwd: '/tmp', + mcpServers: [], + }); + + expect(mockConfig.setFileSystemService).toHaveBeenCalled(); + }); + + it('should cancel a session', async () => { + await agent.newSession({ cwd: '/tmp', mcpServers: [] }); + // Mock the session's cancelPendingPrompt + const session = ( + agent as unknown as { sessions: Map } + ).sessions.get('test-session-id'); + if (!session) throw new Error('Session not found'); + session.cancelPendingPrompt = vi.fn(); + + await agent.cancel({ sessionId: 'test-session-id' }); + + expect(session.cancelPendingPrompt).toHaveBeenCalled(); + }); + + it('should throw error when cancelling non-existent session', async () => { + await expect(agent.cancel({ sessionId: 'unknown' })).rejects.toThrow( + 'Session not found', + ); + }); + + it('should delegate prompt to session', async () => { + await agent.newSession({ cwd: '/tmp', mcpServers: [] }); + const session = ( + agent as unknown as { sessions: Map } + ).sessions.get('test-session-id'); + if (!session) throw new Error('Session not found'); + session.prompt = vi.fn().mockResolvedValue({ stopReason: 'end_turn' }); + + const result = await agent.prompt({ + sessionId: 'test-session-id', + prompt: [], + }); + + expect(session.prompt).toHaveBeenCalled(); + expect(result).toEqual({ stopReason: 'end_turn' }); + }); +}); + +describe('Session', () => { + let mockChat: Mocked; + let mockConfig: Mocked; + let mockClient: Mocked; + let session: Session; + let mockToolRegistry: { getTool: Mock }; + let mockTool: { kind: string; build: Mock }; + + beforeEach(() => { + mockChat = { + sendMessageStream: vi.fn(), + addHistory: vi.fn(), + } as unknown as Mocked; + mockTool = { + kind: 'native', + build: vi.fn().mockReturnValue({ + getDescription: () => 'Test Tool', + toolLocations: () => [], + shouldConfirmExecute: vi.fn().mockResolvedValue(null), + execute: vi.fn().mockResolvedValue({ llmContent: 'Tool Result' }), + }), + }; + mockToolRegistry = { + getTool: vi.fn().mockReturnValue(mockTool), + }; + mockConfig = { + isInFallbackMode: vi.fn().mockReturnValue(false), + getModel: vi.fn().mockReturnValue('gemini-pro'), + getPreviewFeatures: vi.fn().mockReturnValue({}), + getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + getFileService: vi.fn().mockReturnValue({ + shouldIgnoreFile: vi.fn().mockReturnValue(false), + }), + getFileFilteringOptions: vi.fn().mockReturnValue({}), + getTargetDir: vi.fn().mockReturnValue('/tmp'), + getEnableRecursiveFileSearch: vi.fn().mockReturnValue(false), + getDebugMode: vi.fn().mockReturnValue(false), + } as unknown as Mocked; + mockClient = { + sessionUpdate: vi.fn(), + requestPermission: vi.fn(), + sendNotification: vi.fn(), + } as unknown as Mocked; + + session = new Session('session-1', mockChat, mockConfig, mockClient); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it('should handle prompt with text response', async () => { + const stream = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { + candidates: [{ content: { parts: [{ text: 'Hello' }] } }], + }, + }, + ]); + mockChat.sendMessageStream.mockResolvedValue(stream); + + const result = await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Hi' }], + }); + + expect(mockChat.sendMessageStream).toHaveBeenCalled(); + expect(mockClient.sessionUpdate).toHaveBeenCalledWith({ + sessionId: 'session-1', + update: { + sessionUpdate: 'agent_message_chunk', + content: { type: 'text', text: 'Hello' }, + }, + }); + expect(result).toEqual({ stopReason: 'end_turn' }); + }); + + it('should handle tool calls', async () => { + const stream1 = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { + functionCalls: [{ name: 'test_tool', args: { foo: 'bar' } }], + }, + }, + ]); + const stream2 = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { + candidates: [{ content: { parts: [{ text: 'Result' }] } }], + }, + }, + ]); + + mockChat.sendMessageStream + .mockResolvedValueOnce(stream1) + .mockResolvedValueOnce(stream2); + + const result = await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Call tool' }], + }); + + expect(mockToolRegistry.getTool).toHaveBeenCalledWith('test_tool'); + expect(mockTool.build).toHaveBeenCalledWith({ foo: 'bar' }); + expect(mockClient.sessionUpdate).toHaveBeenCalledWith( + expect.objectContaining({ + update: expect.objectContaining({ + sessionUpdate: 'tool_call', + status: 'in_progress', + }), + }), + ); + expect(mockClient.sessionUpdate).toHaveBeenCalledWith( + expect.objectContaining({ + update: expect.objectContaining({ + sessionUpdate: 'tool_call_update', + status: 'completed', + }), + }), + ); + expect(result).toEqual({ stopReason: 'end_turn' }); + }); + + it('should handle tool call permission request', async () => { + const confirmationDetails = { + type: 'info', + onConfirm: vi.fn(), + }; + mockTool.build.mockReturnValue({ + getDescription: () => 'Test Tool', + toolLocations: () => [], + shouldConfirmExecute: vi.fn().mockResolvedValue(confirmationDetails), + execute: vi.fn().mockResolvedValue({ llmContent: 'Tool Result' }), + }); + + mockClient.requestPermission.mockResolvedValue({ + outcome: { + outcome: 'selected', + optionId: ToolConfirmationOutcome.ProceedOnce, + }, + }); + + const stream1 = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { + functionCalls: [{ name: 'test_tool', args: {} }], + }, + }, + ]); + const stream2 = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { candidates: [] }, + }, + ]); + + mockChat.sendMessageStream + .mockResolvedValueOnce(stream1) + .mockResolvedValueOnce(stream2); + + await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Call tool' }], + }); + + expect(mockClient.requestPermission).toHaveBeenCalled(); + expect(confirmationDetails.onConfirm).toHaveBeenCalledWith( + ToolConfirmationOutcome.ProceedOnce, + ); + }); + + it('should handle tool call cancellation by user', async () => { + const confirmationDetails = { + type: 'info', + onConfirm: vi.fn(), + }; + mockTool.build.mockReturnValue({ + getDescription: () => 'Test Tool', + toolLocations: () => [], + shouldConfirmExecute: vi.fn().mockResolvedValue(confirmationDetails), + execute: vi.fn().mockResolvedValue({ llmContent: 'Tool Result' }), + }); + + mockClient.requestPermission.mockResolvedValue({ + outcome: { outcome: 'cancelled' }, + }); + + const stream1 = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { + functionCalls: [{ name: 'test_tool', args: {} }], + }, + }, + ]); + const stream2 = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { candidates: [] }, + }, + ]); + + mockChat.sendMessageStream + .mockResolvedValueOnce(stream1) + .mockResolvedValueOnce(stream2); + + await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Call tool' }], + }); + + // When cancelled, it sends an error response to the model + // We can verify that the second call to sendMessageStream contains the error + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); + const secondCallArgs = mockChat.sendMessageStream.mock.calls[1]; + const parts = secondCallArgs[1]; // parts + expect(parts).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + functionResponse: expect.objectContaining({ + response: { + error: expect.stringContaining('canceled by the user'), + }, + }), + }), + ]), + ); + }); + + it('should handle @path resolution', async () => { + (path.resolve as unknown as Mock).mockReturnValue('/tmp/file.txt'); + (fs.stat as unknown as Mock).mockResolvedValue({ + isDirectory: () => false, + }); + (isWithinRoot as unknown as Mock).mockReturnValue(true); + + const stream = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { candidates: [] }, + }, + ]); + mockChat.sendMessageStream.mockResolvedValue(stream); + + await session.prompt({ + sessionId: 'session-1', + prompt: [ + { type: 'text', text: 'Read' }, + { + type: 'resource_link', + uri: 'file://file.txt', + mimeType: 'text/plain', + name: 'file.txt', + }, + ], + }); + + expect(path.resolve).toHaveBeenCalled(); + expect(fs.stat).toHaveBeenCalled(); + + // Verify ReadManyFilesTool was used (implicitly by checking if sendMessageStream was called with resolved content) + // Since we mocked ReadManyFilesTool to return specific content, we can check the args passed to sendMessageStream + expect(mockChat.sendMessageStream).toHaveBeenCalledWith( + expect.anything(), + expect.arrayContaining([ + expect.objectContaining({ + text: expect.stringContaining('Content from @file.txt'), + }), + ]), + expect.anything(), + expect.anything(), + ); + }); + + it('should handle cancellation during prompt', async () => { + let streamController: ReadableStreamDefaultController; + const stream = new ReadableStream({ + start(controller) { + streamController = controller; + }, + }); + + let streamStarted: (value: unknown) => void; + const streamStartedPromise = new Promise((resolve) => { + streamStarted = resolve; + }); + + // Adapt web stream to async iterable + async function* asyncStream() { + process.stdout.write('TEST: asyncStream started\n'); + streamStarted(true); + const reader = stream.getReader(); + try { + while (true) { + process.stdout.write('TEST: waiting for read\n'); + const { done, value } = await reader.read(); + process.stdout.write(`TEST: read returned done=${done}\n`); + if (done) break; + yield value; + } + } finally { + process.stdout.write('TEST: releasing lock\n'); + reader.releaseLock(); + } + } + + mockChat.sendMessageStream.mockResolvedValue(asyncStream()); + + process.stdout.write('TEST: calling prompt\n'); + const promptPromise = session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Hi' }], + }); + + process.stdout.write('TEST: waiting for streamStarted\n'); + await streamStartedPromise; + process.stdout.write('TEST: streamStarted\n'); + await session.cancelPendingPrompt(); + process.stdout.write('TEST: cancelled\n'); + + // Close the stream to allow prompt loop to continue and check aborted signal + streamController!.close(); + process.stdout.write('TEST: stream closed\n'); + + const result = await promptPromise; + process.stdout.write(`TEST: result received ${JSON.stringify(result)}\n`); + expect(result).toEqual({ stopReason: 'cancelled' }); + }); + + it('should handle rate limit error', async () => { + const error = new Error('Rate limit'); + (error as unknown as { status: number }).status = 429; + mockChat.sendMessageStream.mockRejectedValue(error); + + await expect( + session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Hi' }], + }), + ).rejects.toMatchObject({ + code: 429, + message: 'Rate limit exceeded. Try again later.', + }); + }); + + it('should handle tool execution error', async () => { + mockTool.build.mockReturnValue({ + getDescription: () => 'Test Tool', + toolLocations: () => [], + shouldConfirmExecute: vi.fn().mockResolvedValue(null), + execute: vi.fn().mockRejectedValue(new Error('Tool failed')), + }); + + const stream1 = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { + functionCalls: [{ name: 'test_tool', args: {} }], + }, + }, + ]); + const stream2 = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { candidates: [] }, + }, + ]); + + mockChat.sendMessageStream + .mockResolvedValueOnce(stream1) + .mockResolvedValueOnce(stream2); + + await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Call tool' }], + }); + + expect(mockClient.sessionUpdate).toHaveBeenCalledWith( + expect.objectContaining({ + update: expect.objectContaining({ + sessionUpdate: 'tool_call_update', + status: 'failed', + content: expect.arrayContaining([ + expect.objectContaining({ + content: expect.objectContaining({ text: 'Tool failed' }), + }), + ]), + }), + }), + ); + }); + + it('should handle missing tool', async () => { + mockToolRegistry.getTool.mockReturnValue(undefined); + + const stream1 = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { + functionCalls: [{ name: 'unknown_tool', args: {} }], + }, + }, + ]); + const stream2 = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { candidates: [] }, + }, + ]); + + mockChat.sendMessageStream + .mockResolvedValueOnce(stream1) + .mockResolvedValueOnce(stream2); + + await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Call tool' }], + }); + + // Should send error response to model + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); + const secondCallArgs = mockChat.sendMessageStream.mock.calls[1]; + const parts = secondCallArgs[1]; + expect(parts).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + functionResponse: expect.objectContaining({ + response: { + error: expect.stringContaining('not found in registry'), + }, + }), + }), + ]), + ); + }); + + it('should ignore files based on configuration', async () => { + ( + mockConfig.getFileService().shouldIgnoreFile as unknown as Mock + ).mockReturnValue(true); + const stream = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { candidates: [] }, + }, + ]); + mockChat.sendMessageStream.mockResolvedValue(stream); + + await session.prompt({ + sessionId: 'session-1', + prompt: [ + { + type: 'resource_link', + uri: 'file://ignored.txt', + mimeType: 'text/plain', + name: 'ignored.txt', + }, + ], + }); + + // Should not read file + expect(mockToolRegistry.getTool).not.toHaveBeenCalledWith( + 'read_many_files', + ); + }); + + it('should handle directory resolution with glob', async () => { + (path.resolve as unknown as Mock).mockReturnValue('/tmp/dir'); + (fs.stat as unknown as Mock).mockResolvedValue({ + isDirectory: () => true, + }); + (isWithinRoot as unknown as Mock).mockReturnValue(true); + + const stream = createMockStream([ + { + type: StreamEventType.CHUNK, + value: { candidates: [] }, + }, + ]); + mockChat.sendMessageStream.mockResolvedValue(stream); + + await session.prompt({ + sessionId: 'session-1', + prompt: [ + { + type: 'resource_link', + uri: 'file://dir', + mimeType: 'text/plain', + name: 'dir', + }, + ], + }); + + // Should use glob + // ReadManyFilesTool is instantiated directly, so we check if the mock instance's build method was called + const MockReadManyFilesTool = ReadManyFilesTool as unknown as Mock; + const mockInstance = + MockReadManyFilesTool.mock.results[ + MockReadManyFilesTool.mock.results.length - 1 + ].value; + expect(mockInstance.build).toHaveBeenCalled(); + }); +}); diff --git a/packages/cli/src/zed-integration/zedIntegration.ts b/packages/cli/src/zed-integration/zedIntegration.ts index 43f8bb8a73..2f314c5349 100644 --- a/packages/cli/src/zed-integration/zedIntegration.ts +++ b/packages/cli/src/zed-integration/zedIntegration.ts @@ -66,7 +66,7 @@ export async function runZedIntegration( ); } -class GeminiAgent { +export class GeminiAgent { private sessions: Map = new Map(); private clientCapabilities: acp.ClientCapabilities | undefined; @@ -209,7 +209,7 @@ class GeminiAgent { } } -class Session { +export class Session { private pendingPrompt: AbortController | null = null; constructor( @@ -296,6 +296,10 @@ class Session { functionCalls.push(...resp.value.functionCalls); } } + + if (pendingSend.signal.aborted) { + return { stopReason: 'cancelled' }; + } } catch (error) { if (getErrorStatus(error) === 429) { throw new acp.RequestError(