merge main

This commit is contained in:
Christian Gunderman
2026-03-12 15:50:13 -07:00
26 changed files with 823 additions and 1169 deletions
+6 -3
View File
@@ -1,6 +1,6 @@
# Preview release: v0.34.0-preview.0
# Preview release: v0.34.0-preview.1
Released: March 11, 2026
Released: March 12, 2026
Our preview release includes the latest, new, and experimental features. This
release may not be as stable as our [latest weekly release](latest.md).
@@ -28,6 +28,9 @@ npm install -g @google/gemini-cli@preview
## What's Changed
- fix(patch): cherry-pick 45faf4d to release/v0.34.0-preview.0-pr-22148
[CONFLICTS] by @gemini-cli-robot in
[#22174](https://github.com/google-gemini/gemini-cli/pull/22174)
- feat(cli): add chat resume footer on session quit by @lordshashank in
[#20667](https://github.com/google-gemini/gemini-cli/pull/20667)
- Support bold and other styles in svg snapshots by @jacob314 in
@@ -465,4 +468,4 @@ npm install -g @google/gemini-cli@preview
[#21938](https://github.com/google-gemini/gemini-cli/pull/21938)
**Full Changelog**:
https://github.com/google-gemini/gemini-cli/compare/v0.33.0-preview.15...v0.34.0-preview.0
https://github.com/google-gemini/gemini-cli/compare/v0.33.0-preview.15...v0.34.0-preview.1
+18 -1
View File
@@ -26,6 +26,20 @@ policies.
the CLI will use an available fallback model for the current turn or the
remainder of the session.
### Local Model Routing (Experimental)
Gemini CLI supports using a local model for routing decisions. When configured,
Gemini CLI will use a locally-running **Gemma** model to make routing decisions
(instead of sending routing decisions to a hosted model). This feature can help
reduce costs associated with hosted model usage while offering similar routing
decision latency and quality.
In order to use this feature, the local Gemma model **must** be served behind a
Gemini API and accessible via HTTP at an endpoint configured in `settings.json`.
For more details on how to configure local model routing, see
[Local Model Routing](../core/local-model-routing.md).
### Model selection precedence
The model used by Gemini CLI is determined by the following order of precedence:
@@ -38,5 +52,8 @@ The model used by Gemini CLI is determined by the following order of precedence:
3. **`model.name` in `settings.json`:** If neither of the above are set, the
model specified in the `model.name` property of your `settings.json` file
will be used.
4. **Default model:** If none of the above are set, the default model will be
4. **Local model (experimental):** If the Gemma local model router is enabled
in your `settings.json` file, the CLI will use the local Gemma model
(instead of Gemini models) to route the request to an appropriate model.
5. **Default model:** If none of the above are set, the default model will be
used. The default model is `auto`
+2
View File
@@ -15,6 +15,8 @@ requests sent from `packages/cli`. For a general overview of Gemini CLI, see the
modular GEMINI.md import feature using @file.md syntax.
- **[Policy Engine](../reference/policy-engine.md):** Use the Policy Engine for
fine-grained control over tool execution.
- **[Local Model Routing (experimental)](./local-model-routing.md):** Learn how
to enable use of a local Gemma model for model routing decisions.
## Role of the core
+193
View File
@@ -0,0 +1,193 @@
# Local Model Routing (experimental)
Gemini CLI supports using a local model for
[routing decisions](../cli/model-routing.md). When configured, Gemini CLI will
use a locally-running **Gemma** model to make routing decisions (instead of
sending routing decisions to a hosted model).
This feature can help reduce costs associated with hosted model usage while
offering similar routing decision latency and quality.
> **Note: Local model routing is currently an experimental feature.**
## Setup
Using a Gemma model for routing decisions requires that an implementation of a
Gemma model be running locally on your machine, served behind an HTTP endpoint
and accessed via the Gemini API.
To serve the Gemma model, follow these steps:
### Download the LiteRT-LM runtime
The [LiteRT-LM](https://github.com/google-ai-edge/LiteRT-LM) runtime offers
pre-built binaries for locally-serving models. Download the binary appropriate
for your system.
#### Windows
1. Download
[lit.windows_x86_64.exe](https://github.com/google-ai-edge/LiteRT-LM/releases/download/v0.9.0-alpha03/lit.windows_x86_64.exe).
2. Using GPU on Windows requires the DirectXShaderCompiler. Download the
[dxc zip from the latest release](https://github.com/microsoft/DirectXShaderCompiler/releases/download/v1.8.2505.1/dxc_2025_07_14.zip).
Unzip the archive and from the architecture-appropriate `bin\` directory, and
copy the `dxil.dll` and `dxcompiler.dll` into the same location as you saved
`lit.windows_x86_64.exe`.
3. (Optional) Test starting the runtime:
`.\lit.windows_x86_64.exe serve --verbose`
#### Linux
1. Download
[lit.linux_x86_64](https://github.com/google-ai-edge/LiteRT-LM/releases/download/v0.9.0-alpha03/lit.linux_x86_64).
2. Ensure the binary is executable: `chmod a+x lit.linux_x86_64`
3. (Optional) Test starting the runtime: `./lit.linux_x86_64 serve --verbose`
#### MacOS
1. Download
[lit-macos-arm64](https://github.com/google-ai-edge/LiteRT-LM/releases/download/v0.9.0-alpha03/lit.macos_arm64).
2. Ensure the binary is executable: `chmod a+x lit.macos_arm64`
3. (Optional) Test starting the runtime: `./lit.macos_arm64 serve --verbose`
> **Note**: MacOS can be configured to only allows binaries from "App Store &
> Known Developers". If you encounter an error message when attempting to run
> the binary, you will need to allow the application. One option is to visit
> `System Settings -> Privacy & Security`, scroll to `Security`, and click
> `"Allow Anyway"` for `"lit.macos_arm64"`. Another option is to run
> `xattr -d com.apple.quarantine lit.macos_arm64` from the commandline.
### Download the Gemma Model
Before using Gemma, you will need to download the model (and agree to the Terms
of Service).
This can be done via the LiteRT-LM runtime.
#### Windows
```bash
$ .\lit.windows_x86_64.exe pull gemma3-1b-gpu-custom
[Legal] The model you are about to download is governed by
the Gemma Terms of Use and Prohibited Use Policy. Please review these terms and ensure you agree before continuing.
Full Terms: https://ai.google.dev/gemma/terms
Prohibited Use Policy: https://ai.google.dev/gemma/prohibited_use_policy
Do you accept these terms? (Y/N): Y
Terms accepted.
Downloading model 'gemma3-1b-gpu-custom' ...
Downloading... 968.6 MB
Download complete.
```
#### Linux
```bash
$ ./lit.linux_x86_64 pull gemma3-1b-gpu-custom
[Legal] The model you are about to download is governed by
the Gemma Terms of Use and Prohibited Use Policy. Please review these terms and ensure you agree before continuing.
Full Terms: https://ai.google.dev/gemma/terms
Prohibited Use Policy: https://ai.google.dev/gemma/prohibited_use_policy
Do you accept these terms? (Y/N): Y
Terms accepted.
Downloading model 'gemma3-1b-gpu-custom' ...
Downloading... 968.6 MB
Download complete.
```
#### MacOS
```bash
$ ./lit.lit.macos_arm64 pull gemma3-1b-gpu-custom
[Legal] The model you are about to download is governed by
the Gemma Terms of Use and Prohibited Use Policy. Please review these terms and ensure you agree before continuing.
Full Terms: https://ai.google.dev/gemma/terms
Prohibited Use Policy: https://ai.google.dev/gemma/prohibited_use_policy
Do you accept these terms? (Y/N): Y
Terms accepted.
Downloading model 'gemma3-1b-gpu-custom' ...
Downloading... 968.6 MB
Download complete.
```
### Start LiteRT-LM Runtime
Using the command appropriate to your system, start the LiteRT-LM runtime.
Configure the port that you want to use for your Gemma model. For the purposes
of this document, we will use port `9379`.
Example command for MacOS: `./lit.macos_arm64 serve --port=9379 --verbose`
### (Optional) Verify Model Serving
Send a quick prompt to the model via HTTP to validate successful model serving.
This will cause the runtime to download the model and run it once.
You should see a short joke in the server output as an indicator of success.
#### Windows
```
# Run this in PowerShell to send a request to the server
$uri = "http://localhost:9379/v1beta/models/gemma3-1b-gpu-custom:generateContent"
$body = @{contents = @( @{
role = "user"
parts = @( @{ text = "Tell me a joke." } )
})} | ConvertTo-Json -Depth 10
Invoke-RestMethod -Uri $uri -Method Post -Body $body -ContentType "application/json"
```
#### Linux/MacOS
```bash
$ curl "http://localhost:9379/v1beta/models/gemma3-1b-gpu-custom:generateContent" \
-H 'Content-Type: application/json' \
-X POST \
-d '{"contents":[{"role":"user","parts":[{"text":"Tell me a joke."}]}]}'
```
## Configuration
To use a local Gemma model for routing, you must explicitly enable it in your
`settings.json`:
```json
{
"experimental": {
"gemmaModelRouter": {
"enabled": true,
"classifier": {
"host": "http://localhost:9379",
"model": "gemma3-1b-gpu-custom"
}
}
}
}
```
> Use the port you started your LiteRT-LM runtime on in the setup steps.
### Configuration schema
| Field | Type | Required | Description |
| :----------------- | :------ | :------- | :----------------------------------------------------------------------------------------- |
| `enabled` | boolean | Yes | Must be `true` to enable the feature. |
| `classifier` | object | Yes | The configuration for the local model endpoint. It includes the host and model specifiers. |
| `classifier.host` | string | Yes | The URL to the local model server. Should be `http://localhost:<port>`. |
| `classifier.model` | string | Yes | The model name to use for decisions. Must be `"gemma3-1b-gpu-custom"`. |
> **Note: You will need to restart after configuration changes for local model
> routing to take effect.**
+9 -4
View File
@@ -82,11 +82,14 @@ const commonAliases = {
const cliConfig = {
...baseConfig,
banner: {
js: `const require = (await import('node:module')).createRequire(import.meta.url); globalThis.__filename = (await import('node:url')).fileURLToPath(import.meta.url); globalThis.__dirname = (await import('node:path')).dirname(globalThis.__filename);`,
js: `const require = (await import('node:module')).createRequire(import.meta.url); const __chunk_filename = (await import('node:url')).fileURLToPath(import.meta.url); const __chunk_dirname = (await import('node:path')).dirname(__chunk_filename);`,
},
entryPoints: ['packages/cli/index.ts'],
outfile: 'bundle/gemini.js',
entryPoints: { gemini: 'packages/cli/index.ts' },
outdir: 'bundle',
splitting: true,
define: {
__filename: '__chunk_filename',
__dirname: '__chunk_dirname',
'process.env.CLI_VERSION': JSON.stringify(pkg.version),
'process.env.GEMINI_SANDBOX_IMAGE_DEFAULT': JSON.stringify(
pkg.config?.sandboxImageUri,
@@ -103,11 +106,13 @@ const cliConfig = {
const a2aServerConfig = {
...baseConfig,
banner: {
js: `const require = (await import('node:module')).createRequire(import.meta.url); globalThis.__filename = (await import('node:url')).fileURLToPath(import.meta.url); globalThis.__dirname = (await import('node:path')).dirname(globalThis.__filename);`,
js: `const require = (await import('node:module')).createRequire(import.meta.url); const __chunk_filename = (await import('node:url')).fileURLToPath(import.meta.url); const __chunk_dirname = (await import('node:path')).dirname(__chunk_filename);`,
},
entryPoints: ['packages/a2a-server/src/http/server.ts'],
outfile: 'packages/a2a-server/dist/a2a-server.mjs',
define: {
__filename: '__chunk_filename',
__dirname: '__chunk_dirname',
'process.env.CLI_VERSION': JSON.stringify(pkg.version),
},
plugins: createWasmPlugins(),
-5
View File
@@ -35,11 +35,6 @@ const commonRestrictedSyntaxRules = [
message:
'Do not throw string literals or non-Error objects. Throw new Error("...") instead.',
},
{
selector: 'CallExpression[callee.name="fetch"]',
message:
'Use safeFetch() from "@/utils/fetch" instead of the global fetch() to ensure SSRF protection. If you are implementing a custom security layer, use an eslint-disable comment and explain why.',
},
];
export default tseslint.config(
BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 90 KiB

+38 -217
View File
@@ -4,13 +4,38 @@
* SPDX-License-Identifier: Apache-2.0
*/
import React from 'react';
import { render } from 'ink';
import { AppContainer } from './ui/AppContainer.js';
import {
type StartupWarning,
WarningPriority,
type Config,
type ResumedSessionData,
type OutputPayload,
type ConsoleLogPayload,
type UserFeedbackPayload,
sessionId,
logUserPrompt,
AuthType,
UserPromptEvent,
coreEvents,
CoreEvent,
getOauthClient,
patchStdio,
writeToStdout,
writeToStderr,
shouldEnterAlternateScreen,
startupProfiler,
ExitCodes,
SessionStartSource,
SessionEndReason,
ValidationCancelledError,
ValidationRequiredError,
type AdminControlsSettings,
debugLogger,
} from '@google/gemini-cli-core';
import { loadCliConfig, parseArguments } from './config/config.js';
import * as cliConfig from './config/config.js';
import { readStdin } from './utils/readStdin.js';
import { basename } from 'node:path';
import { createHash } from 'node:crypto';
import v8 from 'node:v8';
import os from 'node:os';
@@ -37,47 +62,11 @@ import {
runExitCleanup,
registerTelemetryConfig,
setupSignalHandlers,
setupTtyCheck,
} from './utils/cleanup.js';
import {
cleanupToolOutputFiles,
cleanupExpiredSessions,
} from './utils/sessionCleanup.js';
import {
type StartupWarning,
WarningPriority,
type Config,
type ResumedSessionData,
type OutputPayload,
type ConsoleLogPayload,
type UserFeedbackPayload,
sessionId,
logUserPrompt,
AuthType,
getOauthClient,
UserPromptEvent,
debugLogger,
recordSlowRender,
coreEvents,
CoreEvent,
createWorkingStdio,
patchStdio,
writeToStdout,
writeToStderr,
disableMouseEvents,
enableMouseEvents,
disableLineWrapping,
enableLineWrapping,
shouldEnterAlternateScreen,
startupProfiler,
ExitCodes,
SessionStartSource,
SessionEndReason,
getVersion,
ValidationCancelledError,
ValidationRequiredError,
type AdminControlsSettings,
} from '@google/gemini-cli-core';
import {
initializeApp,
type InitializationResult,
@@ -85,21 +74,9 @@ import {
import { validateAuthMethod } from './config/auth.js';
import { runAcpClient } from './acp/acpClient.js';
import { validateNonInteractiveAuth } from './validateNonInterActiveAuth.js';
import { checkForUpdates } from './ui/utils/updateCheck.js';
import { handleAutoUpdate } from './utils/handleAutoUpdate.js';
import { appEvents, AppEvent } from './utils/events.js';
import { SessionError, SessionSelector } from './utils/sessionUtils.js';
import { SettingsContext } from './ui/contexts/SettingsContext.js';
import { MouseProvider } from './ui/contexts/MouseContext.js';
import { StreamingState } from './ui/types.js';
import { computeTerminalTitle } from './utils/windowTitle.js';
import { SessionStatsProvider } from './ui/contexts/SessionContext.js';
import { VimModeProvider } from './ui/contexts/VimModeContext.js';
import { KeyMatchersProvider } from './ui/hooks/useKeyMatchers.js';
import { loadKeyMatchers } from './ui/key/keyMatchers.js';
import { KeypressProvider } from './ui/contexts/KeypressContext.js';
import { useKittyKeyboardProtocol } from './ui/hooks/useKittyKeyboardProtocol.js';
import {
relaunchAppInChildProcess,
relaunchOnExitCode,
@@ -107,19 +84,13 @@ import {
import { loadSandboxConfig } from './config/sandboxConfig.js';
import { deleteSession, listSessions } from './utils/sessions.js';
import { createPolicyUpdater } from './config/policy.js';
import { ScrollProvider } from './ui/contexts/ScrollProvider.js';
import { TerminalProvider } from './ui/contexts/TerminalContext.js';
import { isAlternateBufferEnabled } from './ui/hooks/useAlternateBuffer.js';
import { OverflowProvider } from './ui/contexts/OverflowContext.js';
import { setupTerminalAndTheme } from './utils/terminalTheme.js';
import { profiler } from './ui/components/DebugProfiler.js';
import { runDeferredCommand } from './deferred.js';
import { cleanupBackgroundLogs } from './utils/logCleanup.js';
import { SlashCommandConflictHandler } from './services/SlashCommandConflictHandler.js';
const SLOW_RENDER_MS = 200;
export function validateDnsResolutionOrder(
order: string | undefined,
): DnsResolutionOrder {
@@ -198,147 +169,16 @@ export async function startInteractiveUI(
resumedSessionData: ResumedSessionData | undefined,
initializationResult: InitializationResult,
) {
// Never enter Ink alternate buffer mode when screen reader mode is enabled
// as there is no benefit of alternate buffer mode when using a screen reader
// and the Ink alternate buffer mode requires line wrapping harmful to
// screen readers.
const useAlternateBuffer = shouldEnterAlternateScreen(
isAlternateBufferEnabled(config),
config.getScreenReader(),
// Dynamically import the heavy UI module so React/Ink are only parsed when needed
const { startInteractiveUI: doStartUI } = await import('./interactiveCli.js');
await doStartUI(
config,
settings,
startupWarnings,
workspaceRoot,
resumedSessionData,
initializationResult,
);
const mouseEventsEnabled = useAlternateBuffer;
if (mouseEventsEnabled) {
enableMouseEvents();
registerCleanup(() => {
disableMouseEvents();
});
}
const { matchers, errors } = await loadKeyMatchers();
errors.forEach((error) => {
coreEvents.emitFeedback('warning', error);
});
const version = await getVersion();
setWindowTitle(basename(workspaceRoot), settings);
const consolePatcher = new ConsolePatcher({
onNewMessage: (msg) => {
coreEvents.emitConsoleLog(msg.type, msg.content);
},
debugMode: config.getDebugMode(),
});
consolePatcher.patch();
registerCleanup(consolePatcher.cleanup);
const { stdout: inkStdout, stderr: inkStderr } = createWorkingStdio();
const isShpool = !!process.env['SHPOOL_SESSION_NAME'];
// Create wrapper component to use hooks inside render
const AppWrapper = () => {
useKittyKeyboardProtocol();
return (
<SettingsContext.Provider value={settings}>
<KeyMatchersProvider value={matchers}>
<KeypressProvider
config={config}
debugKeystrokeLogging={
settings.merged.general.debugKeystrokeLogging
}
>
<MouseProvider
mouseEventsEnabled={mouseEventsEnabled}
debugKeystrokeLogging={
settings.merged.general.debugKeystrokeLogging
}
>
<TerminalProvider>
<ScrollProvider>
<OverflowProvider>
<SessionStatsProvider>
<VimModeProvider>
<AppContainer
config={config}
startupWarnings={startupWarnings}
version={version}
resumedSessionData={resumedSessionData}
initializationResult={initializationResult}
/>
</VimModeProvider>
</SessionStatsProvider>
</OverflowProvider>
</ScrollProvider>
</TerminalProvider>
</MouseProvider>
</KeypressProvider>
</KeyMatchersProvider>
</SettingsContext.Provider>
);
};
if (isShpool) {
// Wait a moment for shpool to stabilize terminal size and state.
// shpool is a persistence tool that restores terminal state by replaying it.
// This delay gives shpool time to finish its restoration replay and send
// the actual terminal size (often via an immediate SIGWINCH) before we
// render the first TUI frame. Without this, the first frame may be
// garbled or rendered at an incorrect size, which disabling incremental
// rendering alone cannot fix for the initial frame.
await new Promise((resolve) => setTimeout(resolve, 100));
}
const instance = render(
process.env['DEBUG'] ? (
<React.StrictMode>
<AppWrapper />
</React.StrictMode>
) : (
<AppWrapper />
),
{
stdout: inkStdout,
stderr: inkStderr,
stdin: process.stdin,
exitOnCtrlC: false,
isScreenReaderEnabled: config.getScreenReader(),
onRender: ({ renderTime }: { renderTime: number }) => {
if (renderTime > SLOW_RENDER_MS) {
recordSlowRender(config, renderTime);
}
profiler.reportFrameRendered();
},
patchConsole: false,
alternateBuffer: useAlternateBuffer,
incrementalRendering:
settings.merged.ui.incrementalRendering !== false &&
useAlternateBuffer &&
!isShpool,
},
);
if (useAlternateBuffer) {
disableLineWrapping();
registerCleanup(() => {
enableLineWrapping();
});
}
checkForUpdates(settings)
.then((info) => {
handleAutoUpdate(info, settings, config.getProjectRoot());
})
.catch((err) => {
// Silently ignore update check errors.
if (config.getDebugMode()) {
debugLogger.warn('Update check failed:', err);
}
});
registerCleanup(() => instance.unmount());
registerCleanup(setupTtyCheck());
}
export async function main() {
@@ -845,25 +685,6 @@ export async function main() {
}
}
function setWindowTitle(title: string, settings: LoadedSettings) {
if (!settings.merged.ui.hideWindowTitle) {
// Initial state before React loop starts
const windowTitle = computeTerminalTitle({
streamingState: StreamingState.Idle,
isConfirming: false,
isSilentWorking: false,
folderName: title,
showThoughts: !!settings.merged.ui.showStatusInTitle,
useDynamicTitle: settings.merged.ui.dynamicWindowTitle,
});
writeToStdout(`\x1b]0;${windowTitle}\x07`);
process.on('exit', () => {
writeToStdout(`\x1b]0;\x07`);
});
}
}
export function initializeOutputListenersAndFlush() {
// If there are no listeners for output, make sure we flush so output is not
// lost.
+214
View File
@@ -0,0 +1,214 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import React from 'react';
import { render } from 'ink';
import { basename } from 'node:path';
import { AppContainer } from './ui/AppContainer.js';
import { ConsolePatcher } from './ui/utils/ConsolePatcher.js';
import { registerCleanup, setupTtyCheck } from './utils/cleanup.js';
import {
type StartupWarning,
type Config,
type ResumedSessionData,
coreEvents,
createWorkingStdio,
disableMouseEvents,
enableMouseEvents,
disableLineWrapping,
enableLineWrapping,
shouldEnterAlternateScreen,
recordSlowRender,
writeToStdout,
getVersion,
debugLogger,
} from '@google/gemini-cli-core';
import type { InitializationResult } from './core/initializer.js';
import type { LoadedSettings } from './config/settings.js';
import { checkForUpdates } from './ui/utils/updateCheck.js';
import { handleAutoUpdate } from './utils/handleAutoUpdate.js';
import { SettingsContext } from './ui/contexts/SettingsContext.js';
import { MouseProvider } from './ui/contexts/MouseContext.js';
import { StreamingState } from './ui/types.js';
import { computeTerminalTitle } from './utils/windowTitle.js';
import { SessionStatsProvider } from './ui/contexts/SessionContext.js';
import { VimModeProvider } from './ui/contexts/VimModeContext.js';
import { KeyMatchersProvider } from './ui/hooks/useKeyMatchers.js';
import { loadKeyMatchers } from './ui/key/keyMatchers.js';
import { KeypressProvider } from './ui/contexts/KeypressContext.js';
import { useKittyKeyboardProtocol } from './ui/hooks/useKittyKeyboardProtocol.js';
import { ScrollProvider } from './ui/contexts/ScrollProvider.js';
import { TerminalProvider } from './ui/contexts/TerminalContext.js';
import { isAlternateBufferEnabled } from './ui/hooks/useAlternateBuffer.js';
import { OverflowProvider } from './ui/contexts/OverflowContext.js';
import { profiler } from './ui/components/DebugProfiler.js';
const SLOW_RENDER_MS = 200;
export async function startInteractiveUI(
config: Config,
settings: LoadedSettings,
startupWarnings: StartupWarning[],
workspaceRoot: string = process.cwd(),
resumedSessionData: ResumedSessionData | undefined,
initializationResult: InitializationResult,
) {
// Never enter Ink alternate buffer mode when screen reader mode is enabled
// as there is no benefit of alternate buffer mode when using a screen reader
// and the Ink alternate buffer mode requires line wrapping harmful to
// screen readers.
const useAlternateBuffer = shouldEnterAlternateScreen(
isAlternateBufferEnabled(config),
config.getScreenReader(),
);
const mouseEventsEnabled = useAlternateBuffer;
if (mouseEventsEnabled) {
enableMouseEvents();
registerCleanup(() => {
disableMouseEvents();
});
}
const { matchers, errors } = await loadKeyMatchers();
errors.forEach((error) => {
coreEvents.emitFeedback('warning', error);
});
const version = await getVersion();
setWindowTitle(basename(workspaceRoot), settings);
const consolePatcher = new ConsolePatcher({
onNewMessage: (msg) => {
coreEvents.emitConsoleLog(msg.type, msg.content);
},
debugMode: config.getDebugMode(),
});
consolePatcher.patch();
registerCleanup(consolePatcher.cleanup);
const { stdout: inkStdout, stderr: inkStderr } = createWorkingStdio();
const isShpool = !!process.env['SHPOOL_SESSION_NAME'];
// Create wrapper component to use hooks inside render
const AppWrapper = () => {
useKittyKeyboardProtocol();
return (
<SettingsContext.Provider value={settings}>
<KeyMatchersProvider value={matchers}>
<KeypressProvider
config={config}
debugKeystrokeLogging={
settings.merged.general.debugKeystrokeLogging
}
>
<MouseProvider
mouseEventsEnabled={mouseEventsEnabled}
debugKeystrokeLogging={
settings.merged.general.debugKeystrokeLogging
}
>
<TerminalProvider>
<ScrollProvider>
<OverflowProvider>
<SessionStatsProvider>
<VimModeProvider>
<AppContainer
config={config}
startupWarnings={startupWarnings}
version={version}
resumedSessionData={resumedSessionData}
initializationResult={initializationResult}
/>
</VimModeProvider>
</SessionStatsProvider>
</OverflowProvider>
</ScrollProvider>
</TerminalProvider>
</MouseProvider>
</KeypressProvider>
</KeyMatchersProvider>
</SettingsContext.Provider>
);
};
if (isShpool) {
// Wait a moment for shpool to stabilize terminal size and state.
await new Promise((resolve) => setTimeout(resolve, 100));
}
const instance = render(
process.env['DEBUG'] ? (
<React.StrictMode>
<AppWrapper />
</React.StrictMode>
) : (
<AppWrapper />
),
{
stdout: inkStdout,
stderr: inkStderr,
stdin: process.stdin,
exitOnCtrlC: false,
isScreenReaderEnabled: config.getScreenReader(),
onRender: ({ renderTime }: { renderTime: number }) => {
if (renderTime > SLOW_RENDER_MS) {
recordSlowRender(config, renderTime);
}
profiler.reportFrameRendered();
},
patchConsole: false,
alternateBuffer: useAlternateBuffer,
incrementalRendering:
settings.merged.ui.incrementalRendering !== false &&
useAlternateBuffer &&
!isShpool,
},
);
if (useAlternateBuffer) {
disableLineWrapping();
registerCleanup(() => {
enableLineWrapping();
});
}
checkForUpdates(settings)
.then((info) => {
handleAutoUpdate(info, settings, config.getProjectRoot());
})
.catch((err) => {
// Silently ignore update check errors.
if (config.getDebugMode()) {
debugLogger.warn('Update check failed:', err);
}
});
registerCleanup(() => instance.unmount());
registerCleanup(setupTtyCheck());
}
function setWindowTitle(title: string, settings: LoadedSettings) {
if (!settings.merged.ui.hideWindowTitle) {
// Initial state before React loop starts
const windowTitle = computeTerminalTitle({
streamingState: StreamingState.Idle,
isConfirming: false,
isSilentWorking: false,
folderName: title,
showThoughts: !!settings.merged.ui.showStatusInTitle,
useDynamicTitle: settings.merged.ui.dynamicWindowTitle,
});
writeToStdout(`\x1b]0;${windowTitle}\x07`);
process.on('exit', () => {
writeToStdout(`\x1b]0;\x07`);
});
}
}
@@ -123,7 +123,6 @@ async function downloadFiles({
downloads.push(
(async () => {
const endpoint = `${REPO_DOWNLOAD_URL}/refs/tags/${releaseTag}/${SOURCE_DIR}/${fileBasename}`;
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(endpoint, {
method: 'GET',
dispatcher: proxy ? new ProxyAgent(proxy) : undefined,
-1
View File
@@ -61,7 +61,6 @@ export const getLatestGitHubRelease = async (
const endpoint = `https://api.github.com/repos/google-github-actions/run-gemini-cli/releases/latest`;
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(endpoint, {
method: 'GET',
headers: {
@@ -5,11 +5,8 @@
*/
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import {
A2AClientManager,
type SendMessageResult,
} from './a2a-client-manager.js';
import type { AgentCard, Task } from '@a2a-js/sdk';
import { A2AClientManager } from './a2a-client-manager.js';
import type { AgentCard } from '@a2a-js/sdk';
import {
ClientFactory,
DefaultAgentCardResolver,
@@ -22,81 +19,95 @@ import type { Config } from '../config/config.js';
import { Agent as UndiciAgent, ProxyAgent } from 'undici';
import { debugLogger } from '../utils/debugLogger.js';
interface MockClient {
sendMessageStream: ReturnType<typeof vi.fn>;
getTask: ReturnType<typeof vi.fn>;
cancelTask: ReturnType<typeof vi.fn>;
}
vi.mock('@a2a-js/sdk/client', async (importOriginal) => {
const actual = await importOriginal();
return {
...(actual as Record<string, unknown>),
createAuthenticatingFetchWithRetry: vi.fn(),
ClientFactory: vi.fn(),
DefaultAgentCardResolver: vi.fn(),
ClientFactoryOptions: {
createFrom: vi.fn(),
default: {},
},
};
});
vi.mock('../utils/debugLogger.js', () => ({
debugLogger: {
debug: vi.fn(),
},
}));
vi.mock('@a2a-js/sdk/client', () => {
const ClientFactory = vi.fn();
const DefaultAgentCardResolver = vi.fn();
const RestTransportFactory = vi.fn();
const JsonRpcTransportFactory = vi.fn();
const ClientFactoryOptions = {
default: {},
createFrom: vi.fn(),
};
const createAuthenticatingFetchWithRetry = vi.fn();
DefaultAgentCardResolver.prototype.resolve = vi.fn();
ClientFactory.prototype.createFromUrl = vi.fn();
return {
ClientFactory,
ClientFactoryOptions,
DefaultAgentCardResolver,
RestTransportFactory,
JsonRpcTransportFactory,
createAuthenticatingFetchWithRetry,
};
});
describe('A2AClientManager', () => {
let manager: A2AClientManager;
const mockAgentCard: AgentCard = {
name: 'test-agent',
description: 'A test agent',
url: 'http://test.agent',
version: '1.0.0',
protocolVersion: '0.1.0',
capabilities: {},
skills: [],
defaultInputModes: [],
defaultOutputModes: [],
};
const mockClient: MockClient = {
sendMessageStream: vi.fn(),
getTask: vi.fn(),
cancelTask: vi.fn(),
};
// Stable mocks initialized once
const sendMessageStreamMock = vi.fn();
const getTaskMock = vi.fn();
const cancelTaskMock = vi.fn();
const getAgentCardMock = vi.fn();
const authFetchMock = vi.fn();
const mockClient = {
sendMessageStream: sendMessageStreamMock,
getTask: getTaskMock,
cancelTask: cancelTaskMock,
getAgentCard: getAgentCardMock,
} as unknown as Client;
const mockAgentCard: Partial<AgentCard> = { name: 'TestAgent' };
beforeEach(() => {
vi.clearAllMocks();
A2AClientManager.resetInstanceForTesting();
manager = A2AClientManager.getInstance();
// Default mock implementations
getAgentCardMock.mockResolvedValue({
// Re-create the instances as plain objects that can be spied on
const factoryInstance = {
createFromUrl: vi.fn(),
createFromAgentCard: vi.fn(),
};
const resolverInstance = {
resolve: vi.fn(),
};
vi.mocked(ClientFactory).mockReturnValue(
factoryInstance as unknown as ClientFactory,
);
vi.mocked(DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as DefaultAgentCardResolver,
);
vi.spyOn(factoryInstance, 'createFromUrl').mockResolvedValue(
mockClient as unknown as Client,
);
vi.spyOn(factoryInstance, 'createFromAgentCard').mockResolvedValue(
mockClient as unknown as Client,
);
vi.spyOn(resolverInstance, 'resolve').mockResolvedValue({
...mockAgentCard,
url: 'http://test.agent/real/endpoint',
} as AgentCard);
vi.mocked(ClientFactory.prototype.createFromUrl).mockResolvedValue(
mockClient,
vi.spyOn(ClientFactoryOptions, 'createFrom').mockImplementation(
(_defaults, overrides) => overrides as unknown as ClientFactoryOptions,
);
vi.mocked(DefaultAgentCardResolver.prototype.resolve).mockResolvedValue({
...mockAgentCard,
url: 'http://test.agent/real/endpoint',
} as AgentCard);
vi.mocked(ClientFactoryOptions.createFrom).mockImplementation(
(_defaults, overrides) => overrides as ClientFactoryOptions,
);
vi.mocked(createAuthenticatingFetchWithRetry).mockReturnValue(
authFetchMock,
vi.mocked(createAuthenticatingFetchWithRetry).mockImplementation(() =>
authFetchMock.mockResolvedValue({
ok: true,
json: async () => ({}),
} as Response),
);
vi.stubGlobal(
@@ -170,15 +181,19 @@ describe('A2AClientManager', () => {
'TestAgent',
'http://test.agent/card',
);
expect(agentCard).toMatchObject(mockAgentCard);
expect(manager.getAgentCard('TestAgent')).toBe(agentCard);
expect(manager.getClient('TestAgent')).toBeDefined();
});
it('should configure ClientFactory with REST, JSON-RPC, and gRPC transports', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(ClientFactoryOptions.createFrom).toHaveBeenCalled();
});
it('should throw an error if an agent with the same name is already loaded', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
await expect(
manager.loadAgent('TestAgent', 'http://another.agent/card'),
manager.loadAgent('TestAgent', 'http://test.agent/card'),
).rejects.toThrow("Agent with name 'TestAgent' is already loaded.");
});
@@ -193,20 +208,12 @@ describe('A2AClientManager', () => {
shouldRetryWithHeaders: vi.fn(),
};
await manager.loadAgent(
'CustomAuthAgent',
'http://custom.agent/card',
'TestAgent',
'http://test.agent/card',
customAuthHandler as unknown as AuthenticationHandler,
);
expect(createAuthenticatingFetchWithRetry).toHaveBeenCalledWith(
expect.anything(),
customAuthHandler,
);
// Card resolver should NOT use the authenticated fetch by default.
const resolverInstance = vi.mocked(DefaultAgentCardResolver).mock
.instances[0];
expect(resolverInstance).toBeDefined();
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock
.calls[0][0];
expect(resolverOptions?.fetchImpl).not.toBe(authFetchMock);
@@ -267,106 +274,163 @@ describe('A2AClientManager', () => {
it('should log a debug message upon loading an agent', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(debugLogger.debug).toHaveBeenCalledWith(
"[A2AClientManager] Loaded agent 'TestAgent' from http://test.agent/card",
expect.stringContaining("Loaded agent 'TestAgent'"),
);
});
it('should clear the cache', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(manager.getAgentCard('TestAgent')).toBeDefined();
expect(manager.getClient('TestAgent')).toBeDefined();
manager.clearCache();
expect(manager.getAgentCard('TestAgent')).toBeUndefined();
expect(manager.getClient('TestAgent')).toBeUndefined();
expect(debugLogger.debug).toHaveBeenCalledWith(
'[A2AClientManager] Cache cleared.',
});
it('should throw if resolveAgentCard fails', async () => {
const resolverInstance = {
resolve: vi.fn().mockRejectedValue(new Error('Resolution failed')),
};
vi.mocked(DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as DefaultAgentCardResolver,
);
await expect(
manager.loadAgent('FailAgent', 'http://fail.agent'),
).rejects.toThrow('Resolution failed');
});
it('should throw if factory.createFromAgentCard fails', async () => {
const factoryInstance = {
createFromAgentCard: vi
.fn()
.mockRejectedValue(new Error('Factory failed')),
};
vi.mocked(ClientFactory).mockReturnValue(
factoryInstance as unknown as ClientFactory,
);
await expect(
manager.loadAgent('FailAgent', 'http://fail.agent'),
).rejects.toThrow('Factory failed');
});
});
describe('getAgentCard and getClient', () => {
it('should return undefined if agent is not found', () => {
expect(manager.getAgentCard('Unknown')).toBeUndefined();
expect(manager.getClient('Unknown')).toBeUndefined();
});
});
describe('sendMessageStream', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
await manager.loadAgent('TestAgent', 'http://test.agent/card');
});
it('should send a message and return a stream', async () => {
const mockResult = {
kind: 'message',
messageId: 'a',
parts: [],
role: 'agent',
} as SendMessageResult;
sendMessageStreamMock.mockReturnValue(
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield mockResult;
yield { kind: 'message' };
})(),
);
const stream = manager.sendMessageStream('TestAgent', 'Hello');
const results = [];
for await (const res of stream) {
results.push(res);
for await (const result of stream) {
results.push(result);
}
expect(results).toEqual([mockResult]);
expect(sendMessageStreamMock).toHaveBeenCalledWith(
expect(results).toHaveLength(1);
expect(mockClient.sendMessageStream).toHaveBeenCalled();
});
it('should use contextId and taskId when provided', async () => {
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield { kind: 'message' };
})(),
);
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
contextId: 'ctx123',
taskId: 'task456',
});
// trigger execution
for await (const _ of stream) {
break;
}
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
expect.objectContaining({
message: expect.anything(),
message: expect.objectContaining({
contextId: 'ctx123',
taskId: 'task456',
}),
}),
expect.any(Object),
);
});
it('should use contextId and taskId when provided', async () => {
sendMessageStreamMock.mockReturnValue(
it('should correctly propagate AbortSignal to the stream', async () => {
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield {
kind: 'message',
messageId: 'a',
parts: [],
role: 'agent',
} as SendMessageResult;
yield { kind: 'message' };
})(),
);
const expectedContextId = 'user-context-id';
const expectedTaskId = 'user-task-id';
const controller = new AbortController();
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
contextId: expectedContextId,
taskId: expectedTaskId,
signal: controller.signal,
});
// trigger execution
for await (const _ of stream) {
// consume stream
break;
}
const call = sendMessageStreamMock.mock.calls[0][0];
expect(call.message.contextId).toBe(expectedContextId);
expect(call.message.taskId).toBe(expectedTaskId);
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
expect.any(Object),
expect.objectContaining({ signal: controller.signal }),
);
});
it('should propagate the original error on failure', async () => {
sendMessageStreamMock.mockImplementationOnce(() => {
throw new Error('Network error');
it('should handle a multi-chunk stream with different event types', async () => {
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield { kind: 'message', messageId: 'm1' };
yield { kind: 'status-update', taskId: 't1' };
})(),
);
const stream = manager.sendMessageStream('TestAgent', 'Hello');
const results = [];
for await (const result of stream) {
results.push(result);
}
expect(results).toHaveLength(2);
expect(results[0].kind).toBe('message');
expect(results[1].kind).toBe('status-update');
});
it('should throw prefixed error on failure', async () => {
mockClient.sendMessageStream.mockImplementation(() => {
throw new Error('Network failure');
});
const stream = manager.sendMessageStream('TestAgent', 'Hello');
await expect(async () => {
for await (const _ of stream) {
// consume
// empty
}
}).rejects.toThrow('Network error');
}).rejects.toThrow(
'[A2AClientManager] sendMessageStream Error [TestAgent]: Network failure',
);
});
it('should throw an error if the agent is not found', async () => {
const stream = manager.sendMessageStream('NonExistentAgent', 'Hello');
await expect(async () => {
for await (const _ of stream) {
// consume
// empty
}
}).rejects.toThrow("Agent 'NonExistentAgent' not found.");
});
@@ -374,28 +438,23 @@ describe('A2AClientManager', () => {
describe('getTask', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
await manager.loadAgent('TestAgent', 'http://test.agent/card');
});
it('should get a task from the correct agent', async () => {
getTaskMock.mockResolvedValue({
id: 'task123',
contextId: 'a',
kind: 'task',
status: { state: 'completed' },
} as Task);
const mockTask = { id: 'task123', kind: 'task' };
mockClient.getTask.mockResolvedValue(mockTask);
await manager.getTask('TestAgent', 'task123');
expect(getTaskMock).toHaveBeenCalledWith({
id: 'task123',
});
const result = await manager.getTask('TestAgent', 'task123');
expect(result).toBe(mockTask);
expect(mockClient.getTask).toHaveBeenCalledWith({ id: 'task123' });
});
it('should throw prefixed error on failure', async () => {
getTaskMock.mockRejectedValueOnce(new Error('Network error'));
mockClient.getTask.mockRejectedValue(new Error('Not found'));
await expect(manager.getTask('TestAgent', 'task123')).rejects.toThrow(
'A2AClient getTask Error [TestAgent]: Network error',
'A2AClient getTask Error [TestAgent]: Not found',
);
});
@@ -408,28 +467,23 @@ describe('A2AClientManager', () => {
describe('cancelTask', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
await manager.loadAgent('TestAgent', 'http://test.agent/card');
});
it('should cancel a task on the correct agent', async () => {
cancelTaskMock.mockResolvedValue({
id: 'task123',
contextId: 'a',
kind: 'task',
status: { state: 'canceled' },
} as Task);
const mockTask = { id: 'task123', kind: 'task' };
mockClient.cancelTask.mockResolvedValue(mockTask);
await manager.cancelTask('TestAgent', 'task123');
expect(cancelTaskMock).toHaveBeenCalledWith({
id: 'task123',
});
const result = await manager.cancelTask('TestAgent', 'task123');
expect(result).toBe(mockTask);
expect(mockClient.cancelTask).toHaveBeenCalledWith({ id: 'task123' });
});
it('should throw prefixed error on failure', async () => {
cancelTaskMock.mockRejectedValueOnce(new Error('Network error'));
mockClient.cancelTask.mockRejectedValue(new Error('Cannot cancel'));
await expect(manager.cancelTask('TestAgent', 'task123')).rejects.toThrow(
'A2AClient cancelTask Error [TestAgent]: Network error',
'A2AClient cancelTask Error [TestAgent]: Cannot cancel',
);
});
+48 -30
View File
@@ -12,36 +12,41 @@ import type {
TaskStatusUpdateEvent,
TaskArtifactUpdateEvent,
} from '@a2a-js/sdk';
import type { AuthenticationHandler, Client } from '@a2a-js/sdk/client';
import {
type Client,
ClientFactory,
ClientFactoryOptions,
DefaultAgentCardResolver,
RestTransportFactory,
JsonRpcTransportFactory,
type AuthenticationHandler,
RestTransportFactory,
createAuthenticatingFetchWithRetry,
} from '@a2a-js/sdk/client';
import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc';
import * as grpc from '@grpc/grpc-js';
import { v4 as uuidv4 } from 'uuid';
import { Agent as UndiciAgent, ProxyAgent } from 'undici';
import { normalizeAgentCard } from './a2aUtils.js';
import type { Config } from '../config/config.js';
import { debugLogger } from '../utils/debugLogger.js';
import { safeLookup } from '../utils/fetch.js';
import { classifyAgentError } from './a2a-errors.js';
// Remote agents can take 10+ minutes (e.g. Deep Research).
// Use a dedicated dispatcher so the global 5-min timeout isn't affected.
const A2A_TIMEOUT = 1800000; // 30 minutes
/**
* Result of sending a message, which can be a full message, a task,
* or an incremental status/artifact update.
*/
export type SendMessageResult =
| Message
| Task
| TaskStatusUpdateEvent
| TaskArtifactUpdateEvent;
// Remote agents can take 10+ minutes (e.g. Deep Research).
// Use a dedicated dispatcher so the global 5-min timeout isn't affected.
const A2A_TIMEOUT = 1800000; // 30 minutes
/**
* Manages A2A clients and caches loaded agent information.
* Follows a singleton pattern to ensure a single client instance.
* Orchestrates communication with remote A2A agents.
* Manages protocol negotiation, authentication, and transport selection.
*/
export class A2AClientManager {
private static instance: A2AClientManager;
@@ -58,9 +63,6 @@ export class A2AClientManager {
const agentOptions = {
headersTimeout: A2A_TIMEOUT,
bodyTimeout: A2A_TIMEOUT,
connect: {
lookup: safeLookup, // SSRF protection at connection level
},
};
if (proxyUrl) {
@@ -73,7 +75,6 @@ export class A2AClientManager {
}
this.a2aFetch = (input, init) =>
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
fetch(input, { ...init, dispatcher: this.a2aDispatcher } as RequestInit);
}
@@ -139,22 +140,35 @@ export class A2AClientManager {
};
const resolver = new DefaultAgentCardResolver({ fetchImpl: cardFetch });
const rawCard = await resolver.resolve(agentCardUrl, '');
// TODO: Remove normalizeAgentCard once @a2a-js/sdk handles
// proto field name aliases (supportedInterfaces → additionalInterfaces,
// protocolBinding → transport).
const agentCard = normalizeAgentCard(rawCard);
const options = ClientFactoryOptions.createFrom(
const grpcUrl =
agentCard.additionalInterfaces?.find((i) => i.transport === 'GRPC')
?.url ?? agentCard.url;
const clientOptions = ClientFactoryOptions.createFrom(
ClientFactoryOptions.default,
{
transports: [
new RestTransportFactory({ fetchImpl: authFetch }),
new JsonRpcTransportFactory({ fetchImpl: authFetch }),
new GrpcTransportFactory({
grpcChannelCredentials: grpcUrl.startsWith('https://')
? grpc.credentials.createSsl()
: grpc.credentials.createInsecure(),
}),
],
cardResolver: resolver,
},
);
try {
const factory = new ClientFactory(options);
const client = await factory.createFromUrl(agentCardUrl, '');
const agentCard = await client.getAgentCard();
const factory = new ClientFactory(clientOptions);
const client = await factory.createFromAgentCard(agentCard);
this.clients.set(name, client);
this.agentCards.set(name, agentCard);
@@ -192,9 +206,7 @@ export class A2AClientManager {
options?: { contextId?: string; taskId?: string; signal?: AbortSignal },
): AsyncIterable<SendMessageResult> {
const client = this.clients.get(agentName);
if (!client) {
throw new Error(`Agent '${agentName}' not found.`);
}
if (!client) throw new Error(`Agent '${agentName}' not found.`);
const messageParams: MessageSendParams = {
message: {
@@ -207,9 +219,19 @@ export class A2AClientManager {
},
};
yield* client.sendMessageStream(messageParams, {
signal: options?.signal,
});
try {
yield* client.sendMessageStream(messageParams, {
signal: options?.signal,
});
} catch (error: unknown) {
const prefix = `[A2AClientManager] sendMessageStream Error [${agentName}]`;
if (error instanceof Error) {
throw new Error(`${prefix}: ${error.message}`, { cause: error });
}
throw new Error(
`${prefix}: Unexpected error during sendMessageStream: ${String(error)}`,
);
}
}
/**
@@ -238,9 +260,7 @@ export class A2AClientManager {
*/
async getTask(agentName: string, taskId: string): Promise<Task> {
const client = this.clients.get(agentName);
if (!client) {
throw new Error(`Agent '${agentName}' not found.`);
}
if (!client) throw new Error(`Agent '${agentName}' not found.`);
try {
return await client.getTask({ id: taskId });
} catch (error: unknown) {
@@ -260,9 +280,7 @@ export class A2AClientManager {
*/
async cancelTask(agentName: string, taskId: string): Promise<Task> {
const client = this.clients.get(agentName);
if (!client) {
throw new Error(`Agent '${agentName}' not found.`);
}
if (!client) throw new Error(`Agent '${agentName}' not found.`);
try {
return await client.cancelTask({ id: taskId });
} catch (error: unknown) {
+20 -171
View File
@@ -12,9 +12,6 @@ import {
A2AResultReassembler,
AUTH_REQUIRED_MSG,
normalizeAgentCard,
getGrpcCredentials,
pinUrlToIp,
splitAgentCardUrl,
} from './a2aUtils.js';
import type { SendMessageResult } from './a2a-client-manager.js';
import type {
@@ -26,12 +23,6 @@ import type {
TaskStatusUpdateEvent,
TaskArtifactUpdateEvent,
} from '@a2a-js/sdk';
import * as dnsPromises from 'node:dns/promises';
import type { LookupAddress } from 'node:dns';
vi.mock('node:dns/promises', () => ({
lookup: vi.fn(),
}));
describe('a2aUtils', () => {
beforeEach(() => {
@@ -42,89 +33,6 @@ describe('a2aUtils', () => {
vi.restoreAllMocks();
});
describe('getGrpcCredentials', () => {
it('should return secure credentials for https', () => {
const credentials = getGrpcCredentials('https://test.agent');
expect(credentials).toBeDefined();
});
it('should return insecure credentials for http', () => {
const credentials = getGrpcCredentials('http://test.agent');
expect(credentials).toBeDefined();
});
});
describe('pinUrlToIp', () => {
it('should resolve and pin hostname to IP', async () => {
vi.mocked(
dnsPromises.lookup as unknown as (
hostname: string,
options: { all: true },
) => Promise<LookupAddress[]>,
).mockResolvedValue([{ address: '93.184.216.34', family: 4 }]);
const { pinnedUrl, hostname } = await pinUrlToIp(
'http://example.com:9000',
'test-agent',
);
expect(hostname).toBe('example.com');
expect(pinnedUrl).toBe('http://93.184.216.34:9000/');
});
it('should handle raw host:port strings (standard for gRPC)', async () => {
vi.mocked(
dnsPromises.lookup as unknown as (
hostname: string,
options: { all: true },
) => Promise<LookupAddress[]>,
).mockResolvedValue([{ address: '93.184.216.34', family: 4 }]);
const { pinnedUrl, hostname } = await pinUrlToIp(
'example.com:9000',
'test-agent',
);
expect(hostname).toBe('example.com');
expect(pinnedUrl).toBe('93.184.216.34:9000');
});
it('should throw error if resolution fails (fail closed)', async () => {
vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error'));
await expect(
pinUrlToIp('http://unreachable.com', 'test-agent'),
).rejects.toThrow("Failed to resolve host for agent 'test-agent'");
});
it('should throw error if resolved to private IP', async () => {
vi.mocked(
dnsPromises.lookup as unknown as (
hostname: string,
options: { all: true },
) => Promise<LookupAddress[]>,
).mockResolvedValue([{ address: '10.0.0.1', family: 4 }]);
await expect(
pinUrlToIp('http://malicious.com', 'test-agent'),
).rejects.toThrow('resolves to private IP range');
});
it('should allow localhost/127.0.0.1/::1 exceptions', async () => {
vi.mocked(
dnsPromises.lookup as unknown as (
hostname: string,
options: { all: true },
) => Promise<LookupAddress[]>,
).mockResolvedValue([{ address: '127.0.0.1', family: 4 }]);
const { pinnedUrl, hostname } = await pinUrlToIp(
'http://localhost:9000',
'test-agent',
);
expect(hostname).toBe('localhost');
expect(pinnedUrl).toBe('http://127.0.0.1:9000/');
});
});
describe('isTerminalState', () => {
it('should return true for completed, failed, canceled, and rejected', () => {
expect(isTerminalState('completed')).toBe(true);
@@ -365,12 +273,12 @@ describe('a2aUtils', () => {
expect(normalized.name).toBe('my-agent');
// @ts-expect-error - testing dynamic preservation
expect(normalized.customField).toBe('keep-me');
expect(normalized.description).toBe('');
expect(normalized.skills).toEqual([]);
expect(normalized.defaultInputModes).toEqual([]);
expect(normalized.description).toBeUndefined();
expect(normalized.skills).toBeUndefined();
expect(normalized.defaultInputModes).toBeUndefined();
});
it('should normalize and synchronize interfaces while preserving other fields', () => {
it('should map supportedInterfaces to additionalInterfaces with protocolBinding → transport', () => {
const raw = {
name: 'test',
supportedInterfaces: [
@@ -384,13 +292,7 @@ describe('a2aUtils', () => {
const normalized = normalizeAgentCard(raw);
// Should exist in both fields
expect(normalized.additionalInterfaces).toHaveLength(1);
expect(
(normalized as unknown as Record<string, unknown>)[
'supportedInterfaces'
],
).toHaveLength(1);
const intf = normalized.additionalInterfaces?.[0] as unknown as Record<
string,
@@ -399,43 +301,18 @@ describe('a2aUtils', () => {
expect(intf['transport']).toBe('GRPC');
expect(intf['url']).toBe('grpc://test');
// Should fallback top-level url
expect(normalized.url).toBe('grpc://test');
});
it('should preserve existing top-level url if present', () => {
it('should not overwrite additionalInterfaces if already present', () => {
const raw = {
name: 'test',
url: 'http://existing',
additionalInterfaces: [{ url: 'http://grpc', transport: 'GRPC' }],
supportedInterfaces: [{ url: 'http://other', transport: 'REST' }],
};
const normalized = normalizeAgentCard(raw);
expect(normalized.url).toBe('http://existing');
});
it('should NOT prepend http:// scheme to raw IP:port strings for gRPC interfaces', () => {
const raw = {
name: 'raw-ip-grpc',
supportedInterfaces: [{ url: '127.0.0.1:9000', transport: 'GRPC' }],
};
const normalized = normalizeAgentCard(raw);
expect(normalized.additionalInterfaces?.[0].url).toBe('127.0.0.1:9000');
expect(normalized.url).toBe('127.0.0.1:9000');
});
it('should prepend http:// scheme to raw IP:port strings for REST interfaces', () => {
const raw = {
name: 'raw-ip-rest',
supportedInterfaces: [{ url: '127.0.0.1:8080', transport: 'REST' }],
};
const normalized = normalizeAgentCard(raw);
expect(normalized.additionalInterfaces?.[0].url).toBe(
'http://127.0.0.1:8080',
);
expect(normalized.additionalInterfaces).toHaveLength(1);
expect(normalized.additionalInterfaces?.[0].url).toBe('http://grpc');
});
it('should NOT override existing transport if protocolBinding is also present', () => {
@@ -448,48 +325,20 @@ describe('a2aUtils', () => {
const normalized = normalizeAgentCard(raw);
expect(normalized.additionalInterfaces?.[0].transport).toBe('GRPC');
});
});
describe('splitAgentCardUrl', () => {
const standard = '.well-known/agent-card.json';
it('should not mutate the original card object', () => {
const raw = {
name: 'test',
supportedInterfaces: [{ url: 'grpc://test', protocolBinding: 'GRPC' }],
};
it('should return baseUrl as-is if it does not end with standard path', () => {
const url = 'http://localhost:9001/custom/path';
expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url });
});
it('should split correctly if URL ends with standard path', () => {
const url = `http://localhost:9001/${standard}`;
expect(splitAgentCardUrl(url)).toEqual({
baseUrl: 'http://localhost:9001/',
path: undefined,
});
});
it('should handle trailing slash in baseUrl when splitting', () => {
const url = `http://example.com/api/${standard}`;
expect(splitAgentCardUrl(url)).toEqual({
baseUrl: 'http://example.com/api/',
path: undefined,
});
});
it('should ignore hashes and query params when splitting', () => {
const url = `http://localhost:9001/${standard}?foo=bar#baz`;
expect(splitAgentCardUrl(url)).toEqual({
baseUrl: 'http://localhost:9001/',
path: undefined,
});
});
it('should return original URL if parsing fails', () => {
const url = 'not-a-url';
expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url });
});
it('should handle standard path appearing earlier in the path', () => {
const url = `http://localhost:9001/${standard}/something-else`;
expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url });
const normalized = normalizeAgentCard(raw);
expect(normalized).not.toBe(raw);
expect(normalized.additionalInterfaces).toBeDefined();
// Original should not have additionalInterfaces added
expect(
(raw as Record<string, unknown>)['additionalInterfaces'],
).toBeUndefined();
});
});
+24 -234
View File
@@ -4,9 +4,6 @@
* SPDX-License-Identifier: Apache-2.0
*/
import * as grpc from '@grpc/grpc-js';
import { lookup } from 'node:dns/promises';
import { z } from 'zod';
import type {
Message,
Part,
@@ -18,37 +15,10 @@ import type {
AgentCard,
AgentInterface,
} from '@a2a-js/sdk';
import { isAddressPrivate } from '../utils/fetch.js';
import type { SendMessageResult } from './a2a-client-manager.js';
export const AUTH_REQUIRED_MSG = `[Authorization Required] The agent has indicated it requires authorization to proceed. Please follow the agent's instructions.`;
const AgentInterfaceSchema = z
.object({
url: z.string().default(''),
transport: z.string().optional(),
protocolBinding: z.string().optional(),
})
.passthrough();
const AgentCardSchema = z
.object({
name: z.string().default('unknown'),
description: z.string().default(''),
url: z.string().default(''),
version: z.string().default(''),
protocolVersion: z.string().default(''),
capabilities: z.record(z.unknown()).default({}),
skills: z.array(z.union([z.string(), z.record(z.unknown())])).default([]),
defaultInputModes: z.array(z.string()).default([]),
defaultOutputModes: z.array(z.string()).default([]),
additionalInterfaces: z.array(AgentInterfaceSchema).optional(),
supportedInterfaces: z.array(AgentInterfaceSchema).optional(),
preferredTransport: z.string().optional(),
})
.passthrough();
/**
* Reassembles incremental A2A streaming updates into a coherent result.
* Shows sequential status/messages followed by all reassembled artifacts.
@@ -241,166 +211,45 @@ function extractPartText(part: Part): string {
}
/**
* Normalizes an agent card by ensuring it has the required properties
* and resolving any inconsistencies between protocol versions.
* Normalizes proto field name aliases that the SDK doesn't handle yet.
* The A2A proto spec uses `supported_interfaces` and `protocol_binding`,
* while the SDK expects `additionalInterfaces` and `transport`.
* TODO: Remove once @a2a-js/sdk handles these aliases natively.
*/
export function normalizeAgentCard(card: unknown): AgentCard {
if (!isObject(card)) {
throw new Error('Agent card is missing.');
}
// Use Zod to validate and parse the card, ensuring safe defaults and narrowing types.
const parsed = AgentCardSchema.parse(card);
// Narrowing to AgentCard interface after runtime validation.
// Shallow-copy to avoid mutating the SDK's cached object.
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const result = parsed as unknown as AgentCard;
const result = { ...card } as unknown as AgentCard;
// Normalize interfaces and synchronize both interface fields.
const normalizedInterfaces = extractNormalizedInterfaces(parsed);
result.additionalInterfaces = normalizedInterfaces;
// Sync supportedInterfaces for backward compatibility.
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const legacyResult = result as unknown as Record<string, AgentInterface[]>;
legacyResult['supportedInterfaces'] = normalizedInterfaces;
// Fallback preferredTransport: If not specified, default to GRPC if available.
if (
!result.preferredTransport &&
normalizedInterfaces.some((i) => i.transport === 'GRPC')
) {
result.preferredTransport = 'GRPC';
// Map supportedInterfaces → additionalInterfaces if needed
if (!result.additionalInterfaces) {
const raw = card;
if (Array.isArray(raw['supportedInterfaces'])) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
result.additionalInterfaces = raw[
'supportedInterfaces'
] as AgentInterface[];
}
}
// Fallback: If top-level URL is missing, use the first interface's URL.
if (result.url === '' && normalizedInterfaces.length > 0) {
result.url = normalizedInterfaces[0].url;
// Map protocolBinding → transport on each interface
for (const intf of result.additionalInterfaces ?? []) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const raw = intf as unknown as Record<string, unknown>;
const binding = raw['protocolBinding'];
if (!intf.transport && typeof binding === 'string') {
intf.transport = binding;
}
}
return result;
}
/**
* Returns gRPC channel credentials based on the URL scheme.
*/
export function getGrpcCredentials(url: string): grpc.ChannelCredentials {
return url.startsWith('https://')
? grpc.credentials.createSsl()
: grpc.credentials.createInsecure();
}
/**
* Returns gRPC channel options to ensure SSL/authority matches the original hostname
* when connecting via a pinned IP address.
*/
export function getGrpcChannelOptions(
hostname: string,
): Record<string, unknown> {
return {
'grpc.default_authority': hostname,
'grpc.ssl_target_name_override': hostname,
};
}
/**
* Resolves a hostname to its IP address and validates it against SSRF.
* Returns the pinned IP-based URL and the original hostname.
*/
export async function pinUrlToIp(
url: string,
agentName: string,
): Promise<{ pinnedUrl: string; hostname: string }> {
if (!url) return { pinnedUrl: url, hostname: '' };
// gRPC URLs in A2A can be 'host:port' or 'dns:///host:port' or have schemes.
// We normalize to host:port for resolution.
const hasScheme = url.includes('://');
const normalizedUrl = hasScheme ? url : `http://${url}`;
try {
const parsed = new URL(normalizedUrl);
const hostname = parsed.hostname;
const sanitizedHost =
hostname.startsWith('[') && hostname.endsWith(']')
? hostname.slice(1, -1)
: hostname;
// Resolve DNS to check the actual target IP and pin it
const addresses = await lookup(hostname, { all: true });
const publicAddresses = addresses.filter(
(addr) =>
!isAddressPrivate(addr.address) ||
sanitizedHost === 'localhost' ||
sanitizedHost === '127.0.0.1' ||
sanitizedHost === '::1',
);
if (publicAddresses.length === 0) {
if (addresses.length > 0) {
throw new Error(
`Refusing to load agent '${agentName}': transport URL '${url}' resolves to private IP range.`,
);
}
throw new Error(
`Failed to resolve any public IP addresses for host: ${hostname}`,
);
}
const pinnedIp = publicAddresses[0].address;
const pinnedHostname = pinnedIp.includes(':') ? `[${pinnedIp}]` : pinnedIp;
// Reconstruct URL with IP
parsed.hostname = pinnedHostname;
let pinnedUrl = parsed.toString();
// If original didn't have scheme, remove it (standard for gRPC targets)
if (!hasScheme) {
pinnedUrl = pinnedUrl.replace(/^http:\/\//, '');
// URL.toString() might append a trailing slash
if (pinnedUrl.endsWith('/') && !url.endsWith('/')) {
pinnedUrl = pinnedUrl.slice(0, -1);
}
}
return { pinnedUrl, hostname };
} catch (e) {
if (e instanceof Error && e.message.includes('Refusing')) throw e;
throw new Error(`Failed to resolve host for agent '${agentName}': ${url}`, {
cause: e,
});
}
}
/**
* Splts an agent card URL into a baseUrl and a standard path if it already
* contains '.well-known/agent-card.json'.
*/
export function splitAgentCardUrl(url: string): {
baseUrl: string;
path?: string;
} {
const standardPath = '.well-known/agent-card.json';
try {
const parsedUrl = new URL(url);
if (parsedUrl.pathname.endsWith(standardPath)) {
// Reconstruct baseUrl from parsed components to avoid issues with hashes or query params.
parsedUrl.pathname = parsedUrl.pathname.substring(
0,
parsedUrl.pathname.lastIndexOf(standardPath),
);
parsedUrl.search = '';
parsedUrl.hash = '';
// We return undefined for path if it's the standard one,
// because the SDK's DefaultAgentCardResolver appends it automatically.
return { baseUrl: parsedUrl.toString(), path: undefined };
}
} catch (_e) {
// Ignore URL parsing errors here, let the resolver handle them.
}
return { baseUrl: url };
}
/**
* Extracts contextId and taskId from a Message, Task, or Update response.
* Follows the pattern from the A2A CLI sample to maintain conversational continuity.
@@ -446,65 +295,6 @@ export function extractIdsFromResponse(result: SendMessageResult): {
return { contextId, taskId, clearTaskId };
}
/**
* Extracts and normalizes interfaces from the card, handling protocol version fallbacks.
* Preserves all original fields to maintain SDK compatibility.
*/
function extractNormalizedInterfaces(
card: Record<string, unknown>,
): AgentInterface[] {
const rawInterfaces =
getArray(card, 'additionalInterfaces') ||
getArray(card, 'supportedInterfaces');
if (!rawInterfaces) {
return [];
}
const mapped: AgentInterface[] = [];
for (const i of rawInterfaces) {
if (isObject(i)) {
// Use schema to validate interface object.
const parsed = AgentInterfaceSchema.parse(i);
// Narrowing to AgentInterface after runtime validation.
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const normalized = parsed as unknown as AgentInterface & {
protocolBinding?: string;
};
// Normalize 'transport' from 'protocolBinding' if missing.
if (!normalized.transport && normalized.protocolBinding) {
normalized.transport = normalized.protocolBinding;
}
// Robust URL: Ensure the URL has a scheme (except for gRPC).
if (
normalized.url &&
!normalized.url.includes('://') &&
!normalized.url.startsWith('/') &&
normalized.transport !== 'GRPC'
) {
// Default to http:// for insecure REST/JSON-RPC if scheme is missing.
normalized.url = `http://${normalized.url}`;
}
mapped.push(normalized as AgentInterface);
}
}
return mapped;
}
/**
* Safely extracts an array property from an object.
*/
function getArray(
obj: Record<string, unknown>,
key: string,
): unknown[] | undefined {
const val = obj[key];
return Array.isArray(val) ? val : undefined;
}
// Type Guards
function isTextPart(part: Part): part is TextPart {
-1
View File
@@ -700,7 +700,6 @@ async function fetchAndCacheUserInfo(client: OAuth2Client): Promise<void> {
return;
}
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(
'https://www.googleapis.com/oauth2/v2/userinfo',
{
-2
View File
@@ -111,7 +111,6 @@ export class MCPOAuthProvider {
scope: config.scopes?.join(' ') || '',
};
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(registrationUrl, {
method: 'POST',
headers: {
@@ -301,7 +300,6 @@ export class MCPOAuthProvider {
? { Accept: 'text/event-stream' }
: { Accept: 'application/json' };
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(mcpServerUrl, {
method: 'HEAD',
headers,
-2
View File
@@ -97,7 +97,6 @@ export class OAuthUtils {
resourceMetadataUrl: string,
): Promise<OAuthProtectedResourceMetadata | null> {
try {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(resourceMetadataUrl);
if (!response.ok) {
return null;
@@ -122,7 +121,6 @@ export class OAuthUtils {
authServerMetadataUrl: string,
): Promise<OAuthAuthorizationServerMetadata | null> {
try {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(authServerMetadataUrl);
if (!response.ok) {
return null;
@@ -546,7 +546,6 @@ export class ClearcutLogger {
let result: LogResponse = {};
try {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(CLEARCUT_URL, {
method: 'POST',
body: safeJsonStringify(request),
-1
View File
@@ -1903,7 +1903,6 @@ export async function connectToMcpServer(
acceptHeader = 'application/json';
}
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(urlToFetch, {
method: 'HEAD',
headers: {
+1 -168
View File
@@ -5,27 +5,12 @@
*/
import { describe, it, expect, vi, beforeEach, afterAll } from 'vitest';
import {
isPrivateIp,
isPrivateIpAsync,
isAddressPrivate,
safeLookup,
safeFetch,
fetchWithTimeout,
PrivateIpError,
} from './fetch.js';
import * as dnsPromises from 'node:dns/promises';
import * as dns from 'node:dns';
import { isPrivateIp, isAddressPrivate, fetchWithTimeout } from './fetch.js';
vi.mock('node:dns/promises', () => ({
lookup: vi.fn(),
}));
// We need to mock node:dns for safeLookup since it uses the callback API
vi.mock('node:dns', () => ({
lookup: vi.fn(),
}));
// Mock global fetch
const originalFetch = global.fetch;
global.fetch = vi.fn();
@@ -114,150 +99,6 @@ describe('fetch utils', () => {
});
});
describe('isPrivateIpAsync', () => {
it('should identify private IPs directly', async () => {
expect(await isPrivateIpAsync('http://10.0.0.1/')).toBe(true);
});
it('should identify domains resolving to private IPs', async () => {
vi.mocked(dnsPromises.lookup).mockImplementation(
async () =>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
[{ address: '10.0.0.1', family: 4 }] as any,
);
expect(await isPrivateIpAsync('http://malicious.com/')).toBe(true);
});
it('should identify domains resolving to public IPs as non-private', async () => {
vi.mocked(dnsPromises.lookup).mockImplementation(
async () =>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
[{ address: '8.8.8.8', family: 4 }] as any,
);
expect(await isPrivateIpAsync('http://google.com/')).toBe(false);
});
it('should throw error if DNS resolution fails (fail closed)', async () => {
vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error'));
await expect(isPrivateIpAsync('http://unreachable.com/')).rejects.toThrow(
'Failed to verify if URL resolves to private IP',
);
});
it('should return false for invalid URLs instead of throwing verification error', async () => {
expect(await isPrivateIpAsync('not-a-url')).toBe(false);
});
});
describe('safeLookup', () => {
it('should filter out private IPs', async () => {
const addresses = [
{ address: '8.8.8.8', family: 4 },
{ address: '10.0.0.1', family: 4 },
];
vi.mocked(dns.lookup).mockImplementation(((
_h: string,
_o: dns.LookupOptions,
cb: (
err: Error | null,
addr: Array<{ address: string; family: number }>,
) => void,
) => {
cb(null, addresses);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
}) as any);
const result = await new Promise<
Array<{ address: string; family: number }>
>((resolve, reject) => {
safeLookup('example.com', { all: true }, (err, filtered) => {
if (err) reject(err);
else resolve(filtered);
});
});
expect(result).toHaveLength(1);
expect(result[0].address).toBe('8.8.8.8');
});
it('should allow explicit localhost', async () => {
const addresses = [{ address: '127.0.0.1', family: 4 }];
vi.mocked(dns.lookup).mockImplementation(((
_h: string,
_o: dns.LookupOptions,
cb: (
err: Error | null,
addr: Array<{ address: string; family: number }>,
) => void,
) => {
cb(null, addresses);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
}) as any);
const result = await new Promise<
Array<{ address: string; family: number }>
>((resolve, reject) => {
safeLookup('localhost', { all: true }, (err, filtered) => {
if (err) reject(err);
else resolve(filtered);
});
});
expect(result).toHaveLength(1);
expect(result[0].address).toBe('127.0.0.1');
});
it('should error if all resolved IPs are private', async () => {
const addresses = [{ address: '10.0.0.1', family: 4 }];
vi.mocked(dns.lookup).mockImplementation(((
_h: string,
_o: dns.LookupOptions,
cb: (
err: Error | null,
addr: Array<{ address: string; family: number }>,
) => void,
) => {
cb(null, addresses);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
}) as any);
await expect(
new Promise((resolve, reject) => {
safeLookup('malicious.com', { all: true }, (err, filtered) => {
if (err) reject(err);
else resolve(filtered);
});
}),
).rejects.toThrow(PrivateIpError);
});
});
describe('safeFetch', () => {
it('should forward to fetch with dispatcher', async () => {
vi.mocked(global.fetch).mockResolvedValue(new Response('ok'));
const response = await safeFetch('https://example.com');
expect(response.status).toBe(200);
expect(global.fetch).toHaveBeenCalledWith(
'https://example.com',
expect.objectContaining({
dispatcher: expect.any(Object),
}),
);
});
it('should handle Refusing to connect errors', async () => {
vi.mocked(global.fetch).mockRejectedValue(new PrivateIpError());
await expect(safeFetch('http://10.0.0.1')).rejects.toThrow(
'Access to private network is blocked',
);
});
});
describe('fetchWithTimeout', () => {
it('should handle timeouts', async () => {
vi.mocked(global.fetch).mockImplementation(
@@ -279,13 +120,5 @@ describe('fetch utils', () => {
'Request timed out after 50ms',
);
});
it('should handle private IP errors via handleFetchError', async () => {
vi.mocked(global.fetch).mockRejectedValue(new PrivateIpError());
await expect(fetchWithTimeout('http://10.0.0.1', 1000)).rejects.toThrow(
'Access to private network is blocked: http://10.0.0.1',
);
});
});
});
+13 -169
View File
@@ -6,37 +6,12 @@
import { getErrorMessage, isNodeError } from './errors.js';
import { URL } from 'node:url';
import * as dns from 'node:dns';
import { lookup } from 'node:dns/promises';
import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici';
import ipaddr from 'ipaddr.js';
const DEFAULT_HEADERS_TIMEOUT = 300000; // 5 minutes
const DEFAULT_BODY_TIMEOUT = 300000; // 5 minutes
// Configure default global dispatcher with higher timeouts
setGlobalDispatcher(
new Agent({
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
bodyTimeout: DEFAULT_BODY_TIMEOUT,
}),
);
// Local extension of RequestInit to support Node.js/undici dispatcher
interface NodeFetchInit extends RequestInit {
dispatcher?: Agent | ProxyAgent;
}
/**
* Error thrown when a connection to a private IP address is blocked for security reasons.
*/
export class PrivateIpError extends Error {
constructor(message = 'Refusing to connect to private IP address') {
super(message);
this.name = 'PrivateIpError';
}
}
export class FetchError extends Error {
constructor(
message: string,
@@ -48,6 +23,14 @@ export class FetchError extends Error {
}
}
// Configure default global dispatcher with higher timeouts
setGlobalDispatcher(
new Agent({
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
bodyTimeout: DEFAULT_BODY_TIMEOUT,
}),
);
/**
* Sanitizes a hostname by stripping IPv6 brackets if present.
*/
@@ -69,53 +52,6 @@ export function isLoopbackHost(hostname: string): boolean {
);
}
/**
* A custom DNS lookup implementation for undici agents that prevents
* connection to private IP ranges (SSRF protection).
*/
export function safeLookup(
hostname: string,
options: dns.LookupOptions | number | null | undefined,
callback: (
err: Error | null,
addresses: Array<{ address: string; family: number }>,
) => void,
): void {
// Use the callback-based dns.lookup to match undici's expected signature.
// We explicitly handle the 'all' option to ensure we get an array of addresses.
const lookupOptions =
typeof options === 'number' ? { family: options } : { ...options };
const finalOptions = { ...lookupOptions, all: true };
dns.lookup(hostname, finalOptions, (err, addresses) => {
if (err) {
callback(err, []);
return;
}
const addressArray = Array.isArray(addresses) ? addresses : [];
const filtered = addressArray.filter(
(addr) => !isAddressPrivate(addr.address) || isLoopbackHost(hostname),
);
if (filtered.length === 0 && addressArray.length > 0) {
callback(new PrivateIpError(), []);
return;
}
callback(null, filtered);
});
}
// Dedicated dispatcher with connection-level SSRF protection (safeLookup)
const safeDispatcher = new Agent({
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
bodyTimeout: DEFAULT_BODY_TIMEOUT,
connect: {
lookup: safeLookup,
},
});
export function isPrivateIp(url: string): boolean {
try {
const hostname = new URL(url).hostname;
@@ -125,37 +61,6 @@ export function isPrivateIp(url: string): boolean {
}
}
/**
* Checks if a URL resolves to a private IP address.
* Performs DNS resolution to prevent DNS rebinding/SSRF bypasses.
*/
export async function isPrivateIpAsync(url: string): Promise<boolean> {
try {
const parsed = new URL(url);
const hostname = parsed.hostname;
// Fast check for literal IPs or localhost
if (isAddressPrivate(hostname)) {
return true;
}
// Resolve DNS to check the actual target IP
const addresses = await lookup(hostname, { all: true });
return addresses.some((addr) => isAddressPrivate(addr.address));
} catch (e) {
if (
e instanceof Error &&
e.name === 'TypeError' &&
e.message.includes('Invalid URL')
) {
return false;
}
throw new Error(`Failed to verify if URL resolves to private IP: ${url}`, {
cause: e,
});
}
}
/**
* IANA Benchmark Testing Range (198.18.0.0/15).
* Classified as 'unicast' by ipaddr.js but is reserved and should not be
@@ -210,72 +115,15 @@ export function isAddressPrivate(address: string): boolean {
}
}
/**
* Internal helper to map varied fetch errors to a standardized FetchError.
* Centralizes security-related error mapping (e.g. PrivateIpError).
*/
function handleFetchError(error: unknown, url: string): never {
if (error instanceof PrivateIpError) {
throw new FetchError(
`Access to private network is blocked: ${url}`,
'ERR_PRIVATE_NETWORK',
{ cause: error },
);
}
if (error instanceof FetchError) {
throw error;
}
throw new FetchError(
getErrorMessage(error),
isNodeError(error) ? error.code : undefined,
{ cause: error },
);
}
/**
* Enhanced fetch with SSRF protection.
* Prevents access to private/internal networks at the connection level.
*/
export async function safeFetch(
input: RequestInfo | URL,
init?: RequestInit,
): Promise<Response> {
const nodeInit: NodeFetchInit = {
...init,
dispatcher: safeDispatcher,
};
try {
// eslint-disable-next-line no-restricted-syntax
return await fetch(input, nodeInit);
} catch (error) {
const url =
input instanceof Request
? input.url
: typeof input === 'string'
? input
: input.toString();
handleFetchError(error, url);
}
}
/**
* Creates an undici ProxyAgent that incorporates safe DNS lookup.
*/
export function createSafeProxyAgent(proxyUrl: string): ProxyAgent {
return new ProxyAgent({
uri: proxyUrl,
connect: {
lookup: safeLookup,
},
});
}
/**
* Performs a fetch with a specified timeout and connection-level SSRF protection.
*/
export async function fetchWithTimeout(
url: string,
timeout: number,
@@ -294,21 +142,17 @@ export async function fetchWithTimeout(
}
}
const nodeInit: NodeFetchInit = {
...options,
signal: controller.signal,
dispatcher: safeDispatcher,
};
try {
// eslint-disable-next-line no-restricted-syntax
const response = await fetch(url, nodeInit);
const response = await fetch(url, {
...options,
signal: controller.signal,
});
return response;
} catch (error) {
if (isNodeError(error) && error.code === 'ABORT_ERR') {
throw new FetchError(`Request timed out after ${timeout}ms`, 'ETIMEDOUT');
}
handleFetchError(error, url.toString());
throw new FetchError(getErrorMessage(error), undefined, { cause: error });
} finally {
clearTimeout(timeoutId);
}
-2
View File
@@ -454,7 +454,6 @@ export async function exchangeCodeForToken(
params.append('resource', resource);
}
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(config.tokenUrl, {
method: 'POST',
headers: {
@@ -508,7 +507,6 @@ export async function refreshAccessToken(
params.append('resource', resource);
}
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(tokenUrl, {
method: 'POST',
headers: {
@@ -42,7 +42,6 @@ async function checkForUpdates(
const currentVersion = context.extension.packageJSON.version;
// Fetch extension details from the VSCode Marketplace.
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(
'https://marketplace.visualstudio.com/_apis/public/gallery/extensionquery',
{
@@ -356,7 +356,6 @@ describe('IDEServer', () => {
});
it('should reject request without auth token', async () => {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(`http://localhost:${port}/mcp`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
@@ -371,7 +370,6 @@ describe('IDEServer', () => {
});
it('should allow request with valid auth token', async () => {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(`http://localhost:${port}/mcp`, {
method: 'POST',
headers: {
@@ -389,7 +387,6 @@ describe('IDEServer', () => {
});
it('should reject request with invalid auth token', async () => {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(`http://localhost:${port}/mcp`, {
method: 'POST',
headers: {
@@ -416,7 +413,6 @@ describe('IDEServer', () => {
];
for (const header of malformedHeaders) {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(`http://localhost:${port}/mcp`, {
method: 'POST',
headers: {
+41 -9
View File
@@ -228,23 +228,35 @@ const packageJson = JSON.parse(
// Helper to calc hash
const sha256 = (content) => createHash('sha256').update(content).digest('hex');
// Read Main Bundle
const geminiBundlePath = join(root, 'bundle/gemini.js');
const geminiContent = readFileSync(geminiBundlePath);
const geminiHash = sha256(geminiContent);
const assets = {
'gemini.mjs': geminiBundlePath, // Use .js source but map to .mjs for runtime ESM
'manifest.json': 'bundle/manifest.json',
};
const manifest = {
main: 'gemini.mjs',
mainHash: geminiHash,
mainHash: '',
version: packageJson.version,
files: [],
};
// Add all javascript chunks from the bundle directory
const jsFiles = globSync('*.js', { cwd: bundleDir });
for (const jsFile of jsFiles) {
const fsPath = join(bundleDir, jsFile);
const content = readFileSync(fsPath);
const hash = sha256(content);
// Node SEA requires the main entry point to be explicitly mapped
if (jsFile === 'gemini.js') {
assets['gemini.mjs'] = fsPath;
manifest.mainHash = hash;
} else {
// Other chunks need to be mapped exactly as they are named so dynamic imports find them
assets[jsFile] = fsPath;
manifest.files.push({ key: jsFile, path: jsFile, hash: hash });
}
}
// Helper to recursively find files from STAGING
function addAssetsFromDir(baseDir, runtimePrefix) {
const fullDir = join(stagingDir, baseDir);
@@ -346,6 +358,22 @@ const targetBinaryPath = join(targetDir, binaryName);
console.log(`Copying node binary from ${nodeBinary} to ${targetBinaryPath}...`);
copyFileSync(nodeBinary, targetBinaryPath);
if (platform === 'darwin') {
console.log(`Thinning universal binary for ${arch}...`);
try {
// Attempt to thin the binary. Will fail safely if it's not a fat binary.
runCommand('lipo', [
targetBinaryPath,
'-thin',
arch,
'-output',
targetBinaryPath,
]);
} catch (e) {
console.log(`Skipping lipo thinning: ${e.message}`);
}
}
// Remove existing signature using helper
removeSignature(targetBinaryPath);
@@ -357,9 +385,7 @@ if (existsSync(bundleDir)) {
// Clean up source JS files from output (we only want embedded)
const filesToRemove = [
'gemini.js',
'gemini.mjs',
'gemini.js.map',
'gemini.mjs.map',
'gemini-sea.cjs',
'sea-launch.cjs',
@@ -373,6 +399,12 @@ filesToRemove.forEach((f) => {
if (existsSync(p)) rmSync(p, { recursive: true, force: true });
});
// Remove all chunk and entry .js/.js.map files
const jsFilesToRemove = globSync('*.{js,js.map}', { cwd: targetDir });
for (const f of jsFilesToRemove) {
rmSync(join(targetDir, f));
}
// Remove .sb files from targetDir
const sbFilesToRemove = globSync('sandbox-macos-*.sb', { cwd: targetDir });
for (const f of sbFilesToRemove) {