Override Gemini CLI trust with VScode workspace trust when in IDE (#7433)

This commit is contained in:
shrutip90
2025-09-03 11:44:26 -07:00
committed by GitHub
parent 5ccf46b5a0
commit 7c667e100e
16 changed files with 248 additions and 30 deletions
@@ -32,6 +32,7 @@ vi.mock('vscode', () => ({
onDidCloseTextDocument: vi.fn(),
registerTextDocumentContentProvider: vi.fn(),
onDidChangeWorkspaceFolders: vi.fn(),
onDidGrantWorkspaceTrust: vi.fn(),
},
commands: {
registerCommand: vi.fn(),
@@ -91,6 +92,11 @@ describe('activate', () => {
expect(vscode.window.showInformationMessage).not.toHaveBeenCalled();
});
it('should register a handler for onDidGrantWorkspaceTrust', async () => {
await activate(context);
expect(vscode.workspace.onDidGrantWorkspaceTrust).toHaveBeenCalled();
});
it('should launch the Gemini CLI when the user clicks the button', async () => {
const showInformationMessageMock = vi
.mocked(vscode.window.showInformationMessage)
@@ -72,7 +72,10 @@ export async function activate(context: vscode.ExtensionContext) {
context.subscriptions.push(
vscode.workspace.onDidChangeWorkspaceFolders(() => {
ideServer.updateWorkspacePath();
ideServer.syncEnvVars();
}),
vscode.workspace.onDidGrantWorkspaceTrust(() => {
ideServer.syncEnvVars();
}),
vscode.commands.registerCommand('gemini-cli.runGeminiCLI', async () => {
const workspaceFolders = vscode.workspace.workspaceFolders;
@@ -45,6 +45,7 @@ const vscodeMock = vi.hoisted(() => ({
},
},
],
isTrusted: true,
},
}));
@@ -229,7 +230,7 @@ describe('IDEServer', () => {
{ uri: { fsPath: '/foo/bar' } },
{ uri: { fsPath: '/baz/qux' } },
];
await ideServer.updateWorkspacePath();
await ideServer.syncEnvVars();
const expectedWorkspacePaths = ['/foo/bar', '/baz/qux'].join(
path.delimiter,
@@ -264,7 +265,7 @@ describe('IDEServer', () => {
// Simulate removing a folder
vscodeMock.workspace.workspaceFolders = [{ uri: { fsPath: '/baz/qux' } }];
await ideServer.updateWorkspacePath();
await ideServer.syncEnvVars();
expect(replaceMock).toHaveBeenCalledWith(
'GEMINI_CLI_IDE_WORKSPACE_PATH',
+33 -21
View File
@@ -91,6 +91,9 @@ export class IDEServer {
private portFile: string | undefined;
private ppidPortFile: string | undefined;
private port: number | undefined;
private transports: { [sessionId: string]: StreamableHTTPServerTransport } =
{};
private openFilesManager: OpenFilesManager | undefined;
diffManager: DiffManager;
constructor(log: (message: string) => void, diffManager: DiffManager) {
@@ -102,27 +105,19 @@ export class IDEServer {
return new Promise((resolve) => {
this.context = context;
const sessionsWithInitialNotification = new Set<string>();
const transports: { [sessionId: string]: StreamableHTTPServerTransport } =
{};
const app = express();
app.use(express.json());
const mcpServer = createMcpServer(this.diffManager);
const openFilesManager = new OpenFilesManager(context);
const onDidChangeSubscription = openFilesManager.onDidChange(() => {
for (const transport of Object.values(transports)) {
sendIdeContextUpdateNotification(
transport,
this.log.bind(this),
openFilesManager,
);
}
this.openFilesManager = new OpenFilesManager(context);
const onDidChangeSubscription = this.openFilesManager.onDidChange(() => {
this.broadcastIdeContextUpdate();
});
context.subscriptions.push(onDidChangeSubscription);
const onDidChangeDiffSubscription = this.diffManager.onDidChange(
(notification) => {
for (const transport of Object.values(transports)) {
for (const transport of Object.values(this.transports)) {
transport.send(notification);
}
},
@@ -135,14 +130,14 @@ export class IDEServer {
| undefined;
let transport: StreamableHTTPServerTransport;
if (sessionId && transports[sessionId]) {
transport = transports[sessionId];
if (sessionId && this.transports[sessionId]) {
transport = this.transports[sessionId];
} else if (!sessionId && isInitializeRequest(req.body)) {
transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID(),
onsessioninitialized: (newSessionId) => {
this.log(`New session initialized: ${newSessionId}`);
transports[newSessionId] = transport;
this.transports[newSessionId] = transport;
},
});
const keepAlive = setInterval(() => {
@@ -161,7 +156,7 @@ export class IDEServer {
if (transport.sessionId) {
this.log(`Session closed: ${transport.sessionId}`);
sessionsWithInitialNotification.delete(transport.sessionId);
delete transports[transport.sessionId];
delete this.transports[transport.sessionId];
}
};
mcpServer.connect(transport);
@@ -204,13 +199,13 @@ export class IDEServer {
const sessionId = req.headers[MCP_SESSION_ID_HEADER] as
| string
| undefined;
if (!sessionId || !transports[sessionId]) {
if (!sessionId || !this.transports[sessionId]) {
this.log('Invalid or missing session ID');
res.status(400).send('Invalid or missing session ID');
return;
}
const transport = transports[sessionId];
const transport = this.transports[sessionId];
try {
await transport.handleRequest(req, res);
} catch (error) {
@@ -222,11 +217,14 @@ export class IDEServer {
}
}
if (!sessionsWithInitialNotification.has(sessionId)) {
if (
this.openFilesManager &&
!sessionsWithInitialNotification.has(sessionId)
) {
sendIdeContextUpdateNotification(
transport,
this.log.bind(this),
openFilesManager,
this.openFilesManager,
);
sessionsWithInitialNotification.add(sessionId);
}
@@ -260,7 +258,20 @@ export class IDEServer {
});
}
async updateWorkspacePath(): Promise<void> {
broadcastIdeContextUpdate() {
if (!this.openFilesManager) {
return;
}
for (const transport of Object.values(this.transports)) {
sendIdeContextUpdateNotification(
transport,
this.log.bind(this),
this.openFilesManager,
);
}
}
async syncEnvVars(): Promise<void> {
if (
this.context &&
this.server &&
@@ -275,6 +286,7 @@ export class IDEServer {
this.ppidPortFile,
this.log,
);
this.broadcastIdeContextUpdate();
}
}
@@ -172,6 +172,7 @@ export class OpenFilesManager {
return {
workspaceState: {
openFiles: [...this.openFiles],
isTrusted: vscode.workspace.isTrusted,
},
};
}