diff --git a/packages/a2a-server/src/config/config.test.ts b/packages/a2a-server/src/config/config.test.ts index bd8771d1b5..cfe77311ea 100644 --- a/packages/a2a-server/src/config/config.test.ts +++ b/packages/a2a-server/src/config/config.test.ts @@ -19,6 +19,8 @@ import { AuthType, isHeadlessMode, FatalAuthenticationError, + PolicyDecision, + PRIORITY_YOLO_ALLOW_ALL, } from '@google/gemini-cli-core'; // Mock dependencies @@ -325,6 +327,29 @@ describe('loadConfig', () => { ); }); + it('should pass enableAgents to Config constructor', async () => { + const settings: Settings = { + experimental: { + enableAgents: false, + }, + }; + await loadConfig(settings, mockExtensionLoader, taskId); + expect(Config).toHaveBeenCalledWith( + expect.objectContaining({ + enableAgents: false, + }), + ); + }); + + it('should default enableAgents to true when not provided', async () => { + await loadConfig(mockSettings, mockExtensionLoader, taskId); + expect(Config).toHaveBeenCalledWith( + expect.objectContaining({ + enableAgents: true, + }), + ); + }); + describe('interactivity', () => { it('should set interactive true when not headless', async () => { vi.mocked(isHeadlessMode).mockReturnValue(false); @@ -349,6 +374,41 @@ describe('loadConfig', () => { }); }); + describe('YOLO mode', () => { + it('should enable YOLO mode and add policy rule when GEMINI_YOLO_MODE is true', async () => { + vi.stubEnv('GEMINI_YOLO_MODE', 'true'); + await loadConfig(mockSettings, mockExtensionLoader, taskId); + expect(Config).toHaveBeenCalledWith( + expect.objectContaining({ + approvalMode: 'yolo', + policyEngineConfig: expect.objectContaining({ + rules: expect.arrayContaining([ + expect.objectContaining({ + decision: PolicyDecision.ALLOW, + priority: PRIORITY_YOLO_ALLOW_ALL, + modes: ['yolo'], + allowRedirection: true, + }), + ]), + }), + }), + ); + }); + + it('should use default approval mode and empty rules when GEMINI_YOLO_MODE is not true', async () => { + vi.stubEnv('GEMINI_YOLO_MODE', 'false'); + await loadConfig(mockSettings, mockExtensionLoader, taskId); + expect(Config).toHaveBeenCalledWith( + expect.objectContaining({ + approvalMode: 'default', + policyEngineConfig: expect.objectContaining({ + rules: [], + }), + }), + ); + }); + }); + describe('authentication fallback', () => { beforeEach(() => { vi.stubEnv('USE_CCPA', 'true'); diff --git a/packages/a2a-server/src/config/config.ts b/packages/a2a-server/src/config/config.ts index 607695f173..9474c4d9c5 100644 --- a/packages/a2a-server/src/config/config.ts +++ b/packages/a2a-server/src/config/config.ts @@ -26,6 +26,8 @@ import { isHeadlessMode, FatalAuthenticationError, isCloudShell, + PolicyDecision, + PRIORITY_YOLO_ALLOW_ALL, type TelemetryTarget, type ConfigParameters, type ExtensionLoader, @@ -60,6 +62,11 @@ export async function loadConfig( } } + const approvalMode = + process.env['GEMINI_YOLO_MODE'] === 'true' + ? ApprovalMode.YOLO + : ApprovalMode.DEFAULT; + const configParams: ConfigParameters = { sessionId: taskId, clientName: 'a2a-server', @@ -74,10 +81,20 @@ export async function loadConfig( excludeTools: settings.excludeTools || settings.tools?.exclude || undefined, allowedTools: settings.allowedTools || settings.tools?.allowed || undefined, showMemoryUsage: settings.showMemoryUsage || false, - approvalMode: - process.env['GEMINI_YOLO_MODE'] === 'true' - ? ApprovalMode.YOLO - : ApprovalMode.DEFAULT, + approvalMode, + policyEngineConfig: { + rules: + approvalMode === ApprovalMode.YOLO + ? [ + { + decision: PolicyDecision.ALLOW, + priority: PRIORITY_YOLO_ALLOW_ALL, + modes: [ApprovalMode.YOLO], + allowRedirection: true, + }, + ] + : [], + }, mcpServers: settings.mcpServers, cwd: workspaceDir, telemetry: { @@ -110,6 +127,7 @@ export async function loadConfig( interactive: !isHeadlessMode(), enableInteractiveShell: !isHeadlessMode(), ptyInfo: 'auto', + enableAgents: settings.experimental?.enableAgents ?? true, }; const fileService = new FileDiscoveryService(workspaceDir, { diff --git a/packages/a2a-server/src/config/settings.test.ts b/packages/a2a-server/src/config/settings.test.ts index 7c51950535..ab80bced24 100644 --- a/packages/a2a-server/src/config/settings.test.ts +++ b/packages/a2a-server/src/config/settings.test.ts @@ -112,6 +112,18 @@ describe('loadSettings', () => { expect(result.fileFiltering?.respectGitIgnore).toBe(true); }); + it('should load experimental settings correctly', () => { + const settings = { + experimental: { + enableAgents: true, + }, + }; + fs.writeFileSync(USER_SETTINGS_PATH, JSON.stringify(settings)); + + const result = loadSettings(mockWorkspaceDir); + expect(result.experimental?.enableAgents).toBe(true); + }); + it('should overwrite top-level settings from workspace (shallow merge)', () => { const userSettings = { showMemoryUsage: false, diff --git a/packages/a2a-server/src/config/settings.ts b/packages/a2a-server/src/config/settings.ts index da9db4e069..ced11a4daa 100644 --- a/packages/a2a-server/src/config/settings.ts +++ b/packages/a2a-server/src/config/settings.ts @@ -48,6 +48,9 @@ export interface Settings { enableRecursiveFileSearch?: boolean; customIgnoreFilePaths?: string[]; }; + experimental?: { + enableAgents?: boolean; + }; } export interface SettingsError { diff --git a/packages/core/src/agents/a2a-client-manager.test.ts b/packages/core/src/agents/a2a-client-manager.test.ts index 0a0aa4d956..f4a39c1d36 100644 --- a/packages/core/src/agents/a2a-client-manager.test.ts +++ b/packages/core/src/agents/a2a-client-manager.test.ts @@ -66,11 +66,13 @@ describe('A2AClientManager', () => { }; const authFetchMock = vi.fn(); + const mockConfig = { + getProxy: vi.fn(), + } as unknown as Config; beforeEach(() => { vi.clearAllMocks(); - A2AClientManager.resetInstanceForTesting(); - manager = A2AClientManager.getInstance(); + manager = new A2AClientManager(mockConfig); // Re-create the instances as plain objects that can be spied on const factoryInstance = { @@ -124,12 +126,6 @@ describe('A2AClientManager', () => { vi.unstubAllGlobals(); }); - it('should enforce the singleton pattern', () => { - const instance1 = A2AClientManager.getInstance(); - const instance2 = A2AClientManager.getInstance(); - expect(instance1).toBe(instance2); - }); - describe('getInstance / dispatcher initialization', () => { it('should use UndiciAgent when no proxy is configured', async () => { await manager.loadAgent('TestAgent', 'http://test.agent/card'); @@ -152,12 +148,11 @@ describe('A2AClientManager', () => { }); it('should use ProxyAgent when a proxy is configured via Config', async () => { - A2AClientManager.resetInstanceForTesting(); - const mockConfig = { + const mockConfigWithProxy = { getProxy: () => 'http://my-proxy:8080', } as Config; - manager = A2AClientManager.getInstance(mockConfig); + manager = new A2AClientManager(mockConfigWithProxy); await manager.loadAgent('TestProxyAgent', 'http://test.proxy.agent/card'); const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock diff --git a/packages/core/src/agents/a2a-client-manager.ts b/packages/core/src/agents/a2a-client-manager.ts index 3a03c033d8..c15d34179c 100644 --- a/packages/core/src/agents/a2a-client-manager.ts +++ b/packages/core/src/agents/a2a-client-manager.ts @@ -49,8 +49,6 @@ const A2A_TIMEOUT = 1800000; // 30 minutes * Manages protocol negotiation, authentication, and transport selection. */ export class A2AClientManager { - private static instance: A2AClientManager; - // Each agent should manage their own context/taskIds/card/etc private clients = new Map(); private agentCards = new Map(); @@ -58,8 +56,8 @@ export class A2AClientManager { private a2aDispatcher: UndiciAgent | ProxyAgent; private a2aFetch: typeof fetch; - private constructor(config?: Config) { - const proxyUrl = config?.getProxy(); + constructor(private readonly config: Config) { + const proxyUrl = this.config.getProxy(); const agentOptions = { headersTimeout: A2A_TIMEOUT, bodyTimeout: A2A_TIMEOUT, @@ -78,25 +76,6 @@ export class A2AClientManager { fetch(input, { ...init, dispatcher: this.a2aDispatcher } as RequestInit); } - /** - * Gets the singleton instance of the A2AClientManager. - */ - static getInstance(config?: Config): A2AClientManager { - if (!A2AClientManager.instance) { - A2AClientManager.instance = new A2AClientManager(config); - } - return A2AClientManager.instance; - } - - /** - * Resets the singleton instance. Only for testing purposes. - * @internal - */ - static resetInstanceForTesting() { - // @ts-expect-error - Resetting singleton for testing - A2AClientManager.instance = undefined; - } - /** * Loads an agent by fetching its AgentCard and caches the client. * @param name The name to assign to the agent. diff --git a/packages/core/src/agents/registry.test.ts b/packages/core/src/agents/registry.test.ts index 49786de4b0..92bd3b2ec8 100644 --- a/packages/core/src/agents/registry.test.ts +++ b/packages/core/src/agents/registry.test.ts @@ -15,7 +15,7 @@ import type { } from '../config/config.js'; import { debugLogger } from '../utils/debugLogger.js'; import { coreEvents, CoreEvent } from '../utils/events.js'; -import { A2AClientManager } from './a2a-client-manager.js'; +import type { A2AClientManager } from './a2a-client-manager.js'; import { DEFAULT_GEMINI_FLASH_LITE_MODEL, DEFAULT_GEMINI_MODEL, @@ -40,9 +40,7 @@ vi.mock('./agentLoader.js', () => ({ })); vi.mock('./a2a-client-manager.js', () => ({ - A2AClientManager: { - getInstance: vi.fn(), - }, + A2AClientManager: vi.fn(), })); vi.mock('./auth-provider/factory.js', () => ({ @@ -450,7 +448,7 @@ describe('AgentRegistry', () => { ); // Mock A2AClientManager to avoid network calls - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({ loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }), clearCache: vi.fn(), } as unknown as A2AClientManager); @@ -548,7 +546,7 @@ describe('AgentRegistry', () => { inputConfig: { inputSchema: { type: 'object' } }, }; - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({ loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }), } as unknown as A2AClientManager); @@ -583,7 +581,7 @@ describe('AgentRegistry', () => { const loadAgentSpy = vi .fn() .mockResolvedValue({ name: 'RemoteAgentWithAuth' }); - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({ loadAgent: loadAgentSpy, clearCache: vi.fn(), } as unknown as A2AClientManager); @@ -622,7 +620,7 @@ describe('AgentRegistry', () => { vi.mocked(A2AAuthProviderFactory.create).mockResolvedValue(undefined); const loadAgentSpy = vi.fn(); - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({ loadAgent: loadAgentSpy, clearCache: vi.fn(), } as unknown as A2AClientManager); @@ -645,6 +643,9 @@ describe('AgentRegistry', () => { it('should log remote agent registration in debug mode', async () => { const debugConfig = makeMockedConfig({ debugMode: true }); const debugRegistry = new TestableAgentRegistry(debugConfig); + vi.spyOn(debugConfig, 'getA2AClientManager').mockReturnValue({ + loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }), + } as unknown as A2AClientManager); const debugLogSpy = vi .spyOn(debugLogger, 'log') .mockImplementation(() => {}); @@ -657,10 +658,6 @@ describe('AgentRegistry', () => { inputConfig: { inputSchema: { type: 'object' } }, }; - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ - loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }), - } as unknown as A2AClientManager); - await debugRegistry.testRegisterAgent(remoteAgent); expect(debugLogSpy).toHaveBeenCalledWith( @@ -688,7 +685,7 @@ describe('AgentRegistry', () => { new Error('ECONNREFUSED'), ); - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({ loadAgent: vi.fn().mockRejectedValue(a2aError), } as unknown as A2AClientManager); @@ -714,7 +711,7 @@ describe('AgentRegistry', () => { inputConfig: { inputSchema: { type: 'object' } }, }; - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({ loadAgent: vi.fn().mockRejectedValue(new Error('unexpected crash')), } as unknown as A2AClientManager); @@ -749,7 +746,7 @@ describe('AgentRegistry', () => { // No auth configured }; - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({ loadAgent: vi.fn().mockResolvedValue({ name: 'SecuredAgent', securitySchemes: { @@ -783,7 +780,7 @@ describe('AgentRegistry', () => { }; const error = new Error('401 Unauthorized'); - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({ loadAgent: vi.fn().mockRejectedValue(error), } as unknown as A2AClientManager); @@ -815,7 +812,7 @@ describe('AgentRegistry', () => { ], }; - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({ loadAgent: vi.fn().mockResolvedValue(mockAgentCard), clearCache: vi.fn(), } as unknown as A2AClientManager); @@ -843,7 +840,7 @@ describe('AgentRegistry', () => { skills: [{ name: 'Skill1', description: 'Desc1' }], }; - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({ loadAgent: vi.fn().mockResolvedValue(mockAgentCard), clearCache: vi.fn(), } as unknown as A2AClientManager); @@ -871,7 +868,7 @@ describe('AgentRegistry', () => { skills: [], }; - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({ loadAgent: vi.fn().mockResolvedValue(mockAgentCard), clearCache: vi.fn(), } as unknown as A2AClientManager); @@ -902,7 +899,7 @@ describe('AgentRegistry', () => { skills: [{ name: 'Skill1', description: 'Desc1' }], }; - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({ loadAgent: vi.fn().mockResolvedValue(mockAgentCard), clearCache: vi.fn(), } as unknown as A2AClientManager); @@ -930,7 +927,7 @@ describe('AgentRegistry', () => { inputConfig: { inputSchema: { type: 'object' } }, }; - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({ loadAgent: vi.fn().mockResolvedValue({ name: 'EmptyDescAgent', description: 'Loaded from card', @@ -955,7 +952,7 @@ describe('AgentRegistry', () => { inputConfig: { inputSchema: { type: 'object' } }, }; - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({ loadAgent: vi.fn().mockResolvedValue({ name: 'SkillFallbackAgent', description: 'Card description', @@ -1092,7 +1089,7 @@ describe('AgentRegistry', () => { inputConfig: { inputSchema: { type: 'object' } }, }; - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({ loadAgent: vi.fn().mockResolvedValue({ name: 'RemotePolicyAgent' }), } as unknown as A2AClientManager); @@ -1141,7 +1138,7 @@ describe('AgentRegistry', () => { inputConfig: { inputSchema: { type: 'object' } }, }; - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({ loadAgent: vi.fn().mockResolvedValue({ name: 'OverwrittenAgent' }), } as unknown as A2AClientManager); @@ -1189,8 +1186,10 @@ describe('AgentRegistry', () => { }); const clearCacheSpy = vi.fn(); - vi.mocked(A2AClientManager.getInstance).mockReturnValue({ + vi.spyOn(config, 'getA2AClientManager').mockReturnValue({ clearCache: clearCacheSpy, + loadAgent: vi.fn(), + getClient: vi.fn(), } as unknown as A2AClientManager); const emitSpy = vi.spyOn(coreEvents, 'emitAgentsRefreshed'); diff --git a/packages/core/src/agents/registry.ts b/packages/core/src/agents/registry.ts index 3a815aa012..3c681266fa 100644 --- a/packages/core/src/agents/registry.ts +++ b/packages/core/src/agents/registry.ts @@ -13,7 +13,6 @@ import { CodebaseInvestigatorAgent } from './codebase-investigator.js'; import { CliHelpAgent } from './cli-help-agent.js'; import { GeneralistAgent } from './generalist-agent.js'; import { BrowserAgentDefinition } from './browser/browserAgentDefinition.js'; -import { A2AClientManager } from './a2a-client-manager.js'; import { A2AAuthProviderFactory } from './auth-provider/factory.js'; import type { AuthenticationHandler } from '@a2a-js/sdk/client'; import { type z } from 'zod'; @@ -69,7 +68,7 @@ export class AgentRegistry { * Clears the current registry and re-scans for agents. */ async reload(): Promise { - A2AClientManager.getInstance(this.config).clearCache(); + this.config.getA2AClientManager()?.clearCache(); await this.config.reloadAgents(); this.agents.clear(); this.allDefinitions.clear(); @@ -414,7 +413,13 @@ export class AgentRegistry { // Load the remote A2A agent card and register. try { - const clientManager = A2AClientManager.getInstance(this.config); + const clientManager = this.config.getA2AClientManager(); + if (!clientManager) { + debugLogger.warn( + `[AgentRegistry] Skipping remote agent '${definition.name}': A2AClientManager is not available.`, + ); + return; + } let authHandler: AuthenticationHandler | undefined; if (definition.auth) { const provider = await A2AAuthProviderFactory.create({ diff --git a/packages/core/src/agents/remote-invocation.test.ts b/packages/core/src/agents/remote-invocation.test.ts index e186cc7aa9..870071b321 100644 --- a/packages/core/src/agents/remote-invocation.test.ts +++ b/packages/core/src/agents/remote-invocation.test.ts @@ -13,21 +13,27 @@ import { afterEach, type Mock, } from 'vitest'; +import type { Client } from '@a2a-js/sdk/client'; import { RemoteAgentInvocation } from './remote-invocation.js'; import { - A2AClientManager, type SendMessageResult, + type A2AClientManager, } from './a2a-client-manager.js'; + import type { RemoteAgentDefinition } from './types.js'; import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; import { A2AAuthProviderFactory } from './auth-provider/factory.js'; import type { A2AAuthProvider } from './auth-provider/types.js'; +import type { AgentLoopContext } from '../config/agent-loop-context.js'; +import type { Config } from '../config/config.js'; // Mock A2AClientManager vi.mock('./a2a-client-manager.js', () => ({ - A2AClientManager: { - getInstance: vi.fn(), - }, + A2AClientManager: vi.fn().mockImplementation(() => ({ + getClient: vi.fn(), + loadAgent: vi.fn(), + sendMessageStream: vi.fn(), + })), })); // Mock A2AAuthProviderFactory @@ -49,16 +55,40 @@ describe('RemoteAgentInvocation', () => { }, }; - const mockClientManager = { - getClient: vi.fn(), - loadAgent: vi.fn(), - sendMessageStream: vi.fn(), + let mockClientManager: { + getClient: Mock; + loadAgent: Mock; + sendMessageStream: Mock; }; + let mockContext: AgentLoopContext; const mockMessageBus = createMockMessageBus(); + const mockClient = { + sendMessageStream: vi.fn(), + getTask: vi.fn(), + cancelTask: vi.fn(), + } as unknown as Client; + beforeEach(() => { vi.clearAllMocks(); - (A2AClientManager.getInstance as Mock).mockReturnValue(mockClientManager); + + mockClientManager = { + getClient: vi.fn(), + loadAgent: vi.fn(), + sendMessageStream: vi.fn(), + }; + + const mockConfig = { + getA2AClientManager: vi.fn().mockReturnValue(mockClientManager), + injectionService: { + getLatestInjectionIndex: vi.fn().mockReturnValue(0), + }, + } as unknown as Config; + + mockContext = { + config: mockConfig, + } as unknown as AgentLoopContext; + ( RemoteAgentInvocation as unknown as { sessionState?: Map; @@ -75,6 +105,7 @@ describe('RemoteAgentInvocation', () => { expect(() => { new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'valid' }, mockMessageBus, ); @@ -83,12 +114,17 @@ describe('RemoteAgentInvocation', () => { it('accepts missing query (defaults to "Get Started!")', () => { expect(() => { - new RemoteAgentInvocation(mockDefinition, {}, mockMessageBus); + new RemoteAgentInvocation( + mockDefinition, + mockContext, + {}, + mockMessageBus, + ); }).not.toThrow(); }); it('uses "Get Started!" default when query is missing during execution', async () => { - mockClientManager.getClient.mockReturnValue({}); + mockClientManager.getClient.mockReturnValue(mockClient); mockClientManager.sendMessageStream.mockImplementation( async function* () { yield { @@ -102,6 +138,7 @@ describe('RemoteAgentInvocation', () => { const invocation = new RemoteAgentInvocation( mockDefinition, + mockContext, {}, mockMessageBus, ); @@ -118,6 +155,7 @@ describe('RemoteAgentInvocation', () => { expect(() => { new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 123 }, mockMessageBus, ); @@ -141,6 +179,7 @@ describe('RemoteAgentInvocation', () => { const invocation = new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'hi', }, @@ -187,6 +226,7 @@ describe('RemoteAgentInvocation', () => { const invocation = new RemoteAgentInvocation( authDefinition, + mockContext, { query: 'hi' }, mockMessageBus, ); @@ -220,6 +260,7 @@ describe('RemoteAgentInvocation', () => { const invocation = new RemoteAgentInvocation( authDefinition, + mockContext, { query: 'hi' }, mockMessageBus, ); @@ -231,7 +272,7 @@ describe('RemoteAgentInvocation', () => { }); it('should not load the agent if already present', async () => { - mockClientManager.getClient.mockReturnValue({}); + mockClientManager.getClient.mockReturnValue(mockClient); mockClientManager.sendMessageStream.mockImplementation( async function* () { yield { @@ -245,6 +286,7 @@ describe('RemoteAgentInvocation', () => { const invocation = new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'hi', }, @@ -256,7 +298,7 @@ describe('RemoteAgentInvocation', () => { }); it('should persist contextId and taskId across invocations', async () => { - mockClientManager.getClient.mockReturnValue({}); + mockClientManager.getClient.mockReturnValue(mockClient); // First call return values mockClientManager.sendMessageStream.mockImplementationOnce( @@ -274,6 +316,7 @@ describe('RemoteAgentInvocation', () => { const invocation1 = new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'first', }, @@ -305,6 +348,7 @@ describe('RemoteAgentInvocation', () => { const invocation2 = new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'second', }, @@ -335,6 +379,7 @@ describe('RemoteAgentInvocation', () => { const invocation3 = new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'third', }, @@ -356,6 +401,7 @@ describe('RemoteAgentInvocation', () => { const invocation4 = new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'fourth', }, @@ -371,7 +417,7 @@ describe('RemoteAgentInvocation', () => { }); it('should handle streaming updates and reassemble output', async () => { - mockClientManager.getClient.mockReturnValue({}); + mockClientManager.getClient.mockReturnValue(mockClient); mockClientManager.sendMessageStream.mockImplementation( async function* () { yield { @@ -392,6 +438,7 @@ describe('RemoteAgentInvocation', () => { const updateOutput = vi.fn(); const invocation = new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'hi' }, mockMessageBus, ); @@ -402,7 +449,7 @@ describe('RemoteAgentInvocation', () => { }); it('should abort when signal is aborted during streaming', async () => { - mockClientManager.getClient.mockReturnValue({}); + mockClientManager.getClient.mockReturnValue(mockClient); const controller = new AbortController(); mockClientManager.sendMessageStream.mockImplementation( async function* () { @@ -425,6 +472,7 @@ describe('RemoteAgentInvocation', () => { const invocation = new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'hi' }, mockMessageBus, ); @@ -435,7 +483,7 @@ describe('RemoteAgentInvocation', () => { }); it('should handle errors gracefully', async () => { - mockClientManager.getClient.mockReturnValue({}); + mockClientManager.getClient.mockReturnValue(mockClient); mockClientManager.sendMessageStream.mockImplementation( async function* () { if (Math.random() < 0) yield {} as unknown as SendMessageResult; @@ -445,6 +493,7 @@ describe('RemoteAgentInvocation', () => { const invocation = new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'hi', }, @@ -458,7 +507,7 @@ describe('RemoteAgentInvocation', () => { }); it('should use a2a helpers for extracting text', async () => { - mockClientManager.getClient.mockReturnValue({}); + mockClientManager.getClient.mockReturnValue(mockClient); // Mock a complex message part that needs extraction mockClientManager.sendMessageStream.mockImplementation( async function* () { @@ -476,6 +525,7 @@ describe('RemoteAgentInvocation', () => { const invocation = new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'hi', }, @@ -488,7 +538,7 @@ describe('RemoteAgentInvocation', () => { }); it('should handle mixed response types during streaming (TaskStatusUpdateEvent + Message)', async () => { - mockClientManager.getClient.mockReturnValue({}); + mockClientManager.getClient.mockReturnValue(mockClient); mockClientManager.sendMessageStream.mockImplementation( async function* () { yield { @@ -518,6 +568,7 @@ describe('RemoteAgentInvocation', () => { const updateOutput = vi.fn(); const invocation = new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'hi' }, mockMessageBus, ); @@ -532,17 +583,20 @@ describe('RemoteAgentInvocation', () => { }); it('should handle artifact reassembly with append: true', async () => { - mockClientManager.getClient.mockReturnValue({}); + mockClientManager.getClient.mockReturnValue(mockClient); mockClientManager.sendMessageStream.mockImplementation( async function* () { yield { kind: 'status-update', taskId: 'task-1', + contextId: 'ctx-1', + final: false, status: { state: 'working', message: { kind: 'message', role: 'agent', + messageId: 'm1', parts: [{ kind: 'text', text: 'Generating...' }], }, }, @@ -550,6 +604,7 @@ describe('RemoteAgentInvocation', () => { yield { kind: 'artifact-update', taskId: 'task-1', + contextId: 'ctx-1', append: false, artifact: { artifactId: 'art-1', @@ -560,18 +615,21 @@ describe('RemoteAgentInvocation', () => { yield { kind: 'artifact-update', taskId: 'task-1', + contextId: 'ctx-1', append: true, artifact: { artifactId: 'art-1', parts: [{ kind: 'text', text: ' Part 2' }], }, }; + return; }, ); const updateOutput = vi.fn(); const invocation = new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'hi' }, mockMessageBus, ); @@ -591,6 +649,7 @@ describe('RemoteAgentInvocation', () => { it('should return info confirmation details', async () => { const invocation = new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'hi', }, @@ -629,6 +688,7 @@ describe('RemoteAgentInvocation', () => { const invocation = new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'hi' }, mockMessageBus, ); @@ -646,6 +706,7 @@ describe('RemoteAgentInvocation', () => { const invocation = new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'hi' }, mockMessageBus, ); @@ -658,7 +719,7 @@ describe('RemoteAgentInvocation', () => { }); it('should include partial output when error occurs mid-stream', async () => { - mockClientManager.getClient.mockReturnValue({}); + mockClientManager.getClient.mockReturnValue(mockClient); mockClientManager.sendMessageStream.mockImplementation( async function* () { yield { @@ -674,6 +735,7 @@ describe('RemoteAgentInvocation', () => { const invocation = new RemoteAgentInvocation( mockDefinition, + mockContext, { query: 'hi' }, mockMessageBus, ); diff --git a/packages/core/src/agents/remote-invocation.ts b/packages/core/src/agents/remote-invocation.ts index 489f0f91cc..0933ca026e 100644 --- a/packages/core/src/agents/remote-invocation.ts +++ b/packages/core/src/agents/remote-invocation.ts @@ -16,10 +16,11 @@ import { type RemoteAgentDefinition, type AgentInputs, } from './types.js'; +import { type AgentLoopContext } from '../config/agent-loop-context.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; -import { +import type { A2AClientManager, - type SendMessageResult, + SendMessageResult, } from './a2a-client-manager.js'; import { extractIdsFromResponse, A2AResultReassembler } from './a2aUtils.js'; import type { AuthenticationHandler } from '@a2a-js/sdk/client'; @@ -47,13 +48,13 @@ export class RemoteAgentInvocation extends BaseToolInvocation< // State for the ongoing conversation with the remote agent private contextId: string | undefined; private taskId: string | undefined; - // TODO: See if we can reuse the singleton from AppContainer or similar, but for now use getInstance directly - // as per the current pattern in the codebase. - private readonly clientManager = A2AClientManager.getInstance(); + + private readonly clientManager: A2AClientManager; private authHandler: AuthenticationHandler | undefined; constructor( private readonly definition: RemoteAgentDefinition, + private readonly context: AgentLoopContext, params: AgentInputs, messageBus: MessageBus, _toolName?: string, @@ -72,6 +73,13 @@ export class RemoteAgentInvocation extends BaseToolInvocation< _toolName ?? definition.name, _toolDisplayName ?? definition.displayName, ); + const clientManager = this.context.config.getA2AClientManager(); + if (!clientManager) { + throw new Error( + `Failed to initialize RemoteAgentInvocation for '${definition.name}': A2AClientManager is not available.`, + ); + } + this.clientManager = clientManager; } getDescription(): string { diff --git a/packages/core/src/agents/subagent-tool-wrapper.ts b/packages/core/src/agents/subagent-tool-wrapper.ts index cf6d1e7112..30a30d76d0 100644 --- a/packages/core/src/agents/subagent-tool-wrapper.ts +++ b/packages/core/src/agents/subagent-tool-wrapper.ts @@ -75,6 +75,7 @@ export class SubagentToolWrapper extends BaseDeclarativeTool< if (definition.kind === 'remote') { return new RemoteAgentInvocation( definition, + this.context, params, effectiveMessageBus, _toolName, diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 5b291977f5..eff489dcd6 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -1523,7 +1523,7 @@ describe('Server Config (config.ts)', () => { const paramsWithProxy: ConfigParameters = { ...baseParams, - proxy: 'invalid-proxy', + proxy: 'http://invalid-proxy:8080', }; new Config(paramsWithProxy); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 7dc4636c18..fcb6613756 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -405,6 +405,7 @@ import { SimpleExtensionLoader, } from '../utils/extensionLoader.js'; import { McpClientManager } from '../tools/mcp-client-manager.js'; +import { A2AClientManager } from '../agents/a2a-client-manager.js'; import { type McpContext } from '../tools/mcp-client.js'; import type { EnvironmentSanitizationConfig } from '../services/environmentSanitization.js'; import { getErrorMessage } from '../utils/errors.js'; @@ -653,6 +654,7 @@ export interface ConfigParameters { export class Config implements McpContext, AgentLoopContext { private _toolRegistry!: ToolRegistry; private mcpClientManager?: McpClientManager; + private readonly a2aClientManager?: A2AClientManager; private allowedMcpServers: string[]; private blockedMcpServers: string[]; private allowedEnvironmentVariables: string[]; @@ -1188,6 +1190,7 @@ export class Config implements McpContext, AgentLoopContext { params.toolSandboxing ?? false, this.targetDir, ); + this.a2aClientManager = new A2AClientManager(this); this.shellExecutionConfig.sandboxManager = this._sandboxManager; this.modelRouterService = new ModelRouterService(this); } @@ -2000,6 +2003,10 @@ export class Config implements McpContext, AgentLoopContext { return this.mcpClientManager; } + getA2AClientManager(): A2AClientManager | undefined { + return this.a2aClientManager; + } + setUserInteractedWithMcp(): void { this.mcpClientManager?.setUserInteractedWithMcp(); } diff --git a/packages/core/src/policy/policy-engine.test.ts b/packages/core/src/policy/policy-engine.test.ts index 376e465604..b8865ba587 100644 --- a/packages/core/src/policy/policy-engine.test.ts +++ b/packages/core/src/policy/policy-engine.test.ts @@ -15,6 +15,7 @@ import { ApprovalMode, PRIORITY_SUBAGENT_TOOL, ALWAYS_ALLOW_PRIORITY_FRACTION, + PRIORITY_YOLO_ALLOW_ALL, } from './types.js'; import type { FunctionCall } from '@google/genai'; import { SafetyCheckDecision } from '../safety/protocol.js'; @@ -2852,7 +2853,7 @@ describe('PolicyEngine', () => { }, { decision: PolicyDecision.ALLOW, - priority: 998, + priority: PRIORITY_YOLO_ALLOW_ALL, modes: [ApprovalMode.YOLO], }, ]; @@ -2879,7 +2880,7 @@ describe('PolicyEngine', () => { }, { decision: PolicyDecision.ALLOW, - priority: 998, + priority: PRIORITY_YOLO_ALLOW_ALL, modes: [ApprovalMode.YOLO], }, ]; diff --git a/packages/core/src/policy/types.ts b/packages/core/src/policy/types.ts index 6e14e1fac9..a3a919e1cd 100644 --- a/packages/core/src/policy/types.ts +++ b/packages/core/src/policy/types.ts @@ -345,3 +345,9 @@ export const ALWAYS_ALLOW_PRIORITY_FRACTION = 950; */ export const ALWAYS_ALLOW_PRIORITY_OFFSET = ALWAYS_ALLOW_PRIORITY_FRACTION / 1000; + +/** + * Priority for the YOLO "allow all" rule. + * Matches the raw priority used in yolo.toml. + */ +export const PRIORITY_YOLO_ALLOW_ALL = 998;