diff --git a/integration-tests/plan-mode.test.ts b/integration-tests/plan-mode.test.ts new file mode 100644 index 0000000000..784bb890a0 --- /dev/null +++ b/integration-tests/plan-mode.test.ts @@ -0,0 +1,143 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach, afterEach } from 'vitest'; +import { TestRig, checkModelOutputContent } from './test-helper.js'; + +describe('Plan Mode', () => { + let rig: TestRig; + + beforeEach(() => { + rig = new TestRig(); + }); + + afterEach(async () => await rig.cleanup()); + + it('should allow read-only tools but deny write tools in plan mode', async () => { + await rig.setup( + 'should allow read-only tools but deny write tools in plan mode', + { + settings: { + experimental: { plan: true }, + tools: { + core: [ + 'run_shell_command', + 'list_directory', + 'write_file', + 'read_file', + ], + }, + }, + }, + ); + + // We use a prompt that asks for both a read-only action and a write action. + // "List files" (read-only) followed by "touch denied.txt" (write). + const result = await rig.run({ + approvalMode: 'plan', + stdin: + 'Please list the files in the current directory, and then attempt to create a new file named "denied.txt" using a shell command.', + }); + + const lsCallFound = await rig.waitForToolCall('list_directory'); + expect(lsCallFound, 'Expected list_directory to be called').toBe(true); + + const shellCallFound = await rig.waitForToolCall('run_shell_command'); + expect(shellCallFound, 'Expected run_shell_command to fail').toBe(false); + + const toolLogs = rig.readToolLogs(); + const lsLog = toolLogs.find((l) => l.toolRequest.name === 'list_directory'); + expect( + toolLogs.find((l) => l.toolRequest.name === 'run_shell_command'), + ).toBeUndefined(); + + expect(lsLog?.toolRequest.success).toBe(true); + + checkModelOutputContent(result, { + expectedContent: ['Plan Mode', 'read-only'], + testName: 'Plan Mode restrictions test', + }); + }); + + it('should allow write_file only in the plans directory in plan mode', async () => { + await rig.setup( + 'should allow write_file only in the plans directory in plan mode', + { + settings: { + experimental: { plan: true }, + tools: { + core: ['write_file', 'read_file', 'list_directory'], + allowed: ['write_file'], + }, + general: { defaultApprovalMode: 'plan' }, + }, + }, + ); + + // We ask the agent to create a plan for a feature, which should trigger a write_file in the plans directory. + // Verify that write_file outside of plan directory fails + await rig.run({ + approvalMode: 'plan', + stdin: + 'Create a file called plan.md in the plans directory. Then create a file called hello.txt in the current directory', + }); + + const toolLogs = rig.readToolLogs(); + const writeLogs = toolLogs.filter( + (l) => l.toolRequest.name === 'write_file', + ); + + const planWrite = writeLogs.find( + (l) => + l.toolRequest.args.includes('plans') && + l.toolRequest.args.includes('plan.md'), + ); + + const blockedWrite = writeLogs.find((l) => + l.toolRequest.args.includes('hello.txt'), + ); + + // Model is undeterministic, sometimes a blocked write appears in tool logs and sometimes it doesn't + if (blockedWrite) { + expect(blockedWrite?.toolRequest.success).toBe(false); + } + + expect(planWrite?.toolRequest.success).toBe(true); + }); + + it('should be able to enter plan mode from default mode', async () => { + await rig.setup('should be able to enter plan mode from default mode', { + settings: { + experimental: { plan: true }, + tools: { + core: ['enter_plan_mode'], + allowed: ['enter_plan_mode'], + }, + }, + }); + + // Start in default mode and ask to enter plan mode. + await rig.run({ + approvalMode: 'default', + stdin: + 'I want to perform a complex refactoring. Please enter plan mode so we can design it first.', + }); + + const enterPlanCallFound = await rig.waitForToolCall( + 'enter_plan_mode', + 10000, + ); + expect(enterPlanCallFound, 'Expected enter_plan_mode to be called').toBe( + true, + ); + + const toolLogs = rig.readToolLogs(); + const enterLog = toolLogs.find( + (l) => l.toolRequest.name === 'enter_plan_mode', + ); + expect(enterLog?.toolRequest.success).toBe(true); + }); +}); diff --git a/packages/cli/src/config/extension-manager.test.ts b/packages/cli/src/config/extension-manager.test.ts new file mode 100644 index 0000000000..4ab52e24b5 --- /dev/null +++ b/packages/cli/src/config/extension-manager.test.ts @@ -0,0 +1,188 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; +import * as fs from 'node:fs'; +import * as os from 'node:os'; +import * as path from 'node:path'; +import { ExtensionManager } from './extension-manager.js'; +import { createTestMergedSettings } from './settings.js'; +import { createExtension } from '../test-utils/createExtension.js'; +import { EXTENSIONS_DIRECTORY_NAME } from './extensions/variables.js'; + +const mockHomedir = vi.hoisted(() => vi.fn(() => '/tmp/mock-home')); + +vi.mock('os', async (importOriginal) => { + const mockedOs = await importOriginal(); + return { + ...mockedOs, + homedir: mockHomedir, + }; +}); + +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const actual = + await importOriginal(); + return { + ...actual, + homedir: mockHomedir, + }; +}); + +describe('ExtensionManager', () => { + let tempHomeDir: string; + let tempWorkspaceDir: string; + let userExtensionsDir: string; + let extensionManager: ExtensionManager; + + beforeEach(() => { + vi.clearAllMocks(); + tempHomeDir = fs.mkdtempSync( + path.join(os.tmpdir(), 'gemini-cli-test-home-'), + ); + tempWorkspaceDir = fs.mkdtempSync( + path.join(tempHomeDir, 'gemini-cli-test-workspace-'), + ); + mockHomedir.mockReturnValue(tempHomeDir); + userExtensionsDir = path.join(tempHomeDir, EXTENSIONS_DIRECTORY_NAME); + fs.mkdirSync(userExtensionsDir, { recursive: true }); + + extensionManager = new ExtensionManager({ + settings: createTestMergedSettings(), + workspaceDir: tempWorkspaceDir, + requestConsent: vi.fn().mockResolvedValue(true), + requestSetting: null, + }); + }); + + afterEach(() => { + try { + fs.rmSync(tempHomeDir, { recursive: true, force: true }); + } catch (_e) { + // Ignore + } + }); + + describe('loadExtensions parallel loading', () => { + it('should prevent concurrent loading and return the same promise', async () => { + createExtension({ + extensionsDir: userExtensionsDir, + name: 'ext1', + version: '1.0.0', + }); + createExtension({ + extensionsDir: userExtensionsDir, + name: 'ext2', + version: '1.0.0', + }); + + // Call loadExtensions twice concurrently + const promise1 = extensionManager.loadExtensions(); + const promise2 = extensionManager.loadExtensions(); + + // They should resolve to the exact same array + const [extensions1, extensions2] = await Promise.all([ + promise1, + promise2, + ]); + + expect(extensions1).toBe(extensions2); + expect(extensions1).toHaveLength(2); + + const names = extensions1.map((ext) => ext.name).sort(); + expect(names).toEqual(['ext1', 'ext2']); + }); + + it('should throw an error if loadExtensions is called after it has already resolved', async () => { + createExtension({ + extensionsDir: userExtensionsDir, + name: 'ext1', + version: '1.0.0', + }); + + await extensionManager.loadExtensions(); + + await expect(extensionManager.loadExtensions()).rejects.toThrow( + 'Extensions already loaded, only load extensions once.', + ); + }); + + it('should not throw if extension directory does not exist', async () => { + fs.rmSync(userExtensionsDir, { recursive: true, force: true }); + + const extensions = await extensionManager.loadExtensions(); + expect(extensions).toEqual([]); + }); + + it('should throw if there are duplicate extension names', async () => { + // We manually create two extensions with different dirs but same name in config + const ext1Dir = path.join(userExtensionsDir, 'ext1-dir'); + const ext2Dir = path.join(userExtensionsDir, 'ext2-dir'); + fs.mkdirSync(ext1Dir, { recursive: true }); + fs.mkdirSync(ext2Dir, { recursive: true }); + + const config = JSON.stringify({ + name: 'duplicate-ext', + version: '1.0.0', + }); + fs.writeFileSync(path.join(ext1Dir, 'gemini-extension.json'), config); + fs.writeFileSync( + path.join(ext1Dir, 'metadata.json'), + JSON.stringify({ type: 'local', source: ext1Dir }), + ); + + fs.writeFileSync(path.join(ext2Dir, 'gemini-extension.json'), config); + fs.writeFileSync( + path.join(ext2Dir, 'metadata.json'), + JSON.stringify({ type: 'local', source: ext2Dir }), + ); + + await expect(extensionManager.loadExtensions()).rejects.toThrow( + 'Extension with name duplicate-ext already was loaded.', + ); + }); + + it('should wait for loadExtensions to finish when loadExtension is called concurrently', async () => { + // Create an initial extension that loadExtensions will find + createExtension({ + extensionsDir: userExtensionsDir, + name: 'ext1', + version: '1.0.0', + }); + + // Start the parallel load (it will read ext1) + const loadAllPromise = extensionManager.loadExtensions(); + + // Create a second extension dynamically in a DIFFERENT directory + // so that loadExtensions (which scans userExtensionsDir) doesn't find it. + const externalDir = fs.mkdtempSync( + path.join(os.tmpdir(), 'external-ext-'), + ); + fs.writeFileSync( + path.join(externalDir, 'gemini-extension.json'), + JSON.stringify({ name: 'ext2', version: '1.0.0' }), + ); + fs.writeFileSync( + path.join(externalDir, 'metadata.json'), + JSON.stringify({ type: 'local', source: externalDir }), + ); + + // Concurrently call loadExtension (simulating an install or update) + const loadSinglePromise = extensionManager.loadExtension(externalDir); + + // Wait for both to complete + await Promise.all([loadAllPromise, loadSinglePromise]); + + // Both extensions should now be present in the loadedExtensions array + const extensions = extensionManager.getExtensions(); + expect(extensions).toHaveLength(2); + const names = extensions.map((ext) => ext.name).sort(); + expect(names).toEqual(['ext1', 'ext2']); + + fs.rmSync(externalDir, { recursive: true, force: true }); + }); + }); +}); diff --git a/packages/cli/src/config/extension-manager.ts b/packages/cli/src/config/extension-manager.ts index 8d3b5fa15f..93ad3f3536 100644 --- a/packages/cli/src/config/extension-manager.ts +++ b/packages/cli/src/config/extension-manager.ts @@ -102,6 +102,7 @@ export class ExtensionManager extends ExtensionLoader { private telemetryConfig: Config; private workspaceDir: string; private loadedExtensions: GeminiCLIExtension[] | undefined; + private loadingPromise: Promise | null = null; constructor(options: ExtensionManagerParams) { super(options.eventEmitter); @@ -519,31 +520,103 @@ Would you like to attempt to install via "git clone" instead?`, throw new Error('Extensions already loaded, only load extensions once.'); } - if (this.settings.admin.extensions.enabled === false) { - this.loadedExtensions = []; - return this.loadedExtensions; + if (this.loadingPromise) { + return this.loadingPromise; } - const extensionsDir = ExtensionStorage.getUserExtensionsDir(); - this.loadedExtensions = []; - if (!fs.existsSync(extensionsDir)) { - return this.loadedExtensions; - } - for (const subdir of fs.readdirSync(extensionsDir)) { - const extensionDir = path.join(extensionsDir, subdir); - await this.loadExtension(extensionDir); - } - return this.loadedExtensions; + this.loadingPromise = (async () => { + try { + if (this.settings.admin.extensions.enabled === false) { + this.loadedExtensions = []; + return this.loadedExtensions; + } + + const extensionsDir = ExtensionStorage.getUserExtensionsDir(); + if (!fs.existsSync(extensionsDir)) { + this.loadedExtensions = []; + return this.loadedExtensions; + } + + const subdirs = await fs.promises.readdir(extensionsDir); + const extensionPromises = subdirs.map((subdir) => { + const extensionDir = path.join(extensionsDir, subdir); + return this._buildExtension(extensionDir); + }); + + const builtExtensionsOrNull = await Promise.all(extensionPromises); + const builtExtensions = builtExtensionsOrNull.filter( + (ext): ext is GeminiCLIExtension => ext !== null, + ); + + const seenNames = new Set(); + for (const ext of builtExtensions) { + if (seenNames.has(ext.name)) { + throw new Error( + `Extension with name ${ext.name} already was loaded.`, + ); + } + seenNames.add(ext.name); + } + + this.loadedExtensions = builtExtensions; + + await Promise.all( + this.loadedExtensions.map((ext) => this.maybeStartExtension(ext)), + ); + + return this.loadedExtensions; + } finally { + this.loadingPromise = null; + } + })(); + + return this.loadingPromise; } /** * Adds `extension` to the list of extensions and starts it if appropriate. + * + * @internal visible for testing only */ - private async loadExtension( + async loadExtension( extensionDir: string, ): Promise { + if (this.loadingPromise) { + await this.loadingPromise; + } this.loadedExtensions ??= []; - if (!fs.statSync(extensionDir).isDirectory()) { + const extension = await this._buildExtension(extensionDir); + if (!extension) { + return null; + } + + if ( + this.getExtensions().find( + (installed) => installed.name === extension.name, + ) + ) { + throw new Error( + `Extension with name ${extension.name} already was loaded.`, + ); + } + + this.loadedExtensions = [...this.loadedExtensions, extension]; + await this.maybeStartExtension(extension); + return extension; + } + + /** + * Builds an extension without side effects (does not mutate loadedExtensions or start it). + */ + private async _buildExtension( + extensionDir: string, + ): Promise { + try { + const stats = await fs.promises.stat(extensionDir); + if (!stats.isDirectory()) { + return null; + } + } catch { return null; } @@ -592,13 +665,6 @@ Would you like to attempt to install via "git clone" instead?`, try { let config = await this.loadExtensionConfig(effectiveExtensionPath); - if ( - this.getExtensions().find((extension) => extension.name === config.name) - ) { - throw new Error( - `Extension with name ${config.name} already was loaded.`, - ); - } const extensionId = getExtensionId(config, installMetadata); @@ -768,7 +834,7 @@ Would you like to attempt to install via "git clone" instead?`, ); } - const extension: GeminiCLIExtension = { + return { name: config.name, version: config.version, path: effectiveExtensionPath, @@ -788,10 +854,6 @@ Would you like to attempt to install via "git clone" instead?`, agents: agentLoadResult.agents, themes: config.themes, }; - this.loadedExtensions = [...this.loadedExtensions, extension]; - - await this.maybeStartExtension(extension); - return extension; } catch (e) { debugLogger.error( `Warning: Skipping extension in ${effectiveExtensionPath}: ${getErrorMessage( diff --git a/packages/cli/src/zed-integration/zedIntegration.test.ts b/packages/cli/src/zed-integration/zedIntegration.test.ts index f1cb22bfda..37da3035c3 100644 --- a/packages/cli/src/zed-integration/zedIntegration.test.ts +++ b/packages/cli/src/zed-integration/zedIntegration.test.ts @@ -177,6 +177,14 @@ describe('GeminiAgent', () => { expect(response.protocolVersion).toBe(acp.PROTOCOL_VERSION); expect(response.authMethods).toHaveLength(3); + const geminiAuth = response.authMethods?.find( + (m) => m.id === AuthType.USE_GEMINI, + ); + expect(geminiAuth?._meta).toEqual({ + 'api-key': { + provider: 'google', + }, + }); expect(response.agentCapabilities?.loadSession).toBe(true); }); @@ -187,6 +195,7 @@ describe('GeminiAgent', () => { expect(mockConfig.refreshAuth).toHaveBeenCalledWith( AuthType.LOGIN_WITH_GOOGLE, + undefined, ); expect(mockSettings.setValue).toHaveBeenCalledWith( SettingScope.User, @@ -195,6 +204,25 @@ describe('GeminiAgent', () => { ); }); + it('should authenticate correctly with api-key in _meta', async () => { + await agent.authenticate({ + methodId: AuthType.USE_GEMINI, + _meta: { + 'api-key': 'test-api-key', + }, + } as unknown as acp.AuthenticateRequest); + + expect(mockConfig.refreshAuth).toHaveBeenCalledWith( + AuthType.USE_GEMINI, + 'test-api-key', + ); + expect(mockSettings.setValue).toHaveBeenCalledWith( + SettingScope.User, + 'security.auth.selectedType', + AuthType.USE_GEMINI, + ); + }); + it('should create a new session', async () => { mockConfig.getContentGeneratorConfig = vi.fn().mockReturnValue({ apiKey: 'test-key', diff --git a/packages/cli/src/zed-integration/zedIntegration.ts b/packages/cli/src/zed-integration/zedIntegration.ts index f6c0a63349..d4f1b27b92 100644 --- a/packages/cli/src/zed-integration/zedIntegration.ts +++ b/packages/cli/src/zed-integration/zedIntegration.ts @@ -37,12 +37,17 @@ import { partListUnionToString, LlmRole, ApprovalMode, + getVersion, convertSessionToClientHistory, } from '@google/gemini-cli-core'; import * as acp from '@agentclientprotocol/sdk'; import { AcpFileSystemService } from './fileSystemService.js'; import { getAcpErrorMessage } from './acpErrors.js'; import { Readable, Writable } from 'node:stream'; + +function hasMeta(obj: unknown): obj is { _meta?: Record } { + return typeof obj === 'object' && obj !== null && '_meta' in obj; +} import type { Content, Part, FunctionCall } from '@google/genai'; import type { LoadedSettings } from '../config/settings.js'; import { SettingScope, loadSettings } from '../config/settings.js'; @@ -81,6 +86,7 @@ export async function runZedIntegration( export class GeminiAgent { private sessions: Map = new Map(); private clientCapabilities: acp.ClientCapabilities | undefined; + private apiKey: string | undefined; constructor( private config: Config, @@ -97,25 +103,35 @@ export class GeminiAgent { { id: AuthType.LOGIN_WITH_GOOGLE, name: 'Log in with Google', - description: null, + description: 'Log in with your Google account', }, { id: AuthType.USE_GEMINI, - name: 'Use Gemini API key', - description: - 'Requires setting the `GEMINI_API_KEY` environment variable', + name: 'Gemini API key', + description: 'Use an API key with Gemini Developer API', + _meta: { + 'api-key': { + provider: 'google', + }, + }, }, { id: AuthType.USE_VERTEX_AI, name: 'Vertex AI', - description: null, + description: 'Use an API key with Vertex AI GenAI API', }, ]; await this.config.initialize(); + const version = await getVersion(); return { protocolVersion: acp.PROTOCOL_VERSION, authMethods, + agentInfo: { + name: 'gemini-cli', + title: 'Gemini CLI', + version, + }, agentCapabilities: { loadSession: true, promptCapabilities: { @@ -131,7 +147,8 @@ export class GeminiAgent { }; } - async authenticate({ methodId }: acp.AuthenticateRequest): Promise { + async authenticate(req: acp.AuthenticateRequest): Promise { + const { methodId } = req; const method = z.nativeEnum(AuthType).parse(methodId); const selectedAuthType = this.settings.merged.security.auth.selectedType; @@ -139,17 +156,21 @@ export class GeminiAgent { if (selectedAuthType && selectedAuthType !== method) { await clearCachedCredentialFile(); } + // Check for api-key in _meta + const meta = hasMeta(req) ? req._meta : undefined; + const apiKey = + typeof meta?.['api-key'] === 'string' ? meta['api-key'] : undefined; // Refresh auth with the requested method // This will reuse existing credentials if they're valid, // or perform new authentication if needed try { - await this.config.refreshAuth(method); + if (apiKey) { + this.apiKey = apiKey; + } + await this.config.refreshAuth(method, apiKey ?? this.apiKey); } catch (e) { - throw new acp.RequestError( - getErrorStatus(e) || 401, - getAcpErrorMessage(e), - ); + throw new acp.RequestError(-32000, getAcpErrorMessage(e)); } this.settings.setValue( SettingScope.User, @@ -177,7 +198,7 @@ export class GeminiAgent { let isAuthenticated = false; let authErrorMessage = ''; try { - await config.refreshAuth(authType); + await config.refreshAuth(authType, this.apiKey); isAuthenticated = true; // Extra validation for Gemini API key @@ -199,7 +220,7 @@ export class GeminiAgent { if (!isAuthenticated) { throw new acp.RequestError( - 401, + -32000, authErrorMessage || 'Authentication required.', ); } @@ -302,7 +323,7 @@ export class GeminiAgent { // This satisfies the security requirement to verify the user before executing // potentially unsafe server definitions. try { - await config.refreshAuth(selectedAuthType); + await config.refreshAuth(selectedAuthType, this.apiKey); } catch (e) { debugLogger.error(`Authentication failed: ${e}`); throw acp.RequestError.authRequired(); diff --git a/packages/core/src/agents/a2a-client-manager.test.ts b/packages/core/src/agents/a2a-client-manager.test.ts index 42e31d2405..58e68759fe 100644 --- a/packages/core/src/agents/a2a-client-manager.test.ts +++ b/packages/core/src/agents/a2a-client-manager.test.ts @@ -53,14 +53,14 @@ describe('A2AClientManager', () => { let manager: A2AClientManager; // Stable mocks initialized once - const sendMessageMock = vi.fn(); + const sendMessageStreamMock = vi.fn(); const getTaskMock = vi.fn(); const cancelTaskMock = vi.fn(); const getAgentCardMock = vi.fn(); const authFetchMock = vi.fn(); const mockClient = { - sendMessage: sendMessageMock, + sendMessageStream: sendMessageStreamMock, getTask: getTaskMock, cancelTask: cancelTaskMock, getAgentCard: getAgentCardMock, @@ -178,75 +178,91 @@ describe('A2AClientManager', () => { }); }); - describe('sendMessage', () => { + describe('sendMessageStream', () => { beforeEach(async () => { await manager.loadAgent('TestAgent', 'http://test.agent'); }); - it('should send a message to the correct agent', async () => { - sendMessageMock.mockResolvedValue({ + it('should send a message and return a stream', async () => { + const mockResult = { kind: 'message', messageId: 'a', parts: [], role: 'agent', - } as SendMessageResult); + } as SendMessageResult; - await manager.sendMessage('TestAgent', 'Hello'); - expect(sendMessageMock).toHaveBeenCalledWith( + sendMessageStreamMock.mockReturnValue( + (async function* () { + yield mockResult; + })(), + ); + + const stream = manager.sendMessageStream('TestAgent', 'Hello'); + const results = []; + for await (const res of stream) { + results.push(res); + } + + expect(results).toEqual([mockResult]); + expect(sendMessageStreamMock).toHaveBeenCalledWith( expect.objectContaining({ message: expect.anything(), }), + expect.any(Object), ); }); it('should use contextId and taskId when provided', async () => { - sendMessageMock.mockResolvedValue({ - kind: 'message', - messageId: 'a', - parts: [], - role: 'agent', - } as SendMessageResult); + sendMessageStreamMock.mockReturnValue( + (async function* () { + yield { + kind: 'message', + messageId: 'a', + parts: [], + role: 'agent', + } as SendMessageResult; + })(), + ); const expectedContextId = 'user-context-id'; const expectedTaskId = 'user-task-id'; - await manager.sendMessage('TestAgent', 'Hello', { + const stream = manager.sendMessageStream('TestAgent', 'Hello', { contextId: expectedContextId, taskId: expectedTaskId, }); - const call = sendMessageMock.mock.calls[0][0]; + for await (const _ of stream) { + // consume stream + } + + const call = sendMessageStreamMock.mock.calls[0][0]; expect(call.message.contextId).toBe(expectedContextId); expect(call.message.taskId).toBe(expectedTaskId); }); - it('should return result from client', async () => { - const mockResult = { - contextId: 'server-context-id', - id: 'ctx-1', - kind: 'task', - status: { state: 'working' }, - }; - - sendMessageMock.mockResolvedValueOnce(mockResult as SendMessageResult); - - const response = await manager.sendMessage('TestAgent', 'Hello'); - - expect(response).toEqual(mockResult); - }); - it('should throw prefixed error on failure', async () => { - sendMessageMock.mockRejectedValueOnce(new Error('Network error')); + sendMessageStreamMock.mockImplementationOnce(() => { + throw new Error('Network error'); + }); - await expect(manager.sendMessage('TestAgent', 'Hello')).rejects.toThrow( - 'A2AClient SendMessage Error [TestAgent]: Network error', + const stream = manager.sendMessageStream('TestAgent', 'Hello'); + await expect(async () => { + for await (const _ of stream) { + // consume + } + }).rejects.toThrow( + '[A2AClientManager] sendMessageStream Error [TestAgent]: Network error', ); }); it('should throw an error if the agent is not found', async () => { - await expect( - manager.sendMessage('NonExistentAgent', 'Hello'), - ).rejects.toThrow("Agent 'NonExistentAgent' not found."); + const stream = manager.sendMessageStream('NonExistentAgent', 'Hello'); + await expect(async () => { + for await (const _ of stream) { + // consume + } + }).rejects.toThrow("Agent 'NonExistentAgent' not found."); }); }); diff --git a/packages/core/src/agents/a2a-client-manager.ts b/packages/core/src/agents/a2a-client-manager.ts index 82adf2653c..694905cdc5 100644 --- a/packages/core/src/agents/a2a-client-manager.ts +++ b/packages/core/src/agents/a2a-client-manager.ts @@ -4,7 +4,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { AgentCard, Message, MessageSendParams, Task } from '@a2a-js/sdk'; +import type { + AgentCard, + Message, + MessageSendParams, + Task, + TaskStatusUpdateEvent, + TaskArtifactUpdateEvent, +} from '@a2a-js/sdk'; import { type Client, ClientFactory, @@ -18,7 +25,11 @@ import { import { v4 as uuidv4 } from 'uuid'; import { debugLogger } from '../utils/debugLogger.js'; -export type SendMessageResult = Message | Task; +export type SendMessageResult = + | Message + | Task + | TaskStatusUpdateEvent + | TaskArtifactUpdateEvent; /** * Manages A2A clients and caches loaded agent information. @@ -110,18 +121,18 @@ export class A2AClientManager { } /** - * Sends a message to a loaded agent. + * Sends a message to a loaded agent and returns a stream of responses. * @param agentName The name of the agent to send the message to. * @param message The message content. * @param options Optional context and task IDs to maintain conversation state. - * @returns The response from the agent (Message or Task). + * @returns An async iterable of responses from the agent (Message or Task). * @throws Error if the agent returns an error response. */ - async sendMessage( + async *sendMessageStream( agentName: string, message: string, - options?: { contextId?: string; taskId?: string }, - ): Promise { + options?: { contextId?: string; taskId?: string; signal?: AbortSignal }, + ): AsyncIterable { const client = this.clients.get(agentName); if (!client) { throw new Error(`Agent '${agentName}' not found.`); @@ -136,20 +147,19 @@ export class A2AClientManager { contextId: options?.contextId, taskId: options?.taskId, }, - configuration: { - blocking: true, - }, }; try { - return await client.sendMessage(messageParams); + yield* client.sendMessageStream(messageParams, { + signal: options?.signal, + }); } catch (error: unknown) { - const prefix = `A2AClient SendMessage Error [${agentName}]`; + const prefix = `[A2AClientManager] sendMessageStream Error [${agentName}]`; if (error instanceof Error) { throw new Error(`${prefix}: ${error.message}`, { cause: error }); } throw new Error( - `${prefix}: Unexpected error during sendMessage: ${String(error)}`, + `${prefix}: Unexpected error during sendMessageStream: ${String(error)}`, ); } } diff --git a/packages/core/src/agents/a2aUtils.test.ts b/packages/core/src/agents/a2aUtils.test.ts index dcb911f2c0..711650ea80 100644 --- a/packages/core/src/agents/a2aUtils.test.ts +++ b/packages/core/src/agents/a2aUtils.test.ts @@ -7,12 +7,40 @@ import { describe, it, expect } from 'vitest'; import { extractMessageText, - extractTaskText, extractIdsFromResponse, + isTerminalState, + A2AResultReassembler, } from './a2aUtils.js'; -import type { Message, Task, TextPart, DataPart, FilePart } from '@a2a-js/sdk'; +import type { SendMessageResult } from './a2a-client-manager.js'; +import type { + Message, + Task, + TextPart, + DataPart, + FilePart, + TaskStatusUpdateEvent, + TaskArtifactUpdateEvent, +} from '@a2a-js/sdk'; describe('a2aUtils', () => { + describe('isTerminalState', () => { + it('should return true for completed, failed, canceled, and rejected', () => { + expect(isTerminalState('completed')).toBe(true); + expect(isTerminalState('failed')).toBe(true); + expect(isTerminalState('canceled')).toBe(true); + expect(isTerminalState('rejected')).toBe(true); + }); + + it('should return false for working, submitted, input-required, auth-required, and unknown', () => { + expect(isTerminalState('working')).toBe(false); + expect(isTerminalState('submitted')).toBe(false); + expect(isTerminalState('input-required')).toBe(false); + expect(isTerminalState('auth-required')).toBe(false); + expect(isTerminalState('unknown')).toBe(false); + expect(isTerminalState(undefined)).toBe(false); + }); + }); + describe('extractIdsFromResponse', () => { it('should extract IDs from a message response', () => { const message: Message = { @@ -25,7 +53,11 @@ describe('a2aUtils', () => { }; const result = extractIdsFromResponse(message); - expect(result).toEqual({ contextId: 'ctx-1', taskId: 'task-1' }); + expect(result).toEqual({ + contextId: 'ctx-1', + taskId: 'task-1', + clearTaskId: false, + }); }); it('should extract IDs from an in-progress task response', () => { @@ -37,7 +69,76 @@ describe('a2aUtils', () => { }; const result = extractIdsFromResponse(task); - expect(result).toEqual({ contextId: 'ctx-2', taskId: 'task-2' }); + expect(result).toEqual({ + contextId: 'ctx-2', + taskId: 'task-2', + clearTaskId: false, + }); + }); + + it('should set clearTaskId true for terminal task response', () => { + const task: Task = { + id: 'task-3', + contextId: 'ctx-3', + kind: 'task', + status: { state: 'completed' }, + }; + + const result = extractIdsFromResponse(task); + expect(result.clearTaskId).toBe(true); + }); + + it('should set clearTaskId true for terminal status update', () => { + const update = { + kind: 'status-update', + contextId: 'ctx-4', + taskId: 'task-4', + final: true, + status: { state: 'failed' }, + }; + + const result = extractIdsFromResponse( + update as unknown as TaskStatusUpdateEvent, + ); + expect(result.contextId).toBe('ctx-4'); + expect(result.taskId).toBe('task-4'); + expect(result.clearTaskId).toBe(true); + }); + + it('should extract IDs from an artifact-update event', () => { + const update = { + kind: 'artifact-update', + taskId: 'task-5', + contextId: 'ctx-5', + artifact: { + artifactId: 'art-1', + parts: [{ kind: 'text', text: 'artifact content' }], + }, + } as unknown as TaskArtifactUpdateEvent; + + const result = extractIdsFromResponse(update); + expect(result).toEqual({ + contextId: 'ctx-5', + taskId: 'task-5', + clearTaskId: false, + }); + }); + + it('should extract taskId from status update event', () => { + const update = { + kind: 'status-update', + taskId: 'task-6', + contextId: 'ctx-6', + final: false, + status: { state: 'working' }, + }; + + const result = extractIdsFromResponse( + update as unknown as TaskStatusUpdateEvent, + ); + expect(result.taskId).toBe('task-6'); + expect(result.contextId).toBe('ctx-6'); + expect(result.clearTaskId).toBe(false); }); }); @@ -123,49 +224,65 @@ describe('a2aUtils', () => { }); }); - describe('extractTaskText', () => { - it('should extract basic task info (clean)', () => { - const task: Task = { - id: 'task-1', - contextId: 'ctx-1', - kind: 'task', + describe('A2AResultReassembler', () => { + it('should reassemble sequential messages and incremental artifacts', () => { + const reassembler = new A2AResultReassembler(); + + // 1. Initial status + reassembler.update({ + kind: 'status-update', + taskId: 't1', status: { state: 'working', message: { kind: 'message', role: 'agent', - messageId: 'm1', - parts: [{ kind: 'text', text: 'Processing...' } as TextPart], - }, + parts: [{ kind: 'text', text: 'Analyzing...' }], + } as Message, }, - }; + } as unknown as SendMessageResult); - const result = extractTaskText(task); - expect(result).not.toContain('ID: task-1'); - expect(result).not.toContain('State: working'); - expect(result).toBe('Processing...'); - }); + // 2. First artifact chunk + reassembler.update({ + kind: 'artifact-update', + taskId: 't1', + append: false, + artifact: { + artifactId: 'a1', + name: 'Code', + parts: [{ kind: 'text', text: 'print(' }], + }, + } as unknown as SendMessageResult); - it('should extract artifacts with headers', () => { - const task: Task = { - id: 'task-1', - contextId: 'ctx-1', - kind: 'task', - status: { state: 'completed' }, - artifacts: [ - { - artifactId: 'art-1', - name: 'Report', - parts: [{ kind: 'text', text: 'This is the report.' } as TextPart], - }, - ], - }; + // 3. Second status + reassembler.update({ + kind: 'status-update', + taskId: 't1', + status: { + state: 'working', + message: { + kind: 'message', + role: 'agent', + parts: [{ kind: 'text', text: 'Processing...' }], + } as Message, + }, + } as unknown as SendMessageResult); - const result = extractTaskText(task); - expect(result).toContain('Artifact (Report):'); - expect(result).toContain('This is the report.'); - expect(result).not.toContain('Artifacts:'); - expect(result).not.toContain(' - Name: Report'); + // 4. Second artifact chunk (append) + reassembler.update({ + kind: 'artifact-update', + taskId: 't1', + append: true, + artifact: { + artifactId: 'a1', + parts: [{ kind: 'text', text: '"Done")' }], + }, + } as unknown as SendMessageResult); + + const output = reassembler.toString(); + expect(output).toBe( + 'Analyzing...\n\nProcessing...\n\nArtifact (Code):\nprint("Done")', + ); }); }); }); diff --git a/packages/core/src/agents/a2aUtils.ts b/packages/core/src/agents/a2aUtils.ts index f1e66309d6..e753d047d0 100644 --- a/packages/core/src/agents/a2aUtils.ts +++ b/packages/core/src/agents/a2aUtils.ts @@ -6,12 +6,120 @@ import type { Message, - Task, Part, TextPart, DataPart, FilePart, + Artifact, + TaskState, + TaskStatusUpdateEvent, } from '@a2a-js/sdk'; +import type { SendMessageResult } from './a2a-client-manager.js'; + +/** + * Reassembles incremental A2A streaming updates into a coherent result. + * Shows sequential status/messages followed by all reassembled artifacts. + */ +export class A2AResultReassembler { + private messageLog: string[] = []; + private artifacts = new Map(); + private artifactChunks = new Map(); + + /** + * Processes a new chunk from the A2A stream. + */ + update(chunk: SendMessageResult) { + if (!('kind' in chunk)) return; + + switch (chunk.kind) { + case 'status-update': + this.pushMessage(chunk.status?.message); + break; + + case 'artifact-update': + if (chunk.artifact) { + const id = chunk.artifact.artifactId; + const existing = this.artifacts.get(id); + + if (chunk.append && existing) { + for (const part of chunk.artifact.parts) { + existing.parts.push(structuredClone(part)); + } + } else { + this.artifacts.set(id, structuredClone(chunk.artifact)); + } + + const newText = extractPartsText(chunk.artifact.parts, ''); + let chunks = this.artifactChunks.get(id); + if (!chunks) { + chunks = []; + this.artifactChunks.set(id, chunks); + } + if (chunk.append) { + chunks.push(newText); + } else { + chunks.length = 0; + chunks.push(newText); + } + } + break; + + case 'task': + this.pushMessage(chunk.status?.message); + if (chunk.artifacts) { + for (const art of chunk.artifacts) { + this.artifacts.set(art.artifactId, structuredClone(art)); + this.artifactChunks.set(art.artifactId, [ + extractPartsText(art.parts, ''), + ]); + } + } + break; + + case 'message': { + this.pushMessage(chunk); + break; + } + + default: + break; + } + } + + private pushMessage(message: Message | undefined) { + if (!message) return; + const text = extractPartsText(message.parts, '\n'); + if (text && this.messageLog[this.messageLog.length - 1] !== text) { + this.messageLog.push(text); + } + } + + /** + * Returns a human-readable string representation of the current reassembled state. + */ + toString(): string { + const joinedMessages = this.messageLog.join('\n\n'); + + const artifactsOutput = Array.from(this.artifacts.keys()) + .map((id) => { + const chunks = this.artifactChunks.get(id); + const artifact = this.artifacts.get(id); + if (!chunks || !artifact) return ''; + const content = chunks.join(''); + const header = artifact.name + ? `Artifact (${artifact.name}):` + : 'Artifact:'; + return `${header}\n${content}`; + }) + .filter(Boolean) + .join('\n\n'); + + if (joinedMessages && artifactsOutput) { + return `${joinedMessages}\n\n${artifactsOutput}`; + } + return joinedMessages || artifactsOutput; + } +} /** * Extracts a human-readable text representation from a Message object. @@ -22,7 +130,23 @@ export function extractMessageText(message: Message | undefined): string { return ''; } - return extractPartsText(message.parts); + return extractPartsText(message.parts, '\n'); +} + +/** + * Extracts text from an array of parts, joining them with the specified separator. + */ +function extractPartsText( + parts: Part[] | undefined, + separator: string, +): string { + if (!parts || parts.length === 0) { + return ''; + } + return parts + .map((p) => extractPartText(p)) + .filter(Boolean) + .join(separator); } /** @@ -52,50 +176,6 @@ function extractPartText(part: Part): string { return ''; } -/** - * Extracts a clean, human-readable text summary from a Task object. - * Includes the status message and any artifact content with context headers. - * Technical metadata like ID and State are omitted for better clarity and token efficiency. - */ -export function extractTaskText(task: Task): string { - const parts: string[] = []; - - // Status Message - const statusMessageText = extractMessageText(task.status?.message); - if (statusMessageText) { - parts.push(statusMessageText); - } - - // Artifacts - if (task.artifacts) { - for (const artifact of task.artifacts) { - const artifactContent = extractPartsText(artifact.parts); - - if (artifactContent) { - const header = artifact.name - ? `Artifact (${artifact.name}):` - : 'Artifact:'; - parts.push(`${header}\n${artifactContent}`); - } - } - } - - return parts.join('\n\n'); -} - -/** - * Extracts text from an array of parts. - */ -function extractPartsText(parts: Part[] | undefined): string { - if (!parts || parts.length === 0) { - return ''; - } - return parts - .map((p) => extractPartText(p)) - .filter(Boolean) - .join('\n'); -} - // Type Guards function isTextPart(part: Part): part is TextPart { @@ -110,36 +190,58 @@ function isFilePart(part: Part): part is FilePart { return part.kind === 'file'; } +function isStatusUpdateEvent( + result: SendMessageResult, +): result is TaskStatusUpdateEvent { + return result.kind === 'status-update'; +} + /** - * Extracts contextId and taskId from a Message or Task response. + * Returns true if the given state is a terminal state for a task. + */ +export function isTerminalState(state: TaskState | undefined): boolean { + return ( + state === 'completed' || + state === 'failed' || + state === 'canceled' || + state === 'rejected' + ); +} + +/** + * Extracts contextId and taskId from a Message, Task, or Update response. * Follows the pattern from the A2A CLI sample to maintain conversational continuity. */ -export function extractIdsFromResponse(result: Message | Task): { +export function extractIdsFromResponse(result: SendMessageResult): { contextId?: string; taskId?: string; + clearTaskId?: boolean; } { let contextId: string | undefined; let taskId: string | undefined; + let clearTaskId = false; - if (result.kind === 'message') { - taskId = result.taskId; - contextId = result.contextId; - } else if (result.kind === 'task') { - taskId = result.id; - contextId = result.contextId; - - // If the task is in a final state (and not input-required), we clear the taskId - // so that the next interaction starts a fresh task (or keeps context without being bound to the old task). - if ( - result.status && - result.status.state !== 'input-required' && - (result.status.state === 'completed' || - result.status.state === 'failed' || - result.status.state === 'canceled') - ) { - taskId = undefined; + if ('kind' in result) { + const kind = result.kind; + if (kind === 'message' || kind === 'artifact-update') { + taskId = result.taskId; + contextId = result.contextId; + } else if (kind === 'task') { + taskId = result.id; + contextId = result.contextId; + if (isTerminalState(result.status?.state)) { + clearTaskId = true; + } + } else if (isStatusUpdateEvent(result)) { + taskId = result.taskId; + contextId = result.contextId; + // Note: We ignore the 'final' flag here per A2A protocol best practices, + // as a stream can close while a task is still in a 'working' state. + if (isTerminalState(result.status?.state)) { + clearTaskId = true; + } } } - return { contextId, taskId }; + return { contextId, taskId, clearTaskId }; } diff --git a/packages/core/src/agents/remote-invocation.test.ts b/packages/core/src/agents/remote-invocation.test.ts index 7baa77d941..9688b61d78 100644 --- a/packages/core/src/agents/remote-invocation.test.ts +++ b/packages/core/src/agents/remote-invocation.test.ts @@ -14,7 +14,10 @@ import { type Mock, } from 'vitest'; import { RemoteAgentInvocation } from './remote-invocation.js'; -import { A2AClientManager } from './a2a-client-manager.js'; +import { + A2AClientManager, + type SendMessageResult, +} from './a2a-client-manager.js'; import type { RemoteAgentDefinition } from './types.js'; import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; @@ -41,7 +44,7 @@ describe('RemoteAgentInvocation', () => { const mockClientManager = { getClient: vi.fn(), loadAgent: vi.fn(), - sendMessage: vi.fn(), + sendMessageStream: vi.fn(), }; const mockMessageBus = createMockMessageBus(); @@ -78,12 +81,16 @@ describe('RemoteAgentInvocation', () => { it('uses "Get Started!" default when query is missing during execution', async () => { mockClientManager.getClient.mockReturnValue({}); - mockClientManager.sendMessage.mockResolvedValue({ - kind: 'message', - messageId: 'msg-1', - role: 'agent', - parts: [{ kind: 'text', text: 'Hello' }], - }); + mockClientManager.sendMessageStream.mockImplementation( + async function* () { + yield { + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [{ kind: 'text', text: 'Hello' }], + }; + }, + ); const invocation = new RemoteAgentInvocation( mockDefinition, @@ -92,10 +99,10 @@ describe('RemoteAgentInvocation', () => { ); await invocation.execute(new AbortController().signal); - expect(mockClientManager.sendMessage).toHaveBeenCalledWith( + expect(mockClientManager.sendMessageStream).toHaveBeenCalledWith( 'test-agent', 'Get Started!', - expect.any(Object), + expect.objectContaining({ signal: expect.any(Object) }), ); }); @@ -113,12 +120,16 @@ describe('RemoteAgentInvocation', () => { describe('Execution Logic', () => { it('should lazy load the agent with ADCHandler if not present', async () => { mockClientManager.getClient.mockReturnValue(undefined); - mockClientManager.sendMessage.mockResolvedValue({ - kind: 'message', - messageId: 'msg-1', - role: 'agent', - parts: [{ kind: 'text', text: 'Hello' }], - }); + mockClientManager.sendMessageStream.mockImplementation( + async function* () { + yield { + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [{ kind: 'text', text: 'Hello' }], + }; + }, + ); const invocation = new RemoteAgentInvocation( mockDefinition, @@ -141,12 +152,16 @@ describe('RemoteAgentInvocation', () => { it('should not load the agent if already present', async () => { mockClientManager.getClient.mockReturnValue({}); - mockClientManager.sendMessage.mockResolvedValue({ - kind: 'message', - messageId: 'msg-1', - role: 'agent', - parts: [{ kind: 'text', text: 'Hello' }], - }); + mockClientManager.sendMessageStream.mockImplementation( + async function* () { + yield { + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [{ kind: 'text', text: 'Hello' }], + }; + }, + ); const invocation = new RemoteAgentInvocation( mockDefinition, @@ -164,14 +179,18 @@ describe('RemoteAgentInvocation', () => { mockClientManager.getClient.mockReturnValue({}); // First call return values - mockClientManager.sendMessage.mockResolvedValueOnce({ - kind: 'message', - messageId: 'msg-1', - role: 'agent', - parts: [{ kind: 'text', text: 'Response 1' }], - contextId: 'ctx-1', - taskId: 'task-1', - }); + mockClientManager.sendMessageStream.mockImplementationOnce( + async function* () { + yield { + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [{ kind: 'text', text: 'Response 1' }], + contextId: 'ctx-1', + taskId: 'task-1', + }; + }, + ); const invocation1 = new RemoteAgentInvocation( mockDefinition, @@ -184,21 +203,25 @@ describe('RemoteAgentInvocation', () => { // Execute first time const result1 = await invocation1.execute(new AbortController().signal); expect(result1.returnDisplay).toBe('Response 1'); - expect(mockClientManager.sendMessage).toHaveBeenLastCalledWith( + expect(mockClientManager.sendMessageStream).toHaveBeenLastCalledWith( 'test-agent', 'first', - { contextId: undefined, taskId: undefined }, + { contextId: undefined, taskId: undefined, signal: expect.any(Object) }, ); // Prepare for second call with simulated state persistence - mockClientManager.sendMessage.mockResolvedValueOnce({ - kind: 'message', - messageId: 'msg-2', - role: 'agent', - parts: [{ kind: 'text', text: 'Response 2' }], - contextId: 'ctx-1', - taskId: 'task-2', - }); + mockClientManager.sendMessageStream.mockImplementationOnce( + async function* () { + yield { + kind: 'message', + messageId: 'msg-2', + role: 'agent', + parts: [{ kind: 'text', text: 'Response 2' }], + contextId: 'ctx-1', + taskId: 'task-2', + }; + }, + ); const invocation2 = new RemoteAgentInvocation( mockDefinition, @@ -210,21 +233,25 @@ describe('RemoteAgentInvocation', () => { const result2 = await invocation2.execute(new AbortController().signal); expect(result2.returnDisplay).toBe('Response 2'); - expect(mockClientManager.sendMessage).toHaveBeenLastCalledWith( + expect(mockClientManager.sendMessageStream).toHaveBeenLastCalledWith( 'test-agent', 'second', - { contextId: 'ctx-1', taskId: 'task-1' }, // Used state from first call + { contextId: 'ctx-1', taskId: 'task-1', signal: expect.any(Object) }, // Used state from first call ); // Third call: Task completes - mockClientManager.sendMessage.mockResolvedValueOnce({ - kind: 'task', - id: 'task-2', - contextId: 'ctx-1', - status: { state: 'completed', message: undefined }, - artifacts: [], - history: [], - }); + mockClientManager.sendMessageStream.mockImplementationOnce( + async function* () { + yield { + kind: 'task', + id: 'task-2', + contextId: 'ctx-1', + status: { state: 'completed', message: undefined }, + artifacts: [], + history: [], + }; + }, + ); const invocation3 = new RemoteAgentInvocation( mockDefinition, @@ -236,12 +263,16 @@ describe('RemoteAgentInvocation', () => { await invocation3.execute(new AbortController().signal); // Fourth call: Should start new task (taskId undefined) - mockClientManager.sendMessage.mockResolvedValueOnce({ - kind: 'message', - messageId: 'msg-3', - role: 'agent', - parts: [{ kind: 'text', text: 'New Task' }], - }); + mockClientManager.sendMessageStream.mockImplementationOnce( + async function* () { + yield { + kind: 'message', + messageId: 'msg-3', + role: 'agent', + parts: [{ kind: 'text', text: 'New Task' }], + }; + }, + ); const invocation4 = new RemoteAgentInvocation( mockDefinition, @@ -252,17 +283,84 @@ describe('RemoteAgentInvocation', () => { ); await invocation4.execute(new AbortController().signal); - expect(mockClientManager.sendMessage).toHaveBeenLastCalledWith( + expect(mockClientManager.sendMessageStream).toHaveBeenLastCalledWith( 'test-agent', 'fourth', - { contextId: 'ctx-1', taskId: undefined }, // taskId cleared! + { contextId: 'ctx-1', taskId: undefined, signal: expect.any(Object) }, // taskId cleared! ); }); + it('should handle streaming updates and reassemble output', async () => { + mockClientManager.getClient.mockReturnValue({}); + mockClientManager.sendMessageStream.mockImplementation( + async function* () { + yield { + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [{ kind: 'text', text: 'Hello' }], + }; + yield { + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [{ kind: 'text', text: 'Hello World' }], + }; + }, + ); + + const updateOutput = vi.fn(); + const invocation = new RemoteAgentInvocation( + mockDefinition, + { query: 'hi' }, + mockMessageBus, + ); + await invocation.execute(new AbortController().signal, updateOutput); + + expect(updateOutput).toHaveBeenCalledWith('Hello'); + expect(updateOutput).toHaveBeenCalledWith('Hello\n\nHello World'); + }); + + it('should abort when signal is aborted during streaming', async () => { + mockClientManager.getClient.mockReturnValue({}); + const controller = new AbortController(); + mockClientManager.sendMessageStream.mockImplementation( + async function* () { + yield { + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [{ kind: 'text', text: 'Partial' }], + }; + // Simulate abort between chunks + controller.abort(); + yield { + kind: 'message', + messageId: 'msg-2', + role: 'agent', + parts: [{ kind: 'text', text: 'Partial response continued' }], + }; + }, + ); + + const invocation = new RemoteAgentInvocation( + mockDefinition, + { query: 'hi' }, + mockMessageBus, + ); + const result = await invocation.execute(controller.signal); + + expect(result.error).toBeDefined(); + expect(result.error?.message).toContain('Operation aborted'); + }); + it('should handle errors gracefully', async () => { mockClientManager.getClient.mockReturnValue({}); - mockClientManager.sendMessage.mockRejectedValue( - new Error('Network error'), + mockClientManager.sendMessageStream.mockImplementation( + async function* () { + if (Math.random() < 0) yield {} as unknown as SendMessageResult; + throw new Error('Network error'); + }, ); const invocation = new RemoteAgentInvocation( @@ -282,15 +380,19 @@ describe('RemoteAgentInvocation', () => { it('should use a2a helpers for extracting text', async () => { mockClientManager.getClient.mockReturnValue({}); // Mock a complex message part that needs extraction - mockClientManager.sendMessage.mockResolvedValue({ - kind: 'message', - messageId: 'msg-1', - role: 'agent', - parts: [ - { kind: 'text', text: 'Extracted text' }, - { kind: 'data', data: { foo: 'bar' } }, - ], - }); + mockClientManager.sendMessageStream.mockImplementation( + async function* () { + yield { + kind: 'message', + messageId: 'msg-1', + role: 'agent', + parts: [ + { kind: 'text', text: 'Extracted text' }, + { kind: 'data', data: { foo: 'bar' } }, + ], + }; + }, + ); const invocation = new RemoteAgentInvocation( mockDefinition, @@ -304,6 +406,105 @@ describe('RemoteAgentInvocation', () => { // Just check that text is present, exact formatting depends on helper expect(result.returnDisplay).toContain('Extracted text'); }); + + it('should handle mixed response types during streaming (TaskStatusUpdateEvent + Message)', async () => { + mockClientManager.getClient.mockReturnValue({}); + mockClientManager.sendMessageStream.mockImplementation( + async function* () { + yield { + kind: 'status-update', + taskId: 'task-1', + contextId: 'ctx-1', + final: false, + status: { + state: 'working', + message: { + kind: 'message', + role: 'agent', + messageId: 'm1', + parts: [{ kind: 'text', text: 'Thinking...' }], + }, + }, + }; + yield { + kind: 'message', + messageId: 'msg-final', + role: 'agent', + parts: [{ kind: 'text', text: 'Final Answer' }], + }; + }, + ); + + const updateOutput = vi.fn(); + const invocation = new RemoteAgentInvocation( + mockDefinition, + { query: 'hi' }, + mockMessageBus, + ); + const result = await invocation.execute( + new AbortController().signal, + updateOutput, + ); + + expect(updateOutput).toHaveBeenCalledWith('Thinking...'); + expect(updateOutput).toHaveBeenCalledWith('Thinking...\n\nFinal Answer'); + expect(result.returnDisplay).toBe('Thinking...\n\nFinal Answer'); + }); + + it('should handle artifact reassembly with append: true', async () => { + mockClientManager.getClient.mockReturnValue({}); + mockClientManager.sendMessageStream.mockImplementation( + async function* () { + yield { + kind: 'status-update', + taskId: 'task-1', + status: { + state: 'working', + message: { + kind: 'message', + role: 'agent', + parts: [{ kind: 'text', text: 'Generating...' }], + }, + }, + }; + yield { + kind: 'artifact-update', + taskId: 'task-1', + append: false, + artifact: { + artifactId: 'art-1', + name: 'Result', + parts: [{ kind: 'text', text: 'Part 1' }], + }, + }; + yield { + kind: 'artifact-update', + taskId: 'task-1', + append: true, + artifact: { + artifactId: 'art-1', + parts: [{ kind: 'text', text: ' Part 2' }], + }, + }; + }, + ); + + const updateOutput = vi.fn(); + const invocation = new RemoteAgentInvocation( + mockDefinition, + { query: 'hi' }, + mockMessageBus, + ); + await invocation.execute(new AbortController().signal, updateOutput); + + expect(updateOutput).toHaveBeenCalledWith('Generating...'); + expect(updateOutput).toHaveBeenCalledWith( + 'Generating...\n\nArtifact (Result):\nPart 1', + ); + expect(updateOutput).toHaveBeenCalledWith( + 'Generating...\n\nArtifact (Result):\nPart 1 Part 2', + ); + }); }); describe('Confirmations', () => { diff --git a/packages/core/src/agents/remote-invocation.ts b/packages/core/src/agents/remote-invocation.ts index ea43c901a2..b76f216f34 100644 --- a/packages/core/src/agents/remote-invocation.ts +++ b/packages/core/src/agents/remote-invocation.ts @@ -18,14 +18,12 @@ import type { } from './types.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { A2AClientManager } from './a2a-client-manager.js'; -import { - extractMessageText, - extractTaskText, - extractIdsFromResponse, -} from './a2aUtils.js'; +import { extractIdsFromResponse, A2AResultReassembler } from './a2aUtils.js'; import { GoogleAuth } from 'google-auth-library'; import type { AuthenticationHandler } from '@a2a-js/sdk/client'; import { debugLogger } from '../utils/debugLogger.js'; +import type { AnsiOutput } from '../utils/terminalSerializer.js'; +import type { SendMessageResult } from './a2a-client-manager.js'; /** * Authentication handler implementation using Google Application Default Credentials (ADC). @@ -123,10 +121,14 @@ export class RemoteAgentInvocation extends BaseToolInvocation< }; } - async execute(_signal: AbortSignal): Promise { + async execute( + _signal: AbortSignal, + updateOutput?: (output: string | AnsiOutput) => void, + ): Promise { // 1. Ensure the agent is loaded (cached by manager) // We assume the user has provided an access token via some mechanism (TODO), // or we rely on ADC. + const reassembler = new A2AResultReassembler(); try { const priorState = RemoteAgentInvocation.sessionState.get( this.definition.name, @@ -146,49 +148,73 @@ export class RemoteAgentInvocation extends BaseToolInvocation< const message = this.params.query; - const response = await this.clientManager.sendMessage( + const stream = this.clientManager.sendMessageStream( this.definition.name, message, { contextId: this.contextId, taskId: this.taskId, + signal: _signal, }, ); - // Extracts IDs, taskID will be undefined if the task is completed/failed/canceled. - const { contextId, taskId } = extractIdsFromResponse(response); + let finalResponse: SendMessageResult | undefined; - this.contextId = contextId ?? this.contextId; - this.taskId = taskId; + for await (const chunk of stream) { + if (_signal.aborted) { + throw new Error('Operation aborted'); + } + finalResponse = chunk; + reassembler.update(chunk); + if (updateOutput) { + updateOutput(reassembler.toString()); + } + + const { + contextId: newContextId, + taskId: newTaskId, + clearTaskId, + } = extractIdsFromResponse(chunk); + + if (newContextId) { + this.contextId = newContextId; + } + + this.taskId = clearTaskId ? undefined : (newTaskId ?? this.taskId); + } + + if (!finalResponse) { + throw new Error('No response from remote agent.'); + } + + const finalOutput = reassembler.toString(); + + debugLogger.debug( + `[RemoteAgent] Final response from ${this.definition.name}:\n${JSON.stringify(finalResponse, null, 2)}`, + ); + + return { + llmContent: [{ text: finalOutput }], + returnDisplay: finalOutput, + }; + } catch (error: unknown) { + const partialOutput = reassembler.toString(); + const errorMessage = `Error calling remote agent: ${error instanceof Error ? error.message : String(error)}`; + const fullDisplay = partialOutput + ? `${partialOutput}\n\n${errorMessage}` + : errorMessage; + return { + llmContent: [{ text: fullDisplay }], + returnDisplay: fullDisplay, + error: { message: errorMessage }, + }; + } finally { + // Persist state even on partial failures or aborts to maintain conversational continuity. RemoteAgentInvocation.sessionState.set(this.definition.name, { contextId: this.contextId, taskId: this.taskId, }); - - // Extract the output text - const outputText = - response.kind === 'task' - ? extractTaskText(response) - : response.kind === 'message' - ? extractMessageText(response) - : JSON.stringify(response); - - debugLogger.debug( - `[RemoteAgent] Response from ${this.definition.name}:\n${JSON.stringify(response, null, 2)}`, - ); - - return { - llmContent: [{ text: outputText }], - returnDisplay: outputText, - }; - } catch (error: unknown) { - const errorMessage = `Error calling remote agent: ${error instanceof Error ? error.message : String(error)}`; - return { - llmContent: [{ text: errorMessage }], - returnDisplay: errorMessage, - error: { message: errorMessage }, - }; } } } diff --git a/packages/core/src/code_assist/oauth2.test.ts b/packages/core/src/code_assist/oauth2.test.ts index 5726f76451..c1fe162e63 100644 --- a/packages/core/src/code_assist/oauth2.test.ts +++ b/packages/core/src/code_assist/oauth2.test.ts @@ -95,6 +95,7 @@ const mockConfig = { getNoBrowser: () => false, getProxy: () => 'http://test.proxy.com:8080', isBrowserLaunchSuppressed: () => false, + getExperimentalZedIntegration: () => false, } as unknown as Config; // Mock fetch globally diff --git a/packages/core/src/code_assist/oauth2.ts b/packages/core/src/code_assist/oauth2.ts index 14e65f5906..31bc3c0e5e 100644 --- a/packages/core/src/code_assist/oauth2.ts +++ b/packages/core/src/code_assist/oauth2.ts @@ -271,9 +271,12 @@ async function initOauthClient( await triggerPostAuthCallbacks(client.credentials); } else { - const userConsent = await getConsentForOauth(''); - if (!userConsent) { - throw new FatalCancellationError('Authentication cancelled by user.'); + // In Zed integration, we skip the interactive consent and directly open the browser + if (!config.getExperimentalZedIntegration()) { + const userConsent = await getConsentForOauth(''); + if (!userConsent) { + throw new FatalCancellationError('Authentication cancelled by user.'); + } } const webLogin = await authWithWeb(client); diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index a9e9a78415..e92f464fa2 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -499,6 +499,7 @@ describe('Server Config (config.ts)', () => { expect(createContentGeneratorConfig).toHaveBeenCalledWith( config, authType, + undefined, ); // Verify that contentGeneratorConfig is updated expect(config.getContentGeneratorConfig()).toEqual(mockContentConfig); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index dceb65c9a8..7297693b8e 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -1126,7 +1126,7 @@ export class Config { return this.contentGenerator; } - async refreshAuth(authMethod: AuthType) { + async refreshAuth(authMethod: AuthType, apiKey?: string) { // Reset availability service when switching auth this.modelAvailabilityService.reset(); @@ -1152,6 +1152,7 @@ export class Config { const newContentGeneratorConfig = await createContentGeneratorConfig( this, authMethod, + apiKey, ); this.contentGenerator = await createContentGenerator( newContentGeneratorConfig, diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index 7adae874aa..98d8d50020 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -90,9 +90,13 @@ export type ContentGeneratorConfig = { export async function createContentGeneratorConfig( config: Config, authType: AuthType | undefined, + apiKey?: string, ): Promise { const geminiApiKey = - process.env['GEMINI_API_KEY'] || (await loadApiKey()) || undefined; + apiKey || + process.env['GEMINI_API_KEY'] || + (await loadApiKey()) || + undefined; const googleApiKey = process.env['GOOGLE_API_KEY'] || undefined; const googleCloudProject = process.env['GOOGLE_CLOUD_PROJECT'] ||