mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-12 21:03:05 -07:00
feat(vscode-ide-companion): harden ide-server with CORS and host validation (#8512)
This commit is contained in:
@@ -9,6 +9,7 @@ import type * as vscode from 'vscode';
|
|||||||
import * as fs from 'node:fs/promises';
|
import * as fs from 'node:fs/promises';
|
||||||
import type * as os from 'node:os';
|
import type * as os from 'node:os';
|
||||||
import * as path from 'node:path';
|
import * as path from 'node:path';
|
||||||
|
import * as http from 'node:http';
|
||||||
import { IDEServer } from './ide-server.js';
|
import { IDEServer } from './ide-server.js';
|
||||||
import type { DiffManager } from './diff-manager.js';
|
import type { DiffManager } from './diff-manager.js';
|
||||||
|
|
||||||
@@ -62,16 +63,11 @@ vi.mock('./open-files-manager', () => {
|
|||||||
return { OpenFilesManager };
|
return { OpenFilesManager };
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('IDEServer', () => {
|
const getPortFromMock = (
|
||||||
let ideServer: IDEServer;
|
|
||||||
let mockContext: vscode.ExtensionContext;
|
|
||||||
let mockLog: (message: string) => void;
|
|
||||||
|
|
||||||
const getPortFromMock = (
|
|
||||||
replaceMock: ReturnType<
|
replaceMock: ReturnType<
|
||||||
() => vscode.ExtensionContext['environmentVariableCollection']['replace']
|
() => vscode.ExtensionContext['environmentVariableCollection']['replace']
|
||||||
>,
|
>,
|
||||||
) => {
|
) => {
|
||||||
const port = vi
|
const port = vi
|
||||||
.mocked(replaceMock)
|
.mocked(replaceMock)
|
||||||
.mock.calls.find((call) => call[0] === 'GEMINI_CLI_IDE_SERVER_PORT')?.[1];
|
.mock.calls.find((call) => call[0] === 'GEMINI_CLI_IDE_SERVER_PORT')?.[1];
|
||||||
@@ -80,7 +76,12 @@ describe('IDEServer', () => {
|
|||||||
expect.fail('Port was not set');
|
expect.fail('Port was not set');
|
||||||
}
|
}
|
||||||
return port;
|
return port;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
describe('IDEServer', () => {
|
||||||
|
let ideServer: IDEServer;
|
||||||
|
let mockContext: vscode.ExtensionContext;
|
||||||
|
let mockLog: (message: string) => void;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
mockLog = vi.fn();
|
mockLog = vi.fn();
|
||||||
@@ -456,3 +457,105 @@ describe('IDEServer', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const request = (
|
||||||
|
port: string,
|
||||||
|
options: http.RequestOptions,
|
||||||
|
body?: string,
|
||||||
|
): Promise<http.IncomingMessage> =>
|
||||||
|
new Promise((resolve, reject) => {
|
||||||
|
const req = http.request(
|
||||||
|
{
|
||||||
|
hostname: '127.0.0.1',
|
||||||
|
port,
|
||||||
|
...options,
|
||||||
|
},
|
||||||
|
(res) => {
|
||||||
|
res.resume(); // Consume response data to free up memory
|
||||||
|
resolve(res);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
req.on('error', reject);
|
||||||
|
if (body) {
|
||||||
|
req.write(body);
|
||||||
|
}
|
||||||
|
req.end();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('IDEServer HTTP endpoints', () => {
|
||||||
|
let ideServer: IDEServer;
|
||||||
|
let mockContext: vscode.ExtensionContext;
|
||||||
|
let mockLog: (message: string) => void;
|
||||||
|
let port: string;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
mockLog = vi.fn();
|
||||||
|
ideServer = new IDEServer(mockLog, mocks.diffManager);
|
||||||
|
mockContext = {
|
||||||
|
subscriptions: [],
|
||||||
|
environmentVariableCollection: {
|
||||||
|
replace: vi.fn(),
|
||||||
|
clear: vi.fn(),
|
||||||
|
},
|
||||||
|
} as unknown as vscode.ExtensionContext;
|
||||||
|
await ideServer.start(mockContext);
|
||||||
|
const replaceMock = mockContext.environmentVariableCollection.replace;
|
||||||
|
port = getPortFromMock(replaceMock);
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(async () => {
|
||||||
|
await ideServer.stop();
|
||||||
|
vi.restoreAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should deny requests with an origin header', async () => {
|
||||||
|
const response = await request(
|
||||||
|
port,
|
||||||
|
{
|
||||||
|
path: '/mcp',
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
Host: `localhost:${port}`,
|
||||||
|
Origin: 'https://evil.com',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
JSON.stringify({ jsonrpc: '2.0', method: 'initialize' }),
|
||||||
|
);
|
||||||
|
expect(response.statusCode).toBe(403);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should deny requests with an invalid host header', async () => {
|
||||||
|
const response = await request(
|
||||||
|
port,
|
||||||
|
{
|
||||||
|
path: '/mcp',
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
Host: 'evil.com',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
JSON.stringify({ jsonrpc: '2.0', method: 'initialize' }),
|
||||||
|
);
|
||||||
|
expect(response.statusCode).toBe(403);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow requests with a valid host header', async () => {
|
||||||
|
const response = await request(
|
||||||
|
port,
|
||||||
|
{
|
||||||
|
path: '/mcp',
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
Host: `localhost:${port}`,
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
JSON.stringify({ jsonrpc: '2.0', method: 'initialize' }),
|
||||||
|
);
|
||||||
|
// We expect a 400 here because we are not sending a valid MCP request,
|
||||||
|
// but it's not a host error, which is what we are testing.
|
||||||
|
expect(response.statusCode).toBe(400);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|||||||
@@ -13,7 +13,12 @@ import {
|
|||||||
import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js';
|
import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js';
|
||||||
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
|
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
|
||||||
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
|
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
|
||||||
import express, { type Request, type Response } from 'express';
|
import express, {
|
||||||
|
type Request,
|
||||||
|
type Response,
|
||||||
|
type NextFunction,
|
||||||
|
} from 'express';
|
||||||
|
import cors from 'cors';
|
||||||
import { randomUUID } from 'node:crypto';
|
import { randomUUID } from 'node:crypto';
|
||||||
import { type Server as HTTPServer } from 'node:http';
|
import { type Server as HTTPServer } from 'node:http';
|
||||||
import * as path from 'node:path';
|
import * as path from 'node:path';
|
||||||
@@ -23,6 +28,13 @@ import type { z } from 'zod';
|
|||||||
import type { DiffManager } from './diff-manager.js';
|
import type { DiffManager } from './diff-manager.js';
|
||||||
import { OpenFilesManager } from './open-files-manager.js';
|
import { OpenFilesManager } from './open-files-manager.js';
|
||||||
|
|
||||||
|
class CORSError extends Error {
|
||||||
|
constructor(message: string) {
|
||||||
|
super(message);
|
||||||
|
this.name = 'CORSError';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const MCP_SESSION_ID_HEADER = 'mcp-session-id';
|
const MCP_SESSION_ID_HEADER = 'mcp-session-id';
|
||||||
const IDE_SERVER_PORT_ENV_VAR = 'GEMINI_CLI_IDE_SERVER_PORT';
|
const IDE_SERVER_PORT_ENV_VAR = 'GEMINI_CLI_IDE_SERVER_PORT';
|
||||||
const IDE_WORKSPACE_PATH_ENV_VAR = 'GEMINI_CLI_IDE_WORKSPACE_PATH';
|
const IDE_WORKSPACE_PATH_ENV_VAR = 'GEMINI_CLI_IDE_WORKSPACE_PATH';
|
||||||
@@ -131,6 +143,34 @@ export class IDEServer {
|
|||||||
|
|
||||||
const app = express();
|
const app = express();
|
||||||
app.use(express.json({ limit: '10mb' }));
|
app.use(express.json({ limit: '10mb' }));
|
||||||
|
|
||||||
|
app.use(
|
||||||
|
cors({
|
||||||
|
origin: (origin, callback) => {
|
||||||
|
// Only allow non-browser requests with no origin.
|
||||||
|
if (!origin) {
|
||||||
|
return callback(null, true);
|
||||||
|
}
|
||||||
|
return callback(
|
||||||
|
new CORSError('Request denied by CORS policy.'),
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
app.use((req, res, next) => {
|
||||||
|
const host = req.headers.host || '';
|
||||||
|
const allowedHosts = [
|
||||||
|
`localhost:${this.port}`,
|
||||||
|
`127.0.0.1:${this.port}`,
|
||||||
|
];
|
||||||
|
if (!allowedHosts.includes(host)) {
|
||||||
|
return res.status(403).json({ error: 'Invalid Host header' });
|
||||||
|
}
|
||||||
|
next();
|
||||||
|
});
|
||||||
|
|
||||||
app.use((req, res, next) => {
|
app.use((req, res, next) => {
|
||||||
const authHeader = req.headers.authorization;
|
const authHeader = req.headers.authorization;
|
||||||
if (authHeader) {
|
if (authHeader) {
|
||||||
@@ -274,7 +314,15 @@ export class IDEServer {
|
|||||||
|
|
||||||
app.get('/mcp', handleSessionRequest);
|
app.get('/mcp', handleSessionRequest);
|
||||||
|
|
||||||
this.server = app.listen(0, async () => {
|
app.use((err: Error, req: Request, res: Response, next: NextFunction) => {
|
||||||
|
if (err instanceof CORSError) {
|
||||||
|
res.status(403).json({ error: 'Request denied by CORS policy.' });
|
||||||
|
} else {
|
||||||
|
next(err);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
this.server = app.listen(0, '127.0.0.1', async () => {
|
||||||
const address = (this.server as HTTPServer).address();
|
const address = (this.server as HTTPServer).address();
|
||||||
if (address && typeof address !== 'string') {
|
if (address && typeof address !== 'string') {
|
||||||
this.port = address.port;
|
this.port = address.port;
|
||||||
@@ -286,7 +334,7 @@ export class IDEServer {
|
|||||||
os.tmpdir(),
|
os.tmpdir(),
|
||||||
`gemini-ide-server-${process.ppid}.json`,
|
`gemini-ide-server-${process.ppid}.json`,
|
||||||
);
|
);
|
||||||
this.log(`IDE server listening on port ${this.port}`);
|
this.log(`IDE server listening on http://127.0.0.1:${this.port}`);
|
||||||
|
|
||||||
if (this.authToken) {
|
if (this.authToken) {
|
||||||
await writePortAndWorkspace({
|
await writePortAndWorkspace({
|
||||||
|
|||||||
Reference in New Issue
Block a user