mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-11 06:31:01 -07:00
feat(a2a): implement secure acknowledgement and correct agent indexing
This commit is contained in:
@@ -66,6 +66,18 @@ export class AcknowledgedAgentsService {
|
||||
hash: string,
|
||||
): Promise<boolean> {
|
||||
await this.load();
|
||||
return this.isAcknowledgedSync(projectPath, agentName, hash);
|
||||
}
|
||||
|
||||
/**
|
||||
* Synchronous check for acknowledgment.
|
||||
* Note: Assumes load() has already been called and awaited (e.g. during registry init).
|
||||
*/
|
||||
isAcknowledgedSync(
|
||||
projectPath: string,
|
||||
agentName: string,
|
||||
hash: string,
|
||||
): boolean {
|
||||
const projectAgents = this.acknowledgedAgents[projectPath];
|
||||
if (!projectAgents) return false;
|
||||
return projectAgents[agentName] === hash;
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import * as crypto from 'node:crypto';
|
||||
import { AgentRegistry, getModelConfigAlias } from './registry.js';
|
||||
import { makeFakeConfig } from '../test-utils/config.js';
|
||||
import type { AgentDefinition, LocalAgentDefinition } from './types.js';
|
||||
@@ -29,10 +30,23 @@ import { SimpleExtensionLoader } from '../utils/extensionLoader.js';
|
||||
import type { ToolRegistry } from '../tools/tool-registry.js';
|
||||
import { ThinkingLevel } from '@google/genai';
|
||||
import type { AcknowledgedAgentsService } from './acknowledgedAgents.js';
|
||||
import * as sdkClient from '@a2a-js/sdk/client';
|
||||
import { safeFetch } from '../utils/fetch.js';
|
||||
import { PolicyDecision } from '../policy/types.js';
|
||||
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
|
||||
import type { A2AAuthProvider } from './auth-provider/types.js';
|
||||
|
||||
vi.mock('@a2a-js/sdk/client', async (importOriginal) => {
|
||||
const actual = await importOriginal();
|
||||
return {
|
||||
...(actual as Record<string, unknown>),
|
||||
DefaultAgentCardResolver: vi.fn().mockImplementation((options) => ({
|
||||
fetchImpl: options?.fetchImpl,
|
||||
resolve: vi.fn().mockResolvedValue({ name: 'RemoteAgent' }),
|
||||
})),
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('./agentLoader.js', () => ({
|
||||
loadAgentsFromDirectory: vi
|
||||
.fn()
|
||||
@@ -417,7 +431,7 @@ describe('AgentRegistry', () => {
|
||||
expect(registry.getDefinition('extension-agent')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should use agentCardUrl as hash for acknowledgement of remote agents', async () => {
|
||||
it('should use agentCardUrl and content-based hash for acknowledgement of remote agents', async () => {
|
||||
mockConfig = makeMockedConfig({ enableAgents: true });
|
||||
// Trust the folder so it attempts to load project agents
|
||||
vi.spyOn(mockConfig, 'isTrustedFolder').mockReturnValue(true);
|
||||
@@ -453,21 +467,75 @@ describe('AgentRegistry', () => {
|
||||
clearCache: vi.fn(),
|
||||
} as unknown as A2AClientManager);
|
||||
|
||||
// Mock the resolver to return a consistent card content for hashing
|
||||
const mockCardContent = { name: 'RemoteAgent' };
|
||||
const expectedContentHash = crypto
|
||||
.createHash('sha256')
|
||||
.update(JSON.stringify(mockCardContent))
|
||||
.digest('hex');
|
||||
const expectedHash = `https://example.com/card#${expectedContentHash}`;
|
||||
|
||||
vi.mocked(sdkClient.DefaultAgentCardResolver).mockImplementation(
|
||||
() =>
|
||||
({
|
||||
resolve: vi.fn().mockResolvedValue(mockCardContent),
|
||||
}) as unknown as sdkClient.DefaultAgentCardResolver,
|
||||
);
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Verify ackService was called with the URL, not the file hash
|
||||
// Verify ackService was called with the content-based hash
|
||||
expect(ackService.isAcknowledged).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
'RemoteAgent',
|
||||
'https://example.com/card',
|
||||
expectedHash,
|
||||
);
|
||||
|
||||
// Also verify that the agent's metadata was updated to use the URL as hash
|
||||
// Use getDefinition because registerAgent might have been called
|
||||
// Also verify that the agent's metadata was updated to use the content-based hash
|
||||
expect(registry.getDefinition('RemoteAgent')?.metadata?.hash).toBe(
|
||||
'https://example.com/card',
|
||||
expectedHash,
|
||||
);
|
||||
});
|
||||
|
||||
it('should use safeFetch in DefaultAgentCardResolver during initialization', async () => {
|
||||
mockConfig = makeMockedConfig({ enableAgents: true });
|
||||
vi.spyOn(mockConfig, 'isTrustedFolder').mockReturnValue(true);
|
||||
vi.spyOn(mockConfig, 'getFolderTrust').mockReturnValue(true);
|
||||
|
||||
const registry = new TestableAgentRegistry(mockConfig);
|
||||
|
||||
const remoteAgent: AgentDefinition = {
|
||||
kind: 'remote',
|
||||
name: 'RemoteAgent',
|
||||
description: 'A remote agent',
|
||||
agentCardUrl: 'https://example.com/card',
|
||||
inputConfig: { inputSchema: { type: 'object' } },
|
||||
};
|
||||
|
||||
vi.mocked(tomlLoader.loadAgentsFromDirectory).mockResolvedValue({
|
||||
agents: [remoteAgent],
|
||||
errors: [],
|
||||
});
|
||||
|
||||
// Track constructor calls
|
||||
const resolverMock = vi.mocked(sdkClient.DefaultAgentCardResolver);
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Find the call for our remote agent
|
||||
const call = resolverMock.mock.calls.find((args) => {
|
||||
const options = args[0] as { fetchImpl?: typeof fetch };
|
||||
// We look for a call that was provided with a fetch implementation.
|
||||
// In our current implementation, we wrap safeFetch.
|
||||
return typeof options?.fetchImpl === 'function';
|
||||
});
|
||||
|
||||
expect(call).toBeDefined();
|
||||
const options = call?.[0] as { fetchImpl?: typeof fetch };
|
||||
|
||||
// We passed safeFetch directly
|
||||
expect(options?.fetchImpl).toBe(safeFetch);
|
||||
});
|
||||
});
|
||||
|
||||
describe('registration logic', () => {
|
||||
@@ -874,6 +942,17 @@ describe('AgentRegistry', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should maintain registration under canonical name', async () => {
|
||||
const originalName = 'my-agent';
|
||||
const definition = { ...MOCK_AGENT_V1, name: originalName };
|
||||
|
||||
await registry.testRegisterAgent(definition);
|
||||
|
||||
const registered = registry.getDefinition(originalName);
|
||||
expect(registered).toBeDefined();
|
||||
expect(registry.getAllAgentNames()).toEqual([originalName]);
|
||||
});
|
||||
|
||||
it('should reject an agent definition missing a name', async () => {
|
||||
const invalidAgent = { ...MOCK_AGENT_V1, name: '' };
|
||||
const debugWarnSpy = vi
|
||||
|
||||
@@ -9,14 +9,18 @@ import { CoreEvent, coreEvents } from '../utils/events.js';
|
||||
import type { AgentOverride, Config } from '../config/config.js';
|
||||
import type { AgentDefinition, LocalAgentDefinition } from './types.js';
|
||||
import { loadAgentsFromDirectory } from './agentLoader.js';
|
||||
import { splitAgentCardUrl } from './a2aUtils.js';
|
||||
import { CodebaseInvestigatorAgent } from './codebase-investigator.js';
|
||||
import { CliHelpAgent } from './cli-help-agent.js';
|
||||
import { GeneralistAgent } from './generalist-agent.js';
|
||||
import { BrowserAgentDefinition } from './browser/browserAgentDefinition.js';
|
||||
import { A2AClientManager } from './a2a-client-manager.js';
|
||||
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
|
||||
import { DefaultAgentCardResolver } from '@a2a-js/sdk/client';
|
||||
import { safeFetch } from '../utils/fetch.js';
|
||||
import type { AuthenticationHandler } from '@a2a-js/sdk/client';
|
||||
import { type z } from 'zod';
|
||||
import * as crypto from 'node:crypto';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { isAutoModel } from '../config/models.js';
|
||||
import {
|
||||
@@ -155,13 +159,38 @@ export class AgentRegistry {
|
||||
const agentsToRegister: AgentDefinition[] = [];
|
||||
|
||||
for (const agent of projectAgents.agents) {
|
||||
// If it's a remote agent, use the agentCardUrl as the hash.
|
||||
// This allows multiple remote agents in a single file to be tracked independently.
|
||||
// For remote agents, we must use a content-based hash of the AgentCard
|
||||
// to prevent Indirect Prompt Injection if the remote card is modified.
|
||||
if (agent.kind === 'remote') {
|
||||
if (!agent.metadata) {
|
||||
agent.metadata = {};
|
||||
try {
|
||||
// We use a dedicated resolver here to fetch the card for hashing.
|
||||
// This is separate from loadAgent to keep hashing logic isolated.
|
||||
// We provide safeFetch to ensure SSRF and DNS rebinding protection.
|
||||
const resolver = new DefaultAgentCardResolver({
|
||||
fetchImpl: safeFetch,
|
||||
});
|
||||
const { baseUrl, path } = splitAgentCardUrl(agent.agentCardUrl);
|
||||
const rawCard = await resolver.resolve(baseUrl, path);
|
||||
const cardContent = JSON.stringify(rawCard);
|
||||
const contentHash = crypto
|
||||
.createHash('sha256')
|
||||
.update(cardContent)
|
||||
.digest('hex');
|
||||
|
||||
if (!agent.metadata) {
|
||||
agent.metadata = {};
|
||||
}
|
||||
// Combining URL and content hash ensures we track specific card versions at specific locations.
|
||||
agent.metadata.hash = `${agent.agentCardUrl}#${contentHash}`;
|
||||
} catch (e) {
|
||||
debugLogger.warn(
|
||||
`[AgentRegistry] Could not fetch remote card for hashing "${agent.name}":`,
|
||||
e,
|
||||
);
|
||||
// If we can't fetch the card, we can't verify its acknowledgement securely.
|
||||
unacknowledgedAgents.push(agent);
|
||||
continue;
|
||||
}
|
||||
agent.metadata.hash = agent.agentCardUrl;
|
||||
}
|
||||
|
||||
if (!agent.metadata?.hash) {
|
||||
@@ -178,6 +207,10 @@ export class AgentRegistry {
|
||||
if (isAcknowledged) {
|
||||
agentsToRegister.push(agent);
|
||||
} else {
|
||||
// Register unacknowledged agents so they are visible to the LLM.
|
||||
// They will be registered with ASK_USER policy, triggering the
|
||||
// acknowledgement flow when the LLM tries to call them.
|
||||
agentsToRegister.push(agent);
|
||||
unacknowledgedAgents.push(agent);
|
||||
}
|
||||
}
|
||||
@@ -312,6 +345,17 @@ export class AgentRegistry {
|
||||
}
|
||||
|
||||
const mergedDefinition = this.applyOverrides(definition, settingsOverrides);
|
||||
|
||||
// Ensure we don't accidentally overwrite an existing agent with a different canonical name
|
||||
if (
|
||||
mergedDefinition.name !== definition.name &&
|
||||
this.agents.has(mergedDefinition.name)
|
||||
) {
|
||||
throw new Error(
|
||||
`Cannot register agent '${definition.name}' as '${mergedDefinition.name}': Name collision with an already registered agent.`,
|
||||
);
|
||||
}
|
||||
|
||||
this.agents.set(mergedDefinition.name, mergedDefinition);
|
||||
|
||||
this.registerModelConfigs(mergedDefinition);
|
||||
@@ -339,12 +383,21 @@ export class AgentRegistry {
|
||||
policyEngine.removeRulesForTool(definition.name, 'AgentRegistry (Dynamic)');
|
||||
|
||||
// Add the new dynamic policy
|
||||
const isAcknowledged =
|
||||
definition.kind === 'local' &&
|
||||
(!definition.metadata?.hash ||
|
||||
(this.config.getProjectRoot() &&
|
||||
this.config
|
||||
.getAcknowledgedAgentsService()
|
||||
?.isAcknowledgedSync?.(
|
||||
this.config.getProjectRoot(),
|
||||
definition.name,
|
||||
definition.metadata.hash,
|
||||
)));
|
||||
|
||||
policyEngine.addRule({
|
||||
toolName: definition.name,
|
||||
decision:
|
||||
definition.kind === 'local'
|
||||
? PolicyDecision.ALLOW
|
||||
: PolicyDecision.ASK_USER,
|
||||
decision: isAcknowledged ? PolicyDecision.ALLOW : PolicyDecision.ASK_USER,
|
||||
priority: PRIORITY_SUBAGENT_TOOL,
|
||||
source: 'AgentRegistry (Dynamic)',
|
||||
});
|
||||
|
||||
@@ -12,6 +12,7 @@ import { coreEvents } from '../utils/events.js';
|
||||
import * as tomlLoader from './agentLoader.js';
|
||||
import { type Config } from '../config/config.js';
|
||||
import { AcknowledgedAgentsService } from './acknowledgedAgents.js';
|
||||
import { PolicyDecision } from '../policy/types.js';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import * as path from 'node:path';
|
||||
import * as os from 'node:os';
|
||||
@@ -103,13 +104,22 @@ describe('AgentRegistry Acknowledgement', () => {
|
||||
await fs.rm(tempDir, { recursive: true, force: true });
|
||||
});
|
||||
|
||||
it('should not register unacknowledged project agents and emit event', async () => {
|
||||
it('should register unacknowledged project agents and emit event', async () => {
|
||||
const emitSpy = vi.spyOn(coreEvents, 'emitAgentsDiscovered');
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
expect(registry.getDefinition('ProjectAgent')).toBeUndefined();
|
||||
// Now unacknowledged agents ARE registered (but with ASK_USER policy)
|
||||
expect(registry.getDefinition('ProjectAgent')).toBeDefined();
|
||||
expect(emitSpy).toHaveBeenCalledWith([MOCK_AGENT_WITH_HASH]);
|
||||
|
||||
// Verify policy
|
||||
const policyEngine = config.getPolicyEngine();
|
||||
expect(
|
||||
await policyEngine?.check({ name: 'ProjectAgent', args: {} }, undefined),
|
||||
).toMatchObject({
|
||||
decision: PolicyDecision.ASK_USER,
|
||||
});
|
||||
});
|
||||
|
||||
it('should register acknowledged project agents', async () => {
|
||||
@@ -134,6 +144,14 @@ describe('AgentRegistry Acknowledgement', () => {
|
||||
|
||||
expect(registry.getDefinition('ProjectAgent')).toBeDefined();
|
||||
expect(emitSpy).not.toHaveBeenCalled();
|
||||
|
||||
// Verify policy is ALLOW for acknowledged agent
|
||||
const policyEngine = config.getPolicyEngine();
|
||||
expect(
|
||||
await policyEngine?.check({ name: 'ProjectAgent', args: {} }, undefined),
|
||||
).toMatchObject({
|
||||
decision: PolicyDecision.ALLOW,
|
||||
});
|
||||
});
|
||||
|
||||
it('should register agents without hash (legacy/safe?)', async () => {
|
||||
|
||||
Reference in New Issue
Block a user