Improve test code coverage for cli/command/extensions package (#12994)

This commit is contained in:
Megha Bansal
2025-11-13 21:28:01 -08:00
committed by GitHub
parent 555e25e633
commit 638dd2f6c0
12 changed files with 1904 additions and 251 deletions
@@ -0,0 +1,296 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi, describe, it, expect, beforeEach, type Mock } from 'vitest';
import {
AgentSideConnection,
RequestError,
type Agent,
type Client,
} from './acp.js';
import { type ErrorResponse } from './schema.js';
import { type MethodHandler } from './connection.js';
import { ReadableStream, WritableStream } from 'node:stream/web';
const mockConnectionConstructor = vi.hoisted(() =>
vi.fn<
(
arg1: MethodHandler,
arg2: WritableStream<Uint8Array>,
arg3: ReadableStream<Uint8Array>,
) => { sendRequest: Mock; sendNotification: Mock }
>(() => ({
sendRequest: vi.fn(),
sendNotification: vi.fn(),
})),
);
vi.mock('./connection.js', async (importOriginal) => {
const actual = await importOriginal();
return {
...(actual as object),
Connection: mockConnectionConstructor,
};
});
describe('acp', () => {
describe('RequestError', () => {
it('should create a parse error', () => {
const error = RequestError.parseError('details');
expect(error.code).toBe(-32700);
expect(error.message).toBe('Parse error');
expect(error.data?.details).toBe('details');
});
it('should create a method not found error', () => {
const error = RequestError.methodNotFound('details');
expect(error.code).toBe(-32601);
expect(error.message).toBe('Method not found');
expect(error.data?.details).toBe('details');
});
it('should convert to a result', () => {
const error = RequestError.internalError('details');
const result = error.toResult() as { error: ErrorResponse };
expect(result.error.code).toBe(-32603);
expect(result.error.message).toBe('Internal error');
expect(result.error.data).toEqual({ details: 'details' });
});
});
describe('AgentSideConnection', () => {
let mockAgent: Agent;
let toAgent: WritableStream<Uint8Array>;
let fromAgent: ReadableStream<Uint8Array>;
let agentSideConnection: AgentSideConnection;
let connectionInstance: InstanceType<typeof mockConnectionConstructor>;
beforeEach(() => {
vi.clearAllMocks();
const initializeResponse = {
agentCapabilities: { loadSession: true },
authMethods: [],
protocolVersion: 1,
};
const newSessionResponse = { sessionId: 'session-1' };
const loadSessionResponse = { sessionId: 'session-1' };
mockAgent = {
initialize: vi.fn().mockResolvedValue(initializeResponse),
newSession: vi.fn().mockResolvedValue(newSessionResponse),
loadSession: vi.fn().mockResolvedValue(loadSessionResponse),
authenticate: vi.fn(),
prompt: vi.fn(),
cancel: vi.fn(),
};
toAgent = new WritableStream<Uint8Array>();
fromAgent = new ReadableStream<Uint8Array>();
agentSideConnection = new AgentSideConnection(
(_client: Client) => mockAgent,
toAgent,
fromAgent,
);
// Get the mocked Connection instance
connectionInstance = mockConnectionConstructor.mock.results[0].value;
});
it('should initialize Connection with the correct handler and streams', () => {
expect(mockConnectionConstructor).toHaveBeenCalledTimes(1);
expect(mockConnectionConstructor).toHaveBeenCalledWith(
expect.any(Function),
toAgent,
fromAgent,
);
});
it('should call agent.initialize when Connection handler receives initialize method', async () => {
const initializeParams = {
clientCapabilities: { fs: { readTextFile: true, writeTextFile: true } },
protocolVersion: 1,
};
const initializeResponse = {
agentCapabilities: { loadSession: true },
authMethods: [],
protocolVersion: 1,
};
const handler = mockConnectionConstructor.mock
.calls[0][0]! as MethodHandler;
const result = await handler('initialize', initializeParams);
expect(mockAgent.initialize).toHaveBeenCalledWith(initializeParams);
expect(result).toEqual(initializeResponse);
});
it('should call agent.newSession when Connection handler receives session_new method', async () => {
const newSessionParams = { cwd: '/tmp', mcpServers: [] };
const newSessionResponse = { sessionId: 'session-1' };
const handler = mockConnectionConstructor.mock
.calls[0][0]! as MethodHandler;
const result = await handler('session/new', newSessionParams);
expect(mockAgent.newSession).toHaveBeenCalledWith(newSessionParams);
expect(result).toEqual(newSessionResponse);
});
it('should call agent.loadSession when Connection handler receives session_load method', async () => {
const loadSessionParams = {
cwd: '/tmp',
mcpServers: [],
sessionId: 'session-1',
};
const loadSessionResponse = { sessionId: 'session-1' };
const handler = mockConnectionConstructor.mock
.calls[0][0]! as MethodHandler;
const result = await handler('session/load', loadSessionParams);
expect(mockAgent.loadSession).toHaveBeenCalledWith(loadSessionParams);
expect(result).toEqual(loadSessionResponse);
});
it('should throw methodNotFound if agent.loadSession is not implemented', async () => {
mockAgent.loadSession = undefined; // Simulate not implemented
const loadSessionParams = {
cwd: '/tmp',
mcpServers: [],
sessionId: 'session-1',
};
const handler = mockConnectionConstructor.mock
.calls[0][0]! as MethodHandler;
await expect(handler('session/load', loadSessionParams)).rejects.toThrow(
RequestError.methodNotFound().message,
);
});
it('should call agent.authenticate when Connection handler receives authenticate method', async () => {
const authenticateParams = {
methodId: 'test-auth-method',
authMethod: {
id: 'test-auth',
name: 'Test Auth Method',
description: 'A test authentication method',
},
};
const handler = mockConnectionConstructor.mock
.calls[0][0]! as MethodHandler;
const result = await handler('authenticate', authenticateParams);
expect(mockAgent.authenticate).toHaveBeenCalledWith(authenticateParams);
expect(result).toBeUndefined();
});
it('should call agent.prompt when Connection handler receives session_prompt method', async () => {
const promptParams = {
prompt: [{ type: 'text', text: 'hi' }],
sessionId: 'session-1',
};
const promptResponse = {
response: [{ type: 'text', text: 'hello' }],
traceId: 'trace-1',
};
(mockAgent.prompt as Mock).mockResolvedValue(promptResponse);
const handler = mockConnectionConstructor.mock
.calls[0][0]! as MethodHandler;
const result = await handler('session/prompt', promptParams);
expect(mockAgent.prompt).toHaveBeenCalledWith(promptParams);
expect(result).toEqual(promptResponse);
});
it('should call agent.cancel when Connection handler receives session_cancel method', async () => {
const cancelParams = { sessionId: 'session-1' };
const handler = mockConnectionConstructor.mock
.calls[0][0]! as MethodHandler;
const result = await handler('session/cancel', cancelParams);
expect(mockAgent.cancel).toHaveBeenCalledWith(cancelParams);
expect(result).toBeUndefined();
});
it('should throw methodNotFound for unknown methods', async () => {
const handler = mockConnectionConstructor.mock
.calls[0][0]! as MethodHandler;
await expect(handler('unknown_method', {})).rejects.toThrow(
RequestError.methodNotFound().message,
);
});
it('should send sessionUpdate notification via connection', async () => {
const params = {
sessionId: '123',
update: {
sessionUpdate: 'user_message_chunk' as const,
content: { type: 'text' as const, text: 'hello' },
},
};
await agentSideConnection.sessionUpdate(params);
});
it('should send requestPermission request via connection', async () => {
const params = {
sessionId: '123',
toolCall: {
toolCallId: 'tool-1',
title: 'Test Tool',
kind: 'other' as const,
status: 'pending' as const,
},
options: [
{
optionId: 'option-1',
name: 'Allow',
kind: 'allow_once' as const,
},
],
};
const response = {
outcome: { outcome: 'selected', optionId: 'option-1' },
};
(connectionInstance.sendRequest as Mock).mockResolvedValue(response);
const result = await agentSideConnection.requestPermission(params);
expect(connectionInstance.sendRequest).toHaveBeenCalledWith(
'session/request_permission',
params,
);
expect(result).toEqual(response);
});
it('should send readTextFile request via connection', async () => {
const params = { path: '/a/b.txt', sessionId: 'session-1' };
const response = { content: 'file content' };
(connectionInstance.sendRequest as Mock).mockResolvedValue(response);
const result = await agentSideConnection.readTextFile(params);
expect(connectionInstance.sendRequest).toHaveBeenCalledWith(
'fs/read_text_file',
params,
);
expect(result).toEqual(response);
});
it('should send writeTextFile request via connection', async () => {
const params = {
path: '/a/b.txt',
content: 'new content',
sessionId: 'session-1',
};
const response = { success: true };
(connectionInstance.sendRequest as Mock).mockResolvedValue(response);
const result = await agentSideConnection.writeTextFile(params);
expect(connectionInstance.sendRequest).toHaveBeenCalledWith(
'fs/write_text_file',
params,
);
expect(result).toEqual(response);
});
});
});
+2 -232
View File
@@ -6,12 +6,12 @@
/* ACP defines a schema for a simple (experimental) JSON-RPC protocol that allows GUI applications to interact with agents. */
import { z } from 'zod';
import * as schema from './schema.js';
export * from './schema.js';
import type { WritableStream, ReadableStream } from 'node:stream/web';
import { coreEvents } from '@google/gemini-cli-core';
import { Connection, RequestError } from './connection.js';
export { RequestError };
export class AgentSideConnection implements Client {
#connection: Connection;
@@ -108,236 +108,6 @@ export class AgentSideConnection implements Client {
}
}
type AnyMessage = AnyRequest | AnyResponse | AnyNotification;
type AnyRequest = {
jsonrpc: '2.0';
id: string | number;
method: string;
params?: unknown;
};
type AnyResponse = {
jsonrpc: '2.0';
id: string | number;
} & Result<unknown>;
type AnyNotification = {
jsonrpc: '2.0';
method: string;
params?: unknown;
};
type Result<T> =
| {
result: T;
}
| {
error: ErrorResponse;
};
type ErrorResponse = {
code: number;
message: string;
data?: unknown;
};
type PendingResponse = {
resolve: (response: unknown) => void;
reject: (error: ErrorResponse) => void;
};
type MethodHandler = (method: string, params: unknown) => Promise<unknown>;
class Connection {
#pendingResponses: Map<string | number, PendingResponse> = new Map();
#nextRequestId: number = 0;
#handler: MethodHandler;
#peerInput: WritableStream<Uint8Array>;
#writeQueue: Promise<void> = Promise.resolve();
#textEncoder: TextEncoder;
constructor(
handler: MethodHandler,
peerInput: WritableStream<Uint8Array>,
peerOutput: ReadableStream<Uint8Array>,
) {
this.#handler = handler;
this.#peerInput = peerInput;
this.#textEncoder = new TextEncoder();
this.#receive(peerOutput);
}
async #receive(output: ReadableStream<Uint8Array>) {
let content = '';
const decoder = new TextDecoder();
for await (const chunk of output) {
content += decoder.decode(chunk, { stream: true });
const lines = content.split('\n');
content = lines.pop() || '';
for (const line of lines) {
const trimmedLine = line.trim();
if (trimmedLine) {
const message = JSON.parse(trimmedLine);
this.#processMessage(message);
}
}
}
}
async #processMessage(message: AnyMessage) {
if ('method' in message && 'id' in message) {
// It's a request
const response = await this.#tryCallHandler(
message.method,
message.params,
);
await this.#sendMessage({
jsonrpc: '2.0',
id: message.id,
...response,
});
} else if ('method' in message) {
// It's a notification
await this.#tryCallHandler(message.method, message.params);
} else if ('id' in message) {
// It's a response
this.#handleResponse(message as AnyResponse);
}
}
async #tryCallHandler(
method: string,
params?: unknown,
): Promise<Result<unknown>> {
try {
const result = await this.#handler(method, params);
return { result: result ?? null };
} catch (error: unknown) {
if (error instanceof RequestError) {
return error.toResult();
}
if (error instanceof z.ZodError) {
return RequestError.invalidParams(
JSON.stringify(error.format(), undefined, 2),
).toResult();
}
let details;
if (error instanceof Error) {
details = error.message;
} else if (
typeof error === 'object' &&
error != null &&
'message' in error &&
typeof error.message === 'string'
) {
details = error.message;
}
return RequestError.internalError(details).toResult();
}
}
#handleResponse(response: AnyResponse) {
const pendingResponse = this.#pendingResponses.get(response.id);
if (pendingResponse) {
if ('result' in response) {
pendingResponse.resolve(response.result);
} else if ('error' in response) {
pendingResponse.reject(response.error);
}
this.#pendingResponses.delete(response.id);
}
}
async sendRequest<Req, Resp>(method: string, params?: Req): Promise<Resp> {
const id = this.#nextRequestId++;
const responsePromise = new Promise((resolve, reject) => {
this.#pendingResponses.set(id, { resolve, reject });
});
await this.#sendMessage({ jsonrpc: '2.0', id, method, params });
return responsePromise as Promise<Resp>;
}
async sendNotification<N>(method: string, params?: N): Promise<void> {
await this.#sendMessage({ jsonrpc: '2.0', method, params });
}
async #sendMessage(json: AnyMessage) {
const content = JSON.stringify(json) + '\n';
this.#writeQueue = this.#writeQueue
.then(async () => {
const writer = this.#peerInput.getWriter();
try {
await writer.write(this.#textEncoder.encode(content));
} finally {
writer.releaseLock();
}
})
.catch((error) => {
// Continue processing writes on error
coreEvents.emitFeedback('error', 'ACP write error.', error);
});
return this.#writeQueue;
}
}
export class RequestError extends Error {
data?: { details?: string };
constructor(
public code: number,
message: string,
details?: string,
) {
super(message);
this.name = 'RequestError';
if (details) {
this.data = { details };
}
}
static parseError(details?: string): RequestError {
return new RequestError(-32700, 'Parse error', details);
}
static invalidRequest(details?: string): RequestError {
return new RequestError(-32600, 'Invalid request', details);
}
static methodNotFound(details?: string): RequestError {
return new RequestError(-32601, 'Method not found', details);
}
static invalidParams(details?: string): RequestError {
return new RequestError(-32602, 'Invalid params', details);
}
static internalError(details?: string): RequestError {
return new RequestError(-32603, 'Internal error', details);
}
static authRequired(details?: string): RequestError {
return new RequestError(-32000, 'Authentication required', details);
}
toResult<T>(): Result<T> {
return {
error: {
code: this.code,
message: this.message,
data: this.data,
},
};
}
}
export interface Client {
requestPermission(
params: schema.RequestPermissionRequest,
@@ -0,0 +1,229 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { z } from 'zod';
import { coreEvents } from '@google/gemini-cli-core';
import { type Result, type ErrorResponse } from './schema.js';
import type { WritableStream, ReadableStream } from 'node:stream/web';
export class RequestError extends Error {
data?: { details?: string };
constructor(
public code: number,
message: string,
details?: string,
) {
super(message);
this.name = 'RequestError';
if (details) {
this.data = { details };
}
}
static parseError(details?: string): RequestError {
return new RequestError(-32700, 'Parse error', details);
}
static invalidRequest(details?: string): RequestError {
return new RequestError(-32600, 'Invalid request', details);
}
static methodNotFound(details?: string): RequestError {
return new RequestError(-32601, 'Method not found', details);
}
static invalidParams(details?: string): RequestError {
return new RequestError(-32602, 'Invalid params', details);
}
static internalError(details?: string): RequestError {
return new RequestError(-32603, 'Internal error', details);
}
static authRequired(details?: string): RequestError {
return new RequestError(-32000, 'Authentication required', details);
}
toResult<T>(): Result<T> {
return {
error: {
code: this.code,
message: this.message,
data: this.data,
},
};
}
}
type AnyMessage = AnyRequest | AnyResponse | AnyNotification;
type AnyRequest = {
jsonrpc: '2.0';
id: string | number;
method: string;
params?: unknown;
};
type AnyResponse = {
jsonrpc: '2.0';
id: string | number;
} & Result<unknown>;
type AnyNotification = {
jsonrpc: '2.0';
method: string;
params?: unknown;
};
type PendingResponse = {
resolve: (response: unknown) => void;
reject: (error: ErrorResponse) => void;
};
export type MethodHandler = (
method: string,
params: unknown,
) => Promise<unknown>;
export class Connection {
#pendingResponses: Map<string | number, PendingResponse> = new Map();
#nextRequestId: number = 0;
#handler: MethodHandler;
#peerInput: WritableStream<Uint8Array>;
#writeQueue: Promise<void> = Promise.resolve();
#textEncoder: TextEncoder;
constructor(
handler: MethodHandler,
peerInput: WritableStream<Uint8Array>,
peerOutput: ReadableStream<Uint8Array>,
) {
this.#handler = handler;
this.#peerInput = peerInput;
this.#textEncoder = new TextEncoder();
this.#receive(peerOutput);
}
async #receive(output: ReadableStream<Uint8Array>) {
let content = '';
const decoder = new TextDecoder();
for await (const chunk of output) {
content += decoder.decode(chunk, { stream: true });
const lines = content.split('\n');
content = lines.pop() || '';
for (const line of lines) {
const trimmedLine = line.trim();
if (trimmedLine) {
const message = JSON.parse(trimmedLine);
this.#processMessage(message);
}
}
}
}
async #processMessage(message: AnyMessage) {
if ('method' in message && 'id' in message) {
// It's a request
const response = await this.#tryCallHandler(
message.method,
message.params,
);
await this.#sendMessage({
jsonrpc: '2.0',
id: message.id,
...response,
});
} else if ('method' in message) {
// It's a notification
await this.#tryCallHandler(message.method, message.params);
} else if ('id' in message) {
// It's a response
this.#handleResponse(message as AnyResponse);
}
}
async #tryCallHandler(
method: string,
params?: unknown,
): Promise<Result<unknown>> {
try {
const result = await this.#handler(method, params);
return { result: result ?? null };
} catch (error: unknown) {
if (error instanceof RequestError) {
return error.toResult();
}
if (error instanceof z.ZodError) {
return RequestError.invalidParams(
JSON.stringify(error.format(), undefined, 2),
).toResult();
}
let details;
if (error instanceof Error) {
details = error.message;
} else if (
typeof error === 'object' &&
error != null &&
'message' in error &&
typeof error.message === 'string'
) {
details = error.message;
}
return RequestError.internalError(details).toResult();
}
}
#handleResponse(response: AnyResponse) {
const pendingResponse = this.#pendingResponses.get(response.id);
if (pendingResponse) {
if ('result' in response) {
pendingResponse.resolve(response.result);
} else if ('error' in response) {
pendingResponse.reject(response.error);
}
this.#pendingResponses.delete(response.id);
}
}
async sendRequest<Req, Resp>(method: string, params?: Req): Promise<Resp> {
const id = this.#nextRequestId++;
const responsePromise = new Promise((resolve, reject) => {
this.#pendingResponses.set(id, { resolve, reject });
});
await this.#sendMessage({ jsonrpc: '2.0', id, method, params });
return responsePromise as Promise<Resp>;
}
async sendNotification<N>(method: string, params?: N): Promise<void> {
await this.#sendMessage({ jsonrpc: '2.0', method, params });
}
async #sendMessage(json: AnyMessage) {
const content = JSON.stringify(json) + '\n';
this.#writeQueue = this.#writeQueue
.then(async () => {
const writer = this.#peerInput.getWriter();
try {
await writer.write(this.#textEncoder.encode(content));
} finally {
writer.releaseLock();
}
})
.catch((error) => {
// Continue processing writes on error
coreEvents.emitFeedback('error', 'ACP write error.', error);
});
return this.#writeQueue;
}
}
+21 -6
View File
@@ -24,6 +24,12 @@ export const CLIENT_METHODS = {
export const PROTOCOL_VERSION = 1;
export const authMethodSchema = z.object({
description: z.string().nullable(),
id: z.string(),
name: z.string(),
});
export type WriteTextFileRequest = z.infer<typeof writeTextFileRequestSchema>;
export type ReadTextFileRequest = z.infer<typeof readTextFileRequestSchema>;
@@ -128,6 +134,20 @@ export type AgentRequest = z.infer<typeof agentRequestSchema>;
export type AgentNotification = z.infer<typeof agentNotificationSchema>;
export type Result<T> =
| {
result: T;
}
| {
error: ErrorResponse;
};
export type ErrorResponse = {
code: number;
message: string;
data?: unknown;
};
export const writeTextFileRequestSchema = z.object({
content: z.string(),
path: z.string(),
@@ -203,6 +223,7 @@ export const cancelNotificationSchema = z.object({
export const authenticateRequestSchema = z.object({
methodId: z.string(),
authMethod: authMethodSchema,
});
export const authenticateResponseSchema = z.null();
@@ -283,12 +304,6 @@ export const agentCapabilitiesSchema = z.object({
promptCapabilities: promptCapabilitiesSchema.optional(),
});
export const authMethodSchema = z.object({
description: z.string().nullable(),
id: z.string(),
name: z.string(),
});
export const clientResponseSchema = z.union([
writeTextFileResponseSchema,
readTextFileResponseSchema,