mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-12 12:54:07 -07:00
feat(a2a): add agent acknowledgment command and enhance registry discovery (#22389)
This commit is contained in:
@@ -19,6 +19,8 @@ import {
|
|||||||
AuthType,
|
AuthType,
|
||||||
isHeadlessMode,
|
isHeadlessMode,
|
||||||
FatalAuthenticationError,
|
FatalAuthenticationError,
|
||||||
|
PolicyDecision,
|
||||||
|
PRIORITY_YOLO_ALLOW_ALL,
|
||||||
} from '@google/gemini-cli-core';
|
} from '@google/gemini-cli-core';
|
||||||
|
|
||||||
// Mock dependencies
|
// Mock dependencies
|
||||||
@@ -325,6 +327,29 @@ describe('loadConfig', () => {
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should pass enableAgents to Config constructor', async () => {
|
||||||
|
const settings: Settings = {
|
||||||
|
experimental: {
|
||||||
|
enableAgents: false,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
await loadConfig(settings, mockExtensionLoader, taskId);
|
||||||
|
expect(Config).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
enableAgents: false,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should default enableAgents to true when not provided', async () => {
|
||||||
|
await loadConfig(mockSettings, mockExtensionLoader, taskId);
|
||||||
|
expect(Config).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
enableAgents: true,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
describe('interactivity', () => {
|
describe('interactivity', () => {
|
||||||
it('should set interactive true when not headless', async () => {
|
it('should set interactive true when not headless', async () => {
|
||||||
vi.mocked(isHeadlessMode).mockReturnValue(false);
|
vi.mocked(isHeadlessMode).mockReturnValue(false);
|
||||||
@@ -349,6 +374,41 @@ describe('loadConfig', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('YOLO mode', () => {
|
||||||
|
it('should enable YOLO mode and add policy rule when GEMINI_YOLO_MODE is true', async () => {
|
||||||
|
vi.stubEnv('GEMINI_YOLO_MODE', 'true');
|
||||||
|
await loadConfig(mockSettings, mockExtensionLoader, taskId);
|
||||||
|
expect(Config).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
approvalMode: 'yolo',
|
||||||
|
policyEngineConfig: expect.objectContaining({
|
||||||
|
rules: expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
decision: PolicyDecision.ALLOW,
|
||||||
|
priority: PRIORITY_YOLO_ALLOW_ALL,
|
||||||
|
modes: ['yolo'],
|
||||||
|
allowRedirection: true,
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should use default approval mode and empty rules when GEMINI_YOLO_MODE is not true', async () => {
|
||||||
|
vi.stubEnv('GEMINI_YOLO_MODE', 'false');
|
||||||
|
await loadConfig(mockSettings, mockExtensionLoader, taskId);
|
||||||
|
expect(Config).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
approvalMode: 'default',
|
||||||
|
policyEngineConfig: expect.objectContaining({
|
||||||
|
rules: [],
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('authentication fallback', () => {
|
describe('authentication fallback', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.stubEnv('USE_CCPA', 'true');
|
vi.stubEnv('USE_CCPA', 'true');
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ import {
|
|||||||
isHeadlessMode,
|
isHeadlessMode,
|
||||||
FatalAuthenticationError,
|
FatalAuthenticationError,
|
||||||
isCloudShell,
|
isCloudShell,
|
||||||
|
PolicyDecision,
|
||||||
|
PRIORITY_YOLO_ALLOW_ALL,
|
||||||
type TelemetryTarget,
|
type TelemetryTarget,
|
||||||
type ConfigParameters,
|
type ConfigParameters,
|
||||||
type ExtensionLoader,
|
type ExtensionLoader,
|
||||||
@@ -60,6 +62,11 @@ export async function loadConfig(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const approvalMode =
|
||||||
|
process.env['GEMINI_YOLO_MODE'] === 'true'
|
||||||
|
? ApprovalMode.YOLO
|
||||||
|
: ApprovalMode.DEFAULT;
|
||||||
|
|
||||||
const configParams: ConfigParameters = {
|
const configParams: ConfigParameters = {
|
||||||
sessionId: taskId,
|
sessionId: taskId,
|
||||||
clientName: 'a2a-server',
|
clientName: 'a2a-server',
|
||||||
@@ -74,10 +81,20 @@ export async function loadConfig(
|
|||||||
excludeTools: settings.excludeTools || settings.tools?.exclude || undefined,
|
excludeTools: settings.excludeTools || settings.tools?.exclude || undefined,
|
||||||
allowedTools: settings.allowedTools || settings.tools?.allowed || undefined,
|
allowedTools: settings.allowedTools || settings.tools?.allowed || undefined,
|
||||||
showMemoryUsage: settings.showMemoryUsage || false,
|
showMemoryUsage: settings.showMemoryUsage || false,
|
||||||
approvalMode:
|
approvalMode,
|
||||||
process.env['GEMINI_YOLO_MODE'] === 'true'
|
policyEngineConfig: {
|
||||||
? ApprovalMode.YOLO
|
rules:
|
||||||
: ApprovalMode.DEFAULT,
|
approvalMode === ApprovalMode.YOLO
|
||||||
|
? [
|
||||||
|
{
|
||||||
|
decision: PolicyDecision.ALLOW,
|
||||||
|
priority: PRIORITY_YOLO_ALLOW_ALL,
|
||||||
|
modes: [ApprovalMode.YOLO],
|
||||||
|
allowRedirection: true,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
: [],
|
||||||
|
},
|
||||||
mcpServers: settings.mcpServers,
|
mcpServers: settings.mcpServers,
|
||||||
cwd: workspaceDir,
|
cwd: workspaceDir,
|
||||||
telemetry: {
|
telemetry: {
|
||||||
@@ -110,6 +127,7 @@ export async function loadConfig(
|
|||||||
interactive: !isHeadlessMode(),
|
interactive: !isHeadlessMode(),
|
||||||
enableInteractiveShell: !isHeadlessMode(),
|
enableInteractiveShell: !isHeadlessMode(),
|
||||||
ptyInfo: 'auto',
|
ptyInfo: 'auto',
|
||||||
|
enableAgents: settings.experimental?.enableAgents ?? true,
|
||||||
};
|
};
|
||||||
|
|
||||||
const fileService = new FileDiscoveryService(workspaceDir, {
|
const fileService = new FileDiscoveryService(workspaceDir, {
|
||||||
|
|||||||
@@ -112,6 +112,18 @@ describe('loadSettings', () => {
|
|||||||
expect(result.fileFiltering?.respectGitIgnore).toBe(true);
|
expect(result.fileFiltering?.respectGitIgnore).toBe(true);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should load experimental settings correctly', () => {
|
||||||
|
const settings = {
|
||||||
|
experimental: {
|
||||||
|
enableAgents: true,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
fs.writeFileSync(USER_SETTINGS_PATH, JSON.stringify(settings));
|
||||||
|
|
||||||
|
const result = loadSettings(mockWorkspaceDir);
|
||||||
|
expect(result.experimental?.enableAgents).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
it('should overwrite top-level settings from workspace (shallow merge)', () => {
|
it('should overwrite top-level settings from workspace (shallow merge)', () => {
|
||||||
const userSettings = {
|
const userSettings = {
|
||||||
showMemoryUsage: false,
|
showMemoryUsage: false,
|
||||||
|
|||||||
@@ -48,6 +48,9 @@ export interface Settings {
|
|||||||
enableRecursiveFileSearch?: boolean;
|
enableRecursiveFileSearch?: boolean;
|
||||||
customIgnoreFilePaths?: string[];
|
customIgnoreFilePaths?: string[];
|
||||||
};
|
};
|
||||||
|
experimental?: {
|
||||||
|
enableAgents?: boolean;
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface SettingsError {
|
export interface SettingsError {
|
||||||
|
|||||||
@@ -66,11 +66,13 @@ describe('A2AClientManager', () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const authFetchMock = vi.fn();
|
const authFetchMock = vi.fn();
|
||||||
|
const mockConfig = {
|
||||||
|
getProxy: vi.fn(),
|
||||||
|
} as unknown as Config;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.clearAllMocks();
|
vi.clearAllMocks();
|
||||||
A2AClientManager.resetInstanceForTesting();
|
manager = new A2AClientManager(mockConfig);
|
||||||
manager = A2AClientManager.getInstance();
|
|
||||||
|
|
||||||
// Re-create the instances as plain objects that can be spied on
|
// Re-create the instances as plain objects that can be spied on
|
||||||
const factoryInstance = {
|
const factoryInstance = {
|
||||||
@@ -124,12 +126,6 @@ describe('A2AClientManager', () => {
|
|||||||
vi.unstubAllGlobals();
|
vi.unstubAllGlobals();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should enforce the singleton pattern', () => {
|
|
||||||
const instance1 = A2AClientManager.getInstance();
|
|
||||||
const instance2 = A2AClientManager.getInstance();
|
|
||||||
expect(instance1).toBe(instance2);
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('getInstance / dispatcher initialization', () => {
|
describe('getInstance / dispatcher initialization', () => {
|
||||||
it('should use UndiciAgent when no proxy is configured', async () => {
|
it('should use UndiciAgent when no proxy is configured', async () => {
|
||||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||||
@@ -152,12 +148,11 @@ describe('A2AClientManager', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should use ProxyAgent when a proxy is configured via Config', async () => {
|
it('should use ProxyAgent when a proxy is configured via Config', async () => {
|
||||||
A2AClientManager.resetInstanceForTesting();
|
const mockConfigWithProxy = {
|
||||||
const mockConfig = {
|
|
||||||
getProxy: () => 'http://my-proxy:8080',
|
getProxy: () => 'http://my-proxy:8080',
|
||||||
} as Config;
|
} as Config;
|
||||||
|
|
||||||
manager = A2AClientManager.getInstance(mockConfig);
|
manager = new A2AClientManager(mockConfigWithProxy);
|
||||||
await manager.loadAgent('TestProxyAgent', 'http://test.proxy.agent/card');
|
await manager.loadAgent('TestProxyAgent', 'http://test.proxy.agent/card');
|
||||||
|
|
||||||
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock
|
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock
|
||||||
|
|||||||
@@ -49,8 +49,6 @@ const A2A_TIMEOUT = 1800000; // 30 minutes
|
|||||||
* Manages protocol negotiation, authentication, and transport selection.
|
* Manages protocol negotiation, authentication, and transport selection.
|
||||||
*/
|
*/
|
||||||
export class A2AClientManager {
|
export class A2AClientManager {
|
||||||
private static instance: A2AClientManager;
|
|
||||||
|
|
||||||
// Each agent should manage their own context/taskIds/card/etc
|
// Each agent should manage their own context/taskIds/card/etc
|
||||||
private clients = new Map<string, Client>();
|
private clients = new Map<string, Client>();
|
||||||
private agentCards = new Map<string, AgentCard>();
|
private agentCards = new Map<string, AgentCard>();
|
||||||
@@ -58,8 +56,8 @@ export class A2AClientManager {
|
|||||||
private a2aDispatcher: UndiciAgent | ProxyAgent;
|
private a2aDispatcher: UndiciAgent | ProxyAgent;
|
||||||
private a2aFetch: typeof fetch;
|
private a2aFetch: typeof fetch;
|
||||||
|
|
||||||
private constructor(config?: Config) {
|
constructor(private readonly config: Config) {
|
||||||
const proxyUrl = config?.getProxy();
|
const proxyUrl = this.config.getProxy();
|
||||||
const agentOptions = {
|
const agentOptions = {
|
||||||
headersTimeout: A2A_TIMEOUT,
|
headersTimeout: A2A_TIMEOUT,
|
||||||
bodyTimeout: A2A_TIMEOUT,
|
bodyTimeout: A2A_TIMEOUT,
|
||||||
@@ -78,25 +76,6 @@ export class A2AClientManager {
|
|||||||
fetch(input, { ...init, dispatcher: this.a2aDispatcher } as RequestInit);
|
fetch(input, { ...init, dispatcher: this.a2aDispatcher } as RequestInit);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Gets the singleton instance of the A2AClientManager.
|
|
||||||
*/
|
|
||||||
static getInstance(config?: Config): A2AClientManager {
|
|
||||||
if (!A2AClientManager.instance) {
|
|
||||||
A2AClientManager.instance = new A2AClientManager(config);
|
|
||||||
}
|
|
||||||
return A2AClientManager.instance;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Resets the singleton instance. Only for testing purposes.
|
|
||||||
* @internal
|
|
||||||
*/
|
|
||||||
static resetInstanceForTesting() {
|
|
||||||
// @ts-expect-error - Resetting singleton for testing
|
|
||||||
A2AClientManager.instance = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads an agent by fetching its AgentCard and caches the client.
|
* Loads an agent by fetching its AgentCard and caches the client.
|
||||||
* @param name The name to assign to the agent.
|
* @param name The name to assign to the agent.
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import type {
|
|||||||
} from '../config/config.js';
|
} from '../config/config.js';
|
||||||
import { debugLogger } from '../utils/debugLogger.js';
|
import { debugLogger } from '../utils/debugLogger.js';
|
||||||
import { coreEvents, CoreEvent } from '../utils/events.js';
|
import { coreEvents, CoreEvent } from '../utils/events.js';
|
||||||
import { A2AClientManager } from './a2a-client-manager.js';
|
import type { A2AClientManager } from './a2a-client-manager.js';
|
||||||
import {
|
import {
|
||||||
DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||||
DEFAULT_GEMINI_MODEL,
|
DEFAULT_GEMINI_MODEL,
|
||||||
@@ -40,9 +40,7 @@ vi.mock('./agentLoader.js', () => ({
|
|||||||
}));
|
}));
|
||||||
|
|
||||||
vi.mock('./a2a-client-manager.js', () => ({
|
vi.mock('./a2a-client-manager.js', () => ({
|
||||||
A2AClientManager: {
|
A2AClientManager: vi.fn(),
|
||||||
getInstance: vi.fn(),
|
|
||||||
},
|
|
||||||
}));
|
}));
|
||||||
|
|
||||||
vi.mock('./auth-provider/factory.js', () => ({
|
vi.mock('./auth-provider/factory.js', () => ({
|
||||||
@@ -450,7 +448,7 @@ describe('AgentRegistry', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Mock A2AClientManager to avoid network calls
|
// Mock A2AClientManager to avoid network calls
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }),
|
loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }),
|
||||||
clearCache: vi.fn(),
|
clearCache: vi.fn(),
|
||||||
} as unknown as A2AClientManager);
|
} as unknown as A2AClientManager);
|
||||||
@@ -548,7 +546,7 @@ describe('AgentRegistry', () => {
|
|||||||
inputConfig: { inputSchema: { type: 'object' } },
|
inputConfig: { inputSchema: { type: 'object' } },
|
||||||
};
|
};
|
||||||
|
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }),
|
loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }),
|
||||||
} as unknown as A2AClientManager);
|
} as unknown as A2AClientManager);
|
||||||
|
|
||||||
@@ -583,7 +581,7 @@ describe('AgentRegistry', () => {
|
|||||||
const loadAgentSpy = vi
|
const loadAgentSpy = vi
|
||||||
.fn()
|
.fn()
|
||||||
.mockResolvedValue({ name: 'RemoteAgentWithAuth' });
|
.mockResolvedValue({ name: 'RemoteAgentWithAuth' });
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
loadAgent: loadAgentSpy,
|
loadAgent: loadAgentSpy,
|
||||||
clearCache: vi.fn(),
|
clearCache: vi.fn(),
|
||||||
} as unknown as A2AClientManager);
|
} as unknown as A2AClientManager);
|
||||||
@@ -622,7 +620,7 @@ describe('AgentRegistry', () => {
|
|||||||
|
|
||||||
vi.mocked(A2AAuthProviderFactory.create).mockResolvedValue(undefined);
|
vi.mocked(A2AAuthProviderFactory.create).mockResolvedValue(undefined);
|
||||||
const loadAgentSpy = vi.fn();
|
const loadAgentSpy = vi.fn();
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
loadAgent: loadAgentSpy,
|
loadAgent: loadAgentSpy,
|
||||||
clearCache: vi.fn(),
|
clearCache: vi.fn(),
|
||||||
} as unknown as A2AClientManager);
|
} as unknown as A2AClientManager);
|
||||||
@@ -645,6 +643,9 @@ describe('AgentRegistry', () => {
|
|||||||
it('should log remote agent registration in debug mode', async () => {
|
it('should log remote agent registration in debug mode', async () => {
|
||||||
const debugConfig = makeMockedConfig({ debugMode: true });
|
const debugConfig = makeMockedConfig({ debugMode: true });
|
||||||
const debugRegistry = new TestableAgentRegistry(debugConfig);
|
const debugRegistry = new TestableAgentRegistry(debugConfig);
|
||||||
|
vi.spyOn(debugConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
|
loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }),
|
||||||
|
} as unknown as A2AClientManager);
|
||||||
const debugLogSpy = vi
|
const debugLogSpy = vi
|
||||||
.spyOn(debugLogger, 'log')
|
.spyOn(debugLogger, 'log')
|
||||||
.mockImplementation(() => {});
|
.mockImplementation(() => {});
|
||||||
@@ -657,10 +658,6 @@ describe('AgentRegistry', () => {
|
|||||||
inputConfig: { inputSchema: { type: 'object' } },
|
inputConfig: { inputSchema: { type: 'object' } },
|
||||||
};
|
};
|
||||||
|
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
|
||||||
loadAgent: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }),
|
|
||||||
} as unknown as A2AClientManager);
|
|
||||||
|
|
||||||
await debugRegistry.testRegisterAgent(remoteAgent);
|
await debugRegistry.testRegisterAgent(remoteAgent);
|
||||||
|
|
||||||
expect(debugLogSpy).toHaveBeenCalledWith(
|
expect(debugLogSpy).toHaveBeenCalledWith(
|
||||||
@@ -688,7 +685,7 @@ describe('AgentRegistry', () => {
|
|||||||
new Error('ECONNREFUSED'),
|
new Error('ECONNREFUSED'),
|
||||||
);
|
);
|
||||||
|
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
loadAgent: vi.fn().mockRejectedValue(a2aError),
|
loadAgent: vi.fn().mockRejectedValue(a2aError),
|
||||||
} as unknown as A2AClientManager);
|
} as unknown as A2AClientManager);
|
||||||
|
|
||||||
@@ -714,7 +711,7 @@ describe('AgentRegistry', () => {
|
|||||||
inputConfig: { inputSchema: { type: 'object' } },
|
inputConfig: { inputSchema: { type: 'object' } },
|
||||||
};
|
};
|
||||||
|
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
loadAgent: vi.fn().mockRejectedValue(new Error('unexpected crash')),
|
loadAgent: vi.fn().mockRejectedValue(new Error('unexpected crash')),
|
||||||
} as unknown as A2AClientManager);
|
} as unknown as A2AClientManager);
|
||||||
|
|
||||||
@@ -749,7 +746,7 @@ describe('AgentRegistry', () => {
|
|||||||
// No auth configured
|
// No auth configured
|
||||||
};
|
};
|
||||||
|
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
loadAgent: vi.fn().mockResolvedValue({
|
loadAgent: vi.fn().mockResolvedValue({
|
||||||
name: 'SecuredAgent',
|
name: 'SecuredAgent',
|
||||||
securitySchemes: {
|
securitySchemes: {
|
||||||
@@ -783,7 +780,7 @@ describe('AgentRegistry', () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const error = new Error('401 Unauthorized');
|
const error = new Error('401 Unauthorized');
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
loadAgent: vi.fn().mockRejectedValue(error),
|
loadAgent: vi.fn().mockRejectedValue(error),
|
||||||
} as unknown as A2AClientManager);
|
} as unknown as A2AClientManager);
|
||||||
|
|
||||||
@@ -815,7 +812,7 @@ describe('AgentRegistry', () => {
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
loadAgent: vi.fn().mockResolvedValue(mockAgentCard),
|
loadAgent: vi.fn().mockResolvedValue(mockAgentCard),
|
||||||
clearCache: vi.fn(),
|
clearCache: vi.fn(),
|
||||||
} as unknown as A2AClientManager);
|
} as unknown as A2AClientManager);
|
||||||
@@ -843,7 +840,7 @@ describe('AgentRegistry', () => {
|
|||||||
skills: [{ name: 'Skill1', description: 'Desc1' }],
|
skills: [{ name: 'Skill1', description: 'Desc1' }],
|
||||||
};
|
};
|
||||||
|
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
loadAgent: vi.fn().mockResolvedValue(mockAgentCard),
|
loadAgent: vi.fn().mockResolvedValue(mockAgentCard),
|
||||||
clearCache: vi.fn(),
|
clearCache: vi.fn(),
|
||||||
} as unknown as A2AClientManager);
|
} as unknown as A2AClientManager);
|
||||||
@@ -871,7 +868,7 @@ describe('AgentRegistry', () => {
|
|||||||
skills: [],
|
skills: [],
|
||||||
};
|
};
|
||||||
|
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
loadAgent: vi.fn().mockResolvedValue(mockAgentCard),
|
loadAgent: vi.fn().mockResolvedValue(mockAgentCard),
|
||||||
clearCache: vi.fn(),
|
clearCache: vi.fn(),
|
||||||
} as unknown as A2AClientManager);
|
} as unknown as A2AClientManager);
|
||||||
@@ -902,7 +899,7 @@ describe('AgentRegistry', () => {
|
|||||||
skills: [{ name: 'Skill1', description: 'Desc1' }],
|
skills: [{ name: 'Skill1', description: 'Desc1' }],
|
||||||
};
|
};
|
||||||
|
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
loadAgent: vi.fn().mockResolvedValue(mockAgentCard),
|
loadAgent: vi.fn().mockResolvedValue(mockAgentCard),
|
||||||
clearCache: vi.fn(),
|
clearCache: vi.fn(),
|
||||||
} as unknown as A2AClientManager);
|
} as unknown as A2AClientManager);
|
||||||
@@ -930,7 +927,7 @@ describe('AgentRegistry', () => {
|
|||||||
inputConfig: { inputSchema: { type: 'object' } },
|
inputConfig: { inputSchema: { type: 'object' } },
|
||||||
};
|
};
|
||||||
|
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
loadAgent: vi.fn().mockResolvedValue({
|
loadAgent: vi.fn().mockResolvedValue({
|
||||||
name: 'EmptyDescAgent',
|
name: 'EmptyDescAgent',
|
||||||
description: 'Loaded from card',
|
description: 'Loaded from card',
|
||||||
@@ -955,7 +952,7 @@ describe('AgentRegistry', () => {
|
|||||||
inputConfig: { inputSchema: { type: 'object' } },
|
inputConfig: { inputSchema: { type: 'object' } },
|
||||||
};
|
};
|
||||||
|
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
loadAgent: vi.fn().mockResolvedValue({
|
loadAgent: vi.fn().mockResolvedValue({
|
||||||
name: 'SkillFallbackAgent',
|
name: 'SkillFallbackAgent',
|
||||||
description: 'Card description',
|
description: 'Card description',
|
||||||
@@ -1092,7 +1089,7 @@ describe('AgentRegistry', () => {
|
|||||||
inputConfig: { inputSchema: { type: 'object' } },
|
inputConfig: { inputSchema: { type: 'object' } },
|
||||||
};
|
};
|
||||||
|
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
loadAgent: vi.fn().mockResolvedValue({ name: 'RemotePolicyAgent' }),
|
loadAgent: vi.fn().mockResolvedValue({ name: 'RemotePolicyAgent' }),
|
||||||
} as unknown as A2AClientManager);
|
} as unknown as A2AClientManager);
|
||||||
|
|
||||||
@@ -1141,7 +1138,7 @@ describe('AgentRegistry', () => {
|
|||||||
inputConfig: { inputSchema: { type: 'object' } },
|
inputConfig: { inputSchema: { type: 'object' } },
|
||||||
};
|
};
|
||||||
|
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(mockConfig, 'getA2AClientManager').mockReturnValue({
|
||||||
loadAgent: vi.fn().mockResolvedValue({ name: 'OverwrittenAgent' }),
|
loadAgent: vi.fn().mockResolvedValue({ name: 'OverwrittenAgent' }),
|
||||||
} as unknown as A2AClientManager);
|
} as unknown as A2AClientManager);
|
||||||
|
|
||||||
@@ -1189,8 +1186,10 @@ describe('AgentRegistry', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
const clearCacheSpy = vi.fn();
|
const clearCacheSpy = vi.fn();
|
||||||
vi.mocked(A2AClientManager.getInstance).mockReturnValue({
|
vi.spyOn(config, 'getA2AClientManager').mockReturnValue({
|
||||||
clearCache: clearCacheSpy,
|
clearCache: clearCacheSpy,
|
||||||
|
loadAgent: vi.fn(),
|
||||||
|
getClient: vi.fn(),
|
||||||
} as unknown as A2AClientManager);
|
} as unknown as A2AClientManager);
|
||||||
|
|
||||||
const emitSpy = vi.spyOn(coreEvents, 'emitAgentsRefreshed');
|
const emitSpy = vi.spyOn(coreEvents, 'emitAgentsRefreshed');
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import { CodebaseInvestigatorAgent } from './codebase-investigator.js';
|
|||||||
import { CliHelpAgent } from './cli-help-agent.js';
|
import { CliHelpAgent } from './cli-help-agent.js';
|
||||||
import { GeneralistAgent } from './generalist-agent.js';
|
import { GeneralistAgent } from './generalist-agent.js';
|
||||||
import { BrowserAgentDefinition } from './browser/browserAgentDefinition.js';
|
import { BrowserAgentDefinition } from './browser/browserAgentDefinition.js';
|
||||||
import { A2AClientManager } from './a2a-client-manager.js';
|
|
||||||
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
|
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
|
||||||
import type { AuthenticationHandler } from '@a2a-js/sdk/client';
|
import type { AuthenticationHandler } from '@a2a-js/sdk/client';
|
||||||
import { type z } from 'zod';
|
import { type z } from 'zod';
|
||||||
@@ -69,7 +68,7 @@ export class AgentRegistry {
|
|||||||
* Clears the current registry and re-scans for agents.
|
* Clears the current registry and re-scans for agents.
|
||||||
*/
|
*/
|
||||||
async reload(): Promise<void> {
|
async reload(): Promise<void> {
|
||||||
A2AClientManager.getInstance(this.config).clearCache();
|
this.config.getA2AClientManager()?.clearCache();
|
||||||
await this.config.reloadAgents();
|
await this.config.reloadAgents();
|
||||||
this.agents.clear();
|
this.agents.clear();
|
||||||
this.allDefinitions.clear();
|
this.allDefinitions.clear();
|
||||||
@@ -414,7 +413,13 @@ export class AgentRegistry {
|
|||||||
|
|
||||||
// Load the remote A2A agent card and register.
|
// Load the remote A2A agent card and register.
|
||||||
try {
|
try {
|
||||||
const clientManager = A2AClientManager.getInstance(this.config);
|
const clientManager = this.config.getA2AClientManager();
|
||||||
|
if (!clientManager) {
|
||||||
|
debugLogger.warn(
|
||||||
|
`[AgentRegistry] Skipping remote agent '${definition.name}': A2AClientManager is not available.`,
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
let authHandler: AuthenticationHandler | undefined;
|
let authHandler: AuthenticationHandler | undefined;
|
||||||
if (definition.auth) {
|
if (definition.auth) {
|
||||||
const provider = await A2AAuthProviderFactory.create({
|
const provider = await A2AAuthProviderFactory.create({
|
||||||
|
|||||||
@@ -13,21 +13,27 @@ import {
|
|||||||
afterEach,
|
afterEach,
|
||||||
type Mock,
|
type Mock,
|
||||||
} from 'vitest';
|
} from 'vitest';
|
||||||
|
import type { Client } from '@a2a-js/sdk/client';
|
||||||
import { RemoteAgentInvocation } from './remote-invocation.js';
|
import { RemoteAgentInvocation } from './remote-invocation.js';
|
||||||
import {
|
import {
|
||||||
A2AClientManager,
|
|
||||||
type SendMessageResult,
|
type SendMessageResult,
|
||||||
|
type A2AClientManager,
|
||||||
} from './a2a-client-manager.js';
|
} from './a2a-client-manager.js';
|
||||||
|
|
||||||
import type { RemoteAgentDefinition } from './types.js';
|
import type { RemoteAgentDefinition } from './types.js';
|
||||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||||
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
|
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
|
||||||
import type { A2AAuthProvider } from './auth-provider/types.js';
|
import type { A2AAuthProvider } from './auth-provider/types.js';
|
||||||
|
import type { AgentLoopContext } from '../config/agent-loop-context.js';
|
||||||
|
import type { Config } from '../config/config.js';
|
||||||
|
|
||||||
// Mock A2AClientManager
|
// Mock A2AClientManager
|
||||||
vi.mock('./a2a-client-manager.js', () => ({
|
vi.mock('./a2a-client-manager.js', () => ({
|
||||||
A2AClientManager: {
|
A2AClientManager: vi.fn().mockImplementation(() => ({
|
||||||
getInstance: vi.fn(),
|
getClient: vi.fn(),
|
||||||
},
|
loadAgent: vi.fn(),
|
||||||
|
sendMessageStream: vi.fn(),
|
||||||
|
})),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Mock A2AAuthProviderFactory
|
// Mock A2AAuthProviderFactory
|
||||||
@@ -49,16 +55,40 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
const mockClientManager = {
|
let mockClientManager: {
|
||||||
getClient: vi.fn(),
|
getClient: Mock<A2AClientManager['getClient']>;
|
||||||
loadAgent: vi.fn(),
|
loadAgent: Mock<A2AClientManager['loadAgent']>;
|
||||||
sendMessageStream: vi.fn(),
|
sendMessageStream: Mock<A2AClientManager['sendMessageStream']>;
|
||||||
};
|
};
|
||||||
|
let mockContext: AgentLoopContext;
|
||||||
const mockMessageBus = createMockMessageBus();
|
const mockMessageBus = createMockMessageBus();
|
||||||
|
|
||||||
|
const mockClient = {
|
||||||
|
sendMessageStream: vi.fn(),
|
||||||
|
getTask: vi.fn(),
|
||||||
|
cancelTask: vi.fn(),
|
||||||
|
} as unknown as Client;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.clearAllMocks();
|
vi.clearAllMocks();
|
||||||
(A2AClientManager.getInstance as Mock).mockReturnValue(mockClientManager);
|
|
||||||
|
mockClientManager = {
|
||||||
|
getClient: vi.fn(),
|
||||||
|
loadAgent: vi.fn(),
|
||||||
|
sendMessageStream: vi.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
const mockConfig = {
|
||||||
|
getA2AClientManager: vi.fn().mockReturnValue(mockClientManager),
|
||||||
|
injectionService: {
|
||||||
|
getLatestInjectionIndex: vi.fn().mockReturnValue(0),
|
||||||
|
},
|
||||||
|
} as unknown as Config;
|
||||||
|
|
||||||
|
mockContext = {
|
||||||
|
config: mockConfig,
|
||||||
|
} as unknown as AgentLoopContext;
|
||||||
|
|
||||||
(
|
(
|
||||||
RemoteAgentInvocation as unknown as {
|
RemoteAgentInvocation as unknown as {
|
||||||
sessionState?: Map<string, { contextId?: string; taskId?: string }>;
|
sessionState?: Map<string, { contextId?: string; taskId?: string }>;
|
||||||
@@ -75,6 +105,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
expect(() => {
|
expect(() => {
|
||||||
new RemoteAgentInvocation(
|
new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{ query: 'valid' },
|
{ query: 'valid' },
|
||||||
mockMessageBus,
|
mockMessageBus,
|
||||||
);
|
);
|
||||||
@@ -83,12 +114,17 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
|
|
||||||
it('accepts missing query (defaults to "Get Started!")', () => {
|
it('accepts missing query (defaults to "Get Started!")', () => {
|
||||||
expect(() => {
|
expect(() => {
|
||||||
new RemoteAgentInvocation(mockDefinition, {}, mockMessageBus);
|
new RemoteAgentInvocation(
|
||||||
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
|
{},
|
||||||
|
mockMessageBus,
|
||||||
|
);
|
||||||
}).not.toThrow();
|
}).not.toThrow();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('uses "Get Started!" default when query is missing during execution', async () => {
|
it('uses "Get Started!" default when query is missing during execution', async () => {
|
||||||
mockClientManager.getClient.mockReturnValue({});
|
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||||
mockClientManager.sendMessageStream.mockImplementation(
|
mockClientManager.sendMessageStream.mockImplementation(
|
||||||
async function* () {
|
async function* () {
|
||||||
yield {
|
yield {
|
||||||
@@ -102,6 +138,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
|
|
||||||
const invocation = new RemoteAgentInvocation(
|
const invocation = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{},
|
{},
|
||||||
mockMessageBus,
|
mockMessageBus,
|
||||||
);
|
);
|
||||||
@@ -118,6 +155,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
expect(() => {
|
expect(() => {
|
||||||
new RemoteAgentInvocation(
|
new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{ query: 123 },
|
{ query: 123 },
|
||||||
mockMessageBus,
|
mockMessageBus,
|
||||||
);
|
);
|
||||||
@@ -141,6 +179,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
|
|
||||||
const invocation = new RemoteAgentInvocation(
|
const invocation = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{
|
{
|
||||||
query: 'hi',
|
query: 'hi',
|
||||||
},
|
},
|
||||||
@@ -187,6 +226,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
|
|
||||||
const invocation = new RemoteAgentInvocation(
|
const invocation = new RemoteAgentInvocation(
|
||||||
authDefinition,
|
authDefinition,
|
||||||
|
mockContext,
|
||||||
{ query: 'hi' },
|
{ query: 'hi' },
|
||||||
mockMessageBus,
|
mockMessageBus,
|
||||||
);
|
);
|
||||||
@@ -220,6 +260,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
|
|
||||||
const invocation = new RemoteAgentInvocation(
|
const invocation = new RemoteAgentInvocation(
|
||||||
authDefinition,
|
authDefinition,
|
||||||
|
mockContext,
|
||||||
{ query: 'hi' },
|
{ query: 'hi' },
|
||||||
mockMessageBus,
|
mockMessageBus,
|
||||||
);
|
);
|
||||||
@@ -231,7 +272,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should not load the agent if already present', async () => {
|
it('should not load the agent if already present', async () => {
|
||||||
mockClientManager.getClient.mockReturnValue({});
|
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||||
mockClientManager.sendMessageStream.mockImplementation(
|
mockClientManager.sendMessageStream.mockImplementation(
|
||||||
async function* () {
|
async function* () {
|
||||||
yield {
|
yield {
|
||||||
@@ -245,6 +286,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
|
|
||||||
const invocation = new RemoteAgentInvocation(
|
const invocation = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{
|
{
|
||||||
query: 'hi',
|
query: 'hi',
|
||||||
},
|
},
|
||||||
@@ -256,7 +298,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should persist contextId and taskId across invocations', async () => {
|
it('should persist contextId and taskId across invocations', async () => {
|
||||||
mockClientManager.getClient.mockReturnValue({});
|
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||||
|
|
||||||
// First call return values
|
// First call return values
|
||||||
mockClientManager.sendMessageStream.mockImplementationOnce(
|
mockClientManager.sendMessageStream.mockImplementationOnce(
|
||||||
@@ -274,6 +316,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
|
|
||||||
const invocation1 = new RemoteAgentInvocation(
|
const invocation1 = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{
|
{
|
||||||
query: 'first',
|
query: 'first',
|
||||||
},
|
},
|
||||||
@@ -305,6 +348,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
|
|
||||||
const invocation2 = new RemoteAgentInvocation(
|
const invocation2 = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{
|
{
|
||||||
query: 'second',
|
query: 'second',
|
||||||
},
|
},
|
||||||
@@ -335,6 +379,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
|
|
||||||
const invocation3 = new RemoteAgentInvocation(
|
const invocation3 = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{
|
{
|
||||||
query: 'third',
|
query: 'third',
|
||||||
},
|
},
|
||||||
@@ -356,6 +401,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
|
|
||||||
const invocation4 = new RemoteAgentInvocation(
|
const invocation4 = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{
|
{
|
||||||
query: 'fourth',
|
query: 'fourth',
|
||||||
},
|
},
|
||||||
@@ -371,7 +417,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should handle streaming updates and reassemble output', async () => {
|
it('should handle streaming updates and reassemble output', async () => {
|
||||||
mockClientManager.getClient.mockReturnValue({});
|
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||||
mockClientManager.sendMessageStream.mockImplementation(
|
mockClientManager.sendMessageStream.mockImplementation(
|
||||||
async function* () {
|
async function* () {
|
||||||
yield {
|
yield {
|
||||||
@@ -392,6 +438,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
const updateOutput = vi.fn();
|
const updateOutput = vi.fn();
|
||||||
const invocation = new RemoteAgentInvocation(
|
const invocation = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{ query: 'hi' },
|
{ query: 'hi' },
|
||||||
mockMessageBus,
|
mockMessageBus,
|
||||||
);
|
);
|
||||||
@@ -402,7 +449,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should abort when signal is aborted during streaming', async () => {
|
it('should abort when signal is aborted during streaming', async () => {
|
||||||
mockClientManager.getClient.mockReturnValue({});
|
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
mockClientManager.sendMessageStream.mockImplementation(
|
mockClientManager.sendMessageStream.mockImplementation(
|
||||||
async function* () {
|
async function* () {
|
||||||
@@ -425,6 +472,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
|
|
||||||
const invocation = new RemoteAgentInvocation(
|
const invocation = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{ query: 'hi' },
|
{ query: 'hi' },
|
||||||
mockMessageBus,
|
mockMessageBus,
|
||||||
);
|
);
|
||||||
@@ -435,7 +483,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should handle errors gracefully', async () => {
|
it('should handle errors gracefully', async () => {
|
||||||
mockClientManager.getClient.mockReturnValue({});
|
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||||
mockClientManager.sendMessageStream.mockImplementation(
|
mockClientManager.sendMessageStream.mockImplementation(
|
||||||
async function* () {
|
async function* () {
|
||||||
if (Math.random() < 0) yield {} as unknown as SendMessageResult;
|
if (Math.random() < 0) yield {} as unknown as SendMessageResult;
|
||||||
@@ -445,6 +493,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
|
|
||||||
const invocation = new RemoteAgentInvocation(
|
const invocation = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{
|
{
|
||||||
query: 'hi',
|
query: 'hi',
|
||||||
},
|
},
|
||||||
@@ -458,7 +507,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should use a2a helpers for extracting text', async () => {
|
it('should use a2a helpers for extracting text', async () => {
|
||||||
mockClientManager.getClient.mockReturnValue({});
|
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||||
// Mock a complex message part that needs extraction
|
// Mock a complex message part that needs extraction
|
||||||
mockClientManager.sendMessageStream.mockImplementation(
|
mockClientManager.sendMessageStream.mockImplementation(
|
||||||
async function* () {
|
async function* () {
|
||||||
@@ -476,6 +525,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
|
|
||||||
const invocation = new RemoteAgentInvocation(
|
const invocation = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{
|
{
|
||||||
query: 'hi',
|
query: 'hi',
|
||||||
},
|
},
|
||||||
@@ -488,7 +538,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should handle mixed response types during streaming (TaskStatusUpdateEvent + Message)', async () => {
|
it('should handle mixed response types during streaming (TaskStatusUpdateEvent + Message)', async () => {
|
||||||
mockClientManager.getClient.mockReturnValue({});
|
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||||
mockClientManager.sendMessageStream.mockImplementation(
|
mockClientManager.sendMessageStream.mockImplementation(
|
||||||
async function* () {
|
async function* () {
|
||||||
yield {
|
yield {
|
||||||
@@ -518,6 +568,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
const updateOutput = vi.fn();
|
const updateOutput = vi.fn();
|
||||||
const invocation = new RemoteAgentInvocation(
|
const invocation = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{ query: 'hi' },
|
{ query: 'hi' },
|
||||||
mockMessageBus,
|
mockMessageBus,
|
||||||
);
|
);
|
||||||
@@ -532,17 +583,20 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should handle artifact reassembly with append: true', async () => {
|
it('should handle artifact reassembly with append: true', async () => {
|
||||||
mockClientManager.getClient.mockReturnValue({});
|
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||||
mockClientManager.sendMessageStream.mockImplementation(
|
mockClientManager.sendMessageStream.mockImplementation(
|
||||||
async function* () {
|
async function* () {
|
||||||
yield {
|
yield {
|
||||||
kind: 'status-update',
|
kind: 'status-update',
|
||||||
taskId: 'task-1',
|
taskId: 'task-1',
|
||||||
|
contextId: 'ctx-1',
|
||||||
|
final: false,
|
||||||
status: {
|
status: {
|
||||||
state: 'working',
|
state: 'working',
|
||||||
message: {
|
message: {
|
||||||
kind: 'message',
|
kind: 'message',
|
||||||
role: 'agent',
|
role: 'agent',
|
||||||
|
messageId: 'm1',
|
||||||
parts: [{ kind: 'text', text: 'Generating...' }],
|
parts: [{ kind: 'text', text: 'Generating...' }],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -550,6 +604,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
yield {
|
yield {
|
||||||
kind: 'artifact-update',
|
kind: 'artifact-update',
|
||||||
taskId: 'task-1',
|
taskId: 'task-1',
|
||||||
|
contextId: 'ctx-1',
|
||||||
append: false,
|
append: false,
|
||||||
artifact: {
|
artifact: {
|
||||||
artifactId: 'art-1',
|
artifactId: 'art-1',
|
||||||
@@ -560,18 +615,21 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
yield {
|
yield {
|
||||||
kind: 'artifact-update',
|
kind: 'artifact-update',
|
||||||
taskId: 'task-1',
|
taskId: 'task-1',
|
||||||
|
contextId: 'ctx-1',
|
||||||
append: true,
|
append: true,
|
||||||
artifact: {
|
artifact: {
|
||||||
artifactId: 'art-1',
|
artifactId: 'art-1',
|
||||||
parts: [{ kind: 'text', text: ' Part 2' }],
|
parts: [{ kind: 'text', text: ' Part 2' }],
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
return;
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
const updateOutput = vi.fn();
|
const updateOutput = vi.fn();
|
||||||
const invocation = new RemoteAgentInvocation(
|
const invocation = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{ query: 'hi' },
|
{ query: 'hi' },
|
||||||
mockMessageBus,
|
mockMessageBus,
|
||||||
);
|
);
|
||||||
@@ -591,6 +649,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
it('should return info confirmation details', async () => {
|
it('should return info confirmation details', async () => {
|
||||||
const invocation = new RemoteAgentInvocation(
|
const invocation = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{
|
{
|
||||||
query: 'hi',
|
query: 'hi',
|
||||||
},
|
},
|
||||||
@@ -629,6 +688,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
|
|
||||||
const invocation = new RemoteAgentInvocation(
|
const invocation = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{ query: 'hi' },
|
{ query: 'hi' },
|
||||||
mockMessageBus,
|
mockMessageBus,
|
||||||
);
|
);
|
||||||
@@ -646,6 +706,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
|
|
||||||
const invocation = new RemoteAgentInvocation(
|
const invocation = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{ query: 'hi' },
|
{ query: 'hi' },
|
||||||
mockMessageBus,
|
mockMessageBus,
|
||||||
);
|
);
|
||||||
@@ -658,7 +719,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should include partial output when error occurs mid-stream', async () => {
|
it('should include partial output when error occurs mid-stream', async () => {
|
||||||
mockClientManager.getClient.mockReturnValue({});
|
mockClientManager.getClient.mockReturnValue(mockClient);
|
||||||
mockClientManager.sendMessageStream.mockImplementation(
|
mockClientManager.sendMessageStream.mockImplementation(
|
||||||
async function* () {
|
async function* () {
|
||||||
yield {
|
yield {
|
||||||
@@ -674,6 +735,7 @@ describe('RemoteAgentInvocation', () => {
|
|||||||
|
|
||||||
const invocation = new RemoteAgentInvocation(
|
const invocation = new RemoteAgentInvocation(
|
||||||
mockDefinition,
|
mockDefinition,
|
||||||
|
mockContext,
|
||||||
{ query: 'hi' },
|
{ query: 'hi' },
|
||||||
mockMessageBus,
|
mockMessageBus,
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -16,10 +16,11 @@ import {
|
|||||||
type RemoteAgentDefinition,
|
type RemoteAgentDefinition,
|
||||||
type AgentInputs,
|
type AgentInputs,
|
||||||
} from './types.js';
|
} from './types.js';
|
||||||
|
import { type AgentLoopContext } from '../config/agent-loop-context.js';
|
||||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||||
import {
|
import type {
|
||||||
A2AClientManager,
|
A2AClientManager,
|
||||||
type SendMessageResult,
|
SendMessageResult,
|
||||||
} from './a2a-client-manager.js';
|
} from './a2a-client-manager.js';
|
||||||
import { extractIdsFromResponse, A2AResultReassembler } from './a2aUtils.js';
|
import { extractIdsFromResponse, A2AResultReassembler } from './a2aUtils.js';
|
||||||
import type { AuthenticationHandler } from '@a2a-js/sdk/client';
|
import type { AuthenticationHandler } from '@a2a-js/sdk/client';
|
||||||
@@ -47,13 +48,13 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
|
|||||||
// State for the ongoing conversation with the remote agent
|
// State for the ongoing conversation with the remote agent
|
||||||
private contextId: string | undefined;
|
private contextId: string | undefined;
|
||||||
private taskId: string | undefined;
|
private taskId: string | undefined;
|
||||||
// TODO: See if we can reuse the singleton from AppContainer or similar, but for now use getInstance directly
|
|
||||||
// as per the current pattern in the codebase.
|
private readonly clientManager: A2AClientManager;
|
||||||
private readonly clientManager = A2AClientManager.getInstance();
|
|
||||||
private authHandler: AuthenticationHandler | undefined;
|
private authHandler: AuthenticationHandler | undefined;
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
private readonly definition: RemoteAgentDefinition,
|
private readonly definition: RemoteAgentDefinition,
|
||||||
|
private readonly context: AgentLoopContext,
|
||||||
params: AgentInputs,
|
params: AgentInputs,
|
||||||
messageBus: MessageBus,
|
messageBus: MessageBus,
|
||||||
_toolName?: string,
|
_toolName?: string,
|
||||||
@@ -72,6 +73,13 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
|
|||||||
_toolName ?? definition.name,
|
_toolName ?? definition.name,
|
||||||
_toolDisplayName ?? definition.displayName,
|
_toolDisplayName ?? definition.displayName,
|
||||||
);
|
);
|
||||||
|
const clientManager = this.context.config.getA2AClientManager();
|
||||||
|
if (!clientManager) {
|
||||||
|
throw new Error(
|
||||||
|
`Failed to initialize RemoteAgentInvocation for '${definition.name}': A2AClientManager is not available.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
this.clientManager = clientManager;
|
||||||
}
|
}
|
||||||
|
|
||||||
getDescription(): string {
|
getDescription(): string {
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ export class SubagentToolWrapper extends BaseDeclarativeTool<
|
|||||||
if (definition.kind === 'remote') {
|
if (definition.kind === 'remote') {
|
||||||
return new RemoteAgentInvocation(
|
return new RemoteAgentInvocation(
|
||||||
definition,
|
definition,
|
||||||
|
this.context,
|
||||||
params,
|
params,
|
||||||
effectiveMessageBus,
|
effectiveMessageBus,
|
||||||
_toolName,
|
_toolName,
|
||||||
|
|||||||
@@ -1523,7 +1523,7 @@ describe('Server Config (config.ts)', () => {
|
|||||||
|
|
||||||
const paramsWithProxy: ConfigParameters = {
|
const paramsWithProxy: ConfigParameters = {
|
||||||
...baseParams,
|
...baseParams,
|
||||||
proxy: 'invalid-proxy',
|
proxy: 'http://invalid-proxy:8080',
|
||||||
};
|
};
|
||||||
new Config(paramsWithProxy);
|
new Config(paramsWithProxy);
|
||||||
|
|
||||||
|
|||||||
@@ -405,6 +405,7 @@ import {
|
|||||||
SimpleExtensionLoader,
|
SimpleExtensionLoader,
|
||||||
} from '../utils/extensionLoader.js';
|
} from '../utils/extensionLoader.js';
|
||||||
import { McpClientManager } from '../tools/mcp-client-manager.js';
|
import { McpClientManager } from '../tools/mcp-client-manager.js';
|
||||||
|
import { A2AClientManager } from '../agents/a2a-client-manager.js';
|
||||||
import { type McpContext } from '../tools/mcp-client.js';
|
import { type McpContext } from '../tools/mcp-client.js';
|
||||||
import type { EnvironmentSanitizationConfig } from '../services/environmentSanitization.js';
|
import type { EnvironmentSanitizationConfig } from '../services/environmentSanitization.js';
|
||||||
import { getErrorMessage } from '../utils/errors.js';
|
import { getErrorMessage } from '../utils/errors.js';
|
||||||
@@ -653,6 +654,7 @@ export interface ConfigParameters {
|
|||||||
export class Config implements McpContext, AgentLoopContext {
|
export class Config implements McpContext, AgentLoopContext {
|
||||||
private _toolRegistry!: ToolRegistry;
|
private _toolRegistry!: ToolRegistry;
|
||||||
private mcpClientManager?: McpClientManager;
|
private mcpClientManager?: McpClientManager;
|
||||||
|
private readonly a2aClientManager?: A2AClientManager;
|
||||||
private allowedMcpServers: string[];
|
private allowedMcpServers: string[];
|
||||||
private blockedMcpServers: string[];
|
private blockedMcpServers: string[];
|
||||||
private allowedEnvironmentVariables: string[];
|
private allowedEnvironmentVariables: string[];
|
||||||
@@ -1188,6 +1190,7 @@ export class Config implements McpContext, AgentLoopContext {
|
|||||||
params.toolSandboxing ?? false,
|
params.toolSandboxing ?? false,
|
||||||
this.targetDir,
|
this.targetDir,
|
||||||
);
|
);
|
||||||
|
this.a2aClientManager = new A2AClientManager(this);
|
||||||
this.shellExecutionConfig.sandboxManager = this._sandboxManager;
|
this.shellExecutionConfig.sandboxManager = this._sandboxManager;
|
||||||
this.modelRouterService = new ModelRouterService(this);
|
this.modelRouterService = new ModelRouterService(this);
|
||||||
}
|
}
|
||||||
@@ -2000,6 +2003,10 @@ export class Config implements McpContext, AgentLoopContext {
|
|||||||
return this.mcpClientManager;
|
return this.mcpClientManager;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getA2AClientManager(): A2AClientManager | undefined {
|
||||||
|
return this.a2aClientManager;
|
||||||
|
}
|
||||||
|
|
||||||
setUserInteractedWithMcp(): void {
|
setUserInteractedWithMcp(): void {
|
||||||
this.mcpClientManager?.setUserInteractedWithMcp();
|
this.mcpClientManager?.setUserInteractedWithMcp();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import {
|
|||||||
ApprovalMode,
|
ApprovalMode,
|
||||||
PRIORITY_SUBAGENT_TOOL,
|
PRIORITY_SUBAGENT_TOOL,
|
||||||
ALWAYS_ALLOW_PRIORITY_FRACTION,
|
ALWAYS_ALLOW_PRIORITY_FRACTION,
|
||||||
|
PRIORITY_YOLO_ALLOW_ALL,
|
||||||
} from './types.js';
|
} from './types.js';
|
||||||
import type { FunctionCall } from '@google/genai';
|
import type { FunctionCall } from '@google/genai';
|
||||||
import { SafetyCheckDecision } from '../safety/protocol.js';
|
import { SafetyCheckDecision } from '../safety/protocol.js';
|
||||||
@@ -2852,7 +2853,7 @@ describe('PolicyEngine', () => {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
decision: PolicyDecision.ALLOW,
|
decision: PolicyDecision.ALLOW,
|
||||||
priority: 998,
|
priority: PRIORITY_YOLO_ALLOW_ALL,
|
||||||
modes: [ApprovalMode.YOLO],
|
modes: [ApprovalMode.YOLO],
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
@@ -2879,7 +2880,7 @@ describe('PolicyEngine', () => {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
decision: PolicyDecision.ALLOW,
|
decision: PolicyDecision.ALLOW,
|
||||||
priority: 998,
|
priority: PRIORITY_YOLO_ALLOW_ALL,
|
||||||
modes: [ApprovalMode.YOLO],
|
modes: [ApprovalMode.YOLO],
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|||||||
@@ -345,3 +345,9 @@ export const ALWAYS_ALLOW_PRIORITY_FRACTION = 950;
|
|||||||
*/
|
*/
|
||||||
export const ALWAYS_ALLOW_PRIORITY_OFFSET =
|
export const ALWAYS_ALLOW_PRIORITY_OFFSET =
|
||||||
ALWAYS_ALLOW_PRIORITY_FRACTION / 1000;
|
ALWAYS_ALLOW_PRIORITY_FRACTION / 1000;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Priority for the YOLO "allow all" rule.
|
||||||
|
* Matches the raw priority used in yolo.toml.
|
||||||
|
*/
|
||||||
|
export const PRIORITY_YOLO_ALLOW_ALL = 998;
|
||||||
|
|||||||
Reference in New Issue
Block a user