mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-12 07:01:09 -07:00
Use official ACP SDK and support HTTP/SSE based MCP servers (#13856)
This commit is contained in:
@@ -1,283 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { vi, describe, it, expect, beforeEach, type Mock } from 'vitest';
|
||||
import {
|
||||
AgentSideConnection,
|
||||
RequestError,
|
||||
type Agent,
|
||||
type Client,
|
||||
} from './acp.js';
|
||||
import { type ErrorResponse } from './schema.js';
|
||||
import { type MethodHandler } from './connection.js';
|
||||
import { ReadableStream, WritableStream } from 'node:stream/web';
|
||||
|
||||
const mockConnectionConstructor = vi.hoisted(() =>
|
||||
vi.fn<
|
||||
(
|
||||
arg1: MethodHandler,
|
||||
arg2: WritableStream<Uint8Array>,
|
||||
arg3: ReadableStream<Uint8Array>,
|
||||
) => { sendRequest: Mock; sendNotification: Mock }
|
||||
>(() => ({
|
||||
sendRequest: vi.fn(),
|
||||
sendNotification: vi.fn(),
|
||||
})),
|
||||
);
|
||||
|
||||
vi.mock('./connection.js', async (importOriginal) => {
|
||||
const actual = await importOriginal();
|
||||
return {
|
||||
...(actual as object),
|
||||
Connection: mockConnectionConstructor,
|
||||
};
|
||||
});
|
||||
|
||||
describe('acp', () => {
|
||||
describe('RequestError', () => {
|
||||
it('should create a parse error', () => {
|
||||
const error = RequestError.parseError('details');
|
||||
expect(error.code).toBe(-32700);
|
||||
expect(error.message).toBe('Parse error');
|
||||
expect(error.data?.details).toBe('details');
|
||||
});
|
||||
|
||||
it('should create a method not found error', () => {
|
||||
const error = RequestError.methodNotFound('details');
|
||||
expect(error.code).toBe(-32601);
|
||||
expect(error.message).toBe('Method not found');
|
||||
expect(error.data?.details).toBe('details');
|
||||
});
|
||||
|
||||
it('should convert to a result', () => {
|
||||
const error = RequestError.internalError('details');
|
||||
const result = error.toResult() as { error: ErrorResponse };
|
||||
expect(result.error.code).toBe(-32603);
|
||||
expect(result.error.message).toBe('Internal error');
|
||||
expect(result.error.data).toEqual({ details: 'details' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('AgentSideConnection', () => {
|
||||
let mockAgent: Agent;
|
||||
|
||||
let toAgent: WritableStream<Uint8Array>;
|
||||
let fromAgent: ReadableStream<Uint8Array>;
|
||||
let agentSideConnection: AgentSideConnection;
|
||||
let connectionInstance: InstanceType<typeof mockConnectionConstructor>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
const initializeResponse = {
|
||||
agentCapabilities: { loadSession: true },
|
||||
authMethods: [],
|
||||
protocolVersion: 1,
|
||||
};
|
||||
const newSessionResponse = { sessionId: 'session-1' };
|
||||
const loadSessionResponse = { sessionId: 'session-1' };
|
||||
|
||||
mockAgent = {
|
||||
initialize: vi.fn().mockResolvedValue(initializeResponse),
|
||||
newSession: vi.fn().mockResolvedValue(newSessionResponse),
|
||||
loadSession: vi.fn().mockResolvedValue(loadSessionResponse),
|
||||
authenticate: vi.fn(),
|
||||
prompt: vi.fn(),
|
||||
cancel: vi.fn(),
|
||||
};
|
||||
|
||||
toAgent = new WritableStream<Uint8Array>();
|
||||
fromAgent = new ReadableStream<Uint8Array>();
|
||||
|
||||
agentSideConnection = new AgentSideConnection(
|
||||
(_client: Client) => mockAgent,
|
||||
toAgent,
|
||||
fromAgent,
|
||||
);
|
||||
|
||||
// Get the mocked Connection instance
|
||||
connectionInstance = mockConnectionConstructor.mock.results[0].value;
|
||||
});
|
||||
|
||||
it('should initialize Connection with the correct handler and streams', () => {
|
||||
expect(mockConnectionConstructor).toHaveBeenCalledTimes(1);
|
||||
expect(mockConnectionConstructor).toHaveBeenCalledWith(
|
||||
expect.any(Function),
|
||||
toAgent,
|
||||
fromAgent,
|
||||
);
|
||||
});
|
||||
|
||||
it('should call agent.initialize when Connection handler receives initialize method', async () => {
|
||||
const initializeParams = {
|
||||
clientCapabilities: { fs: { readTextFile: true, writeTextFile: true } },
|
||||
protocolVersion: 1,
|
||||
};
|
||||
const initializeResponse = {
|
||||
agentCapabilities: { loadSession: true },
|
||||
authMethods: [],
|
||||
protocolVersion: 1,
|
||||
};
|
||||
const handler = mockConnectionConstructor.mock.calls[0][0];
|
||||
const result = await handler('initialize', initializeParams);
|
||||
|
||||
expect(mockAgent.initialize).toHaveBeenCalledWith(initializeParams);
|
||||
expect(result).toEqual(initializeResponse);
|
||||
});
|
||||
|
||||
it('should call agent.newSession when Connection handler receives session_new method', async () => {
|
||||
const newSessionParams = { cwd: '/tmp', mcpServers: [] };
|
||||
const newSessionResponse = { sessionId: 'session-1' };
|
||||
const handler = mockConnectionConstructor.mock.calls[0][0];
|
||||
const result = await handler('session/new', newSessionParams);
|
||||
|
||||
expect(mockAgent.newSession).toHaveBeenCalledWith(newSessionParams);
|
||||
expect(result).toEqual(newSessionResponse);
|
||||
});
|
||||
|
||||
it('should call agent.loadSession when Connection handler receives session_load method', async () => {
|
||||
const loadSessionParams = {
|
||||
cwd: '/tmp',
|
||||
mcpServers: [],
|
||||
sessionId: 'session-1',
|
||||
};
|
||||
const loadSessionResponse = { sessionId: 'session-1' };
|
||||
const handler = mockConnectionConstructor.mock.calls[0][0];
|
||||
const result = await handler('session/load', loadSessionParams);
|
||||
|
||||
expect(mockAgent.loadSession).toHaveBeenCalledWith(loadSessionParams);
|
||||
expect(result).toEqual(loadSessionResponse);
|
||||
});
|
||||
|
||||
it('should throw methodNotFound if agent.loadSession is not implemented', async () => {
|
||||
mockAgent.loadSession = undefined; // Simulate not implemented
|
||||
const loadSessionParams = {
|
||||
cwd: '/tmp',
|
||||
mcpServers: [],
|
||||
sessionId: 'session-1',
|
||||
};
|
||||
const handler = mockConnectionConstructor.mock.calls[0][0];
|
||||
await expect(handler('session/load', loadSessionParams)).rejects.toThrow(
|
||||
RequestError.methodNotFound().message,
|
||||
);
|
||||
});
|
||||
|
||||
it('should call agent.authenticate when Connection handler receives authenticate method', async () => {
|
||||
const authenticateParams = {
|
||||
methodId: 'test-auth-method',
|
||||
};
|
||||
const handler = mockConnectionConstructor.mock.calls[0][0];
|
||||
const result = await handler('authenticate', authenticateParams);
|
||||
|
||||
expect(mockAgent.authenticate).toHaveBeenCalledWith(authenticateParams);
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should call agent.prompt when Connection handler receives session_prompt method', async () => {
|
||||
const promptParams = {
|
||||
prompt: [{ type: 'text', text: 'hi' }],
|
||||
sessionId: 'session-1',
|
||||
};
|
||||
const promptResponse = {
|
||||
response: [{ type: 'text', text: 'hello' }],
|
||||
traceId: 'trace-1',
|
||||
};
|
||||
(mockAgent.prompt as Mock).mockResolvedValue(promptResponse);
|
||||
const handler = mockConnectionConstructor.mock.calls[0][0];
|
||||
const result = await handler('session/prompt', promptParams);
|
||||
|
||||
expect(mockAgent.prompt).toHaveBeenCalledWith(promptParams);
|
||||
expect(result).toEqual(promptResponse);
|
||||
});
|
||||
|
||||
it('should call agent.cancel when Connection handler receives session_cancel method', async () => {
|
||||
const cancelParams = { sessionId: 'session-1' };
|
||||
const handler = mockConnectionConstructor.mock.calls[0][0];
|
||||
const result = await handler('session/cancel', cancelParams);
|
||||
|
||||
expect(mockAgent.cancel).toHaveBeenCalledWith(cancelParams);
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw methodNotFound for unknown methods', async () => {
|
||||
const handler = mockConnectionConstructor.mock.calls[0][0];
|
||||
await expect(handler('unknown_method', {})).rejects.toThrow(
|
||||
RequestError.methodNotFound().message,
|
||||
);
|
||||
});
|
||||
|
||||
it('should send sessionUpdate notification via connection', async () => {
|
||||
const params = {
|
||||
sessionId: '123',
|
||||
update: {
|
||||
sessionUpdate: 'user_message_chunk' as const,
|
||||
content: { type: 'text' as const, text: 'hello' },
|
||||
},
|
||||
};
|
||||
await agentSideConnection.sessionUpdate(params);
|
||||
});
|
||||
|
||||
it('should send requestPermission request via connection', async () => {
|
||||
const params = {
|
||||
sessionId: '123',
|
||||
toolCall: {
|
||||
toolCallId: 'tool-1',
|
||||
title: 'Test Tool',
|
||||
kind: 'other' as const,
|
||||
status: 'pending' as const,
|
||||
},
|
||||
options: [
|
||||
{
|
||||
optionId: 'option-1',
|
||||
name: 'Allow',
|
||||
kind: 'allow_once' as const,
|
||||
},
|
||||
],
|
||||
};
|
||||
const response = {
|
||||
outcome: { outcome: 'selected', optionId: 'option-1' },
|
||||
};
|
||||
connectionInstance.sendRequest.mockResolvedValue(response);
|
||||
|
||||
const result = await agentSideConnection.requestPermission(params);
|
||||
expect(connectionInstance.sendRequest).toHaveBeenCalledWith(
|
||||
'session/request_permission',
|
||||
params,
|
||||
);
|
||||
expect(result).toEqual(response);
|
||||
});
|
||||
|
||||
it('should send readTextFile request via connection', async () => {
|
||||
const params = { path: '/a/b.txt', sessionId: 'session-1' };
|
||||
const response = { content: 'file content' };
|
||||
connectionInstance.sendRequest.mockResolvedValue(response);
|
||||
|
||||
const result = await agentSideConnection.readTextFile(params);
|
||||
expect(connectionInstance.sendRequest).toHaveBeenCalledWith(
|
||||
'fs/read_text_file',
|
||||
params,
|
||||
);
|
||||
expect(result).toEqual(response);
|
||||
});
|
||||
|
||||
it('should send writeTextFile request via connection', async () => {
|
||||
const params = {
|
||||
path: '/a/b.txt',
|
||||
content: 'new content',
|
||||
sessionId: 'session-1',
|
||||
};
|
||||
const response = { success: true };
|
||||
connectionInstance.sendRequest.mockResolvedValue(response);
|
||||
|
||||
const result = await agentSideConnection.writeTextFile(params);
|
||||
expect(connectionInstance.sendRequest).toHaveBeenCalledWith(
|
||||
'fs/write_text_file',
|
||||
params,
|
||||
);
|
||||
expect(result).toEqual(response);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,137 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/* ACP defines a schema for a simple (experimental) JSON-RPC protocol that allows GUI applications to interact with agents. */
|
||||
|
||||
import * as schema from './schema.js';
|
||||
export * from './schema.js';
|
||||
|
||||
import type { WritableStream, ReadableStream } from 'node:stream/web';
|
||||
import { Connection, RequestError } from './connection.js';
|
||||
export { RequestError };
|
||||
|
||||
export class AgentSideConnection implements Client {
|
||||
#connection: Connection;
|
||||
|
||||
constructor(
|
||||
toAgent: (conn: Client) => Agent,
|
||||
input: WritableStream<Uint8Array>,
|
||||
output: ReadableStream<Uint8Array>,
|
||||
) {
|
||||
const agent = toAgent(this);
|
||||
|
||||
const handler = async (
|
||||
method: string,
|
||||
params: unknown,
|
||||
): Promise<unknown> => {
|
||||
switch (method) {
|
||||
case schema.AGENT_METHODS.initialize: {
|
||||
const validatedParams = schema.initializeRequestSchema.parse(params);
|
||||
return agent.initialize(validatedParams);
|
||||
}
|
||||
case schema.AGENT_METHODS.session_new: {
|
||||
const validatedParams = schema.newSessionRequestSchema.parse(params);
|
||||
return agent.newSession(validatedParams);
|
||||
}
|
||||
case schema.AGENT_METHODS.session_load: {
|
||||
if (!agent.loadSession) {
|
||||
throw RequestError.methodNotFound();
|
||||
}
|
||||
const validatedParams = schema.loadSessionRequestSchema.parse(params);
|
||||
return agent.loadSession(validatedParams);
|
||||
}
|
||||
case schema.AGENT_METHODS.authenticate: {
|
||||
const validatedParams =
|
||||
schema.authenticateRequestSchema.parse(params);
|
||||
return agent.authenticate(validatedParams);
|
||||
}
|
||||
case schema.AGENT_METHODS.session_prompt: {
|
||||
const validatedParams = schema.promptRequestSchema.parse(params);
|
||||
return agent.prompt(validatedParams);
|
||||
}
|
||||
case schema.AGENT_METHODS.session_cancel: {
|
||||
const validatedParams = schema.cancelNotificationSchema.parse(params);
|
||||
return agent.cancel(validatedParams);
|
||||
}
|
||||
default:
|
||||
throw RequestError.methodNotFound(method);
|
||||
}
|
||||
};
|
||||
|
||||
this.#connection = new Connection(handler, input, output);
|
||||
}
|
||||
|
||||
/**
|
||||
* Streams new content to the client including text, tool calls, etc.
|
||||
*/
|
||||
async sessionUpdate(params: schema.SessionNotification): Promise<void> {
|
||||
return this.#connection.sendNotification(
|
||||
schema.CLIENT_METHODS.session_update,
|
||||
params,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Request permission before running a tool
|
||||
*
|
||||
* The agent specifies a series of permission options with different granularity,
|
||||
* and the client returns the chosen one.
|
||||
*/
|
||||
async requestPermission(
|
||||
params: schema.RequestPermissionRequest,
|
||||
): Promise<schema.RequestPermissionResponse> {
|
||||
return this.#connection.sendRequest(
|
||||
schema.CLIENT_METHODS.session_request_permission,
|
||||
params,
|
||||
);
|
||||
}
|
||||
|
||||
async readTextFile(
|
||||
params: schema.ReadTextFileRequest,
|
||||
): Promise<schema.ReadTextFileResponse> {
|
||||
return this.#connection.sendRequest(
|
||||
schema.CLIENT_METHODS.fs_read_text_file,
|
||||
params,
|
||||
);
|
||||
}
|
||||
|
||||
async writeTextFile(
|
||||
params: schema.WriteTextFileRequest,
|
||||
): Promise<schema.WriteTextFileResponse> {
|
||||
return this.#connection.sendRequest(
|
||||
schema.CLIENT_METHODS.fs_write_text_file,
|
||||
params,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export interface Client {
|
||||
requestPermission(
|
||||
params: schema.RequestPermissionRequest,
|
||||
): Promise<schema.RequestPermissionResponse>;
|
||||
sessionUpdate(params: schema.SessionNotification): Promise<void>;
|
||||
writeTextFile(
|
||||
params: schema.WriteTextFileRequest,
|
||||
): Promise<schema.WriteTextFileResponse>;
|
||||
readTextFile(
|
||||
params: schema.ReadTextFileRequest,
|
||||
): Promise<schema.ReadTextFileResponse>;
|
||||
}
|
||||
|
||||
export interface Agent {
|
||||
initialize(
|
||||
params: schema.InitializeRequest,
|
||||
): Promise<schema.InitializeResponse>;
|
||||
newSession(
|
||||
params: schema.NewSessionRequest,
|
||||
): Promise<schema.NewSessionResponse>;
|
||||
loadSession?(
|
||||
params: schema.LoadSessionRequest,
|
||||
): Promise<schema.LoadSessionResponse>;
|
||||
authenticate(params: schema.AuthenticateRequest): Promise<void>;
|
||||
prompt(params: schema.PromptRequest): Promise<schema.PromptResponse>;
|
||||
cancel(params: schema.CancelNotification): Promise<void>;
|
||||
}
|
||||
@@ -1,216 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { Connection, RequestError } from './connection.js';
|
||||
import { ReadableStream, WritableStream } from 'node:stream/web';
|
||||
|
||||
describe('Connection', () => {
|
||||
let toPeer: WritableStream<Uint8Array>;
|
||||
let fromPeer: ReadableStream<Uint8Array>;
|
||||
let peerController: ReadableStreamDefaultController<Uint8Array>;
|
||||
let receivedChunks: string[] = [];
|
||||
let connection: Connection;
|
||||
let handler: ReturnType<typeof vi.fn>;
|
||||
|
||||
beforeEach(() => {
|
||||
receivedChunks = [];
|
||||
toPeer = new WritableStream({
|
||||
write(chunk) {
|
||||
const str = new TextDecoder().decode(chunk);
|
||||
receivedChunks.push(str);
|
||||
},
|
||||
});
|
||||
|
||||
fromPeer = new ReadableStream({
|
||||
start(controller) {
|
||||
peerController = controller;
|
||||
},
|
||||
});
|
||||
|
||||
handler = vi.fn();
|
||||
connection = new Connection(handler, toPeer, fromPeer);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should send a request and receive a response', async () => {
|
||||
const responsePromise = connection.sendRequest('testMethod', {
|
||||
key: 'value',
|
||||
});
|
||||
|
||||
// Verify request was sent
|
||||
await vi.waitFor(() => {
|
||||
expect(receivedChunks.length).toBeGreaterThan(0);
|
||||
});
|
||||
const request = JSON.parse(receivedChunks[0]);
|
||||
expect(request).toMatchObject({
|
||||
jsonrpc: '2.0',
|
||||
method: 'testMethod',
|
||||
params: { key: 'value' },
|
||||
});
|
||||
expect(request.id).toBeDefined();
|
||||
|
||||
// Simulate response
|
||||
const response = {
|
||||
jsonrpc: '2.0',
|
||||
id: request.id,
|
||||
result: { success: true },
|
||||
};
|
||||
peerController.enqueue(
|
||||
new TextEncoder().encode(JSON.stringify(response) + '\n'),
|
||||
);
|
||||
|
||||
const result = await responsePromise;
|
||||
expect(result).toEqual({ success: true });
|
||||
});
|
||||
|
||||
it('should send a notification', async () => {
|
||||
await connection.sendNotification('notifyMethod', { key: 'value' });
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(receivedChunks.length).toBeGreaterThan(0);
|
||||
});
|
||||
const notification = JSON.parse(receivedChunks[0]);
|
||||
expect(notification).toMatchObject({
|
||||
jsonrpc: '2.0',
|
||||
method: 'notifyMethod',
|
||||
params: { key: 'value' },
|
||||
});
|
||||
expect(notification.id).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should handle incoming requests', async () => {
|
||||
handler.mockResolvedValue({ result: 'ok' });
|
||||
|
||||
const request = {
|
||||
jsonrpc: '2.0',
|
||||
id: 1,
|
||||
method: 'incomingMethod',
|
||||
params: { foo: 'bar' },
|
||||
};
|
||||
peerController.enqueue(
|
||||
new TextEncoder().encode(JSON.stringify(request) + '\n'),
|
||||
);
|
||||
|
||||
// Wait for handler to be called and response to be written
|
||||
await vi.waitFor(() => {
|
||||
expect(handler).toHaveBeenCalledWith('incomingMethod', { foo: 'bar' });
|
||||
expect(receivedChunks.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
const response = JSON.parse(receivedChunks[receivedChunks.length - 1]);
|
||||
expect(response).toMatchObject({
|
||||
jsonrpc: '2.0',
|
||||
id: 1,
|
||||
result: { result: 'ok' },
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle incoming notifications', async () => {
|
||||
const notification = {
|
||||
jsonrpc: '2.0',
|
||||
method: 'incomingNotify',
|
||||
params: { foo: 'bar' },
|
||||
};
|
||||
peerController.enqueue(
|
||||
new TextEncoder().encode(JSON.stringify(notification) + '\n'),
|
||||
);
|
||||
|
||||
// Wait for handler to be called
|
||||
await vi.waitFor(() => {
|
||||
expect(handler).toHaveBeenCalledWith('incomingNotify', { foo: 'bar' });
|
||||
});
|
||||
// Notifications don't send responses
|
||||
expect(receivedChunks.length).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle request errors from handler', async () => {
|
||||
handler.mockRejectedValue(new Error('Handler failed'));
|
||||
|
||||
const request = {
|
||||
jsonrpc: '2.0',
|
||||
id: 2,
|
||||
method: 'failMethod',
|
||||
};
|
||||
peerController.enqueue(
|
||||
new TextEncoder().encode(JSON.stringify(request) + '\n'),
|
||||
);
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(receivedChunks.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
const response = JSON.parse(receivedChunks[receivedChunks.length - 1]);
|
||||
expect(response).toMatchObject({
|
||||
jsonrpc: '2.0',
|
||||
id: 2,
|
||||
error: {
|
||||
code: -32603,
|
||||
message: 'Internal error',
|
||||
data: { details: 'Handler failed' },
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle RequestError from handler', async () => {
|
||||
handler.mockRejectedValue(RequestError.methodNotFound('Unknown method'));
|
||||
|
||||
const request = {
|
||||
jsonrpc: '2.0',
|
||||
id: 3,
|
||||
method: 'unknown',
|
||||
};
|
||||
peerController.enqueue(
|
||||
new TextEncoder().encode(JSON.stringify(request) + '\n'),
|
||||
);
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(receivedChunks.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
const response = JSON.parse(receivedChunks[receivedChunks.length - 1]);
|
||||
expect(response).toMatchObject({
|
||||
jsonrpc: '2.0',
|
||||
id: 3,
|
||||
error: {
|
||||
code: -32601,
|
||||
message: 'Method not found',
|
||||
data: { details: 'Unknown method' },
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle response errors', async () => {
|
||||
const responsePromise = connection.sendRequest('testMethod');
|
||||
|
||||
// Verify request was sent
|
||||
await vi.waitFor(() => {
|
||||
expect(receivedChunks.length).toBeGreaterThan(0);
|
||||
});
|
||||
const request = JSON.parse(receivedChunks[0]);
|
||||
|
||||
// Simulate error response
|
||||
const response = {
|
||||
jsonrpc: '2.0',
|
||||
id: request.id,
|
||||
error: {
|
||||
code: -32000,
|
||||
message: 'Custom error',
|
||||
},
|
||||
};
|
||||
peerController.enqueue(
|
||||
new TextEncoder().encode(JSON.stringify(response) + '\n'),
|
||||
);
|
||||
|
||||
await expect(responsePromise).rejects.toMatchObject({
|
||||
code: -32000,
|
||||
message: 'Custom error',
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,231 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { z } from 'zod';
|
||||
import { coreEvents } from '@google/gemini-cli-core';
|
||||
import { type Result, type ErrorResponse } from './schema.js';
|
||||
import type { WritableStream, ReadableStream } from 'node:stream/web';
|
||||
|
||||
export class RequestError extends Error {
|
||||
data?: { details?: string };
|
||||
|
||||
constructor(
|
||||
public code: number,
|
||||
message: string,
|
||||
details?: string,
|
||||
) {
|
||||
super(message);
|
||||
this.name = 'RequestError';
|
||||
if (details) {
|
||||
this.data = { details };
|
||||
}
|
||||
}
|
||||
|
||||
static parseError(details?: string): RequestError {
|
||||
return new RequestError(-32700, 'Parse error', details);
|
||||
}
|
||||
|
||||
static invalidRequest(details?: string): RequestError {
|
||||
return new RequestError(-32600, 'Invalid request', details);
|
||||
}
|
||||
|
||||
static methodNotFound(details?: string): RequestError {
|
||||
return new RequestError(-32601, 'Method not found', details);
|
||||
}
|
||||
|
||||
static invalidParams(details?: string): RequestError {
|
||||
return new RequestError(-32602, 'Invalid params', details);
|
||||
}
|
||||
|
||||
static internalError(details?: string): RequestError {
|
||||
return new RequestError(-32603, 'Internal error', details);
|
||||
}
|
||||
|
||||
static authRequired(details?: string): RequestError {
|
||||
return new RequestError(-32000, 'Authentication required', details);
|
||||
}
|
||||
|
||||
toResult<T>(): Result<T> {
|
||||
return {
|
||||
error: {
|
||||
code: this.code,
|
||||
message: this.message,
|
||||
data: this.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
type AnyMessage = AnyRequest | AnyResponse | AnyNotification;
|
||||
|
||||
type AnyRequest = {
|
||||
jsonrpc: '2.0';
|
||||
id: string | number;
|
||||
method: string;
|
||||
params?: unknown;
|
||||
};
|
||||
|
||||
type AnyResponse = {
|
||||
jsonrpc: '2.0';
|
||||
id: string | number;
|
||||
} & Result<unknown>;
|
||||
|
||||
type AnyNotification = {
|
||||
jsonrpc: '2.0';
|
||||
method: string;
|
||||
params?: unknown;
|
||||
};
|
||||
|
||||
type PendingResponse = {
|
||||
resolve: (response: unknown) => void;
|
||||
reject: (error: ErrorResponse) => void;
|
||||
};
|
||||
|
||||
export type MethodHandler = (
|
||||
method: string,
|
||||
params: unknown,
|
||||
) => Promise<unknown>;
|
||||
|
||||
export class Connection {
|
||||
#pendingResponses: Map<string | number, PendingResponse> = new Map();
|
||||
#nextRequestId: number = 0;
|
||||
#handler: MethodHandler;
|
||||
#peerInput: WritableStream<Uint8Array>;
|
||||
#writeQueue: Promise<void> = Promise.resolve();
|
||||
#textEncoder: TextEncoder;
|
||||
|
||||
constructor(
|
||||
handler: MethodHandler,
|
||||
peerInput: WritableStream<Uint8Array>,
|
||||
peerOutput: ReadableStream<Uint8Array>,
|
||||
) {
|
||||
this.#handler = handler;
|
||||
this.#peerInput = peerInput;
|
||||
this.#textEncoder = new TextEncoder();
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
this.#receive(peerOutput);
|
||||
}
|
||||
|
||||
async #receive(output: ReadableStream<Uint8Array>) {
|
||||
let content = '';
|
||||
const decoder = new TextDecoder();
|
||||
for await (const chunk of output) {
|
||||
content += decoder.decode(chunk, { stream: true });
|
||||
const lines = content.split('\n');
|
||||
content = lines.pop() || '';
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmedLine = line.trim();
|
||||
|
||||
if (trimmedLine) {
|
||||
const message = JSON.parse(trimmedLine);
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
this.#processMessage(message);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async #processMessage(message: AnyMessage) {
|
||||
if ('method' in message && 'id' in message) {
|
||||
// It's a request
|
||||
const response = await this.#tryCallHandler(
|
||||
message.method,
|
||||
message.params,
|
||||
);
|
||||
|
||||
await this.#sendMessage({
|
||||
jsonrpc: '2.0',
|
||||
id: message.id,
|
||||
...response,
|
||||
});
|
||||
} else if ('method' in message) {
|
||||
// It's a notification
|
||||
await this.#tryCallHandler(message.method, message.params);
|
||||
} else if ('id' in message) {
|
||||
// It's a response
|
||||
this.#handleResponse(message);
|
||||
}
|
||||
}
|
||||
|
||||
async #tryCallHandler(
|
||||
method: string,
|
||||
params?: unknown,
|
||||
): Promise<Result<unknown>> {
|
||||
try {
|
||||
const result = await this.#handler(method, params);
|
||||
return { result: result ?? null };
|
||||
} catch (error: unknown) {
|
||||
if (error instanceof RequestError) {
|
||||
return error.toResult();
|
||||
}
|
||||
|
||||
if (error instanceof z.ZodError) {
|
||||
return RequestError.invalidParams(
|
||||
JSON.stringify(error.format(), undefined, 2),
|
||||
).toResult();
|
||||
}
|
||||
|
||||
let details;
|
||||
|
||||
if (error instanceof Error) {
|
||||
details = error.message;
|
||||
} else if (
|
||||
typeof error === 'object' &&
|
||||
error != null &&
|
||||
'message' in error &&
|
||||
typeof error.message === 'string'
|
||||
) {
|
||||
details = error.message;
|
||||
}
|
||||
|
||||
return RequestError.internalError(details).toResult();
|
||||
}
|
||||
}
|
||||
|
||||
#handleResponse(response: AnyResponse) {
|
||||
const pendingResponse = this.#pendingResponses.get(response.id);
|
||||
if (pendingResponse) {
|
||||
if ('result' in response) {
|
||||
pendingResponse.resolve(response.result);
|
||||
} else if ('error' in response) {
|
||||
pendingResponse.reject(response.error);
|
||||
}
|
||||
this.#pendingResponses.delete(response.id);
|
||||
}
|
||||
}
|
||||
|
||||
async sendRequest<Req, Resp>(method: string, params?: Req): Promise<Resp> {
|
||||
const id = this.#nextRequestId++;
|
||||
const responsePromise = new Promise((resolve, reject) => {
|
||||
this.#pendingResponses.set(id, { resolve, reject });
|
||||
});
|
||||
await this.#sendMessage({ jsonrpc: '2.0', id, method, params });
|
||||
return responsePromise as Promise<Resp>;
|
||||
}
|
||||
|
||||
async sendNotification<N>(method: string, params?: N): Promise<void> {
|
||||
await this.#sendMessage({ jsonrpc: '2.0', method, params });
|
||||
}
|
||||
|
||||
async #sendMessage(json: AnyMessage) {
|
||||
const content = JSON.stringify(json) + '\n';
|
||||
this.#writeQueue = this.#writeQueue
|
||||
.then(async () => {
|
||||
const writer = this.#peerInput.getWriter();
|
||||
try {
|
||||
await writer.write(this.#textEncoder.encode(content));
|
||||
} finally {
|
||||
writer.releaseLock();
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
// Continue processing writes on error
|
||||
coreEvents.emitFeedback('error', 'ACP write error.', error);
|
||||
});
|
||||
return this.#writeQueue;
|
||||
}
|
||||
}
|
||||
@@ -6,21 +6,21 @@
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, type Mocked } from 'vitest';
|
||||
import { AcpFileSystemService } from './fileSystemService.js';
|
||||
import type { Client } from './acp.js';
|
||||
import type { AgentSideConnection } from '@agentclientprotocol/sdk';
|
||||
import type { FileSystemService } from '@google/gemini-cli-core';
|
||||
|
||||
describe('AcpFileSystemService', () => {
|
||||
let mockClient: Mocked<Client>;
|
||||
let mockConnection: Mocked<AgentSideConnection>;
|
||||
let mockFallback: Mocked<FileSystemService>;
|
||||
let service: AcpFileSystemService;
|
||||
|
||||
beforeEach(() => {
|
||||
mockClient = {
|
||||
mockConnection = {
|
||||
requestPermission: vi.fn(),
|
||||
sessionUpdate: vi.fn(),
|
||||
writeTextFile: vi.fn(),
|
||||
readTextFile: vi.fn(),
|
||||
};
|
||||
} as unknown as Mocked<AgentSideConnection>;
|
||||
mockFallback = {
|
||||
readTextFile: vi.fn(),
|
||||
writeTextFile: vi.fn(),
|
||||
@@ -31,16 +31,14 @@ describe('AcpFileSystemService', () => {
|
||||
it.each([
|
||||
{
|
||||
capability: true,
|
||||
desc: 'client if capability exists',
|
||||
desc: 'connection if capability exists',
|
||||
setup: () => {
|
||||
mockClient.readTextFile.mockResolvedValue({ content: 'content' });
|
||||
mockConnection.readTextFile.mockResolvedValue({ content: 'content' });
|
||||
},
|
||||
verify: () => {
|
||||
expect(mockClient.readTextFile).toHaveBeenCalledWith({
|
||||
expect(mockConnection.readTextFile).toHaveBeenCalledWith({
|
||||
path: '/path/to/file',
|
||||
sessionId: 'session-1',
|
||||
line: null,
|
||||
limit: null,
|
||||
});
|
||||
expect(mockFallback.readTextFile).not.toHaveBeenCalled();
|
||||
},
|
||||
@@ -55,12 +53,12 @@ describe('AcpFileSystemService', () => {
|
||||
expect(mockFallback.readTextFile).toHaveBeenCalledWith(
|
||||
'/path/to/file',
|
||||
);
|
||||
expect(mockClient.readTextFile).not.toHaveBeenCalled();
|
||||
expect(mockConnection.readTextFile).not.toHaveBeenCalled();
|
||||
},
|
||||
},
|
||||
])('should use $desc', async ({ capability, setup, verify }) => {
|
||||
service = new AcpFileSystemService(
|
||||
mockClient,
|
||||
mockConnection,
|
||||
'session-1',
|
||||
{ readTextFile: capability, writeTextFile: true },
|
||||
mockFallback,
|
||||
@@ -78,9 +76,9 @@ describe('AcpFileSystemService', () => {
|
||||
it.each([
|
||||
{
|
||||
capability: true,
|
||||
desc: 'client if capability exists',
|
||||
desc: 'connection if capability exists',
|
||||
verify: () => {
|
||||
expect(mockClient.writeTextFile).toHaveBeenCalledWith({
|
||||
expect(mockConnection.writeTextFile).toHaveBeenCalledWith({
|
||||
path: '/path/to/file',
|
||||
content: 'content',
|
||||
sessionId: 'session-1',
|
||||
@@ -96,12 +94,12 @@ describe('AcpFileSystemService', () => {
|
||||
'/path/to/file',
|
||||
'content',
|
||||
);
|
||||
expect(mockClient.writeTextFile).not.toHaveBeenCalled();
|
||||
expect(mockConnection.writeTextFile).not.toHaveBeenCalled();
|
||||
},
|
||||
},
|
||||
])('should use $desc', async ({ capability, verify }) => {
|
||||
service = new AcpFileSystemService(
|
||||
mockClient,
|
||||
mockConnection,
|
||||
'session-1',
|
||||
{ writeTextFile: capability, readTextFile: true },
|
||||
mockFallback,
|
||||
|
||||
@@ -5,14 +5,14 @@
|
||||
*/
|
||||
|
||||
import type { FileSystemService } from '@google/gemini-cli-core';
|
||||
import type * as acp from './acp.js';
|
||||
import type * as acp from '@agentclientprotocol/sdk';
|
||||
|
||||
/**
|
||||
* ACP client-based implementation of FileSystemService
|
||||
*/
|
||||
export class AcpFileSystemService implements FileSystemService {
|
||||
constructor(
|
||||
private readonly client: acp.Client,
|
||||
private readonly connection: acp.AgentSideConnection,
|
||||
private readonly sessionId: string,
|
||||
private readonly capabilities: acp.FileSystemCapability,
|
||||
private readonly fallback: FileSystemService,
|
||||
@@ -23,11 +23,9 @@ export class AcpFileSystemService implements FileSystemService {
|
||||
return this.fallback.readTextFile(filePath);
|
||||
}
|
||||
|
||||
const response = await this.client.readTextFile({
|
||||
const response = await this.connection.readTextFile({
|
||||
path: filePath,
|
||||
sessionId: this.sessionId,
|
||||
line: null,
|
||||
limit: null,
|
||||
});
|
||||
|
||||
return response.content;
|
||||
@@ -38,7 +36,7 @@ export class AcpFileSystemService implements FileSystemService {
|
||||
return this.fallback.writeTextFile(filePath, content);
|
||||
}
|
||||
|
||||
await this.client.writeTextFile({
|
||||
await this.connection.writeTextFile({
|
||||
path: filePath,
|
||||
content,
|
||||
sessionId: this.sessionId,
|
||||
|
||||
@@ -1,480 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { z } from 'zod';
|
||||
|
||||
export const AGENT_METHODS = {
|
||||
authenticate: 'authenticate',
|
||||
initialize: 'initialize',
|
||||
session_cancel: 'session/cancel',
|
||||
session_load: 'session/load',
|
||||
session_new: 'session/new',
|
||||
session_prompt: 'session/prompt',
|
||||
};
|
||||
|
||||
export const CLIENT_METHODS = {
|
||||
fs_read_text_file: 'fs/read_text_file',
|
||||
fs_write_text_file: 'fs/write_text_file',
|
||||
session_request_permission: 'session/request_permission',
|
||||
session_update: 'session/update',
|
||||
};
|
||||
|
||||
export const PROTOCOL_VERSION = 1;
|
||||
|
||||
export const authMethodSchema = z.object({
|
||||
description: z.string().nullable(),
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
});
|
||||
|
||||
export type WriteTextFileRequest = z.infer<typeof writeTextFileRequestSchema>;
|
||||
|
||||
export type ReadTextFileRequest = z.infer<typeof readTextFileRequestSchema>;
|
||||
|
||||
export type PermissionOptionKind = z.infer<typeof permissionOptionKindSchema>;
|
||||
|
||||
export type Role = z.infer<typeof roleSchema>;
|
||||
|
||||
export type TextResourceContents = z.infer<typeof textResourceContentsSchema>;
|
||||
|
||||
export type BlobResourceContents = z.infer<typeof blobResourceContentsSchema>;
|
||||
|
||||
export type ToolKind = z.infer<typeof toolKindSchema>;
|
||||
|
||||
export type ToolCallStatus = z.infer<typeof toolCallStatusSchema>;
|
||||
|
||||
export type WriteTextFileResponse = z.infer<typeof writeTextFileResponseSchema>;
|
||||
|
||||
export type ReadTextFileResponse = z.infer<typeof readTextFileResponseSchema>;
|
||||
|
||||
export type RequestPermissionOutcome = z.infer<
|
||||
typeof requestPermissionOutcomeSchema
|
||||
>;
|
||||
|
||||
export type CancelNotification = z.infer<typeof cancelNotificationSchema>;
|
||||
|
||||
export type AuthenticateRequest = z.infer<typeof authenticateRequestSchema>;
|
||||
|
||||
export type AuthenticateResponse = z.infer<typeof authenticateResponseSchema>;
|
||||
|
||||
export type NewSessionResponse = z.infer<typeof newSessionResponseSchema>;
|
||||
|
||||
export type LoadSessionResponse = z.infer<typeof loadSessionResponseSchema>;
|
||||
|
||||
export type StopReason = z.infer<typeof stopReasonSchema>;
|
||||
|
||||
export type PromptResponse = z.infer<typeof promptResponseSchema>;
|
||||
|
||||
export type ToolCallLocation = z.infer<typeof toolCallLocationSchema>;
|
||||
|
||||
export type PlanEntry = z.infer<typeof planEntrySchema>;
|
||||
|
||||
export type PermissionOption = z.infer<typeof permissionOptionSchema>;
|
||||
|
||||
export type Annotations = z.infer<typeof annotationsSchema>;
|
||||
|
||||
export type RequestPermissionResponse = z.infer<
|
||||
typeof requestPermissionResponseSchema
|
||||
>;
|
||||
|
||||
export type FileSystemCapability = z.infer<typeof fileSystemCapabilitySchema>;
|
||||
|
||||
export type EnvVariable = z.infer<typeof envVariableSchema>;
|
||||
|
||||
export type McpServer = z.infer<typeof mcpServerSchema>;
|
||||
|
||||
export type AgentCapabilities = z.infer<typeof agentCapabilitiesSchema>;
|
||||
|
||||
export type AuthMethod = z.infer<typeof authMethodSchema>;
|
||||
|
||||
export type PromptCapabilities = z.infer<typeof promptCapabilitiesSchema>;
|
||||
|
||||
export type ClientResponse = z.infer<typeof clientResponseSchema>;
|
||||
|
||||
export type ClientNotification = z.infer<typeof clientNotificationSchema>;
|
||||
|
||||
export type EmbeddedResourceResource = z.infer<
|
||||
typeof embeddedResourceResourceSchema
|
||||
>;
|
||||
|
||||
export type NewSessionRequest = z.infer<typeof newSessionRequestSchema>;
|
||||
|
||||
export type LoadSessionRequest = z.infer<typeof loadSessionRequestSchema>;
|
||||
|
||||
export type InitializeResponse = z.infer<typeof initializeResponseSchema>;
|
||||
|
||||
export type ContentBlock = z.infer<typeof contentBlockSchema>;
|
||||
|
||||
export type ToolCallContent = z.infer<typeof toolCallContentSchema>;
|
||||
|
||||
export type ToolCall = z.infer<typeof toolCallSchema>;
|
||||
|
||||
export type ClientCapabilities = z.infer<typeof clientCapabilitiesSchema>;
|
||||
|
||||
export type PromptRequest = z.infer<typeof promptRequestSchema>;
|
||||
|
||||
export type SessionUpdate = z.infer<typeof sessionUpdateSchema>;
|
||||
|
||||
export type AgentResponse = z.infer<typeof agentResponseSchema>;
|
||||
|
||||
export type RequestPermissionRequest = z.infer<
|
||||
typeof requestPermissionRequestSchema
|
||||
>;
|
||||
|
||||
export type InitializeRequest = z.infer<typeof initializeRequestSchema>;
|
||||
|
||||
export type SessionNotification = z.infer<typeof sessionNotificationSchema>;
|
||||
|
||||
export type ClientRequest = z.infer<typeof clientRequestSchema>;
|
||||
|
||||
export type AgentRequest = z.infer<typeof agentRequestSchema>;
|
||||
|
||||
export type AgentNotification = z.infer<typeof agentNotificationSchema>;
|
||||
|
||||
export type Result<T> =
|
||||
| {
|
||||
result: T;
|
||||
}
|
||||
| {
|
||||
error: ErrorResponse;
|
||||
};
|
||||
|
||||
export type ErrorResponse = {
|
||||
code: number;
|
||||
message: string;
|
||||
data?: unknown;
|
||||
};
|
||||
|
||||
export const writeTextFileRequestSchema = z.object({
|
||||
content: z.string(),
|
||||
path: z.string(),
|
||||
sessionId: z.string(),
|
||||
});
|
||||
|
||||
export const readTextFileRequestSchema = z.object({
|
||||
limit: z.number().optional().nullable(),
|
||||
line: z.number().optional().nullable(),
|
||||
path: z.string(),
|
||||
sessionId: z.string(),
|
||||
});
|
||||
|
||||
export const permissionOptionKindSchema = z.union([
|
||||
z.literal('allow_once'),
|
||||
z.literal('allow_always'),
|
||||
z.literal('reject_once'),
|
||||
z.literal('reject_always'),
|
||||
]);
|
||||
|
||||
export const roleSchema = z.union([z.literal('assistant'), z.literal('user')]);
|
||||
|
||||
export const textResourceContentsSchema = z.object({
|
||||
mimeType: z.string().optional().nullable(),
|
||||
text: z.string(),
|
||||
uri: z.string(),
|
||||
});
|
||||
|
||||
export const blobResourceContentsSchema = z.object({
|
||||
blob: z.string(),
|
||||
mimeType: z.string().optional().nullable(),
|
||||
uri: z.string(),
|
||||
});
|
||||
|
||||
export const toolKindSchema = z.union([
|
||||
z.literal('read'),
|
||||
z.literal('edit'),
|
||||
z.literal('delete'),
|
||||
z.literal('move'),
|
||||
z.literal('search'),
|
||||
z.literal('execute'),
|
||||
z.literal('think'),
|
||||
z.literal('fetch'),
|
||||
z.literal('other'),
|
||||
]);
|
||||
|
||||
export const toolCallStatusSchema = z.union([
|
||||
z.literal('pending'),
|
||||
z.literal('in_progress'),
|
||||
z.literal('completed'),
|
||||
z.literal('failed'),
|
||||
]);
|
||||
|
||||
export const writeTextFileResponseSchema = z.null();
|
||||
|
||||
export const readTextFileResponseSchema = z.object({
|
||||
content: z.string(),
|
||||
});
|
||||
|
||||
export const requestPermissionOutcomeSchema = z.union([
|
||||
z.object({
|
||||
outcome: z.literal('cancelled'),
|
||||
}),
|
||||
z.object({
|
||||
optionId: z.string(),
|
||||
outcome: z.literal('selected'),
|
||||
}),
|
||||
]);
|
||||
|
||||
export const cancelNotificationSchema = z.object({
|
||||
sessionId: z.string(),
|
||||
});
|
||||
|
||||
export const authenticateRequestSchema = z.object({
|
||||
methodId: z.string(),
|
||||
});
|
||||
|
||||
export const authenticateResponseSchema = z.null();
|
||||
|
||||
export const newSessionResponseSchema = z.object({
|
||||
sessionId: z.string(),
|
||||
});
|
||||
|
||||
export const loadSessionResponseSchema = z.null();
|
||||
|
||||
export const stopReasonSchema = z.union([
|
||||
z.literal('end_turn'),
|
||||
z.literal('max_tokens'),
|
||||
z.literal('refusal'),
|
||||
z.literal('cancelled'),
|
||||
]);
|
||||
|
||||
export const promptResponseSchema = z.object({
|
||||
stopReason: stopReasonSchema,
|
||||
});
|
||||
|
||||
export const toolCallLocationSchema = z.object({
|
||||
line: z.number().optional().nullable(),
|
||||
path: z.string(),
|
||||
});
|
||||
|
||||
export const planEntrySchema = z.object({
|
||||
content: z.string(),
|
||||
priority: z.union([z.literal('high'), z.literal('medium'), z.literal('low')]),
|
||||
status: z.union([
|
||||
z.literal('pending'),
|
||||
z.literal('in_progress'),
|
||||
z.literal('completed'),
|
||||
]),
|
||||
});
|
||||
|
||||
export const permissionOptionSchema = z.object({
|
||||
kind: permissionOptionKindSchema,
|
||||
name: z.string(),
|
||||
optionId: z.string(),
|
||||
});
|
||||
|
||||
export const annotationsSchema = z.object({
|
||||
audience: z.array(roleSchema).optional().nullable(),
|
||||
lastModified: z.string().optional().nullable(),
|
||||
priority: z.number().optional().nullable(),
|
||||
});
|
||||
|
||||
export const requestPermissionResponseSchema = z.object({
|
||||
outcome: requestPermissionOutcomeSchema,
|
||||
});
|
||||
|
||||
export const fileSystemCapabilitySchema = z.object({
|
||||
readTextFile: z.boolean(),
|
||||
writeTextFile: z.boolean(),
|
||||
});
|
||||
|
||||
export const envVariableSchema = z.object({
|
||||
name: z.string(),
|
||||
value: z.string(),
|
||||
});
|
||||
|
||||
export const mcpServerSchema = z.object({
|
||||
args: z.array(z.string()),
|
||||
command: z.string(),
|
||||
env: z.array(envVariableSchema),
|
||||
name: z.string(),
|
||||
});
|
||||
|
||||
export const promptCapabilitiesSchema = z.object({
|
||||
audio: z.boolean().optional(),
|
||||
embeddedContext: z.boolean().optional(),
|
||||
image: z.boolean().optional(),
|
||||
});
|
||||
|
||||
export const agentCapabilitiesSchema = z.object({
|
||||
loadSession: z.boolean().optional(),
|
||||
promptCapabilities: promptCapabilitiesSchema.optional(),
|
||||
});
|
||||
|
||||
export const clientResponseSchema = z.union([
|
||||
writeTextFileResponseSchema,
|
||||
readTextFileResponseSchema,
|
||||
requestPermissionResponseSchema,
|
||||
]);
|
||||
|
||||
export const clientNotificationSchema = cancelNotificationSchema;
|
||||
|
||||
export const embeddedResourceResourceSchema = z.union([
|
||||
textResourceContentsSchema,
|
||||
blobResourceContentsSchema,
|
||||
]);
|
||||
|
||||
export const newSessionRequestSchema = z.object({
|
||||
cwd: z.string(),
|
||||
mcpServers: z.array(mcpServerSchema),
|
||||
});
|
||||
|
||||
export const loadSessionRequestSchema = z.object({
|
||||
cwd: z.string(),
|
||||
mcpServers: z.array(mcpServerSchema),
|
||||
sessionId: z.string(),
|
||||
});
|
||||
|
||||
export const initializeResponseSchema = z.object({
|
||||
agentCapabilities: agentCapabilitiesSchema,
|
||||
authMethods: z.array(authMethodSchema),
|
||||
protocolVersion: z.number(),
|
||||
});
|
||||
|
||||
export const contentBlockSchema = z.union([
|
||||
z.object({
|
||||
annotations: annotationsSchema.optional().nullable(),
|
||||
text: z.string(),
|
||||
type: z.literal('text'),
|
||||
}),
|
||||
z.object({
|
||||
annotations: annotationsSchema.optional().nullable(),
|
||||
data: z.string(),
|
||||
mimeType: z.string(),
|
||||
type: z.literal('image'),
|
||||
}),
|
||||
z.object({
|
||||
annotations: annotationsSchema.optional().nullable(),
|
||||
data: z.string(),
|
||||
mimeType: z.string(),
|
||||
type: z.literal('audio'),
|
||||
}),
|
||||
z.object({
|
||||
annotations: annotationsSchema.optional().nullable(),
|
||||
description: z.string().optional().nullable(),
|
||||
mimeType: z.string().optional().nullable(),
|
||||
name: z.string(),
|
||||
size: z.number().optional().nullable(),
|
||||
title: z.string().optional().nullable(),
|
||||
type: z.literal('resource_link'),
|
||||
uri: z.string(),
|
||||
}),
|
||||
z.object({
|
||||
annotations: annotationsSchema.optional().nullable(),
|
||||
resource: embeddedResourceResourceSchema,
|
||||
type: z.literal('resource'),
|
||||
}),
|
||||
]);
|
||||
|
||||
export const toolCallContentSchema = z.union([
|
||||
z.object({
|
||||
content: contentBlockSchema,
|
||||
type: z.literal('content'),
|
||||
}),
|
||||
z.object({
|
||||
newText: z.string(),
|
||||
oldText: z.string().nullable(),
|
||||
path: z.string(),
|
||||
type: z.literal('diff'),
|
||||
}),
|
||||
]);
|
||||
|
||||
export const toolCallSchema = z.object({
|
||||
content: z.array(toolCallContentSchema).optional(),
|
||||
kind: toolKindSchema,
|
||||
locations: z.array(toolCallLocationSchema).optional(),
|
||||
rawInput: z.unknown().optional(),
|
||||
status: toolCallStatusSchema,
|
||||
title: z.string(),
|
||||
toolCallId: z.string(),
|
||||
});
|
||||
|
||||
export const clientCapabilitiesSchema = z.object({
|
||||
fs: fileSystemCapabilitySchema,
|
||||
});
|
||||
|
||||
export const promptRequestSchema = z.object({
|
||||
prompt: z.array(contentBlockSchema),
|
||||
sessionId: z.string(),
|
||||
});
|
||||
|
||||
export const sessionUpdateSchema = z.union([
|
||||
z.object({
|
||||
content: contentBlockSchema,
|
||||
sessionUpdate: z.literal('user_message_chunk'),
|
||||
}),
|
||||
z.object({
|
||||
content: contentBlockSchema,
|
||||
sessionUpdate: z.literal('agent_message_chunk'),
|
||||
}),
|
||||
z.object({
|
||||
content: contentBlockSchema,
|
||||
sessionUpdate: z.literal('agent_thought_chunk'),
|
||||
}),
|
||||
z.object({
|
||||
content: z.array(toolCallContentSchema).optional(),
|
||||
kind: toolKindSchema,
|
||||
locations: z.array(toolCallLocationSchema).optional(),
|
||||
rawInput: z.unknown().optional(),
|
||||
sessionUpdate: z.literal('tool_call'),
|
||||
status: toolCallStatusSchema,
|
||||
title: z.string(),
|
||||
toolCallId: z.string(),
|
||||
}),
|
||||
z.object({
|
||||
content: z.array(toolCallContentSchema).optional().nullable(),
|
||||
kind: toolKindSchema.optional().nullable(),
|
||||
locations: z.array(toolCallLocationSchema).optional().nullable(),
|
||||
rawInput: z.unknown().optional(),
|
||||
sessionUpdate: z.literal('tool_call_update'),
|
||||
status: toolCallStatusSchema.optional().nullable(),
|
||||
title: z.string().optional().nullable(),
|
||||
toolCallId: z.string(),
|
||||
}),
|
||||
z.object({
|
||||
entries: z.array(planEntrySchema),
|
||||
sessionUpdate: z.literal('plan'),
|
||||
}),
|
||||
]);
|
||||
|
||||
export const agentResponseSchema = z.union([
|
||||
initializeResponseSchema,
|
||||
authenticateResponseSchema,
|
||||
newSessionResponseSchema,
|
||||
loadSessionResponseSchema,
|
||||
promptResponseSchema,
|
||||
]);
|
||||
|
||||
export const requestPermissionRequestSchema = z.object({
|
||||
options: z.array(permissionOptionSchema),
|
||||
sessionId: z.string(),
|
||||
toolCall: toolCallSchema,
|
||||
});
|
||||
|
||||
export const initializeRequestSchema = z.object({
|
||||
clientCapabilities: clientCapabilitiesSchema,
|
||||
protocolVersion: z.number(),
|
||||
});
|
||||
|
||||
export const sessionNotificationSchema = z.object({
|
||||
sessionId: z.string(),
|
||||
update: sessionUpdateSchema,
|
||||
});
|
||||
|
||||
export const clientRequestSchema = z.union([
|
||||
writeTextFileRequestSchema,
|
||||
readTextFileRequestSchema,
|
||||
requestPermissionRequestSchema,
|
||||
]);
|
||||
|
||||
export const agentRequestSchema = z.union([
|
||||
initializeRequestSchema,
|
||||
authenticateRequestSchema,
|
||||
newSessionRequestSchema,
|
||||
loadSessionRequestSchema,
|
||||
promptRequestSchema,
|
||||
]);
|
||||
|
||||
export const agentNotificationSchema = sessionNotificationSchema;
|
||||
@@ -15,7 +15,7 @@ import {
|
||||
type Mocked,
|
||||
} from 'vitest';
|
||||
import { GeminiAgent, Session } from './zedIntegration.js';
|
||||
import * as acp from './acp.js';
|
||||
import * as acp from '@agentclientprotocol/sdk';
|
||||
import {
|
||||
AuthType,
|
||||
ToolConfirmationOutcome,
|
||||
@@ -85,7 +85,7 @@ describe('GeminiAgent', () => {
|
||||
let mockConfig: Mocked<Awaited<ReturnType<typeof loadCliConfig>>>;
|
||||
let mockSettings: Mocked<LoadedSettings>;
|
||||
let mockArgv: CliArgs;
|
||||
let mockClient: Mocked<acp.Client>;
|
||||
let mockConnection: Mocked<acp.AgentSideConnection>;
|
||||
let agent: GeminiAgent;
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -106,13 +106,13 @@ describe('GeminiAgent', () => {
|
||||
setValue: vi.fn(),
|
||||
} as unknown as Mocked<LoadedSettings>;
|
||||
mockArgv = {} as unknown as CliArgs;
|
||||
mockClient = {
|
||||
mockConnection = {
|
||||
sessionUpdate: vi.fn(),
|
||||
} as unknown as Mocked<acp.Client>;
|
||||
} as unknown as Mocked<acp.AgentSideConnection>;
|
||||
|
||||
(loadCliConfig as unknown as Mock).mockResolvedValue(mockConfig);
|
||||
|
||||
agent = new GeminiAgent(mockConfig, mockSettings, mockArgv, mockClient);
|
||||
agent = new GeminiAgent(mockConfig, mockSettings, mockArgv, mockConnection);
|
||||
});
|
||||
|
||||
it('should initialize correctly', async () => {
|
||||
@@ -123,7 +123,7 @@ describe('GeminiAgent', () => {
|
||||
|
||||
expect(response.protocolVersion).toBe(acp.PROTOCOL_VERSION);
|
||||
expect(response.authMethods).toHaveLength(3);
|
||||
expect(response.agentCapabilities.loadSession).toBe(false);
|
||||
expect(response.agentCapabilities?.loadSession).toBe(false);
|
||||
});
|
||||
|
||||
it('should authenticate correctly', async () => {
|
||||
@@ -202,7 +202,7 @@ describe('GeminiAgent', () => {
|
||||
});
|
||||
|
||||
it('should initialize file system service if client supports it', async () => {
|
||||
agent = new GeminiAgent(mockConfig, mockSettings, mockArgv, mockClient);
|
||||
agent = new GeminiAgent(mockConfig, mockSettings, mockArgv, mockConnection);
|
||||
await agent.initialize({
|
||||
clientCapabilities: { fs: { readTextFile: true, writeTextFile: true } },
|
||||
protocolVersion: 1,
|
||||
@@ -257,7 +257,7 @@ describe('GeminiAgent', () => {
|
||||
describe('Session', () => {
|
||||
let mockChat: Mocked<GeminiChat>;
|
||||
let mockConfig: Mocked<Config>;
|
||||
let mockClient: Mocked<acp.Client>;
|
||||
let mockConnection: Mocked<acp.AgentSideConnection>;
|
||||
let session: Session;
|
||||
let mockToolRegistry: { getTool: Mock };
|
||||
let mockTool: { kind: string; build: Mock };
|
||||
@@ -292,13 +292,13 @@ describe('Session', () => {
|
||||
getEnableRecursiveFileSearch: vi.fn().mockReturnValue(false),
|
||||
getDebugMode: vi.fn().mockReturnValue(false),
|
||||
} as unknown as Mocked<Config>;
|
||||
mockClient = {
|
||||
mockConnection = {
|
||||
sessionUpdate: vi.fn(),
|
||||
requestPermission: vi.fn(),
|
||||
sendNotification: vi.fn(),
|
||||
} as unknown as Mocked<acp.Client>;
|
||||
} as unknown as Mocked<acp.AgentSideConnection>;
|
||||
|
||||
session = new Session('session-1', mockChat, mockConfig, mockClient);
|
||||
session = new Session('session-1', mockChat, mockConfig, mockConnection);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -322,7 +322,7 @@ describe('Session', () => {
|
||||
});
|
||||
|
||||
expect(mockChat.sendMessageStream).toHaveBeenCalled();
|
||||
expect(mockClient.sessionUpdate).toHaveBeenCalledWith({
|
||||
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith({
|
||||
sessionId: 'session-1',
|
||||
update: {
|
||||
sessionUpdate: 'agent_message_chunk',
|
||||
@@ -361,7 +361,7 @@ describe('Session', () => {
|
||||
|
||||
expect(mockToolRegistry.getTool).toHaveBeenCalledWith('test_tool');
|
||||
expect(mockTool.build).toHaveBeenCalledWith({ foo: 'bar' });
|
||||
expect(mockClient.sessionUpdate).toHaveBeenCalledWith(
|
||||
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
update: expect.objectContaining({
|
||||
sessionUpdate: 'tool_call',
|
||||
@@ -369,7 +369,7 @@ describe('Session', () => {
|
||||
}),
|
||||
}),
|
||||
);
|
||||
expect(mockClient.sessionUpdate).toHaveBeenCalledWith(
|
||||
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
update: expect.objectContaining({
|
||||
sessionUpdate: 'tool_call_update',
|
||||
@@ -392,7 +392,7 @@ describe('Session', () => {
|
||||
execute: vi.fn().mockResolvedValue({ llmContent: 'Tool Result' }),
|
||||
});
|
||||
|
||||
mockClient.requestPermission.mockResolvedValue({
|
||||
mockConnection.requestPermission.mockResolvedValue({
|
||||
outcome: {
|
||||
outcome: 'selected',
|
||||
optionId: ToolConfirmationOutcome.ProceedOnce,
|
||||
@@ -423,7 +423,7 @@ describe('Session', () => {
|
||||
prompt: [{ type: 'text', text: 'Call tool' }],
|
||||
});
|
||||
|
||||
expect(mockClient.requestPermission).toHaveBeenCalled();
|
||||
expect(mockConnection.requestPermission).toHaveBeenCalled();
|
||||
expect(confirmationDetails.onConfirm).toHaveBeenCalledWith(
|
||||
ToolConfirmationOutcome.ProceedOnce,
|
||||
);
|
||||
@@ -441,7 +441,7 @@ describe('Session', () => {
|
||||
execute: vi.fn().mockResolvedValue({ llmContent: 'Tool Result' }),
|
||||
});
|
||||
|
||||
mockClient.requestPermission.mockResolvedValue({
|
||||
mockConnection.requestPermission.mockResolvedValue({
|
||||
outcome: { outcome: 'cancelled' },
|
||||
});
|
||||
|
||||
@@ -635,7 +635,7 @@ describe('Session', () => {
|
||||
prompt: [{ type: 'text', text: 'Call tool' }],
|
||||
});
|
||||
|
||||
expect(mockClient.sessionUpdate).toHaveBeenCalledWith(
|
||||
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
update: expect.objectContaining({
|
||||
sessionUpdate: 'tool_call_update',
|
||||
|
||||
@@ -4,8 +4,6 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { ReadableStream } from 'node:stream/web';
|
||||
|
||||
import type {
|
||||
Config,
|
||||
GeminiChat,
|
||||
@@ -33,7 +31,7 @@ import {
|
||||
createWorkingStdio,
|
||||
startupProfiler,
|
||||
} from '@google/gemini-cli-core';
|
||||
import * as acp from './acp.js';
|
||||
import * as acp from '@agentclientprotocol/sdk';
|
||||
import { AcpFileSystemService } from './fileSystemService.js';
|
||||
import { Readable, Writable } from 'node:stream';
|
||||
import type { Content, Part, FunctionCall } from '@google/genai';
|
||||
@@ -53,13 +51,13 @@ export async function runZedIntegration(
|
||||
argv: CliArgs,
|
||||
) {
|
||||
const { stdout: workingStdout } = createWorkingStdio();
|
||||
const stdout = Writable.toWeb(workingStdout);
|
||||
const stdout = Writable.toWeb(workingStdout) as WritableStream;
|
||||
const stdin = Readable.toWeb(process.stdin) as ReadableStream<Uint8Array>;
|
||||
|
||||
const stream = acp.ndJsonStream(stdout, stdin);
|
||||
new acp.AgentSideConnection(
|
||||
(client: acp.Client) => new GeminiAgent(config, settings, argv, client),
|
||||
stdout,
|
||||
stdin,
|
||||
(connection) => new GeminiAgent(config, settings, argv, connection),
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -71,7 +69,7 @@ export class GeminiAgent {
|
||||
private config: Config,
|
||||
private settings: LoadedSettings,
|
||||
private argv: CliArgs,
|
||||
private client: acp.Client,
|
||||
private connection: acp.AgentSideConnection,
|
||||
) {}
|
||||
|
||||
async initialize(
|
||||
@@ -107,6 +105,10 @@ export class GeminiAgent {
|
||||
audio: true,
|
||||
embeddedContext: true,
|
||||
},
|
||||
mcpCapabilities: {
|
||||
http: true,
|
||||
sse: true,
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -156,7 +158,7 @@ export class GeminiAgent {
|
||||
|
||||
if (this.clientCapabilities?.fs) {
|
||||
const acpFileSystemService = new AcpFileSystemService(
|
||||
this.client,
|
||||
this.connection,
|
||||
sessionId,
|
||||
this.clientCapabilities.fs,
|
||||
config.getFileSystemService(),
|
||||
@@ -166,7 +168,7 @@ export class GeminiAgent {
|
||||
|
||||
const geminiClient = config.getGeminiClient();
|
||||
const chat = await geminiClient.startChat();
|
||||
const session = new Session(sessionId, chat, config, this.client);
|
||||
const session = new Session(sessionId, chat, config, this.connection);
|
||||
this.sessions.set(sessionId, session);
|
||||
|
||||
return {
|
||||
@@ -181,12 +183,37 @@ export class GeminiAgent {
|
||||
): Promise<Config> {
|
||||
const mergedMcpServers = { ...this.settings.merged.mcpServers };
|
||||
|
||||
for (const { command, args, env: rawEnv, name } of mcpServers) {
|
||||
const env: Record<string, string> = {};
|
||||
for (const { name: envName, value } of rawEnv) {
|
||||
env[envName] = value;
|
||||
for (const server of mcpServers) {
|
||||
if (
|
||||
'type' in server &&
|
||||
(server.type === 'sse' || server.type === 'http')
|
||||
) {
|
||||
// HTTP or SSE MCP server
|
||||
const headers = Object.fromEntries(
|
||||
server.headers.map(({ name, value }) => [name, value]),
|
||||
);
|
||||
mergedMcpServers[server.name] = new MCPServerConfig(
|
||||
undefined, // command
|
||||
undefined, // args
|
||||
undefined, // env
|
||||
undefined, // cwd
|
||||
server.type === 'sse' ? server.url : undefined, // url (sse)
|
||||
server.type === 'http' ? server.url : undefined, // httpUrl
|
||||
headers,
|
||||
);
|
||||
} else if ('command' in server) {
|
||||
// Stdio MCP server
|
||||
const env: Record<string, string> = {};
|
||||
for (const { name: envName, value } of server.env) {
|
||||
env[envName] = value;
|
||||
}
|
||||
mergedMcpServers[server.name] = new MCPServerConfig(
|
||||
server.command,
|
||||
server.args,
|
||||
env,
|
||||
cwd,
|
||||
);
|
||||
}
|
||||
mergedMcpServers[name] = new MCPServerConfig(command, args, env, cwd);
|
||||
}
|
||||
|
||||
const settings = { ...this.settings.merged, mcpServers: mergedMcpServers };
|
||||
@@ -222,7 +249,7 @@ export class Session {
|
||||
private readonly id: string,
|
||||
private readonly chat: GeminiChat,
|
||||
private readonly config: Config,
|
||||
private readonly client: acp.Client,
|
||||
private readonly connection: acp.AgentSideConnection,
|
||||
) {}
|
||||
|
||||
async cancelPendingPrompt(): Promise<void> {
|
||||
@@ -340,13 +367,15 @@ export class Session {
|
||||
return { stopReason: 'end_turn' };
|
||||
}
|
||||
|
||||
private async sendUpdate(update: acp.SessionUpdate): Promise<void> {
|
||||
private async sendUpdate(
|
||||
update: acp.SessionNotification['update'],
|
||||
): Promise<void> {
|
||||
const params: acp.SessionNotification = {
|
||||
sessionId: this.id,
|
||||
update,
|
||||
};
|
||||
|
||||
await this.client.sessionUpdate(params);
|
||||
await this.connection.sessionUpdate(params);
|
||||
}
|
||||
|
||||
private async runTool(
|
||||
@@ -432,7 +461,7 @@ export class Session {
|
||||
},
|
||||
};
|
||||
|
||||
const output = await this.client.requestPermission(params);
|
||||
const output = await this.connection.requestPermission(params);
|
||||
const outcome =
|
||||
output.outcome.outcome === 'cancelled'
|
||||
? ToolConfirmationOutcome.Cancel
|
||||
|
||||
Reference in New Issue
Block a user