diff --git a/docs/changelogs/preview.md b/docs/changelogs/preview.md index da20f5d441..19ff7f8210 100644 --- a/docs/changelogs/preview.md +++ b/docs/changelogs/preview.md @@ -1,6 +1,6 @@ -# Preview release: v0.34.0-preview.0 +# Preview release: v0.34.0-preview.1 -Released: March 11, 2026 +Released: March 12, 2026 Our preview release includes the latest, new, and experimental features. This release may not be as stable as our [latest weekly release](latest.md). @@ -28,6 +28,9 @@ npm install -g @google/gemini-cli@preview ## What's Changed +- fix(patch): cherry-pick 45faf4d to release/v0.34.0-preview.0-pr-22148 + [CONFLICTS] by @gemini-cli-robot in + [#22174](https://github.com/google-gemini/gemini-cli/pull/22174) - feat(cli): add chat resume footer on session quit by @lordshashank in [#20667](https://github.com/google-gemini/gemini-cli/pull/20667) - Support bold and other styles in svg snapshots by @jacob314 in @@ -465,4 +468,4 @@ npm install -g @google/gemini-cli@preview [#21938](https://github.com/google-gemini/gemini-cli/pull/21938) **Full Changelog**: -https://github.com/google-gemini/gemini-cli/compare/v0.33.0-preview.15...v0.34.0-preview.0 +https://github.com/google-gemini/gemini-cli/compare/v0.33.0-preview.15...v0.34.0-preview.1 diff --git a/docs/cli/model-routing.md b/docs/cli/model-routing.md index 1f7ba5da09..3c7bd65bc5 100644 --- a/docs/cli/model-routing.md +++ b/docs/cli/model-routing.md @@ -26,6 +26,20 @@ policies. the CLI will use an available fallback model for the current turn or the remainder of the session. +### Local Model Routing (Experimental) + +Gemini CLI supports using a local model for routing decisions. When configured, +Gemini CLI will use a locally-running **Gemma** model to make routing decisions +(instead of sending routing decisions to a hosted model). This feature can help +reduce costs associated with hosted model usage while offering similar routing +decision latency and quality. + +In order to use this feature, the local Gemma model **must** be served behind a +Gemini API and accessible via HTTP at an endpoint configured in `settings.json`. + +For more details on how to configure local model routing, see +[Local Model Routing](../core/local-model-routing.md). + ### Model selection precedence The model used by Gemini CLI is determined by the following order of precedence: @@ -38,5 +52,8 @@ The model used by Gemini CLI is determined by the following order of precedence: 3. **`model.name` in `settings.json`:** If neither of the above are set, the model specified in the `model.name` property of your `settings.json` file will be used. -4. **Default model:** If none of the above are set, the default model will be +4. **Local model (experimental):** If the Gemma local model router is enabled + in your `settings.json` file, the CLI will use the local Gemma model + (instead of Gemini models) to route the request to an appropriate model. +5. **Default model:** If none of the above are set, the default model will be used. The default model is `auto` diff --git a/docs/core/index.md b/docs/core/index.md index adf186116f..afa13787b8 100644 --- a/docs/core/index.md +++ b/docs/core/index.md @@ -15,6 +15,8 @@ requests sent from `packages/cli`. For a general overview of Gemini CLI, see the modular GEMINI.md import feature using @file.md syntax. - **[Policy Engine](../reference/policy-engine.md):** Use the Policy Engine for fine-grained control over tool execution. +- **[Local Model Routing (experimental)](./local-model-routing.md):** Learn how + to enable use of a local Gemma model for model routing decisions. ## Role of the core diff --git a/docs/core/local-model-routing.md b/docs/core/local-model-routing.md new file mode 100644 index 0000000000..99f52511b0 --- /dev/null +++ b/docs/core/local-model-routing.md @@ -0,0 +1,193 @@ +# Local Model Routing (experimental) + +Gemini CLI supports using a local model for +[routing decisions](../cli/model-routing.md). When configured, Gemini CLI will +use a locally-running **Gemma** model to make routing decisions (instead of +sending routing decisions to a hosted model). + +This feature can help reduce costs associated with hosted model usage while +offering similar routing decision latency and quality. + +> **Note: Local model routing is currently an experimental feature.** + +## Setup + +Using a Gemma model for routing decisions requires that an implementation of a +Gemma model be running locally on your machine, served behind an HTTP endpoint +and accessed via the Gemini API. + +To serve the Gemma model, follow these steps: + +### Download the LiteRT-LM runtime + +The [LiteRT-LM](https://github.com/google-ai-edge/LiteRT-LM) runtime offers +pre-built binaries for locally-serving models. Download the binary appropriate +for your system. + +#### Windows + +1. Download + [lit.windows_x86_64.exe](https://github.com/google-ai-edge/LiteRT-LM/releases/download/v0.9.0-alpha03/lit.windows_x86_64.exe). +2. Using GPU on Windows requires the DirectXShaderCompiler. Download the + [dxc zip from the latest release](https://github.com/microsoft/DirectXShaderCompiler/releases/download/v1.8.2505.1/dxc_2025_07_14.zip). + Unzip the archive and from the architecture-appropriate `bin\` directory, and + copy the `dxil.dll` and `dxcompiler.dll` into the same location as you saved + `lit.windows_x86_64.exe`. +3. (Optional) Test starting the runtime: + `.\lit.windows_x86_64.exe serve --verbose` + +#### Linux + +1. Download + [lit.linux_x86_64](https://github.com/google-ai-edge/LiteRT-LM/releases/download/v0.9.0-alpha03/lit.linux_x86_64). +2. Ensure the binary is executable: `chmod a+x lit.linux_x86_64` +3. (Optional) Test starting the runtime: `./lit.linux_x86_64 serve --verbose` + +#### MacOS + +1. Download + [lit-macos-arm64](https://github.com/google-ai-edge/LiteRT-LM/releases/download/v0.9.0-alpha03/lit.macos_arm64). +2. Ensure the binary is executable: `chmod a+x lit.macos_arm64` +3. (Optional) Test starting the runtime: `./lit.macos_arm64 serve --verbose` + +> **Note**: MacOS can be configured to only allows binaries from "App Store & +> Known Developers". If you encounter an error message when attempting to run +> the binary, you will need to allow the application. One option is to visit +> `System Settings -> Privacy & Security`, scroll to `Security`, and click +> `"Allow Anyway"` for `"lit.macos_arm64"`. Another option is to run +> `xattr -d com.apple.quarantine lit.macos_arm64` from the commandline. + +### Download the Gemma Model + +Before using Gemma, you will need to download the model (and agree to the Terms +of Service). + +This can be done via the LiteRT-LM runtime. + +#### Windows + +```bash +$ .\lit.windows_x86_64.exe pull gemma3-1b-gpu-custom + +[Legal] The model you are about to download is governed by +the Gemma Terms of Use and Prohibited Use Policy. Please review these terms and ensure you agree before continuing. + +Full Terms: https://ai.google.dev/gemma/terms +Prohibited Use Policy: https://ai.google.dev/gemma/prohibited_use_policy + +Do you accept these terms? (Y/N): Y + +Terms accepted. +Downloading model 'gemma3-1b-gpu-custom' ... +Downloading... 968.6 MB +Download complete. +``` + +#### Linux + +```bash +$ ./lit.linux_x86_64 pull gemma3-1b-gpu-custom + +[Legal] The model you are about to download is governed by +the Gemma Terms of Use and Prohibited Use Policy. Please review these terms and ensure you agree before continuing. + +Full Terms: https://ai.google.dev/gemma/terms +Prohibited Use Policy: https://ai.google.dev/gemma/prohibited_use_policy + +Do you accept these terms? (Y/N): Y + +Terms accepted. +Downloading model 'gemma3-1b-gpu-custom' ... +Downloading... 968.6 MB +Download complete. +``` + +#### MacOS + +```bash +$ ./lit.lit.macos_arm64 pull gemma3-1b-gpu-custom + +[Legal] The model you are about to download is governed by +the Gemma Terms of Use and Prohibited Use Policy. Please review these terms and ensure you agree before continuing. + +Full Terms: https://ai.google.dev/gemma/terms +Prohibited Use Policy: https://ai.google.dev/gemma/prohibited_use_policy + +Do you accept these terms? (Y/N): Y + +Terms accepted. +Downloading model 'gemma3-1b-gpu-custom' ... +Downloading... 968.6 MB +Download complete. +``` + +### Start LiteRT-LM Runtime + +Using the command appropriate to your system, start the LiteRT-LM runtime. +Configure the port that you want to use for your Gemma model. For the purposes +of this document, we will use port `9379`. + +Example command for MacOS: `./lit.macos_arm64 serve --port=9379 --verbose` + +### (Optional) Verify Model Serving + +Send a quick prompt to the model via HTTP to validate successful model serving. +This will cause the runtime to download the model and run it once. + +You should see a short joke in the server output as an indicator of success. + +#### Windows + +``` +# Run this in PowerShell to send a request to the server + +$uri = "http://localhost:9379/v1beta/models/gemma3-1b-gpu-custom:generateContent" +$body = @{contents = @( @{ + role = "user" + parts = @( @{ text = "Tell me a joke." } ) +})} | ConvertTo-Json -Depth 10 + +Invoke-RestMethod -Uri $uri -Method Post -Body $body -ContentType "application/json" +``` + +#### Linux/MacOS + +```bash +$ curl "http://localhost:9379/v1beta/models/gemma3-1b-gpu-custom:generateContent" \ + -H 'Content-Type: application/json' \ + -X POST \ + -d '{"contents":[{"role":"user","parts":[{"text":"Tell me a joke."}]}]}' +``` + +## Configuration + +To use a local Gemma model for routing, you must explicitly enable it in your +`settings.json`: + +```json +{ + "experimental": { + "gemmaModelRouter": { + "enabled": true, + "classifier": { + "host": "http://localhost:9379", + "model": "gemma3-1b-gpu-custom" + } + } + } +} +``` + +> Use the port you started your LiteRT-LM runtime on in the setup steps. + +### Configuration schema + +| Field | Type | Required | Description | +| :----------------- | :------ | :------- | :----------------------------------------------------------------------------------------- | +| `enabled` | boolean | Yes | Must be `true` to enable the feature. | +| `classifier` | object | Yes | The configuration for the local model endpoint. It includes the host and model specifiers. | +| `classifier.host` | string | Yes | The URL to the local model server. Should be `http://localhost:`. | +| `classifier.model` | string | Yes | The model name to use for decisions. Must be `"gemma3-1b-gpu-custom"`. | + +> **Note: You will need to restart after configuration changes for local model +> routing to take effect.** diff --git a/esbuild.config.js b/esbuild.config.js index 49d158ec36..f0d55e3ca6 100644 --- a/esbuild.config.js +++ b/esbuild.config.js @@ -82,11 +82,14 @@ const commonAliases = { const cliConfig = { ...baseConfig, banner: { - js: `const require = (await import('node:module')).createRequire(import.meta.url); globalThis.__filename = (await import('node:url')).fileURLToPath(import.meta.url); globalThis.__dirname = (await import('node:path')).dirname(globalThis.__filename);`, + js: `const require = (await import('node:module')).createRequire(import.meta.url); const __chunk_filename = (await import('node:url')).fileURLToPath(import.meta.url); const __chunk_dirname = (await import('node:path')).dirname(__chunk_filename);`, }, - entryPoints: ['packages/cli/index.ts'], - outfile: 'bundle/gemini.js', + entryPoints: { gemini: 'packages/cli/index.ts' }, + outdir: 'bundle', + splitting: true, define: { + __filename: '__chunk_filename', + __dirname: '__chunk_dirname', 'process.env.CLI_VERSION': JSON.stringify(pkg.version), 'process.env.GEMINI_SANDBOX_IMAGE_DEFAULT': JSON.stringify( pkg.config?.sandboxImageUri, @@ -103,11 +106,13 @@ const cliConfig = { const a2aServerConfig = { ...baseConfig, banner: { - js: `const require = (await import('node:module')).createRequire(import.meta.url); globalThis.__filename = (await import('node:url')).fileURLToPath(import.meta.url); globalThis.__dirname = (await import('node:path')).dirname(globalThis.__filename);`, + js: `const require = (await import('node:module')).createRequire(import.meta.url); const __chunk_filename = (await import('node:url')).fileURLToPath(import.meta.url); const __chunk_dirname = (await import('node:path')).dirname(__chunk_filename);`, }, entryPoints: ['packages/a2a-server/src/http/server.ts'], outfile: 'packages/a2a-server/dist/a2a-server.mjs', define: { + __filename: '__chunk_filename', + __dirname: '__chunk_dirname', 'process.env.CLI_VERSION': JSON.stringify(pkg.version), }, plugins: createWasmPlugins(), diff --git a/eslint.config.js b/eslint.config.js index a0a0429119..d3a267f30a 100644 --- a/eslint.config.js +++ b/eslint.config.js @@ -35,11 +35,6 @@ const commonRestrictedSyntaxRules = [ message: 'Do not throw string literals or non-Error objects. Throw new Error("...") instead.', }, - { - selector: 'CallExpression[callee.name="fetch"]', - message: - 'Use safeFetch() from "@/utils/fetch" instead of the global fetch() to ensure SSRF protection. If you are implementing a custom security layer, use an eslint-disable comment and explain why.', - }, ]; export default tseslint.config( diff --git a/img.png b/img.png deleted file mode 100644 index ab9f0bafcd..0000000000 Binary files a/img.png and /dev/null differ diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 2985e20358..04a370d7e9 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -4,13 +4,38 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React from 'react'; -import { render } from 'ink'; -import { AppContainer } from './ui/AppContainer.js'; +import { + type StartupWarning, + WarningPriority, + type Config, + type ResumedSessionData, + type OutputPayload, + type ConsoleLogPayload, + type UserFeedbackPayload, + sessionId, + logUserPrompt, + AuthType, + UserPromptEvent, + coreEvents, + CoreEvent, + getOauthClient, + patchStdio, + writeToStdout, + writeToStderr, + shouldEnterAlternateScreen, + startupProfiler, + ExitCodes, + SessionStartSource, + SessionEndReason, + ValidationCancelledError, + ValidationRequiredError, + type AdminControlsSettings, + debugLogger, +} from '@google/gemini-cli-core'; + import { loadCliConfig, parseArguments } from './config/config.js'; import * as cliConfig from './config/config.js'; import { readStdin } from './utils/readStdin.js'; -import { basename } from 'node:path'; import { createHash } from 'node:crypto'; import v8 from 'node:v8'; import os from 'node:os'; @@ -37,47 +62,11 @@ import { runExitCleanup, registerTelemetryConfig, setupSignalHandlers, - setupTtyCheck, } from './utils/cleanup.js'; import { cleanupToolOutputFiles, cleanupExpiredSessions, } from './utils/sessionCleanup.js'; -import { - type StartupWarning, - WarningPriority, - type Config, - type ResumedSessionData, - type OutputPayload, - type ConsoleLogPayload, - type UserFeedbackPayload, - sessionId, - logUserPrompt, - AuthType, - getOauthClient, - UserPromptEvent, - debugLogger, - recordSlowRender, - coreEvents, - CoreEvent, - createWorkingStdio, - patchStdio, - writeToStdout, - writeToStderr, - disableMouseEvents, - enableMouseEvents, - disableLineWrapping, - enableLineWrapping, - shouldEnterAlternateScreen, - startupProfiler, - ExitCodes, - SessionStartSource, - SessionEndReason, - getVersion, - ValidationCancelledError, - ValidationRequiredError, - type AdminControlsSettings, -} from '@google/gemini-cli-core'; import { initializeApp, type InitializationResult, @@ -85,21 +74,9 @@ import { import { validateAuthMethod } from './config/auth.js'; import { runAcpClient } from './acp/acpClient.js'; import { validateNonInteractiveAuth } from './validateNonInterActiveAuth.js'; -import { checkForUpdates } from './ui/utils/updateCheck.js'; -import { handleAutoUpdate } from './utils/handleAutoUpdate.js'; import { appEvents, AppEvent } from './utils/events.js'; import { SessionError, SessionSelector } from './utils/sessionUtils.js'; -import { SettingsContext } from './ui/contexts/SettingsContext.js'; -import { MouseProvider } from './ui/contexts/MouseContext.js'; -import { StreamingState } from './ui/types.js'; -import { computeTerminalTitle } from './utils/windowTitle.js'; -import { SessionStatsProvider } from './ui/contexts/SessionContext.js'; -import { VimModeProvider } from './ui/contexts/VimModeContext.js'; -import { KeyMatchersProvider } from './ui/hooks/useKeyMatchers.js'; -import { loadKeyMatchers } from './ui/key/keyMatchers.js'; -import { KeypressProvider } from './ui/contexts/KeypressContext.js'; -import { useKittyKeyboardProtocol } from './ui/hooks/useKittyKeyboardProtocol.js'; import { relaunchAppInChildProcess, relaunchOnExitCode, @@ -107,19 +84,13 @@ import { import { loadSandboxConfig } from './config/sandboxConfig.js'; import { deleteSession, listSessions } from './utils/sessions.js'; import { createPolicyUpdater } from './config/policy.js'; -import { ScrollProvider } from './ui/contexts/ScrollProvider.js'; -import { TerminalProvider } from './ui/contexts/TerminalContext.js'; import { isAlternateBufferEnabled } from './ui/hooks/useAlternateBuffer.js'; -import { OverflowProvider } from './ui/contexts/OverflowContext.js'; import { setupTerminalAndTheme } from './utils/terminalTheme.js'; -import { profiler } from './ui/components/DebugProfiler.js'; import { runDeferredCommand } from './deferred.js'; import { cleanupBackgroundLogs } from './utils/logCleanup.js'; import { SlashCommandConflictHandler } from './services/SlashCommandConflictHandler.js'; -const SLOW_RENDER_MS = 200; - export function validateDnsResolutionOrder( order: string | undefined, ): DnsResolutionOrder { @@ -198,147 +169,16 @@ export async function startInteractiveUI( resumedSessionData: ResumedSessionData | undefined, initializationResult: InitializationResult, ) { - // Never enter Ink alternate buffer mode when screen reader mode is enabled - // as there is no benefit of alternate buffer mode when using a screen reader - // and the Ink alternate buffer mode requires line wrapping harmful to - // screen readers. - const useAlternateBuffer = shouldEnterAlternateScreen( - isAlternateBufferEnabled(config), - config.getScreenReader(), + // Dynamically import the heavy UI module so React/Ink are only parsed when needed + const { startInteractiveUI: doStartUI } = await import('./interactiveCli.js'); + await doStartUI( + config, + settings, + startupWarnings, + workspaceRoot, + resumedSessionData, + initializationResult, ); - const mouseEventsEnabled = useAlternateBuffer; - if (mouseEventsEnabled) { - enableMouseEvents(); - registerCleanup(() => { - disableMouseEvents(); - }); - } - - const { matchers, errors } = await loadKeyMatchers(); - errors.forEach((error) => { - coreEvents.emitFeedback('warning', error); - }); - - const version = await getVersion(); - setWindowTitle(basename(workspaceRoot), settings); - - const consolePatcher = new ConsolePatcher({ - onNewMessage: (msg) => { - coreEvents.emitConsoleLog(msg.type, msg.content); - }, - debugMode: config.getDebugMode(), - }); - consolePatcher.patch(); - registerCleanup(consolePatcher.cleanup); - - const { stdout: inkStdout, stderr: inkStderr } = createWorkingStdio(); - - const isShpool = !!process.env['SHPOOL_SESSION_NAME']; - - // Create wrapper component to use hooks inside render - const AppWrapper = () => { - useKittyKeyboardProtocol(); - - return ( - - - - - - - - - - - - - - - - - - - - ); - }; - - if (isShpool) { - // Wait a moment for shpool to stabilize terminal size and state. - // shpool is a persistence tool that restores terminal state by replaying it. - // This delay gives shpool time to finish its restoration replay and send - // the actual terminal size (often via an immediate SIGWINCH) before we - // render the first TUI frame. Without this, the first frame may be - // garbled or rendered at an incorrect size, which disabling incremental - // rendering alone cannot fix for the initial frame. - await new Promise((resolve) => setTimeout(resolve, 100)); - } - - const instance = render( - process.env['DEBUG'] ? ( - - - - ) : ( - - ), - { - stdout: inkStdout, - stderr: inkStderr, - stdin: process.stdin, - exitOnCtrlC: false, - isScreenReaderEnabled: config.getScreenReader(), - onRender: ({ renderTime }: { renderTime: number }) => { - if (renderTime > SLOW_RENDER_MS) { - recordSlowRender(config, renderTime); - } - profiler.reportFrameRendered(); - }, - patchConsole: false, - alternateBuffer: useAlternateBuffer, - incrementalRendering: - settings.merged.ui.incrementalRendering !== false && - useAlternateBuffer && - !isShpool, - }, - ); - - if (useAlternateBuffer) { - disableLineWrapping(); - registerCleanup(() => { - enableLineWrapping(); - }); - } - - checkForUpdates(settings) - .then((info) => { - handleAutoUpdate(info, settings, config.getProjectRoot()); - }) - .catch((err) => { - // Silently ignore update check errors. - if (config.getDebugMode()) { - debugLogger.warn('Update check failed:', err); - } - }); - - registerCleanup(() => instance.unmount()); - - registerCleanup(setupTtyCheck()); } export async function main() { @@ -845,25 +685,6 @@ export async function main() { } } -function setWindowTitle(title: string, settings: LoadedSettings) { - if (!settings.merged.ui.hideWindowTitle) { - // Initial state before React loop starts - const windowTitle = computeTerminalTitle({ - streamingState: StreamingState.Idle, - isConfirming: false, - isSilentWorking: false, - folderName: title, - showThoughts: !!settings.merged.ui.showStatusInTitle, - useDynamicTitle: settings.merged.ui.dynamicWindowTitle, - }); - writeToStdout(`\x1b]0;${windowTitle}\x07`); - - process.on('exit', () => { - writeToStdout(`\x1b]0;\x07`); - }); - } -} - export function initializeOutputListenersAndFlush() { // If there are no listeners for output, make sure we flush so output is not // lost. diff --git a/packages/cli/src/interactiveCli.tsx b/packages/cli/src/interactiveCli.tsx new file mode 100644 index 0000000000..a27cdbbb78 --- /dev/null +++ b/packages/cli/src/interactiveCli.tsx @@ -0,0 +1,214 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { render } from 'ink'; +import { basename } from 'node:path'; +import { AppContainer } from './ui/AppContainer.js'; +import { ConsolePatcher } from './ui/utils/ConsolePatcher.js'; +import { registerCleanup, setupTtyCheck } from './utils/cleanup.js'; +import { + type StartupWarning, + type Config, + type ResumedSessionData, + coreEvents, + createWorkingStdio, + disableMouseEvents, + enableMouseEvents, + disableLineWrapping, + enableLineWrapping, + shouldEnterAlternateScreen, + recordSlowRender, + writeToStdout, + getVersion, + debugLogger, +} from '@google/gemini-cli-core'; +import type { InitializationResult } from './core/initializer.js'; +import type { LoadedSettings } from './config/settings.js'; +import { checkForUpdates } from './ui/utils/updateCheck.js'; +import { handleAutoUpdate } from './utils/handleAutoUpdate.js'; +import { SettingsContext } from './ui/contexts/SettingsContext.js'; +import { MouseProvider } from './ui/contexts/MouseContext.js'; +import { StreamingState } from './ui/types.js'; +import { computeTerminalTitle } from './utils/windowTitle.js'; + +import { SessionStatsProvider } from './ui/contexts/SessionContext.js'; +import { VimModeProvider } from './ui/contexts/VimModeContext.js'; +import { KeyMatchersProvider } from './ui/hooks/useKeyMatchers.js'; +import { loadKeyMatchers } from './ui/key/keyMatchers.js'; +import { KeypressProvider } from './ui/contexts/KeypressContext.js'; +import { useKittyKeyboardProtocol } from './ui/hooks/useKittyKeyboardProtocol.js'; +import { ScrollProvider } from './ui/contexts/ScrollProvider.js'; +import { TerminalProvider } from './ui/contexts/TerminalContext.js'; +import { isAlternateBufferEnabled } from './ui/hooks/useAlternateBuffer.js'; +import { OverflowProvider } from './ui/contexts/OverflowContext.js'; +import { profiler } from './ui/components/DebugProfiler.js'; + +const SLOW_RENDER_MS = 200; + +export async function startInteractiveUI( + config: Config, + settings: LoadedSettings, + startupWarnings: StartupWarning[], + workspaceRoot: string = process.cwd(), + resumedSessionData: ResumedSessionData | undefined, + initializationResult: InitializationResult, +) { + // Never enter Ink alternate buffer mode when screen reader mode is enabled + // as there is no benefit of alternate buffer mode when using a screen reader + // and the Ink alternate buffer mode requires line wrapping harmful to + // screen readers. + const useAlternateBuffer = shouldEnterAlternateScreen( + isAlternateBufferEnabled(config), + config.getScreenReader(), + ); + const mouseEventsEnabled = useAlternateBuffer; + if (mouseEventsEnabled) { + enableMouseEvents(); + registerCleanup(() => { + disableMouseEvents(); + }); + } + + const { matchers, errors } = await loadKeyMatchers(); + errors.forEach((error) => { + coreEvents.emitFeedback('warning', error); + }); + + const version = await getVersion(); + setWindowTitle(basename(workspaceRoot), settings); + + const consolePatcher = new ConsolePatcher({ + onNewMessage: (msg) => { + coreEvents.emitConsoleLog(msg.type, msg.content); + }, + debugMode: config.getDebugMode(), + }); + consolePatcher.patch(); + registerCleanup(consolePatcher.cleanup); + + const { stdout: inkStdout, stderr: inkStderr } = createWorkingStdio(); + + const isShpool = !!process.env['SHPOOL_SESSION_NAME']; + + // Create wrapper component to use hooks inside render + const AppWrapper = () => { + useKittyKeyboardProtocol(); + + return ( + + + + + + + + + + + + + + + + + + + + ); + }; + + if (isShpool) { + // Wait a moment for shpool to stabilize terminal size and state. + await new Promise((resolve) => setTimeout(resolve, 100)); + } + + const instance = render( + process.env['DEBUG'] ? ( + + + + ) : ( + + ), + { + stdout: inkStdout, + stderr: inkStderr, + stdin: process.stdin, + exitOnCtrlC: false, + isScreenReaderEnabled: config.getScreenReader(), + onRender: ({ renderTime }: { renderTime: number }) => { + if (renderTime > SLOW_RENDER_MS) { + recordSlowRender(config, renderTime); + } + profiler.reportFrameRendered(); + }, + patchConsole: false, + alternateBuffer: useAlternateBuffer, + incrementalRendering: + settings.merged.ui.incrementalRendering !== false && + useAlternateBuffer && + !isShpool, + }, + ); + + if (useAlternateBuffer) { + disableLineWrapping(); + registerCleanup(() => { + enableLineWrapping(); + }); + } + + checkForUpdates(settings) + .then((info) => { + handleAutoUpdate(info, settings, config.getProjectRoot()); + }) + .catch((err) => { + // Silently ignore update check errors. + if (config.getDebugMode()) { + debugLogger.warn('Update check failed:', err); + } + }); + + registerCleanup(() => instance.unmount()); + + registerCleanup(setupTtyCheck()); +} + +function setWindowTitle(title: string, settings: LoadedSettings) { + if (!settings.merged.ui.hideWindowTitle) { + // Initial state before React loop starts + const windowTitle = computeTerminalTitle({ + streamingState: StreamingState.Idle, + isConfirming: false, + isSilentWorking: false, + folderName: title, + showThoughts: !!settings.merged.ui.showStatusInTitle, + useDynamicTitle: settings.merged.ui.dynamicWindowTitle, + }); + writeToStdout(`\x1b]0;${windowTitle}\x07`); + + process.on('exit', () => { + writeToStdout(`\x1b]0;\x07`); + }); + } +} diff --git a/packages/cli/src/ui/commands/setupGithubCommand.ts b/packages/cli/src/ui/commands/setupGithubCommand.ts index 2554ebaa60..c68dd5cb88 100644 --- a/packages/cli/src/ui/commands/setupGithubCommand.ts +++ b/packages/cli/src/ui/commands/setupGithubCommand.ts @@ -123,7 +123,6 @@ async function downloadFiles({ downloads.push( (async () => { const endpoint = `${REPO_DOWNLOAD_URL}/refs/tags/${releaseTag}/${SOURCE_DIR}/${fileBasename}`; - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(endpoint, { method: 'GET', dispatcher: proxy ? new ProxyAgent(proxy) : undefined, diff --git a/packages/cli/src/utils/gitUtils.ts b/packages/cli/src/utils/gitUtils.ts index 83d89ad164..e27673f0fe 100644 --- a/packages/cli/src/utils/gitUtils.ts +++ b/packages/cli/src/utils/gitUtils.ts @@ -61,7 +61,6 @@ export const getLatestGitHubRelease = async ( const endpoint = `https://api.github.com/repos/google-github-actions/run-gemini-cli/releases/latest`; - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(endpoint, { method: 'GET', headers: { diff --git a/packages/core/src/agents/a2a-client-manager.test.ts b/packages/core/src/agents/a2a-client-manager.test.ts index aab0de5506..0a0aa4d956 100644 --- a/packages/core/src/agents/a2a-client-manager.test.ts +++ b/packages/core/src/agents/a2a-client-manager.test.ts @@ -5,11 +5,8 @@ */ import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest'; -import { - A2AClientManager, - type SendMessageResult, -} from './a2a-client-manager.js'; -import type { AgentCard, Task } from '@a2a-js/sdk'; +import { A2AClientManager } from './a2a-client-manager.js'; +import type { AgentCard } from '@a2a-js/sdk'; import { ClientFactory, DefaultAgentCardResolver, @@ -22,81 +19,95 @@ import type { Config } from '../config/config.js'; import { Agent as UndiciAgent, ProxyAgent } from 'undici'; import { debugLogger } from '../utils/debugLogger.js'; +interface MockClient { + sendMessageStream: ReturnType; + getTask: ReturnType; + cancelTask: ReturnType; +} + +vi.mock('@a2a-js/sdk/client', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...(actual as Record), + createAuthenticatingFetchWithRetry: vi.fn(), + ClientFactory: vi.fn(), + DefaultAgentCardResolver: vi.fn(), + ClientFactoryOptions: { + createFrom: vi.fn(), + default: {}, + }, + }; +}); + vi.mock('../utils/debugLogger.js', () => ({ debugLogger: { debug: vi.fn(), }, })); -vi.mock('@a2a-js/sdk/client', () => { - const ClientFactory = vi.fn(); - const DefaultAgentCardResolver = vi.fn(); - const RestTransportFactory = vi.fn(); - const JsonRpcTransportFactory = vi.fn(); - const ClientFactoryOptions = { - default: {}, - createFrom: vi.fn(), - }; - const createAuthenticatingFetchWithRetry = vi.fn(); - - DefaultAgentCardResolver.prototype.resolve = vi.fn(); - ClientFactory.prototype.createFromUrl = vi.fn(); - - return { - ClientFactory, - ClientFactoryOptions, - DefaultAgentCardResolver, - RestTransportFactory, - JsonRpcTransportFactory, - createAuthenticatingFetchWithRetry, - }; -}); - describe('A2AClientManager', () => { let manager: A2AClientManager; + const mockAgentCard: AgentCard = { + name: 'test-agent', + description: 'A test agent', + url: 'http://test.agent', + version: '1.0.0', + protocolVersion: '0.1.0', + capabilities: {}, + skills: [], + defaultInputModes: [], + defaultOutputModes: [], + }; + + const mockClient: MockClient = { + sendMessageStream: vi.fn(), + getTask: vi.fn(), + cancelTask: vi.fn(), + }; - // Stable mocks initialized once - const sendMessageStreamMock = vi.fn(); - const getTaskMock = vi.fn(); - const cancelTaskMock = vi.fn(); - const getAgentCardMock = vi.fn(); const authFetchMock = vi.fn(); - const mockClient = { - sendMessageStream: sendMessageStreamMock, - getTask: getTaskMock, - cancelTask: cancelTaskMock, - getAgentCard: getAgentCardMock, - } as unknown as Client; - - const mockAgentCard: Partial = { name: 'TestAgent' }; - beforeEach(() => { vi.clearAllMocks(); A2AClientManager.resetInstanceForTesting(); manager = A2AClientManager.getInstance(); - // Default mock implementations - getAgentCardMock.mockResolvedValue({ + // Re-create the instances as plain objects that can be spied on + const factoryInstance = { + createFromUrl: vi.fn(), + createFromAgentCard: vi.fn(), + }; + const resolverInstance = { + resolve: vi.fn(), + }; + + vi.mocked(ClientFactory).mockReturnValue( + factoryInstance as unknown as ClientFactory, + ); + vi.mocked(DefaultAgentCardResolver).mockReturnValue( + resolverInstance as unknown as DefaultAgentCardResolver, + ); + + vi.spyOn(factoryInstance, 'createFromUrl').mockResolvedValue( + mockClient as unknown as Client, + ); + vi.spyOn(factoryInstance, 'createFromAgentCard').mockResolvedValue( + mockClient as unknown as Client, + ); + vi.spyOn(resolverInstance, 'resolve').mockResolvedValue({ ...mockAgentCard, url: 'http://test.agent/real/endpoint', } as AgentCard); - vi.mocked(ClientFactory.prototype.createFromUrl).mockResolvedValue( - mockClient, + vi.spyOn(ClientFactoryOptions, 'createFrom').mockImplementation( + (_defaults, overrides) => overrides as unknown as ClientFactoryOptions, ); - vi.mocked(DefaultAgentCardResolver.prototype.resolve).mockResolvedValue({ - ...mockAgentCard, - url: 'http://test.agent/real/endpoint', - } as AgentCard); - - vi.mocked(ClientFactoryOptions.createFrom).mockImplementation( - (_defaults, overrides) => overrides as ClientFactoryOptions, - ); - - vi.mocked(createAuthenticatingFetchWithRetry).mockReturnValue( - authFetchMock, + vi.mocked(createAuthenticatingFetchWithRetry).mockImplementation(() => + authFetchMock.mockResolvedValue({ + ok: true, + json: async () => ({}), + } as Response), ); vi.stubGlobal( @@ -170,15 +181,19 @@ describe('A2AClientManager', () => { 'TestAgent', 'http://test.agent/card', ); - expect(agentCard).toMatchObject(mockAgentCard); expect(manager.getAgentCard('TestAgent')).toBe(agentCard); expect(manager.getClient('TestAgent')).toBeDefined(); }); + it('should configure ClientFactory with REST, JSON-RPC, and gRPC transports', async () => { + await manager.loadAgent('TestAgent', 'http://test.agent/card'); + expect(ClientFactoryOptions.createFrom).toHaveBeenCalled(); + }); + it('should throw an error if an agent with the same name is already loaded', async () => { await manager.loadAgent('TestAgent', 'http://test.agent/card'); await expect( - manager.loadAgent('TestAgent', 'http://another.agent/card'), + manager.loadAgent('TestAgent', 'http://test.agent/card'), ).rejects.toThrow("Agent with name 'TestAgent' is already loaded."); }); @@ -193,20 +208,12 @@ describe('A2AClientManager', () => { shouldRetryWithHeaders: vi.fn(), }; await manager.loadAgent( - 'CustomAuthAgent', - 'http://custom.agent/card', + 'TestAgent', + 'http://test.agent/card', customAuthHandler as unknown as AuthenticationHandler, ); - expect(createAuthenticatingFetchWithRetry).toHaveBeenCalledWith( - expect.anything(), - customAuthHandler, - ); - // Card resolver should NOT use the authenticated fetch by default. - const resolverInstance = vi.mocked(DefaultAgentCardResolver).mock - .instances[0]; - expect(resolverInstance).toBeDefined(); const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock .calls[0][0]; expect(resolverOptions?.fetchImpl).not.toBe(authFetchMock); @@ -267,106 +274,163 @@ describe('A2AClientManager', () => { it('should log a debug message upon loading an agent', async () => { await manager.loadAgent('TestAgent', 'http://test.agent/card'); expect(debugLogger.debug).toHaveBeenCalledWith( - "[A2AClientManager] Loaded agent 'TestAgent' from http://test.agent/card", + expect.stringContaining("Loaded agent 'TestAgent'"), ); }); it('should clear the cache', async () => { await manager.loadAgent('TestAgent', 'http://test.agent/card'); - expect(manager.getAgentCard('TestAgent')).toBeDefined(); - expect(manager.getClient('TestAgent')).toBeDefined(); - manager.clearCache(); - expect(manager.getAgentCard('TestAgent')).toBeUndefined(); expect(manager.getClient('TestAgent')).toBeUndefined(); - expect(debugLogger.debug).toHaveBeenCalledWith( - '[A2AClientManager] Cache cleared.', + }); + + it('should throw if resolveAgentCard fails', async () => { + const resolverInstance = { + resolve: vi.fn().mockRejectedValue(new Error('Resolution failed')), + }; + vi.mocked(DefaultAgentCardResolver).mockReturnValue( + resolverInstance as unknown as DefaultAgentCardResolver, ); + + await expect( + manager.loadAgent('FailAgent', 'http://fail.agent'), + ).rejects.toThrow('Resolution failed'); + }); + + it('should throw if factory.createFromAgentCard fails', async () => { + const factoryInstance = { + createFromAgentCard: vi + .fn() + .mockRejectedValue(new Error('Factory failed')), + }; + vi.mocked(ClientFactory).mockReturnValue( + factoryInstance as unknown as ClientFactory, + ); + + await expect( + manager.loadAgent('FailAgent', 'http://fail.agent'), + ).rejects.toThrow('Factory failed'); + }); + }); + + describe('getAgentCard and getClient', () => { + it('should return undefined if agent is not found', () => { + expect(manager.getAgentCard('Unknown')).toBeUndefined(); + expect(manager.getClient('Unknown')).toBeUndefined(); }); }); describe('sendMessageStream', () => { beforeEach(async () => { - await manager.loadAgent('TestAgent', 'http://test.agent'); + await manager.loadAgent('TestAgent', 'http://test.agent/card'); }); it('should send a message and return a stream', async () => { - const mockResult = { - kind: 'message', - messageId: 'a', - parts: [], - role: 'agent', - } as SendMessageResult; - - sendMessageStreamMock.mockReturnValue( + mockClient.sendMessageStream.mockReturnValue( (async function* () { - yield mockResult; + yield { kind: 'message' }; })(), ); const stream = manager.sendMessageStream('TestAgent', 'Hello'); const results = []; - for await (const res of stream) { - results.push(res); + for await (const result of stream) { + results.push(result); } - expect(results).toEqual([mockResult]); - expect(sendMessageStreamMock).toHaveBeenCalledWith( + expect(results).toHaveLength(1); + expect(mockClient.sendMessageStream).toHaveBeenCalled(); + }); + + it('should use contextId and taskId when provided', async () => { + mockClient.sendMessageStream.mockReturnValue( + (async function* () { + yield { kind: 'message' }; + })(), + ); + + const stream = manager.sendMessageStream('TestAgent', 'Hello', { + contextId: 'ctx123', + taskId: 'task456', + }); + // trigger execution + for await (const _ of stream) { + break; + } + + expect(mockClient.sendMessageStream).toHaveBeenCalledWith( expect.objectContaining({ - message: expect.anything(), + message: expect.objectContaining({ + contextId: 'ctx123', + taskId: 'task456', + }), }), expect.any(Object), ); }); - it('should use contextId and taskId when provided', async () => { - sendMessageStreamMock.mockReturnValue( + it('should correctly propagate AbortSignal to the stream', async () => { + mockClient.sendMessageStream.mockReturnValue( (async function* () { - yield { - kind: 'message', - messageId: 'a', - parts: [], - role: 'agent', - } as SendMessageResult; + yield { kind: 'message' }; })(), ); - const expectedContextId = 'user-context-id'; - const expectedTaskId = 'user-task-id'; - + const controller = new AbortController(); const stream = manager.sendMessageStream('TestAgent', 'Hello', { - contextId: expectedContextId, - taskId: expectedTaskId, + signal: controller.signal, }); - + // trigger execution for await (const _ of stream) { - // consume stream + break; } - const call = sendMessageStreamMock.mock.calls[0][0]; - expect(call.message.contextId).toBe(expectedContextId); - expect(call.message.taskId).toBe(expectedTaskId); + expect(mockClient.sendMessageStream).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ signal: controller.signal }), + ); }); - it('should propagate the original error on failure', async () => { - sendMessageStreamMock.mockImplementationOnce(() => { - throw new Error('Network error'); + it('should handle a multi-chunk stream with different event types', async () => { + mockClient.sendMessageStream.mockReturnValue( + (async function* () { + yield { kind: 'message', messageId: 'm1' }; + yield { kind: 'status-update', taskId: 't1' }; + })(), + ); + + const stream = manager.sendMessageStream('TestAgent', 'Hello'); + const results = []; + for await (const result of stream) { + results.push(result); + } + + expect(results).toHaveLength(2); + expect(results[0].kind).toBe('message'); + expect(results[1].kind).toBe('status-update'); + }); + + it('should throw prefixed error on failure', async () => { + mockClient.sendMessageStream.mockImplementation(() => { + throw new Error('Network failure'); }); const stream = manager.sendMessageStream('TestAgent', 'Hello'); await expect(async () => { for await (const _ of stream) { - // consume + // empty } - }).rejects.toThrow('Network error'); + }).rejects.toThrow( + '[A2AClientManager] sendMessageStream Error [TestAgent]: Network failure', + ); }); it('should throw an error if the agent is not found', async () => { const stream = manager.sendMessageStream('NonExistentAgent', 'Hello'); await expect(async () => { for await (const _ of stream) { - // consume + // empty } }).rejects.toThrow("Agent 'NonExistentAgent' not found."); }); @@ -374,28 +438,23 @@ describe('A2AClientManager', () => { describe('getTask', () => { beforeEach(async () => { - await manager.loadAgent('TestAgent', 'http://test.agent'); + await manager.loadAgent('TestAgent', 'http://test.agent/card'); }); it('should get a task from the correct agent', async () => { - getTaskMock.mockResolvedValue({ - id: 'task123', - contextId: 'a', - kind: 'task', - status: { state: 'completed' }, - } as Task); + const mockTask = { id: 'task123', kind: 'task' }; + mockClient.getTask.mockResolvedValue(mockTask); - await manager.getTask('TestAgent', 'task123'); - expect(getTaskMock).toHaveBeenCalledWith({ - id: 'task123', - }); + const result = await manager.getTask('TestAgent', 'task123'); + expect(result).toBe(mockTask); + expect(mockClient.getTask).toHaveBeenCalledWith({ id: 'task123' }); }); it('should throw prefixed error on failure', async () => { - getTaskMock.mockRejectedValueOnce(new Error('Network error')); + mockClient.getTask.mockRejectedValue(new Error('Not found')); await expect(manager.getTask('TestAgent', 'task123')).rejects.toThrow( - 'A2AClient getTask Error [TestAgent]: Network error', + 'A2AClient getTask Error [TestAgent]: Not found', ); }); @@ -408,28 +467,23 @@ describe('A2AClientManager', () => { describe('cancelTask', () => { beforeEach(async () => { - await manager.loadAgent('TestAgent', 'http://test.agent'); + await manager.loadAgent('TestAgent', 'http://test.agent/card'); }); it('should cancel a task on the correct agent', async () => { - cancelTaskMock.mockResolvedValue({ - id: 'task123', - contextId: 'a', - kind: 'task', - status: { state: 'canceled' }, - } as Task); + const mockTask = { id: 'task123', kind: 'task' }; + mockClient.cancelTask.mockResolvedValue(mockTask); - await manager.cancelTask('TestAgent', 'task123'); - expect(cancelTaskMock).toHaveBeenCalledWith({ - id: 'task123', - }); + const result = await manager.cancelTask('TestAgent', 'task123'); + expect(result).toBe(mockTask); + expect(mockClient.cancelTask).toHaveBeenCalledWith({ id: 'task123' }); }); it('should throw prefixed error on failure', async () => { - cancelTaskMock.mockRejectedValueOnce(new Error('Network error')); + mockClient.cancelTask.mockRejectedValue(new Error('Cannot cancel')); await expect(manager.cancelTask('TestAgent', 'task123')).rejects.toThrow( - 'A2AClient cancelTask Error [TestAgent]: Network error', + 'A2AClient cancelTask Error [TestAgent]: Cannot cancel', ); }); diff --git a/packages/core/src/agents/a2a-client-manager.ts b/packages/core/src/agents/a2a-client-manager.ts index 7d558e7dbe..3a03c033d8 100644 --- a/packages/core/src/agents/a2a-client-manager.ts +++ b/packages/core/src/agents/a2a-client-manager.ts @@ -12,36 +12,41 @@ import type { TaskStatusUpdateEvent, TaskArtifactUpdateEvent, } from '@a2a-js/sdk'; +import type { AuthenticationHandler, Client } from '@a2a-js/sdk/client'; import { - type Client, ClientFactory, ClientFactoryOptions, DefaultAgentCardResolver, - RestTransportFactory, JsonRpcTransportFactory, - type AuthenticationHandler, + RestTransportFactory, createAuthenticatingFetchWithRetry, } from '@a2a-js/sdk/client'; +import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc'; +import * as grpc from '@grpc/grpc-js'; import { v4 as uuidv4 } from 'uuid'; import { Agent as UndiciAgent, ProxyAgent } from 'undici'; +import { normalizeAgentCard } from './a2aUtils.js'; import type { Config } from '../config/config.js'; import { debugLogger } from '../utils/debugLogger.js'; -import { safeLookup } from '../utils/fetch.js'; import { classifyAgentError } from './a2a-errors.js'; -// Remote agents can take 10+ minutes (e.g. Deep Research). -// Use a dedicated dispatcher so the global 5-min timeout isn't affected. -const A2A_TIMEOUT = 1800000; // 30 minutes - +/** + * Result of sending a message, which can be a full message, a task, + * or an incremental status/artifact update. + */ export type SendMessageResult = | Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent; +// Remote agents can take 10+ minutes (e.g. Deep Research). +// Use a dedicated dispatcher so the global 5-min timeout isn't affected. +const A2A_TIMEOUT = 1800000; // 30 minutes + /** - * Manages A2A clients and caches loaded agent information. - * Follows a singleton pattern to ensure a single client instance. + * Orchestrates communication with remote A2A agents. + * Manages protocol negotiation, authentication, and transport selection. */ export class A2AClientManager { private static instance: A2AClientManager; @@ -58,9 +63,6 @@ export class A2AClientManager { const agentOptions = { headersTimeout: A2A_TIMEOUT, bodyTimeout: A2A_TIMEOUT, - connect: { - lookup: safeLookup, // SSRF protection at connection level - }, }; if (proxyUrl) { @@ -73,7 +75,6 @@ export class A2AClientManager { } this.a2aFetch = (input, init) => - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection fetch(input, { ...init, dispatcher: this.a2aDispatcher } as RequestInit); } @@ -139,22 +140,35 @@ export class A2AClientManager { }; const resolver = new DefaultAgentCardResolver({ fetchImpl: cardFetch }); + const rawCard = await resolver.resolve(agentCardUrl, ''); + // TODO: Remove normalizeAgentCard once @a2a-js/sdk handles + // proto field name aliases (supportedInterfaces → additionalInterfaces, + // protocolBinding → transport). + const agentCard = normalizeAgentCard(rawCard); - const options = ClientFactoryOptions.createFrom( + const grpcUrl = + agentCard.additionalInterfaces?.find((i) => i.transport === 'GRPC') + ?.url ?? agentCard.url; + + const clientOptions = ClientFactoryOptions.createFrom( ClientFactoryOptions.default, { transports: [ new RestTransportFactory({ fetchImpl: authFetch }), new JsonRpcTransportFactory({ fetchImpl: authFetch }), + new GrpcTransportFactory({ + grpcChannelCredentials: grpcUrl.startsWith('https://') + ? grpc.credentials.createSsl() + : grpc.credentials.createInsecure(), + }), ], cardResolver: resolver, }, ); try { - const factory = new ClientFactory(options); - const client = await factory.createFromUrl(agentCardUrl, ''); - const agentCard = await client.getAgentCard(); + const factory = new ClientFactory(clientOptions); + const client = await factory.createFromAgentCard(agentCard); this.clients.set(name, client); this.agentCards.set(name, agentCard); @@ -192,9 +206,7 @@ export class A2AClientManager { options?: { contextId?: string; taskId?: string; signal?: AbortSignal }, ): AsyncIterable { const client = this.clients.get(agentName); - if (!client) { - throw new Error(`Agent '${agentName}' not found.`); - } + if (!client) throw new Error(`Agent '${agentName}' not found.`); const messageParams: MessageSendParams = { message: { @@ -207,9 +219,19 @@ export class A2AClientManager { }, }; - yield* client.sendMessageStream(messageParams, { - signal: options?.signal, - }); + try { + yield* client.sendMessageStream(messageParams, { + signal: options?.signal, + }); + } catch (error: unknown) { + const prefix = `[A2AClientManager] sendMessageStream Error [${agentName}]`; + if (error instanceof Error) { + throw new Error(`${prefix}: ${error.message}`, { cause: error }); + } + throw new Error( + `${prefix}: Unexpected error during sendMessageStream: ${String(error)}`, + ); + } } /** @@ -238,9 +260,7 @@ export class A2AClientManager { */ async getTask(agentName: string, taskId: string): Promise { const client = this.clients.get(agentName); - if (!client) { - throw new Error(`Agent '${agentName}' not found.`); - } + if (!client) throw new Error(`Agent '${agentName}' not found.`); try { return await client.getTask({ id: taskId }); } catch (error: unknown) { @@ -260,9 +280,7 @@ export class A2AClientManager { */ async cancelTask(agentName: string, taskId: string): Promise { const client = this.clients.get(agentName); - if (!client) { - throw new Error(`Agent '${agentName}' not found.`); - } + if (!client) throw new Error(`Agent '${agentName}' not found.`); try { return await client.cancelTask({ id: taskId }); } catch (error: unknown) { diff --git a/packages/core/src/agents/a2aUtils.test.ts b/packages/core/src/agents/a2aUtils.test.ts index c3fe170aa5..0dce551be4 100644 --- a/packages/core/src/agents/a2aUtils.test.ts +++ b/packages/core/src/agents/a2aUtils.test.ts @@ -12,9 +12,6 @@ import { A2AResultReassembler, AUTH_REQUIRED_MSG, normalizeAgentCard, - getGrpcCredentials, - pinUrlToIp, - splitAgentCardUrl, } from './a2aUtils.js'; import type { SendMessageResult } from './a2a-client-manager.js'; import type { @@ -26,12 +23,6 @@ import type { TaskStatusUpdateEvent, TaskArtifactUpdateEvent, } from '@a2a-js/sdk'; -import * as dnsPromises from 'node:dns/promises'; -import type { LookupAddress } from 'node:dns'; - -vi.mock('node:dns/promises', () => ({ - lookup: vi.fn(), -})); describe('a2aUtils', () => { beforeEach(() => { @@ -42,89 +33,6 @@ describe('a2aUtils', () => { vi.restoreAllMocks(); }); - describe('getGrpcCredentials', () => { - it('should return secure credentials for https', () => { - const credentials = getGrpcCredentials('https://test.agent'); - expect(credentials).toBeDefined(); - }); - - it('should return insecure credentials for http', () => { - const credentials = getGrpcCredentials('http://test.agent'); - expect(credentials).toBeDefined(); - }); - }); - - describe('pinUrlToIp', () => { - it('should resolve and pin hostname to IP', async () => { - vi.mocked( - dnsPromises.lookup as unknown as ( - hostname: string, - options: { all: true }, - ) => Promise, - ).mockResolvedValue([{ address: '93.184.216.34', family: 4 }]); - - const { pinnedUrl, hostname } = await pinUrlToIp( - 'http://example.com:9000', - 'test-agent', - ); - expect(hostname).toBe('example.com'); - expect(pinnedUrl).toBe('http://93.184.216.34:9000/'); - }); - - it('should handle raw host:port strings (standard for gRPC)', async () => { - vi.mocked( - dnsPromises.lookup as unknown as ( - hostname: string, - options: { all: true }, - ) => Promise, - ).mockResolvedValue([{ address: '93.184.216.34', family: 4 }]); - - const { pinnedUrl, hostname } = await pinUrlToIp( - 'example.com:9000', - 'test-agent', - ); - expect(hostname).toBe('example.com'); - expect(pinnedUrl).toBe('93.184.216.34:9000'); - }); - - it('should throw error if resolution fails (fail closed)', async () => { - vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error')); - - await expect( - pinUrlToIp('http://unreachable.com', 'test-agent'), - ).rejects.toThrow("Failed to resolve host for agent 'test-agent'"); - }); - - it('should throw error if resolved to private IP', async () => { - vi.mocked( - dnsPromises.lookup as unknown as ( - hostname: string, - options: { all: true }, - ) => Promise, - ).mockResolvedValue([{ address: '10.0.0.1', family: 4 }]); - - await expect( - pinUrlToIp('http://malicious.com', 'test-agent'), - ).rejects.toThrow('resolves to private IP range'); - }); - - it('should allow localhost/127.0.0.1/::1 exceptions', async () => { - vi.mocked( - dnsPromises.lookup as unknown as ( - hostname: string, - options: { all: true }, - ) => Promise, - ).mockResolvedValue([{ address: '127.0.0.1', family: 4 }]); - - const { pinnedUrl, hostname } = await pinUrlToIp( - 'http://localhost:9000', - 'test-agent', - ); - expect(hostname).toBe('localhost'); - expect(pinnedUrl).toBe('http://127.0.0.1:9000/'); - }); - }); - describe('isTerminalState', () => { it('should return true for completed, failed, canceled, and rejected', () => { expect(isTerminalState('completed')).toBe(true); @@ -365,12 +273,12 @@ describe('a2aUtils', () => { expect(normalized.name).toBe('my-agent'); // @ts-expect-error - testing dynamic preservation expect(normalized.customField).toBe('keep-me'); - expect(normalized.description).toBe(''); - expect(normalized.skills).toEqual([]); - expect(normalized.defaultInputModes).toEqual([]); + expect(normalized.description).toBeUndefined(); + expect(normalized.skills).toBeUndefined(); + expect(normalized.defaultInputModes).toBeUndefined(); }); - it('should normalize and synchronize interfaces while preserving other fields', () => { + it('should map supportedInterfaces to additionalInterfaces with protocolBinding → transport', () => { const raw = { name: 'test', supportedInterfaces: [ @@ -384,13 +292,7 @@ describe('a2aUtils', () => { const normalized = normalizeAgentCard(raw); - // Should exist in both fields expect(normalized.additionalInterfaces).toHaveLength(1); - expect( - (normalized as unknown as Record)[ - 'supportedInterfaces' - ], - ).toHaveLength(1); const intf = normalized.additionalInterfaces?.[0] as unknown as Record< string, @@ -399,43 +301,18 @@ describe('a2aUtils', () => { expect(intf['transport']).toBe('GRPC'); expect(intf['url']).toBe('grpc://test'); - - // Should fallback top-level url - expect(normalized.url).toBe('grpc://test'); }); - it('should preserve existing top-level url if present', () => { + it('should not overwrite additionalInterfaces if already present', () => { const raw = { name: 'test', - url: 'http://existing', + additionalInterfaces: [{ url: 'http://grpc', transport: 'GRPC' }], supportedInterfaces: [{ url: 'http://other', transport: 'REST' }], }; const normalized = normalizeAgentCard(raw); - expect(normalized.url).toBe('http://existing'); - }); - - it('should NOT prepend http:// scheme to raw IP:port strings for gRPC interfaces', () => { - const raw = { - name: 'raw-ip-grpc', - supportedInterfaces: [{ url: '127.0.0.1:9000', transport: 'GRPC' }], - }; - - const normalized = normalizeAgentCard(raw); - expect(normalized.additionalInterfaces?.[0].url).toBe('127.0.0.1:9000'); - expect(normalized.url).toBe('127.0.0.1:9000'); - }); - - it('should prepend http:// scheme to raw IP:port strings for REST interfaces', () => { - const raw = { - name: 'raw-ip-rest', - supportedInterfaces: [{ url: '127.0.0.1:8080', transport: 'REST' }], - }; - - const normalized = normalizeAgentCard(raw); - expect(normalized.additionalInterfaces?.[0].url).toBe( - 'http://127.0.0.1:8080', - ); + expect(normalized.additionalInterfaces).toHaveLength(1); + expect(normalized.additionalInterfaces?.[0].url).toBe('http://grpc'); }); it('should NOT override existing transport if protocolBinding is also present', () => { @@ -448,48 +325,20 @@ describe('a2aUtils', () => { const normalized = normalizeAgentCard(raw); expect(normalized.additionalInterfaces?.[0].transport).toBe('GRPC'); }); - }); - describe('splitAgentCardUrl', () => { - const standard = '.well-known/agent-card.json'; + it('should not mutate the original card object', () => { + const raw = { + name: 'test', + supportedInterfaces: [{ url: 'grpc://test', protocolBinding: 'GRPC' }], + }; - it('should return baseUrl as-is if it does not end with standard path', () => { - const url = 'http://localhost:9001/custom/path'; - expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url }); - }); - - it('should split correctly if URL ends with standard path', () => { - const url = `http://localhost:9001/${standard}`; - expect(splitAgentCardUrl(url)).toEqual({ - baseUrl: 'http://localhost:9001/', - path: undefined, - }); - }); - - it('should handle trailing slash in baseUrl when splitting', () => { - const url = `http://example.com/api/${standard}`; - expect(splitAgentCardUrl(url)).toEqual({ - baseUrl: 'http://example.com/api/', - path: undefined, - }); - }); - - it('should ignore hashes and query params when splitting', () => { - const url = `http://localhost:9001/${standard}?foo=bar#baz`; - expect(splitAgentCardUrl(url)).toEqual({ - baseUrl: 'http://localhost:9001/', - path: undefined, - }); - }); - - it('should return original URL if parsing fails', () => { - const url = 'not-a-url'; - expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url }); - }); - - it('should handle standard path appearing earlier in the path', () => { - const url = `http://localhost:9001/${standard}/something-else`; - expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url }); + const normalized = normalizeAgentCard(raw); + expect(normalized).not.toBe(raw); + expect(normalized.additionalInterfaces).toBeDefined(); + // Original should not have additionalInterfaces added + expect( + (raw as Record)['additionalInterfaces'], + ).toBeUndefined(); }); }); diff --git a/packages/core/src/agents/a2aUtils.ts b/packages/core/src/agents/a2aUtils.ts index ec8b36bba1..70fc9cf557 100644 --- a/packages/core/src/agents/a2aUtils.ts +++ b/packages/core/src/agents/a2aUtils.ts @@ -4,9 +4,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -import * as grpc from '@grpc/grpc-js'; -import { lookup } from 'node:dns/promises'; -import { z } from 'zod'; import type { Message, Part, @@ -18,37 +15,10 @@ import type { AgentCard, AgentInterface, } from '@a2a-js/sdk'; -import { isAddressPrivate } from '../utils/fetch.js'; import type { SendMessageResult } from './a2a-client-manager.js'; export const AUTH_REQUIRED_MSG = `[Authorization Required] The agent has indicated it requires authorization to proceed. Please follow the agent's instructions.`; -const AgentInterfaceSchema = z - .object({ - url: z.string().default(''), - transport: z.string().optional(), - protocolBinding: z.string().optional(), - }) - .passthrough(); - -const AgentCardSchema = z - .object({ - name: z.string().default('unknown'), - description: z.string().default(''), - url: z.string().default(''), - version: z.string().default(''), - protocolVersion: z.string().default(''), - capabilities: z.record(z.unknown()).default({}), - skills: z.array(z.union([z.string(), z.record(z.unknown())])).default([]), - defaultInputModes: z.array(z.string()).default([]), - defaultOutputModes: z.array(z.string()).default([]), - - additionalInterfaces: z.array(AgentInterfaceSchema).optional(), - supportedInterfaces: z.array(AgentInterfaceSchema).optional(), - preferredTransport: z.string().optional(), - }) - .passthrough(); - /** * Reassembles incremental A2A streaming updates into a coherent result. * Shows sequential status/messages followed by all reassembled artifacts. @@ -241,166 +211,45 @@ function extractPartText(part: Part): string { } /** - * Normalizes an agent card by ensuring it has the required properties - * and resolving any inconsistencies between protocol versions. + * Normalizes proto field name aliases that the SDK doesn't handle yet. + * The A2A proto spec uses `supported_interfaces` and `protocol_binding`, + * while the SDK expects `additionalInterfaces` and `transport`. + * TODO: Remove once @a2a-js/sdk handles these aliases natively. */ export function normalizeAgentCard(card: unknown): AgentCard { if (!isObject(card)) { throw new Error('Agent card is missing.'); } - // Use Zod to validate and parse the card, ensuring safe defaults and narrowing types. - const parsed = AgentCardSchema.parse(card); - // Narrowing to AgentCard interface after runtime validation. + // Shallow-copy to avoid mutating the SDK's cached object. // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const result = parsed as unknown as AgentCard; + const result = { ...card } as unknown as AgentCard; - // Normalize interfaces and synchronize both interface fields. - const normalizedInterfaces = extractNormalizedInterfaces(parsed); - result.additionalInterfaces = normalizedInterfaces; - - // Sync supportedInterfaces for backward compatibility. - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const legacyResult = result as unknown as Record; - legacyResult['supportedInterfaces'] = normalizedInterfaces; - - // Fallback preferredTransport: If not specified, default to GRPC if available. - if ( - !result.preferredTransport && - normalizedInterfaces.some((i) => i.transport === 'GRPC') - ) { - result.preferredTransport = 'GRPC'; + // Map supportedInterfaces → additionalInterfaces if needed + if (!result.additionalInterfaces) { + const raw = card; + if (Array.isArray(raw['supportedInterfaces'])) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + result.additionalInterfaces = raw[ + 'supportedInterfaces' + ] as AgentInterface[]; + } } - // Fallback: If top-level URL is missing, use the first interface's URL. - if (result.url === '' && normalizedInterfaces.length > 0) { - result.url = normalizedInterfaces[0].url; + // Map protocolBinding → transport on each interface + for (const intf of result.additionalInterfaces ?? []) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + const raw = intf as unknown as Record; + const binding = raw['protocolBinding']; + + if (!intf.transport && typeof binding === 'string') { + intf.transport = binding; + } } return result; } -/** - * Returns gRPC channel credentials based on the URL scheme. - */ -export function getGrpcCredentials(url: string): grpc.ChannelCredentials { - return url.startsWith('https://') - ? grpc.credentials.createSsl() - : grpc.credentials.createInsecure(); -} - -/** - * Returns gRPC channel options to ensure SSL/authority matches the original hostname - * when connecting via a pinned IP address. - */ -export function getGrpcChannelOptions( - hostname: string, -): Record { - return { - 'grpc.default_authority': hostname, - 'grpc.ssl_target_name_override': hostname, - }; -} - -/** - * Resolves a hostname to its IP address and validates it against SSRF. - * Returns the pinned IP-based URL and the original hostname. - */ -export async function pinUrlToIp( - url: string, - agentName: string, -): Promise<{ pinnedUrl: string; hostname: string }> { - if (!url) return { pinnedUrl: url, hostname: '' }; - - // gRPC URLs in A2A can be 'host:port' or 'dns:///host:port' or have schemes. - // We normalize to host:port for resolution. - const hasScheme = url.includes('://'); - const normalizedUrl = hasScheme ? url : `http://${url}`; - - try { - const parsed = new URL(normalizedUrl); - const hostname = parsed.hostname; - - const sanitizedHost = - hostname.startsWith('[') && hostname.endsWith(']') - ? hostname.slice(1, -1) - : hostname; - - // Resolve DNS to check the actual target IP and pin it - const addresses = await lookup(hostname, { all: true }); - const publicAddresses = addresses.filter( - (addr) => - !isAddressPrivate(addr.address) || - sanitizedHost === 'localhost' || - sanitizedHost === '127.0.0.1' || - sanitizedHost === '::1', - ); - - if (publicAddresses.length === 0) { - if (addresses.length > 0) { - throw new Error( - `Refusing to load agent '${agentName}': transport URL '${url}' resolves to private IP range.`, - ); - } - throw new Error( - `Failed to resolve any public IP addresses for host: ${hostname}`, - ); - } - - const pinnedIp = publicAddresses[0].address; - const pinnedHostname = pinnedIp.includes(':') ? `[${pinnedIp}]` : pinnedIp; - - // Reconstruct URL with IP - parsed.hostname = pinnedHostname; - let pinnedUrl = parsed.toString(); - - // If original didn't have scheme, remove it (standard for gRPC targets) - if (!hasScheme) { - pinnedUrl = pinnedUrl.replace(/^http:\/\//, ''); - // URL.toString() might append a trailing slash - if (pinnedUrl.endsWith('/') && !url.endsWith('/')) { - pinnedUrl = pinnedUrl.slice(0, -1); - } - } - - return { pinnedUrl, hostname }; - } catch (e) { - if (e instanceof Error && e.message.includes('Refusing')) throw e; - throw new Error(`Failed to resolve host for agent '${agentName}': ${url}`, { - cause: e, - }); - } -} - -/** - * Splts an agent card URL into a baseUrl and a standard path if it already - * contains '.well-known/agent-card.json'. - */ -export function splitAgentCardUrl(url: string): { - baseUrl: string; - path?: string; -} { - const standardPath = '.well-known/agent-card.json'; - try { - const parsedUrl = new URL(url); - if (parsedUrl.pathname.endsWith(standardPath)) { - // Reconstruct baseUrl from parsed components to avoid issues with hashes or query params. - parsedUrl.pathname = parsedUrl.pathname.substring( - 0, - parsedUrl.pathname.lastIndexOf(standardPath), - ); - parsedUrl.search = ''; - parsedUrl.hash = ''; - // We return undefined for path if it's the standard one, - // because the SDK's DefaultAgentCardResolver appends it automatically. - return { baseUrl: parsedUrl.toString(), path: undefined }; - } - } catch (_e) { - // Ignore URL parsing errors here, let the resolver handle them. - } - return { baseUrl: url }; -} - /** * Extracts contextId and taskId from a Message, Task, or Update response. * Follows the pattern from the A2A CLI sample to maintain conversational continuity. @@ -446,65 +295,6 @@ export function extractIdsFromResponse(result: SendMessageResult): { return { contextId, taskId, clearTaskId }; } -/** - * Extracts and normalizes interfaces from the card, handling protocol version fallbacks. - * Preserves all original fields to maintain SDK compatibility. - */ -function extractNormalizedInterfaces( - card: Record, -): AgentInterface[] { - const rawInterfaces = - getArray(card, 'additionalInterfaces') || - getArray(card, 'supportedInterfaces'); - - if (!rawInterfaces) { - return []; - } - - const mapped: AgentInterface[] = []; - for (const i of rawInterfaces) { - if (isObject(i)) { - // Use schema to validate interface object. - const parsed = AgentInterfaceSchema.parse(i); - // Narrowing to AgentInterface after runtime validation. - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const normalized = parsed as unknown as AgentInterface & { - protocolBinding?: string; - }; - - // Normalize 'transport' from 'protocolBinding' if missing. - if (!normalized.transport && normalized.protocolBinding) { - normalized.transport = normalized.protocolBinding; - } - - // Robust URL: Ensure the URL has a scheme (except for gRPC). - if ( - normalized.url && - !normalized.url.includes('://') && - !normalized.url.startsWith('/') && - normalized.transport !== 'GRPC' - ) { - // Default to http:// for insecure REST/JSON-RPC if scheme is missing. - normalized.url = `http://${normalized.url}`; - } - - mapped.push(normalized as AgentInterface); - } - } - return mapped; -} - -/** - * Safely extracts an array property from an object. - */ -function getArray( - obj: Record, - key: string, -): unknown[] | undefined { - const val = obj[key]; - return Array.isArray(val) ? val : undefined; -} - // Type Guards function isTextPart(part: Part): part is TextPart { diff --git a/packages/core/src/code_assist/oauth2.ts b/packages/core/src/code_assist/oauth2.ts index 654ba0e10a..e238a4a860 100644 --- a/packages/core/src/code_assist/oauth2.ts +++ b/packages/core/src/code_assist/oauth2.ts @@ -700,7 +700,6 @@ async function fetchAndCacheUserInfo(client: OAuth2Client): Promise { return; } - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch( 'https://www.googleapis.com/oauth2/v2/userinfo', { diff --git a/packages/core/src/mcp/oauth-provider.ts b/packages/core/src/mcp/oauth-provider.ts index 01934d9019..6aaafa6054 100644 --- a/packages/core/src/mcp/oauth-provider.ts +++ b/packages/core/src/mcp/oauth-provider.ts @@ -111,7 +111,6 @@ export class MCPOAuthProvider { scope: config.scopes?.join(' ') || '', }; - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(registrationUrl, { method: 'POST', headers: { @@ -301,7 +300,6 @@ export class MCPOAuthProvider { ? { Accept: 'text/event-stream' } : { Accept: 'application/json' }; - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(mcpServerUrl, { method: 'HEAD', headers, diff --git a/packages/core/src/mcp/oauth-utils.ts b/packages/core/src/mcp/oauth-utils.ts index 207b694181..320c3b9685 100644 --- a/packages/core/src/mcp/oauth-utils.ts +++ b/packages/core/src/mcp/oauth-utils.ts @@ -97,7 +97,6 @@ export class OAuthUtils { resourceMetadataUrl: string, ): Promise { try { - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(resourceMetadataUrl); if (!response.ok) { return null; @@ -122,7 +121,6 @@ export class OAuthUtils { authServerMetadataUrl: string, ): Promise { try { - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(authServerMetadataUrl); if (!response.ok) { return null; diff --git a/packages/core/src/telemetry/clearcut-logger/clearcut-logger.ts b/packages/core/src/telemetry/clearcut-logger/clearcut-logger.ts index 5953578eae..2f059030ca 100644 --- a/packages/core/src/telemetry/clearcut-logger/clearcut-logger.ts +++ b/packages/core/src/telemetry/clearcut-logger/clearcut-logger.ts @@ -546,7 +546,6 @@ export class ClearcutLogger { let result: LogResponse = {}; try { - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(CLEARCUT_URL, { method: 'POST', body: safeJsonStringify(request), diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 7932e35f38..6dbae6dcde 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -1903,7 +1903,6 @@ export async function connectToMcpServer( acceptHeader = 'application/json'; } - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(urlToFetch, { method: 'HEAD', headers: { diff --git a/packages/core/src/utils/fetch.test.ts b/packages/core/src/utils/fetch.test.ts index 3eddefaf3d..4ac0c7b344 100644 --- a/packages/core/src/utils/fetch.test.ts +++ b/packages/core/src/utils/fetch.test.ts @@ -5,27 +5,12 @@ */ import { describe, it, expect, vi, beforeEach, afterAll } from 'vitest'; -import { - isPrivateIp, - isPrivateIpAsync, - isAddressPrivate, - safeLookup, - safeFetch, - fetchWithTimeout, - PrivateIpError, -} from './fetch.js'; -import * as dnsPromises from 'node:dns/promises'; -import * as dns from 'node:dns'; +import { isPrivateIp, isAddressPrivate, fetchWithTimeout } from './fetch.js'; vi.mock('node:dns/promises', () => ({ lookup: vi.fn(), })); -// We need to mock node:dns for safeLookup since it uses the callback API -vi.mock('node:dns', () => ({ - lookup: vi.fn(), -})); - // Mock global fetch const originalFetch = global.fetch; global.fetch = vi.fn(); @@ -114,150 +99,6 @@ describe('fetch utils', () => { }); }); - describe('isPrivateIpAsync', () => { - it('should identify private IPs directly', async () => { - expect(await isPrivateIpAsync('http://10.0.0.1/')).toBe(true); - }); - - it('should identify domains resolving to private IPs', async () => { - vi.mocked(dnsPromises.lookup).mockImplementation( - async () => - // eslint-disable-next-line @typescript-eslint/no-explicit-any - [{ address: '10.0.0.1', family: 4 }] as any, - ); - expect(await isPrivateIpAsync('http://malicious.com/')).toBe(true); - }); - - it('should identify domains resolving to public IPs as non-private', async () => { - vi.mocked(dnsPromises.lookup).mockImplementation( - async () => - // eslint-disable-next-line @typescript-eslint/no-explicit-any - [{ address: '8.8.8.8', family: 4 }] as any, - ); - expect(await isPrivateIpAsync('http://google.com/')).toBe(false); - }); - - it('should throw error if DNS resolution fails (fail closed)', async () => { - vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error')); - await expect(isPrivateIpAsync('http://unreachable.com/')).rejects.toThrow( - 'Failed to verify if URL resolves to private IP', - ); - }); - - it('should return false for invalid URLs instead of throwing verification error', async () => { - expect(await isPrivateIpAsync('not-a-url')).toBe(false); - }); - }); - - describe('safeLookup', () => { - it('should filter out private IPs', async () => { - const addresses = [ - { address: '8.8.8.8', family: 4 }, - { address: '10.0.0.1', family: 4 }, - ]; - - vi.mocked(dns.lookup).mockImplementation((( - _h: string, - _o: dns.LookupOptions, - cb: ( - err: Error | null, - addr: Array<{ address: string; family: number }>, - ) => void, - ) => { - cb(null, addresses); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - }) as any); - - const result = await new Promise< - Array<{ address: string; family: number }> - >((resolve, reject) => { - safeLookup('example.com', { all: true }, (err, filtered) => { - if (err) reject(err); - else resolve(filtered); - }); - }); - - expect(result).toHaveLength(1); - expect(result[0].address).toBe('8.8.8.8'); - }); - - it('should allow explicit localhost', async () => { - const addresses = [{ address: '127.0.0.1', family: 4 }]; - - vi.mocked(dns.lookup).mockImplementation((( - _h: string, - _o: dns.LookupOptions, - cb: ( - err: Error | null, - addr: Array<{ address: string; family: number }>, - ) => void, - ) => { - cb(null, addresses); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - }) as any); - - const result = await new Promise< - Array<{ address: string; family: number }> - >((resolve, reject) => { - safeLookup('localhost', { all: true }, (err, filtered) => { - if (err) reject(err); - else resolve(filtered); - }); - }); - - expect(result).toHaveLength(1); - expect(result[0].address).toBe('127.0.0.1'); - }); - - it('should error if all resolved IPs are private', async () => { - const addresses = [{ address: '10.0.0.1', family: 4 }]; - - vi.mocked(dns.lookup).mockImplementation((( - _h: string, - _o: dns.LookupOptions, - cb: ( - err: Error | null, - addr: Array<{ address: string; family: number }>, - ) => void, - ) => { - cb(null, addresses); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - }) as any); - - await expect( - new Promise((resolve, reject) => { - safeLookup('malicious.com', { all: true }, (err, filtered) => { - if (err) reject(err); - else resolve(filtered); - }); - }), - ).rejects.toThrow(PrivateIpError); - }); - }); - - describe('safeFetch', () => { - it('should forward to fetch with dispatcher', async () => { - vi.mocked(global.fetch).mockResolvedValue(new Response('ok')); - - const response = await safeFetch('https://example.com'); - expect(response.status).toBe(200); - expect(global.fetch).toHaveBeenCalledWith( - 'https://example.com', - expect.objectContaining({ - dispatcher: expect.any(Object), - }), - ); - }); - - it('should handle Refusing to connect errors', async () => { - vi.mocked(global.fetch).mockRejectedValue(new PrivateIpError()); - - await expect(safeFetch('http://10.0.0.1')).rejects.toThrow( - 'Access to private network is blocked', - ); - }); - }); - describe('fetchWithTimeout', () => { it('should handle timeouts', async () => { vi.mocked(global.fetch).mockImplementation( @@ -279,13 +120,5 @@ describe('fetch utils', () => { 'Request timed out after 50ms', ); }); - - it('should handle private IP errors via handleFetchError', async () => { - vi.mocked(global.fetch).mockRejectedValue(new PrivateIpError()); - - await expect(fetchWithTimeout('http://10.0.0.1', 1000)).rejects.toThrow( - 'Access to private network is blocked: http://10.0.0.1', - ); - }); }); }); diff --git a/packages/core/src/utils/fetch.ts b/packages/core/src/utils/fetch.ts index a324172d94..e339ea7fed 100644 --- a/packages/core/src/utils/fetch.ts +++ b/packages/core/src/utils/fetch.ts @@ -6,37 +6,12 @@ import { getErrorMessage, isNodeError } from './errors.js'; import { URL } from 'node:url'; -import * as dns from 'node:dns'; -import { lookup } from 'node:dns/promises'; import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici'; import ipaddr from 'ipaddr.js'; const DEFAULT_HEADERS_TIMEOUT = 300000; // 5 minutes const DEFAULT_BODY_TIMEOUT = 300000; // 5 minutes -// Configure default global dispatcher with higher timeouts -setGlobalDispatcher( - new Agent({ - headersTimeout: DEFAULT_HEADERS_TIMEOUT, - bodyTimeout: DEFAULT_BODY_TIMEOUT, - }), -); - -// Local extension of RequestInit to support Node.js/undici dispatcher -interface NodeFetchInit extends RequestInit { - dispatcher?: Agent | ProxyAgent; -} - -/** - * Error thrown when a connection to a private IP address is blocked for security reasons. - */ -export class PrivateIpError extends Error { - constructor(message = 'Refusing to connect to private IP address') { - super(message); - this.name = 'PrivateIpError'; - } -} - export class FetchError extends Error { constructor( message: string, @@ -48,6 +23,14 @@ export class FetchError extends Error { } } +// Configure default global dispatcher with higher timeouts +setGlobalDispatcher( + new Agent({ + headersTimeout: DEFAULT_HEADERS_TIMEOUT, + bodyTimeout: DEFAULT_BODY_TIMEOUT, + }), +); + /** * Sanitizes a hostname by stripping IPv6 brackets if present. */ @@ -69,53 +52,6 @@ export function isLoopbackHost(hostname: string): boolean { ); } -/** - * A custom DNS lookup implementation for undici agents that prevents - * connection to private IP ranges (SSRF protection). - */ -export function safeLookup( - hostname: string, - options: dns.LookupOptions | number | null | undefined, - callback: ( - err: Error | null, - addresses: Array<{ address: string; family: number }>, - ) => void, -): void { - // Use the callback-based dns.lookup to match undici's expected signature. - // We explicitly handle the 'all' option to ensure we get an array of addresses. - const lookupOptions = - typeof options === 'number' ? { family: options } : { ...options }; - const finalOptions = { ...lookupOptions, all: true }; - - dns.lookup(hostname, finalOptions, (err, addresses) => { - if (err) { - callback(err, []); - return; - } - - const addressArray = Array.isArray(addresses) ? addresses : []; - const filtered = addressArray.filter( - (addr) => !isAddressPrivate(addr.address) || isLoopbackHost(hostname), - ); - - if (filtered.length === 0 && addressArray.length > 0) { - callback(new PrivateIpError(), []); - return; - } - - callback(null, filtered); - }); -} - -// Dedicated dispatcher with connection-level SSRF protection (safeLookup) -const safeDispatcher = new Agent({ - headersTimeout: DEFAULT_HEADERS_TIMEOUT, - bodyTimeout: DEFAULT_BODY_TIMEOUT, - connect: { - lookup: safeLookup, - }, -}); - export function isPrivateIp(url: string): boolean { try { const hostname = new URL(url).hostname; @@ -125,37 +61,6 @@ export function isPrivateIp(url: string): boolean { } } -/** - * Checks if a URL resolves to a private IP address. - * Performs DNS resolution to prevent DNS rebinding/SSRF bypasses. - */ -export async function isPrivateIpAsync(url: string): Promise { - try { - const parsed = new URL(url); - const hostname = parsed.hostname; - - // Fast check for literal IPs or localhost - if (isAddressPrivate(hostname)) { - return true; - } - - // Resolve DNS to check the actual target IP - const addresses = await lookup(hostname, { all: true }); - return addresses.some((addr) => isAddressPrivate(addr.address)); - } catch (e) { - if ( - e instanceof Error && - e.name === 'TypeError' && - e.message.includes('Invalid URL') - ) { - return false; - } - throw new Error(`Failed to verify if URL resolves to private IP: ${url}`, { - cause: e, - }); - } -} - /** * IANA Benchmark Testing Range (198.18.0.0/15). * Classified as 'unicast' by ipaddr.js but is reserved and should not be @@ -210,72 +115,15 @@ export function isAddressPrivate(address: string): boolean { } } -/** - * Internal helper to map varied fetch errors to a standardized FetchError. - * Centralizes security-related error mapping (e.g. PrivateIpError). - */ -function handleFetchError(error: unknown, url: string): never { - if (error instanceof PrivateIpError) { - throw new FetchError( - `Access to private network is blocked: ${url}`, - 'ERR_PRIVATE_NETWORK', - { cause: error }, - ); - } - - if (error instanceof FetchError) { - throw error; - } - - throw new FetchError( - getErrorMessage(error), - isNodeError(error) ? error.code : undefined, - { cause: error }, - ); -} - -/** - * Enhanced fetch with SSRF protection. - * Prevents access to private/internal networks at the connection level. - */ -export async function safeFetch( - input: RequestInfo | URL, - init?: RequestInit, -): Promise { - const nodeInit: NodeFetchInit = { - ...init, - dispatcher: safeDispatcher, - }; - - try { - // eslint-disable-next-line no-restricted-syntax - return await fetch(input, nodeInit); - } catch (error) { - const url = - input instanceof Request - ? input.url - : typeof input === 'string' - ? input - : input.toString(); - handleFetchError(error, url); - } -} - /** * Creates an undici ProxyAgent that incorporates safe DNS lookup. */ export function createSafeProxyAgent(proxyUrl: string): ProxyAgent { return new ProxyAgent({ uri: proxyUrl, - connect: { - lookup: safeLookup, - }, }); } -/** - * Performs a fetch with a specified timeout and connection-level SSRF protection. - */ export async function fetchWithTimeout( url: string, timeout: number, @@ -294,21 +142,17 @@ export async function fetchWithTimeout( } } - const nodeInit: NodeFetchInit = { - ...options, - signal: controller.signal, - dispatcher: safeDispatcher, - }; - try { - // eslint-disable-next-line no-restricted-syntax - const response = await fetch(url, nodeInit); + const response = await fetch(url, { + ...options, + signal: controller.signal, + }); return response; } catch (error) { if (isNodeError(error) && error.code === 'ABORT_ERR') { throw new FetchError(`Request timed out after ${timeout}ms`, 'ETIMEDOUT'); } - handleFetchError(error, url.toString()); + throw new FetchError(getErrorMessage(error), undefined, { cause: error }); } finally { clearTimeout(timeoutId); } diff --git a/packages/core/src/utils/oauth-flow.ts b/packages/core/src/utils/oauth-flow.ts index 45318efdb5..e13fd37837 100644 --- a/packages/core/src/utils/oauth-flow.ts +++ b/packages/core/src/utils/oauth-flow.ts @@ -454,7 +454,6 @@ export async function exchangeCodeForToken( params.append('resource', resource); } - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(config.tokenUrl, { method: 'POST', headers: { @@ -508,7 +507,6 @@ export async function refreshAccessToken( params.append('resource', resource); } - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(tokenUrl, { method: 'POST', headers: { diff --git a/packages/vscode-ide-companion/src/extension.ts b/packages/vscode-ide-companion/src/extension.ts index e8cef91c2b..456ec6e872 100644 --- a/packages/vscode-ide-companion/src/extension.ts +++ b/packages/vscode-ide-companion/src/extension.ts @@ -42,7 +42,6 @@ async function checkForUpdates( const currentVersion = context.extension.packageJSON.version; // Fetch extension details from the VSCode Marketplace. - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch( 'https://marketplace.visualstudio.com/_apis/public/gallery/extensionquery', { diff --git a/packages/vscode-ide-companion/src/ide-server.test.ts b/packages/vscode-ide-companion/src/ide-server.test.ts index b3d39bf832..eb28638a78 100644 --- a/packages/vscode-ide-companion/src/ide-server.test.ts +++ b/packages/vscode-ide-companion/src/ide-server.test.ts @@ -356,7 +356,6 @@ describe('IDEServer', () => { }); it('should reject request without auth token', async () => { - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(`http://localhost:${port}/mcp`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, @@ -371,7 +370,6 @@ describe('IDEServer', () => { }); it('should allow request with valid auth token', async () => { - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(`http://localhost:${port}/mcp`, { method: 'POST', headers: { @@ -389,7 +387,6 @@ describe('IDEServer', () => { }); it('should reject request with invalid auth token', async () => { - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(`http://localhost:${port}/mcp`, { method: 'POST', headers: { @@ -416,7 +413,6 @@ describe('IDEServer', () => { ]; for (const header of malformedHeaders) { - // eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const response = await fetch(`http://localhost:${port}/mcp`, { method: 'POST', headers: { diff --git a/scripts/build_binary.js b/scripts/build_binary.js index d4aa578925..7d0fd815c1 100644 --- a/scripts/build_binary.js +++ b/scripts/build_binary.js @@ -228,23 +228,35 @@ const packageJson = JSON.parse( // Helper to calc hash const sha256 = (content) => createHash('sha256').update(content).digest('hex'); -// Read Main Bundle -const geminiBundlePath = join(root, 'bundle/gemini.js'); -const geminiContent = readFileSync(geminiBundlePath); -const geminiHash = sha256(geminiContent); - const assets = { - 'gemini.mjs': geminiBundlePath, // Use .js source but map to .mjs for runtime ESM 'manifest.json': 'bundle/manifest.json', }; const manifest = { main: 'gemini.mjs', - mainHash: geminiHash, + mainHash: '', version: packageJson.version, files: [], }; +// Add all javascript chunks from the bundle directory +const jsFiles = globSync('*.js', { cwd: bundleDir }); +for (const jsFile of jsFiles) { + const fsPath = join(bundleDir, jsFile); + const content = readFileSync(fsPath); + const hash = sha256(content); + + // Node SEA requires the main entry point to be explicitly mapped + if (jsFile === 'gemini.js') { + assets['gemini.mjs'] = fsPath; + manifest.mainHash = hash; + } else { + // Other chunks need to be mapped exactly as they are named so dynamic imports find them + assets[jsFile] = fsPath; + manifest.files.push({ key: jsFile, path: jsFile, hash: hash }); + } +} + // Helper to recursively find files from STAGING function addAssetsFromDir(baseDir, runtimePrefix) { const fullDir = join(stagingDir, baseDir); @@ -346,6 +358,22 @@ const targetBinaryPath = join(targetDir, binaryName); console.log(`Copying node binary from ${nodeBinary} to ${targetBinaryPath}...`); copyFileSync(nodeBinary, targetBinaryPath); +if (platform === 'darwin') { + console.log(`Thinning universal binary for ${arch}...`); + try { + // Attempt to thin the binary. Will fail safely if it's not a fat binary. + runCommand('lipo', [ + targetBinaryPath, + '-thin', + arch, + '-output', + targetBinaryPath, + ]); + } catch (e) { + console.log(`Skipping lipo thinning: ${e.message}`); + } +} + // Remove existing signature using helper removeSignature(targetBinaryPath); @@ -357,9 +385,7 @@ if (existsSync(bundleDir)) { // Clean up source JS files from output (we only want embedded) const filesToRemove = [ - 'gemini.js', 'gemini.mjs', - 'gemini.js.map', 'gemini.mjs.map', 'gemini-sea.cjs', 'sea-launch.cjs', @@ -373,6 +399,12 @@ filesToRemove.forEach((f) => { if (existsSync(p)) rmSync(p, { recursive: true, force: true }); }); +// Remove all chunk and entry .js/.js.map files +const jsFilesToRemove = globSync('*.{js,js.map}', { cwd: targetDir }); +for (const f of jsFilesToRemove) { + rmSync(join(targetDir, f)); +} + // Remove .sb files from targetDir const sbFilesToRemove = globSync('sandbox-macos-*.sb', { cwd: targetDir }); for (const f of sbFilesToRemove) {