feat(cli): A2A Server Primary Agent Support

- Added a new `/agents use <name>` command to allow users to set a remote agent as the primary agent for the current session.
- Implemented `A2AStreamingAdapter` to seamlessly integrate A2A server task polling with the existing CLI streaming UX.
- Mapped A2A server task status updates to Gemini `Thought` events, providing native thinking visualization in the terminal.
- Updated `nonInteractiveCli` and `useGeminiStream` to dynamically route requests through the `A2AStreamingAdapter` when a session primary agent is configured.
- Ensured context history continuity by maintaining `contextId` and `taskId` across requests within the adapter.
This commit is contained in:
Taylor Mullen
2026-02-23 16:37:37 -08:00
parent 8b1dc15182
commit 38e99a4556
23 changed files with 1395 additions and 20 deletions
File diff suppressed because one or more lines are too long
@@ -93,6 +93,7 @@ describe('ExtensionManager theme loading', () => {
startExtension: vi.fn().mockResolvedValue(undefined),
}),
getGeminiClient: () => ({
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
isInitialized: () => false,
updateSystemInstruction: vi.fn(),
setTools: vi.fn(),
@@ -203,6 +204,7 @@ describe('ExtensionManager theme loading', () => {
setGeminiMdFilePaths: vi.fn(),
getEnableExtensionReloading: () => true,
getGeminiClient: () => ({
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
isInitialized: () => false,
updateSystemInstruction: vi.fn(),
setTools: vi.fn(),
@@ -173,6 +173,7 @@ describe('runNonInteractive', () => {
publish: vi.fn(),
}),
getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient),
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
getMaxSessionTurns: vi.fn().mockReturnValue(10),
getSessionId: vi.fn().mockReturnValue('test-session-id'),
+24 -10
View File
@@ -9,9 +9,8 @@ import type {
ToolCallRequestInfo,
ResumedSessionData,
UserFeedbackPayload,
ServerGeminiStreamEvent,
} from '@google/gemini-cli-core';
import { isSlashCommand } from './ui/utils/commandUtils.js';
import type { LoadedSettings } from './config/settings.js';
import {
convertSessionToClientHistory,
GeminiEventType,
@@ -30,7 +29,10 @@ import {
ToolErrorType,
Scheduler,
ROOT_SCHEDULER_ID,
A2AStreamingAdapter,
} from '@google/gemini-cli-core';
import { isSlashCommand } from './ui/utils/commandUtils.js';
import type { LoadedSettings } from './config/settings.js';
import type { Content, Part } from '@google/genai';
import readline from 'node:readline';
@@ -299,14 +301,26 @@ export async function runNonInteractive({
}
const toolCallRequests: ToolCallRequestInfo[] = [];
const responseStream = geminiClient.sendMessageStream(
currentMessages[0]?.parts || [],
abortController.signal,
prompt_id,
undefined,
false,
turnCount === 1 ? input : undefined,
);
let responseStream: AsyncIterable<ServerGeminiStreamEvent>;
const primaryAgent = config.getSessionPrimaryAgent();
if (primaryAgent) {
const adapter = new A2AStreamingAdapter(config);
responseStream = adapter.sendMessageStream(
primaryAgent,
currentMessages[0]?.parts || [],
abortController.signal,
prompt_id,
);
} else {
responseStream = geminiClient.sendMessageStream(
currentMessages[0]?.parts || [],
abortController.signal,
prompt_id,
undefined,
false,
turnCount === 1 ? input : undefined,
);
}
let responseText = '';
for await (const event of responseStream) {
@@ -85,8 +85,11 @@ export const createMockConfig = (overrides: Partial<Config> = {}): Config =>
getTelemetryUseCollector: vi.fn().mockReturnValue(false),
getTelemetryUseCliAuth: vi.fn().mockReturnValue(false),
getGeminiClient: vi.fn().mockReturnValue({
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
isInitialized: vi.fn().mockReturnValue(true),
}),
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
setSessionPrimaryAgent: vi.fn(),
updateSystemInstructionIfInitialized: vi.fn().mockResolvedValue(undefined),
getModelRouterService: vi.fn().mockReturnValue({}),
getModelAvailabilityService: vi.fn().mockReturnValue({}),
@@ -265,6 +265,72 @@ async function configAction(
};
}
async function useAction(
context: CommandContext,
args: string,
): Promise<SlashCommandActionReturn | void> {
const { config } = context.services;
if (!config) {
return {
type: 'message',
messageType: 'error',
content: 'Config not loaded.',
};
}
const agentName = args.trim();
if (!agentName) {
return {
type: 'message',
messageType: 'error',
content: 'Usage: /agents use <agent-name> (or /agents use default)',
};
}
if (agentName.toLowerCase() === 'default') {
config.setSessionPrimaryAgent(null);
return {
type: 'message',
messageType: 'info',
content: 'Reverted to default primary agent.',
};
}
const agentRegistry = config.getAgentRegistry();
if (!agentRegistry) {
return {
type: 'message',
messageType: 'error',
content: 'Agent registry not found.',
};
}
const definition = agentRegistry.getDiscoveredDefinition(agentName);
if (!definition) {
return {
type: 'message',
messageType: 'error',
content: `Agent '${agentName}' not found.`,
};
}
if (definition.kind !== 'remote') {
return {
type: 'message',
messageType: 'error',
content: `Agent '${agentName}' is not a remote (A2A) agent. Only remote agents can be set as the primary agent currently.`,
};
}
config.setSessionPrimaryAgent(agentName);
return {
type: 'message',
messageType: 'info',
content: `Set primary agent to '${agentName}' for this session.`,
};
}
function completeAgentsToEnable(context: CommandContext, partialArg: string) {
const { config, settings } = context.services;
if (!config) return [];
@@ -295,6 +361,20 @@ function completeAllAgents(context: CommandContext, partialArg: string) {
return allAgents.filter((name: string) => name.startsWith(partialArg));
}
function completeRemoteAgents(context: CommandContext, partialArg: string) {
const { config } = context.services;
if (!config) return [];
const agentRegistry = config.getAgentRegistry();
if (!agentRegistry) return ['default'];
const remoteAgents = agentRegistry
.getAllDefinitions()
.filter((d) => d.kind === 'remote')
.map((d) => d.name);
remoteAgents.push('default');
return remoteAgents.filter((name: string) => name.startsWith(partialArg));
}
const enableCommand: SlashCommand = {
name: 'enable',
description: 'Enable a disabled agent',
@@ -322,6 +402,15 @@ const configCommand: SlashCommand = {
completion: completeAllAgents,
};
const useCommand: SlashCommand = {
name: 'use',
description: 'Use a specific agent for the current session',
kind: CommandKind.BUILT_IN,
autoExecute: false,
action: useAction,
completion: completeRemoteAgents,
};
const agentsRefreshCommand: SlashCommand = {
name: 'refresh',
altNames: ['reload'],
@@ -363,6 +452,7 @@ export const agentsCommand: SlashCommand = {
enableCommand,
disableCommand,
configCommand,
useCommand,
],
action: async (context: CommandContext, args) =>
// Default to list if no subcommand is provided
@@ -26,6 +26,7 @@ describe('authCommand', () => {
services: {
config: {
getGeminiClient: vi.fn(),
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
},
},
});
@@ -88,6 +88,7 @@ describe('bugCommand', () => {
getBugCommand: () => undefined,
getIdeMode: () => true,
getGeminiClient: () => ({
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
getChat: () => ({
getHistory: () => [],
}),
@@ -131,6 +132,7 @@ describe('bugCommand', () => {
getBugCommand: () => undefined,
getIdeMode: () => true,
getGeminiClient: () => ({
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
getChat: () => ({
getHistory: () => history,
}),
@@ -177,6 +179,7 @@ describe('bugCommand', () => {
getBugCommand: () => ({ urlTemplate: customTemplate }),
getIdeMode: () => true,
getGeminiClient: () => ({
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
getChat: () => ({
getHistory: () => [],
}),
@@ -32,6 +32,7 @@ describe('copyCommand', () => {
services: {
config: {
getGeminiClient: () => ({
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
getChat: mockGetChat,
}),
},
@@ -66,6 +66,7 @@ describe('directoryCommand', () => {
getWorkspaceContext: () => mockWorkspaceContext,
isRestrictiveSandbox: vi.fn().mockReturnValue(false),
getGeminiClient: vi.fn().mockReturnValue({
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
addDirectoryContext: vi.fn(),
getChatRecordingService: vi.fn().mockReturnValue({
recordDirectories: vi.fn(),
@@ -48,6 +48,7 @@ describe('restoreCommand', () => {
getProjectTempDir: vi.fn().mockReturnValue(geminiTempDir),
},
getGeminiClient: vi.fn().mockReturnValue({
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
setHistory: mockSetHistory,
}),
} as unknown as Config;
@@ -99,6 +99,7 @@ describe('rewindCommand', () => {
services: {
config: {
getGeminiClient: () => ({
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
getChatRecordingService: mockGetChatRecordingService,
setHistory: mockSetHistory,
sendMessageStream: mockSendMessageStream,
@@ -293,7 +294,10 @@ describe('rewindCommand', () => {
it('should fail if client is not initialized', () => {
const context = createMockCommandContext({
services: {
config: { getGeminiClient: () => undefined },
config: {
getGeminiClient: () => undefined,
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
},
},
}) as unknown as CommandContext;
@@ -311,6 +315,7 @@ describe('rewindCommand', () => {
services: {
config: {
getGeminiClient: () => ({ getChatRecordingService: () => undefined }),
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
},
},
}) as unknown as CommandContext;
@@ -95,6 +95,7 @@ describe('useCommandCompletion', () => {
const mockCommandContext = {} as CommandContext;
const mockConfig = {
getGeminiClient: vi.fn(),
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
} as unknown as Config;
const testRootDir = '/';
@@ -499,6 +500,7 @@ describe('useCommandCompletion', () => {
it('should not trigger prompt completion for line comments', async () => {
const mockConfig = {
getGeminiClient: vi.fn(),
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
} as unknown as Config;
let hookResult: ReturnType<typeof useCommandCompletion> & {
@@ -531,6 +533,7 @@ describe('useCommandCompletion', () => {
it('should not trigger prompt completion for block comments', async () => {
const mockConfig = {
getGeminiClient: vi.fn(),
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
} as unknown as Config;
let hookResult: ReturnType<typeof useCommandCompletion> & {
@@ -565,6 +568,7 @@ describe('useCommandCompletion', () => {
it('should trigger prompt completion for regular text when enabled', async () => {
const mockConfig = {
getGeminiClient: vi.fn(),
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
} as unknown as Config;
let hookResult: ReturnType<typeof useCommandCompletion> & {
@@ -246,6 +246,7 @@ describe('useGeminiStream', () => {
getProjectRoot: vi.fn(() => '/test/dir'),
getCheckpointingEnabled: vi.fn(() => false),
getGeminiClient: mockGetGeminiClient,
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
getMcpClientManager: () => mockMcpClientManager as any,
getApprovalMode: vi.fn(() => ApprovalMode.DEFAULT),
getUsageStatisticsEnabled: () => true,
+21 -8
View File
@@ -36,6 +36,7 @@ import {
CoreToolCallStatus,
buildUserSteeringHintPrompt,
generateSteeringAckMessage,
A2AStreamingAdapter,
} from '@google/gemini-cli-core';
import type {
Config,
@@ -1324,14 +1325,26 @@ export const useGeminiStream = (
lastPromptIdRef.current = prompt_id!;
try {
const stream = geminiClient.sendMessageStream(
queryToSend,
abortSignal,
prompt_id!,
undefined,
false,
query,
);
let stream: AsyncIterable<GeminiEvent>;
const primaryAgent = config.getSessionPrimaryAgent();
if (primaryAgent) {
const adapter = new A2AStreamingAdapter(config);
stream = adapter.sendMessageStream(
primaryAgent,
queryToSend,
abortSignal,
prompt_id!,
);
} else {
stream = geminiClient.sendMessageStream(
queryToSend,
abortSignal,
prompt_id!,
undefined,
false,
query,
);
}
const processingStatus = await processGeminiStreamEvents(
stream,
userMessageTimestamp,
@@ -54,6 +54,7 @@ describe('useSessionBrowser', () => {
setSessionId: vi.fn(),
getSessionId: vi.fn(),
getGeminiClient: vi.fn().mockReturnValue({
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
getChatRecordingService: vi.fn().mockReturnValue({
deleteSession: vi.fn(),
}),
@@ -25,6 +25,7 @@ describe('useSessionResume', () => {
const mockConfig = {
getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient),
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
};
const createMockHistoryManager = (): UseHistoryManagerReturn => ({
@@ -286,6 +287,7 @@ describe('useSessionResume', () => {
const newMockConfig = {
getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient),
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
};
rerender({ config: newMockConfig as unknown as Config });
@@ -66,6 +66,7 @@ describe('GeminiAgent Session Resume', () => {
getFileSystemService: vi.fn(),
setFileSystemService: vi.fn(),
getGeminiClient: vi.fn().mockReturnValue({
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
initialize: vi.fn().mockResolvedValue(undefined),
resumeChat: vi.fn().mockResolvedValue(undefined),
getChat: vi.fn().mockReturnValue({}),
@@ -113,6 +113,7 @@ describe('GeminiAgent', () => {
getActiveModel: vi.fn().mockReturnValue('gemini-pro'),
getModel: vi.fn().mockReturnValue('gemini-pro'),
getGeminiClient: vi.fn().mockReturnValue({
getSessionPrimaryAgent: vi.fn().mockReturnValue(null),
startChat: vi.fn().mockResolvedValue({}),
}),
getMessageBus: vi.fn().mockReturnValue({
@@ -0,0 +1,199 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { FinishReason, type PartListUnion } from '@google/genai';
import { type ServerGeminiStreamEvent, GeminiEventType } from '../core/turn.js';
import { A2AClientManager } from './a2a-client-manager.js';
import { extractTaskText, extractIdsFromResponse } from './a2aUtils.js';
import { ADCHandler } from './remote-invocation.js';
import type { Config } from '../config/config.js';
import type { Task } from '@a2a-js/sdk';
import type { ThoughtSummary } from '../utils/thoughtUtils.js';
export class A2AStreamingAdapter {
private static sessionState = new Map<
string,
{ contextId?: string; taskId?: string }
>();
constructor(private readonly config: Config) {}
async *sendMessageStream(
agentName: string,
request: PartListUnion,
signal: AbortSignal,
prompt_id: string,
): AsyncGenerator<ServerGeminiStreamEvent> {
const clientManager = A2AClientManager.getInstance();
const registry = this.config.getAgentRegistry();
if (!registry) {
yield {
type: GeminiEventType.Error,
value: { error: new Error('Agent registry not found.') },
};
return;
}
const definition = registry.getDiscoveredDefinition(agentName);
if (!definition || definition.kind !== 'remote') {
yield {
type: GeminiEventType.Error,
value: { error: new Error(`Remote agent '${agentName}' not found.`) },
};
return;
}
// Determine query
const { partToString } = await import('../utils/partUtils.js');
const queryText = partToString(request);
// Load agent if needed
if (!clientManager.getClient(agentName)) {
try {
await clientManager.loadAgent(
agentName,
definition.agentCardUrl,
new ADCHandler(),
);
} catch (err) {
yield {
type: GeminiEventType.Error,
value: {
error: new Error(`Failed to load agent ${agentName}: ${err}`),
},
};
return;
}
}
const priorState = A2AStreamingAdapter.sessionState.get(agentName);
const contextId = priorState?.contextId;
const taskId = priorState?.taskId;
const client = clientManager.getClient(agentName);
if (!client) {
yield {
type: GeminiEventType.Error,
value: { error: new Error(`Client for ${agentName} not initialized.`) },
};
return;
}
try {
let currentTask: Task | undefined;
const response = await client.sendMessage({
message: {
kind: 'message',
role: 'user',
messageId: prompt_id,
parts: [{ kind: 'text', text: queryText }],
contextId,
taskId,
},
configuration: { blocking: false }, // NON-BLOCKING!
});
let currentTaskId = taskId;
if (response.kind === 'message') {
A2AStreamingAdapter.sessionState.set(
agentName,
extractIdsFromResponse(response),
);
if (response.parts && response.parts.length > 0) {
const finalMessageText = response.parts
.map((p) => (p.kind === 'text' ? p.text : ''))
.join('\\n');
if (finalMessageText) {
yield {
type: GeminiEventType.Content,
value: finalMessageText,
traceId: response.messageId,
};
}
}
} else if (response.kind === 'task') {
currentTask = response;
currentTaskId = response.id;
A2AStreamingAdapter.sessionState.set(agentName, {
contextId: response.contextId,
taskId: response.id,
});
}
if (currentTask && currentTaskId) {
let lastStatusMessage = '';
while (true) {
if (signal.aborted) {
await client.cancelTask({ id: currentTaskId });
yield { type: GeminiEventType.UserCancelled };
return;
}
currentTask = await client.getTask({ id: currentTaskId });
const state = currentTask.status?.state;
const statusMessage =
currentTask.status?.message?.parts
?.map((p) => (p.kind === 'text' ? p.text : ''))
.join('\\n') || '';
if (statusMessage && statusMessage !== lastStatusMessage) {
const thought: ThoughtSummary = {
subject: 'Remote Action',
description: statusMessage,
};
yield {
type: GeminiEventType.Thought,
value: thought,
};
lastStatusMessage = statusMessage;
}
if (
state === 'completed' ||
state === 'failed' ||
state === 'canceled' ||
state === 'input-required'
) {
A2AStreamingAdapter.sessionState.set(
agentName,
extractIdsFromResponse(currentTask),
);
break;
}
await new Promise((resolve) => setTimeout(resolve, 1000));
}
const finalOutput = extractTaskText(currentTask);
if (finalOutput) {
yield {
type: GeminiEventType.Content,
value: finalOutput,
traceId: currentTask.id,
};
}
// Also fire finished event when complete.
if (currentTask.status?.state === 'completed') {
yield {
type: GeminiEventType.Finished,
value: {
reason: FinishReason.STOP,
usageMetadata: undefined,
},
};
}
}
} catch (err) {
yield {
type: GeminiEventType.Error,
value: { error: new Error(`A2A Execution Error: ${err}`) },
};
}
}
}
+4 -1
View File
@@ -122,7 +122,10 @@ export function extractIdsFromResponse(result: Message | Task): {
let taskId: string | undefined;
if (result.kind === 'message') {
taskId = result.taskId;
// We explicitly DO NOT return the taskId for a 'message' response.
// In the A2A SDK, when a server returns a final Message instead of a Task,
// the task is implicitly complete and should not be passed to subsequent requests,
// otherwise the server will throw a TaskNotFoundError.
contextId = result.contextId;
} else if (result.kind === 'task') {
taskId = result.id;
+9
View File
@@ -701,6 +701,7 @@ export class Config {
private lastModeSwitchTime: number = performance.now();
readonly userHintService: UserHintService;
private approvedPlanPath: string | undefined;
private sessionPrimaryAgent: string | null = null;
constructor(params: ConfigParameters) {
this.sessionId = params.sessionId;
@@ -1936,6 +1937,14 @@ export class Config {
return this.telemetrySettings.useCliAuth ?? false;
}
getSessionPrimaryAgent(): string | null {
return this.sessionPrimaryAgent;
}
setSessionPrimaryAgent(agentName: string | null): void {
this.sessionPrimaryAgent = agentName;
}
getGeminiClient(): GeminiClient {
return this.geminiClient;
}
+1
View File
@@ -195,6 +195,7 @@ export * from './hooks/types.js';
// Export agent types
export * from './agents/types.js';
export * from './agents/a2aStreamingAdapter.js';
// Export stdio utils
export * from './utils/stdio.js';