diff --git a/packages/vscode-ide-companion/src/ide-server.test.ts b/packages/vscode-ide-companion/src/ide-server.test.ts index 2367ac1551..614c338063 100644 --- a/packages/vscode-ide-companion/src/ide-server.test.ts +++ b/packages/vscode-ide-companion/src/ide-server.test.ts @@ -9,6 +9,7 @@ import type * as vscode from 'vscode'; import * as fs from 'node:fs/promises'; import type * as os from 'node:os'; import * as path from 'node:path'; +import * as http from 'node:http'; import { IDEServer } from './ide-server.js'; import type { DiffManager } from './diff-manager.js'; @@ -62,26 +63,26 @@ vi.mock('./open-files-manager', () => { return { OpenFilesManager }; }); +const getPortFromMock = ( + replaceMock: ReturnType< + () => vscode.ExtensionContext['environmentVariableCollection']['replace'] + >, +) => { + const port = vi + .mocked(replaceMock) + .mock.calls.find((call) => call[0] === 'GEMINI_CLI_IDE_SERVER_PORT')?.[1]; + + if (port === undefined) { + expect.fail('Port was not set'); + } + return port; +}; + describe('IDEServer', () => { let ideServer: IDEServer; let mockContext: vscode.ExtensionContext; let mockLog: (message: string) => void; - const getPortFromMock = ( - replaceMock: ReturnType< - () => vscode.ExtensionContext['environmentVariableCollection']['replace'] - >, - ) => { - const port = vi - .mocked(replaceMock) - .mock.calls.find((call) => call[0] === 'GEMINI_CLI_IDE_SERVER_PORT')?.[1]; - - if (port === undefined) { - expect.fail('Port was not set'); - } - return port; - }; - beforeEach(() => { mockLog = vi.fn(); ideServer = new IDEServer(mockLog, mocks.diffManager); @@ -456,3 +457,105 @@ describe('IDEServer', () => { }); }); }); + +const request = ( + port: string, + options: http.RequestOptions, + body?: string, +): Promise => + new Promise((resolve, reject) => { + const req = http.request( + { + hostname: '127.0.0.1', + port, + ...options, + }, + (res) => { + res.resume(); // Consume response data to free up memory + resolve(res); + }, + ); + req.on('error', reject); + if (body) { + req.write(body); + } + req.end(); + }); + +describe('IDEServer HTTP endpoints', () => { + let ideServer: IDEServer; + let mockContext: vscode.ExtensionContext; + let mockLog: (message: string) => void; + let port: string; + + beforeEach(async () => { + mockLog = vi.fn(); + ideServer = new IDEServer(mockLog, mocks.diffManager); + mockContext = { + subscriptions: [], + environmentVariableCollection: { + replace: vi.fn(), + clear: vi.fn(), + }, + } as unknown as vscode.ExtensionContext; + await ideServer.start(mockContext); + const replaceMock = mockContext.environmentVariableCollection.replace; + port = getPortFromMock(replaceMock); + }); + + afterEach(async () => { + await ideServer.stop(); + vi.restoreAllMocks(); + }); + + it('should deny requests with an origin header', async () => { + const response = await request( + port, + { + path: '/mcp', + method: 'POST', + headers: { + Host: `localhost:${port}`, + Origin: 'https://evil.com', + 'Content-Type': 'application/json', + }, + }, + JSON.stringify({ jsonrpc: '2.0', method: 'initialize' }), + ); + expect(response.statusCode).toBe(403); + }); + + it('should deny requests with an invalid host header', async () => { + const response = await request( + port, + { + path: '/mcp', + method: 'POST', + headers: { + Host: 'evil.com', + 'Content-Type': 'application/json', + }, + }, + JSON.stringify({ jsonrpc: '2.0', method: 'initialize' }), + ); + expect(response.statusCode).toBe(403); + }); + + it('should allow requests with a valid host header', async () => { + const response = await request( + port, + { + path: '/mcp', + method: 'POST', + headers: { + Host: `localhost:${port}`, + 'Content-Type': 'application/json', + }, + }, + JSON.stringify({ jsonrpc: '2.0', method: 'initialize' }), + ); + // We expect a 400 here because we are not sending a valid MCP request, + // but it's not a host error, which is what we are testing. + expect(response.statusCode).toBe(400); + }); +}); diff --git a/packages/vscode-ide-companion/src/ide-server.ts b/packages/vscode-ide-companion/src/ide-server.ts index c1bb85d3fe..12dbaac6bd 100644 --- a/packages/vscode-ide-companion/src/ide-server.ts +++ b/packages/vscode-ide-companion/src/ide-server.ts @@ -13,7 +13,12 @@ import { import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js'; import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; -import express, { type Request, type Response } from 'express'; +import express, { + type Request, + type Response, + type NextFunction, +} from 'express'; +import cors from 'cors'; import { randomUUID } from 'node:crypto'; import { type Server as HTTPServer } from 'node:http'; import * as path from 'node:path'; @@ -23,6 +28,13 @@ import type { z } from 'zod'; import type { DiffManager } from './diff-manager.js'; import { OpenFilesManager } from './open-files-manager.js'; +class CORSError extends Error { + constructor(message: string) { + super(message); + this.name = 'CORSError'; + } +} + const MCP_SESSION_ID_HEADER = 'mcp-session-id'; const IDE_SERVER_PORT_ENV_VAR = 'GEMINI_CLI_IDE_SERVER_PORT'; const IDE_WORKSPACE_PATH_ENV_VAR = 'GEMINI_CLI_IDE_WORKSPACE_PATH'; @@ -131,6 +143,34 @@ export class IDEServer { const app = express(); app.use(express.json({ limit: '10mb' })); + + app.use( + cors({ + origin: (origin, callback) => { + // Only allow non-browser requests with no origin. + if (!origin) { + return callback(null, true); + } + return callback( + new CORSError('Request denied by CORS policy.'), + false, + ); + }, + }), + ); + + app.use((req, res, next) => { + const host = req.headers.host || ''; + const allowedHosts = [ + `localhost:${this.port}`, + `127.0.0.1:${this.port}`, + ]; + if (!allowedHosts.includes(host)) { + return res.status(403).json({ error: 'Invalid Host header' }); + } + next(); + }); + app.use((req, res, next) => { const authHeader = req.headers.authorization; if (authHeader) { @@ -274,7 +314,15 @@ export class IDEServer { app.get('/mcp', handleSessionRequest); - this.server = app.listen(0, async () => { + app.use((err: Error, req: Request, res: Response, next: NextFunction) => { + if (err instanceof CORSError) { + res.status(403).json({ error: 'Request denied by CORS policy.' }); + } else { + next(err); + } + }); + + this.server = app.listen(0, '127.0.0.1', async () => { const address = (this.server as HTTPServer).address(); if (address && typeof address !== 'string') { this.port = address.port; @@ -286,7 +334,7 @@ export class IDEServer { os.tmpdir(), `gemini-ide-server-${process.ppid}.json`, ); - this.log(`IDE server listening on port ${this.port}`); + this.log(`IDE server listening on http://127.0.0.1:${this.port}`); if (this.authToken) { await writePortAndWorkspace({