feat(a2a): add agent acknowledgment command and enhance registry discovery (#22389)

This commit is contained in:
Alisa
2026-03-17 15:47:05 -07:00
committed by GitHub
parent fb9264bf80
commit 7ae39fd622
15 changed files with 250 additions and 94 deletions

View File

@@ -19,6 +19,8 @@ import {
AuthType,
isHeadlessMode,
FatalAuthenticationError,
PolicyDecision,
PRIORITY_YOLO_ALLOW_ALL,
} from '@google/gemini-cli-core';
// Mock dependencies
@@ -325,6 +327,29 @@ describe('loadConfig', () => {
);
});
it('should pass enableAgents to Config constructor', async () => {
const settings: Settings = {
experimental: {
enableAgents: false,
},
};
await loadConfig(settings, mockExtensionLoader, taskId);
expect(Config).toHaveBeenCalledWith(
expect.objectContaining({
enableAgents: false,
}),
);
});
it('should default enableAgents to true when not provided', async () => {
await loadConfig(mockSettings, mockExtensionLoader, taskId);
expect(Config).toHaveBeenCalledWith(
expect.objectContaining({
enableAgents: true,
}),
);
});
describe('interactivity', () => {
it('should set interactive true when not headless', async () => {
vi.mocked(isHeadlessMode).mockReturnValue(false);
@@ -349,6 +374,41 @@ describe('loadConfig', () => {
});
});
describe('YOLO mode', () => {
it('should enable YOLO mode and add policy rule when GEMINI_YOLO_MODE is true', async () => {
vi.stubEnv('GEMINI_YOLO_MODE', 'true');
await loadConfig(mockSettings, mockExtensionLoader, taskId);
expect(Config).toHaveBeenCalledWith(
expect.objectContaining({
approvalMode: 'yolo',
policyEngineConfig: expect.objectContaining({
rules: expect.arrayContaining([
expect.objectContaining({
decision: PolicyDecision.ALLOW,
priority: PRIORITY_YOLO_ALLOW_ALL,
modes: ['yolo'],
allowRedirection: true,
}),
]),
}),
}),
);
});
it('should use default approval mode and empty rules when GEMINI_YOLO_MODE is not true', async () => {
vi.stubEnv('GEMINI_YOLO_MODE', 'false');
await loadConfig(mockSettings, mockExtensionLoader, taskId);
expect(Config).toHaveBeenCalledWith(
expect.objectContaining({
approvalMode: 'default',
policyEngineConfig: expect.objectContaining({
rules: [],
}),
}),
);
});
});
describe('authentication fallback', () => {
beforeEach(() => {
vi.stubEnv('USE_CCPA', 'true');

View File

@@ -26,6 +26,8 @@ import {
isHeadlessMode,
FatalAuthenticationError,
isCloudShell,
PolicyDecision,
PRIORITY_YOLO_ALLOW_ALL,
type TelemetryTarget,
type ConfigParameters,
type ExtensionLoader,
@@ -60,6 +62,11 @@ export async function loadConfig(
}
}
const approvalMode =
process.env['GEMINI_YOLO_MODE'] === 'true'
? ApprovalMode.YOLO
: ApprovalMode.DEFAULT;
const configParams: ConfigParameters = {
sessionId: taskId,
clientName: 'a2a-server',
@@ -74,10 +81,20 @@ export async function loadConfig(
excludeTools: settings.excludeTools || settings.tools?.exclude || undefined,
allowedTools: settings.allowedTools || settings.tools?.allowed || undefined,
showMemoryUsage: settings.showMemoryUsage || false,
approvalMode:
process.env['GEMINI_YOLO_MODE'] === 'true'
? ApprovalMode.YOLO
: ApprovalMode.DEFAULT,
approvalMode,
policyEngineConfig: {
rules:
approvalMode === ApprovalMode.YOLO
? [
{
decision: PolicyDecision.ALLOW,
priority: PRIORITY_YOLO_ALLOW_ALL,
modes: [ApprovalMode.YOLO],
allowRedirection: true,
},
]
: [],
},
mcpServers: settings.mcpServers,
cwd: workspaceDir,
telemetry: {
@@ -110,6 +127,7 @@ export async function loadConfig(
interactive: !isHeadlessMode(),
enableInteractiveShell: !isHeadlessMode(),
ptyInfo: 'auto',
enableAgents: settings.experimental?.enableAgents ?? true,
};
const fileService = new FileDiscoveryService(workspaceDir, {

View File

@@ -112,6 +112,18 @@ describe('loadSettings', () => {
expect(result.fileFiltering?.respectGitIgnore).toBe(true);
});
it('should load experimental settings correctly', () => {
const settings = {
experimental: {
enableAgents: true,
},
};
fs.writeFileSync(USER_SETTINGS_PATH, JSON.stringify(settings));
const result = loadSettings(mockWorkspaceDir);
expect(result.experimental?.enableAgents).toBe(true);
});
it('should overwrite top-level settings from workspace (shallow merge)', () => {
const userSettings = {
showMemoryUsage: false,

View File

@@ -48,6 +48,9 @@ export interface Settings {
enableRecursiveFileSearch?: boolean;
customIgnoreFilePaths?: string[];
};
experimental?: {
enableAgents?: boolean;
};
}
export interface SettingsError {

View File

@@ -66,11 +66,13 @@ describe('A2AClientManager', () => {
};
const authFetchMock = vi.fn();
const mockConfig = {
getProxy: vi.fn(),
} as unknown as Config;
beforeEach(() => {
vi.clearAllMocks();
A2AClientManager.resetInstanceForTesting();
manager = A2AClientManager.getInstance();
manager = new A2AClientManager(mockConfig);
// Re-create the instances as plain objects that can be spied on
const factoryInstance = {
@@ -124,12 +126,6 @@ describe('A2AClientManager', () => {
vi.unstubAllGlobals();
});
it('should enforce the singleton pattern', () => {
const instance1 = A2AClientManager.getInstance();
const instance2 = A2AClientManager.getInstance();
expect(instance1).toBe(instance2);
});
describe('getInstance / dispatcher initialization', () => {
it('should use UndiciAgent when no proxy is configured', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
@@ -152,12 +148,11 @@ describe('A2AClientManager', () => {
});
it('should use ProxyAgent when a proxy is configured via Config', async () => {
A2AClientManager.resetInstanceForTesting();
const mockConfig = {
const mockConfigWithProxy = {
getProxy: () => 'http://my-proxy:8080',
} as Config;
manager = A2AClientManager.getInstance(mockConfig);
manager = new A2AClientManager(mockConfigWithProxy);
await manager.loadAgent('TestProxyAgent', 'http://test.proxy.agent/card');
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock

View File

@@ -49,8 +49,6 @@ const A2A_TIMEOUT = 1800000; // 30 minutes
* Manages protocol negotiation, authentication, and transport selection.
*/
export class A2AClientManager {
private static instance: A2AClientManager;
// Each agent should manage their own context/taskIds/card/etc
private clients = new Map<string, Client>();
private agentCards = new Map<string, AgentCard>();
@@ -58,8 +56,8 @@ export class A2AClientManager {
private a2aDispatcher: UndiciAgent | ProxyAgent;
private a2aFetch: typeof fetch;
private constructor(config?: Config) {
const proxyUrl = config?.getProxy();
constructor(private readonly config: Config) {
const proxyUrl = this.config.getProxy();
const agentOptions = {
headersTimeout: A2A_TIMEOUT,
bodyTimeout: A2A_TIMEOUT,
@@ -78,25 +76,6 @@ export class A2AClientManager {
fetch(input, { ...init, dispatcher: this.a2aDispatcher } as RequestInit);
}
/**
* Gets the singleton instance of the A2AClientManager.
*/
static getInstance(config?: Config): A2AClientManager {
if (!A2AClientManager.instance) {
A2AClientManager.instance = new A2AClientManager(config);
}
return A2AClientManager.instance;
}
/**
* Resets the singleton instance. Only for testing purposes.
* @internal
*/
static resetInstanceForTesting() {
// @ts-expect-error - Resetting singleton for testing
A2AClientManager.instance = undefined;
}
/**
* Loads an agent by fetching its AgentCard and caches the client.
* @param name The name to assign to the agent.

View File

@@ -15,7 +15,7 @@ import type {
} from '../config/config.js';
import { debugLogger } from '../utils/debugLogger.js';
import { coreEvents, CoreEvent } from '../utils/events.js';
import { A2AClientManager } from './a2a-client-manager.js';
import type { A2AClientManager } from './a2a-client-manager.js';
import {
DEFAULT_GEMINI_FLASH_LITE_MODEL,
DEFAULT_GEMINI_MODEL,
@@ -40,9 +40,7 @@ vi.mock('./agentLoader.js', () => ({
}));
vi.mock('./a2a-client-manager.js', () => ({
A2AClientManager: {
getInstance: vi.fn(),
},
A2AClientManager: vi.fn(),
}));
vi.mock('./auth-provider/factory.js', () => ({
@@ -450,7 +448,7 @@ describe('AgentRegistry', () => {
);
// Mock A2AClientManager to avoid network calls
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }),
clearCache: vi.fn(),
} as unknown as A2AClientManager);
@@ -548,7 +546,7 @@ describe('AgentRegistry', () => {
inputConfig: { inputSchema: { type: 'object' } },
};
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }),
} as unknown as A2AClientManager);
@@ -583,7 +581,7 @@ describe('AgentRegistry', () => {
const loadAgentSpy = vi
.fn()
.mockResolvedValue({ name: 'RemoteAgentWithAuth' });
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: loadAgentSpy,
clearCache: vi.fn(),
} as unknown as A2AClientManager);
@@ -622,7 +620,7 @@ describe('AgentRegistry', () => {
vi.mocked(A2AAuthProviderFactory.create).mockResolvedValue(undefined);
const loadAgentSpy = vi.fn();
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: loadAgentSpy,
clearCache: vi.fn(),
} as unknown as A2AClientManager);
@@ -645,6 +643,9 @@ describe('AgentRegistry', () => {
it('should log remote agent registration in debug mode', async () => {
const debugConfig = makeMockedConfig({ debugMode: true });
const debugRegistry = new TestableAgentRegistry(debugConfig);
vi.spyOn(debugConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }),
} as unknown as A2AClientManager);
const debugLogSpy = vi
.spyOn(debugLogger, 'log')
.mockImplementation(() => {});
@@ -657,10 +658,6 @@ describe('AgentRegistry', () => {
inputConfig: { inputSchema: { type: 'object' } },
};
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }),
} as unknown as A2AClientManager);
await debugRegistry.testRegisterAgent(remoteAgent);
expect(debugLogSpy).toHaveBeenCalledWith(
@@ -688,7 +685,7 @@ describe('AgentRegistry', () => {
new Error('ECONNREFUSED'),
);
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: vi.fn().mockRejectedValue(a2aError),
} as unknown as A2AClientManager);
@@ -714,7 +711,7 @@ describe('AgentRegistry', () => {
inputConfig: { inputSchema: { type: 'object' } },
};
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: vi.fn().mockRejectedValue(new Error('unexpected crash')),
} as unknown as A2AClientManager);
@@ -749,7 +746,7 @@ describe('AgentRegistry', () => {
// No auth configured
};
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: vi.fn().mockResolvedValue({
name: 'SecuredAgent',
securitySchemes: {
@@ -783,7 +780,7 @@ describe('AgentRegistry', () => {
};
const error = new Error('401 Unauthorized');
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: vi.fn().mockRejectedValue(error),
} as unknown as A2AClientManager);
@@ -815,7 +812,7 @@ describe('AgentRegistry', () => {
],
};
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: vi.fn().mockResolvedValue(mockAgentCard),
clearCache: vi.fn(),
} as unknown as A2AClientManager);
@@ -843,7 +840,7 @@ describe('AgentRegistry', () => {
skills: [{ name: 'Skill1', description: 'Desc1' }],
};
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: vi.fn().mockResolvedValue(mockAgentCard),
clearCache: vi.fn(),
} as unknown as A2AClientManager);
@@ -871,7 +868,7 @@ describe('AgentRegistry', () => {
skills: [],
};
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: vi.fn().mockResolvedValue(mockAgentCard),
clearCache: vi.fn(),
} as unknown as A2AClientManager);
@@ -902,7 +899,7 @@ describe('AgentRegistry', () => {
skills: [{ name: 'Skill1', description: 'Desc1' }],
};
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: vi.fn().mockResolvedValue(mockAgentCard),
clearCache: vi.fn(),
} as unknown as A2AClientManager);
@@ -930,7 +927,7 @@ describe('AgentRegistry', () => {
inputConfig: { inputSchema: { type: 'object' } },
};
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: vi.fn().mockResolvedValue({
name: 'EmptyDescAgent',
description: 'Loaded from card',
@@ -955,7 +952,7 @@ describe('AgentRegistry', () => {
inputConfig: { inputSchema: { type: 'object' } },
};
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: vi.fn().mockResolvedValue({
name: 'SkillFallbackAgent',
description: 'Card description',
@@ -1092,7 +1089,7 @@ describe('AgentRegistry', () => {
inputConfig: { inputSchema: { type: 'object' } },
};
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: vi.fn().mockResolvedValue({ name: 'RemotePolicyAgent' }),
} as unknown as A2AClientManager);
@@ -1141,7 +1138,7 @@ describe('AgentRegistry', () => {
inputConfig: { inputSchema: { type: 'object' } },
};
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
loadAgent: vi.fn().mockResolvedValue({ name: 'OverwrittenAgent' }),
} as unknown as A2AClientManager);
@@ -1189,8 +1186,10 @@ describe('AgentRegistry', () => {
});
const clearCacheSpy = vi.fn();
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
vi.spyOn(config, 'getA2AClientManager').mockReturnValue({
clearCache: clearCacheSpy,
loadAgent: vi.fn(),
getClient: vi.fn(),
} as unknown as A2AClientManager);
const emitSpy = vi.spyOn(coreEvents, 'emitAgentsRefreshed');

View File

@@ -13,7 +13,6 @@ import { CodebaseInvestigatorAgent } from './codebase-investigator.js';
import { CliHelpAgent } from './cli-help-agent.js';
import { GeneralistAgent } from './generalist-agent.js';
import { BrowserAgentDefinition } from './browser/browserAgentDefinition.js';
import { A2AClientManager } from './a2a-client-manager.js';
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
import type { AuthenticationHandler } from '@a2a-js/sdk/client';
import { type z } from 'zod';
@@ -69,7 +68,7 @@ export class AgentRegistry {
* Clears the current registry and re-scans for agents.
*/
async reload(): Promise<void> {
A2AClientManager.getInstance(this.config).clearCache();
this.config.getA2AClientManager()?.clearCache();
await this.config.reloadAgents();
this.agents.clear();
this.allDefinitions.clear();
@@ -414,7 +413,13 @@ export class AgentRegistry {
// Load the remote A2A agent card and register.
try {
const clientManager = A2AClientManager.getInstance(this.config);
const clientManager = this.config.getA2AClientManager();
if (!clientManager) {
debugLogger.warn(
`[AgentRegistry] Skipping remote agent '${definition.name}': A2AClientManager is not available.`,
);
return;
}
let authHandler: AuthenticationHandler | undefined;
if (definition.auth) {
const provider = await A2AAuthProviderFactory.create({

View File

@@ -13,21 +13,27 @@ import {
afterEach,
type Mock,
} from 'vitest';
import type { Client } from '@a2a-js/sdk/client';
import { RemoteAgentInvocation } from './remote-invocation.js';
import {
A2AClientManager,
type SendMessageResult,
type A2AClientManager,
} from './a2a-client-manager.js';
import type { RemoteAgentDefinition } from './types.js';
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
import type { A2AAuthProvider } from './auth-provider/types.js';
import type { AgentLoopContext } from '../config/agent-loop-context.js';
import type { Config } from '../config/config.js';
// Mock A2AClientManager
vi.mock('./a2a-client-manager.js', () => ({
A2AClientManager: {
getInstance: vi.fn(),
},
A2AClientManager: vi.fn().mockImplementation(() => ({
getClient: vi.fn(),
loadAgent: vi.fn(),
sendMessageStream: vi.fn(),
})),
}));
// Mock A2AAuthProviderFactory
@@ -49,16 +55,40 @@ describe('RemoteAgentInvocation', () => {
},
};
const mockClientManager = {
getClient: vi.fn(),
loadAgent: vi.fn(),
sendMessageStream: vi.fn(),
let mockClientManager: {
getClient: Mock<A2AClientManager['getClient']>;
loadAgent: Mock<A2AClientManager['loadAgent']>;
sendMessageStream: Mock<A2AClientManager['sendMessageStream']>;
};
let mockContext: AgentLoopContext;
const mockMessageBus = createMockMessageBus();
const mockClient = {
sendMessageStream: vi.fn(),
getTask: vi.fn(),
cancelTask: vi.fn(),
} as unknown as Client;
beforeEach(() => {
vi.clearAllMocks();
(A2AClientManager.getInstance as Mock).mockReturnValue(mockClientManager);
mockClientManager = {
getClient: vi.fn(),
loadAgent: vi.fn(),
sendMessageStream: vi.fn(),
};
const mockConfig = {
getA2AClientManager: vi.fn().mockReturnValue(mockClientManager),
injectionService: {
getLatestInjectionIndex: vi.fn().mockReturnValue(0),
},
} as unknown as Config;
mockContext = {
config: mockConfig,
} as unknown as AgentLoopContext;
(
RemoteAgentInvocation as unknown as {
sessionState?: Map<string, { contextId?: string; taskId?: string }>;
@@ -75,6 +105,7 @@ describe('RemoteAgentInvocation', () => {
expect(() => {
new RemoteAgentInvocation(
mockDefinition,
mockContext,
{ query: 'valid' },
mockMessageBus,
);
@@ -83,12 +114,17 @@ describe('RemoteAgentInvocation', () => {
it('accepts missing query (defaults to "Get Started!")', () => {
expect(() => {
new RemoteAgentInvocation(mockDefinition, {}, mockMessageBus);
new RemoteAgentInvocation(
mockDefinition,
mockContext,
{},
mockMessageBus,
);
}).not.toThrow();
});
it('uses "Get Started!" default when query is missing during execution', async () => {
mockClientManager.getClient.mockReturnValue({});
mockClientManager.getClient.mockReturnValue(mockClient);
mockClientManager.sendMessageStream.mockImplementation(
async function* () {
yield {
@@ -102,6 +138,7 @@ describe('RemoteAgentInvocation', () => {
const invocation = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{},
mockMessageBus,
);
@@ -118,6 +155,7 @@ describe('RemoteAgentInvocation', () => {
expect(() => {
new RemoteAgentInvocation(
mockDefinition,
mockContext,
{ query: 123 },
mockMessageBus,
);
@@ -141,6 +179,7 @@ describe('RemoteAgentInvocation', () => {
const invocation = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{
query: 'hi',
},
@@ -187,6 +226,7 @@ describe('RemoteAgentInvocation', () => {
const invocation = new RemoteAgentInvocation(
authDefinition,
mockContext,
{ query: 'hi' },
mockMessageBus,
);
@@ -220,6 +260,7 @@ describe('RemoteAgentInvocation', () => {
const invocation = new RemoteAgentInvocation(
authDefinition,
mockContext,
{ query: 'hi' },
mockMessageBus,
);
@@ -231,7 +272,7 @@ describe('RemoteAgentInvocation', () => {
});
it('should not load the agent if already present', async () => {
mockClientManager.getClient.mockReturnValue({});
mockClientManager.getClient.mockReturnValue(mockClient);
mockClientManager.sendMessageStream.mockImplementation(
async function* () {
yield {
@@ -245,6 +286,7 @@ describe('RemoteAgentInvocation', () => {
const invocation = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{
query: 'hi',
},
@@ -256,7 +298,7 @@ describe('RemoteAgentInvocation', () => {
});
it('should persist contextId and taskId across invocations', async () => {
mockClientManager.getClient.mockReturnValue({});
mockClientManager.getClient.mockReturnValue(mockClient);
// First call return values
mockClientManager.sendMessageStream.mockImplementationOnce(
@@ -274,6 +316,7 @@ describe('RemoteAgentInvocation', () => {
const invocation1 = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{
query: 'first',
},
@@ -305,6 +348,7 @@ describe('RemoteAgentInvocation', () => {
const invocation2 = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{
query: 'second',
},
@@ -335,6 +379,7 @@ describe('RemoteAgentInvocation', () => {
const invocation3 = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{
query: 'third',
},
@@ -356,6 +401,7 @@ describe('RemoteAgentInvocation', () => {
const invocation4 = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{
query: 'fourth',
},
@@ -371,7 +417,7 @@ describe('RemoteAgentInvocation', () => {
});
it('should handle streaming updates and reassemble output', async () => {
mockClientManager.getClient.mockReturnValue({});
mockClientManager.getClient.mockReturnValue(mockClient);
mockClientManager.sendMessageStream.mockImplementation(
async function* () {
yield {
@@ -392,6 +438,7 @@ describe('RemoteAgentInvocation', () => {
const updateOutput = vi.fn();
const invocation = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{ query: 'hi' },
mockMessageBus,
);
@@ -402,7 +449,7 @@ describe('RemoteAgentInvocation', () => {
});
it('should abort when signal is aborted during streaming', async () => {
mockClientManager.getClient.mockReturnValue({});
mockClientManager.getClient.mockReturnValue(mockClient);
const controller = new AbortController();
mockClientManager.sendMessageStream.mockImplementation(
async function* () {
@@ -425,6 +472,7 @@ describe('RemoteAgentInvocation', () => {
const invocation = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{ query: 'hi' },
mockMessageBus,
);
@@ -435,7 +483,7 @@ describe('RemoteAgentInvocation', () => {
});
it('should handle errors gracefully', async () => {
mockClientManager.getClient.mockReturnValue({});
mockClientManager.getClient.mockReturnValue(mockClient);
mockClientManager.sendMessageStream.mockImplementation(
async function* () {
if (Math.random() < 0) yield {} as unknown as SendMessageResult;
@@ -445,6 +493,7 @@ describe('RemoteAgentInvocation', () => {
const invocation = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{
query: 'hi',
},
@@ -458,7 +507,7 @@ describe('RemoteAgentInvocation', () => {
});
it('should use a2a helpers for extracting text', async () => {
mockClientManager.getClient.mockReturnValue({});
mockClientManager.getClient.mockReturnValue(mockClient);
// Mock a complex message part that needs extraction
mockClientManager.sendMessageStream.mockImplementation(
async function* () {
@@ -476,6 +525,7 @@ describe('RemoteAgentInvocation', () => {
const invocation = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{
query: 'hi',
},
@@ -488,7 +538,7 @@ describe('RemoteAgentInvocation', () => {
});
it('should handle mixed response types during streaming (TaskStatusUpdateEvent + Message)', async () => {
mockClientManager.getClient.mockReturnValue({});
mockClientManager.getClient.mockReturnValue(mockClient);
mockClientManager.sendMessageStream.mockImplementation(
async function* () {
yield {
@@ -518,6 +568,7 @@ describe('RemoteAgentInvocation', () => {
const updateOutput = vi.fn();
const invocation = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{ query: 'hi' },
mockMessageBus,
);
@@ -532,17 +583,20 @@ describe('RemoteAgentInvocation', () => {
});
it('should handle artifact reassembly with append: true', async () => {
mockClientManager.getClient.mockReturnValue({});
mockClientManager.getClient.mockReturnValue(mockClient);
mockClientManager.sendMessageStream.mockImplementation(
async function* () {
yield {
kind: 'status-update',
taskId: 'task-1',
contextId: 'ctx-1',
final: false,
status: {
state: 'working',
message: {
kind: 'message',
role: 'agent',
messageId: 'm1',
parts: [{ kind: 'text', text: 'Generating...' }],
},
},
@@ -550,6 +604,7 @@ describe('RemoteAgentInvocation', () => {
yield {
kind: 'artifact-update',
taskId: 'task-1',
contextId: 'ctx-1',
append: false,
artifact: {
artifactId: 'art-1',
@@ -560,18 +615,21 @@ describe('RemoteAgentInvocation', () => {
yield {
kind: 'artifact-update',
taskId: 'task-1',
contextId: 'ctx-1',
append: true,
artifact: {
artifactId: 'art-1',
parts: [{ kind: 'text', text: ' Part 2' }],
},
};
return;
},
);
const updateOutput = vi.fn();
const invocation = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{ query: 'hi' },
mockMessageBus,
);
@@ -591,6 +649,7 @@ describe('RemoteAgentInvocation', () => {
it('should return info confirmation details', async () => {
const invocation = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{
query: 'hi',
},
@@ -629,6 +688,7 @@ describe('RemoteAgentInvocation', () => {
const invocation = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{ query: 'hi' },
mockMessageBus,
);
@@ -646,6 +706,7 @@ describe('RemoteAgentInvocation', () => {
const invocation = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{ query: 'hi' },
mockMessageBus,
);
@@ -658,7 +719,7 @@ describe('RemoteAgentInvocation', () => {
});
it('should include partial output when error occurs mid-stream', async () => {
mockClientManager.getClient.mockReturnValue({});
mockClientManager.getClient.mockReturnValue(mockClient);
mockClientManager.sendMessageStream.mockImplementation(
async function* () {
yield {
@@ -674,6 +735,7 @@ describe('RemoteAgentInvocation', () => {
const invocation = new RemoteAgentInvocation(
mockDefinition,
mockContext,
{ query: 'hi' },
mockMessageBus,
);

View File

@@ -16,10 +16,11 @@ import {
type RemoteAgentDefinition,
type AgentInputs,
} from './types.js';
import { type AgentLoopContext } from '../config/agent-loop-context.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import {
import type {
A2AClientManager,
type SendMessageResult,
SendMessageResult,
} from './a2a-client-manager.js';
import { extractIdsFromResponse, A2AResultReassembler } from './a2aUtils.js';
import type { AuthenticationHandler } from '@a2a-js/sdk/client';
@@ -47,13 +48,13 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
// State for the ongoing conversation with the remote agent
private contextId: string | undefined;
private taskId: string | undefined;
// TODO: See if we can reuse the singleton from AppContainer or similar, but for now use getInstance directly
// as per the current pattern in the codebase.
private readonly clientManager = A2AClientManager.getInstance();
private readonly clientManager: A2AClientManager;
private authHandler: AuthenticationHandler | undefined;
constructor(
private readonly definition: RemoteAgentDefinition,
private readonly context: AgentLoopContext,
params: AgentInputs,
messageBus: MessageBus,
_toolName?: string,
@@ -72,6 +73,13 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
_toolName ?? definition.name,
_toolDisplayName ?? definition.displayName,
);
const clientManager = this.context.config.getA2AClientManager();
if (!clientManager) {
throw new Error(
`Failed to initialize RemoteAgentInvocation for '${definition.name}': A2AClientManager is not available.`,
);
}
this.clientManager = clientManager;
}
getDescription(): string {

View File

@@ -75,6 +75,7 @@ export class SubagentToolWrapper extends BaseDeclarativeTool<
if (definition.kind === 'remote') {
return new RemoteAgentInvocation(
definition,
this.context,
params,
effectiveMessageBus,
_toolName,

View File

@@ -1523,7 +1523,7 @@ describe('Server Config (config.ts)', () => {
const paramsWithProxy: ConfigParameters = {
...baseParams,
proxy: 'invalid-proxy',
proxy: 'http://invalid-proxy:8080',
};
new Config(paramsWithProxy);

View File

@@ -405,6 +405,7 @@ import {
SimpleExtensionLoader,
} from '../utils/extensionLoader.js';
import { McpClientManager } from '../tools/mcp-client-manager.js';
import { A2AClientManager } from '../agents/a2a-client-manager.js';
import { type McpContext } from '../tools/mcp-client.js';
import type { EnvironmentSanitizationConfig } from '../services/environmentSanitization.js';
import { getErrorMessage } from '../utils/errors.js';
@@ -653,6 +654,7 @@ export interface ConfigParameters {
export class Config implements McpContext, AgentLoopContext {
private _toolRegistry!: ToolRegistry;
private mcpClientManager?: McpClientManager;
private readonly a2aClientManager?: A2AClientManager;
private allowedMcpServers: string[];
private blockedMcpServers: string[];
private allowedEnvironmentVariables: string[];
@@ -1188,6 +1190,7 @@ export class Config implements McpContext, AgentLoopContext {
params.toolSandboxing ?? false,
this.targetDir,
);
this.a2aClientManager = new A2AClientManager(this);
this.shellExecutionConfig.sandboxManager = this._sandboxManager;
this.modelRouterService = new ModelRouterService(this);
}
@@ -2000,6 +2003,10 @@ export class Config implements McpContext, AgentLoopContext {
return this.mcpClientManager;
}
getA2AClientManager(): A2AClientManager | undefined {
return this.a2aClientManager;
}
setUserInteractedWithMcp(): void {
this.mcpClientManager?.setUserInteractedWithMcp();
}

View File

@@ -15,6 +15,7 @@ import {
ApprovalMode,
PRIORITY_SUBAGENT_TOOL,
ALWAYS_ALLOW_PRIORITY_FRACTION,
PRIORITY_YOLO_ALLOW_ALL,
} from './types.js';
import type { FunctionCall } from '@google/genai';
import { SafetyCheckDecision } from '../safety/protocol.js';
@@ -2852,7 +2853,7 @@ describe('PolicyEngine', () => {
},
{
decision: PolicyDecision.ALLOW,
priority: 998,
priority: PRIORITY_YOLO_ALLOW_ALL,
modes: [ApprovalMode.YOLO],
},
];
@@ -2879,7 +2880,7 @@ describe('PolicyEngine', () => {
},
{
decision: PolicyDecision.ALLOW,
priority: 998,
priority: PRIORITY_YOLO_ALLOW_ALL,
modes: [ApprovalMode.YOLO],
},
];

View File

@@ -345,3 +345,9 @@ export const ALWAYS_ALLOW_PRIORITY_FRACTION = 950;
*/
export const ALWAYS_ALLOW_PRIORITY_OFFSET =
ALWAYS_ALLOW_PRIORITY_FRACTION / 1000;
/**
* Priority for the YOLO "allow all" rule.
* Matches the raw priority used in yolo.toml.
*/
export const PRIORITY_YOLO_ALLOW_ALL = 998;