mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-25 21:41:12 -07:00
feat(a2a): add agent acknowledgment command and enhance registry discovery (#22389)
This commit is contained in:
@@ -19,6 +19,8 @@ import {
|
||||
AuthType,
|
||||
isHeadlessMode,
|
||||
FatalAuthenticationError,
|
||||
PolicyDecision,
|
||||
PRIORITY_YOLO_ALLOW_ALL,
|
||||
} from '@google/gemini-cli-core';
|
||||
|
||||
// Mock dependencies
|
||||
@@ -325,6 +327,29 @@ describe('loadConfig', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass enableAgents to Config constructor', async () => {
|
||||
const settings: Settings = {
|
||||
experimental: {
|
||||
enableAgents: false,
|
||||
},
|
||||
};
|
||||
await loadConfig(settings, mockExtensionLoader, taskId);
|
||||
expect(Config).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
enableAgents: false,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should default enableAgents to true when not provided', async () => {
|
||||
await loadConfig(mockSettings, mockExtensionLoader, taskId);
|
||||
expect(Config).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
enableAgents: true,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
describe('interactivity', () => {
|
||||
it('should set interactive true when not headless', async () => {
|
||||
vi.mocked(isHeadlessMode).mockReturnValue(false);
|
||||
@@ -349,6 +374,41 @@ describe('loadConfig', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('YOLO mode', () => {
|
||||
it('should enable YOLO mode and add policy rule when GEMINI_YOLO_MODE is true', async () => {
|
||||
vi.stubEnv('GEMINI_YOLO_MODE', 'true');
|
||||
await loadConfig(mockSettings, mockExtensionLoader, taskId);
|
||||
expect(Config).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
approvalMode: 'yolo',
|
||||
policyEngineConfig: expect.objectContaining({
|
||||
rules: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
decision: PolicyDecision.ALLOW,
|
||||
priority: PRIORITY_YOLO_ALLOW_ALL,
|
||||
modes: ['yolo'],
|
||||
allowRedirection: true,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use default approval mode and empty rules when GEMINI_YOLO_MODE is not true', async () => {
|
||||
vi.stubEnv('GEMINI_YOLO_MODE', 'false');
|
||||
await loadConfig(mockSettings, mockExtensionLoader, taskId);
|
||||
expect(Config).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
approvalMode: 'default',
|
||||
policyEngineConfig: expect.objectContaining({
|
||||
rules: [],
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('authentication fallback', () => {
|
||||
beforeEach(() => {
|
||||
vi.stubEnv('USE_CCPA', 'true');
|
||||
|
||||
@@ -26,6 +26,8 @@ import {
|
||||
isHeadlessMode,
|
||||
FatalAuthenticationError,
|
||||
isCloudShell,
|
||||
PolicyDecision,
|
||||
PRIORITY_YOLO_ALLOW_ALL,
|
||||
type TelemetryTarget,
|
||||
type ConfigParameters,
|
||||
type ExtensionLoader,
|
||||
@@ -60,6 +62,11 @@ export async function loadConfig(
|
||||
}
|
||||
}
|
||||
|
||||
const approvalMode =
|
||||
process.env['GEMINI_YOLO_MODE'] === 'true'
|
||||
? ApprovalMode.YOLO
|
||||
: ApprovalMode.DEFAULT;
|
||||
|
||||
const configParams: ConfigParameters = {
|
||||
sessionId: taskId,
|
||||
clientName: 'a2a-server',
|
||||
@@ -74,10 +81,20 @@ export async function loadConfig(
|
||||
excludeTools: settings.excludeTools || settings.tools?.exclude || undefined,
|
||||
allowedTools: settings.allowedTools || settings.tools?.allowed || undefined,
|
||||
showMemoryUsage: settings.showMemoryUsage || false,
|
||||
approvalMode:
|
||||
process.env['GEMINI_YOLO_MODE'] === 'true'
|
||||
? ApprovalMode.YOLO
|
||||
: ApprovalMode.DEFAULT,
|
||||
approvalMode,
|
||||
policyEngineConfig: {
|
||||
rules:
|
||||
approvalMode === ApprovalMode.YOLO
|
||||
? [
|
||||
{
|
||||
decision: PolicyDecision.ALLOW,
|
||||
priority: PRIORITY_YOLO_ALLOW_ALL,
|
||||
modes: [ApprovalMode.YOLO],
|
||||
allowRedirection: true,
|
||||
},
|
||||
]
|
||||
: [],
|
||||
},
|
||||
mcpServers: settings.mcpServers,
|
||||
cwd: workspaceDir,
|
||||
telemetry: {
|
||||
@@ -110,6 +127,7 @@ export async function loadConfig(
|
||||
interactive: !isHeadlessMode(),
|
||||
enableInteractiveShell: !isHeadlessMode(),
|
||||
ptyInfo: 'auto',
|
||||
enableAgents: settings.experimental?.enableAgents ?? true,
|
||||
};
|
||||
|
||||
const fileService = new FileDiscoveryService(workspaceDir, {
|
||||
|
||||
@@ -112,6 +112,18 @@ describe('loadSettings', () => {
|
||||
expect(result.fileFiltering?.respectGitIgnore).toBe(true);
|
||||
});
|
||||
|
||||
it('should load experimental settings correctly', () => {
|
||||
const settings = {
|
||||
experimental: {
|
||||
enableAgents: true,
|
||||
},
|
||||
};
|
||||
fs.writeFileSync(USER_SETTINGS_PATH, JSON.stringify(settings));
|
||||
|
||||
const result = loadSettings(mockWorkspaceDir);
|
||||
expect(result.experimental?.enableAgents).toBe(true);
|
||||
});
|
||||
|
||||
it('should overwrite top-level settings from workspace (shallow merge)', () => {
|
||||
const userSettings = {
|
||||
showMemoryUsage: false,
|
||||
|
||||
@@ -48,6 +48,9 @@ export interface Settings {
|
||||
enableRecursiveFileSearch?: boolean;
|
||||
customIgnoreFilePaths?: string[];
|
||||
};
|
||||
experimental?: {
|
||||
enableAgents?: boolean;
|
||||
};
|
||||
}
|
||||
|
||||
export interface SettingsError {
|
||||
|
||||
@@ -66,11 +66,13 @@ describe('A2AClientManager', () => {
|
||||
};
|
||||
|
||||
const authFetchMock = vi.fn();
|
||||
const mockConfig = {
|
||||
getProxy: vi.fn(),
|
||||
} as unknown as Config;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
A2AClientManager.resetInstanceForTesting();
|
||||
manager = A2AClientManager.getInstance();
|
||||
manager = new A2AClientManager(mockConfig);
|
||||
|
||||
// Re-create the instances as plain objects that can be spied on
|
||||
const factoryInstance = {
|
||||
@@ -124,12 +126,6 @@ describe('A2AClientManager', () => {
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it('should enforce the singleton pattern', () => {
|
||||
const instance1 = A2AClientManager.getInstance();
|
||||
const instance2 = A2AClientManager.getInstance();
|
||||
expect(instance1).toBe(instance2);
|
||||
});
|
||||
|
||||
describe('getInstance / dispatcher initialization', () => {
|
||||
it('should use UndiciAgent when no proxy is configured', async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
@@ -152,12 +148,11 @@ describe('A2AClientManager', () => {
|
||||
});
|
||||
|
||||
it('should use ProxyAgent when a proxy is configured via Config', async () => {
|
||||
A2AClientManager.resetInstanceForTesting();
|
||||
const mockConfig = {
|
||||
const mockConfigWithProxy = {
|
||||
getProxy: () => 'http://my-proxy:8080',
|
||||
} as Config;
|
||||
|
||||
manager = A2AClientManager.getInstance(mockConfig);
|
||||
manager = new A2AClientManager(mockConfigWithProxy);
|
||||
await manager.loadAgent('TestProxyAgent', 'http://test.proxy.agent/card');
|
||||
|
||||
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock
|
||||
|
||||
@@ -49,8 +49,6 @@ const A2A_TIMEOUT = 1800000; // 30 minutes
|
||||
* Manages protocol negotiation, authentication, and transport selection.
|
||||
*/
|
||||
export class A2AClientManager {
|
||||
private static instance: A2AClientManager;
|
||||
|
||||
// Each agent should manage their own context/taskIds/card/etc
|
||||
private clients = new Map<string, Client>();
|
||||
private agentCards = new Map<string, AgentCard>();
|
||||
@@ -58,8 +56,8 @@ export class A2AClientManager {
|
||||
private a2aDispatcher: UndiciAgent | ProxyAgent;
|
||||
private a2aFetch: typeof fetch;
|
||||
|
||||
private constructor(config?: Config) {
|
||||
const proxyUrl = config?.getProxy();
|
||||
constructor(private readonly config: Config) {
|
||||
const proxyUrl = this.config.getProxy();
|
||||
const agentOptions = {
|
||||
headersTimeout: A2A_TIMEOUT,
|
||||
bodyTimeout: A2A_TIMEOUT,
|
||||
@@ -78,25 +76,6 @@ export class A2AClientManager {
|
||||
fetch(input, { ...init, dispatcher: this.a2aDispatcher } as RequestInit);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the singleton instance of the A2AClientManager.
|
||||
*/
|
||||
static getInstance(config?: Config): A2AClientManager {
|
||||
if (!A2AClientManager.instance) {
|
||||
A2AClientManager.instance = new A2AClientManager(config);
|
||||
}
|
||||
return A2AClientManager.instance;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resets the singleton instance. Only for testing purposes.
|
||||
* @internal
|
||||
*/
|
||||
static resetInstanceForTesting() {
|
||||
// @ts-expect-error - Resetting singleton for testing
|
||||
A2AClientManager.instance = undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads an agent by fetching its AgentCard and caches the client.
|
||||
* @param name The name to assign to the agent.
|
||||
|
||||
@@ -15,7 +15,7 @@ import type {
|
||||
} from '../config/config.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { coreEvents, CoreEvent } from '../utils/events.js';
|
||||
import { A2AClientManager } from './a2a-client-manager.js';
|
||||
import type { A2AClientManager } from './a2a-client-manager.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
@@ -40,9 +40,7 @@ vi.mock('./agentLoader.js', () => ({
|
||||
}));
|
||||
|
||||
vi.mock('./a2a-client-manager.js', () => ({
|
||||
A2AClientManager: {
|
||||
getInstance: vi.fn(),
|
||||
},
|
||||
A2AClientManager: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('./auth-provider/factory.js', () => ({
|
||||
@@ -450,7 +448,7 @@ describe('AgentRegistry', () => {
|
||||
);
|
||||
|
||||
// Mock A2AClientManager to avoid network calls
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }),
|
||||
clearCache: vi.fn(),
|
||||
} as unknown as A2AClientManager);
|
||||
@@ -548,7 +546,7 @@ describe('AgentRegistry', () => {
|
||||
inputConfig: { inputSchema: { type: 'object' } },
|
||||
};
|
||||
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }),
|
||||
} as unknown as A2AClientManager);
|
||||
|
||||
@@ -583,7 +581,7 @@ describe('AgentRegistry', () => {
|
||||
const loadAgentSpy = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ name: 'RemoteAgentWithAuth' });
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: loadAgentSpy,
|
||||
clearCache: vi.fn(),
|
||||
} as unknown as A2AClientManager);
|
||||
@@ -622,7 +620,7 @@ describe('AgentRegistry', () => {
|
||||
|
||||
vi.mocked(A2AAuthProviderFactory.create).mockResolvedValue(undefined);
|
||||
const loadAgentSpy = vi.fn();
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: loadAgentSpy,
|
||||
clearCache: vi.fn(),
|
||||
} as unknown as A2AClientManager);
|
||||
@@ -645,6 +643,9 @@ describe('AgentRegistry', () => {
|
||||
it('should log remote agent registration in debug mode', async () => {
|
||||
const debugConfig = makeMockedConfig({ debugMode: true });
|
||||
const debugRegistry = new TestableAgentRegistry(debugConfig);
|
||||
vi.spyOn(debugConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }),
|
||||
} as unknown as A2AClientManager);
|
||||
const debugLogSpy = vi
|
||||
.spyOn(debugLogger, 'log')
|
||||
.mockImplementation(() => {});
|
||||
@@ -657,10 +658,6 @@ describe('AgentRegistry', () => {
|
||||
inputConfig: { inputSchema: { type: 'object' } },
|
||||
};
|
||||
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }),
|
||||
} as unknown as A2AClientManager);
|
||||
|
||||
await debugRegistry.testRegisterAgent(remoteAgent);
|
||||
|
||||
expect(debugLogSpy).toHaveBeenCalledWith(
|
||||
@@ -688,7 +685,7 @@ describe('AgentRegistry', () => {
|
||||
new Error('ECONNREFUSED'),
|
||||
);
|
||||
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: vi.fn().mockRejectedValue(a2aError),
|
||||
} as unknown as A2AClientManager);
|
||||
|
||||
@@ -714,7 +711,7 @@ describe('AgentRegistry', () => {
|
||||
inputConfig: { inputSchema: { type: 'object' } },
|
||||
};
|
||||
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: vi.fn().mockRejectedValue(new Error('unexpected crash')),
|
||||
} as unknown as A2AClientManager);
|
||||
|
||||
@@ -749,7 +746,7 @@ describe('AgentRegistry', () => {
|
||||
// No auth configured
|
||||
};
|
||||
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: vi.fn().mockResolvedValue({
|
||||
name: 'SecuredAgent',
|
||||
securitySchemes: {
|
||||
@@ -783,7 +780,7 @@ describe('AgentRegistry', () => {
|
||||
};
|
||||
|
||||
const error = new Error('401 Unauthorized');
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: vi.fn().mockRejectedValue(error),
|
||||
} as unknown as A2AClientManager);
|
||||
|
||||
@@ -815,7 +812,7 @@ describe('AgentRegistry', () => {
|
||||
],
|
||||
};
|
||||
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: vi.fn().mockResolvedValue(mockAgentCard),
|
||||
clearCache: vi.fn(),
|
||||
} as unknown as A2AClientManager);
|
||||
@@ -843,7 +840,7 @@ describe('AgentRegistry', () => {
|
||||
skills: [{ name: 'Skill1', description: 'Desc1' }],
|
||||
};
|
||||
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: vi.fn().mockResolvedValue(mockAgentCard),
|
||||
clearCache: vi.fn(),
|
||||
} as unknown as A2AClientManager);
|
||||
@@ -871,7 +868,7 @@ describe('AgentRegistry', () => {
|
||||
skills: [],
|
||||
};
|
||||
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: vi.fn().mockResolvedValue(mockAgentCard),
|
||||
clearCache: vi.fn(),
|
||||
} as unknown as A2AClientManager);
|
||||
@@ -902,7 +899,7 @@ describe('AgentRegistry', () => {
|
||||
skills: [{ name: 'Skill1', description: 'Desc1' }],
|
||||
};
|
||||
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: vi.fn().mockResolvedValue(mockAgentCard),
|
||||
clearCache: vi.fn(),
|
||||
} as unknown as A2AClientManager);
|
||||
@@ -930,7 +927,7 @@ describe('AgentRegistry', () => {
|
||||
inputConfig: { inputSchema: { type: 'object' } },
|
||||
};
|
||||
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: vi.fn().mockResolvedValue({
|
||||
name: 'EmptyDescAgent',
|
||||
description: 'Loaded from card',
|
||||
@@ -955,7 +952,7 @@ describe('AgentRegistry', () => {
|
||||
inputConfig: { inputSchema: { type: 'object' } },
|
||||
};
|
||||
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: vi.fn().mockResolvedValue({
|
||||
name: 'SkillFallbackAgent',
|
||||
description: 'Card description',
|
||||
@@ -1092,7 +1089,7 @@ describe('AgentRegistry', () => {
|
||||
inputConfig: { inputSchema: { type: 'object' } },
|
||||
};
|
||||
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: vi.fn().mockResolvedValue({ name: 'RemotePolicyAgent' }),
|
||||
} as unknown as A2AClientManager);
|
||||
|
||||
@@ -1141,7 +1138,7 @@ describe('AgentRegistry', () => {
|
||||
inputConfig: { inputSchema: { type: 'object' } },
|
||||
};
|
||||
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||
loadAgent: vi.fn().mockResolvedValue({ name: 'OverwrittenAgent' }),
|
||||
} as unknown as A2AClientManager);
|
||||
|
||||
@@ -1189,8 +1186,10 @@ describe('AgentRegistry', () => {
|
||||
});
|
||||
|
||||
const clearCacheSpy = vi.fn();
|
||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
||||
vi.spyOn(config, 'getA2AClientManager').mockReturnValue({
|
||||
clearCache: clearCacheSpy,
|
||||
loadAgent: vi.fn(),
|
||||
getClient: vi.fn(),
|
||||
} as unknown as A2AClientManager);
|
||||
|
||||
const emitSpy = vi.spyOn(coreEvents, 'emitAgentsRefreshed');
|
||||
|
||||
@@ -13,7 +13,6 @@ import { CodebaseInvestigatorAgent } from './codebase-investigator.js';
|
||||
import { CliHelpAgent } from './cli-help-agent.js';
|
||||
import { GeneralistAgent } from './generalist-agent.js';
|
||||
import { BrowserAgentDefinition } from './browser/browserAgentDefinition.js';
|
||||
import { A2AClientManager } from './a2a-client-manager.js';
|
||||
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
|
||||
import type { AuthenticationHandler } from '@a2a-js/sdk/client';
|
||||
import { type z } from 'zod';
|
||||
@@ -69,7 +68,7 @@ export class AgentRegistry {
|
||||
* Clears the current registry and re-scans for agents.
|
||||
*/
|
||||
async reload(): Promise<void> {
|
||||
A2AClientManager.getInstance(this.config).clearCache();
|
||||
this.config.getA2AClientManager()?.clearCache();
|
||||
await this.config.reloadAgents();
|
||||
this.agents.clear();
|
||||
this.allDefinitions.clear();
|
||||
@@ -414,7 +413,13 @@ export class AgentRegistry {
|
||||
|
||||
// Load the remote A2A agent card and register.
|
||||
try {
|
||||
const clientManager = A2AClientManager.getInstance(this.config);
|
||||
const clientManager = this.config.getA2AClientManager();
|
||||
if (!clientManager) {
|
||||
debugLogger.warn(
|
||||
`[AgentRegistry] Skipping remote agent '${definition.name}': A2AClientManager is not available.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
let authHandler: AuthenticationHandler | undefined;
|
||||
if (definition.auth) {
|
||||
const provider = await A2AAuthProviderFactory.create({
|
||||
|
||||
@@ -13,21 +13,27 @@ import {
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import type { Client } from '@a2a-js/sdk/client';
|
||||
import { RemoteAgentInvocation } from './remote-invocation.js';
|
||||
import {
|
||||
A2AClientManager,
|
||||
type SendMessageResult,
|
||||
type A2AClientManager,
|
||||
} from './a2a-client-manager.js';
|
||||
|
||||
import type { RemoteAgentDefinition } from './types.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
|
||||
import type { A2AAuthProvider } from './auth-provider/types.js';
|
||||
import type { AgentLoopContext } from '../config/agent-loop-context.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
|
||||
// Mock A2AClientManager
|
||||
vi.mock('./a2a-client-manager.js', () => ({
|
||||
A2AClientManager: {
|
||||
getInstance: vi.fn(),
|
||||
},
|
||||
A2AClientManager: vi.fn().mockImplementation(() => ({
|
||||
getClient: vi.fn(),
|
||||
loadAgent: vi.fn(),
|
||||
sendMessageStream: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock A2AAuthProviderFactory
|
||||
@@ -49,16 +55,40 @@ describe('RemoteAgentInvocation', () => {
|
||||
},
|
||||
};
|
||||
|
||||
const mockClientManager = {
|
||||
getClient: vi.fn(),
|
||||
loadAgent: vi.fn(),
|
||||
sendMessageStream: vi.fn(),
|
||||
let mockClientManager: {
|
||||
getClient: Mock<A2AClientManager['getClient']>;
|
||||
loadAgent: Mock<A2AClientManager['loadAgent']>;
|
||||
sendMessageStream: Mock<A2AClientManager['sendMessageStream']>;
|
||||
};
|
||||
let mockContext: AgentLoopContext;
|
||||
const mockMessageBus = createMockMessageBus();
|
||||
|
||||
const mockClient = {
|
||||
sendMessageStream: vi.fn(),
|
||||
getTask: vi.fn(),
|
||||
cancelTask: vi.fn(),
|
||||
} as unknown as Client;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
(A2AClientManager.getInstance as Mock).mockReturnValue(mockClientManager);
|
||||
|
||||
mockClientManager = {
|
||||
getClient: vi.fn(),
|
||||
loadAgent: vi.fn(),
|
||||
sendMessageStream: vi.fn(),
|
||||
};
|
||||
|
||||
const mockConfig = {
|
||||
getA2AClientManager: vi.fn().mockReturnValue(mockClientManager),
|
||||
injectionService: {
|
||||
getLatestInjectionIndex: vi.fn().mockReturnValue(0),
|
||||
},
|
||||
} as unknown as Config;
|
||||
|
||||
mockContext = {
|
||||
config: mockConfig,
|
||||
} as unknown as AgentLoopContext;
|
||||
|
||||
(
|
||||
RemoteAgentInvocation as unknown as {
|
||||
sessionState?: Map<string, { contextId?: string; taskId?: string }>;
|
||||
@@ -75,6 +105,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
expect(() => {
|
||||
new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{ query: 'valid' },
|
||||
mockMessageBus,
|
||||
);
|
||||
@@ -83,12 +114,17 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
it('accepts missing query (defaults to "Get Started!")', () => {
|
||||
expect(() => {
|
||||
new RemoteAgentInvocation(mockDefinition, {}, mockMessageBus);
|
||||
new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{},
|
||||
mockMessageBus,
|
||||
);
|
||||
}).not.toThrow();
|
||||
});
|
||||
|
||||
it('uses "Get Started!" default when query is missing during execution', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
yield {
|
||||
@@ -102,6 +138,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{},
|
||||
mockMessageBus,
|
||||
);
|
||||
@@ -118,6 +155,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
expect(() => {
|
||||
new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{ query: 123 },
|
||||
mockMessageBus,
|
||||
);
|
||||
@@ -141,6 +179,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{
|
||||
query: 'hi',
|
||||
},
|
||||
@@ -187,6 +226,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
authDefinition,
|
||||
mockContext,
|
||||
{ query: 'hi' },
|
||||
mockMessageBus,
|
||||
);
|
||||
@@ -220,6 +260,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
authDefinition,
|
||||
mockContext,
|
||||
{ query: 'hi' },
|
||||
mockMessageBus,
|
||||
);
|
||||
@@ -231,7 +272,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
});
|
||||
|
||||
it('should not load the agent if already present', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
yield {
|
||||
@@ -245,6 +286,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{
|
||||
query: 'hi',
|
||||
},
|
||||
@@ -256,7 +298,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
});
|
||||
|
||||
it('should persist contextId and taskId across invocations', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||
|
||||
// First call return values
|
||||
mockClientManager.sendMessageStream.mockImplementationOnce(
|
||||
@@ -274,6 +316,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
const invocation1 = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{
|
||||
query: 'first',
|
||||
},
|
||||
@@ -305,6 +348,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
const invocation2 = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{
|
||||
query: 'second',
|
||||
},
|
||||
@@ -335,6 +379,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
const invocation3 = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{
|
||||
query: 'third',
|
||||
},
|
||||
@@ -356,6 +401,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
const invocation4 = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{
|
||||
query: 'fourth',
|
||||
},
|
||||
@@ -371,7 +417,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
});
|
||||
|
||||
it('should handle streaming updates and reassemble output', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
yield {
|
||||
@@ -392,6 +438,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
const updateOutput = vi.fn();
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{ query: 'hi' },
|
||||
mockMessageBus,
|
||||
);
|
||||
@@ -402,7 +449,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
});
|
||||
|
||||
it('should abort when signal is aborted during streaming', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||
const controller = new AbortController();
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
@@ -425,6 +472,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{ query: 'hi' },
|
||||
mockMessageBus,
|
||||
);
|
||||
@@ -435,7 +483,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
});
|
||||
|
||||
it('should handle errors gracefully', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
if (Math.random() < 0) yield {} as unknown as SendMessageResult;
|
||||
@@ -445,6 +493,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{
|
||||
query: 'hi',
|
||||
},
|
||||
@@ -458,7 +507,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
});
|
||||
|
||||
it('should use a2a helpers for extracting text', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||
// Mock a complex message part that needs extraction
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
@@ -476,6 +525,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{
|
||||
query: 'hi',
|
||||
},
|
||||
@@ -488,7 +538,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
});
|
||||
|
||||
it('should handle mixed response types during streaming (TaskStatusUpdateEvent + Message)', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
yield {
|
||||
@@ -518,6 +568,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
const updateOutput = vi.fn();
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{ query: 'hi' },
|
||||
mockMessageBus,
|
||||
);
|
||||
@@ -532,17 +583,20 @@ describe('RemoteAgentInvocation', () => {
|
||||
});
|
||||
|
||||
it('should handle artifact reassembly with append: true', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
yield {
|
||||
kind: 'status-update',
|
||||
taskId: 'task-1',
|
||||
contextId: 'ctx-1',
|
||||
final: false,
|
||||
status: {
|
||||
state: 'working',
|
||||
message: {
|
||||
kind: 'message',
|
||||
role: 'agent',
|
||||
messageId: 'm1',
|
||||
parts: [{ kind: 'text', text: 'Generating...' }],
|
||||
},
|
||||
},
|
||||
@@ -550,6 +604,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
yield {
|
||||
kind: 'artifact-update',
|
||||
taskId: 'task-1',
|
||||
contextId: 'ctx-1',
|
||||
append: false,
|
||||
artifact: {
|
||||
artifactId: 'art-1',
|
||||
@@ -560,18 +615,21 @@ describe('RemoteAgentInvocation', () => {
|
||||
yield {
|
||||
kind: 'artifact-update',
|
||||
taskId: 'task-1',
|
||||
contextId: 'ctx-1',
|
||||
append: true,
|
||||
artifact: {
|
||||
artifactId: 'art-1',
|
||||
parts: [{ kind: 'text', text: ' Part 2' }],
|
||||
},
|
||||
};
|
||||
return;
|
||||
},
|
||||
);
|
||||
|
||||
const updateOutput = vi.fn();
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{ query: 'hi' },
|
||||
mockMessageBus,
|
||||
);
|
||||
@@ -591,6 +649,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
it('should return info confirmation details', async () => {
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{
|
||||
query: 'hi',
|
||||
},
|
||||
@@ -629,6 +688,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{ query: 'hi' },
|
||||
mockMessageBus,
|
||||
);
|
||||
@@ -646,6 +706,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{ query: 'hi' },
|
||||
mockMessageBus,
|
||||
);
|
||||
@@ -658,7 +719,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
});
|
||||
|
||||
it('should include partial output when error occurs mid-stream', async () => {
|
||||
mockClientManager.getClient.mockReturnValue({});
|
||||
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||
mockClientManager.sendMessageStream.mockImplementation(
|
||||
async function* () {
|
||||
yield {
|
||||
@@ -674,6 +735,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
|
||||
const invocation = new RemoteAgentInvocation(
|
||||
mockDefinition,
|
||||
mockContext,
|
||||
{ query: 'hi' },
|
||||
mockMessageBus,
|
||||
);
|
||||
|
||||
@@ -16,10 +16,11 @@ import {
|
||||
type RemoteAgentDefinition,
|
||||
type AgentInputs,
|
||||
} from './types.js';
|
||||
import { type AgentLoopContext } from '../config/agent-loop-context.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import {
|
||||
import type {
|
||||
A2AClientManager,
|
||||
type SendMessageResult,
|
||||
SendMessageResult,
|
||||
} from './a2a-client-manager.js';
|
||||
import { extractIdsFromResponse, A2AResultReassembler } from './a2aUtils.js';
|
||||
import type { AuthenticationHandler } from '@a2a-js/sdk/client';
|
||||
@@ -47,13 +48,13 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
|
||||
// State for the ongoing conversation with the remote agent
|
||||
private contextId: string | undefined;
|
||||
private taskId: string | undefined;
|
||||
// TODO: See if we can reuse the singleton from AppContainer or similar, but for now use getInstance directly
|
||||
// as per the current pattern in the codebase.
|
||||
private readonly clientManager = A2AClientManager.getInstance();
|
||||
|
||||
private readonly clientManager: A2AClientManager;
|
||||
private authHandler: AuthenticationHandler | undefined;
|
||||
|
||||
constructor(
|
||||
private readonly definition: RemoteAgentDefinition,
|
||||
private readonly context: AgentLoopContext,
|
||||
params: AgentInputs,
|
||||
messageBus: MessageBus,
|
||||
_toolName?: string,
|
||||
@@ -72,6 +73,13 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
|
||||
_toolName ?? definition.name,
|
||||
_toolDisplayName ?? definition.displayName,
|
||||
);
|
||||
const clientManager = this.context.config.getA2AClientManager();
|
||||
if (!clientManager) {
|
||||
throw new Error(
|
||||
`Failed to initialize RemoteAgentInvocation for '${definition.name}': A2AClientManager is not available.`,
|
||||
);
|
||||
}
|
||||
this.clientManager = clientManager;
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
|
||||
@@ -75,6 +75,7 @@ export class SubagentToolWrapper extends BaseDeclarativeTool<
|
||||
if (definition.kind === 'remote') {
|
||||
return new RemoteAgentInvocation(
|
||||
definition,
|
||||
this.context,
|
||||
params,
|
||||
effectiveMessageBus,
|
||||
_toolName,
|
||||
|
||||
@@ -1523,7 +1523,7 @@ describe('Server Config (config.ts)', () => {
|
||||
|
||||
const paramsWithProxy: ConfigParameters = {
|
||||
...baseParams,
|
||||
proxy: 'invalid-proxy',
|
||||
proxy: 'http://invalid-proxy:8080',
|
||||
};
|
||||
new Config(paramsWithProxy);
|
||||
|
||||
|
||||
@@ -405,6 +405,7 @@ import {
|
||||
SimpleExtensionLoader,
|
||||
} from '../utils/extensionLoader.js';
|
||||
import { McpClientManager } from '../tools/mcp-client-manager.js';
|
||||
import { A2AClientManager } from '../agents/a2a-client-manager.js';
|
||||
import { type McpContext } from '../tools/mcp-client.js';
|
||||
import type { EnvironmentSanitizationConfig } from '../services/environmentSanitization.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
@@ -653,6 +654,7 @@ export interface ConfigParameters {
|
||||
export class Config implements McpContext, AgentLoopContext {
|
||||
private _toolRegistry!: ToolRegistry;
|
||||
private mcpClientManager?: McpClientManager;
|
||||
private readonly a2aClientManager?: A2AClientManager;
|
||||
private allowedMcpServers: string[];
|
||||
private blockedMcpServers: string[];
|
||||
private allowedEnvironmentVariables: string[];
|
||||
@@ -1188,6 +1190,7 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
params.toolSandboxing ?? false,
|
||||
this.targetDir,
|
||||
);
|
||||
this.a2aClientManager = new A2AClientManager(this);
|
||||
this.shellExecutionConfig.sandboxManager = this._sandboxManager;
|
||||
this.modelRouterService = new ModelRouterService(this);
|
||||
}
|
||||
@@ -2000,6 +2003,10 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
return this.mcpClientManager;
|
||||
}
|
||||
|
||||
getA2AClientManager(): A2AClientManager | undefined {
|
||||
return this.a2aClientManager;
|
||||
}
|
||||
|
||||
setUserInteractedWithMcp(): void {
|
||||
this.mcpClientManager?.setUserInteractedWithMcp();
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
ApprovalMode,
|
||||
PRIORITY_SUBAGENT_TOOL,
|
||||
ALWAYS_ALLOW_PRIORITY_FRACTION,
|
||||
PRIORITY_YOLO_ALLOW_ALL,
|
||||
} from './types.js';
|
||||
import type { FunctionCall } from '@google/genai';
|
||||
import { SafetyCheckDecision } from '../safety/protocol.js';
|
||||
@@ -2852,7 +2853,7 @@ describe('PolicyEngine', () => {
|
||||
},
|
||||
{
|
||||
decision: PolicyDecision.ALLOW,
|
||||
priority: 998,
|
||||
priority: PRIORITY_YOLO_ALLOW_ALL,
|
||||
modes: [ApprovalMode.YOLO],
|
||||
},
|
||||
];
|
||||
@@ -2879,7 +2880,7 @@ describe('PolicyEngine', () => {
|
||||
},
|
||||
{
|
||||
decision: PolicyDecision.ALLOW,
|
||||
priority: 998,
|
||||
priority: PRIORITY_YOLO_ALLOW_ALL,
|
||||
modes: [ApprovalMode.YOLO],
|
||||
},
|
||||
];
|
||||
|
||||
@@ -345,3 +345,9 @@ export const ALWAYS_ALLOW_PRIORITY_FRACTION = 950;
|
||||
*/
|
||||
export const ALWAYS_ALLOW_PRIORITY_OFFSET =
|
||||
ALWAYS_ALLOW_PRIORITY_FRACTION / 1000;
|
||||
|
||||
/**
|
||||
* Priority for the YOLO "allow all" rule.
|
||||
* Matches the raw priority used in yolo.toml.
|
||||
*/
|
||||
export const PRIORITY_YOLO_ALLOW_ALL = 998;
|
||||
|
||||
Reference in New Issue
Block a user