mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-06-13 12:57:12 -07:00
merge main
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# Preview release: v0.34.0-preview.0
|
||||
# Preview release: v0.34.0-preview.1
|
||||
|
||||
Released: March 11, 2026
|
||||
Released: March 12, 2026
|
||||
|
||||
Our preview release includes the latest, new, and experimental features. This
|
||||
release may not be as stable as our [latest weekly release](latest.md).
|
||||
@@ -28,6 +28,9 @@ npm install -g @google/gemini-cli@preview
|
||||
|
||||
## What's Changed
|
||||
|
||||
- fix(patch): cherry-pick 45faf4d to release/v0.34.0-preview.0-pr-22148
|
||||
[CONFLICTS] by @gemini-cli-robot in
|
||||
[#22174](https://github.com/google-gemini/gemini-cli/pull/22174)
|
||||
- feat(cli): add chat resume footer on session quit by @lordshashank in
|
||||
[#20667](https://github.com/google-gemini/gemini-cli/pull/20667)
|
||||
- Support bold and other styles in svg snapshots by @jacob314 in
|
||||
@@ -465,4 +468,4 @@ npm install -g @google/gemini-cli@preview
|
||||
[#21938](https://github.com/google-gemini/gemini-cli/pull/21938)
|
||||
|
||||
**Full Changelog**:
|
||||
https://github.com/google-gemini/gemini-cli/compare/v0.33.0-preview.15...v0.34.0-preview.0
|
||||
https://github.com/google-gemini/gemini-cli/compare/v0.33.0-preview.15...v0.34.0-preview.1
|
||||
|
||||
@@ -26,6 +26,20 @@ policies.
|
||||
the CLI will use an available fallback model for the current turn or the
|
||||
remainder of the session.
|
||||
|
||||
### Local Model Routing (Experimental)
|
||||
|
||||
Gemini CLI supports using a local model for routing decisions. When configured,
|
||||
Gemini CLI will use a locally-running **Gemma** model to make routing decisions
|
||||
(instead of sending routing decisions to a hosted model). This feature can help
|
||||
reduce costs associated with hosted model usage while offering similar routing
|
||||
decision latency and quality.
|
||||
|
||||
In order to use this feature, the local Gemma model **must** be served behind a
|
||||
Gemini API and accessible via HTTP at an endpoint configured in `settings.json`.
|
||||
|
||||
For more details on how to configure local model routing, see
|
||||
[Local Model Routing](../core/local-model-routing.md).
|
||||
|
||||
### Model selection precedence
|
||||
|
||||
The model used by Gemini CLI is determined by the following order of precedence:
|
||||
@@ -38,5 +52,8 @@ The model used by Gemini CLI is determined by the following order of precedence:
|
||||
3. **`model.name` in `settings.json`:** If neither of the above are set, the
|
||||
model specified in the `model.name` property of your `settings.json` file
|
||||
will be used.
|
||||
4. **Default model:** If none of the above are set, the default model will be
|
||||
4. **Local model (experimental):** If the Gemma local model router is enabled
|
||||
in your `settings.json` file, the CLI will use the local Gemma model
|
||||
(instead of Gemini models) to route the request to an appropriate model.
|
||||
5. **Default model:** If none of the above are set, the default model will be
|
||||
used. The default model is `auto`
|
||||
|
||||
@@ -15,6 +15,8 @@ requests sent from `packages/cli`. For a general overview of Gemini CLI, see the
|
||||
modular GEMINI.md import feature using @file.md syntax.
|
||||
- **[Policy Engine](../reference/policy-engine.md):** Use the Policy Engine for
|
||||
fine-grained control over tool execution.
|
||||
- **[Local Model Routing (experimental)](./local-model-routing.md):** Learn how
|
||||
to enable use of a local Gemma model for model routing decisions.
|
||||
|
||||
## Role of the core
|
||||
|
||||
|
||||
@@ -0,0 +1,193 @@
|
||||
# Local Model Routing (experimental)
|
||||
|
||||
Gemini CLI supports using a local model for
|
||||
[routing decisions](../cli/model-routing.md). When configured, Gemini CLI will
|
||||
use a locally-running **Gemma** model to make routing decisions (instead of
|
||||
sending routing decisions to a hosted model).
|
||||
|
||||
This feature can help reduce costs associated with hosted model usage while
|
||||
offering similar routing decision latency and quality.
|
||||
|
||||
> **Note: Local model routing is currently an experimental feature.**
|
||||
|
||||
## Setup
|
||||
|
||||
Using a Gemma model for routing decisions requires that an implementation of a
|
||||
Gemma model be running locally on your machine, served behind an HTTP endpoint
|
||||
and accessed via the Gemini API.
|
||||
|
||||
To serve the Gemma model, follow these steps:
|
||||
|
||||
### Download the LiteRT-LM runtime
|
||||
|
||||
The [LiteRT-LM](https://github.com/google-ai-edge/LiteRT-LM) runtime offers
|
||||
pre-built binaries for locally-serving models. Download the binary appropriate
|
||||
for your system.
|
||||
|
||||
#### Windows
|
||||
|
||||
1. Download
|
||||
[lit.windows_x86_64.exe](https://github.com/google-ai-edge/LiteRT-LM/releases/download/v0.9.0-alpha03/lit.windows_x86_64.exe).
|
||||
2. Using GPU on Windows requires the DirectXShaderCompiler. Download the
|
||||
[dxc zip from the latest release](https://github.com/microsoft/DirectXShaderCompiler/releases/download/v1.8.2505.1/dxc_2025_07_14.zip).
|
||||
Unzip the archive and from the architecture-appropriate `bin\` directory, and
|
||||
copy the `dxil.dll` and `dxcompiler.dll` into the same location as you saved
|
||||
`lit.windows_x86_64.exe`.
|
||||
3. (Optional) Test starting the runtime:
|
||||
`.\lit.windows_x86_64.exe serve --verbose`
|
||||
|
||||
#### Linux
|
||||
|
||||
1. Download
|
||||
[lit.linux_x86_64](https://github.com/google-ai-edge/LiteRT-LM/releases/download/v0.9.0-alpha03/lit.linux_x86_64).
|
||||
2. Ensure the binary is executable: `chmod a+x lit.linux_x86_64`
|
||||
3. (Optional) Test starting the runtime: `./lit.linux_x86_64 serve --verbose`
|
||||
|
||||
#### MacOS
|
||||
|
||||
1. Download
|
||||
[lit-macos-arm64](https://github.com/google-ai-edge/LiteRT-LM/releases/download/v0.9.0-alpha03/lit.macos_arm64).
|
||||
2. Ensure the binary is executable: `chmod a+x lit.macos_arm64`
|
||||
3. (Optional) Test starting the runtime: `./lit.macos_arm64 serve --verbose`
|
||||
|
||||
> **Note**: MacOS can be configured to only allows binaries from "App Store &
|
||||
> Known Developers". If you encounter an error message when attempting to run
|
||||
> the binary, you will need to allow the application. One option is to visit
|
||||
> `System Settings -> Privacy & Security`, scroll to `Security`, and click
|
||||
> `"Allow Anyway"` for `"lit.macos_arm64"`. Another option is to run
|
||||
> `xattr -d com.apple.quarantine lit.macos_arm64` from the commandline.
|
||||
|
||||
### Download the Gemma Model
|
||||
|
||||
Before using Gemma, you will need to download the model (and agree to the Terms
|
||||
of Service).
|
||||
|
||||
This can be done via the LiteRT-LM runtime.
|
||||
|
||||
#### Windows
|
||||
|
||||
```bash
|
||||
$ .\lit.windows_x86_64.exe pull gemma3-1b-gpu-custom
|
||||
|
||||
[Legal] The model you are about to download is governed by
|
||||
the Gemma Terms of Use and Prohibited Use Policy. Please review these terms and ensure you agree before continuing.
|
||||
|
||||
Full Terms: https://ai.google.dev/gemma/terms
|
||||
Prohibited Use Policy: https://ai.google.dev/gemma/prohibited_use_policy
|
||||
|
||||
Do you accept these terms? (Y/N): Y
|
||||
|
||||
Terms accepted.
|
||||
Downloading model 'gemma3-1b-gpu-custom' ...
|
||||
Downloading... 968.6 MB
|
||||
Download complete.
|
||||
```
|
||||
|
||||
#### Linux
|
||||
|
||||
```bash
|
||||
$ ./lit.linux_x86_64 pull gemma3-1b-gpu-custom
|
||||
|
||||
[Legal] The model you are about to download is governed by
|
||||
the Gemma Terms of Use and Prohibited Use Policy. Please review these terms and ensure you agree before continuing.
|
||||
|
||||
Full Terms: https://ai.google.dev/gemma/terms
|
||||
Prohibited Use Policy: https://ai.google.dev/gemma/prohibited_use_policy
|
||||
|
||||
Do you accept these terms? (Y/N): Y
|
||||
|
||||
Terms accepted.
|
||||
Downloading model 'gemma3-1b-gpu-custom' ...
|
||||
Downloading... 968.6 MB
|
||||
Download complete.
|
||||
```
|
||||
|
||||
#### MacOS
|
||||
|
||||
```bash
|
||||
$ ./lit.lit.macos_arm64 pull gemma3-1b-gpu-custom
|
||||
|
||||
[Legal] The model you are about to download is governed by
|
||||
the Gemma Terms of Use and Prohibited Use Policy. Please review these terms and ensure you agree before continuing.
|
||||
|
||||
Full Terms: https://ai.google.dev/gemma/terms
|
||||
Prohibited Use Policy: https://ai.google.dev/gemma/prohibited_use_policy
|
||||
|
||||
Do you accept these terms? (Y/N): Y
|
||||
|
||||
Terms accepted.
|
||||
Downloading model 'gemma3-1b-gpu-custom' ...
|
||||
Downloading... 968.6 MB
|
||||
Download complete.
|
||||
```
|
||||
|
||||
### Start LiteRT-LM Runtime
|
||||
|
||||
Using the command appropriate to your system, start the LiteRT-LM runtime.
|
||||
Configure the port that you want to use for your Gemma model. For the purposes
|
||||
of this document, we will use port `9379`.
|
||||
|
||||
Example command for MacOS: `./lit.macos_arm64 serve --port=9379 --verbose`
|
||||
|
||||
### (Optional) Verify Model Serving
|
||||
|
||||
Send a quick prompt to the model via HTTP to validate successful model serving.
|
||||
This will cause the runtime to download the model and run it once.
|
||||
|
||||
You should see a short joke in the server output as an indicator of success.
|
||||
|
||||
#### Windows
|
||||
|
||||
```
|
||||
# Run this in PowerShell to send a request to the server
|
||||
|
||||
$uri = "http://localhost:9379/v1beta/models/gemma3-1b-gpu-custom:generateContent"
|
||||
$body = @{contents = @( @{
|
||||
role = "user"
|
||||
parts = @( @{ text = "Tell me a joke." } )
|
||||
})} | ConvertTo-Json -Depth 10
|
||||
|
||||
Invoke-RestMethod -Uri $uri -Method Post -Body $body -ContentType "application/json"
|
||||
```
|
||||
|
||||
#### Linux/MacOS
|
||||
|
||||
```bash
|
||||
$ curl "http://localhost:9379/v1beta/models/gemma3-1b-gpu-custom:generateContent" \
|
||||
-H 'Content-Type: application/json' \
|
||||
-X POST \
|
||||
-d '{"contents":[{"role":"user","parts":[{"text":"Tell me a joke."}]}]}'
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
To use a local Gemma model for routing, you must explicitly enable it in your
|
||||
`settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"experimental": {
|
||||
"gemmaModelRouter": {
|
||||
"enabled": true,
|
||||
"classifier": {
|
||||
"host": "http://localhost:9379",
|
||||
"model": "gemma3-1b-gpu-custom"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> Use the port you started your LiteRT-LM runtime on in the setup steps.
|
||||
|
||||
### Configuration schema
|
||||
|
||||
| Field | Type | Required | Description |
|
||||
| :----------------- | :------ | :------- | :----------------------------------------------------------------------------------------- |
|
||||
| `enabled` | boolean | Yes | Must be `true` to enable the feature. |
|
||||
| `classifier` | object | Yes | The configuration for the local model endpoint. It includes the host and model specifiers. |
|
||||
| `classifier.host` | string | Yes | The URL to the local model server. Should be `http://localhost:<port>`. |
|
||||
| `classifier.model` | string | Yes | The model name to use for decisions. Must be `"gemma3-1b-gpu-custom"`. |
|
||||
|
||||
> **Note: You will need to restart after configuration changes for local model
|
||||
> routing to take effect.**
|
||||
+9
-4
@@ -82,11 +82,14 @@ const commonAliases = {
|
||||
const cliConfig = {
|
||||
...baseConfig,
|
||||
banner: {
|
||||
js: `const require = (await import('node:module')).createRequire(import.meta.url); globalThis.__filename = (await import('node:url')).fileURLToPath(import.meta.url); globalThis.__dirname = (await import('node:path')).dirname(globalThis.__filename);`,
|
||||
js: `const require = (await import('node:module')).createRequire(import.meta.url); const __chunk_filename = (await import('node:url')).fileURLToPath(import.meta.url); const __chunk_dirname = (await import('node:path')).dirname(__chunk_filename);`,
|
||||
},
|
||||
entryPoints: ['packages/cli/index.ts'],
|
||||
outfile: 'bundle/gemini.js',
|
||||
entryPoints: { gemini: 'packages/cli/index.ts' },
|
||||
outdir: 'bundle',
|
||||
splitting: true,
|
||||
define: {
|
||||
__filename: '__chunk_filename',
|
||||
__dirname: '__chunk_dirname',
|
||||
'process.env.CLI_VERSION': JSON.stringify(pkg.version),
|
||||
'process.env.GEMINI_SANDBOX_IMAGE_DEFAULT': JSON.stringify(
|
||||
pkg.config?.sandboxImageUri,
|
||||
@@ -103,11 +106,13 @@ const cliConfig = {
|
||||
const a2aServerConfig = {
|
||||
...baseConfig,
|
||||
banner: {
|
||||
js: `const require = (await import('node:module')).createRequire(import.meta.url); globalThis.__filename = (await import('node:url')).fileURLToPath(import.meta.url); globalThis.__dirname = (await import('node:path')).dirname(globalThis.__filename);`,
|
||||
js: `const require = (await import('node:module')).createRequire(import.meta.url); const __chunk_filename = (await import('node:url')).fileURLToPath(import.meta.url); const __chunk_dirname = (await import('node:path')).dirname(__chunk_filename);`,
|
||||
},
|
||||
entryPoints: ['packages/a2a-server/src/http/server.ts'],
|
||||
outfile: 'packages/a2a-server/dist/a2a-server.mjs',
|
||||
define: {
|
||||
__filename: '__chunk_filename',
|
||||
__dirname: '__chunk_dirname',
|
||||
'process.env.CLI_VERSION': JSON.stringify(pkg.version),
|
||||
},
|
||||
plugins: createWasmPlugins(),
|
||||
|
||||
@@ -35,11 +35,6 @@ const commonRestrictedSyntaxRules = [
|
||||
message:
|
||||
'Do not throw string literals or non-Error objects. Throw new Error("...") instead.',
|
||||
},
|
||||
{
|
||||
selector: 'CallExpression[callee.name="fetch"]',
|
||||
message:
|
||||
'Use safeFetch() from "@/utils/fetch" instead of the global fetch() to ensure SSRF protection. If you are implementing a custom security layer, use an eslint-disable comment and explain why.',
|
||||
},
|
||||
];
|
||||
|
||||
export default tseslint.config(
|
||||
|
||||
+38
-217
@@ -4,13 +4,38 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { render } from 'ink';
|
||||
import { AppContainer } from './ui/AppContainer.js';
|
||||
import {
|
||||
type StartupWarning,
|
||||
WarningPriority,
|
||||
type Config,
|
||||
type ResumedSessionData,
|
||||
type OutputPayload,
|
||||
type ConsoleLogPayload,
|
||||
type UserFeedbackPayload,
|
||||
sessionId,
|
||||
logUserPrompt,
|
||||
AuthType,
|
||||
UserPromptEvent,
|
||||
coreEvents,
|
||||
CoreEvent,
|
||||
getOauthClient,
|
||||
patchStdio,
|
||||
writeToStdout,
|
||||
writeToStderr,
|
||||
shouldEnterAlternateScreen,
|
||||
startupProfiler,
|
||||
ExitCodes,
|
||||
SessionStartSource,
|
||||
SessionEndReason,
|
||||
ValidationCancelledError,
|
||||
ValidationRequiredError,
|
||||
type AdminControlsSettings,
|
||||
debugLogger,
|
||||
} from '@google/gemini-cli-core';
|
||||
|
||||
import { loadCliConfig, parseArguments } from './config/config.js';
|
||||
import * as cliConfig from './config/config.js';
|
||||
import { readStdin } from './utils/readStdin.js';
|
||||
import { basename } from 'node:path';
|
||||
import { createHash } from 'node:crypto';
|
||||
import v8 from 'node:v8';
|
||||
import os from 'node:os';
|
||||
@@ -37,47 +62,11 @@ import {
|
||||
runExitCleanup,
|
||||
registerTelemetryConfig,
|
||||
setupSignalHandlers,
|
||||
setupTtyCheck,
|
||||
} from './utils/cleanup.js';
|
||||
import {
|
||||
cleanupToolOutputFiles,
|
||||
cleanupExpiredSessions,
|
||||
} from './utils/sessionCleanup.js';
|
||||
import {
|
||||
type StartupWarning,
|
||||
WarningPriority,
|
||||
type Config,
|
||||
type ResumedSessionData,
|
||||
type OutputPayload,
|
||||
type ConsoleLogPayload,
|
||||
type UserFeedbackPayload,
|
||||
sessionId,
|
||||
logUserPrompt,
|
||||
AuthType,
|
||||
getOauthClient,
|
||||
UserPromptEvent,
|
||||
debugLogger,
|
||||
recordSlowRender,
|
||||
coreEvents,
|
||||
CoreEvent,
|
||||
createWorkingStdio,
|
||||
patchStdio,
|
||||
writeToStdout,
|
||||
writeToStderr,
|
||||
disableMouseEvents,
|
||||
enableMouseEvents,
|
||||
disableLineWrapping,
|
||||
enableLineWrapping,
|
||||
shouldEnterAlternateScreen,
|
||||
startupProfiler,
|
||||
ExitCodes,
|
||||
SessionStartSource,
|
||||
SessionEndReason,
|
||||
getVersion,
|
||||
ValidationCancelledError,
|
||||
ValidationRequiredError,
|
||||
type AdminControlsSettings,
|
||||
} from '@google/gemini-cli-core';
|
||||
import {
|
||||
initializeApp,
|
||||
type InitializationResult,
|
||||
@@ -85,21 +74,9 @@ import {
|
||||
import { validateAuthMethod } from './config/auth.js';
|
||||
import { runAcpClient } from './acp/acpClient.js';
|
||||
import { validateNonInteractiveAuth } from './validateNonInterActiveAuth.js';
|
||||
import { checkForUpdates } from './ui/utils/updateCheck.js';
|
||||
import { handleAutoUpdate } from './utils/handleAutoUpdate.js';
|
||||
import { appEvents, AppEvent } from './utils/events.js';
|
||||
import { SessionError, SessionSelector } from './utils/sessionUtils.js';
|
||||
import { SettingsContext } from './ui/contexts/SettingsContext.js';
|
||||
import { MouseProvider } from './ui/contexts/MouseContext.js';
|
||||
import { StreamingState } from './ui/types.js';
|
||||
import { computeTerminalTitle } from './utils/windowTitle.js';
|
||||
|
||||
import { SessionStatsProvider } from './ui/contexts/SessionContext.js';
|
||||
import { VimModeProvider } from './ui/contexts/VimModeContext.js';
|
||||
import { KeyMatchersProvider } from './ui/hooks/useKeyMatchers.js';
|
||||
import { loadKeyMatchers } from './ui/key/keyMatchers.js';
|
||||
import { KeypressProvider } from './ui/contexts/KeypressContext.js';
|
||||
import { useKittyKeyboardProtocol } from './ui/hooks/useKittyKeyboardProtocol.js';
|
||||
import {
|
||||
relaunchAppInChildProcess,
|
||||
relaunchOnExitCode,
|
||||
@@ -107,19 +84,13 @@ import {
|
||||
import { loadSandboxConfig } from './config/sandboxConfig.js';
|
||||
import { deleteSession, listSessions } from './utils/sessions.js';
|
||||
import { createPolicyUpdater } from './config/policy.js';
|
||||
import { ScrollProvider } from './ui/contexts/ScrollProvider.js';
|
||||
import { TerminalProvider } from './ui/contexts/TerminalContext.js';
|
||||
import { isAlternateBufferEnabled } from './ui/hooks/useAlternateBuffer.js';
|
||||
import { OverflowProvider } from './ui/contexts/OverflowContext.js';
|
||||
|
||||
import { setupTerminalAndTheme } from './utils/terminalTheme.js';
|
||||
import { profiler } from './ui/components/DebugProfiler.js';
|
||||
import { runDeferredCommand } from './deferred.js';
|
||||
import { cleanupBackgroundLogs } from './utils/logCleanup.js';
|
||||
import { SlashCommandConflictHandler } from './services/SlashCommandConflictHandler.js';
|
||||
|
||||
const SLOW_RENDER_MS = 200;
|
||||
|
||||
export function validateDnsResolutionOrder(
|
||||
order: string | undefined,
|
||||
): DnsResolutionOrder {
|
||||
@@ -198,147 +169,16 @@ export async function startInteractiveUI(
|
||||
resumedSessionData: ResumedSessionData | undefined,
|
||||
initializationResult: InitializationResult,
|
||||
) {
|
||||
// Never enter Ink alternate buffer mode when screen reader mode is enabled
|
||||
// as there is no benefit of alternate buffer mode when using a screen reader
|
||||
// and the Ink alternate buffer mode requires line wrapping harmful to
|
||||
// screen readers.
|
||||
const useAlternateBuffer = shouldEnterAlternateScreen(
|
||||
isAlternateBufferEnabled(config),
|
||||
config.getScreenReader(),
|
||||
// Dynamically import the heavy UI module so React/Ink are only parsed when needed
|
||||
const { startInteractiveUI: doStartUI } = await import('./interactiveCli.js');
|
||||
await doStartUI(
|
||||
config,
|
||||
settings,
|
||||
startupWarnings,
|
||||
workspaceRoot,
|
||||
resumedSessionData,
|
||||
initializationResult,
|
||||
);
|
||||
const mouseEventsEnabled = useAlternateBuffer;
|
||||
if (mouseEventsEnabled) {
|
||||
enableMouseEvents();
|
||||
registerCleanup(() => {
|
||||
disableMouseEvents();
|
||||
});
|
||||
}
|
||||
|
||||
const { matchers, errors } = await loadKeyMatchers();
|
||||
errors.forEach((error) => {
|
||||
coreEvents.emitFeedback('warning', error);
|
||||
});
|
||||
|
||||
const version = await getVersion();
|
||||
setWindowTitle(basename(workspaceRoot), settings);
|
||||
|
||||
const consolePatcher = new ConsolePatcher({
|
||||
onNewMessage: (msg) => {
|
||||
coreEvents.emitConsoleLog(msg.type, msg.content);
|
||||
},
|
||||
debugMode: config.getDebugMode(),
|
||||
});
|
||||
consolePatcher.patch();
|
||||
registerCleanup(consolePatcher.cleanup);
|
||||
|
||||
const { stdout: inkStdout, stderr: inkStderr } = createWorkingStdio();
|
||||
|
||||
const isShpool = !!process.env['SHPOOL_SESSION_NAME'];
|
||||
|
||||
// Create wrapper component to use hooks inside render
|
||||
const AppWrapper = () => {
|
||||
useKittyKeyboardProtocol();
|
||||
|
||||
return (
|
||||
<SettingsContext.Provider value={settings}>
|
||||
<KeyMatchersProvider value={matchers}>
|
||||
<KeypressProvider
|
||||
config={config}
|
||||
debugKeystrokeLogging={
|
||||
settings.merged.general.debugKeystrokeLogging
|
||||
}
|
||||
>
|
||||
<MouseProvider
|
||||
mouseEventsEnabled={mouseEventsEnabled}
|
||||
debugKeystrokeLogging={
|
||||
settings.merged.general.debugKeystrokeLogging
|
||||
}
|
||||
>
|
||||
<TerminalProvider>
|
||||
<ScrollProvider>
|
||||
<OverflowProvider>
|
||||
<SessionStatsProvider>
|
||||
<VimModeProvider>
|
||||
<AppContainer
|
||||
config={config}
|
||||
startupWarnings={startupWarnings}
|
||||
version={version}
|
||||
resumedSessionData={resumedSessionData}
|
||||
initializationResult={initializationResult}
|
||||
/>
|
||||
</VimModeProvider>
|
||||
</SessionStatsProvider>
|
||||
</OverflowProvider>
|
||||
</ScrollProvider>
|
||||
</TerminalProvider>
|
||||
</MouseProvider>
|
||||
</KeypressProvider>
|
||||
</KeyMatchersProvider>
|
||||
</SettingsContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
if (isShpool) {
|
||||
// Wait a moment for shpool to stabilize terminal size and state.
|
||||
// shpool is a persistence tool that restores terminal state by replaying it.
|
||||
// This delay gives shpool time to finish its restoration replay and send
|
||||
// the actual terminal size (often via an immediate SIGWINCH) before we
|
||||
// render the first TUI frame. Without this, the first frame may be
|
||||
// garbled or rendered at an incorrect size, which disabling incremental
|
||||
// rendering alone cannot fix for the initial frame.
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
}
|
||||
|
||||
const instance = render(
|
||||
process.env['DEBUG'] ? (
|
||||
<React.StrictMode>
|
||||
<AppWrapper />
|
||||
</React.StrictMode>
|
||||
) : (
|
||||
<AppWrapper />
|
||||
),
|
||||
{
|
||||
stdout: inkStdout,
|
||||
stderr: inkStderr,
|
||||
stdin: process.stdin,
|
||||
exitOnCtrlC: false,
|
||||
isScreenReaderEnabled: config.getScreenReader(),
|
||||
onRender: ({ renderTime }: { renderTime: number }) => {
|
||||
if (renderTime > SLOW_RENDER_MS) {
|
||||
recordSlowRender(config, renderTime);
|
||||
}
|
||||
profiler.reportFrameRendered();
|
||||
},
|
||||
patchConsole: false,
|
||||
alternateBuffer: useAlternateBuffer,
|
||||
incrementalRendering:
|
||||
settings.merged.ui.incrementalRendering !== false &&
|
||||
useAlternateBuffer &&
|
||||
!isShpool,
|
||||
},
|
||||
);
|
||||
|
||||
if (useAlternateBuffer) {
|
||||
disableLineWrapping();
|
||||
registerCleanup(() => {
|
||||
enableLineWrapping();
|
||||
});
|
||||
}
|
||||
|
||||
checkForUpdates(settings)
|
||||
.then((info) => {
|
||||
handleAutoUpdate(info, settings, config.getProjectRoot());
|
||||
})
|
||||
.catch((err) => {
|
||||
// Silently ignore update check errors.
|
||||
if (config.getDebugMode()) {
|
||||
debugLogger.warn('Update check failed:', err);
|
||||
}
|
||||
});
|
||||
|
||||
registerCleanup(() => instance.unmount());
|
||||
|
||||
registerCleanup(setupTtyCheck());
|
||||
}
|
||||
|
||||
export async function main() {
|
||||
@@ -845,25 +685,6 @@ export async function main() {
|
||||
}
|
||||
}
|
||||
|
||||
function setWindowTitle(title: string, settings: LoadedSettings) {
|
||||
if (!settings.merged.ui.hideWindowTitle) {
|
||||
// Initial state before React loop starts
|
||||
const windowTitle = computeTerminalTitle({
|
||||
streamingState: StreamingState.Idle,
|
||||
isConfirming: false,
|
||||
isSilentWorking: false,
|
||||
folderName: title,
|
||||
showThoughts: !!settings.merged.ui.showStatusInTitle,
|
||||
useDynamicTitle: settings.merged.ui.dynamicWindowTitle,
|
||||
});
|
||||
writeToStdout(`\x1b]0;${windowTitle}\x07`);
|
||||
|
||||
process.on('exit', () => {
|
||||
writeToStdout(`\x1b]0;\x07`);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export function initializeOutputListenersAndFlush() {
|
||||
// If there are no listeners for output, make sure we flush so output is not
|
||||
// lost.
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { render } from 'ink';
|
||||
import { basename } from 'node:path';
|
||||
import { AppContainer } from './ui/AppContainer.js';
|
||||
import { ConsolePatcher } from './ui/utils/ConsolePatcher.js';
|
||||
import { registerCleanup, setupTtyCheck } from './utils/cleanup.js';
|
||||
import {
|
||||
type StartupWarning,
|
||||
type Config,
|
||||
type ResumedSessionData,
|
||||
coreEvents,
|
||||
createWorkingStdio,
|
||||
disableMouseEvents,
|
||||
enableMouseEvents,
|
||||
disableLineWrapping,
|
||||
enableLineWrapping,
|
||||
shouldEnterAlternateScreen,
|
||||
recordSlowRender,
|
||||
writeToStdout,
|
||||
getVersion,
|
||||
debugLogger,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type { InitializationResult } from './core/initializer.js';
|
||||
import type { LoadedSettings } from './config/settings.js';
|
||||
import { checkForUpdates } from './ui/utils/updateCheck.js';
|
||||
import { handleAutoUpdate } from './utils/handleAutoUpdate.js';
|
||||
import { SettingsContext } from './ui/contexts/SettingsContext.js';
|
||||
import { MouseProvider } from './ui/contexts/MouseContext.js';
|
||||
import { StreamingState } from './ui/types.js';
|
||||
import { computeTerminalTitle } from './utils/windowTitle.js';
|
||||
|
||||
import { SessionStatsProvider } from './ui/contexts/SessionContext.js';
|
||||
import { VimModeProvider } from './ui/contexts/VimModeContext.js';
|
||||
import { KeyMatchersProvider } from './ui/hooks/useKeyMatchers.js';
|
||||
import { loadKeyMatchers } from './ui/key/keyMatchers.js';
|
||||
import { KeypressProvider } from './ui/contexts/KeypressContext.js';
|
||||
import { useKittyKeyboardProtocol } from './ui/hooks/useKittyKeyboardProtocol.js';
|
||||
import { ScrollProvider } from './ui/contexts/ScrollProvider.js';
|
||||
import { TerminalProvider } from './ui/contexts/TerminalContext.js';
|
||||
import { isAlternateBufferEnabled } from './ui/hooks/useAlternateBuffer.js';
|
||||
import { OverflowProvider } from './ui/contexts/OverflowContext.js';
|
||||
import { profiler } from './ui/components/DebugProfiler.js';
|
||||
|
||||
const SLOW_RENDER_MS = 200;
|
||||
|
||||
export async function startInteractiveUI(
|
||||
config: Config,
|
||||
settings: LoadedSettings,
|
||||
startupWarnings: StartupWarning[],
|
||||
workspaceRoot: string = process.cwd(),
|
||||
resumedSessionData: ResumedSessionData | undefined,
|
||||
initializationResult: InitializationResult,
|
||||
) {
|
||||
// Never enter Ink alternate buffer mode when screen reader mode is enabled
|
||||
// as there is no benefit of alternate buffer mode when using a screen reader
|
||||
// and the Ink alternate buffer mode requires line wrapping harmful to
|
||||
// screen readers.
|
||||
const useAlternateBuffer = shouldEnterAlternateScreen(
|
||||
isAlternateBufferEnabled(config),
|
||||
config.getScreenReader(),
|
||||
);
|
||||
const mouseEventsEnabled = useAlternateBuffer;
|
||||
if (mouseEventsEnabled) {
|
||||
enableMouseEvents();
|
||||
registerCleanup(() => {
|
||||
disableMouseEvents();
|
||||
});
|
||||
}
|
||||
|
||||
const { matchers, errors } = await loadKeyMatchers();
|
||||
errors.forEach((error) => {
|
||||
coreEvents.emitFeedback('warning', error);
|
||||
});
|
||||
|
||||
const version = await getVersion();
|
||||
setWindowTitle(basename(workspaceRoot), settings);
|
||||
|
||||
const consolePatcher = new ConsolePatcher({
|
||||
onNewMessage: (msg) => {
|
||||
coreEvents.emitConsoleLog(msg.type, msg.content);
|
||||
},
|
||||
debugMode: config.getDebugMode(),
|
||||
});
|
||||
consolePatcher.patch();
|
||||
registerCleanup(consolePatcher.cleanup);
|
||||
|
||||
const { stdout: inkStdout, stderr: inkStderr } = createWorkingStdio();
|
||||
|
||||
const isShpool = !!process.env['SHPOOL_SESSION_NAME'];
|
||||
|
||||
// Create wrapper component to use hooks inside render
|
||||
const AppWrapper = () => {
|
||||
useKittyKeyboardProtocol();
|
||||
|
||||
return (
|
||||
<SettingsContext.Provider value={settings}>
|
||||
<KeyMatchersProvider value={matchers}>
|
||||
<KeypressProvider
|
||||
config={config}
|
||||
debugKeystrokeLogging={
|
||||
settings.merged.general.debugKeystrokeLogging
|
||||
}
|
||||
>
|
||||
<MouseProvider
|
||||
mouseEventsEnabled={mouseEventsEnabled}
|
||||
debugKeystrokeLogging={
|
||||
settings.merged.general.debugKeystrokeLogging
|
||||
}
|
||||
>
|
||||
<TerminalProvider>
|
||||
<ScrollProvider>
|
||||
<OverflowProvider>
|
||||
<SessionStatsProvider>
|
||||
<VimModeProvider>
|
||||
<AppContainer
|
||||
config={config}
|
||||
startupWarnings={startupWarnings}
|
||||
version={version}
|
||||
resumedSessionData={resumedSessionData}
|
||||
initializationResult={initializationResult}
|
||||
/>
|
||||
</VimModeProvider>
|
||||
</SessionStatsProvider>
|
||||
</OverflowProvider>
|
||||
</ScrollProvider>
|
||||
</TerminalProvider>
|
||||
</MouseProvider>
|
||||
</KeypressProvider>
|
||||
</KeyMatchersProvider>
|
||||
</SettingsContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
if (isShpool) {
|
||||
// Wait a moment for shpool to stabilize terminal size and state.
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
}
|
||||
|
||||
const instance = render(
|
||||
process.env['DEBUG'] ? (
|
||||
<React.StrictMode>
|
||||
<AppWrapper />
|
||||
</React.StrictMode>
|
||||
) : (
|
||||
<AppWrapper />
|
||||
),
|
||||
{
|
||||
stdout: inkStdout,
|
||||
stderr: inkStderr,
|
||||
stdin: process.stdin,
|
||||
exitOnCtrlC: false,
|
||||
isScreenReaderEnabled: config.getScreenReader(),
|
||||
onRender: ({ renderTime }: { renderTime: number }) => {
|
||||
if (renderTime > SLOW_RENDER_MS) {
|
||||
recordSlowRender(config, renderTime);
|
||||
}
|
||||
profiler.reportFrameRendered();
|
||||
},
|
||||
patchConsole: false,
|
||||
alternateBuffer: useAlternateBuffer,
|
||||
incrementalRendering:
|
||||
settings.merged.ui.incrementalRendering !== false &&
|
||||
useAlternateBuffer &&
|
||||
!isShpool,
|
||||
},
|
||||
);
|
||||
|
||||
if (useAlternateBuffer) {
|
||||
disableLineWrapping();
|
||||
registerCleanup(() => {
|
||||
enableLineWrapping();
|
||||
});
|
||||
}
|
||||
|
||||
checkForUpdates(settings)
|
||||
.then((info) => {
|
||||
handleAutoUpdate(info, settings, config.getProjectRoot());
|
||||
})
|
||||
.catch((err) => {
|
||||
// Silently ignore update check errors.
|
||||
if (config.getDebugMode()) {
|
||||
debugLogger.warn('Update check failed:', err);
|
||||
}
|
||||
});
|
||||
|
||||
registerCleanup(() => instance.unmount());
|
||||
|
||||
registerCleanup(setupTtyCheck());
|
||||
}
|
||||
|
||||
function setWindowTitle(title: string, settings: LoadedSettings) {
|
||||
if (!settings.merged.ui.hideWindowTitle) {
|
||||
// Initial state before React loop starts
|
||||
const windowTitle = computeTerminalTitle({
|
||||
streamingState: StreamingState.Idle,
|
||||
isConfirming: false,
|
||||
isSilentWorking: false,
|
||||
folderName: title,
|
||||
showThoughts: !!settings.merged.ui.showStatusInTitle,
|
||||
useDynamicTitle: settings.merged.ui.dynamicWindowTitle,
|
||||
});
|
||||
writeToStdout(`\x1b]0;${windowTitle}\x07`);
|
||||
|
||||
process.on('exit', () => {
|
||||
writeToStdout(`\x1b]0;\x07`);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -123,7 +123,6 @@ async function downloadFiles({
|
||||
downloads.push(
|
||||
(async () => {
|
||||
const endpoint = `${REPO_DOWNLOAD_URL}/refs/tags/${releaseTag}/${SOURCE_DIR}/${fileBasename}`;
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(endpoint, {
|
||||
method: 'GET',
|
||||
dispatcher: proxy ? new ProxyAgent(proxy) : undefined,
|
||||
|
||||
@@ -61,7 +61,6 @@ export const getLatestGitHubRelease = async (
|
||||
|
||||
const endpoint = `https://api.github.com/repos/google-github-actions/run-gemini-cli/releases/latest`;
|
||||
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(endpoint, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
|
||||
@@ -5,11 +5,8 @@
|
||||
*/
|
||||
|
||||
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
A2AClientManager,
|
||||
type SendMessageResult,
|
||||
} from './a2a-client-manager.js';
|
||||
import type { AgentCard, Task } from '@a2a-js/sdk';
|
||||
import { A2AClientManager } from './a2a-client-manager.js';
|
||||
import type { AgentCard } from '@a2a-js/sdk';
|
||||
import {
|
||||
ClientFactory,
|
||||
DefaultAgentCardResolver,
|
||||
@@ -22,81 +19,95 @@ import type { Config } from '../config/config.js';
|
||||
import { Agent as UndiciAgent, ProxyAgent } from 'undici';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
interface MockClient {
|
||||
sendMessageStream: ReturnType<typeof vi.fn>;
|
||||
getTask: ReturnType<typeof vi.fn>;
|
||||
cancelTask: ReturnType<typeof vi.fn>;
|
||||
}
|
||||
|
||||
vi.mock('@a2a-js/sdk/client', async (importOriginal) => {
|
||||
const actual = await importOriginal();
|
||||
return {
|
||||
...(actual as Record<string, unknown>),
|
||||
createAuthenticatingFetchWithRetry: vi.fn(),
|
||||
ClientFactory: vi.fn(),
|
||||
DefaultAgentCardResolver: vi.fn(),
|
||||
ClientFactoryOptions: {
|
||||
createFrom: vi.fn(),
|
||||
default: {},
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('../utils/debugLogger.js', () => ({
|
||||
debugLogger: {
|
||||
debug: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock('@a2a-js/sdk/client', () => {
|
||||
const ClientFactory = vi.fn();
|
||||
const DefaultAgentCardResolver = vi.fn();
|
||||
const RestTransportFactory = vi.fn();
|
||||
const JsonRpcTransportFactory = vi.fn();
|
||||
const ClientFactoryOptions = {
|
||||
default: {},
|
||||
createFrom: vi.fn(),
|
||||
};
|
||||
const createAuthenticatingFetchWithRetry = vi.fn();
|
||||
|
||||
DefaultAgentCardResolver.prototype.resolve = vi.fn();
|
||||
ClientFactory.prototype.createFromUrl = vi.fn();
|
||||
|
||||
return {
|
||||
ClientFactory,
|
||||
ClientFactoryOptions,
|
||||
DefaultAgentCardResolver,
|
||||
RestTransportFactory,
|
||||
JsonRpcTransportFactory,
|
||||
createAuthenticatingFetchWithRetry,
|
||||
};
|
||||
});
|
||||
|
||||
describe('A2AClientManager', () => {
|
||||
let manager: A2AClientManager;
|
||||
const mockAgentCard: AgentCard = {
|
||||
name: 'test-agent',
|
||||
description: 'A test agent',
|
||||
url: 'http://test.agent',
|
||||
version: '1.0.0',
|
||||
protocolVersion: '0.1.0',
|
||||
capabilities: {},
|
||||
skills: [],
|
||||
defaultInputModes: [],
|
||||
defaultOutputModes: [],
|
||||
};
|
||||
|
||||
const mockClient: MockClient = {
|
||||
sendMessageStream: vi.fn(),
|
||||
getTask: vi.fn(),
|
||||
cancelTask: vi.fn(),
|
||||
};
|
||||
|
||||
// Stable mocks initialized once
|
||||
const sendMessageStreamMock = vi.fn();
|
||||
const getTaskMock = vi.fn();
|
||||
const cancelTaskMock = vi.fn();
|
||||
const getAgentCardMock = vi.fn();
|
||||
const authFetchMock = vi.fn();
|
||||
|
||||
const mockClient = {
|
||||
sendMessageStream: sendMessageStreamMock,
|
||||
getTask: getTaskMock,
|
||||
cancelTask: cancelTaskMock,
|
||||
getAgentCard: getAgentCardMock,
|
||||
} as unknown as Client;
|
||||
|
||||
const mockAgentCard: Partial<AgentCard> = { name: 'TestAgent' };
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
A2AClientManager.resetInstanceForTesting();
|
||||
manager = A2AClientManager.getInstance();
|
||||
|
||||
// Default mock implementations
|
||||
getAgentCardMock.mockResolvedValue({
|
||||
// Re-create the instances as plain objects that can be spied on
|
||||
const factoryInstance = {
|
||||
createFromUrl: vi.fn(),
|
||||
createFromAgentCard: vi.fn(),
|
||||
};
|
||||
const resolverInstance = {
|
||||
resolve: vi.fn(),
|
||||
};
|
||||
|
||||
vi.mocked(ClientFactory).mockReturnValue(
|
||||
factoryInstance as unknown as ClientFactory,
|
||||
);
|
||||
vi.mocked(DefaultAgentCardResolver).mockReturnValue(
|
||||
resolverInstance as unknown as DefaultAgentCardResolver,
|
||||
);
|
||||
|
||||
vi.spyOn(factoryInstance, 'createFromUrl').mockResolvedValue(
|
||||
mockClient as unknown as Client,
|
||||
);
|
||||
vi.spyOn(factoryInstance, 'createFromAgentCard').mockResolvedValue(
|
||||
mockClient as unknown as Client,
|
||||
);
|
||||
vi.spyOn(resolverInstance, 'resolve').mockResolvedValue({
|
||||
...mockAgentCard,
|
||||
url: 'http://test.agent/real/endpoint',
|
||||
} as AgentCard);
|
||||
|
||||
vi.mocked(ClientFactory.prototype.createFromUrl).mockResolvedValue(
|
||||
mockClient,
|
||||
vi.spyOn(ClientFactoryOptions, 'createFrom').mockImplementation(
|
||||
(_defaults, overrides) => overrides as unknown as ClientFactoryOptions,
|
||||
);
|
||||
|
||||
vi.mocked(DefaultAgentCardResolver.prototype.resolve).mockResolvedValue({
|
||||
...mockAgentCard,
|
||||
url: 'http://test.agent/real/endpoint',
|
||||
} as AgentCard);
|
||||
|
||||
vi.mocked(ClientFactoryOptions.createFrom).mockImplementation(
|
||||
(_defaults, overrides) => overrides as ClientFactoryOptions,
|
||||
);
|
||||
|
||||
vi.mocked(createAuthenticatingFetchWithRetry).mockReturnValue(
|
||||
authFetchMock,
|
||||
vi.mocked(createAuthenticatingFetchWithRetry).mockImplementation(() =>
|
||||
authFetchMock.mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response),
|
||||
);
|
||||
|
||||
vi.stubGlobal(
|
||||
@@ -170,15 +181,19 @@ describe('A2AClientManager', () => {
|
||||
'TestAgent',
|
||||
'http://test.agent/card',
|
||||
);
|
||||
expect(agentCard).toMatchObject(mockAgentCard);
|
||||
expect(manager.getAgentCard('TestAgent')).toBe(agentCard);
|
||||
expect(manager.getClient('TestAgent')).toBeDefined();
|
||||
});
|
||||
|
||||
it('should configure ClientFactory with REST, JSON-RPC, and gRPC transports', async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
expect(ClientFactoryOptions.createFrom).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should throw an error if an agent with the same name is already loaded', async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
await expect(
|
||||
manager.loadAgent('TestAgent', 'http://another.agent/card'),
|
||||
manager.loadAgent('TestAgent', 'http://test.agent/card'),
|
||||
).rejects.toThrow("Agent with name 'TestAgent' is already loaded.");
|
||||
});
|
||||
|
||||
@@ -193,20 +208,12 @@ describe('A2AClientManager', () => {
|
||||
shouldRetryWithHeaders: vi.fn(),
|
||||
};
|
||||
await manager.loadAgent(
|
||||
'CustomAuthAgent',
|
||||
'http://custom.agent/card',
|
||||
'TestAgent',
|
||||
'http://test.agent/card',
|
||||
customAuthHandler as unknown as AuthenticationHandler,
|
||||
);
|
||||
|
||||
expect(createAuthenticatingFetchWithRetry).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
customAuthHandler,
|
||||
);
|
||||
|
||||
// Card resolver should NOT use the authenticated fetch by default.
|
||||
const resolverInstance = vi.mocked(DefaultAgentCardResolver).mock
|
||||
.instances[0];
|
||||
expect(resolverInstance).toBeDefined();
|
||||
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock
|
||||
.calls[0][0];
|
||||
expect(resolverOptions?.fetchImpl).not.toBe(authFetchMock);
|
||||
@@ -267,106 +274,163 @@ describe('A2AClientManager', () => {
|
||||
it('should log a debug message upon loading an agent', async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
expect(debugLogger.debug).toHaveBeenCalledWith(
|
||||
"[A2AClientManager] Loaded agent 'TestAgent' from http://test.agent/card",
|
||||
expect.stringContaining("Loaded agent 'TestAgent'"),
|
||||
);
|
||||
});
|
||||
|
||||
it('should clear the cache', async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
expect(manager.getAgentCard('TestAgent')).toBeDefined();
|
||||
expect(manager.getClient('TestAgent')).toBeDefined();
|
||||
|
||||
manager.clearCache();
|
||||
|
||||
expect(manager.getAgentCard('TestAgent')).toBeUndefined();
|
||||
expect(manager.getClient('TestAgent')).toBeUndefined();
|
||||
expect(debugLogger.debug).toHaveBeenCalledWith(
|
||||
'[A2AClientManager] Cache cleared.',
|
||||
});
|
||||
|
||||
it('should throw if resolveAgentCard fails', async () => {
|
||||
const resolverInstance = {
|
||||
resolve: vi.fn().mockRejectedValue(new Error('Resolution failed')),
|
||||
};
|
||||
vi.mocked(DefaultAgentCardResolver).mockReturnValue(
|
||||
resolverInstance as unknown as DefaultAgentCardResolver,
|
||||
);
|
||||
|
||||
await expect(
|
||||
manager.loadAgent('FailAgent', 'http://fail.agent'),
|
||||
).rejects.toThrow('Resolution failed');
|
||||
});
|
||||
|
||||
it('should throw if factory.createFromAgentCard fails', async () => {
|
||||
const factoryInstance = {
|
||||
createFromAgentCard: vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Factory failed')),
|
||||
};
|
||||
vi.mocked(ClientFactory).mockReturnValue(
|
||||
factoryInstance as unknown as ClientFactory,
|
||||
);
|
||||
|
||||
await expect(
|
||||
manager.loadAgent('FailAgent', 'http://fail.agent'),
|
||||
).rejects.toThrow('Factory failed');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAgentCard and getClient', () => {
|
||||
it('should return undefined if agent is not found', () => {
|
||||
expect(manager.getAgentCard('Unknown')).toBeUndefined();
|
||||
expect(manager.getClient('Unknown')).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('sendMessageStream', () => {
|
||||
beforeEach(async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent');
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
});
|
||||
|
||||
it('should send a message and return a stream', async () => {
|
||||
const mockResult = {
|
||||
kind: 'message',
|
||||
messageId: 'a',
|
||||
parts: [],
|
||||
role: 'agent',
|
||||
} as SendMessageResult;
|
||||
|
||||
sendMessageStreamMock.mockReturnValue(
|
||||
mockClient.sendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield mockResult;
|
||||
yield { kind: 'message' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello');
|
||||
const results = [];
|
||||
for await (const res of stream) {
|
||||
results.push(res);
|
||||
for await (const result of stream) {
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
expect(results).toEqual([mockResult]);
|
||||
expect(sendMessageStreamMock).toHaveBeenCalledWith(
|
||||
expect(results).toHaveLength(1);
|
||||
expect(mockClient.sendMessageStream).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should use contextId and taskId when provided', async () => {
|
||||
mockClient.sendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { kind: 'message' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
|
||||
contextId: 'ctx123',
|
||||
taskId: 'task456',
|
||||
});
|
||||
// trigger execution
|
||||
for await (const _ of stream) {
|
||||
break;
|
||||
}
|
||||
|
||||
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
message: expect.anything(),
|
||||
message: expect.objectContaining({
|
||||
contextId: 'ctx123',
|
||||
taskId: 'task456',
|
||||
}),
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use contextId and taskId when provided', async () => {
|
||||
sendMessageStreamMock.mockReturnValue(
|
||||
it('should correctly propagate AbortSignal to the stream', async () => {
|
||||
mockClient.sendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'a',
|
||||
parts: [],
|
||||
role: 'agent',
|
||||
} as SendMessageResult;
|
||||
yield { kind: 'message' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const expectedContextId = 'user-context-id';
|
||||
const expectedTaskId = 'user-task-id';
|
||||
|
||||
const controller = new AbortController();
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
|
||||
contextId: expectedContextId,
|
||||
taskId: expectedTaskId,
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
// trigger execution
|
||||
for await (const _ of stream) {
|
||||
// consume stream
|
||||
break;
|
||||
}
|
||||
|
||||
const call = sendMessageStreamMock.mock.calls[0][0];
|
||||
expect(call.message.contextId).toBe(expectedContextId);
|
||||
expect(call.message.taskId).toBe(expectedTaskId);
|
||||
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
expect.objectContaining({ signal: controller.signal }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should propagate the original error on failure', async () => {
|
||||
sendMessageStreamMock.mockImplementationOnce(() => {
|
||||
throw new Error('Network error');
|
||||
it('should handle a multi-chunk stream with different event types', async () => {
|
||||
mockClient.sendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { kind: 'message', messageId: 'm1' };
|
||||
yield { kind: 'status-update', taskId: 't1' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello');
|
||||
const results = [];
|
||||
for await (const result of stream) {
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
expect(results).toHaveLength(2);
|
||||
expect(results[0].kind).toBe('message');
|
||||
expect(results[1].kind).toBe('status-update');
|
||||
});
|
||||
|
||||
it('should throw prefixed error on failure', async () => {
|
||||
mockClient.sendMessageStream.mockImplementation(() => {
|
||||
throw new Error('Network failure');
|
||||
});
|
||||
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello');
|
||||
await expect(async () => {
|
||||
for await (const _ of stream) {
|
||||
// consume
|
||||
// empty
|
||||
}
|
||||
}).rejects.toThrow('Network error');
|
||||
}).rejects.toThrow(
|
||||
'[A2AClientManager] sendMessageStream Error [TestAgent]: Network failure',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if the agent is not found', async () => {
|
||||
const stream = manager.sendMessageStream('NonExistentAgent', 'Hello');
|
||||
await expect(async () => {
|
||||
for await (const _ of stream) {
|
||||
// consume
|
||||
// empty
|
||||
}
|
||||
}).rejects.toThrow("Agent 'NonExistentAgent' not found.");
|
||||
});
|
||||
@@ -374,28 +438,23 @@ describe('A2AClientManager', () => {
|
||||
|
||||
describe('getTask', () => {
|
||||
beforeEach(async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent');
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
});
|
||||
|
||||
it('should get a task from the correct agent', async () => {
|
||||
getTaskMock.mockResolvedValue({
|
||||
id: 'task123',
|
||||
contextId: 'a',
|
||||
kind: 'task',
|
||||
status: { state: 'completed' },
|
||||
} as Task);
|
||||
const mockTask = { id: 'task123', kind: 'task' };
|
||||
mockClient.getTask.mockResolvedValue(mockTask);
|
||||
|
||||
await manager.getTask('TestAgent', 'task123');
|
||||
expect(getTaskMock).toHaveBeenCalledWith({
|
||||
id: 'task123',
|
||||
});
|
||||
const result = await manager.getTask('TestAgent', 'task123');
|
||||
expect(result).toBe(mockTask);
|
||||
expect(mockClient.getTask).toHaveBeenCalledWith({ id: 'task123' });
|
||||
});
|
||||
|
||||
it('should throw prefixed error on failure', async () => {
|
||||
getTaskMock.mockRejectedValueOnce(new Error('Network error'));
|
||||
mockClient.getTask.mockRejectedValue(new Error('Not found'));
|
||||
|
||||
await expect(manager.getTask('TestAgent', 'task123')).rejects.toThrow(
|
||||
'A2AClient getTask Error [TestAgent]: Network error',
|
||||
'A2AClient getTask Error [TestAgent]: Not found',
|
||||
);
|
||||
});
|
||||
|
||||
@@ -408,28 +467,23 @@ describe('A2AClientManager', () => {
|
||||
|
||||
describe('cancelTask', () => {
|
||||
beforeEach(async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent');
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
});
|
||||
|
||||
it('should cancel a task on the correct agent', async () => {
|
||||
cancelTaskMock.mockResolvedValue({
|
||||
id: 'task123',
|
||||
contextId: 'a',
|
||||
kind: 'task',
|
||||
status: { state: 'canceled' },
|
||||
} as Task);
|
||||
const mockTask = { id: 'task123', kind: 'task' };
|
||||
mockClient.cancelTask.mockResolvedValue(mockTask);
|
||||
|
||||
await manager.cancelTask('TestAgent', 'task123');
|
||||
expect(cancelTaskMock).toHaveBeenCalledWith({
|
||||
id: 'task123',
|
||||
});
|
||||
const result = await manager.cancelTask('TestAgent', 'task123');
|
||||
expect(result).toBe(mockTask);
|
||||
expect(mockClient.cancelTask).toHaveBeenCalledWith({ id: 'task123' });
|
||||
});
|
||||
|
||||
it('should throw prefixed error on failure', async () => {
|
||||
cancelTaskMock.mockRejectedValueOnce(new Error('Network error'));
|
||||
mockClient.cancelTask.mockRejectedValue(new Error('Cannot cancel'));
|
||||
|
||||
await expect(manager.cancelTask('TestAgent', 'task123')).rejects.toThrow(
|
||||
'A2AClient cancelTask Error [TestAgent]: Network error',
|
||||
'A2AClient cancelTask Error [TestAgent]: Cannot cancel',
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -12,36 +12,41 @@ import type {
|
||||
TaskStatusUpdateEvent,
|
||||
TaskArtifactUpdateEvent,
|
||||
} from '@a2a-js/sdk';
|
||||
import type { AuthenticationHandler, Client } from '@a2a-js/sdk/client';
|
||||
import {
|
||||
type Client,
|
||||
ClientFactory,
|
||||
ClientFactoryOptions,
|
||||
DefaultAgentCardResolver,
|
||||
RestTransportFactory,
|
||||
JsonRpcTransportFactory,
|
||||
type AuthenticationHandler,
|
||||
RestTransportFactory,
|
||||
createAuthenticatingFetchWithRetry,
|
||||
} from '@a2a-js/sdk/client';
|
||||
import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc';
|
||||
import * as grpc from '@grpc/grpc-js';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { Agent as UndiciAgent, ProxyAgent } from 'undici';
|
||||
import { normalizeAgentCard } from './a2aUtils.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { safeLookup } from '../utils/fetch.js';
|
||||
import { classifyAgentError } from './a2a-errors.js';
|
||||
|
||||
// Remote agents can take 10+ minutes (e.g. Deep Research).
|
||||
// Use a dedicated dispatcher so the global 5-min timeout isn't affected.
|
||||
const A2A_TIMEOUT = 1800000; // 30 minutes
|
||||
|
||||
/**
|
||||
* Result of sending a message, which can be a full message, a task,
|
||||
* or an incremental status/artifact update.
|
||||
*/
|
||||
export type SendMessageResult =
|
||||
| Message
|
||||
| Task
|
||||
| TaskStatusUpdateEvent
|
||||
| TaskArtifactUpdateEvent;
|
||||
|
||||
// Remote agents can take 10+ minutes (e.g. Deep Research).
|
||||
// Use a dedicated dispatcher so the global 5-min timeout isn't affected.
|
||||
const A2A_TIMEOUT = 1800000; // 30 minutes
|
||||
|
||||
/**
|
||||
* Manages A2A clients and caches loaded agent information.
|
||||
* Follows a singleton pattern to ensure a single client instance.
|
||||
* Orchestrates communication with remote A2A agents.
|
||||
* Manages protocol negotiation, authentication, and transport selection.
|
||||
*/
|
||||
export class A2AClientManager {
|
||||
private static instance: A2AClientManager;
|
||||
@@ -58,9 +63,6 @@ export class A2AClientManager {
|
||||
const agentOptions = {
|
||||
headersTimeout: A2A_TIMEOUT,
|
||||
bodyTimeout: A2A_TIMEOUT,
|
||||
connect: {
|
||||
lookup: safeLookup, // SSRF protection at connection level
|
||||
},
|
||||
};
|
||||
|
||||
if (proxyUrl) {
|
||||
@@ -73,7 +75,6 @@ export class A2AClientManager {
|
||||
}
|
||||
|
||||
this.a2aFetch = (input, init) =>
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
fetch(input, { ...init, dispatcher: this.a2aDispatcher } as RequestInit);
|
||||
}
|
||||
|
||||
@@ -139,22 +140,35 @@ export class A2AClientManager {
|
||||
};
|
||||
|
||||
const resolver = new DefaultAgentCardResolver({ fetchImpl: cardFetch });
|
||||
const rawCard = await resolver.resolve(agentCardUrl, '');
|
||||
// TODO: Remove normalizeAgentCard once @a2a-js/sdk handles
|
||||
// proto field name aliases (supportedInterfaces → additionalInterfaces,
|
||||
// protocolBinding → transport).
|
||||
const agentCard = normalizeAgentCard(rawCard);
|
||||
|
||||
const options = ClientFactoryOptions.createFrom(
|
||||
const grpcUrl =
|
||||
agentCard.additionalInterfaces?.find((i) => i.transport === 'GRPC')
|
||||
?.url ?? agentCard.url;
|
||||
|
||||
const clientOptions = ClientFactoryOptions.createFrom(
|
||||
ClientFactoryOptions.default,
|
||||
{
|
||||
transports: [
|
||||
new RestTransportFactory({ fetchImpl: authFetch }),
|
||||
new JsonRpcTransportFactory({ fetchImpl: authFetch }),
|
||||
new GrpcTransportFactory({
|
||||
grpcChannelCredentials: grpcUrl.startsWith('https://')
|
||||
? grpc.credentials.createSsl()
|
||||
: grpc.credentials.createInsecure(),
|
||||
}),
|
||||
],
|
||||
cardResolver: resolver,
|
||||
},
|
||||
);
|
||||
|
||||
try {
|
||||
const factory = new ClientFactory(options);
|
||||
const client = await factory.createFromUrl(agentCardUrl, '');
|
||||
const agentCard = await client.getAgentCard();
|
||||
const factory = new ClientFactory(clientOptions);
|
||||
const client = await factory.createFromAgentCard(agentCard);
|
||||
|
||||
this.clients.set(name, client);
|
||||
this.agentCards.set(name, agentCard);
|
||||
@@ -192,9 +206,7 @@ export class A2AClientManager {
|
||||
options?: { contextId?: string; taskId?: string; signal?: AbortSignal },
|
||||
): AsyncIterable<SendMessageResult> {
|
||||
const client = this.clients.get(agentName);
|
||||
if (!client) {
|
||||
throw new Error(`Agent '${agentName}' not found.`);
|
||||
}
|
||||
if (!client) throw new Error(`Agent '${agentName}' not found.`);
|
||||
|
||||
const messageParams: MessageSendParams = {
|
||||
message: {
|
||||
@@ -207,9 +219,19 @@ export class A2AClientManager {
|
||||
},
|
||||
};
|
||||
|
||||
yield* client.sendMessageStream(messageParams, {
|
||||
signal: options?.signal,
|
||||
});
|
||||
try {
|
||||
yield* client.sendMessageStream(messageParams, {
|
||||
signal: options?.signal,
|
||||
});
|
||||
} catch (error: unknown) {
|
||||
const prefix = `[A2AClientManager] sendMessageStream Error [${agentName}]`;
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`${prefix}: ${error.message}`, { cause: error });
|
||||
}
|
||||
throw new Error(
|
||||
`${prefix}: Unexpected error during sendMessageStream: ${String(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -238,9 +260,7 @@ export class A2AClientManager {
|
||||
*/
|
||||
async getTask(agentName: string, taskId: string): Promise<Task> {
|
||||
const client = this.clients.get(agentName);
|
||||
if (!client) {
|
||||
throw new Error(`Agent '${agentName}' not found.`);
|
||||
}
|
||||
if (!client) throw new Error(`Agent '${agentName}' not found.`);
|
||||
try {
|
||||
return await client.getTask({ id: taskId });
|
||||
} catch (error: unknown) {
|
||||
@@ -260,9 +280,7 @@ export class A2AClientManager {
|
||||
*/
|
||||
async cancelTask(agentName: string, taskId: string): Promise<Task> {
|
||||
const client = this.clients.get(agentName);
|
||||
if (!client) {
|
||||
throw new Error(`Agent '${agentName}' not found.`);
|
||||
}
|
||||
if (!client) throw new Error(`Agent '${agentName}' not found.`);
|
||||
try {
|
||||
return await client.cancelTask({ id: taskId });
|
||||
} catch (error: unknown) {
|
||||
|
||||
@@ -12,9 +12,6 @@ import {
|
||||
A2AResultReassembler,
|
||||
AUTH_REQUIRED_MSG,
|
||||
normalizeAgentCard,
|
||||
getGrpcCredentials,
|
||||
pinUrlToIp,
|
||||
splitAgentCardUrl,
|
||||
} from './a2aUtils.js';
|
||||
import type { SendMessageResult } from './a2a-client-manager.js';
|
||||
import type {
|
||||
@@ -26,12 +23,6 @@ import type {
|
||||
TaskStatusUpdateEvent,
|
||||
TaskArtifactUpdateEvent,
|
||||
} from '@a2a-js/sdk';
|
||||
import * as dnsPromises from 'node:dns/promises';
|
||||
import type { LookupAddress } from 'node:dns';
|
||||
|
||||
vi.mock('node:dns/promises', () => ({
|
||||
lookup: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('a2aUtils', () => {
|
||||
beforeEach(() => {
|
||||
@@ -42,89 +33,6 @@ describe('a2aUtils', () => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('getGrpcCredentials', () => {
|
||||
it('should return secure credentials for https', () => {
|
||||
const credentials = getGrpcCredentials('https://test.agent');
|
||||
expect(credentials).toBeDefined();
|
||||
});
|
||||
|
||||
it('should return insecure credentials for http', () => {
|
||||
const credentials = getGrpcCredentials('http://test.agent');
|
||||
expect(credentials).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('pinUrlToIp', () => {
|
||||
it('should resolve and pin hostname to IP', async () => {
|
||||
vi.mocked(
|
||||
dnsPromises.lookup as unknown as (
|
||||
hostname: string,
|
||||
options: { all: true },
|
||||
) => Promise<LookupAddress[]>,
|
||||
).mockResolvedValue([{ address: '93.184.216.34', family: 4 }]);
|
||||
|
||||
const { pinnedUrl, hostname } = await pinUrlToIp(
|
||||
'http://example.com:9000',
|
||||
'test-agent',
|
||||
);
|
||||
expect(hostname).toBe('example.com');
|
||||
expect(pinnedUrl).toBe('http://93.184.216.34:9000/');
|
||||
});
|
||||
|
||||
it('should handle raw host:port strings (standard for gRPC)', async () => {
|
||||
vi.mocked(
|
||||
dnsPromises.lookup as unknown as (
|
||||
hostname: string,
|
||||
options: { all: true },
|
||||
) => Promise<LookupAddress[]>,
|
||||
).mockResolvedValue([{ address: '93.184.216.34', family: 4 }]);
|
||||
|
||||
const { pinnedUrl, hostname } = await pinUrlToIp(
|
||||
'example.com:9000',
|
||||
'test-agent',
|
||||
);
|
||||
expect(hostname).toBe('example.com');
|
||||
expect(pinnedUrl).toBe('93.184.216.34:9000');
|
||||
});
|
||||
|
||||
it('should throw error if resolution fails (fail closed)', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error'));
|
||||
|
||||
await expect(
|
||||
pinUrlToIp('http://unreachable.com', 'test-agent'),
|
||||
).rejects.toThrow("Failed to resolve host for agent 'test-agent'");
|
||||
});
|
||||
|
||||
it('should throw error if resolved to private IP', async () => {
|
||||
vi.mocked(
|
||||
dnsPromises.lookup as unknown as (
|
||||
hostname: string,
|
||||
options: { all: true },
|
||||
) => Promise<LookupAddress[]>,
|
||||
).mockResolvedValue([{ address: '10.0.0.1', family: 4 }]);
|
||||
|
||||
await expect(
|
||||
pinUrlToIp('http://malicious.com', 'test-agent'),
|
||||
).rejects.toThrow('resolves to private IP range');
|
||||
});
|
||||
|
||||
it('should allow localhost/127.0.0.1/::1 exceptions', async () => {
|
||||
vi.mocked(
|
||||
dnsPromises.lookup as unknown as (
|
||||
hostname: string,
|
||||
options: { all: true },
|
||||
) => Promise<LookupAddress[]>,
|
||||
).mockResolvedValue([{ address: '127.0.0.1', family: 4 }]);
|
||||
|
||||
const { pinnedUrl, hostname } = await pinUrlToIp(
|
||||
'http://localhost:9000',
|
||||
'test-agent',
|
||||
);
|
||||
expect(hostname).toBe('localhost');
|
||||
expect(pinnedUrl).toBe('http://127.0.0.1:9000/');
|
||||
});
|
||||
});
|
||||
|
||||
describe('isTerminalState', () => {
|
||||
it('should return true for completed, failed, canceled, and rejected', () => {
|
||||
expect(isTerminalState('completed')).toBe(true);
|
||||
@@ -365,12 +273,12 @@ describe('a2aUtils', () => {
|
||||
expect(normalized.name).toBe('my-agent');
|
||||
// @ts-expect-error - testing dynamic preservation
|
||||
expect(normalized.customField).toBe('keep-me');
|
||||
expect(normalized.description).toBe('');
|
||||
expect(normalized.skills).toEqual([]);
|
||||
expect(normalized.defaultInputModes).toEqual([]);
|
||||
expect(normalized.description).toBeUndefined();
|
||||
expect(normalized.skills).toBeUndefined();
|
||||
expect(normalized.defaultInputModes).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should normalize and synchronize interfaces while preserving other fields', () => {
|
||||
it('should map supportedInterfaces to additionalInterfaces with protocolBinding → transport', () => {
|
||||
const raw = {
|
||||
name: 'test',
|
||||
supportedInterfaces: [
|
||||
@@ -384,13 +292,7 @@ describe('a2aUtils', () => {
|
||||
|
||||
const normalized = normalizeAgentCard(raw);
|
||||
|
||||
// Should exist in both fields
|
||||
expect(normalized.additionalInterfaces).toHaveLength(1);
|
||||
expect(
|
||||
(normalized as unknown as Record<string, unknown>)[
|
||||
'supportedInterfaces'
|
||||
],
|
||||
).toHaveLength(1);
|
||||
|
||||
const intf = normalized.additionalInterfaces?.[0] as unknown as Record<
|
||||
string,
|
||||
@@ -399,43 +301,18 @@ describe('a2aUtils', () => {
|
||||
|
||||
expect(intf['transport']).toBe('GRPC');
|
||||
expect(intf['url']).toBe('grpc://test');
|
||||
|
||||
// Should fallback top-level url
|
||||
expect(normalized.url).toBe('grpc://test');
|
||||
});
|
||||
|
||||
it('should preserve existing top-level url if present', () => {
|
||||
it('should not overwrite additionalInterfaces if already present', () => {
|
||||
const raw = {
|
||||
name: 'test',
|
||||
url: 'http://existing',
|
||||
additionalInterfaces: [{ url: 'http://grpc', transport: 'GRPC' }],
|
||||
supportedInterfaces: [{ url: 'http://other', transport: 'REST' }],
|
||||
};
|
||||
|
||||
const normalized = normalizeAgentCard(raw);
|
||||
expect(normalized.url).toBe('http://existing');
|
||||
});
|
||||
|
||||
it('should NOT prepend http:// scheme to raw IP:port strings for gRPC interfaces', () => {
|
||||
const raw = {
|
||||
name: 'raw-ip-grpc',
|
||||
supportedInterfaces: [{ url: '127.0.0.1:9000', transport: 'GRPC' }],
|
||||
};
|
||||
|
||||
const normalized = normalizeAgentCard(raw);
|
||||
expect(normalized.additionalInterfaces?.[0].url).toBe('127.0.0.1:9000');
|
||||
expect(normalized.url).toBe('127.0.0.1:9000');
|
||||
});
|
||||
|
||||
it('should prepend http:// scheme to raw IP:port strings for REST interfaces', () => {
|
||||
const raw = {
|
||||
name: 'raw-ip-rest',
|
||||
supportedInterfaces: [{ url: '127.0.0.1:8080', transport: 'REST' }],
|
||||
};
|
||||
|
||||
const normalized = normalizeAgentCard(raw);
|
||||
expect(normalized.additionalInterfaces?.[0].url).toBe(
|
||||
'http://127.0.0.1:8080',
|
||||
);
|
||||
expect(normalized.additionalInterfaces).toHaveLength(1);
|
||||
expect(normalized.additionalInterfaces?.[0].url).toBe('http://grpc');
|
||||
});
|
||||
|
||||
it('should NOT override existing transport if protocolBinding is also present', () => {
|
||||
@@ -448,48 +325,20 @@ describe('a2aUtils', () => {
|
||||
const normalized = normalizeAgentCard(raw);
|
||||
expect(normalized.additionalInterfaces?.[0].transport).toBe('GRPC');
|
||||
});
|
||||
});
|
||||
|
||||
describe('splitAgentCardUrl', () => {
|
||||
const standard = '.well-known/agent-card.json';
|
||||
it('should not mutate the original card object', () => {
|
||||
const raw = {
|
||||
name: 'test',
|
||||
supportedInterfaces: [{ url: 'grpc://test', protocolBinding: 'GRPC' }],
|
||||
};
|
||||
|
||||
it('should return baseUrl as-is if it does not end with standard path', () => {
|
||||
const url = 'http://localhost:9001/custom/path';
|
||||
expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url });
|
||||
});
|
||||
|
||||
it('should split correctly if URL ends with standard path', () => {
|
||||
const url = `http://localhost:9001/${standard}`;
|
||||
expect(splitAgentCardUrl(url)).toEqual({
|
||||
baseUrl: 'http://localhost:9001/',
|
||||
path: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle trailing slash in baseUrl when splitting', () => {
|
||||
const url = `http://example.com/api/${standard}`;
|
||||
expect(splitAgentCardUrl(url)).toEqual({
|
||||
baseUrl: 'http://example.com/api/',
|
||||
path: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
it('should ignore hashes and query params when splitting', () => {
|
||||
const url = `http://localhost:9001/${standard}?foo=bar#baz`;
|
||||
expect(splitAgentCardUrl(url)).toEqual({
|
||||
baseUrl: 'http://localhost:9001/',
|
||||
path: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
it('should return original URL if parsing fails', () => {
|
||||
const url = 'not-a-url';
|
||||
expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url });
|
||||
});
|
||||
|
||||
it('should handle standard path appearing earlier in the path', () => {
|
||||
const url = `http://localhost:9001/${standard}/something-else`;
|
||||
expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url });
|
||||
const normalized = normalizeAgentCard(raw);
|
||||
expect(normalized).not.toBe(raw);
|
||||
expect(normalized.additionalInterfaces).toBeDefined();
|
||||
// Original should not have additionalInterfaces added
|
||||
expect(
|
||||
(raw as Record<string, unknown>)['additionalInterfaces'],
|
||||
).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -4,9 +4,6 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as grpc from '@grpc/grpc-js';
|
||||
import { lookup } from 'node:dns/promises';
|
||||
import { z } from 'zod';
|
||||
import type {
|
||||
Message,
|
||||
Part,
|
||||
@@ -18,37 +15,10 @@ import type {
|
||||
AgentCard,
|
||||
AgentInterface,
|
||||
} from '@a2a-js/sdk';
|
||||
import { isAddressPrivate } from '../utils/fetch.js';
|
||||
import type { SendMessageResult } from './a2a-client-manager.js';
|
||||
|
||||
export const AUTH_REQUIRED_MSG = `[Authorization Required] The agent has indicated it requires authorization to proceed. Please follow the agent's instructions.`;
|
||||
|
||||
const AgentInterfaceSchema = z
|
||||
.object({
|
||||
url: z.string().default(''),
|
||||
transport: z.string().optional(),
|
||||
protocolBinding: z.string().optional(),
|
||||
})
|
||||
.passthrough();
|
||||
|
||||
const AgentCardSchema = z
|
||||
.object({
|
||||
name: z.string().default('unknown'),
|
||||
description: z.string().default(''),
|
||||
url: z.string().default(''),
|
||||
version: z.string().default(''),
|
||||
protocolVersion: z.string().default(''),
|
||||
capabilities: z.record(z.unknown()).default({}),
|
||||
skills: z.array(z.union([z.string(), z.record(z.unknown())])).default([]),
|
||||
defaultInputModes: z.array(z.string()).default([]),
|
||||
defaultOutputModes: z.array(z.string()).default([]),
|
||||
|
||||
additionalInterfaces: z.array(AgentInterfaceSchema).optional(),
|
||||
supportedInterfaces: z.array(AgentInterfaceSchema).optional(),
|
||||
preferredTransport: z.string().optional(),
|
||||
})
|
||||
.passthrough();
|
||||
|
||||
/**
|
||||
* Reassembles incremental A2A streaming updates into a coherent result.
|
||||
* Shows sequential status/messages followed by all reassembled artifacts.
|
||||
@@ -241,166 +211,45 @@ function extractPartText(part: Part): string {
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalizes an agent card by ensuring it has the required properties
|
||||
* and resolving any inconsistencies between protocol versions.
|
||||
* Normalizes proto field name aliases that the SDK doesn't handle yet.
|
||||
* The A2A proto spec uses `supported_interfaces` and `protocol_binding`,
|
||||
* while the SDK expects `additionalInterfaces` and `transport`.
|
||||
* TODO: Remove once @a2a-js/sdk handles these aliases natively.
|
||||
*/
|
||||
export function normalizeAgentCard(card: unknown): AgentCard {
|
||||
if (!isObject(card)) {
|
||||
throw new Error('Agent card is missing.');
|
||||
}
|
||||
|
||||
// Use Zod to validate and parse the card, ensuring safe defaults and narrowing types.
|
||||
const parsed = AgentCardSchema.parse(card);
|
||||
// Narrowing to AgentCard interface after runtime validation.
|
||||
// Shallow-copy to avoid mutating the SDK's cached object.
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const result = parsed as unknown as AgentCard;
|
||||
const result = { ...card } as unknown as AgentCard;
|
||||
|
||||
// Normalize interfaces and synchronize both interface fields.
|
||||
const normalizedInterfaces = extractNormalizedInterfaces(parsed);
|
||||
result.additionalInterfaces = normalizedInterfaces;
|
||||
|
||||
// Sync supportedInterfaces for backward compatibility.
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const legacyResult = result as unknown as Record<string, AgentInterface[]>;
|
||||
legacyResult['supportedInterfaces'] = normalizedInterfaces;
|
||||
|
||||
// Fallback preferredTransport: If not specified, default to GRPC if available.
|
||||
if (
|
||||
!result.preferredTransport &&
|
||||
normalizedInterfaces.some((i) => i.transport === 'GRPC')
|
||||
) {
|
||||
result.preferredTransport = 'GRPC';
|
||||
// Map supportedInterfaces → additionalInterfaces if needed
|
||||
if (!result.additionalInterfaces) {
|
||||
const raw = card;
|
||||
if (Array.isArray(raw['supportedInterfaces'])) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
result.additionalInterfaces = raw[
|
||||
'supportedInterfaces'
|
||||
] as AgentInterface[];
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: If top-level URL is missing, use the first interface's URL.
|
||||
if (result.url === '' && normalizedInterfaces.length > 0) {
|
||||
result.url = normalizedInterfaces[0].url;
|
||||
// Map protocolBinding → transport on each interface
|
||||
for (const intf of result.additionalInterfaces ?? []) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const raw = intf as unknown as Record<string, unknown>;
|
||||
const binding = raw['protocolBinding'];
|
||||
|
||||
if (!intf.transport && typeof binding === 'string') {
|
||||
intf.transport = binding;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns gRPC channel credentials based on the URL scheme.
|
||||
*/
|
||||
export function getGrpcCredentials(url: string): grpc.ChannelCredentials {
|
||||
return url.startsWith('https://')
|
||||
? grpc.credentials.createSsl()
|
||||
: grpc.credentials.createInsecure();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns gRPC channel options to ensure SSL/authority matches the original hostname
|
||||
* when connecting via a pinned IP address.
|
||||
*/
|
||||
export function getGrpcChannelOptions(
|
||||
hostname: string,
|
||||
): Record<string, unknown> {
|
||||
return {
|
||||
'grpc.default_authority': hostname,
|
||||
'grpc.ssl_target_name_override': hostname,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves a hostname to its IP address and validates it against SSRF.
|
||||
* Returns the pinned IP-based URL and the original hostname.
|
||||
*/
|
||||
export async function pinUrlToIp(
|
||||
url: string,
|
||||
agentName: string,
|
||||
): Promise<{ pinnedUrl: string; hostname: string }> {
|
||||
if (!url) return { pinnedUrl: url, hostname: '' };
|
||||
|
||||
// gRPC URLs in A2A can be 'host:port' or 'dns:///host:port' or have schemes.
|
||||
// We normalize to host:port for resolution.
|
||||
const hasScheme = url.includes('://');
|
||||
const normalizedUrl = hasScheme ? url : `http://${url}`;
|
||||
|
||||
try {
|
||||
const parsed = new URL(normalizedUrl);
|
||||
const hostname = parsed.hostname;
|
||||
|
||||
const sanitizedHost =
|
||||
hostname.startsWith('[') && hostname.endsWith(']')
|
||||
? hostname.slice(1, -1)
|
||||
: hostname;
|
||||
|
||||
// Resolve DNS to check the actual target IP and pin it
|
||||
const addresses = await lookup(hostname, { all: true });
|
||||
const publicAddresses = addresses.filter(
|
||||
(addr) =>
|
||||
!isAddressPrivate(addr.address) ||
|
||||
sanitizedHost === 'localhost' ||
|
||||
sanitizedHost === '127.0.0.1' ||
|
||||
sanitizedHost === '::1',
|
||||
);
|
||||
|
||||
if (publicAddresses.length === 0) {
|
||||
if (addresses.length > 0) {
|
||||
throw new Error(
|
||||
`Refusing to load agent '${agentName}': transport URL '${url}' resolves to private IP range.`,
|
||||
);
|
||||
}
|
||||
throw new Error(
|
||||
`Failed to resolve any public IP addresses for host: ${hostname}`,
|
||||
);
|
||||
}
|
||||
|
||||
const pinnedIp = publicAddresses[0].address;
|
||||
const pinnedHostname = pinnedIp.includes(':') ? `[${pinnedIp}]` : pinnedIp;
|
||||
|
||||
// Reconstruct URL with IP
|
||||
parsed.hostname = pinnedHostname;
|
||||
let pinnedUrl = parsed.toString();
|
||||
|
||||
// If original didn't have scheme, remove it (standard for gRPC targets)
|
||||
if (!hasScheme) {
|
||||
pinnedUrl = pinnedUrl.replace(/^http:\/\//, '');
|
||||
// URL.toString() might append a trailing slash
|
||||
if (pinnedUrl.endsWith('/') && !url.endsWith('/')) {
|
||||
pinnedUrl = pinnedUrl.slice(0, -1);
|
||||
}
|
||||
}
|
||||
|
||||
return { pinnedUrl, hostname };
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.message.includes('Refusing')) throw e;
|
||||
throw new Error(`Failed to resolve host for agent '${agentName}': ${url}`, {
|
||||
cause: e,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Splts an agent card URL into a baseUrl and a standard path if it already
|
||||
* contains '.well-known/agent-card.json'.
|
||||
*/
|
||||
export function splitAgentCardUrl(url: string): {
|
||||
baseUrl: string;
|
||||
path?: string;
|
||||
} {
|
||||
const standardPath = '.well-known/agent-card.json';
|
||||
try {
|
||||
const parsedUrl = new URL(url);
|
||||
if (parsedUrl.pathname.endsWith(standardPath)) {
|
||||
// Reconstruct baseUrl from parsed components to avoid issues with hashes or query params.
|
||||
parsedUrl.pathname = parsedUrl.pathname.substring(
|
||||
0,
|
||||
parsedUrl.pathname.lastIndexOf(standardPath),
|
||||
);
|
||||
parsedUrl.search = '';
|
||||
parsedUrl.hash = '';
|
||||
// We return undefined for path if it's the standard one,
|
||||
// because the SDK's DefaultAgentCardResolver appends it automatically.
|
||||
return { baseUrl: parsedUrl.toString(), path: undefined };
|
||||
}
|
||||
} catch (_e) {
|
||||
// Ignore URL parsing errors here, let the resolver handle them.
|
||||
}
|
||||
return { baseUrl: url };
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts contextId and taskId from a Message, Task, or Update response.
|
||||
* Follows the pattern from the A2A CLI sample to maintain conversational continuity.
|
||||
@@ -446,65 +295,6 @@ export function extractIdsFromResponse(result: SendMessageResult): {
|
||||
return { contextId, taskId, clearTaskId };
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts and normalizes interfaces from the card, handling protocol version fallbacks.
|
||||
* Preserves all original fields to maintain SDK compatibility.
|
||||
*/
|
||||
function extractNormalizedInterfaces(
|
||||
card: Record<string, unknown>,
|
||||
): AgentInterface[] {
|
||||
const rawInterfaces =
|
||||
getArray(card, 'additionalInterfaces') ||
|
||||
getArray(card, 'supportedInterfaces');
|
||||
|
||||
if (!rawInterfaces) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const mapped: AgentInterface[] = [];
|
||||
for (const i of rawInterfaces) {
|
||||
if (isObject(i)) {
|
||||
// Use schema to validate interface object.
|
||||
const parsed = AgentInterfaceSchema.parse(i);
|
||||
// Narrowing to AgentInterface after runtime validation.
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const normalized = parsed as unknown as AgentInterface & {
|
||||
protocolBinding?: string;
|
||||
};
|
||||
|
||||
// Normalize 'transport' from 'protocolBinding' if missing.
|
||||
if (!normalized.transport && normalized.protocolBinding) {
|
||||
normalized.transport = normalized.protocolBinding;
|
||||
}
|
||||
|
||||
// Robust URL: Ensure the URL has a scheme (except for gRPC).
|
||||
if (
|
||||
normalized.url &&
|
||||
!normalized.url.includes('://') &&
|
||||
!normalized.url.startsWith('/') &&
|
||||
normalized.transport !== 'GRPC'
|
||||
) {
|
||||
// Default to http:// for insecure REST/JSON-RPC if scheme is missing.
|
||||
normalized.url = `http://${normalized.url}`;
|
||||
}
|
||||
|
||||
mapped.push(normalized as AgentInterface);
|
||||
}
|
||||
}
|
||||
return mapped;
|
||||
}
|
||||
|
||||
/**
|
||||
* Safely extracts an array property from an object.
|
||||
*/
|
||||
function getArray(
|
||||
obj: Record<string, unknown>,
|
||||
key: string,
|
||||
): unknown[] | undefined {
|
||||
const val = obj[key];
|
||||
return Array.isArray(val) ? val : undefined;
|
||||
}
|
||||
|
||||
// Type Guards
|
||||
|
||||
function isTextPart(part: Part): part is TextPart {
|
||||
|
||||
@@ -700,7 +700,6 @@ async function fetchAndCacheUserInfo(client: OAuth2Client): Promise<void> {
|
||||
return;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(
|
||||
'https://www.googleapis.com/oauth2/v2/userinfo',
|
||||
{
|
||||
|
||||
@@ -111,7 +111,6 @@ export class MCPOAuthProvider {
|
||||
scope: config.scopes?.join(' ') || '',
|
||||
};
|
||||
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(registrationUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -301,7 +300,6 @@ export class MCPOAuthProvider {
|
||||
? { Accept: 'text/event-stream' }
|
||||
: { Accept: 'application/json' };
|
||||
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(mcpServerUrl, {
|
||||
method: 'HEAD',
|
||||
headers,
|
||||
|
||||
@@ -97,7 +97,6 @@ export class OAuthUtils {
|
||||
resourceMetadataUrl: string,
|
||||
): Promise<OAuthProtectedResourceMetadata | null> {
|
||||
try {
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(resourceMetadataUrl);
|
||||
if (!response.ok) {
|
||||
return null;
|
||||
@@ -122,7 +121,6 @@ export class OAuthUtils {
|
||||
authServerMetadataUrl: string,
|
||||
): Promise<OAuthAuthorizationServerMetadata | null> {
|
||||
try {
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(authServerMetadataUrl);
|
||||
if (!response.ok) {
|
||||
return null;
|
||||
|
||||
@@ -546,7 +546,6 @@ export class ClearcutLogger {
|
||||
let result: LogResponse = {};
|
||||
|
||||
try {
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(CLEARCUT_URL, {
|
||||
method: 'POST',
|
||||
body: safeJsonStringify(request),
|
||||
|
||||
@@ -1903,7 +1903,6 @@ export async function connectToMcpServer(
|
||||
acceptHeader = 'application/json';
|
||||
}
|
||||
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(urlToFetch, {
|
||||
method: 'HEAD',
|
||||
headers: {
|
||||
|
||||
@@ -5,27 +5,12 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterAll } from 'vitest';
|
||||
import {
|
||||
isPrivateIp,
|
||||
isPrivateIpAsync,
|
||||
isAddressPrivate,
|
||||
safeLookup,
|
||||
safeFetch,
|
||||
fetchWithTimeout,
|
||||
PrivateIpError,
|
||||
} from './fetch.js';
|
||||
import * as dnsPromises from 'node:dns/promises';
|
||||
import * as dns from 'node:dns';
|
||||
import { isPrivateIp, isAddressPrivate, fetchWithTimeout } from './fetch.js';
|
||||
|
||||
vi.mock('node:dns/promises', () => ({
|
||||
lookup: vi.fn(),
|
||||
}));
|
||||
|
||||
// We need to mock node:dns for safeLookup since it uses the callback API
|
||||
vi.mock('node:dns', () => ({
|
||||
lookup: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock global fetch
|
||||
const originalFetch = global.fetch;
|
||||
global.fetch = vi.fn();
|
||||
@@ -114,150 +99,6 @@ describe('fetch utils', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('isPrivateIpAsync', () => {
|
||||
it('should identify private IPs directly', async () => {
|
||||
expect(await isPrivateIpAsync('http://10.0.0.1/')).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify domains resolving to private IPs', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockImplementation(
|
||||
async () =>
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
[{ address: '10.0.0.1', family: 4 }] as any,
|
||||
);
|
||||
expect(await isPrivateIpAsync('http://malicious.com/')).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify domains resolving to public IPs as non-private', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockImplementation(
|
||||
async () =>
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
[{ address: '8.8.8.8', family: 4 }] as any,
|
||||
);
|
||||
expect(await isPrivateIpAsync('http://google.com/')).toBe(false);
|
||||
});
|
||||
|
||||
it('should throw error if DNS resolution fails (fail closed)', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error'));
|
||||
await expect(isPrivateIpAsync('http://unreachable.com/')).rejects.toThrow(
|
||||
'Failed to verify if URL resolves to private IP',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return false for invalid URLs instead of throwing verification error', async () => {
|
||||
expect(await isPrivateIpAsync('not-a-url')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('safeLookup', () => {
|
||||
it('should filter out private IPs', async () => {
|
||||
const addresses = [
|
||||
{ address: '8.8.8.8', family: 4 },
|
||||
{ address: '10.0.0.1', family: 4 },
|
||||
];
|
||||
|
||||
vi.mocked(dns.lookup).mockImplementation(((
|
||||
_h: string,
|
||||
_o: dns.LookupOptions,
|
||||
cb: (
|
||||
err: Error | null,
|
||||
addr: Array<{ address: string; family: number }>,
|
||||
) => void,
|
||||
) => {
|
||||
cb(null, addresses);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
}) as any);
|
||||
|
||||
const result = await new Promise<
|
||||
Array<{ address: string; family: number }>
|
||||
>((resolve, reject) => {
|
||||
safeLookup('example.com', { all: true }, (err, filtered) => {
|
||||
if (err) reject(err);
|
||||
else resolve(filtered);
|
||||
});
|
||||
});
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].address).toBe('8.8.8.8');
|
||||
});
|
||||
|
||||
it('should allow explicit localhost', async () => {
|
||||
const addresses = [{ address: '127.0.0.1', family: 4 }];
|
||||
|
||||
vi.mocked(dns.lookup).mockImplementation(((
|
||||
_h: string,
|
||||
_o: dns.LookupOptions,
|
||||
cb: (
|
||||
err: Error | null,
|
||||
addr: Array<{ address: string; family: number }>,
|
||||
) => void,
|
||||
) => {
|
||||
cb(null, addresses);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
}) as any);
|
||||
|
||||
const result = await new Promise<
|
||||
Array<{ address: string; family: number }>
|
||||
>((resolve, reject) => {
|
||||
safeLookup('localhost', { all: true }, (err, filtered) => {
|
||||
if (err) reject(err);
|
||||
else resolve(filtered);
|
||||
});
|
||||
});
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].address).toBe('127.0.0.1');
|
||||
});
|
||||
|
||||
it('should error if all resolved IPs are private', async () => {
|
||||
const addresses = [{ address: '10.0.0.1', family: 4 }];
|
||||
|
||||
vi.mocked(dns.lookup).mockImplementation(((
|
||||
_h: string,
|
||||
_o: dns.LookupOptions,
|
||||
cb: (
|
||||
err: Error | null,
|
||||
addr: Array<{ address: string; family: number }>,
|
||||
) => void,
|
||||
) => {
|
||||
cb(null, addresses);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
}) as any);
|
||||
|
||||
await expect(
|
||||
new Promise((resolve, reject) => {
|
||||
safeLookup('malicious.com', { all: true }, (err, filtered) => {
|
||||
if (err) reject(err);
|
||||
else resolve(filtered);
|
||||
});
|
||||
}),
|
||||
).rejects.toThrow(PrivateIpError);
|
||||
});
|
||||
});
|
||||
|
||||
describe('safeFetch', () => {
|
||||
it('should forward to fetch with dispatcher', async () => {
|
||||
vi.mocked(global.fetch).mockResolvedValue(new Response('ok'));
|
||||
|
||||
const response = await safeFetch('https://example.com');
|
||||
expect(response.status).toBe(200);
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
'https://example.com',
|
||||
expect.objectContaining({
|
||||
dispatcher: expect.any(Object),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle Refusing to connect errors', async () => {
|
||||
vi.mocked(global.fetch).mockRejectedValue(new PrivateIpError());
|
||||
|
||||
await expect(safeFetch('http://10.0.0.1')).rejects.toThrow(
|
||||
'Access to private network is blocked',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('fetchWithTimeout', () => {
|
||||
it('should handle timeouts', async () => {
|
||||
vi.mocked(global.fetch).mockImplementation(
|
||||
@@ -279,13 +120,5 @@ describe('fetch utils', () => {
|
||||
'Request timed out after 50ms',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle private IP errors via handleFetchError', async () => {
|
||||
vi.mocked(global.fetch).mockRejectedValue(new PrivateIpError());
|
||||
|
||||
await expect(fetchWithTimeout('http://10.0.0.1', 1000)).rejects.toThrow(
|
||||
'Access to private network is blocked: http://10.0.0.1',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,37 +6,12 @@
|
||||
|
||||
import { getErrorMessage, isNodeError } from './errors.js';
|
||||
import { URL } from 'node:url';
|
||||
import * as dns from 'node:dns';
|
||||
import { lookup } from 'node:dns/promises';
|
||||
import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||
import ipaddr from 'ipaddr.js';
|
||||
|
||||
const DEFAULT_HEADERS_TIMEOUT = 300000; // 5 minutes
|
||||
const DEFAULT_BODY_TIMEOUT = 300000; // 5 minutes
|
||||
|
||||
// Configure default global dispatcher with higher timeouts
|
||||
setGlobalDispatcher(
|
||||
new Agent({
|
||||
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
|
||||
bodyTimeout: DEFAULT_BODY_TIMEOUT,
|
||||
}),
|
||||
);
|
||||
|
||||
// Local extension of RequestInit to support Node.js/undici dispatcher
|
||||
interface NodeFetchInit extends RequestInit {
|
||||
dispatcher?: Agent | ProxyAgent;
|
||||
}
|
||||
|
||||
/**
|
||||
* Error thrown when a connection to a private IP address is blocked for security reasons.
|
||||
*/
|
||||
export class PrivateIpError extends Error {
|
||||
constructor(message = 'Refusing to connect to private IP address') {
|
||||
super(message);
|
||||
this.name = 'PrivateIpError';
|
||||
}
|
||||
}
|
||||
|
||||
export class FetchError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
@@ -48,6 +23,14 @@ export class FetchError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
// Configure default global dispatcher with higher timeouts
|
||||
setGlobalDispatcher(
|
||||
new Agent({
|
||||
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
|
||||
bodyTimeout: DEFAULT_BODY_TIMEOUT,
|
||||
}),
|
||||
);
|
||||
|
||||
/**
|
||||
* Sanitizes a hostname by stripping IPv6 brackets if present.
|
||||
*/
|
||||
@@ -69,53 +52,6 @@ export function isLoopbackHost(hostname: string): boolean {
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* A custom DNS lookup implementation for undici agents that prevents
|
||||
* connection to private IP ranges (SSRF protection).
|
||||
*/
|
||||
export function safeLookup(
|
||||
hostname: string,
|
||||
options: dns.LookupOptions | number | null | undefined,
|
||||
callback: (
|
||||
err: Error | null,
|
||||
addresses: Array<{ address: string; family: number }>,
|
||||
) => void,
|
||||
): void {
|
||||
// Use the callback-based dns.lookup to match undici's expected signature.
|
||||
// We explicitly handle the 'all' option to ensure we get an array of addresses.
|
||||
const lookupOptions =
|
||||
typeof options === 'number' ? { family: options } : { ...options };
|
||||
const finalOptions = { ...lookupOptions, all: true };
|
||||
|
||||
dns.lookup(hostname, finalOptions, (err, addresses) => {
|
||||
if (err) {
|
||||
callback(err, []);
|
||||
return;
|
||||
}
|
||||
|
||||
const addressArray = Array.isArray(addresses) ? addresses : [];
|
||||
const filtered = addressArray.filter(
|
||||
(addr) => !isAddressPrivate(addr.address) || isLoopbackHost(hostname),
|
||||
);
|
||||
|
||||
if (filtered.length === 0 && addressArray.length > 0) {
|
||||
callback(new PrivateIpError(), []);
|
||||
return;
|
||||
}
|
||||
|
||||
callback(null, filtered);
|
||||
});
|
||||
}
|
||||
|
||||
// Dedicated dispatcher with connection-level SSRF protection (safeLookup)
|
||||
const safeDispatcher = new Agent({
|
||||
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
|
||||
bodyTimeout: DEFAULT_BODY_TIMEOUT,
|
||||
connect: {
|
||||
lookup: safeLookup,
|
||||
},
|
||||
});
|
||||
|
||||
export function isPrivateIp(url: string): boolean {
|
||||
try {
|
||||
const hostname = new URL(url).hostname;
|
||||
@@ -125,37 +61,6 @@ export function isPrivateIp(url: string): boolean {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a URL resolves to a private IP address.
|
||||
* Performs DNS resolution to prevent DNS rebinding/SSRF bypasses.
|
||||
*/
|
||||
export async function isPrivateIpAsync(url: string): Promise<boolean> {
|
||||
try {
|
||||
const parsed = new URL(url);
|
||||
const hostname = parsed.hostname;
|
||||
|
||||
// Fast check for literal IPs or localhost
|
||||
if (isAddressPrivate(hostname)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Resolve DNS to check the actual target IP
|
||||
const addresses = await lookup(hostname, { all: true });
|
||||
return addresses.some((addr) => isAddressPrivate(addr.address));
|
||||
} catch (e) {
|
||||
if (
|
||||
e instanceof Error &&
|
||||
e.name === 'TypeError' &&
|
||||
e.message.includes('Invalid URL')
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
throw new Error(`Failed to verify if URL resolves to private IP: ${url}`, {
|
||||
cause: e,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* IANA Benchmark Testing Range (198.18.0.0/15).
|
||||
* Classified as 'unicast' by ipaddr.js but is reserved and should not be
|
||||
@@ -210,72 +115,15 @@ export function isAddressPrivate(address: string): boolean {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal helper to map varied fetch errors to a standardized FetchError.
|
||||
* Centralizes security-related error mapping (e.g. PrivateIpError).
|
||||
*/
|
||||
function handleFetchError(error: unknown, url: string): never {
|
||||
if (error instanceof PrivateIpError) {
|
||||
throw new FetchError(
|
||||
`Access to private network is blocked: ${url}`,
|
||||
'ERR_PRIVATE_NETWORK',
|
||||
{ cause: error },
|
||||
);
|
||||
}
|
||||
|
||||
if (error instanceof FetchError) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
throw new FetchError(
|
||||
getErrorMessage(error),
|
||||
isNodeError(error) ? error.code : undefined,
|
||||
{ cause: error },
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Enhanced fetch with SSRF protection.
|
||||
* Prevents access to private/internal networks at the connection level.
|
||||
*/
|
||||
export async function safeFetch(
|
||||
input: RequestInfo | URL,
|
||||
init?: RequestInit,
|
||||
): Promise<Response> {
|
||||
const nodeInit: NodeFetchInit = {
|
||||
...init,
|
||||
dispatcher: safeDispatcher,
|
||||
};
|
||||
|
||||
try {
|
||||
// eslint-disable-next-line no-restricted-syntax
|
||||
return await fetch(input, nodeInit);
|
||||
} catch (error) {
|
||||
const url =
|
||||
input instanceof Request
|
||||
? input.url
|
||||
: typeof input === 'string'
|
||||
? input
|
||||
: input.toString();
|
||||
handleFetchError(error, url);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an undici ProxyAgent that incorporates safe DNS lookup.
|
||||
*/
|
||||
export function createSafeProxyAgent(proxyUrl: string): ProxyAgent {
|
||||
return new ProxyAgent({
|
||||
uri: proxyUrl,
|
||||
connect: {
|
||||
lookup: safeLookup,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs a fetch with a specified timeout and connection-level SSRF protection.
|
||||
*/
|
||||
export async function fetchWithTimeout(
|
||||
url: string,
|
||||
timeout: number,
|
||||
@@ -294,21 +142,17 @@ export async function fetchWithTimeout(
|
||||
}
|
||||
}
|
||||
|
||||
const nodeInit: NodeFetchInit = {
|
||||
...options,
|
||||
signal: controller.signal,
|
||||
dispatcher: safeDispatcher,
|
||||
};
|
||||
|
||||
try {
|
||||
// eslint-disable-next-line no-restricted-syntax
|
||||
const response = await fetch(url, nodeInit);
|
||||
const response = await fetch(url, {
|
||||
...options,
|
||||
signal: controller.signal,
|
||||
});
|
||||
return response;
|
||||
} catch (error) {
|
||||
if (isNodeError(error) && error.code === 'ABORT_ERR') {
|
||||
throw new FetchError(`Request timed out after ${timeout}ms`, 'ETIMEDOUT');
|
||||
}
|
||||
handleFetchError(error, url.toString());
|
||||
throw new FetchError(getErrorMessage(error), undefined, { cause: error });
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
|
||||
@@ -454,7 +454,6 @@ export async function exchangeCodeForToken(
|
||||
params.append('resource', resource);
|
||||
}
|
||||
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(config.tokenUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -508,7 +507,6 @@ export async function refreshAccessToken(
|
||||
params.append('resource', resource);
|
||||
}
|
||||
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(tokenUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
|
||||
@@ -42,7 +42,6 @@ async function checkForUpdates(
|
||||
const currentVersion = context.extension.packageJSON.version;
|
||||
|
||||
// Fetch extension details from the VSCode Marketplace.
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(
|
||||
'https://marketplace.visualstudio.com/_apis/public/gallery/extensionquery',
|
||||
{
|
||||
|
||||
@@ -356,7 +356,6 @@ describe('IDEServer', () => {
|
||||
});
|
||||
|
||||
it('should reject request without auth token', async () => {
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(`http://localhost:${port}/mcp`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
@@ -371,7 +370,6 @@ describe('IDEServer', () => {
|
||||
});
|
||||
|
||||
it('should allow request with valid auth token', async () => {
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(`http://localhost:${port}/mcp`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -389,7 +387,6 @@ describe('IDEServer', () => {
|
||||
});
|
||||
|
||||
it('should reject request with invalid auth token', async () => {
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(`http://localhost:${port}/mcp`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -416,7 +413,6 @@ describe('IDEServer', () => {
|
||||
];
|
||||
|
||||
for (const header of malformedHeaders) {
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(`http://localhost:${port}/mcp`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
|
||||
+41
-9
@@ -228,23 +228,35 @@ const packageJson = JSON.parse(
|
||||
// Helper to calc hash
|
||||
const sha256 = (content) => createHash('sha256').update(content).digest('hex');
|
||||
|
||||
// Read Main Bundle
|
||||
const geminiBundlePath = join(root, 'bundle/gemini.js');
|
||||
const geminiContent = readFileSync(geminiBundlePath);
|
||||
const geminiHash = sha256(geminiContent);
|
||||
|
||||
const assets = {
|
||||
'gemini.mjs': geminiBundlePath, // Use .js source but map to .mjs for runtime ESM
|
||||
'manifest.json': 'bundle/manifest.json',
|
||||
};
|
||||
|
||||
const manifest = {
|
||||
main: 'gemini.mjs',
|
||||
mainHash: geminiHash,
|
||||
mainHash: '',
|
||||
version: packageJson.version,
|
||||
files: [],
|
||||
};
|
||||
|
||||
// Add all javascript chunks from the bundle directory
|
||||
const jsFiles = globSync('*.js', { cwd: bundleDir });
|
||||
for (const jsFile of jsFiles) {
|
||||
const fsPath = join(bundleDir, jsFile);
|
||||
const content = readFileSync(fsPath);
|
||||
const hash = sha256(content);
|
||||
|
||||
// Node SEA requires the main entry point to be explicitly mapped
|
||||
if (jsFile === 'gemini.js') {
|
||||
assets['gemini.mjs'] = fsPath;
|
||||
manifest.mainHash = hash;
|
||||
} else {
|
||||
// Other chunks need to be mapped exactly as they are named so dynamic imports find them
|
||||
assets[jsFile] = fsPath;
|
||||
manifest.files.push({ key: jsFile, path: jsFile, hash: hash });
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to recursively find files from STAGING
|
||||
function addAssetsFromDir(baseDir, runtimePrefix) {
|
||||
const fullDir = join(stagingDir, baseDir);
|
||||
@@ -346,6 +358,22 @@ const targetBinaryPath = join(targetDir, binaryName);
|
||||
console.log(`Copying node binary from ${nodeBinary} to ${targetBinaryPath}...`);
|
||||
copyFileSync(nodeBinary, targetBinaryPath);
|
||||
|
||||
if (platform === 'darwin') {
|
||||
console.log(`Thinning universal binary for ${arch}...`);
|
||||
try {
|
||||
// Attempt to thin the binary. Will fail safely if it's not a fat binary.
|
||||
runCommand('lipo', [
|
||||
targetBinaryPath,
|
||||
'-thin',
|
||||
arch,
|
||||
'-output',
|
||||
targetBinaryPath,
|
||||
]);
|
||||
} catch (e) {
|
||||
console.log(`Skipping lipo thinning: ${e.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove existing signature using helper
|
||||
removeSignature(targetBinaryPath);
|
||||
|
||||
@@ -357,9 +385,7 @@ if (existsSync(bundleDir)) {
|
||||
|
||||
// Clean up source JS files from output (we only want embedded)
|
||||
const filesToRemove = [
|
||||
'gemini.js',
|
||||
'gemini.mjs',
|
||||
'gemini.js.map',
|
||||
'gemini.mjs.map',
|
||||
'gemini-sea.cjs',
|
||||
'sea-launch.cjs',
|
||||
@@ -373,6 +399,12 @@ filesToRemove.forEach((f) => {
|
||||
if (existsSync(p)) rmSync(p, { recursive: true, force: true });
|
||||
});
|
||||
|
||||
// Remove all chunk and entry .js/.js.map files
|
||||
const jsFilesToRemove = globSync('*.{js,js.map}', { cwd: targetDir });
|
||||
for (const f of jsFilesToRemove) {
|
||||
rmSync(join(targetDir, f));
|
||||
}
|
||||
|
||||
// Remove .sb files from targetDir
|
||||
const sbFilesToRemove = globSync('sandbox-macos-*.sb', { cwd: targetDir });
|
||||
for (const f of sbFilesToRemove) {
|
||||
|
||||
Reference in New Issue
Block a user