mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-23 11:34:44 -07:00
Merge branch 'main' into gemini-cli-headless-monitor
This commit is contained in:
@@ -0,0 +1,143 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import { TestRig, checkModelOutputContent } from './test-helper.js';
|
||||
|
||||
describe('Plan Mode', () => {
|
||||
let rig: TestRig;
|
||||
|
||||
beforeEach(() => {
|
||||
rig = new TestRig();
|
||||
});
|
||||
|
||||
afterEach(async () => await rig.cleanup());
|
||||
|
||||
it('should allow read-only tools but deny write tools in plan mode', async () => {
|
||||
await rig.setup(
|
||||
'should allow read-only tools but deny write tools in plan mode',
|
||||
{
|
||||
settings: {
|
||||
experimental: { plan: true },
|
||||
tools: {
|
||||
core: [
|
||||
'run_shell_command',
|
||||
'list_directory',
|
||||
'write_file',
|
||||
'read_file',
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
// We use a prompt that asks for both a read-only action and a write action.
|
||||
// "List files" (read-only) followed by "touch denied.txt" (write).
|
||||
const result = await rig.run({
|
||||
approvalMode: 'plan',
|
||||
stdin:
|
||||
'Please list the files in the current directory, and then attempt to create a new file named "denied.txt" using a shell command.',
|
||||
});
|
||||
|
||||
const lsCallFound = await rig.waitForToolCall('list_directory');
|
||||
expect(lsCallFound, 'Expected list_directory to be called').toBe(true);
|
||||
|
||||
const shellCallFound = await rig.waitForToolCall('run_shell_command');
|
||||
expect(shellCallFound, 'Expected run_shell_command to fail').toBe(false);
|
||||
|
||||
const toolLogs = rig.readToolLogs();
|
||||
const lsLog = toolLogs.find((l) => l.toolRequest.name === 'list_directory');
|
||||
expect(
|
||||
toolLogs.find((l) => l.toolRequest.name === 'run_shell_command'),
|
||||
).toBeUndefined();
|
||||
|
||||
expect(lsLog?.toolRequest.success).toBe(true);
|
||||
|
||||
checkModelOutputContent(result, {
|
||||
expectedContent: ['Plan Mode', 'read-only'],
|
||||
testName: 'Plan Mode restrictions test',
|
||||
});
|
||||
});
|
||||
|
||||
it('should allow write_file only in the plans directory in plan mode', async () => {
|
||||
await rig.setup(
|
||||
'should allow write_file only in the plans directory in plan mode',
|
||||
{
|
||||
settings: {
|
||||
experimental: { plan: true },
|
||||
tools: {
|
||||
core: ['write_file', 'read_file', 'list_directory'],
|
||||
allowed: ['write_file'],
|
||||
},
|
||||
general: { defaultApprovalMode: 'plan' },
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
// We ask the agent to create a plan for a feature, which should trigger a write_file in the plans directory.
|
||||
// Verify that write_file outside of plan directory fails
|
||||
await rig.run({
|
||||
approvalMode: 'plan',
|
||||
stdin:
|
||||
'Create a file called plan.md in the plans directory. Then create a file called hello.txt in the current directory',
|
||||
});
|
||||
|
||||
const toolLogs = rig.readToolLogs();
|
||||
const writeLogs = toolLogs.filter(
|
||||
(l) => l.toolRequest.name === 'write_file',
|
||||
);
|
||||
|
||||
const planWrite = writeLogs.find(
|
||||
(l) =>
|
||||
l.toolRequest.args.includes('plans') &&
|
||||
l.toolRequest.args.includes('plan.md'),
|
||||
);
|
||||
|
||||
const blockedWrite = writeLogs.find((l) =>
|
||||
l.toolRequest.args.includes('hello.txt'),
|
||||
);
|
||||
|
||||
// Model is undeterministic, sometimes a blocked write appears in tool logs and sometimes it doesn't
|
||||
if (blockedWrite) {
|
||||
expect(blockedWrite?.toolRequest.success).toBe(false);
|
||||
}
|
||||
|
||||
expect(planWrite?.toolRequest.success).toBe(true);
|
||||
});
|
||||
|
||||
it('should be able to enter plan mode from default mode', async () => {
|
||||
await rig.setup('should be able to enter plan mode from default mode', {
|
||||
settings: {
|
||||
experimental: { plan: true },
|
||||
tools: {
|
||||
core: ['enter_plan_mode'],
|
||||
allowed: ['enter_plan_mode'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Start in default mode and ask to enter plan mode.
|
||||
await rig.run({
|
||||
approvalMode: 'default',
|
||||
stdin:
|
||||
'I want to perform a complex refactoring. Please enter plan mode so we can design it first.',
|
||||
});
|
||||
|
||||
const enterPlanCallFound = await rig.waitForToolCall(
|
||||
'enter_plan_mode',
|
||||
10000,
|
||||
);
|
||||
expect(enterPlanCallFound, 'Expected enter_plan_mode to be called').toBe(
|
||||
true,
|
||||
);
|
||||
|
||||
const toolLogs = rig.readToolLogs();
|
||||
const enterLog = toolLogs.find(
|
||||
(l) => l.toolRequest.name === 'enter_plan_mode',
|
||||
);
|
||||
expect(enterLog?.toolRequest.success).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,188 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
|
||||
import * as fs from 'node:fs';
|
||||
import * as os from 'node:os';
|
||||
import * as path from 'node:path';
|
||||
import { ExtensionManager } from './extension-manager.js';
|
||||
import { createTestMergedSettings } from './settings.js';
|
||||
import { createExtension } from '../test-utils/createExtension.js';
|
||||
import { EXTENSIONS_DIRECTORY_NAME } from './extensions/variables.js';
|
||||
|
||||
const mockHomedir = vi.hoisted(() => vi.fn(() => '/tmp/mock-home'));
|
||||
|
||||
vi.mock('os', async (importOriginal) => {
|
||||
const mockedOs = await importOriginal<typeof os>();
|
||||
return {
|
||||
...mockedOs,
|
||||
homedir: mockHomedir,
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
const actual =
|
||||
await importOriginal<typeof import('@google/gemini-cli-core')>();
|
||||
return {
|
||||
...actual,
|
||||
homedir: mockHomedir,
|
||||
};
|
||||
});
|
||||
|
||||
describe('ExtensionManager', () => {
|
||||
let tempHomeDir: string;
|
||||
let tempWorkspaceDir: string;
|
||||
let userExtensionsDir: string;
|
||||
let extensionManager: ExtensionManager;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
tempHomeDir = fs.mkdtempSync(
|
||||
path.join(os.tmpdir(), 'gemini-cli-test-home-'),
|
||||
);
|
||||
tempWorkspaceDir = fs.mkdtempSync(
|
||||
path.join(tempHomeDir, 'gemini-cli-test-workspace-'),
|
||||
);
|
||||
mockHomedir.mockReturnValue(tempHomeDir);
|
||||
userExtensionsDir = path.join(tempHomeDir, EXTENSIONS_DIRECTORY_NAME);
|
||||
fs.mkdirSync(userExtensionsDir, { recursive: true });
|
||||
|
||||
extensionManager = new ExtensionManager({
|
||||
settings: createTestMergedSettings(),
|
||||
workspaceDir: tempWorkspaceDir,
|
||||
requestConsent: vi.fn().mockResolvedValue(true),
|
||||
requestSetting: null,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
try {
|
||||
fs.rmSync(tempHomeDir, { recursive: true, force: true });
|
||||
} catch (_e) {
|
||||
// Ignore
|
||||
}
|
||||
});
|
||||
|
||||
describe('loadExtensions parallel loading', () => {
|
||||
it('should prevent concurrent loading and return the same promise', async () => {
|
||||
createExtension({
|
||||
extensionsDir: userExtensionsDir,
|
||||
name: 'ext1',
|
||||
version: '1.0.0',
|
||||
});
|
||||
createExtension({
|
||||
extensionsDir: userExtensionsDir,
|
||||
name: 'ext2',
|
||||
version: '1.0.0',
|
||||
});
|
||||
|
||||
// Call loadExtensions twice concurrently
|
||||
const promise1 = extensionManager.loadExtensions();
|
||||
const promise2 = extensionManager.loadExtensions();
|
||||
|
||||
// They should resolve to the exact same array
|
||||
const [extensions1, extensions2] = await Promise.all([
|
||||
promise1,
|
||||
promise2,
|
||||
]);
|
||||
|
||||
expect(extensions1).toBe(extensions2);
|
||||
expect(extensions1).toHaveLength(2);
|
||||
|
||||
const names = extensions1.map((ext) => ext.name).sort();
|
||||
expect(names).toEqual(['ext1', 'ext2']);
|
||||
});
|
||||
|
||||
it('should throw an error if loadExtensions is called after it has already resolved', async () => {
|
||||
createExtension({
|
||||
extensionsDir: userExtensionsDir,
|
||||
name: 'ext1',
|
||||
version: '1.0.0',
|
||||
});
|
||||
|
||||
await extensionManager.loadExtensions();
|
||||
|
||||
await expect(extensionManager.loadExtensions()).rejects.toThrow(
|
||||
'Extensions already loaded, only load extensions once.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should not throw if extension directory does not exist', async () => {
|
||||
fs.rmSync(userExtensionsDir, { recursive: true, force: true });
|
||||
|
||||
const extensions = await extensionManager.loadExtensions();
|
||||
expect(extensions).toEqual([]);
|
||||
});
|
||||
|
||||
it('should throw if there are duplicate extension names', async () => {
|
||||
// We manually create two extensions with different dirs but same name in config
|
||||
const ext1Dir = path.join(userExtensionsDir, 'ext1-dir');
|
||||
const ext2Dir = path.join(userExtensionsDir, 'ext2-dir');
|
||||
fs.mkdirSync(ext1Dir, { recursive: true });
|
||||
fs.mkdirSync(ext2Dir, { recursive: true });
|
||||
|
||||
const config = JSON.stringify({
|
||||
name: 'duplicate-ext',
|
||||
version: '1.0.0',
|
||||
});
|
||||
fs.writeFileSync(path.join(ext1Dir, 'gemini-extension.json'), config);
|
||||
fs.writeFileSync(
|
||||
path.join(ext1Dir, 'metadata.json'),
|
||||
JSON.stringify({ type: 'local', source: ext1Dir }),
|
||||
);
|
||||
|
||||
fs.writeFileSync(path.join(ext2Dir, 'gemini-extension.json'), config);
|
||||
fs.writeFileSync(
|
||||
path.join(ext2Dir, 'metadata.json'),
|
||||
JSON.stringify({ type: 'local', source: ext2Dir }),
|
||||
);
|
||||
|
||||
await expect(extensionManager.loadExtensions()).rejects.toThrow(
|
||||
'Extension with name duplicate-ext already was loaded.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should wait for loadExtensions to finish when loadExtension is called concurrently', async () => {
|
||||
// Create an initial extension that loadExtensions will find
|
||||
createExtension({
|
||||
extensionsDir: userExtensionsDir,
|
||||
name: 'ext1',
|
||||
version: '1.0.0',
|
||||
});
|
||||
|
||||
// Start the parallel load (it will read ext1)
|
||||
const loadAllPromise = extensionManager.loadExtensions();
|
||||
|
||||
// Create a second extension dynamically in a DIFFERENT directory
|
||||
// so that loadExtensions (which scans userExtensionsDir) doesn't find it.
|
||||
const externalDir = fs.mkdtempSync(
|
||||
path.join(os.tmpdir(), 'external-ext-'),
|
||||
);
|
||||
fs.writeFileSync(
|
||||
path.join(externalDir, 'gemini-extension.json'),
|
||||
JSON.stringify({ name: 'ext2', version: '1.0.0' }),
|
||||
);
|
||||
fs.writeFileSync(
|
||||
path.join(externalDir, 'metadata.json'),
|
||||
JSON.stringify({ type: 'local', source: externalDir }),
|
||||
);
|
||||
|
||||
// Concurrently call loadExtension (simulating an install or update)
|
||||
const loadSinglePromise = extensionManager.loadExtension(externalDir);
|
||||
|
||||
// Wait for both to complete
|
||||
await Promise.all([loadAllPromise, loadSinglePromise]);
|
||||
|
||||
// Both extensions should now be present in the loadedExtensions array
|
||||
const extensions = extensionManager.getExtensions();
|
||||
expect(extensions).toHaveLength(2);
|
||||
const names = extensions.map((ext) => ext.name).sort();
|
||||
expect(names).toEqual(['ext1', 'ext2']);
|
||||
|
||||
fs.rmSync(externalDir, { recursive: true, force: true });
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -102,6 +102,7 @@ export class ExtensionManager extends ExtensionLoader {
|
||||
private telemetryConfig: Config;
|
||||
private workspaceDir: string;
|
||||
private loadedExtensions: GeminiCLIExtension[] | undefined;
|
||||
private loadingPromise: Promise<GeminiCLIExtension[]> | null = null;
|
||||
|
||||
constructor(options: ExtensionManagerParams) {
|
||||
super(options.eventEmitter);
|
||||
@@ -519,31 +520,103 @@ Would you like to attempt to install via "git clone" instead?`,
|
||||
throw new Error('Extensions already loaded, only load extensions once.');
|
||||
}
|
||||
|
||||
if (this.settings.admin.extensions.enabled === false) {
|
||||
this.loadedExtensions = [];
|
||||
return this.loadedExtensions;
|
||||
if (this.loadingPromise) {
|
||||
return this.loadingPromise;
|
||||
}
|
||||
|
||||
const extensionsDir = ExtensionStorage.getUserExtensionsDir();
|
||||
this.loadedExtensions = [];
|
||||
if (!fs.existsSync(extensionsDir)) {
|
||||
return this.loadedExtensions;
|
||||
}
|
||||
for (const subdir of fs.readdirSync(extensionsDir)) {
|
||||
const extensionDir = path.join(extensionsDir, subdir);
|
||||
await this.loadExtension(extensionDir);
|
||||
}
|
||||
return this.loadedExtensions;
|
||||
this.loadingPromise = (async () => {
|
||||
try {
|
||||
if (this.settings.admin.extensions.enabled === false) {
|
||||
this.loadedExtensions = [];
|
||||
return this.loadedExtensions;
|
||||
}
|
||||
|
||||
const extensionsDir = ExtensionStorage.getUserExtensionsDir();
|
||||
if (!fs.existsSync(extensionsDir)) {
|
||||
this.loadedExtensions = [];
|
||||
return this.loadedExtensions;
|
||||
}
|
||||
|
||||
const subdirs = await fs.promises.readdir(extensionsDir);
|
||||
const extensionPromises = subdirs.map((subdir) => {
|
||||
const extensionDir = path.join(extensionsDir, subdir);
|
||||
return this._buildExtension(extensionDir);
|
||||
});
|
||||
|
||||
const builtExtensionsOrNull = await Promise.all(extensionPromises);
|
||||
const builtExtensions = builtExtensionsOrNull.filter(
|
||||
(ext): ext is GeminiCLIExtension => ext !== null,
|
||||
);
|
||||
|
||||
const seenNames = new Set<string>();
|
||||
for (const ext of builtExtensions) {
|
||||
if (seenNames.has(ext.name)) {
|
||||
throw new Error(
|
||||
`Extension with name ${ext.name} already was loaded.`,
|
||||
);
|
||||
}
|
||||
seenNames.add(ext.name);
|
||||
}
|
||||
|
||||
this.loadedExtensions = builtExtensions;
|
||||
|
||||
await Promise.all(
|
||||
this.loadedExtensions.map((ext) => this.maybeStartExtension(ext)),
|
||||
);
|
||||
|
||||
return this.loadedExtensions;
|
||||
} finally {
|
||||
this.loadingPromise = null;
|
||||
}
|
||||
})();
|
||||
|
||||
return this.loadingPromise;
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds `extension` to the list of extensions and starts it if appropriate.
|
||||
*
|
||||
* @internal visible for testing only
|
||||
*/
|
||||
private async loadExtension(
|
||||
async loadExtension(
|
||||
extensionDir: string,
|
||||
): Promise<GeminiCLIExtension | null> {
|
||||
if (this.loadingPromise) {
|
||||
await this.loadingPromise;
|
||||
}
|
||||
this.loadedExtensions ??= [];
|
||||
if (!fs.statSync(extensionDir).isDirectory()) {
|
||||
const extension = await this._buildExtension(extensionDir);
|
||||
if (!extension) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (
|
||||
this.getExtensions().find(
|
||||
(installed) => installed.name === extension.name,
|
||||
)
|
||||
) {
|
||||
throw new Error(
|
||||
`Extension with name ${extension.name} already was loaded.`,
|
||||
);
|
||||
}
|
||||
|
||||
this.loadedExtensions = [...this.loadedExtensions, extension];
|
||||
await this.maybeStartExtension(extension);
|
||||
return extension;
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds an extension without side effects (does not mutate loadedExtensions or start it).
|
||||
*/
|
||||
private async _buildExtension(
|
||||
extensionDir: string,
|
||||
): Promise<GeminiCLIExtension | null> {
|
||||
try {
|
||||
const stats = await fs.promises.stat(extensionDir);
|
||||
if (!stats.isDirectory()) {
|
||||
return null;
|
||||
}
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -592,13 +665,6 @@ Would you like to attempt to install via "git clone" instead?`,
|
||||
|
||||
try {
|
||||
let config = await this.loadExtensionConfig(effectiveExtensionPath);
|
||||
if (
|
||||
this.getExtensions().find((extension) => extension.name === config.name)
|
||||
) {
|
||||
throw new Error(
|
||||
`Extension with name ${config.name} already was loaded.`,
|
||||
);
|
||||
}
|
||||
|
||||
const extensionId = getExtensionId(config, installMetadata);
|
||||
|
||||
@@ -768,7 +834,7 @@ Would you like to attempt to install via "git clone" instead?`,
|
||||
);
|
||||
}
|
||||
|
||||
const extension: GeminiCLIExtension = {
|
||||
return {
|
||||
name: config.name,
|
||||
version: config.version,
|
||||
path: effectiveExtensionPath,
|
||||
@@ -788,10 +854,6 @@ Would you like to attempt to install via "git clone" instead?`,
|
||||
agents: agentLoadResult.agents,
|
||||
themes: config.themes,
|
||||
};
|
||||
this.loadedExtensions = [...this.loadedExtensions, extension];
|
||||
|
||||
await this.maybeStartExtension(extension);
|
||||
return extension;
|
||||
} catch (e) {
|
||||
debugLogger.error(
|
||||
`Warning: Skipping extension in ${effectiveExtensionPath}: ${getErrorMessage(
|
||||
|
||||
@@ -177,6 +177,14 @@ describe('GeminiAgent', () => {
|
||||
|
||||
expect(response.protocolVersion).toBe(acp.PROTOCOL_VERSION);
|
||||
expect(response.authMethods).toHaveLength(3);
|
||||
const geminiAuth = response.authMethods?.find(
|
||||
(m) => m.id === AuthType.USE_GEMINI,
|
||||
);
|
||||
expect(geminiAuth?._meta).toEqual({
|
||||
'api-key': {
|
||||
provider: 'google',
|
||||
},
|
||||
});
|
||||
expect(response.agentCapabilities?.loadSession).toBe(true);
|
||||
});
|
||||
|
||||
@@ -187,6 +195,7 @@ describe('GeminiAgent', () => {
|
||||
|
||||
expect(mockConfig.refreshAuth).toHaveBeenCalledWith(
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
undefined,
|
||||
);
|
||||
expect(mockSettings.setValue).toHaveBeenCalledWith(
|
||||
SettingScope.User,
|
||||
@@ -195,6 +204,25 @@ describe('GeminiAgent', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should authenticate correctly with api-key in _meta', async () => {
|
||||
await agent.authenticate({
|
||||
methodId: AuthType.USE_GEMINI,
|
||||
_meta: {
|
||||
'api-key': 'test-api-key',
|
||||
},
|
||||
} as unknown as acp.AuthenticateRequest);
|
||||
|
||||
expect(mockConfig.refreshAuth).toHaveBeenCalledWith(
|
||||
AuthType.USE_GEMINI,
|
||||
'test-api-key',
|
||||
);
|
||||
expect(mockSettings.setValue).toHaveBeenCalledWith(
|
||||
SettingScope.User,
|
||||
'security.auth.selectedType',
|
||||
AuthType.USE_GEMINI,
|
||||
);
|
||||
});
|
||||
|
||||
it('should create a new session', async () => {
|
||||
mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({
|
||||
apiKey: 'test-key',
|
||||
|
||||
@@ -37,12 +37,17 @@ import {
|
||||
partListUnionToString,
|
||||
LlmRole,
|
||||
ApprovalMode,
|
||||
getVersion,
|
||||
convertSessionToClientHistory,
|
||||
} from '@google/gemini-cli-core';
|
||||
import * as acp from '@agentclientprotocol/sdk';
|
||||
import { AcpFileSystemService } from './fileSystemService.js';
|
||||
import { getAcpErrorMessage } from './acpErrors.js';
|
||||
import { Readable, Writable } from 'node:stream';
|
||||
|
||||
function hasMeta(obj: unknown): obj is { _meta?: Record<string, unknown> } {
|
||||
return typeof obj === 'object' && obj !== null && '_meta' in obj;
|
||||
}
|
||||
import type { Content, Part, FunctionCall } from '@google/genai';
|
||||
import type { LoadedSettings } from '../config/settings.js';
|
||||
import { SettingScope, loadSettings } from '../config/settings.js';
|
||||
@@ -81,6 +86,7 @@ export async function runZedIntegration(
|
||||
export class GeminiAgent {
|
||||
private sessions: Map<string, Session> = new Map();
|
||||
private clientCapabilities: acp.ClientCapabilities | undefined;
|
||||
private apiKey: string | undefined;
|
||||
|
||||
constructor(
|
||||
private config: Config,
|
||||
@@ -97,25 +103,35 @@ export class GeminiAgent {
|
||||
{
|
||||
id: AuthType.LOGIN_WITH_GOOGLE,
|
||||
name: 'Log in with Google',
|
||||
description: null,
|
||||
description: 'Log in with your Google account',
|
||||
},
|
||||
{
|
||||
id: AuthType.USE_GEMINI,
|
||||
name: 'Use Gemini API key',
|
||||
description:
|
||||
'Requires setting the `GEMINI_API_KEY` environment variable',
|
||||
name: 'Gemini API key',
|
||||
description: 'Use an API key with Gemini Developer API',
|
||||
_meta: {
|
||||
'api-key': {
|
||||
provider: 'google',
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
id: AuthType.USE_VERTEX_AI,
|
||||
name: 'Vertex AI',
|
||||
description: null,
|
||||
description: 'Use an API key with Vertex AI GenAI API',
|
||||
},
|
||||
];
|
||||
|
||||
await this.config.initialize();
|
||||
const version = await getVersion();
|
||||
return {
|
||||
protocolVersion: acp.PROTOCOL_VERSION,
|
||||
authMethods,
|
||||
agentInfo: {
|
||||
name: 'gemini-cli',
|
||||
title: 'Gemini CLI',
|
||||
version,
|
||||
},
|
||||
agentCapabilities: {
|
||||
loadSession: true,
|
||||
promptCapabilities: {
|
||||
@@ -131,7 +147,8 @@ export class GeminiAgent {
|
||||
};
|
||||
}
|
||||
|
||||
async authenticate({ methodId }: acp.AuthenticateRequest): Promise<void> {
|
||||
async authenticate(req: acp.AuthenticateRequest): Promise<void> {
|
||||
const { methodId } = req;
|
||||
const method = z.nativeEnum(AuthType).parse(methodId);
|
||||
const selectedAuthType = this.settings.merged.security.auth.selectedType;
|
||||
|
||||
@@ -139,17 +156,21 @@ export class GeminiAgent {
|
||||
if (selectedAuthType && selectedAuthType !== method) {
|
||||
await clearCachedCredentialFile();
|
||||
}
|
||||
// Check for api-key in _meta
|
||||
const meta = hasMeta(req) ? req._meta : undefined;
|
||||
const apiKey =
|
||||
typeof meta?.['api-key'] === 'string' ? meta['api-key'] : undefined;
|
||||
|
||||
// Refresh auth with the requested method
|
||||
// This will reuse existing credentials if they're valid,
|
||||
// or perform new authentication if needed
|
||||
try {
|
||||
await this.config.refreshAuth(method);
|
||||
if (apiKey) {
|
||||
this.apiKey = apiKey;
|
||||
}
|
||||
await this.config.refreshAuth(method, apiKey ?? this.apiKey);
|
||||
} catch (e) {
|
||||
throw new acp.RequestError(
|
||||
getErrorStatus(e) || 401,
|
||||
getAcpErrorMessage(e),
|
||||
);
|
||||
throw new acp.RequestError(-32000, getAcpErrorMessage(e));
|
||||
}
|
||||
this.settings.setValue(
|
||||
SettingScope.User,
|
||||
@@ -177,7 +198,7 @@ export class GeminiAgent {
|
||||
let isAuthenticated = false;
|
||||
let authErrorMessage = '';
|
||||
try {
|
||||
await config.refreshAuth(authType);
|
||||
await config.refreshAuth(authType, this.apiKey);
|
||||
isAuthenticated = true;
|
||||
|
||||
// Extra validation for Gemini API key
|
||||
@@ -199,7 +220,7 @@ export class GeminiAgent {
|
||||
|
||||
if (!isAuthenticated) {
|
||||
throw new acp.RequestError(
|
||||
401,
|
||||
-32000,
|
||||
authErrorMessage || 'Authentication required.',
|
||||
);
|
||||
}
|
||||
@@ -302,7 +323,7 @@ export class GeminiAgent {
|
||||
// This satisfies the security requirement to verify the user before executing
|
||||
// potentially unsafe server definitions.
|
||||
try {
|
||||
await config.refreshAuth(selectedAuthType);
|
||||
await config.refreshAuth(selectedAuthType, this.apiKey);
|
||||
} catch (e) {
|
||||
debugLogger.error(`Authentication failed: ${e}`);
|
||||
throw acp.RequestError.authRequired();
|
||||
|
||||
@@ -53,14 +53,14 @@ describe('A2AClientManager', () => {
|
||||
let manager: A2AClientManager;
|
||||
|
||||
// Stable mocks initialized once
|
||||
const sendMessageMock = vi.fn();
|
||||
const sendMessageStreamMock = vi.fn();
|
||||
const getTaskMock = vi.fn();
|
||||
const cancelTaskMock = vi.fn();
|
||||
const getAgentCardMock = vi.fn();
|
||||
const authFetchMock = vi.fn();
|
||||
|
||||
const mockClient = {
|
||||
sendMessage: sendMessageMock,
|
||||
sendMessageStream: sendMessageStreamMock,
|
||||
getTask: getTaskMock,
|
||||
cancelTask: cancelTaskMock,
|
||||
getAgentCard: getAgentCardMock,
|
||||
@@ -178,75 +178,91 @@ describe('A2AClientManager', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('sendMessage', () => {
|
||||
describe('sendMessageStream', () => {
|
||||
beforeEach(async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent');
|
||||
});
|
||||
|
||||
it('should send a message to the correct agent', async () => {
|
||||
sendMessageMock.mockResolvedValue({
|
||||
it('should send a message and return a stream', async () => {
|
||||
const mockResult = {
|
||||
kind: 'message',
|
||||
messageId: 'a',
|
||||
parts: [],
|
||||
role: 'agent',
|
||||
} as SendMessageResult);
|
||||
} as SendMessageResult;
|
||||
|
||||
await manager.sendMessage('TestAgent', 'Hello');
|
||||
expect(sendMessageMock).toHaveBeenCalledWith(
|
||||
sendMessageStreamMock.mockReturnValue(
|
||||
(async function* () {
|
||||
yield mockResult;
|
||||
})(),
|
||||
);
|
||||
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello');
|
||||
const results = [];
|
||||
for await (const res of stream) {
|
||||
results.push(res);
|
||||
}
|
||||
|
||||
expect(results).toEqual([mockResult]);
|
||||
expect(sendMessageStreamMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
message: expect.anything(),
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use contextId and taskId when provided', async () => {
|
||||
sendMessageMock.mockResolvedValue({
|
||||
kind: 'message',
|
||||
messageId: 'a',
|
||||
parts: [],
|
||||
role: 'agent',
|
||||
} as SendMessageResult);
|
||||
sendMessageStreamMock.mockReturnValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'a',
|
||||
parts: [],
|
||||
role: 'agent',
|
||||
} as SendMessageResult;
|
||||
})(),
|
||||
);
|
||||
|
||||
const expectedContextId = 'user-context-id';
|
||||
const expectedTaskId = 'user-task-id';
|
||||
|
||||
await manager.sendMessage('TestAgent', 'Hello', {
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
|
||||
contextId: expectedContextId,
|
||||
taskId: expectedTaskId,
|
||||
});
|
||||
|
||||
const call = sendMessageMock.mock.calls[0][0];
|
||||
for await (const _ of stream) {
|
||||
// consume stream
|
||||
}
|
||||
|
||||
const call = sendMessageStreamMock.mock.calls[0][0];
|
||||
expect(call.message.contextId).toBe(expectedContextId);
|
||||
expect(call.message.taskId).toBe(expectedTaskId);
|
||||
});
|
||||
|
||||
it('should return result from client', async () => {
|
||||
const mockResult = {
|
||||
contextId: 'server-context-id',
|
||||
id: 'ctx-1',
|
||||
kind: 'task',
|
||||
status: { state: 'working' },
|
||||
};
|
||||
|
||||
sendMessageMock.mockResolvedValueOnce(mockResult as SendMessageResult);
|
||||
|
||||
const response = await manager.sendMessage('TestAgent', 'Hello');
|
||||
|
||||
expect(response).toEqual(mockResult);
|
||||
});
|
||||
|
||||
it('should throw prefixed error on failure', async () => {
|
||||
sendMessageMock.mockRejectedValueOnce(new Error('Network error'));
|
||||
sendMessageStreamMock.mockImplementationOnce(() => {
|
||||
throw new Error('Network error');
|
||||
});
|
||||
|
||||
await expect(manager.sendMessage('TestAgent', 'Hello')).rejects.toThrow(
|
||||
'A2AClient SendMessage Error [TestAgent]: Network error',
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello');
|
||||
await expect(async () => {
|
||||
for await (const _ of stream) {
|
||||
// consume
|
||||
}
|
||||
}).rejects.toThrow(
|
||||
'[A2AClientManager] sendMessageStream Error [TestAgent]: Network error',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if the agent is not found', async () => {
|
||||
await expect(
|
||||
manager.sendMessage('NonExistentAgent', 'Hello'),
|
||||
).rejects.toThrow("Agent 'NonExistentAgent' not found.");
|
||||
const stream = manager.sendMessageStream('NonExistentAgent', 'Hello');
|
||||
await expect(async () => {
|
||||
for await (const _ of stream) {
|
||||
// consume
|
||||
}
|
||||
}).rejects.toThrow("Agent 'NonExistentAgent' not found.");
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -4,7 +4,14 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { AgentCard, Message, MessageSendParams, Task } from '@a2a-js/sdk';
|
||||
import type {
|
||||
AgentCard,
|
||||
Message,
|
||||
MessageSendParams,
|
||||
Task,
|
||||
TaskStatusUpdateEvent,
|
||||
TaskArtifactUpdateEvent,
|
||||
} from '@a2a-js/sdk';
|
||||
import {
|
||||
type Client,
|
||||
ClientFactory,
|
||||
@@ -18,7 +25,11 @@ import {
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
export type SendMessageResult = Message | Task;
|
||||
export type SendMessageResult =
|
||||
| Message
|
||||
| Task
|
||||
| TaskStatusUpdateEvent
|
||||
| TaskArtifactUpdateEvent;
|
||||
|
||||
/**
|
||||
* Manages A2A clients and caches loaded agent information.
|
||||
@@ -110,18 +121,18 @@ export class A2AClientManager {
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a message to a loaded agent.
|
||||
* Sends a message to a loaded agent and returns a stream of responses.
|
||||
* @param agentName The name of the agent to send the message to.
|
||||
* @param message The message content.
|
||||
* @param options Optional context and task IDs to maintain conversation state.
|
||||
* @returns The response from the agent (Message or Task).
|
||||
* @returns An async iterable of responses from the agent (Message or Task).
|
||||
* @throws Error if the agent returns an error response.
|
||||
*/
|
||||
async sendMessage(
|
||||
async *sendMessageStream(
|
||||
agentName: string,
|
||||
message: string,
|
||||
options?: { contextId?: string; taskId?: string },
|
||||
): Promise<SendMessageResult> {
|
||||
options?: { contextId?: string; taskId?: string; signal?: AbortSignal },
|
||||
): AsyncIterable<SendMessageResult> {
|
||||
const client = this.clients.get(agentName);
|
||||
if (!client) {
|
||||
throw new Error(`Agent '${agentName}' not found.`);
|
||||
@@ -136,20 +147,19 @@ export class A2AClientManager {
|
||||
contextId: options?.contextId,
|
||||
taskId: options?.taskId,
|
||||
},
|
||||
configuration: {
|
||||
blocking: true,
|
||||
},
|
||||
};
|
||||
|
||||
try {
|
||||
return await client.sendMessage(messageParams);
|
||||
yield* client.sendMessageStream(messageParams, {
|
||||
signal: options?.signal,
|
||||
});
|
||||
} catch (error: unknown) {
|
||||
const prefix = `A2AClient SendMessage Error [${agentName}]`;
|
||||
const prefix = `[A2AClientManager] sendMessageStream Error [${agentName}]`;
|
||||
if (error instanceof Error) {
|
||||
throw new Error(`${prefix}: ${error.message}`, { cause: error });
|
||||
}
|
||||
throw new Error(
|
||||
`${prefix}: Unexpected error during sendMessage: ${String(error)}`,
|
||||
`${prefix}: Unexpected error during sendMessageStream: ${String(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,12 +7,40 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import {
|
||||
extractMessageText,
|
||||
extractTaskText,
|
||||
extractIdsFromResponse,
|
||||
isTerminalState,
|
||||
A2AResultReassembler,
|
||||
} from './a2aUtils.js';
|
||||
import type { Message, Task, TextPart, DataPart, FilePart } from '@a2a-js/sdk';
|
||||
import type { SendMessageResult } from './a2a-client-manager.js';
|
||||
import type {
|
||||
Message,
|
||||
Task,
|
||||
TextPart,
|
||||
DataPart,
|
||||
FilePart,
|
||||
TaskStatusUpdateEvent,
|
||||
TaskArtifactUpdateEvent,
|
||||
} from '@a2a-js/sdk';
|
||||
|
||||
describe('a2aUtils', () => {
|
||||
describe('isTerminalState', () => {
|
||||
it('should return true for completed, failed, canceled, and rejected', () => {
|
||||
expect(isTerminalState('completed')).toBe(true);
|
||||
expect(isTerminalState('failed')).toBe(true);
|
||||
expect(isTerminalState('canceled')).toBe(true);
|
||||
expect(isTerminalState('rejected')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for working, submitted, input-required, auth-required, and unknown', () => {
|
||||
expect(isTerminalState('working')).toBe(false);
|
||||
expect(isTerminalState('submitted')).toBe(false);
|
||||
expect(isTerminalState('input-required')).toBe(false);
|
||||
expect(isTerminalState('auth-required')).toBe(false);
|
||||
expect(isTerminalState('unknown')).toBe(false);
|
||||
expect(isTerminalState(undefined)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractIdsFromResponse', () => {
|
||||
it('should extract IDs from a message response', () => {
|
||||
const message: Message = {
|
||||
@@ -25,7 +53,11 @@ describe('a2aUtils', () => {
|
||||
};
|
||||
|
||||
const result = extractIdsFromResponse(message);
|
||||
expect(result).toEqual({ contextId: 'ctx-1', taskId: 'task-1' });
|
||||
expect(result).toEqual({
|
||||
contextId: 'ctx-1',
|
||||
taskId: 'task-1',
|
||||
clearTaskId: false,
|
||||
});
|
||||
});
|
||||
|
||||
it('should extract IDs from an in-progress task response', () => {
|
||||
@@ -37,7 +69,76 @@ describe('a2aUtils', () => {
|
||||
};
|
||||
|
||||
const result = extractIdsFromResponse(task);
|
||||
expect(result).toEqual({ contextId: 'ctx-2', taskId: 'task-2' });
|
||||
expect(result).toEqual({
|
||||
contextId: 'ctx-2',
|
||||
taskId: 'task-2',
|
||||
clearTaskId: false,
|
||||
});
|
||||
});
|
||||
|
||||
it('should set clearTaskId true for terminal task response', () => {
|
||||
const task: Task = {
|
||||
id: 'task-3',
|
||||
contextId: 'ctx-3',
|
||||
kind: 'task',
|
||||
status: { state: 'completed' },
|
||||
};
|
||||
|
||||
const result = extractIdsFromResponse(task);
|
||||
expect(result.clearTaskId).toBe(true);
|
||||
});
|
||||
|
||||
it('should set clearTaskId true for terminal status update', () => {
|
||||
const update = {
|
||||
kind: 'status-update',
|
||||
contextId: 'ctx-4',
|
||||
taskId: 'task-4',
|
||||
final: true,
|
||||
status: { state: 'failed' },
|
||||
};
|
||||
|
||||
const result = extractIdsFromResponse(
|
||||
update as unknown as TaskStatusUpdateEvent,
|
||||
);
|
||||
expect(result.contextId).toBe('ctx-4');
|
||||
expect(result.taskId).toBe('task-4');
|
||||
expect(result.clearTaskId).toBe(true);
|
||||
});
|
||||
|
||||
it('should extract IDs from an artifact-update event', () => {
|
||||
const update = {
|
||||
kind: 'artifact-update',
|
||||
taskId: 'task-5',
|
||||
contextId: 'ctx-5',
|
||||
artifact: {
|
||||
artifactId: 'art-1',
|
||||
parts: [{ kind: 'text', text: 'artifact content' }],
|
||||
},
|
||||
} as unknown as TaskArtifactUpdateEvent;
|
||||
|
||||
const result = extractIdsFromResponse(update);
|
||||
expect(result).toEqual({
|
||||
contextId: 'ctx-5',
|
||||
taskId: 'task-5',
|
||||
clearTaskId: false,
|
||||
});
|
||||
});
|
||||
|
||||
it('should extract taskId from status update event', () => {
|
||||
const update = {
|
||||
kind: 'status-update',
|
||||
taskId: 'task-6',
|
||||
contextId: 'ctx-6',
|
||||
final: false,
|
||||
status: { state: 'working' },
|
||||
};
|
||||
|
||||
const result = extractIdsFromResponse(
|
||||
update as unknown as TaskStatusUpdateEvent,
|
||||
);
|
||||
expect(result.taskId).toBe('task-6');
|
||||
expect(result.contextId).toBe('ctx-6');
|
||||
expect(result.clearTaskId).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -123,49 +224,65 @@ describe('a2aUtils', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractTaskText', () => {
|
||||
it('should extract basic task info (clean)', () => {
|
||||
const task: Task = {
|
||||
id: 'task-1',
|
||||
contextId: 'ctx-1',
|
||||
kind: 'task',
|
||||
describe('A2AResultReassembler', () => {
|
||||
it('should reassemble sequential messages and incremental artifacts', () => {
|
||||
const reassembler = new A2AResultReassembler();
|
||||
|
||||
// 1. Initial status
|
||||
reassembler.update({
|
||||
kind: 'status-update',
|
||||
taskId: 't1',
|
||||
status: {
|
||||
state: 'working',
|
||||
message: {
|
||||
kind: 'message',
|
||||
role: 'agent',
|
||||
messageId: 'm1',
|
||||
parts: [{ kind: 'text', text: 'Processing...' } as TextPart],
|
||||
},
|
||||
parts: [{ kind: 'text', text: 'Analyzing...' }],
|
||||
} as Message,
|
||||
},
|
||||
};
|
||||
} as unknown as SendMessageResult);
|
||||
|
||||
const result = extractTaskText(task);
|
||||
expect(result).not.toContain('ID: task-1');
|
||||
expect(result).not.toContain('State: working');
|
||||
expect(result).toBe('Processing...');
|
||||
});
|
||||
// 2. First artifact chunk
|
||||
reassembler.update({
|
||||
kind: 'artifact-update',
|
||||
taskId: 't1',
|
||||
append: false,
|
||||
artifact: {
|
||||
artifactId: 'a1',
|
||||
name: 'Code',
|
||||
parts: [{ kind: 'text', text: 'print(' }],
|
||||
},
|
||||
} as unknown as SendMessageResult);
|
||||
|
||||
it('should extract artifacts with headers', () => {
|
||||
const task: Task = {
|
||||
id: 'task-1',
|
||||
contextId: 'ctx-1',
|
||||
kind: 'task',
|
||||
status: { state: 'completed' },
|
||||
artifacts: [
|
||||
{
|
||||
artifactId: 'art-1',
|
||||
name: 'Report',
|
||||
parts: [{ kind: 'text', text: 'This is the report.' } as TextPart],
|
||||
},
|
||||
],
|
||||
};
|
||||
// 3. Second status
|
||||
reassembler.update({
|
||||
kind: 'status-update',
|
||||
taskId: 't1',
|
||||
status: {
|
||||
state: 'working',
|
||||
message: {
|
||||
kind: 'message',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Processing...' }],
|
||||
} as Message,
|
||||
},
|
||||
} as unknown as SendMessageResult);
|
||||
|
||||
const result = extractTaskText(task);
|
||||
expect(result).toContain('Artifact (Report):');
|
||||
expect(result).toContain('This is the report.');
|
||||
expect(result).not.toContain('Artifacts:');
|
||||
expect(result).not.toContain(' - Name: Report');
|
||||
// 4. Second artifact chunk (append)
|
||||
reassembler.update({
|
||||
kind: 'artifact-update',
|
||||
taskId: 't1',
|
||||
append: true,
|
||||
artifact: {
|
||||
artifactId: 'a1',
|
||||
parts: [{ kind: 'text', text: '"Done")' }],
|
||||
},
|
||||
} as unknown as SendMessageResult);
|
||||
|
||||
const output = reassembler.toString();
|
||||
expect(output).toBe(
|
||||
'Analyzing...\n\nProcessing...\n\nArtifact (Code):\nprint("Done")',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,12 +6,120 @@
|
||||
|
||||
import type {
|
||||
Message,
|
||||
Task,
|
||||
Part,
|
||||
TextPart,
|
||||
DataPart,
|
||||
FilePart,
|
||||
Artifact,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
} from '@a2a-js/sdk';
|
||||
import type { SendMessageResult } from './a2a-client-manager.js';
|
||||
|
||||
/**
|
||||
* Reassembles incremental A2A streaming updates into a coherent result.
|
||||
* Shows sequential status/messages followed by all reassembled artifacts.
|
||||
*/
|
||||
export class A2AResultReassembler {
|
||||
private messageLog: string[] = [];
|
||||
private artifacts = new Map<string, Artifact>();
|
||||
private artifactChunks = new Map<string, string[]>();
|
||||
|
||||
/**
|
||||
* Processes a new chunk from the A2A stream.
|
||||
*/
|
||||
update(chunk: SendMessageResult) {
|
||||
if (!('kind' in chunk)) return;
|
||||
|
||||
switch (chunk.kind) {
|
||||
case 'status-update':
|
||||
this.pushMessage(chunk.status?.message);
|
||||
break;
|
||||
|
||||
case 'artifact-update':
|
||||
if (chunk.artifact) {
|
||||
const id = chunk.artifact.artifactId;
|
||||
const existing = this.artifacts.get(id);
|
||||
|
||||
if (chunk.append && existing) {
|
||||
for (const part of chunk.artifact.parts) {
|
||||
existing.parts.push(structuredClone(part));
|
||||
}
|
||||
} else {
|
||||
this.artifacts.set(id, structuredClone(chunk.artifact));
|
||||
}
|
||||
|
||||
const newText = extractPartsText(chunk.artifact.parts, '');
|
||||
let chunks = this.artifactChunks.get(id);
|
||||
if (!chunks) {
|
||||
chunks = [];
|
||||
this.artifactChunks.set(id, chunks);
|
||||
}
|
||||
if (chunk.append) {
|
||||
chunks.push(newText);
|
||||
} else {
|
||||
chunks.length = 0;
|
||||
chunks.push(newText);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case 'task':
|
||||
this.pushMessage(chunk.status?.message);
|
||||
if (chunk.artifacts) {
|
||||
for (const art of chunk.artifacts) {
|
||||
this.artifacts.set(art.artifactId, structuredClone(art));
|
||||
this.artifactChunks.set(art.artifactId, [
|
||||
extractPartsText(art.parts, ''),
|
||||
]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case 'message': {
|
||||
this.pushMessage(chunk);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private pushMessage(message: Message | undefined) {
|
||||
if (!message) return;
|
||||
const text = extractPartsText(message.parts, '\n');
|
||||
if (text && this.messageLog[this.messageLog.length - 1] !== text) {
|
||||
this.messageLog.push(text);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a human-readable string representation of the current reassembled state.
|
||||
*/
|
||||
toString(): string {
|
||||
const joinedMessages = this.messageLog.join('\n\n');
|
||||
|
||||
const artifactsOutput = Array.from(this.artifacts.keys())
|
||||
.map((id) => {
|
||||
const chunks = this.artifactChunks.get(id);
|
||||
const artifact = this.artifacts.get(id);
|
||||
if (!chunks || !artifact) return '';
|
||||
const content = chunks.join('');
|
||||
const header = artifact.name
|
||||
? `Artifact (${artifact.name}):`
|
||||
: 'Artifact:';
|
||||
return `${header}\n${content}`;
|
||||
})
|
||||
.filter(Boolean)
|
||||
.join('\n\n');
|
||||
|
||||
if (joinedMessages && artifactsOutput) {
|
||||
return `${joinedMessages}\n\n${artifactsOutput}`;
|
||||
}
|
||||
return joinedMessages || artifactsOutput;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts a human-readable text representation from a Message object.
|
||||
@@ -22,7 +130,23 @@ export function extractMessageText(message: Message | undefined): string {
|
||||
return '';
|
||||
}
|
||||
|
||||
return extractPartsText(message.parts);
|
||||
return extractPartsText(message.parts, '\n');
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts text from an array of parts, joining them with the specified separator.
|
||||
*/
|
||||
function extractPartsText(
|
||||
parts: Part[] | undefined,
|
||||
separator: string,
|
||||
): string {
|
||||
if (!parts || parts.length === 0) {
|
||||
return '';
|
||||
}
|
||||
return parts
|
||||
.map((p) => extractPartText(p))
|
||||
.filter(Boolean)
|
||||
.join(separator);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -52,50 +176,6 @@ function extractPartText(part: Part): string {
|
||||
return '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts a clean, human-readable text summary from a Task object.
|
||||
* Includes the status message and any artifact content with context headers.
|
||||
* Technical metadata like ID and State are omitted for better clarity and token efficiency.
|
||||
*/
|
||||
export function extractTaskText(task: Task): string {
|
||||
const parts: string[] = [];
|
||||
|
||||
// Status Message
|
||||
const statusMessageText = extractMessageText(task.status?.message);
|
||||
if (statusMessageText) {
|
||||
parts.push(statusMessageText);
|
||||
}
|
||||
|
||||
// Artifacts
|
||||
if (task.artifacts) {
|
||||
for (const artifact of task.artifacts) {
|
||||
const artifactContent = extractPartsText(artifact.parts);
|
||||
|
||||
if (artifactContent) {
|
||||
const header = artifact.name
|
||||
? `Artifact (${artifact.name}):`
|
||||
: 'Artifact:';
|
||||
parts.push(`${header}\n${artifactContent}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parts.join('\n\n');
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts text from an array of parts.
|
||||
*/
|
||||
function extractPartsText(parts: Part[] | undefined): string {
|
||||
if (!parts || parts.length === 0) {
|
||||
return '';
|
||||
}
|
||||
return parts
|
||||
.map((p) => extractPartText(p))
|
||||
.filter(Boolean)
|
||||
.join('\n');
|
||||
}
|
||||
|
||||
// Type Guards
|
||||
|
||||
function isTextPart(part: Part): part is TextPart {
|
||||
@@ -110,36 +190,58 @@ function isFilePart(part: Part): part is FilePart {
|
||||
return part.kind === 'file';
|
||||
}
|
||||
|
||||
function isStatusUpdateEvent(
|
||||
result: SendMessageResult,
|
||||
): result is TaskStatusUpdateEvent {
|
||||
return result.kind === 'status-update';
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts contextId and taskId from a Message or Task response.
|
||||
* Returns true if the given state is a terminal state for a task.
|
||||
*/
|
||||
export function isTerminalState(state: TaskState | undefined): boolean {
|
||||
return (
|
||||
state === 'completed' ||
|
||||
state === 'failed' ||
|
||||
state === 'canceled' ||
|
||||
state === 'rejected'
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts contextId and taskId from a Message, Task, or Update response.
|
||||
* Follows the pattern from the A2A CLI sample to maintain conversational continuity.
|
||||
*/
|
||||
export function extractIdsFromResponse(result: Message | Task): {
|
||||
export function extractIdsFromResponse(result: SendMessageResult): {
|
||||
contextId?: string;
|
||||
taskId?: string;
|
||||
clearTaskId?: boolean;
|
||||
} {
|
||||
let contextId: string | undefined;
|
||||
let taskId: string | undefined;
|
||||
let clearTaskId = false;
|
||||
|
||||
if (result.kind === 'message') {
|
||||
taskId = result.taskId;
|
||||
contextId = result.contextId;
|
||||
} else if (result.kind === 'task') {
|
||||
taskId = result.id;
|
||||
contextId = result.contextId;
|
||||
|
||||
// If the task is in a final state (and not input-required), we clear the taskId
|
||||
// so that the next interaction starts a fresh task (or keeps context without being bound to the old task).
|
||||
if (
|
||||
result.status &&
|
||||
result.status.state !== 'input-required' &&
|
||||
(result.status.state === 'completed' ||
|
||||
result.status.state === 'failed' ||
|
||||
result.status.state === 'canceled')
|
||||
) {
|
||||
taskId = undefined;
|
||||
if ('kind' in result) {
|
||||
const kind = result.kind;
|
||||
if (kind === 'message' || kind === 'artifact-update') {
|
||||
taskId = result.taskId;
|
||||
contextId = result.contextId;
|
||||
} else if (kind === 'task') {
|
||||
taskId = result.id;
|
||||
contextId = result.contextId;
|
||||
if (isTerminalState(result.status?.state)) {
|
||||
clearTaskId = true;
|
||||
}
|
||||
} else if (isStatusUpdateEvent(result)) {
|
||||
taskId = result.taskId;
|
||||
contextId = result.contextId;
|
||||
// Note: We ignore the 'final' flag here per A2A protocol best practices,
|
||||
// as a stream can close while a task is still in a 'working' state.
|
||||
if (isTerminalState(result.status?.state)) {
|
||||
clearTaskId = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { contextId, taskId };
|
||||
return { contextId, taskId, clearTaskId };
|
||||
}
|
||||
|
||||
@@ -14,7 +14,10 @@ import {
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import { RemoteAgentInvocation } from './remote-invocation.js';
|
||||
import { A2AClientManager } from './a2a-client-manager.js';
|
||||
import {
|
||||
A2AClientManager,
|
||||
type SendMessageResult,
|
||||
} from './a2a-client-manager.js';
|
||||
import type { RemoteAgentDefinition } from './types.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
|
||||
@@ -41,7 +44,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
const mockClientManager = {
|
||||
getClient: vi.fn(),
|
||||
loadAgent: vi.fn(),
|
||||
sendMessage: vi.fn(),
|
||||
sendMessageStream: vi.fn(),
|
||||
};
|
||||
const mockMessageBus = createMockMessageBus();
|
||||
|
||||
@@ -78,12 +81,16 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
it('uses "Get Started!" default when query is missing during execution', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.sendMessage.mockResolvedValue({
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Hello' }],
|
||||
});
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Hello' }],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
@@ -92,10 +99,10 @@ describe('RemoteAgentInvocation', () => {
|
||||
);
|
||||
await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(mockClientManager.sendMessage).toHaveBeenCalledWith(
|
||||
expect(mockClientManager.sendMessageStream).toHaveBeenCalledWith(
|
||||
'test-agent',
|
||||
'Get Started!',
|
||||
expect.any(Object),
|
||||
expect.objectContaining({ signal: expect.any(Object) }),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -113,12 +120,16 @@ describe('RemoteAgentInvocation', () => {
|
||||
describe('Execution Logic', () => {
|
||||
it('should lazy load the agent with ADCHandler if not present', async () => {
|
||||
mockClientManager.getClient.mockReturnValue(undefined);
|
||||
mockClientManager.sendMessage.mockResolvedValue({
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Hello' }],
|
||||
});
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Hello' }],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
@@ -141,12 +152,16 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
it('should not load the agent if already present', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.sendMessage.mockResolvedValue({
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Hello' }],
|
||||
});
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Hello' }],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
@@ -164,14 +179,18 @@ describe('RemoteAgentInvocation', () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
|
||||
// First call return values
|
||||
mockClientManager.sendMessage.mockResolvedValueOnce({
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Response 1' }],
|
||||
contextId: 'ctx-1',
|
||||
taskId: 'task-1',
|
||||
});
|
||||
mockClientManager.sendMessageStream.mockImplementationOnce(
|
||||
async function* () {
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Response 1' }],
|
||||
contextId: 'ctx-1',
|
||||
taskId: 'task-1',
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const invocation1 = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
@@ -184,21 +203,25 @@ describe('RemoteAgentInvocation', () => {
|
||||
// Execute first time
|
||||
const result1 = await invocation1.execute(new AbortController().signal);
|
||||
expect(result1.returnDisplay).toBe('Response 1');
|
||||
expect(mockClientManager.sendMessage).toHaveBeenLastCalledWith(
|
||||
expect(mockClientManager.sendMessageStream).toHaveBeenLastCalledWith(
|
||||
'test-agent',
|
||||
'first',
|
||||
{ contextId: undefined, taskId: undefined },
|
||||
{ contextId: undefined, taskId: undefined, signal: expect.any(Object) },
|
||||
);
|
||||
|
||||
// Prepare for second call with simulated state persistence
|
||||
mockClientManager.sendMessage.mockResolvedValueOnce({
|
||||
kind: 'message',
|
||||
messageId: 'msg-2',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Response 2' }],
|
||||
contextId: 'ctx-1',
|
||||
taskId: 'task-2',
|
||||
});
|
||||
mockClientManager.sendMessageStream.mockImplementationOnce(
|
||||
async function* () {
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'msg-2',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Response 2' }],
|
||||
contextId: 'ctx-1',
|
||||
taskId: 'task-2',
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const invocation2 = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
@@ -210,21 +233,25 @@ describe('RemoteAgentInvocation', () => {
|
||||
const result2 = await invocation2.execute(new AbortController().signal);
|
||||
expect(result2.returnDisplay).toBe('Response 2');
|
||||
|
||||
expect(mockClientManager.sendMessage).toHaveBeenLastCalledWith(
|
||||
expect(mockClientManager.sendMessageStream).toHaveBeenLastCalledWith(
|
||||
'test-agent',
|
||||
'second',
|
||||
{ contextId: 'ctx-1', taskId: 'task-1' }, // Used state from first call
|
||||
{ contextId: 'ctx-1', taskId: 'task-1', signal: expect.any(Object) }, // Used state from first call
|
||||
);
|
||||
|
||||
// Third call: Task completes
|
||||
mockClientManager.sendMessage.mockResolvedValueOnce({
|
||||
kind: 'task',
|
||||
id: 'task-2',
|
||||
contextId: 'ctx-1',
|
||||
status: { state: 'completed', message: undefined },
|
||||
artifacts: [],
|
||||
history: [],
|
||||
});
|
||||
mockClientManager.sendMessageStream.mockImplementationOnce(
|
||||
async function* () {
|
||||
yield {
|
||||
kind: 'task',
|
||||
id: 'task-2',
|
||||
contextId: 'ctx-1',
|
||||
status: { state: 'completed', message: undefined },
|
||||
artifacts: [],
|
||||
history: [],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const invocation3 = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
@@ -236,12 +263,16 @@ describe('RemoteAgentInvocation', () => {
|
||||
await invocation3.execute(new AbortController().signal);
|
||||
|
||||
// Fourth call: Should start new task (taskId undefined)
|
||||
mockClientManager.sendMessage.mockResolvedValueOnce({
|
||||
kind: 'message',
|
||||
messageId: 'msg-3',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'New Task' }],
|
||||
});
|
||||
mockClientManager.sendMessageStream.mockImplementationOnce(
|
||||
async function* () {
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'msg-3',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'New Task' }],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const invocation4 = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
@@ -252,17 +283,84 @@ describe('RemoteAgentInvocation', () => {
|
||||
);
|
||||
await invocation4.execute(new AbortController().signal);
|
||||
|
||||
expect(mockClientManager.sendMessage).toHaveBeenLastCalledWith(
|
||||
expect(mockClientManager.sendMessageStream).toHaveBeenLastCalledWith(
|
||||
'test-agent',
|
||||
'fourth',
|
||||
{ contextId: 'ctx-1', taskId: undefined }, // taskId cleared!
|
||||
{ contextId: 'ctx-1', taskId: undefined, signal: expect.any(Object) }, // taskId cleared!
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle streaming updates and reassemble output', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Hello' }],
|
||||
};
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Hello World' }],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const updateOutput = vi.fn();
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{ query: 'hi' },
|
||||
mockMessageBus,
|
||||
);
|
||||
await invocation.execute(new AbortController().signal, updateOutput);
|
||||
|
||||
expect(updateOutput).toHaveBeenCalledWith('Hello');
|
||||
expect(updateOutput).toHaveBeenCalledWith('Hello\n\nHello World');
|
||||
});
|
||||
|
||||
it('should abort when signal is aborted during streaming', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
const controller = new AbortController();
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Partial' }],
|
||||
};
|
||||
// Simulate abort between chunks
|
||||
controller.abort();
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'msg-2',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Partial response continued' }],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{ query: 'hi' },
|
||||
mockMessageBus,
|
||||
);
|
||||
const result = await invocation.execute(controller.signal);
|
||||
|
||||
expect(result.error).toBeDefined();
|
||||
expect(result.error?.message).toContain('Operation aborted');
|
||||
});
|
||||
|
||||
it('should handle errors gracefully', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.sendMessage.mockRejectedValue(
|
||||
new Error('Network error'),
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
if (Math.random() < 0) yield {} as unknown as SendMessageResult;
|
||||
throw new Error('Network error');
|
||||
},
|
||||
);
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
@@ -282,15 +380,19 @@ describe('RemoteAgentInvocation', () => {
|
||||
it('should use a2a helpers for extracting text', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
// Mock a complex message part that needs extraction
|
||||
mockClientManager.sendMessage.mockResolvedValue({
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [
|
||||
{ kind: 'text', text: 'Extracted text' },
|
||||
{ kind: 'data', data: { foo: 'bar' } },
|
||||
],
|
||||
});
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'msg-1',
|
||||
role: 'agent',
|
||||
parts: [
|
||||
{ kind: 'text', text: 'Extracted text' },
|
||||
{ kind: 'data', data: { foo: 'bar' } },
|
||||
],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
@@ -304,6 +406,105 @@ describe('RemoteAgentInvocation', () => {
|
||||
// Just check that text is present, exact formatting depends on helper
|
||||
expect(result.returnDisplay).toContain('Extracted text');
|
||||
});
|
||||
|
||||
it('should handle mixed response types during streaming (TaskStatusUpdateEvent + Message)', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
yield {
|
||||
kind: 'status-update',
|
||||
taskId: 'task-1',
|
||||
contextId: 'ctx-1',
|
||||
final: false,
|
||||
status: {
|
||||
state: 'working',
|
||||
message: {
|
||||
kind: 'message',
|
||||
role: 'agent',
|
||||
messageId: 'm1',
|
||||
parts: [{ kind: 'text', text: 'Thinking...' }],
|
||||
},
|
||||
},
|
||||
};
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'msg-final',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Final Answer' }],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const updateOutput = vi.fn();
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{ query: 'hi' },
|
||||
mockMessageBus,
|
||||
);
|
||||
const result = await invocation.execute(
|
||||
new AbortController().signal,
|
||||
updateOutput,
|
||||
);
|
||||
|
||||
expect(updateOutput).toHaveBeenCalledWith('Thinking...');
|
||||
expect(updateOutput).toHaveBeenCalledWith('Thinking...\n\nFinal Answer');
|
||||
expect(result.returnDisplay).toBe('Thinking...\n\nFinal Answer');
|
||||
});
|
||||
|
||||
it('should handle artifact reassembly with append: true', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
yield {
|
||||
kind: 'status-update',
|
||||
taskId: 'task-1',
|
||||
status: {
|
||||
state: 'working',
|
||||
message: {
|
||||
kind: 'message',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: 'Generating...' }],
|
||||
},
|
||||
},
|
||||
};
|
||||
yield {
|
||||
kind: 'artifact-update',
|
||||
taskId: 'task-1',
|
||||
append: false,
|
||||
artifact: {
|
||||
artifactId: 'art-1',
|
||||
name: 'Result',
|
||||
parts: [{ kind: 'text', text: 'Part 1' }],
|
||||
},
|
||||
};
|
||||
yield {
|
||||
kind: 'artifact-update',
|
||||
taskId: 'task-1',
|
||||
append: true,
|
||||
artifact: {
|
||||
artifactId: 'art-1',
|
||||
parts: [{ kind: 'text', text: ' Part 2' }],
|
||||
},
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const updateOutput = vi.fn();
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
{ query: 'hi' },
|
||||
mockMessageBus,
|
||||
);
|
||||
await invocation.execute(new AbortController().signal, updateOutput);
|
||||
|
||||
expect(updateOutput).toHaveBeenCalledWith('Generating...');
|
||||
expect(updateOutput).toHaveBeenCalledWith(
|
||||
'Generating...\n\nArtifact (Result):\nPart 1',
|
||||
);
|
||||
expect(updateOutput).toHaveBeenCalledWith(
|
||||
'Generating...\n\nArtifact (Result):\nPart 1 Part 2',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Confirmations', () => {
|
||||
|
||||
@@ -18,14 +18,12 @@ import type {
|
||||
} from './types.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { A2AClientManager } from './a2a-client-manager.js';
|
||||
import {
|
||||
extractMessageText,
|
||||
extractTaskText,
|
||||
extractIdsFromResponse,
|
||||
} from './a2aUtils.js';
|
||||
import { extractIdsFromResponse, A2AResultReassembler } from './a2aUtils.js';
|
||||
import { GoogleAuth } from 'google-auth-library';
|
||||
import type { AuthenticationHandler } from '@a2a-js/sdk/client';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { AnsiOutput } from '../utils/terminalSerializer.js';
|
||||
import type { SendMessageResult } from './a2a-client-manager.js';
|
||||
|
||||
/**
|
||||
* Authentication handler implementation using Google Application Default Credentials (ADC).
|
||||
@@ -123,10 +121,14 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
|
||||
};
|
||||
}
|
||||
|
||||
async execute(_signal: AbortSignal): Promise<ToolResult> {
|
||||
async execute(
|
||||
_signal: AbortSignal,
|
||||
updateOutput?: (output: string | AnsiOutput) => void,
|
||||
): Promise<ToolResult> {
|
||||
// 1. Ensure the agent is loaded (cached by manager)
|
||||
// We assume the user has provided an access token via some mechanism (TODO),
|
||||
// or we rely on ADC.
|
||||
const reassembler = new A2AResultReassembler();
|
||||
try {
|
||||
const priorState = RemoteAgentInvocation.sessionState.get(
|
||||
this.definition.name,
|
||||
@@ -146,49 +148,73 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
|
||||
|
||||
const message = this.params.query;
|
||||
|
||||
const response = await this.clientManager.sendMessage(
|
||||
const stream = this.clientManager.sendMessageStream(
|
||||
this.definition.name,
|
||||
message,
|
||||
{
|
||||
contextId: this.contextId,
|
||||
taskId: this.taskId,
|
||||
signal: _signal,
|
||||
},
|
||||
);
|
||||
|
||||
// Extracts IDs, taskID will be undefined if the task is completed/failed/canceled.
|
||||
const { contextId, taskId } = extractIdsFromResponse(response);
|
||||
let finalResponse: SendMessageResult | undefined;
|
||||
|
||||
this.contextId = contextId ?? this.contextId;
|
||||
this.taskId = taskId;
|
||||
for await (const chunk of stream) {
|
||||
if (_signal.aborted) {
|
||||
throw new Error('Operation aborted');
|
||||
}
|
||||
finalResponse = chunk;
|
||||
reassembler.update(chunk);
|
||||
|
||||
if (updateOutput) {
|
||||
updateOutput(reassembler.toString());
|
||||
}
|
||||
|
||||
const {
|
||||
contextId: newContextId,
|
||||
taskId: newTaskId,
|
||||
clearTaskId,
|
||||
} = extractIdsFromResponse(chunk);
|
||||
|
||||
if (newContextId) {
|
||||
this.contextId = newContextId;
|
||||
}
|
||||
|
||||
this.taskId = clearTaskId ? undefined : (newTaskId ?? this.taskId);
|
||||
}
|
||||
|
||||
if (!finalResponse) {
|
||||
throw new Error('No response from remote agent.');
|
||||
}
|
||||
|
||||
const finalOutput = reassembler.toString();
|
||||
|
||||
debugLogger.debug(
|
||||
`[RemoteAgent] Final response from ${this.definition.name}:\n${JSON.stringify(finalResponse, null, 2)}`,
|
||||
);
|
||||
|
||||
return {
|
||||
llmContent: [{ text: finalOutput }],
|
||||
returnDisplay: finalOutput,
|
||||
};
|
||||
} catch (error: unknown) {
|
||||
const partialOutput = reassembler.toString();
|
||||
const errorMessage = `Error calling remote agent: ${error instanceof Error ? error.message : String(error)}`;
|
||||
const fullDisplay = partialOutput
|
||||
? `${partialOutput}\n\n${errorMessage}`
|
||||
: errorMessage;
|
||||
return {
|
||||
llmContent: [{ text: fullDisplay }],
|
||||
returnDisplay: fullDisplay,
|
||||
error: { message: errorMessage },
|
||||
};
|
||||
} finally {
|
||||
// Persist state even on partial failures or aborts to maintain conversational continuity.
|
||||
RemoteAgentInvocation.sessionState.set(this.definition.name, {
|
||||
contextId: this.contextId,
|
||||
taskId: this.taskId,
|
||||
});
|
||||
|
||||
// Extract the output text
|
||||
const outputText =
|
||||
response.kind === 'task'
|
||||
? extractTaskText(response)
|
||||
: response.kind === 'message'
|
||||
? extractMessageText(response)
|
||||
: JSON.stringify(response);
|
||||
|
||||
debugLogger.debug(
|
||||
`[RemoteAgent] Response from ${this.definition.name}:\n${JSON.stringify(response, null, 2)}`,
|
||||
);
|
||||
|
||||
return {
|
||||
llmContent: [{ text: outputText }],
|
||||
returnDisplay: outputText,
|
||||
};
|
||||
} catch (error: unknown) {
|
||||
const errorMessage = `Error calling remote agent: ${error instanceof Error ? error.message : String(error)}`;
|
||||
return {
|
||||
llmContent: [{ text: errorMessage }],
|
||||
returnDisplay: errorMessage,
|
||||
error: { message: errorMessage },
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,6 +95,7 @@ const mockConfig = {
|
||||
getNoBrowser: () => false,
|
||||
getProxy: () => 'http://test.proxy.com:8080',
|
||||
isBrowserLaunchSuppressed: () => false,
|
||||
getExperimentalZedIntegration: () => false,
|
||||
} as unknown as Config;
|
||||
|
||||
// Mock fetch globally
|
||||
|
||||
@@ -271,9 +271,12 @@ async function initOauthClient(
|
||||
|
||||
await triggerPostAuthCallbacks(client.credentials);
|
||||
} else {
|
||||
const userConsent = await getConsentForOauth('');
|
||||
if (!userConsent) {
|
||||
throw new FatalCancellationError('Authentication cancelled by user.');
|
||||
// In Zed integration, we skip the interactive consent and directly open the browser
|
||||
if (!config.getExperimentalZedIntegration()) {
|
||||
const userConsent = await getConsentForOauth('');
|
||||
if (!userConsent) {
|
||||
throw new FatalCancellationError('Authentication cancelled by user.');
|
||||
}
|
||||
}
|
||||
|
||||
const webLogin = await authWithWeb(client);
|
||||
|
||||
@@ -499,6 +499,7 @@ describe('Server Config (config.ts)', () => {
|
||||
expect(createContentGeneratorConfig).toHaveBeenCalledWith(
|
||||
config,
|
||||
authType,
|
||||
undefined,
|
||||
);
|
||||
// Verify that contentGeneratorConfig is updated
|
||||
expect(config.getContentGeneratorConfig()).toEqual(mockContentConfig);
|
||||
|
||||
@@ -1126,7 +1126,7 @@ export class Config {
|
||||
return this.contentGenerator;
|
||||
}
|
||||
|
||||
async refreshAuth(authMethod: AuthType) {
|
||||
async refreshAuth(authMethod: AuthType, apiKey?: string) {
|
||||
// Reset availability service when switching auth
|
||||
this.modelAvailabilityService.reset();
|
||||
|
||||
@@ -1152,6 +1152,7 @@ export class Config {
|
||||
const newContentGeneratorConfig = await createContentGeneratorConfig(
|
||||
this,
|
||||
authMethod,
|
||||
apiKey,
|
||||
);
|
||||
this.contentGenerator = await createContentGenerator(
|
||||
newContentGeneratorConfig,
|
||||
|
||||
@@ -90,9 +90,13 @@ export type ContentGeneratorConfig = {
|
||||
export async function createContentGeneratorConfig(
|
||||
config: Config,
|
||||
authType: AuthType | undefined,
|
||||
apiKey?: string,
|
||||
): Promise<ContentGeneratorConfig> {
|
||||
const geminiApiKey =
|
||||
process.env['GEMINI_API_KEY'] || (await loadApiKey()) || undefined;
|
||||
apiKey ||
|
||||
process.env['GEMINI_API_KEY'] ||
|
||||
(await loadApiKey()) ||
|
||||
undefined;
|
||||
const googleApiKey = process.env['GOOGLE_API_KEY'] || undefined;
|
||||
const googleCloudProject =
|
||||
process.env['GOOGLE_CLOUD_PROJECT'] ||
|
||||
|
||||
Reference in New Issue
Block a user