mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-11 06:31:01 -07:00
feat(vscode-ide-companion): harden ide-server with CORS and host validation (#8512)
This commit is contained in:
@@ -9,6 +9,7 @@ import type * as vscode from 'vscode';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import type * as os from 'node:os';
|
||||
import * as path from 'node:path';
|
||||
import * as http from 'node:http';
|
||||
import { IDEServer } from './ide-server.js';
|
||||
import type { DiffManager } from './diff-manager.js';
|
||||
|
||||
@@ -62,26 +63,26 @@ vi.mock('./open-files-manager', () => {
|
||||
return { OpenFilesManager };
|
||||
});
|
||||
|
||||
const getPortFromMock = (
|
||||
replaceMock: ReturnType<
|
||||
() => vscode.ExtensionContext['environmentVariableCollection']['replace']
|
||||
>,
|
||||
) => {
|
||||
const port = vi
|
||||
.mocked(replaceMock)
|
||||
.mock.calls.find((call) => call[0] === 'GEMINI_CLI_IDE_SERVER_PORT')?.[1];
|
||||
|
||||
if (port === undefined) {
|
||||
expect.fail('Port was not set');
|
||||
}
|
||||
return port;
|
||||
};
|
||||
|
||||
describe('IDEServer', () => {
|
||||
let ideServer: IDEServer;
|
||||
let mockContext: vscode.ExtensionContext;
|
||||
let mockLog: (message: string) => void;
|
||||
|
||||
const getPortFromMock = (
|
||||
replaceMock: ReturnType<
|
||||
() => vscode.ExtensionContext['environmentVariableCollection']['replace']
|
||||
>,
|
||||
) => {
|
||||
const port = vi
|
||||
.mocked(replaceMock)
|
||||
.mock.calls.find((call) => call[0] === 'GEMINI_CLI_IDE_SERVER_PORT')?.[1];
|
||||
|
||||
if (port === undefined) {
|
||||
expect.fail('Port was not set');
|
||||
}
|
||||
return port;
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
mockLog = vi.fn();
|
||||
ideServer = new IDEServer(mockLog, mocks.diffManager);
|
||||
@@ -456,3 +457,105 @@ describe('IDEServer', () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
const request = (
|
||||
port: string,
|
||||
options: http.RequestOptions,
|
||||
body?: string,
|
||||
): Promise<http.IncomingMessage> =>
|
||||
new Promise((resolve, reject) => {
|
||||
const req = http.request(
|
||||
{
|
||||
hostname: '127.0.0.1',
|
||||
port,
|
||||
...options,
|
||||
},
|
||||
(res) => {
|
||||
res.resume(); // Consume response data to free up memory
|
||||
resolve(res);
|
||||
},
|
||||
);
|
||||
req.on('error', reject);
|
||||
if (body) {
|
||||
req.write(body);
|
||||
}
|
||||
req.end();
|
||||
});
|
||||
|
||||
describe('IDEServer HTTP endpoints', () => {
|
||||
let ideServer: IDEServer;
|
||||
let mockContext: vscode.ExtensionContext;
|
||||
let mockLog: (message: string) => void;
|
||||
let port: string;
|
||||
|
||||
beforeEach(async () => {
|
||||
mockLog = vi.fn();
|
||||
ideServer = new IDEServer(mockLog, mocks.diffManager);
|
||||
mockContext = {
|
||||
subscriptions: [],
|
||||
environmentVariableCollection: {
|
||||
replace: vi.fn(),
|
||||
clear: vi.fn(),
|
||||
},
|
||||
} as unknown as vscode.ExtensionContext;
|
||||
await ideServer.start(mockContext);
|
||||
const replaceMock = mockContext.environmentVariableCollection.replace;
|
||||
port = getPortFromMock(replaceMock);
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await ideServer.stop();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should deny requests with an origin header', async () => {
|
||||
const response = await request(
|
||||
port,
|
||||
{
|
||||
path: '/mcp',
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Host: `localhost:${port}`,
|
||||
Origin: 'https://evil.com',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
},
|
||||
JSON.stringify({ jsonrpc: '2.0', method: 'initialize' }),
|
||||
);
|
||||
expect(response.statusCode).toBe(403);
|
||||
});
|
||||
|
||||
it('should deny requests with an invalid host header', async () => {
|
||||
const response = await request(
|
||||
port,
|
||||
{
|
||||
path: '/mcp',
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Host: 'evil.com',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
},
|
||||
JSON.stringify({ jsonrpc: '2.0', method: 'initialize' }),
|
||||
);
|
||||
expect(response.statusCode).toBe(403);
|
||||
});
|
||||
|
||||
it('should allow requests with a valid host header', async () => {
|
||||
const response = await request(
|
||||
port,
|
||||
{
|
||||
path: '/mcp',
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Host: `localhost:${port}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
},
|
||||
JSON.stringify({ jsonrpc: '2.0', method: 'initialize' }),
|
||||
);
|
||||
// We expect a 400 here because we are not sending a valid MCP request,
|
||||
// but it's not a host error, which is what we are testing.
|
||||
expect(response.statusCode).toBe(400);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -13,7 +13,12 @@ import {
|
||||
import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
|
||||
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
|
||||
import express, { type Request, type Response } from 'express';
|
||||
import express, {
|
||||
type Request,
|
||||
type Response,
|
||||
type NextFunction,
|
||||
} from 'express';
|
||||
import cors from 'cors';
|
||||
import { randomUUID } from 'node:crypto';
|
||||
import { type Server as HTTPServer } from 'node:http';
|
||||
import * as path from 'node:path';
|
||||
@@ -23,6 +28,13 @@ import type { z } from 'zod';
|
||||
import type { DiffManager } from './diff-manager.js';
|
||||
import { OpenFilesManager } from './open-files-manager.js';
|
||||
|
||||
class CORSError extends Error {
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = 'CORSError';
|
||||
}
|
||||
}
|
||||
|
||||
const MCP_SESSION_ID_HEADER = 'mcp-session-id';
|
||||
const IDE_SERVER_PORT_ENV_VAR = 'GEMINI_CLI_IDE_SERVER_PORT';
|
||||
const IDE_WORKSPACE_PATH_ENV_VAR = 'GEMINI_CLI_IDE_WORKSPACE_PATH';
|
||||
@@ -131,6 +143,34 @@ export class IDEServer {
|
||||
|
||||
const app = express();
|
||||
app.use(express.json({ limit: '10mb' }));
|
||||
|
||||
app.use(
|
||||
cors({
|
||||
origin: (origin, callback) => {
|
||||
// Only allow non-browser requests with no origin.
|
||||
if (!origin) {
|
||||
return callback(null, true);
|
||||
}
|
||||
return callback(
|
||||
new CORSError('Request denied by CORS policy.'),
|
||||
false,
|
||||
);
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
app.use((req, res, next) => {
|
||||
const host = req.headers.host || '';
|
||||
const allowedHosts = [
|
||||
`localhost:${this.port}`,
|
||||
`127.0.0.1:${this.port}`,
|
||||
];
|
||||
if (!allowedHosts.includes(host)) {
|
||||
return res.status(403).json({ error: 'Invalid Host header' });
|
||||
}
|
||||
next();
|
||||
});
|
||||
|
||||
app.use((req, res, next) => {
|
||||
const authHeader = req.headers.authorization;
|
||||
if (authHeader) {
|
||||
@@ -274,7 +314,15 @@ export class IDEServer {
|
||||
|
||||
app.get('/mcp', handleSessionRequest);
|
||||
|
||||
this.server = app.listen(0, async () => {
|
||||
app.use((err: Error, req: Request, res: Response, next: NextFunction) => {
|
||||
if (err instanceof CORSError) {
|
||||
res.status(403).json({ error: 'Request denied by CORS policy.' });
|
||||
} else {
|
||||
next(err);
|
||||
}
|
||||
});
|
||||
|
||||
this.server = app.listen(0, '127.0.0.1', async () => {
|
||||
const address = (this.server as HTTPServer).address();
|
||||
if (address && typeof address !== 'string') {
|
||||
this.port = address.port;
|
||||
@@ -286,7 +334,7 @@ export class IDEServer {
|
||||
os.tmpdir(),
|
||||
`gemini-ide-server-${process.ppid}.json`,
|
||||
);
|
||||
this.log(`IDE server listening on port ${this.port}`);
|
||||
this.log(`IDE server listening on http://127.0.0.1:${this.port}`);
|
||||
|
||||
if (this.authToken) {
|
||||
await writePortAndWorkspace({
|
||||
|
||||
Reference in New Issue
Block a user