diff --git a/packages/a2a-server/src/agent/executor.test.ts b/packages/a2a-server/src/agent/executor.test.ts index 2b77f3006c..f6e0c03fdb 100644 --- a/packages/a2a-server/src/agent/executor.test.ts +++ b/packages/a2a-server/src/agent/executor.test.ts @@ -22,6 +22,7 @@ vi.mock('../config/config.js', () => ({ getCheckpointingEnabled: () => false, }), loadEnvironment: vi.fn(), + setIsTrusted: vi.fn().mockReturnValue(false), setTargetDir: vi.fn().mockReturnValue('/tmp'), })); @@ -62,6 +63,12 @@ vi.mock('./task.js', () => { scheduleToolCalls: vi.fn().mockResolvedValue(undefined), waitForPendingTools: vi.fn().mockResolvedValue(undefined), getAndClearCompletedTools: vi.fn().mockReturnValue([]), + get hasPendingTools() { + return false; + }, + get pendingToolsCount() { + return 0; + }, addToolResponsesToHistory: vi.fn(), sendCompletedToolsToLlm: vi.fn().mockImplementation(async function* () {}), cancelPendingTools: vi.fn(), @@ -245,4 +252,52 @@ describe('CoderAgentExecutor', () => { expect(executor.getTask(taskId)).toBeUndefined(); expect(wrapper.task.dispose).toHaveBeenCalled(); }); + + it('should yield the turn and transition to input-required if tools are pending', async () => { + const taskId = 'test-task-pending-tools'; + const contextId = 'test-context'; + + const mockSocket = new EventEmitter(); + (requestStorage.getStore as Mock).mockReturnValue({ + req: { socket: mockSocket }, + }); + + // Pre-create the task to safely modify its mocked methods before execution + const wrapper = await executor.createTask( + taskId, + contextId, + undefined, + mockEventBus, + ); + const hasPendingToolsSpy = vi + .spyOn(wrapper.task, 'hasPendingTools', 'get') + .mockReturnValue(true); + vi.spyOn(wrapper.task, 'pendingToolsCount', 'get').mockReturnValue(1); + + const requestContext = { + userMessage: { + messageId: 'msg-1', + taskId, + contextId, + parts: [{ kind: 'confirmation', callId: '1', outcome: 'proceed' }], + metadata: { + coderAgent: { kind: 'agent-settings', workspacePath: '/tmp' }, + }, + }, + } as unknown as RequestContext; + + await executor.execute(requestContext, mockEventBus); + + // Assert that the executor yielded the turn correctly without further progression + expect(hasPendingToolsSpy).toHaveBeenCalled(); + expect(wrapper.task.getAndClearCompletedTools).not.toHaveBeenCalled(); + expect(wrapper.task.sendCompletedToolsToLlm).not.toHaveBeenCalled(); + expect(wrapper.task.setTaskStateAndPublishUpdate).toHaveBeenCalledWith( + 'input-required', + expect.any(Object), + undefined, + undefined, + true, + ); + }); }); diff --git a/packages/a2a-server/src/agent/executor.ts b/packages/a2a-server/src/agent/executor.ts index 97f4010a89..a7adc925a9 100644 --- a/packages/a2a-server/src/agent/executor.ts +++ b/packages/a2a-server/src/agent/executor.ts @@ -31,7 +31,12 @@ import { getContextIdFromMetadata, getAgentSettingsFromMetadata, } from '../types.js'; -import { loadConfig, loadEnvironment, setTargetDir } from '../config/config.js'; +import { + loadConfig, + loadEnvironment, + setIsTrusted, + setTargetDir, +} from '../config/config.js'; import { loadSettings } from '../config/settings.js'; import { loadExtensions } from '../config/extension.js'; import { Task } from './task.js'; @@ -93,8 +98,8 @@ export class CoderAgentExecutor implements AgentExecutor { taskId: string, ): Promise { const workspaceRoot = setTargetDir(agentSettings); - const isTrusted = agentSettings.isTrusted ?? false; loadEnvironment(); // Will override any global env with workspace envs + const isTrusted = setIsTrusted(agentSettings); const settings = loadSettings(workspaceRoot, isTrusted); const extensions = loadExtensions(workspaceRoot); return loadConfig( @@ -541,42 +546,49 @@ export class CoderAgentExecutor implements AgentExecutor { if (abortSignal.aborted) throw new Error('Execution aborted'); - const completedTools = currentTask.getAndClearCompletedTools(); - - if (completedTools.length > 0) { - // If all completed tool calls were canceled, manually add them to history and set state to input-required, final:true - if (completedTools.every((tool) => tool.status === 'cancelled')) { - logger.info( - `[CoderAgentExecutor] Task ${taskId}: All tool calls were cancelled. Updating history and ending agent turn.`, - ); - currentTask.addToolResponsesToHistory(completedTools); - agentTurnActive = false; - const stateChange: StateChange = { - kind: CoderAgentEvent.StateChangeEvent, - }; - currentTask.setTaskStateAndPublishUpdate( - 'input-required', - stateChange, - undefined, - undefined, - true, - ); - } else { - logger.info( - `[CoderAgentExecutor] Task ${taskId}: Found ${completedTools.length} completed tool calls. Sending results back to LLM.`, - ); - - agentEvents = currentTask.sendCompletedToolsToLlm( - completedTools, - abortSignal, - ); - // Continue the loop to process the LLM response to the tool results. - } - } else { + if (currentTask.hasPendingTools) { logger.info( - `[CoderAgentExecutor] Task ${taskId}: No more tool calls to process. Ending agent turn.`, + `[CoderAgentExecutor] Task ${taskId}: There are still ${currentTask.pendingToolsCount} pending tools waiting for approval. Yielding to user.`, ); agentTurnActive = false; + } else { + const completedTools = currentTask.getAndClearCompletedTools(); + + if (completedTools.length > 0) { + // If all completed tool calls were canceled, manually add them to history and set state to input-required, final:true + if (completedTools.every((tool) => tool.status === 'cancelled')) { + logger.info( + `[CoderAgentExecutor] Task ${taskId}: All tool calls were cancelled. Updating history and ending agent turn.`, + ); + currentTask.addToolResponsesToHistory(completedTools); + agentTurnActive = false; + const stateChange: StateChange = { + kind: CoderAgentEvent.StateChangeEvent, + }; + currentTask.setTaskStateAndPublishUpdate( + 'input-required', + stateChange, + undefined, + undefined, + true, + ); + } else { + logger.info( + `[CoderAgentExecutor] Task ${taskId}: Found ${completedTools.length} completed tool calls. Sending results back to LLM.`, + ); + + agentEvents = currentTask.sendCompletedToolsToLlm( + completedTools, + abortSignal, + ); + // Continue the loop to process the LLM response to the tool results. + } + } else { + logger.info( + `[CoderAgentExecutor] Task ${taskId}: No more tool calls to process. Ending agent turn.`, + ); + agentTurnActive = false; + } } } diff --git a/packages/a2a-server/src/agent/task.test.ts b/packages/a2a-server/src/agent/task.test.ts index 1d0e010c2a..5eb5098aeb 100644 --- a/packages/a2a-server/src/agent/task.test.ts +++ b/packages/a2a-server/src/agent/task.test.ts @@ -631,6 +631,35 @@ describe('Task', () => { expect(handleEventDrivenToolCallSpy).toHaveBeenCalled(); }); + + describe('Pending Tools state', () => { + it('should correctly report pending tools presence and count', () => { + const mockConfig = createMockConfig(); + const mockEventBus: ExecutionEventBus = { + publish: vi.fn(), + on: vi.fn(), + off: vi.fn(), + once: vi.fn(), + removeAllListeners: vi.fn(), + finished: vi.fn(), + }; + + // @ts-expect-error - Calling private constructor + const task = new Task( + 'task-id', + 'context-id', + mockConfig as Config, + mockEventBus, + ); + + expect(task.hasPendingTools).toBe(false); + expect(task.pendingToolsCount).toBe(0); + + task['_registerToolCall']('tool-1', 'scheduled'); + expect(task.hasPendingTools).toBe(true); + expect(task.pendingToolsCount).toBe(1); + }); + }); }); describe('Serialization and Mapping', () => { diff --git a/packages/a2a-server/src/agent/task.ts b/packages/a2a-server/src/agent/task.ts index ee13adc8cb..ac1ced0482 100644 --- a/packages/a2a-server/src/agent/task.ts +++ b/packages/a2a-server/src/agent/task.ts @@ -137,6 +137,14 @@ export class Task { ); } + get hasPendingTools(): boolean { + return this.pendingToolCalls.size > 0; + } + + get pendingToolsCount(): number { + return this.pendingToolCalls.size; + } + static async create( id: string, contextId: string, diff --git a/packages/a2a-server/src/config/config.test.ts b/packages/a2a-server/src/config/config.test.ts index 9f890d854c..d12bf534fb 100644 --- a/packages/a2a-server/src/config/config.test.ts +++ b/packages/a2a-server/src/config/config.test.ts @@ -23,6 +23,7 @@ import { PRIORITY_YOLO_ALLOW_ALL, createPolicyEngineConfig, } from '@google/gemini-cli-core'; +import type { AgentSettings } from '../types.js'; // Mock dependencies vi.mock('@google/gemini-cli-core', async (importOriginal) => { @@ -612,3 +613,34 @@ describe('loadConfig', () => { }); }); }); + +describe('setIsTrusted', () => { + beforeEach(() => { + vi.resetModules(); + }); + + afterEach(() => { + vi.unstubAllEnvs(); + }); + + it('should return true when GEMINI_FOLDER_TRUST env var is true', async () => { + vi.stubEnv('GEMINI_FOLDER_TRUST', 'true'); + const { setIsTrusted } = await import('./config.js'); + expect(setIsTrusted(undefined)).toBe(true); + expect(setIsTrusted({ isTrusted: false } as AgentSettings)).toBe(true); + }); + + it('should return false when GEMINI_FOLDER_TRUST env var is false', async () => { + vi.stubEnv('GEMINI_FOLDER_TRUST', 'false'); + const { setIsTrusted } = await import('./config.js'); + expect(setIsTrusted(undefined)).toBe(false); + expect(setIsTrusted({ isTrusted: true } as AgentSettings)).toBe(false); + }); + + it('should fallback to agentSettings.isTrusted if env var is undefined', async () => { + const { setIsTrusted } = await import('./config.js'); + expect(setIsTrusted({ isTrusted: true } as AgentSettings)).toBe(true); + expect(setIsTrusted({ isTrusted: false } as AgentSettings)).toBe(false); + expect(setIsTrusted(undefined)).toBe(false); + }); +}); diff --git a/packages/a2a-server/src/config/config.ts b/packages/a2a-server/src/config/config.ts index 3071827240..3f937f0a5c 100644 --- a/packages/a2a-server/src/config/config.ts +++ b/packages/a2a-server/src/config/config.ts @@ -34,6 +34,8 @@ import { logger } from '../utils/logger.js'; import type { Settings } from './settings.js'; import { type AgentSettings, CoderAgentEvent } from '../types.js'; +const INITIAL_FOLDER_TRUST = process.env['GEMINI_FOLDER_TRUST']; + export async function loadConfig( settings: Settings, extensionLoader: ExtensionLoader, @@ -182,6 +184,15 @@ export async function loadConfig( return config; } +export function setIsTrusted( + agentSettings: AgentSettings | undefined, +): boolean { + if (INITIAL_FOLDER_TRUST !== undefined) { + return INITIAL_FOLDER_TRUST === 'true'; + } + return !!agentSettings?.isTrusted; +} + export function setTargetDir(agentSettings: AgentSettings | undefined): string { const originalCWD = process.cwd(); const targetDir = diff --git a/packages/core/src/scheduler/scheduler.ts b/packages/core/src/scheduler/scheduler.ts index a6e0294962..5a4f986951 100644 --- a/packages/core/src/scheduler/scheduler.ts +++ b/packages/core/src/scheduler/scheduler.ts @@ -26,7 +26,10 @@ import { type ScheduledToolCall, } from './types.js'; import { ToolErrorType } from '../tools/tool-error.js'; -import { UPDATE_TOPIC_TOOL_NAME } from '../tools/tool-names.js'; +import { + UPDATE_TOPIC_TOOL_NAME, + EDIT_TOOL_NAMES, +} from '../tools/tool-names.js'; import { PolicyDecision, type ApprovalMode } from '../policy/types.js'; import { ToolConfirmationOutcome, @@ -548,7 +551,10 @@ export class Scheduler { private _isParallelizable(request: ToolCallRequestInfo): boolean { // update_topic tool is forced as sequential call - if (request.name === UPDATE_TOPIC_TOOL_NAME) { + if ( + request.name === UPDATE_TOPIC_TOOL_NAME || + EDIT_TOOL_NAMES.has(request.name) + ) { return false; } if (request.args) { diff --git a/packages/core/src/scheduler/scheduler_parallel.test.ts b/packages/core/src/scheduler/scheduler_parallel.test.ts index eef052d707..e29c1c0b04 100644 --- a/packages/core/src/scheduler/scheduler_parallel.test.ts +++ b/packages/core/src/scheduler/scheduler_parallel.test.ts @@ -79,7 +79,12 @@ import { type Status, type ToolCall, } from './types.js'; -import { UPDATE_TOPIC_TOOL_NAME } from '../tools/tool-names.js'; +import { + UPDATE_TOPIC_TOOL_NAME, + WRITE_FILE_TOOL_NAME, + EDIT_TOOL_NAME, + EDIT_TOOL_NAMES, +} from '../tools/tool-names.js'; import { GeminiCliOperation } from '../telemetry/constants.js'; import type { EditorType } from '../utils/editor.js'; @@ -161,6 +166,12 @@ describe('Scheduler Parallel Execution', () => { isReadOnly: false, build: vi.fn(), } as unknown as AnyDeclarativeTool; + const editTool = { + name: EDIT_TOOL_NAME, + kind: Kind.Execute, + isReadOnly: false, + build: vi.fn(), + } as unknown as AnyDeclarativeTool; const agentTool1 = { name: 'agent-tool-1', kind: Kind.Agent, @@ -203,6 +214,8 @@ describe('Scheduler Parallel Execution', () => { if (name === 'agent-tool-1') return agentTool1; if (name === 'agent-tool-2') return agentTool2; if (name === UPDATE_TOPIC_TOOL_NAME) return topicTool; + if (name === WRITE_FILE_TOOL_NAME) return writeTool; + if (name === EDIT_TOOL_NAME) return editTool; return undefined; }), getAllToolNames: vi @@ -214,6 +227,8 @@ describe('Scheduler Parallel Execution', () => { 'agent-tool-1', 'agent-tool-2', UPDATE_TOPIC_TOOL_NAME, + WRITE_FILE_TOOL_NAME, + EDIT_TOOL_NAME, ]), } as unknown as Mocked; @@ -336,6 +351,9 @@ describe('Scheduler Parallel Execution', () => { vi.mocked(writeTool.build).mockReturnValue( mockInvocation as unknown as AnyToolInvocation, ); + vi.mocked(editTool.build).mockReturnValue( + mockInvocation as unknown as AnyToolInvocation, + ); vi.mocked(agentTool1.build).mockReturnValue( mockInvocation as unknown as AnyToolInvocation, ); @@ -597,4 +615,44 @@ describe('Scheduler Parallel Execution', () => { expect(executionLog.slice(2, 4)).toContain('start-call-1'); expect(executionLog.slice(2, 4)).toContain('start-call-2'); }); + + it.each(Array.from(EDIT_TOOL_NAMES))( + 'should execute %s sequentially even without wait_for_previous', + async (toolName) => { + const executionLog: string[] = []; + mockExecutor.execute.mockImplementation(async ({ call }) => { + const id = call.request.callId; + executionLog.push(`start-${id}`); + await new Promise((resolve) => setTimeout(resolve, 10)); + executionLog.push(`end-${id}`); + return { + status: 'success', + response: { callId: id, responseParts: [] }, + } as unknown as SuccessfulToolCall; + }); + + const e1: ToolCallRequestInfo = { + callId: 'e1', + name: toolName, + args: { path: 'a.txt', wait_for_previous: false }, + isClientInitiated: false, + prompt_id: 'p1', + schedulerId: ROOT_SCHEDULER_ID, + }; + const e2: ToolCallRequestInfo = { + ...e1, + callId: 'e2', + }; + + await scheduler.schedule([e1, e2], signal); + + // Even though wait_for_previous is false, EDIT_TOOL_NAMES enforces sequential execution + expect(executionLog).toEqual([ + 'start-e1', + 'end-e1', + 'start-e2', + 'end-e2', + ]); + }, + ); });