Use IdeClient directly instead of config.ideClient (#7627)

This commit is contained in:
Tommaso Sciortino
2025-09-04 09:32:09 -07:00
committed by GitHub
parent 45d494a8d8
commit cb43bb9ca4
24 changed files with 288 additions and 217 deletions
@@ -24,7 +24,6 @@ export function createMockConfig(
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT), getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
getIdeMode: vi.fn().mockReturnValue(false), getIdeMode: vi.fn().mockReturnValue(false),
getAllowedTools: vi.fn().mockReturnValue([]), getAllowedTools: vi.fn().mockReturnValue([]),
getIdeClient: vi.fn(),
getWorkspaceContext: vi.fn().mockReturnValue({ getWorkspaceContext: vi.fn().mockReturnValue({
isPathWithinWorkspace: () => true, isPathWithinWorkspace: () => true,
}), }),
-8
View File
@@ -36,9 +36,6 @@ import {
logUserPrompt, logUserPrompt,
AuthType, AuthType,
getOauthClient, getOauthClient,
logIdeConnection,
IdeConnectionEvent,
IdeConnectionType,
uiTelemetryService, uiTelemetryService,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import { validateAuthMethod } from './config/auth.js'; import { validateAuthMethod } from './config/auth.js';
@@ -289,11 +286,6 @@ export async function main() {
spinnerInstance.unmount(); spinnerInstance.unmount();
} }
if (config.getIdeMode()) {
await config.getIdeClient().connect();
logIdeConnection(config, new IdeConnectionEvent(IdeConnectionType.START));
}
// Load custom themes from settings // Load custom themes from settings
themeManager.loadCustomThemes(settings.merged.ui?.customThemes); themeManager.loadCustomThemes(settings.merged.ui?.customThemes);
@@ -64,7 +64,7 @@ describe('BuiltinCommandLoader', () => {
vi.clearAllMocks(); vi.clearAllMocks();
mockConfig = { some: 'config' } as unknown as Config; mockConfig = { some: 'config' } as unknown as Config;
ideCommandMock.mockReturnValue({ ideCommandMock.mockResolvedValue({
name: 'ide', name: 'ide',
description: 'IDE command', description: 'IDE command',
kind: CommandKind.BUILT_IN, kind: CommandKind.BUILT_IN,
@@ -81,7 +81,7 @@ describe('BuiltinCommandLoader', () => {
await loader.loadCommands(new AbortController().signal); await loader.loadCommands(new AbortController().signal);
expect(ideCommandMock).toHaveBeenCalledTimes(1); expect(ideCommandMock).toHaveBeenCalledTimes(1);
expect(ideCommandMock).toHaveBeenCalledWith(mockConfig); expect(ideCommandMock).toHaveBeenCalledWith();
expect(restoreCommandMock).toHaveBeenCalledTimes(1); expect(restoreCommandMock).toHaveBeenCalledTimes(1);
expect(restoreCommandMock).toHaveBeenCalledWith(mockConfig); expect(restoreCommandMock).toHaveBeenCalledWith(mockConfig);
}); });
@@ -105,7 +105,7 @@ describe('BuiltinCommandLoader', () => {
const loader = new BuiltinCommandLoader(null); const loader = new BuiltinCommandLoader(null);
await loader.loadCommands(new AbortController().signal); await loader.loadCommands(new AbortController().signal);
expect(ideCommandMock).toHaveBeenCalledTimes(1); expect(ideCommandMock).toHaveBeenCalledTimes(1);
expect(ideCommandMock).toHaveBeenCalledWith(null); expect(ideCommandMock).toHaveBeenCalledWith();
expect(restoreCommandMock).toHaveBeenCalledTimes(1); expect(restoreCommandMock).toHaveBeenCalledTimes(1);
expect(restoreCommandMock).toHaveBeenCalledWith(null); expect(restoreCommandMock).toHaveBeenCalledWith(null);
}); });
@@ -64,7 +64,7 @@ export class BuiltinCommandLoader implements ICommandLoader {
editorCommand, editorCommand,
extensionsCommand, extensionsCommand,
helpCommand, helpCommand,
ideCommand(this.config), await ideCommand(),
initCommand, initCommand,
mcpCommand, mcpCommand,
memoryCommand, memoryCommand,
+9 -8
View File
@@ -93,7 +93,6 @@ interface MockServerConfig {
getAllGeminiMdFilenames: Mock<() => string[]>; getAllGeminiMdFilenames: Mock<() => string[]>;
getGeminiClient: Mock<() => GeminiClient | undefined>; getGeminiClient: Mock<() => GeminiClient | undefined>;
getUserTier: Mock<() => Promise<string | undefined>>; getUserTier: Mock<() => Promise<string | undefined>>;
getIdeClient: Mock<() => { getCurrentIde: Mock<() => string | undefined> }>;
getScreenReader: Mock<() => boolean>; getScreenReader: Mock<() => boolean>;
} }
@@ -183,13 +182,6 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
getWorkspaceContext: vi.fn(() => ({ getWorkspaceContext: vi.fn(() => ({
getDirectories: vi.fn(() => []), getDirectories: vi.fn(() => []),
})), })),
getIdeClient: vi.fn(() => ({
getCurrentIde: vi.fn(() => 'vscode'),
getDetectedIdeDisplayName: vi.fn(() => 'VSCode'),
addStatusChangeListener: vi.fn(),
removeStatusChangeListener: vi.fn(),
getConnectionStatus: vi.fn(() => 'connected'),
})),
isTrustedFolder: vi.fn(() => true), isTrustedFolder: vi.fn(() => true),
getScreenReader: vi.fn(() => false), getScreenReader: vi.fn(() => false),
getFolderTrustFeature: vi.fn(() => false), getFolderTrustFeature: vi.fn(() => false),
@@ -208,6 +200,15 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
MCPServerConfig: actualCore.MCPServerConfig, MCPServerConfig: actualCore.MCPServerConfig,
getAllGeminiMdFilenames: vi.fn(() => ['GEMINI.md']), getAllGeminiMdFilenames: vi.fn(() => ['GEMINI.md']),
ideContext: ideContextMock, ideContext: ideContextMock,
IdeClient: {
getInstance: vi.fn().mockResolvedValue({
getCurrentIde: vi.fn(() => 'vscode'),
getDetectedIdeDisplayName: vi.fn(() => 'VSCode'),
addStatusChangeListener: vi.fn(),
removeStatusChangeListener: vi.fn(),
getConnectionStatus: vi.fn(() => 'connected'),
}),
},
isGitRepository: vi.fn(), isGitRepository: vi.fn(),
}; };
}); });
+19 -4
View File
@@ -59,7 +59,12 @@ import { ContextSummaryDisplay } from './components/ContextSummaryDisplay.js';
import { useHistory } from './hooks/useHistoryManager.js'; import { useHistory } from './hooks/useHistoryManager.js';
import { useInputHistoryStore } from './hooks/useInputHistoryStore.js'; import { useInputHistoryStore } from './hooks/useInputHistoryStore.js';
import process from 'node:process'; import process from 'node:process';
import type { EditorType, Config, IdeContext } from '@google/gemini-cli-core'; import type {
EditorType,
Config,
IdeContext,
DetectedIde,
} from '@google/gemini-cli-core';
import { import {
ApprovalMode, ApprovalMode,
getAllGeminiMdFilenames, getAllGeminiMdFilenames,
@@ -73,6 +78,7 @@ import {
isGenericQuotaExceededError, isGenericQuotaExceededError,
UserTierId, UserTierId,
DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_FLASH_MODEL,
IdeClient,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import type { IdeIntegrationNudgeResult } from './IdeIntegrationNudge.js'; import type { IdeIntegrationNudgeResult } from './IdeIntegrationNudge.js';
import { IdeIntegrationNudge } from './IdeIntegrationNudge.js'; import { IdeIntegrationNudge } from './IdeIntegrationNudge.js';
@@ -161,10 +167,19 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
const { history, addItem, clearItems, loadHistory } = useHistory(); const { history, addItem, clearItems, loadHistory } = useHistory();
const [idePromptAnswered, setIdePromptAnswered] = useState(false); const [idePromptAnswered, setIdePromptAnswered] = useState(false);
const currentIDE = config.getIdeClient().getCurrentIde(); const [currentIDE, setCurrentIDE] = useState<DetectedIde | undefined>();
useEffect(() => { useEffect(() => {
registerCleanup(() => config.getIdeClient().disconnect()); (async () => {
const ideClient = await IdeClient.getInstance();
setCurrentIDE(ideClient.getCurrentIde());
})();
registerCleanup(async () => {
const ideClient = await IdeClient.getInstance();
ideClient.disconnect();
});
}, [config]); }, [config]);
const shouldShowIdePrompt = const shouldShowIdePrompt =
currentIDE && currentIDE &&
!config.getIdeMode() && !config.getIdeMode() &&
@@ -306,7 +321,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
const { isFolderTrustDialogOpen, handleFolderTrustSelect, isRestarting } = const { isFolderTrustDialogOpen, handleFolderTrustSelect, isRestarting } =
useFolderTrust(settings, setIsTrustedFolder); useFolderTrust(settings, setIsTrustedFolder);
const { needsRestart: ideNeedsRestart } = useIdeTrustListener(config); const { needsRestart: ideNeedsRestart } = useIdeTrustListener();
useEffect(() => { useEffect(() => {
if (ideNeedsRestart) { if (ideNeedsRestart) {
// IDE trust changed, force a restart. // IDE trust changed, force a restart.
@@ -10,8 +10,20 @@ import { type CommandContext } from './types.js';
import { createMockCommandContext } from '../../test-utils/mockCommandContext.js'; import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
import * as versionUtils from '../../utils/version.js'; import * as versionUtils from '../../utils/version.js';
import { MessageType } from '../types.js'; import { MessageType } from '../types.js';
import { IdeClient } from '@google/gemini-cli-core';
import type { IdeClient } from '../../../../core/src/ide/ide-client.js'; vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const actual =
await importOriginal<typeof import('@google/gemini-cli-core')>();
return {
...actual,
IdeClient: {
getInstance: vi.fn().mockResolvedValue({
getDetectedIdeDisplayName: vi.fn().mockReturnValue('test-ide'),
}),
},
};
});
vi.mock('../../utils/version.js', () => ({ vi.mock('../../utils/version.js', () => ({
getCliVersion: vi.fn(), getCliVersion: vi.fn(),
@@ -27,7 +39,6 @@ describe('aboutCommand', () => {
services: { services: {
config: { config: {
getModel: vi.fn(), getModel: vi.fn(),
getIdeClient: vi.fn(),
getIdeMode: vi.fn().mockReturnValue(true), getIdeMode: vi.fn().mockReturnValue(true),
}, },
settings: { settings: {
@@ -53,9 +64,6 @@ describe('aboutCommand', () => {
Object.defineProperty(process, 'platform', { Object.defineProperty(process, 'platform', {
value: 'test-os', value: 'test-os',
}); });
vi.spyOn(mockContext.services.config!, 'getIdeClient').mockReturnValue({
getDetectedIdeDisplayName: vi.fn().mockReturnValue('test-ide'),
} as Partial<IdeClient> as IdeClient);
}); });
afterEach(() => { afterEach(() => {
@@ -129,11 +137,11 @@ describe('aboutCommand', () => {
}); });
it('should not show ide client when it is not detected', async () => { it('should not show ide client when it is not detected', async () => {
vi.spyOn(mockContext.services.config!, 'getIdeClient').mockReturnValue({ vi.mocked(IdeClient.getInstance).mockResolvedValue({
getDetectedIdeDisplayName: vi.fn().mockReturnValue(undefined), getDetectedIdeDisplayName: vi.fn().mockReturnValue(undefined),
} as Partial<IdeClient> as IdeClient); } as unknown as IdeClient);
process.env.SANDBOX = ''; process.env['SANDBOX'] = '';
if (!aboutCommand.action) { if (!aboutCommand.action) {
throw new Error('The about command must have an action.'); throw new Error('The about command must have an action.');
} }
+11 -5
View File
@@ -5,10 +5,11 @@
*/ */
import { getCliVersion } from '../../utils/version.js'; import { getCliVersion } from '../../utils/version.js';
import type { SlashCommand } from './types.js'; import type { CommandContext, SlashCommand } from './types.js';
import { CommandKind } from './types.js'; import { CommandKind } from './types.js';
import process from 'node:process'; import process from 'node:process';
import { MessageType, type HistoryItemAbout } from '../types.js'; import { MessageType, type HistoryItemAbout } from '../types.js';
import { IdeClient } from '@google/gemini-cli-core';
export const aboutCommand: SlashCommand = { export const aboutCommand: SlashCommand = {
name: 'about', name: 'about',
@@ -29,10 +30,7 @@ export const aboutCommand: SlashCommand = {
const selectedAuthType = const selectedAuthType =
context.services.settings.merged.security?.auth?.selectedType || ''; context.services.settings.merged.security?.auth?.selectedType || '';
const gcpProject = process.env['GOOGLE_CLOUD_PROJECT'] || ''; const gcpProject = process.env['GOOGLE_CLOUD_PROJECT'] || '';
const ideClient = const ideClient = await getIdeClientName(context);
(context.services.config?.getIdeMode() &&
context.services.config?.getIdeClient()?.getDetectedIdeDisplayName()) ||
'';
const aboutItem: Omit<HistoryItemAbout, 'id'> = { const aboutItem: Omit<HistoryItemAbout, 'id'> = {
type: MessageType.ABOUT, type: MessageType.ABOUT,
@@ -48,3 +46,11 @@ export const aboutCommand: SlashCommand = {
context.ui.addItem(aboutItem, Date.now()); context.ui.addItem(aboutItem, Date.now());
}, },
}; };
async function getIdeClientName(context: CommandContext) {
if (!context.services.config?.getIdeMode()) {
return '';
}
const ideClient = await IdeClient.getInstance();
return ideClient?.getDetectedIdeDisplayName() ?? '';
}
+13 -10
View File
@@ -16,7 +16,19 @@ import { formatMemoryUsage } from '../utils/formatters.js';
vi.mock('open'); vi.mock('open');
vi.mock('../../utils/version.js'); vi.mock('../../utils/version.js');
vi.mock('../utils/formatters.js'); vi.mock('../utils/formatters.js');
vi.mock('@google/gemini-cli-core'); vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const actual =
await importOriginal<typeof import('@google/gemini-cli-core')>();
return {
...actual,
IdeClient: {
getInstance: () => ({
getDetectedIdeDisplayName: vi.fn().mockReturnValue('VSCode'),
}),
},
sessionId: 'test-session-id',
};
});
vi.mock('node:process', () => ({ vi.mock('node:process', () => ({
default: { default: {
platform: 'test-platform', platform: 'test-platform',
@@ -31,9 +43,6 @@ describe('bugCommand', () => {
beforeEach(() => { beforeEach(() => {
vi.mocked(getCliVersion).mockResolvedValue('0.1.0'); vi.mocked(getCliVersion).mockResolvedValue('0.1.0');
vi.mocked(formatMemoryUsage).mockReturnValue('100 MB'); vi.mocked(formatMemoryUsage).mockReturnValue('100 MB');
vi.mock('@google/gemini-cli-core', () => ({
sessionId: 'test-session-id',
}));
vi.stubEnv('SANDBOX', 'gemini-test'); vi.stubEnv('SANDBOX', 'gemini-test');
}); });
@@ -48,9 +57,6 @@ describe('bugCommand', () => {
config: { config: {
getModel: () => 'gemini-pro', getModel: () => 'gemini-pro',
getBugCommand: () => undefined, getBugCommand: () => undefined,
getIdeClient: () => ({
getDetectedIdeDisplayName: () => 'VSCode',
}),
getIdeMode: () => true, getIdeMode: () => true,
}, },
}, },
@@ -84,9 +90,6 @@ describe('bugCommand', () => {
config: { config: {
getModel: () => 'gemini-pro', getModel: () => 'gemini-pro',
getBugCommand: () => ({ urlTemplate: customTemplate }), getBugCommand: () => ({ urlTemplate: customTemplate }),
getIdeClient: () => ({
getDetectedIdeDisplayName: () => 'VSCode',
}),
getIdeMode: () => true, getIdeMode: () => true,
}, },
}, },
+10 -5
View File
@@ -15,7 +15,7 @@ import { MessageType } from '../types.js';
import { GIT_COMMIT_INFO } from '../../generated/git-commit.js'; import { GIT_COMMIT_INFO } from '../../generated/git-commit.js';
import { formatMemoryUsage } from '../utils/formatters.js'; import { formatMemoryUsage } from '../utils/formatters.js';
import { getCliVersion } from '../../utils/version.js'; import { getCliVersion } from '../../utils/version.js';
import { sessionId } from '@google/gemini-cli-core'; import { IdeClient, sessionId } from '@google/gemini-cli-core';
export const bugCommand: SlashCommand = { export const bugCommand: SlashCommand = {
name: 'bug', name: 'bug',
@@ -37,10 +37,7 @@ export const bugCommand: SlashCommand = {
const modelVersion = config?.getModel() || 'Unknown'; const modelVersion = config?.getModel() || 'Unknown';
const cliVersion = await getCliVersion(); const cliVersion = await getCliVersion();
const memoryUsage = formatMemoryUsage(process.memoryUsage().rss); const memoryUsage = formatMemoryUsage(process.memoryUsage().rss);
const ideClient = const ideClient = await getIdeClientName(context);
(context.services.config?.getIdeMode() &&
context.services.config?.getIdeClient()?.getDetectedIdeDisplayName()) ||
'';
let info = ` let info = `
* **CLI Version:** ${cliVersion} * **CLI Version:** ${cliVersion}
@@ -90,3 +87,11 @@ export const bugCommand: SlashCommand = {
} }
}, },
}; };
async function getIdeClientName(context: CommandContext) {
if (!context.services.config?.getIdeMode()) {
return '';
}
const ideClient = await IdeClient.getInstance();
return ideClient.getDetectedIdeDisplayName() ?? '';
}
+96 -78
View File
@@ -8,26 +8,43 @@ import type { MockInstance } from 'vitest';
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest'; import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import { ideCommand } from './ideCommand.js'; import { ideCommand } from './ideCommand.js';
import { type CommandContext } from './types.js'; import { type CommandContext } from './types.js';
import { type Config, DetectedIde } from '@google/gemini-cli-core'; import { DetectedIde } from '@google/gemini-cli-core';
import * as core from '@google/gemini-cli-core'; import * as core from '@google/gemini-cli-core';
vi.mock('child_process');
vi.mock('glob');
vi.mock('@google/gemini-cli-core', async (importOriginal) => { vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const original = await importOriginal<typeof core>(); const original = await importOriginal<typeof core>();
return { return {
...original, ...original,
getOauthClient: vi.fn(original.getOauthClient), getOauthClient: vi.fn(original.getOauthClient),
getIdeInstaller: vi.fn(original.getIdeInstaller), getIdeInstaller: vi.fn(original.getIdeInstaller),
IdeClient: {
getInstance: vi.fn(),
},
ideContext: {
getIdeContext: vi.fn(),
},
}; };
}); });
describe('ideCommand', () => { describe('ideCommand', () => {
let mockContext: CommandContext; let mockContext: CommandContext;
let mockConfig: Config; let mockIdeClient: core.IdeClient;
let platformSpy: MockInstance; let platformSpy: MockInstance;
beforeEach(() => { beforeEach(() => {
vi.resetAllMocks();
mockIdeClient = {
reconnect: vi.fn(),
disconnect: vi.fn(),
connect: vi.fn(),
getCurrentIde: vi.fn(),
getDetectedIdeDisplayName: vi.fn(),
getConnectionStatus: vi.fn(),
} as unknown as core.IdeClient;
vi.mocked(core.IdeClient.getInstance).mockResolvedValue(mockIdeClient);
mockContext = { mockContext = {
ui: { ui: {
addItem: vi.fn(), addItem: vi.fn(),
@@ -36,22 +53,14 @@ describe('ideCommand', () => {
settings: { settings: {
setValue: vi.fn(), setValue: vi.fn(),
}, },
config: {
getIdeMode: vi.fn(),
setIdeMode: vi.fn(),
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
},
}, },
} as unknown as CommandContext; } as unknown as CommandContext;
mockConfig = {
getIdeMode: vi.fn(),
getIdeClient: vi.fn(() => ({
reconnect: vi.fn(),
disconnect: vi.fn(),
getCurrentIde: vi.fn(),
getDetectedIdeDisplayName: vi.fn(),
getConnectionStatus: vi.fn(),
})),
setIdeModeAndSyncConnection: vi.fn(),
setIdeMode: vi.fn(),
} as unknown as Config;
platformSpy = vi.spyOn(process, 'platform', 'get'); platformSpy = vi.spyOn(process, 'platform', 'get');
}); });
@@ -59,64 +68,57 @@ describe('ideCommand', () => {
vi.restoreAllMocks(); vi.restoreAllMocks();
}); });
it('should return null if config is not provided', () => { it('should return the ide command', async () => {
const command = ideCommand(null); vi.mocked(mockIdeClient.getCurrentIde).mockReturnValue(DetectedIde.VSCode);
expect(command).toBeNull(); vi.mocked(mockIdeClient.getDetectedIdeDisplayName).mockReturnValue(
}); 'VS Code',
);
it('should return the ide command', () => { vi.mocked(mockIdeClient.getConnectionStatus).mockReturnValue({
vi.mocked(mockConfig.getIdeMode).mockReturnValue(true);
vi.mocked(mockConfig.getIdeClient).mockReturnValue({
getCurrentIde: () => DetectedIde.VSCode,
getDetectedIdeDisplayName: () => 'VS Code',
getConnectionStatus: () => ({
status: core.IDEConnectionStatus.Disconnected, status: core.IDEConnectionStatus.Disconnected,
}), });
} as ReturnType<Config['getIdeClient']>); const command = await ideCommand();
const command = ideCommand(mockConfig);
expect(command).not.toBeNull(); expect(command).not.toBeNull();
expect(command?.name).toBe('ide'); expect(command.name).toBe('ide');
expect(command?.subCommands).toHaveLength(3); expect(command.subCommands).toHaveLength(3);
expect(command?.subCommands?.[0].name).toBe('enable'); expect(command.subCommands?.[0].name).toBe('enable');
expect(command?.subCommands?.[1].name).toBe('status'); expect(command.subCommands?.[1].name).toBe('status');
expect(command?.subCommands?.[2].name).toBe('install'); expect(command.subCommands?.[2].name).toBe('install');
}); });
it('should show disable command when connected', () => { it('should show disable command when connected', async () => {
vi.mocked(mockConfig.getIdeMode).mockReturnValue(true); vi.mocked(mockIdeClient.getCurrentIde).mockReturnValue(DetectedIde.VSCode);
vi.mocked(mockConfig.getIdeClient).mockReturnValue({ vi.mocked(mockIdeClient.getDetectedIdeDisplayName).mockReturnValue(
getCurrentIde: () => DetectedIde.VSCode, 'VS Code',
getDetectedIdeDisplayName: () => 'VS Code', );
getConnectionStatus: () => ({ vi.mocked(mockIdeClient.getConnectionStatus).mockReturnValue({
status: core.IDEConnectionStatus.Connected, status: core.IDEConnectionStatus.Connected,
}), });
} as ReturnType<Config['getIdeClient']>); const command = await ideCommand();
const command = ideCommand(mockConfig);
expect(command).not.toBeNull(); expect(command).not.toBeNull();
const subCommandNames = command?.subCommands?.map((cmd) => cmd.name); const subCommandNames = command.subCommands?.map((cmd) => cmd.name);
expect(subCommandNames).toContain('disable'); expect(subCommandNames).toContain('disable');
expect(subCommandNames).not.toContain('enable'); expect(subCommandNames).not.toContain('enable');
}); });
describe('status subcommand', () => { describe('status subcommand', () => {
const mockGetConnectionStatus = vi.fn();
beforeEach(() => { beforeEach(() => {
vi.mocked(mockConfig.getIdeClient).mockReturnValue({ vi.mocked(mockIdeClient.getCurrentIde).mockReturnValue(
getConnectionStatus: mockGetConnectionStatus, DetectedIde.VSCode,
getCurrentIde: () => DetectedIde.VSCode, );
getDetectedIdeDisplayName: () => 'VS Code', vi.mocked(mockIdeClient.getDetectedIdeDisplayName).mockReturnValue(
} as unknown as ReturnType<Config['getIdeClient']>); 'VS Code',
);
}); });
it('should show connected status', async () => { it('should show connected status', async () => {
mockGetConnectionStatus.mockReturnValue({ vi.mocked(mockIdeClient.getConnectionStatus).mockReturnValue({
status: core.IDEConnectionStatus.Connected, status: core.IDEConnectionStatus.Connected,
}); });
const command = ideCommand(mockConfig); const command = await ideCommand();
const result = await command!.subCommands!.find( const result = await command!.subCommands!.find(
(c) => c.name === 'status', (c) => c.name === 'status',
)!.action!(mockContext, ''); )!.action!(mockContext, '');
expect(mockGetConnectionStatus).toHaveBeenCalled(); expect(vi.mocked(mockIdeClient.getConnectionStatus)).toHaveBeenCalled();
expect(result).toEqual({ expect(result).toEqual({
type: 'message', type: 'message',
messageType: 'info', messageType: 'info',
@@ -125,14 +127,14 @@ describe('ideCommand', () => {
}); });
it('should show connecting status', async () => { it('should show connecting status', async () => {
mockGetConnectionStatus.mockReturnValue({ vi.mocked(mockIdeClient.getConnectionStatus).mockReturnValue({
status: core.IDEConnectionStatus.Connecting, status: core.IDEConnectionStatus.Connecting,
}); });
const command = ideCommand(mockConfig); const command = await ideCommand();
const result = await command!.subCommands!.find( const result = await command!.subCommands!.find(
(c) => c.name === 'status', (c) => c.name === 'status',
)!.action!(mockContext, ''); )!.action!(mockContext, '');
expect(mockGetConnectionStatus).toHaveBeenCalled(); expect(vi.mocked(mockIdeClient.getConnectionStatus)).toHaveBeenCalled();
expect(result).toEqual({ expect(result).toEqual({
type: 'message', type: 'message',
messageType: 'info', messageType: 'info',
@@ -140,14 +142,14 @@ describe('ideCommand', () => {
}); });
}); });
it('should show disconnected status', async () => { it('should show disconnected status', async () => {
mockGetConnectionStatus.mockReturnValue({ vi.mocked(mockIdeClient.getConnectionStatus).mockReturnValue({
status: core.IDEConnectionStatus.Disconnected, status: core.IDEConnectionStatus.Disconnected,
}); });
const command = ideCommand(mockConfig); const command = await ideCommand();
const result = await command!.subCommands!.find( const result = await command!.subCommands!.find(
(c) => c.name === 'status', (c) => c.name === 'status',
)!.action!(mockContext, ''); )!.action!(mockContext, '');
expect(mockGetConnectionStatus).toHaveBeenCalled(); expect(vi.mocked(mockIdeClient.getConnectionStatus)).toHaveBeenCalled();
expect(result).toEqual({ expect(result).toEqual({
type: 'message', type: 'message',
messageType: 'error', messageType: 'error',
@@ -157,15 +159,15 @@ describe('ideCommand', () => {
it('should show disconnected status with details', async () => { it('should show disconnected status with details', async () => {
const details = 'Something went wrong'; const details = 'Something went wrong';
mockGetConnectionStatus.mockReturnValue({ vi.mocked(mockIdeClient.getConnectionStatus).mockReturnValue({
status: core.IDEConnectionStatus.Disconnected, status: core.IDEConnectionStatus.Disconnected,
details, details,
}); });
const command = ideCommand(mockConfig); const command = await ideCommand();
const result = await command!.subCommands!.find( const result = await command!.subCommands!.find(
(c) => c.name === 'status', (c) => c.name === 'status',
)!.action!(mockContext, ''); )!.action!(mockContext, '');
expect(mockGetConnectionStatus).toHaveBeenCalled(); expect(vi.mocked(mockIdeClient.getConnectionStatus)).toHaveBeenCalled();
expect(result).toEqual({ expect(result).toEqual({
type: 'message', type: 'message',
messageType: 'error', messageType: 'error',
@@ -177,32 +179,40 @@ describe('ideCommand', () => {
describe('install subcommand', () => { describe('install subcommand', () => {
const mockInstall = vi.fn(); const mockInstall = vi.fn();
beforeEach(() => { beforeEach(() => {
vi.mocked(mockConfig.getIdeMode).mockReturnValue(true); vi.mocked(mockIdeClient.getCurrentIde).mockReturnValue(
vi.mocked(mockConfig.getIdeClient).mockReturnValue({ DetectedIde.VSCode,
getCurrentIde: () => DetectedIde.VSCode, );
getConnectionStatus: () => ({ vi.mocked(mockIdeClient.getDetectedIdeDisplayName).mockReturnValue(
'VS Code',
);
vi.mocked(mockIdeClient.getConnectionStatus).mockReturnValue({
status: core.IDEConnectionStatus.Disconnected, status: core.IDEConnectionStatus.Disconnected,
}), });
getDetectedIdeDisplayName: () => 'VS Code',
} as unknown as ReturnType<Config['getIdeClient']>);
vi.mocked(core.getIdeInstaller).mockReturnValue({ vi.mocked(core.getIdeInstaller).mockReturnValue({
install: mockInstall, install: mockInstall,
isInstalled: vi.fn(),
}); });
platformSpy.mockReturnValue('linux'); platformSpy.mockReturnValue('linux');
}); });
it('should install the extension', async () => { it('should install the extension', async () => {
vi.useFakeTimers();
mockInstall.mockResolvedValue({ mockInstall.mockResolvedValue({
success: true, success: true,
message: 'Successfully installed.', message: 'Successfully installed.',
}); });
const command = ideCommand(mockConfig); const command = await ideCommand();
await command!.subCommands!.find((c) => c.name === 'install')!.action!(
mockContext, // For the polling loop inside the action.
'', vi.mocked(mockIdeClient.getConnectionStatus).mockReturnValue({
); status: core.IDEConnectionStatus.Connected,
});
const actionPromise = command!.subCommands!.find(
(c) => c.name === 'install',
)!.action!(mockContext, '');
await vi.runAllTimersAsync();
await actionPromise;
expect(core.getIdeInstaller).toHaveBeenCalledWith('vscode'); expect(core.getIdeInstaller).toHaveBeenCalledWith('vscode');
expect(mockInstall).toHaveBeenCalled(); expect(mockInstall).toHaveBeenCalled();
@@ -220,6 +230,14 @@ describe('ideCommand', () => {
}), }),
expect.any(Number), expect.any(Number),
); );
expect(mockContext.ui.addItem).toHaveBeenCalledWith(
expect.objectContaining({
type: 'info',
text: '🟢 Connected to VS Code',
}),
expect.any(Number),
);
vi.useRealTimers();
}, 10000); }, 10000);
it('should show an error if installation fails', async () => { it('should show an error if installation fails', async () => {
@@ -228,7 +246,7 @@ describe('ideCommand', () => {
message: 'Installation failed.', message: 'Installation failed.',
}); });
const command = ideCommand(mockConfig); const command = await ideCommand();
await command!.subCommands!.find((c) => c.name === 'install')!.action!( await command!.subCommands!.find((c) => c.name === 'install')!.action!(
mockContext, mockContext,
'', '',
+26 -8
View File
@@ -4,7 +4,14 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import type { Config, IdeClient, File } from '@google/gemini-cli-core'; import {
type Config,
IdeClient,
type File,
logIdeConnection,
IdeConnectionEvent,
IdeConnectionType,
} from '@google/gemini-cli-core';
import { import {
getIdeInstaller, getIdeInstaller,
IDEConnectionStatus, IDEConnectionStatus,
@@ -111,11 +118,22 @@ async function getIdeStatusMessageWithFiles(ideClient: IdeClient): Promise<{
} }
} }
export const ideCommand = (config: Config | null): SlashCommand | null => { async function setIdeModeAndSyncConnection(
if (!config) { config: Config,
return null; value: boolean,
): Promise<void> {
config.setIdeMode(value);
const ideClient = await IdeClient.getInstance();
if (value) {
await ideClient.connect();
logIdeConnection(config, new IdeConnectionEvent(IdeConnectionType.SESSION));
} else {
await ideClient.disconnect();
} }
const ideClient = config.getIdeClient(); }
export const ideCommand = async (): Promise<SlashCommand> => {
const ideClient = await IdeClient.getInstance();
const currentIDE = ideClient.getCurrentIde(); const currentIDE = ideClient.getCurrentIde();
if (!currentIDE || !ideClient.getDetectedIdeDisplayName()) { if (!currentIDE || !ideClient.getDetectedIdeDisplayName()) {
return { return {
@@ -194,7 +212,7 @@ export const ideCommand = (config: Config | null): SlashCommand | null => {
); );
// Poll for up to 5 seconds for the extension to activate. // Poll for up to 5 seconds for the extension to activate.
for (let i = 0; i < 10; i++) { for (let i = 0; i < 10; i++) {
await config.setIdeModeAndSyncConnection(true); await setIdeModeAndSyncConnection(context.services.config!, true);
if ( if (
ideClient.getConnectionStatus().status === ideClient.getConnectionStatus().status ===
IDEConnectionStatus.Connected IDEConnectionStatus.Connected
@@ -236,7 +254,7 @@ export const ideCommand = (config: Config | null): SlashCommand | null => {
'ide.enabled', 'ide.enabled',
true, true,
); );
await config.setIdeModeAndSyncConnection(true); await setIdeModeAndSyncConnection(context.services.config!, true);
const { messageType, content } = getIdeStatusMessage(ideClient); const { messageType, content } = getIdeStatusMessage(ideClient);
context.ui.addItem( context.ui.addItem(
{ {
@@ -258,7 +276,7 @@ export const ideCommand = (config: Config | null): SlashCommand | null => {
'ide.enabled', 'ide.enabled',
false, false,
); );
await config.setIdeModeAndSyncConnection(false); await setIdeModeAndSyncConnection(context.services.config!, false);
const { messageType, content } = getIdeStatusMessage(ideClient); const { messageType, content } = getIdeStatusMessage(ideClient);
context.ui.addItem( context.ui.addItem(
{ {
@@ -15,7 +15,7 @@ import type {
ToolMcpConfirmationDetails, ToolMcpConfirmationDetails,
Config, Config,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import { ToolConfirmationOutcome } from '@google/gemini-cli-core'; import { IdeClient, ToolConfirmationOutcome } from '@google/gemini-cli-core';
import type { RadioSelectItem } from '../shared/RadioButtonSelect.js'; import type { RadioSelectItem } from '../shared/RadioButtonSelect.js';
import { RadioButtonSelect } from '../shared/RadioButtonSelect.js'; import { RadioButtonSelect } from '../shared/RadioButtonSelect.js';
import { MaxSizedBox } from '../shared/MaxSizedBox.js'; import { MaxSizedBox } from '../shared/MaxSizedBox.js';
@@ -43,10 +43,10 @@ export const ToolConfirmationMessage: React.FC<
const handleConfirm = async (outcome: ToolConfirmationOutcome) => { const handleConfirm = async (outcome: ToolConfirmationOutcome) => {
if (confirmationDetails.type === 'edit') { if (confirmationDetails.type === 'edit') {
const ideClient = config.getIdeClient();
if (config.getIdeMode()) { if (config.getIdeMode()) {
const cliOutcome = const cliOutcome =
outcome === ToolConfirmationOutcome.Cancel ? 'rejected' : 'accepted'; outcome === ToolConfirmationOutcome.Cancel ? 'rejected' : 'accepted';
const ideClient = await IdeClient.getInstance();
await ideClient?.resolveDiffFromCli( await ideClient?.resolveDiffFromCli(
confirmationDetails.filePath, confirmationDetails.filePath,
cliOutcome, cliOutcome,
@@ -86,8 +86,6 @@ import {
SlashCommandStatus, SlashCommandStatus,
ToolConfirmationOutcome, ToolConfirmationOutcome,
makeFakeConfig, makeFakeConfig,
ToolConfirmationOutcome,
type IdeClient,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
function createTestCommand( function createTestCommand(
@@ -111,11 +109,6 @@ describe('useSlashCommandProcessor', () => {
const mockSetQuittingMessages = vi.fn(); const mockSetQuittingMessages = vi.fn();
const mockConfig = makeFakeConfig({}); const mockConfig = makeFakeConfig({});
vi.spyOn(mockConfig, 'getIdeClient').mockReturnValue({
addStatusChangeListener: vi.fn(),
removeStatusChangeListener: vi.fn(),
} as unknown as IdeClient);
const mockSettings = {} as LoadedSettings; const mockSettings = {} as LoadedSettings;
beforeEach(() => { beforeEach(() => {
@@ -17,6 +17,7 @@ import {
SlashCommandStatus, SlashCommandStatus,
ToolConfirmationOutcome, ToolConfirmationOutcome,
Storage, Storage,
IdeClient,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import { useSessionStats } from '../contexts/SessionContext.js'; import { useSessionStats } from '../contexts/SessionContext.js';
import { runExitCleanup } from '../../utils/cleanup.js'; import { runExitCleanup } from '../../utils/cleanup.js';
@@ -215,15 +216,20 @@ export const useSlashCommandProcessor = (
return; return;
} }
const ideClient = config.getIdeClient();
const listener = () => { const listener = () => {
reloadCommands(); reloadCommands();
}; };
(async () => {
const ideClient = await IdeClient.getInstance();
ideClient.addStatusChangeListener(listener); ideClient.addStatusChangeListener(listener);
})();
return () => { return () => {
(async () => {
const ideClient = await IdeClient.getInstance();
ideClient.removeStatusChangeListener(listener); ideClient.removeStatusChangeListener(listener);
})();
}; };
}, [config, reloadCommands]); }, [config, reloadCommands]);
@@ -5,25 +5,26 @@
*/ */
import { useCallback, useEffect, useState, useSyncExternalStore } from 'react'; import { useCallback, useEffect, useState, useSyncExternalStore } from 'react';
import type { Config } from '@google/gemini-cli-core'; import { IdeClient, ideContext } from '@google/gemini-cli-core';
import { ideContext } from '@google/gemini-cli-core';
/** /**
* This hook listens for trust status updates from the IDE companion extension. * This hook listens for trust status updates from the IDE companion extension.
* It provides the current trust status from the IDE and a flag indicating * It provides the current trust status from the IDE and a flag indicating
* if a restart is needed because the trust state has changed. * if a restart is needed because the trust state has changed.
*/ */
export function useIdeTrustListener(config: Config) { export function useIdeTrustListener() {
const subscribe = useCallback( const subscribe = useCallback((onStoreChange: () => void) => {
(onStoreChange: () => void) => { (async () => {
const ideClient = config.getIdeClient(); const ideClient = await IdeClient.getInstance();
ideClient.addTrustChangeListener(onStoreChange); ideClient.addTrustChangeListener(onStoreChange);
})();
return () => { return () => {
(async () => {
const ideClient = await IdeClient.getInstance();
ideClient.removeTrustChangeListener(onStoreChange); ideClient.removeTrustChangeListener(onStoreChange);
})();
}; };
}, }, []);
[config],
);
const getSnapshot = () => const getSnapshot = () =>
ideContext.getIdeContext()?.workspaceState?.isTrusted; ideContext.getIdeContext()?.workspaceState?.isTrusted;
+7 -16
View File
@@ -260,7 +260,7 @@ export class Config {
private readonly folderTrustFeature: boolean; private readonly folderTrustFeature: boolean;
private readonly folderTrust: boolean; private readonly folderTrust: boolean;
private ideMode: boolean; private ideMode: boolean;
private ideClient!: IdeClient;
private inFallbackMode = false; private inFallbackMode = false;
private readonly maxSessionTurns: number; private readonly maxSessionTurns: number;
private readonly listExtensions: boolean; private readonly listExtensions: boolean;
@@ -383,7 +383,12 @@ export class Config {
throw Error('Config was already initialized'); throw Error('Config was already initialized');
} }
this.initialized = true; this.initialized = true;
this.ideClient = await IdeClient.getInstance();
if (this.getIdeMode()) {
await (await IdeClient.getInstance()).connect();
logIdeConnection(this, new IdeConnectionEvent(IdeConnectionType.START));
}
// Initialize centralized FileDiscoveryService // Initialize centralized FileDiscoveryService
this.getFileService(); this.getFileService();
if (this.getCheckpointingEnabled()) { if (this.getCheckpointingEnabled()) {
@@ -762,20 +767,6 @@ export class Config {
this.ideMode = value; this.ideMode = value;
} }
async setIdeModeAndSyncConnection(value: boolean): Promise<void> {
this.ideMode = value;
if (value) {
await this.ideClient.connect();
logIdeConnection(this, new IdeConnectionEvent(IdeConnectionType.SESSION));
} else {
await this.ideClient.disconnect();
}
}
getIdeClient(): IdeClient {
return this.ideClient;
}
/** /**
* Get the current FileSystemService * Get the current FileSystemService
*/ */
+7 -5
View File
@@ -65,7 +65,7 @@ function getRealPath(path: string): string {
* Manages the connection to and interaction with the IDE server. * Manages the connection to and interaction with the IDE server.
*/ */
export class IdeClient { export class IdeClient {
private static instance: IdeClient; private static instancePromise: Promise<IdeClient> | null = null;
private client: Client | undefined = undefined; private client: Client | undefined = undefined;
private state: IDEConnectionState = { private state: IDEConnectionState = {
status: IDEConnectionStatus.Disconnected, status: IDEConnectionStatus.Disconnected,
@@ -81,8 +81,9 @@ export class IdeClient {
private constructor() {} private constructor() {}
static async getInstance(): Promise<IdeClient> { static getInstance(): Promise<IdeClient> {
if (!IdeClient.instance) { if (!IdeClient.instancePromise) {
IdeClient.instancePromise = (async () => {
const client = new IdeClient(); const client = new IdeClient();
client.ideProcessInfo = await getIdeProcessInfo(); client.ideProcessInfo = await getIdeProcessInfo();
client.currentIde = detectIde(client.ideProcessInfo); client.currentIde = detectIde(client.ideProcessInfo);
@@ -91,9 +92,10 @@ export class IdeClient {
client.currentIde, client.currentIde,
).displayName; ).displayName;
} }
IdeClient.instance = client; return client;
})();
} }
return IdeClient.instance; return IdeClient.instancePromise;
} }
addStatusChangeListener(listener: (state: IDEConnectionState) => void) { addStatusChangeListener(listener: (state: IDEConnectionState) => void) {
+12 -3
View File
@@ -10,7 +10,17 @@ const mockEnsureCorrectEdit = vi.hoisted(() => vi.fn());
const mockGenerateJson = vi.hoisted(() => vi.fn()); const mockGenerateJson = vi.hoisted(() => vi.fn());
const mockOpenDiff = vi.hoisted(() => vi.fn()); const mockOpenDiff = vi.hoisted(() => vi.fn());
import { IDEConnectionStatus } from '../ide/ide-client.js'; import { IdeClient, IDEConnectionStatus } from '../ide/ide-client.js';
vi.mock('../ide/ide-client.js', () => ({
IdeClient: {
getInstance: vi.fn(),
},
IDEConnectionStatus: {
Connected: 'connected',
Disconnected: 'disconnected',
},
}));
vi.mock('../utils/editCorrector.js', () => ({ vi.mock('../utils/editCorrector.js', () => ({
ensureCorrectEdit: mockEnsureCorrectEdit, ensureCorrectEdit: mockEnsureCorrectEdit,
@@ -70,7 +80,6 @@ describe('EditTool', () => {
setApprovalMode: vi.fn(), setApprovalMode: vi.fn(),
getWorkspaceContext: () => createMockWorkspaceContext(rootDir), getWorkspaceContext: () => createMockWorkspaceContext(rootDir),
getFileSystemService: () => new StandardFileSystemService(), getFileSystemService: () => new StandardFileSystemService(),
getIdeClient: () => undefined,
getIdeMode: () => false, getIdeMode: () => false,
// getGeminiConfig: () => ({ apiKey: 'test-api-key' }), // This was not a real Config method // getGeminiConfig: () => ({ apiKey: 'test-api-key' }), // This was not a real Config method
// Add other properties/methods of Config if EditTool uses them // Add other properties/methods of Config if EditTool uses them
@@ -878,8 +887,8 @@ describe('EditTool', () => {
status: IDEConnectionStatus.Connected, status: IDEConnectionStatus.Connected,
}), }),
}; };
vi.mocked(IdeClient.getInstance).mockResolvedValue(ideClient);
(mockConfig as any).getIdeMode = () => true; (mockConfig as any).getIdeMode = () => true;
(mockConfig as any).getIdeClient = () => ideClient;
}); });
it('should call ideClient.openDiff and update params on confirmation', async () => { it('should call ideClient.openDiff and update params on confirmation', async () => {
+2 -2
View File
@@ -33,7 +33,7 @@ import type {
ModifiableDeclarativeTool, ModifiableDeclarativeTool,
ModifyContext, ModifyContext,
} from './modifiable-tool.js'; } from './modifiable-tool.js';
import { IDEConnectionStatus } from '../ide/ide-client.js'; import { IdeClient, IDEConnectionStatus } from '../ide/ide-client.js';
export function applyReplacement( export function applyReplacement(
currentContent: string | null, currentContent: string | null,
@@ -267,7 +267,7 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
'Proposed', 'Proposed',
DEFAULT_DIFF_OPTIONS, DEFAULT_DIFF_OPTIONS,
); );
const ideClient = this.config.getIdeClient(); const ideClient = await IdeClient.getInstance();
const ideConfirmation = const ideConfirmation =
this.config.getIdeMode() && this.config.getIdeMode() &&
ideClient?.getConnectionStatus().status === IDEConnectionStatus.Connected ideClient?.getConnectionStatus().status === IDEConnectionStatus.Connected
+12 -3
View File
@@ -10,7 +10,17 @@ const mockFixLLMEditWithInstruction = vi.hoisted(() => vi.fn());
const mockGenerateJson = vi.hoisted(() => vi.fn()); const mockGenerateJson = vi.hoisted(() => vi.fn());
const mockOpenDiff = vi.hoisted(() => vi.fn()); const mockOpenDiff = vi.hoisted(() => vi.fn());
import { IDEConnectionStatus } from '../ide/ide-client.js'; import { IdeClient, IDEConnectionStatus } from '../ide/ide-client.js';
vi.mock('../ide/ide-client.js', () => ({
IdeClient: {
getInstance: vi.fn(),
},
IDEConnectionStatus: {
Connected: 'connected',
Disconnected: 'disconnected',
},
}));
vi.mock('../utils/llm-edit-fixer.js', () => ({ vi.mock('../utils/llm-edit-fixer.js', () => ({
FixLLMEditWithInstruction: mockFixLLMEditWithInstruction, FixLLMEditWithInstruction: mockFixLLMEditWithInstruction,
@@ -75,7 +85,6 @@ describe('SmartEditTool', () => {
setApprovalMode: vi.fn(), setApprovalMode: vi.fn(),
getWorkspaceContext: () => createMockWorkspaceContext(rootDir), getWorkspaceContext: () => createMockWorkspaceContext(rootDir),
getFileSystemService: () => new StandardFileSystemService(), getFileSystemService: () => new StandardFileSystemService(),
getIdeClient: () => undefined,
getIdeMode: () => false, getIdeMode: () => false,
getApiKey: () => 'test-api-key', getApiKey: () => 'test-api-key',
getModel: () => 'test-model', getModel: () => 'test-model',
@@ -449,8 +458,8 @@ describe('SmartEditTool', () => {
status: IDEConnectionStatus.Connected, status: IDEConnectionStatus.Connected,
}), }),
}; };
vi.mocked(IdeClient.getInstance).mockResolvedValue(ideClient);
(mockConfig as any).getIdeMode = () => true; (mockConfig as any).getIdeMode = () => true;
(mockConfig as any).getIdeClient = () => ideClient;
}); });
it('should call ideClient.openDiff and update params on confirmation', async () => { it('should call ideClient.openDiff and update params on confirmation', async () => {
+2 -2
View File
@@ -28,7 +28,7 @@ import {
type ModifiableDeclarativeTool, type ModifiableDeclarativeTool,
type ModifyContext, type ModifyContext,
} from './modifiable-tool.js'; } from './modifiable-tool.js';
import { IDEConnectionStatus } from '../ide/ide-client.js'; import { IdeClient, IDEConnectionStatus } from '../ide/ide-client.js';
import { FixLLMEditWithInstruction } from '../utils/llm-edit-fixer.js'; import { FixLLMEditWithInstruction } from '../utils/llm-edit-fixer.js';
export function applyReplacement( export function applyReplacement(
@@ -526,7 +526,7 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
'Proposed', 'Proposed',
DEFAULT_DIFF_OPTIONS, DEFAULT_DIFF_OPTIONS,
); );
const ideClient = this.config.getIdeClient(); const ideClient = await IdeClient.getInstance();
const ideConfirmation = const ideConfirmation =
this.config.getIdeMode() && this.config.getIdeMode() &&
ideClient?.getConnectionStatus().status === IDEConnectionStatus.Connected ideClient?.getConnectionStatus().status === IDEConnectionStatus.Connected
+5 -10
View File
@@ -39,7 +39,11 @@ const rootDir = path.resolve(os.tmpdir(), 'gemini-cli-test-root');
// --- MOCKS --- // --- MOCKS ---
vi.mock('../core/client.js'); vi.mock('../core/client.js');
vi.mock('../utils/editCorrector.js'); vi.mock('../utils/editCorrector.js');
vi.mock('../ide/ide-client.js', () => ({
IdeClient: {
getInstance: vi.fn(),
},
}));
let mockGeminiClientInstance: Mocked<GeminiClient>; let mockGeminiClientInstance: Mocked<GeminiClient>;
const mockEnsureCorrectEdit = vi.fn<typeof ensureCorrectEdit>(); const mockEnsureCorrectEdit = vi.fn<typeof ensureCorrectEdit>();
const mockEnsureCorrectFileContent = vi.fn<typeof ensureCorrectFileContent>(); const mockEnsureCorrectFileContent = vi.fn<typeof ensureCorrectFileContent>();
@@ -58,7 +62,6 @@ const mockConfigInternal = {
setApprovalMode: vi.fn(), setApprovalMode: vi.fn(),
getGeminiClient: vi.fn(), // Initialize as a plain mock function getGeminiClient: vi.fn(), // Initialize as a plain mock function
getFileSystemService: () => fsService, getFileSystemService: () => fsService,
getIdeClient: vi.fn(),
getIdeMode: vi.fn(() => false), getIdeMode: vi.fn(() => false),
getWorkspaceContext: () => createMockWorkspaceContext(rootDir), getWorkspaceContext: () => createMockWorkspaceContext(rootDir),
getApiKey: () => 'test-key', getApiKey: () => 'test-key',
@@ -120,14 +123,6 @@ describe('WriteFileTool', () => {
mockConfigInternal.getGeminiClient.mockReturnValue( mockConfigInternal.getGeminiClient.mockReturnValue(
mockGeminiClientInstance, mockGeminiClientInstance,
); );
mockConfigInternal.getIdeClient.mockReturnValue({
openDiff: vi.fn(),
closeDiff: vi.fn(),
getIdeContext: vi.fn(),
subscribeToIdeContext: vi.fn(),
isCodeTrackerEnabled: vi.fn(),
getTrackedCode: vi.fn(),
});
tool = new WriteFileTool(mockConfig); tool = new WriteFileTool(mockConfig);
+2 -2
View File
@@ -35,7 +35,7 @@ import type {
ModifiableDeclarativeTool, ModifiableDeclarativeTool,
ModifyContext, ModifyContext,
} from './modifiable-tool.js'; } from './modifiable-tool.js';
import { IDEConnectionStatus } from '../ide/ide-client.js'; import { IdeClient, IDEConnectionStatus } from '../ide/ide-client.js';
import { logFileOperation } from '../telemetry/loggers.js'; import { logFileOperation } from '../telemetry/loggers.js';
import { FileOperationEvent } from '../telemetry/types.js'; import { FileOperationEvent } from '../telemetry/types.js';
import { FileOperation } from '../telemetry/metrics.js'; import { FileOperation } from '../telemetry/metrics.js';
@@ -193,7 +193,7 @@ class WriteFileToolInvocation extends BaseToolInvocation<
DEFAULT_DIFF_OPTIONS, DEFAULT_DIFF_OPTIONS,
); );
const ideClient = this.config.getIdeClient(); const ideClient = await IdeClient.getInstance();
const ideConfirmation = const ideConfirmation =
this.config.getIdeMode() && this.config.getIdeMode() &&
ideClient.getConnectionStatus().status === IDEConnectionStatus.Connected ideClient.getConnectionStatus().status === IDEConnectionStatus.Connected