feat(vscode-ide-companion): harden ide-server with CORS and host validation (#8512)

This commit is contained in:
Shreya Keshive
2025-09-18 10:06:16 -04:00
committed by GitHub
parent 930f39a0cd
commit d746eb7b22
2 changed files with 169 additions and 18 deletions

View File

@@ -9,6 +9,7 @@ import type * as vscode from 'vscode';
import * as fs from 'node:fs/promises';
import type * as os from 'node:os';
import * as path from 'node:path';
import * as http from 'node:http';
import { IDEServer } from './ide-server.js';
import type { DiffManager } from './diff-manager.js';
@@ -62,26 +63,26 @@ vi.mock('./open-files-manager', () => {
return { OpenFilesManager };
});
const getPortFromMock = (
replaceMock: ReturnType<
() => vscode.ExtensionContext['environmentVariableCollection']['replace']
>,
) => {
const port = vi
.mocked(replaceMock)
.mock.calls.find((call) => call[0] === 'GEMINI_CLI_IDE_SERVER_PORT')?.[1];
if (port === undefined) {
expect.fail('Port was not set');
}
return port;
};
describe('IDEServer', () => {
let ideServer: IDEServer;
let mockContext: vscode.ExtensionContext;
let mockLog: (message: string) => void;
const getPortFromMock = (
replaceMock: ReturnType<
() => vscode.ExtensionContext['environmentVariableCollection']['replace']
>,
) => {
const port = vi
.mocked(replaceMock)
.mock.calls.find((call) => call[0] === 'GEMINI_CLI_IDE_SERVER_PORT')?.[1];
if (port === undefined) {
expect.fail('Port was not set');
}
return port;
};
beforeEach(() => {
mockLog = vi.fn();
ideServer = new IDEServer(mockLog, mocks.diffManager);
@@ -456,3 +457,105 @@ describe('IDEServer', () => {
});
});
});
const request = (
port: string,
options: http.RequestOptions,
body?: string,
): Promise<http.IncomingMessage> =>
new Promise((resolve, reject) => {
const req = http.request(
{
hostname: '127.0.0.1',
port,
...options,
},
(res) => {
res.resume(); // Consume response data to free up memory
resolve(res);
},
);
req.on('error', reject);
if (body) {
req.write(body);
}
req.end();
});
describe('IDEServer HTTP endpoints', () => {
let ideServer: IDEServer;
let mockContext: vscode.ExtensionContext;
let mockLog: (message: string) => void;
let port: string;
beforeEach(async () => {
mockLog = vi.fn();
ideServer = new IDEServer(mockLog, mocks.diffManager);
mockContext = {
subscriptions: [],
environmentVariableCollection: {
replace: vi.fn(),
clear: vi.fn(),
},
} as unknown as vscode.ExtensionContext;
await ideServer.start(mockContext);
const replaceMock = mockContext.environmentVariableCollection.replace;
port = getPortFromMock(replaceMock);
});
afterEach(async () => {
await ideServer.stop();
vi.restoreAllMocks();
});
it('should deny requests with an origin header', async () => {
const response = await request(
port,
{
path: '/mcp',
method: 'POST',
headers: {
Host: `localhost:${port}`,
Origin: 'https://evil.com',
'Content-Type': 'application/json',
},
},
JSON.stringify({ jsonrpc: '2.0', method: 'initialize' }),
);
expect(response.statusCode).toBe(403);
});
it('should deny requests with an invalid host header', async () => {
const response = await request(
port,
{
path: '/mcp',
method: 'POST',
headers: {
Host: 'evil.com',
'Content-Type': 'application/json',
},
},
JSON.stringify({ jsonrpc: '2.0', method: 'initialize' }),
);
expect(response.statusCode).toBe(403);
});
it('should allow requests with a valid host header', async () => {
const response = await request(
port,
{
path: '/mcp',
method: 'POST',
headers: {
Host: `localhost:${port}`,
'Content-Type': 'application/json',
},
},
JSON.stringify({ jsonrpc: '2.0', method: 'initialize' }),
);
// We expect a 400 here because we are not sending a valid MCP request,
// but it's not a host error, which is what we are testing.
expect(response.statusCode).toBe(400);
});
});

View File

@@ -13,7 +13,12 @@ import {
import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js';
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import express, { type Request, type Response } from 'express';
import express, {
type Request,
type Response,
type NextFunction,
} from 'express';
import cors from 'cors';
import { randomUUID } from 'node:crypto';
import { type Server as HTTPServer } from 'node:http';
import * as path from 'node:path';
@@ -23,6 +28,13 @@ import type { z } from 'zod';
import type { DiffManager } from './diff-manager.js';
import { OpenFilesManager } from './open-files-manager.js';
class CORSError extends Error {
constructor(message: string) {
super(message);
this.name = 'CORSError';
}
}
const MCP_SESSION_ID_HEADER = 'mcp-session-id';
const IDE_SERVER_PORT_ENV_VAR = 'GEMINI_CLI_IDE_SERVER_PORT';
const IDE_WORKSPACE_PATH_ENV_VAR = 'GEMINI_CLI_IDE_WORKSPACE_PATH';
@@ -131,6 +143,34 @@ export class IDEServer {
const app = express();
app.use(express.json({ limit: '10mb' }));
app.use(
cors({
origin: (origin, callback) => {
// Only allow non-browser requests with no origin.
if (!origin) {
return callback(null, true);
}
return callback(
new CORSError('Request denied by CORS policy.'),
false,
);
},
}),
);
app.use((req, res, next) => {
const host = req.headers.host || '';
const allowedHosts = [
`localhost:${this.port}`,
`127.0.0.1:${this.port}`,
];
if (!allowedHosts.includes(host)) {
return res.status(403).json({ error: 'Invalid Host header' });
}
next();
});
app.use((req, res, next) => {
const authHeader = req.headers.authorization;
if (authHeader) {
@@ -274,7 +314,15 @@ export class IDEServer {
app.get('/mcp', handleSessionRequest);
this.server = app.listen(0, async () => {
app.use((err: Error, req: Request, res: Response, next: NextFunction) => {
if (err instanceof CORSError) {
res.status(403).json({ error: 'Request denied by CORS policy.' });
} else {
next(err);
}
});
this.server = app.listen(0, '127.0.0.1', async () => {
const address = (this.server as HTTPServer).address();
if (address && typeof address !== 'string') {
this.port = address.port;
@@ -286,7 +334,7 @@ export class IDEServer {
os.tmpdir(),
`gemini-ide-server-${process.ppid}.json`,
);
this.log(`IDE server listening on port ${this.port}`);
this.log(`IDE server listening on http://127.0.0.1:${this.port}`);
if (this.authToken) {
await writePortAndWorkspace({