diff --git a/packages/core/src/ide/ide-client.ts b/packages/core/src/ide/ide-client.ts index 7c4b50d8ba..c8242c02cc 100644 --- a/packages/core/src/ide/ide-client.ts +++ b/packages/core/src/ide/ide-client.ts @@ -7,18 +7,18 @@ import * as fs from 'node:fs'; import { isSubpath } from '../utils/paths.js'; import { detectIde, type DetectedIde, getIdeInfo } from '../ide/detect-ide.js'; +import { ideContextStore } from './ideContext.js'; import { - ideContextStore, + IdeContextNotificationSchema, IdeDiffAcceptedNotificationSchema, IdeDiffClosedNotificationSchema, - CloseDiffResponseSchema, - type DiffUpdateResult, -} from './ideContext.js'; -import { IdeContextNotificationSchema } from './types.js'; + IdeDiffRejectedNotificationSchema, +} from './types.js'; import { getIdeProcessInfo } from './process-utils.js'; import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; +import { CallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'; import * as os from 'node:os'; import * as path from 'node:path'; import { EnvHttpProxyAgent } from 'undici'; @@ -31,6 +31,16 @@ const logger = { error: (...args: any[]) => console.error('[ERROR] [IDEClient]', ...args), }; +export type DiffUpdateResult = + | { + status: 'accepted'; + content?: string; + } + | { + status: 'rejected'; + content: undefined; + }; + export type IDEConnectionState = { status: IDEConnectionStatus; details?: string; // User-facing @@ -193,22 +203,26 @@ export class IdeClient { } /** - * A diff is accepted with any modifications if the user performs one of the - * following actions: - * - Clicks the checkbox icon in the IDE to accept - * - Runs `command+shift+p` > "Gemini CLI: Accept Diff in IDE" to accept - * - Selects "accept" in the CLI UI - * - Saves the file via `ctrl/command+s` + * Opens a diff view in the IDE, allowing the user to review and accept or + * reject changes. * - * A diff is rejected if the user performs one of the following actions: - * - Clicks the "x" icon in the IDE - * - Runs "Gemini CLI: Close Diff in IDE" - * - Selects "no" in the CLI UI - * - Closes the file + * This method sends a request to the IDE to display a diff between the + * current content of a file and the new content provided. It then waits for + * a notification from the IDE indicating that the user has either accepted + * (potentially with manual edits) or rejected the diff. + * + * A mutex ensures that only one diff view can be open at a time to prevent + * race conditions. + * + * @param filePath The absolute path to the file to be diffed. + * @param newContent The proposed new content for the file. + * @returns A promise that resolves with a `DiffUpdateResult`, indicating + * whether the diff was 'accepted' or 'rejected' and including the final + * content if accepted. */ async openDiff( filePath: string, - newContent?: string, + newContent: string, ): Promise { const release = await this.acquireMutex(); @@ -226,6 +240,30 @@ export class IdeClient { newContent, }, }) + .then((result) => { + const parsedResult = CallToolResultSchema.safeParse(result); + if (!parsedResult.success) { + const err = new Error('Failed to parse tool result from IDE'); + logger.debug(err, parsedResult.error); + this.diffResponses.delete(filePath); + reject(err); + return; + } + + if (parsedResult.data.isError) { + const textPart = parsedResult.data.content.find( + (part) => part.type === 'text', + ); + const errorMessage = + textPart?.text ?? `Tool 'openDiff' reported an error.`; + logger.debug( + `callTool for ${filePath} failed with isError:`, + errorMessage, + ); + this.diffResponses.delete(filePath); + reject(new Error(errorMessage)); + } + }) .catch((err) => { logger.debug(`callTool for ${filePath} failed:`, err); this.diffResponses.delete(filePath); @@ -279,14 +317,56 @@ export class IdeClient { }, }); - if (result) { - const parsed = CloseDiffResponseSchema.parse(result); - return parsed.content; + if (!result) { + return undefined; + } + + const parsedResult = CallToolResultSchema.safeParse(result); + if (!parsedResult.success) { + logger.debug( + `Failed to parse tool result from IDE for closeDiff:`, + parsedResult.error, + ); + return undefined; + } + + if (parsedResult.data.isError) { + const textPart = parsedResult.data.content.find( + (part) => part.type === 'text', + ); + const errorMessage = + textPart?.text ?? `Tool 'closeDiff' reported an error.`; + logger.debug( + `callTool for closeDiff ${filePath} failed with isError:`, + errorMessage, + ); + return undefined; + } + + const textPart = parsedResult.data.content.find( + (part) => part.type === 'text', + ); + + if (textPart?.text) { + try { + const parsedJson = JSON.parse(textPart.text); + if (parsedJson && typeof parsedJson.content === 'string') { + return parsedJson.content; + } + if (parsedJson && parsedJson.content === null) { + return undefined; + } + } catch (_e) { + logger.debug( + `Invalid JSON in closeDiff response for ${filePath}:`, + textPart.text, + ); + } } } catch (err) { - logger.debug(`callTool for ${filePath} failed:`, err); + logger.debug(`callTool for closeDiff ${filePath} failed:`, err); } - return; + return undefined; } // Closes the diff. Instead of waiting for a notification, @@ -648,6 +728,22 @@ export class IdeClient { }, ); + this.client.setNotificationHandler( + IdeDiffRejectedNotificationSchema, + (notification) => { + const { filePath } = notification.params; + const resolver = this.diffResponses.get(filePath); + if (resolver) { + resolver({ status: 'rejected', content: undefined }); + this.diffResponses.delete(filePath); + } else { + logger.debug(`No resolver found for ${filePath}`); + } + }, + ); + + // For backwards compatability. Newer extension versions will only send + // IdeDiffRejectedNotificationSchema. this.client.setNotificationHandler( IdeDiffClosedNotificationSchema, (notification) => { diff --git a/packages/core/src/ide/ideContext.ts b/packages/core/src/ide/ideContext.ts index cb57b99f4b..527d4135cc 100644 --- a/packages/core/src/ide/ideContext.ts +++ b/packages/core/src/ide/ideContext.ts @@ -4,71 +4,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { z } from 'zod'; import { IDE_MAX_OPEN_FILES, IDE_MAX_SELECTED_TEXT_LENGTH, } from './constants.js'; import type { IdeContext } from './types.js'; -export const IdeDiffAcceptedNotificationSchema = z.object({ - jsonrpc: z.literal('2.0'), - method: z.literal('ide/diffAccepted'), - params: z.object({ - filePath: z.string(), - content: z.string(), - }), -}); - -export const IdeDiffClosedNotificationSchema = z.object({ - jsonrpc: z.literal('2.0'), - method: z.literal('ide/diffClosed'), - params: z.object({ - filePath: z.string(), - content: z.string().optional(), - }), -}); - -export const CloseDiffResponseSchema = z - .object({ - content: z - .array( - z.object({ - text: z.string(), - type: z.literal('text'), - }), - ) - .min(1), - }) - .transform((val, ctx) => { - try { - const parsed = JSON.parse(val.content[0].text); - const innerSchema = z.object({ content: z.string().optional() }); - const validationResult = innerSchema.safeParse(parsed); - if (!validationResult.success) { - validationResult.error.issues.forEach((issue) => ctx.addIssue(issue)); - return z.NEVER; - } - return validationResult.data; - } catch (_) { - ctx.addIssue({ - code: z.ZodIssueCode.custom, - message: 'Invalid JSON in text content', - }); - return z.NEVER; - } - }); - -export type DiffUpdateResult = - | { - status: 'accepted'; - content?: string; - } - | { - status: 'rejected'; - content: undefined; - }; - type IdeContextSubscriber = (ideContext?: IdeContext) => void; export class IdeContextStore { diff --git a/packages/core/src/ide/types.ts b/packages/core/src/ide/types.ts index 69310eea65..f94132db59 100644 --- a/packages/core/src/ide/types.ts +++ b/packages/core/src/ide/types.ts @@ -63,8 +63,86 @@ export const IdeContextSchema = z.object({ }); export type IdeContext = z.infer; +/** + * A notification that the IDE context has been updated. + */ export const IdeContextNotificationSchema = z.object({ jsonrpc: z.literal('2.0'), method: z.literal('ide/contextUpdate'), params: IdeContextSchema, }); + +/** + * A notification that a diff has been accepted in the IDE. + */ +export const IdeDiffAcceptedNotificationSchema = z.object({ + jsonrpc: z.literal('2.0'), + method: z.literal('ide/diffAccepted'), + params: z.object({ + /** + * The absolute path to the file that was diffed. + */ + filePath: z.string(), + /** + * The full content of the file after the diff was accepted, which includes any manual edits the user may have made. + */ + content: z.string(), + }), +}); + +/** + * A notification that a diff has been rejected in the IDE. + */ +export const IdeDiffRejectedNotificationSchema = z.object({ + jsonrpc: z.literal('2.0'), + method: z.literal('ide/diffRejected'), + params: z.object({ + /** + * The absolute path to the file that was diffed. + */ + filePath: z.string(), + }), +}); + +/** + * This is defineded for backwards compatability only. Newer extension versions + * will only send IdeDiffRejectedNotificationSchema. + * + * A notification that a diff has been closed in the IDE. + */ +export const IdeDiffClosedNotificationSchema = z.object({ + jsonrpc: z.literal('2.0'), + method: z.literal('ide/diffClosed'), + params: z.object({ + filePath: z.string(), + content: z.string().optional(), + }), +}); + +/** + * The request to open a diff view in the IDE. + */ +export const OpenDiffRequestSchema = z.object({ + /** + * The absolute path to the file to be diffed. + */ + filePath: z.string(), + /** + * The proposed new content for the file. + */ + newContent: z.string(), +}); + +/** + * The request to close a diff view in the IDE. + */ +export const CloseDiffRequestSchema = z.object({ + /** + * The absolute path to the file to be diffed. + */ + filePath: z.string(), + /** + * @deprecated + */ + suppressNotification: z.boolean().optional(), +}); diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 6029b9f8d2..5c51dba755 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -6,7 +6,7 @@ import type { FunctionDeclaration, PartListUnion } from '@google/genai'; import { ToolErrorType } from './tool-error.js'; -import type { DiffUpdateResult } from '../ide/ideContext.js'; +import type { DiffUpdateResult } from '../ide/ide-client.js'; import type { ShellExecutionConfig } from '../services/shellExecutionService.js'; import { SchemaValidator } from '../utils/schemaValidator.js'; import type { AnsiOutput } from '../utils/terminalSerializer.js'; diff --git a/packages/core/src/tools/write-file.test.ts b/packages/core/src/tools/write-file.test.ts index adb6ebf751..5693978971 100644 --- a/packages/core/src/tools/write-file.test.ts +++ b/packages/core/src/tools/write-file.test.ts @@ -33,8 +33,8 @@ import { } from '../utils/editCorrector.js'; import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js'; import { StandardFileSystemService } from '../services/fileSystemService.js'; +import type { DiffUpdateResult } from '../ide/ide-client.js'; import { IdeClient } from '../ide/ide-client.js'; -import type { DiffUpdateResult } from '../ide/ideContext.js'; const rootDir = path.resolve(os.tmpdir(), 'gemini-cli-test-root'); diff --git a/packages/vscode-ide-companion/src/ide-server.ts b/packages/vscode-ide-companion/src/ide-server.ts index 32d56a893a..fe86b20f5d 100644 --- a/packages/vscode-ide-companion/src/ide-server.ts +++ b/packages/vscode-ide-companion/src/ide-server.ts @@ -5,7 +5,11 @@ */ import * as vscode from 'vscode'; -import { IdeContextNotificationSchema } from '@google/gemini-cli-core'; +import { + CloseDiffRequestSchema, + IdeContextNotificationSchema, + OpenDiffRequestSchema, +} from '@google/gemini-cli-core'; import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js'; import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; @@ -15,7 +19,7 @@ import { type Server as HTTPServer } from 'node:http'; import * as path from 'node:path'; import * as fs from 'node:fs/promises'; import * as os from 'node:os'; -import { z } from 'zod'; +import type { z } from 'zod'; import type { DiffManager } from './diff-manager.js'; import { OpenFilesManager } from './open-files-manager.js'; @@ -339,46 +343,23 @@ const createMcpServer = (diffManager: DiffManager) => { { description: '(IDE Tool) Open a diff view to create or modify a file. Returns a notification once the diff has been accepted or rejcted.', - inputSchema: z.object({ - filePath: z.string(), - // TODO(chrstn): determine if this should be required or not. - newContent: z.string().optional(), - }).shape, + inputSchema: OpenDiffRequestSchema.shape, }, - async ({ - filePath, - newContent, - }: { - filePath: string; - newContent?: string; - }) => { - await diffManager.showDiff(filePath, newContent ?? ''); - return { - content: [ - { - type: 'text', - text: `Showing diff for ${filePath}`, - }, - ], - }; + async ({ filePath, newContent }: z.infer) => { + await diffManager.showDiff(filePath, newContent); + return { content: [] }; }, ); server.registerTool( 'closeDiff', { description: '(IDE Tool) Close an open diff view for a specific file.', - inputSchema: z.object({ - filePath: z.string(), - suppressNotification: z.boolean().optional(), - }).shape, + inputSchema: CloseDiffRequestSchema.shape, }, async ({ filePath, suppressNotification, - }: { - filePath: string; - suppressNotification?: boolean; - }) => { + }: z.infer) => { const content = await diffManager.closeDiff( filePath, suppressNotification,