Fix/pending tools and trust overrides (#27854)

This commit is contained in:
jvargassanchez-dot
2026-06-15 22:24:50 +00:00
committed by GitHub
parent bca5667fc6
commit 0f8a157e5e
8 changed files with 249 additions and 38 deletions
@@ -22,6 +22,7 @@ vi.mock('../config/config.js', () => ({
getCheckpointingEnabled: () => false,
}),
loadEnvironment: vi.fn(),
setIsTrusted: vi.fn().mockReturnValue(false),
setTargetDir: vi.fn().mockReturnValue('/tmp'),
}));
@@ -62,6 +63,12 @@ vi.mock('./task.js', () => {
scheduleToolCalls: vi.fn().mockResolvedValue(undefined),
waitForPendingTools: vi.fn().mockResolvedValue(undefined),
getAndClearCompletedTools: vi.fn().mockReturnValue([]),
get hasPendingTools() {
return false;
},
get pendingToolsCount() {
return 0;
},
addToolResponsesToHistory: vi.fn(),
sendCompletedToolsToLlm: vi.fn().mockImplementation(async function* () {}),
cancelPendingTools: vi.fn(),
@@ -245,4 +252,52 @@ describe('CoderAgentExecutor', () => {
expect(executor.getTask(taskId)).toBeUndefined();
expect(wrapper.task.dispose).toHaveBeenCalled();
});
it('should yield the turn and transition to input-required if tools are pending', async () => {
const taskId = 'test-task-pending-tools';
const contextId = 'test-context';
const mockSocket = new EventEmitter();
(requestStorage.getStore as Mock).mockReturnValue({
req: { socket: mockSocket },
});
// Pre-create the task to safely modify its mocked methods before execution
const wrapper = await executor.createTask(
taskId,
contextId,
undefined,
mockEventBus,
);
const hasPendingToolsSpy = vi
.spyOn(wrapper.task, 'hasPendingTools', 'get')
.mockReturnValue(true);
vi.spyOn(wrapper.task, 'pendingToolsCount', 'get').mockReturnValue(1);
const requestContext = {
userMessage: {
messageId: 'msg-1',
taskId,
contextId,
parts: [{ kind: 'confirmation', callId: '1', outcome: 'proceed' }],
metadata: {
coderAgent: { kind: 'agent-settings', workspacePath: '/tmp' },
},
},
} as unknown as RequestContext;
await executor.execute(requestContext, mockEventBus);
// Assert that the executor yielded the turn correctly without further progression
expect(hasPendingToolsSpy).toHaveBeenCalled();
expect(wrapper.task.getAndClearCompletedTools).not.toHaveBeenCalled();
expect(wrapper.task.sendCompletedToolsToLlm).not.toHaveBeenCalled();
expect(wrapper.task.setTaskStateAndPublishUpdate).toHaveBeenCalledWith(
'input-required',
expect.any(Object),
undefined,
undefined,
true,
);
});
});
+47 -35
View File
@@ -31,7 +31,12 @@ import {
getContextIdFromMetadata,
getAgentSettingsFromMetadata,
} from '../types.js';
import { loadConfig, loadEnvironment, setTargetDir } from '../config/config.js';
import {
loadConfig,
loadEnvironment,
setIsTrusted,
setTargetDir,
} from '../config/config.js';
import { loadSettings } from '../config/settings.js';
import { loadExtensions } from '../config/extension.js';
import { Task } from './task.js';
@@ -93,8 +98,8 @@ export class CoderAgentExecutor implements AgentExecutor {
taskId: string,
): Promise<Config> {
const workspaceRoot = setTargetDir(agentSettings);
const isTrusted = agentSettings.isTrusted ?? false;
loadEnvironment(); // Will override any global env with workspace envs
const isTrusted = setIsTrusted(agentSettings);
const settings = loadSettings(workspaceRoot, isTrusted);
const extensions = loadExtensions(workspaceRoot);
return loadConfig(
@@ -541,42 +546,49 @@ export class CoderAgentExecutor implements AgentExecutor {
if (abortSignal.aborted) throw new Error('Execution aborted');
const completedTools = currentTask.getAndClearCompletedTools();
if (completedTools.length > 0) {
// If all completed tool calls were canceled, manually add them to history and set state to input-required, final:true
if (completedTools.every((tool) => tool.status === 'cancelled')) {
logger.info(
`[CoderAgentExecutor] Task ${taskId}: All tool calls were cancelled. Updating history and ending agent turn.`,
);
currentTask.addToolResponsesToHistory(completedTools);
agentTurnActive = false;
const stateChange: StateChange = {
kind: CoderAgentEvent.StateChangeEvent,
};
currentTask.setTaskStateAndPublishUpdate(
'input-required',
stateChange,
undefined,
undefined,
true,
);
} else {
logger.info(
`[CoderAgentExecutor] Task ${taskId}: Found ${completedTools.length} completed tool calls. Sending results back to LLM.`,
);
agentEvents = currentTask.sendCompletedToolsToLlm(
completedTools,
abortSignal,
);
// Continue the loop to process the LLM response to the tool results.
}
} else {
if (currentTask.hasPendingTools) {
logger.info(
`[CoderAgentExecutor] Task ${taskId}: No more tool calls to process. Ending agent turn.`,
`[CoderAgentExecutor] Task ${taskId}: There are still ${currentTask.pendingToolsCount} pending tools waiting for approval. Yielding to user.`,
);
agentTurnActive = false;
} else {
const completedTools = currentTask.getAndClearCompletedTools();
if (completedTools.length > 0) {
// If all completed tool calls were canceled, manually add them to history and set state to input-required, final:true
if (completedTools.every((tool) => tool.status === 'cancelled')) {
logger.info(
`[CoderAgentExecutor] Task ${taskId}: All tool calls were cancelled. Updating history and ending agent turn.`,
);
currentTask.addToolResponsesToHistory(completedTools);
agentTurnActive = false;
const stateChange: StateChange = {
kind: CoderAgentEvent.StateChangeEvent,
};
currentTask.setTaskStateAndPublishUpdate(
'input-required',
stateChange,
undefined,
undefined,
true,
);
} else {
logger.info(
`[CoderAgentExecutor] Task ${taskId}: Found ${completedTools.length} completed tool calls. Sending results back to LLM.`,
);
agentEvents = currentTask.sendCompletedToolsToLlm(
completedTools,
abortSignal,
);
// Continue the loop to process the LLM response to the tool results.
}
} else {
logger.info(
`[CoderAgentExecutor] Task ${taskId}: No more tool calls to process. Ending agent turn.`,
);
agentTurnActive = false;
}
}
}
@@ -631,6 +631,35 @@ describe('Task', () => {
expect(handleEventDrivenToolCallSpy).toHaveBeenCalled();
});
describe('Pending Tools state', () => {
it('should correctly report pending tools presence and count', () => {
const mockConfig = createMockConfig();
const mockEventBus: ExecutionEventBus = {
publish: vi.fn(),
on: vi.fn(),
off: vi.fn(),
once: vi.fn(),
removeAllListeners: vi.fn(),
finished: vi.fn(),
};
// @ts-expect-error - Calling private constructor
const task = new Task(
'task-id',
'context-id',
mockConfig as Config,
mockEventBus,
);
expect(task.hasPendingTools).toBe(false);
expect(task.pendingToolsCount).toBe(0);
task['_registerToolCall']('tool-1', 'scheduled');
expect(task.hasPendingTools).toBe(true);
expect(task.pendingToolsCount).toBe(1);
});
});
});
describe('Serialization and Mapping', () => {
+8
View File
@@ -137,6 +137,14 @@ export class Task {
);
}
get hasPendingTools(): boolean {
return this.pendingToolCalls.size > 0;
}
get pendingToolsCount(): number {
return this.pendingToolCalls.size;
}
static async create(
id: string,
contextId: string,
@@ -23,6 +23,7 @@ import {
PRIORITY_YOLO_ALLOW_ALL,
createPolicyEngineConfig,
} from '@google/gemini-cli-core';
import type { AgentSettings } from '../types.js';
// Mock dependencies
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
@@ -612,3 +613,34 @@ describe('loadConfig', () => {
});
});
});
describe('setIsTrusted', () => {
beforeEach(() => {
vi.resetModules();
});
afterEach(() => {
vi.unstubAllEnvs();
});
it('should return true when GEMINI_FOLDER_TRUST env var is true', async () => {
vi.stubEnv('GEMINI_FOLDER_TRUST', 'true');
const { setIsTrusted } = await import('./config.js');
expect(setIsTrusted(undefined)).toBe(true);
expect(setIsTrusted({ isTrusted: false } as AgentSettings)).toBe(true);
});
it('should return false when GEMINI_FOLDER_TRUST env var is false', async () => {
vi.stubEnv('GEMINI_FOLDER_TRUST', 'false');
const { setIsTrusted } = await import('./config.js');
expect(setIsTrusted(undefined)).toBe(false);
expect(setIsTrusted({ isTrusted: true } as AgentSettings)).toBe(false);
});
it('should fallback to agentSettings.isTrusted if env var is undefined', async () => {
const { setIsTrusted } = await import('./config.js');
expect(setIsTrusted({ isTrusted: true } as AgentSettings)).toBe(true);
expect(setIsTrusted({ isTrusted: false } as AgentSettings)).toBe(false);
expect(setIsTrusted(undefined)).toBe(false);
});
});
+11
View File
@@ -34,6 +34,8 @@ import { logger } from '../utils/logger.js';
import type { Settings } from './settings.js';
import { type AgentSettings, CoderAgentEvent } from '../types.js';
const INITIAL_FOLDER_TRUST = process.env['GEMINI_FOLDER_TRUST'];
export async function loadConfig(
settings: Settings,
extensionLoader: ExtensionLoader,
@@ -182,6 +184,15 @@ export async function loadConfig(
return config;
}
export function setIsTrusted(
agentSettings: AgentSettings | undefined,
): boolean {
if (INITIAL_FOLDER_TRUST !== undefined) {
return INITIAL_FOLDER_TRUST === 'true';
}
return !!agentSettings?.isTrusted;
}
export function setTargetDir(agentSettings: AgentSettings | undefined): string {
const originalCWD = process.cwd();
const targetDir =
+8 -2
View File
@@ -26,7 +26,10 @@ import {
type ScheduledToolCall,
} from './types.js';
import { ToolErrorType } from '../tools/tool-error.js';
import { UPDATE_TOPIC_TOOL_NAME } from '../tools/tool-names.js';
import {
UPDATE_TOPIC_TOOL_NAME,
EDIT_TOOL_NAMES,
} from '../tools/tool-names.js';
import { PolicyDecision, type ApprovalMode } from '../policy/types.js';
import {
ToolConfirmationOutcome,
@@ -548,7 +551,10 @@ export class Scheduler {
private _isParallelizable(request: ToolCallRequestInfo): boolean {
// update_topic tool is forced as sequential call
if (request.name === UPDATE_TOPIC_TOOL_NAME) {
if (
request.name === UPDATE_TOPIC_TOOL_NAME ||
EDIT_TOOL_NAMES.has(request.name)
) {
return false;
}
if (request.args) {
@@ -79,7 +79,12 @@ import {
type Status,
type ToolCall,
} from './types.js';
import { UPDATE_TOPIC_TOOL_NAME } from '../tools/tool-names.js';
import {
UPDATE_TOPIC_TOOL_NAME,
WRITE_FILE_TOOL_NAME,
EDIT_TOOL_NAME,
EDIT_TOOL_NAMES,
} from '../tools/tool-names.js';
import { GeminiCliOperation } from '../telemetry/constants.js';
import type { EditorType } from '../utils/editor.js';
@@ -161,6 +166,12 @@ describe('Scheduler Parallel Execution', () => {
isReadOnly: false,
build: vi.fn(),
} as unknown as AnyDeclarativeTool;
const editTool = {
name: EDIT_TOOL_NAME,
kind: Kind.Execute,
isReadOnly: false,
build: vi.fn(),
} as unknown as AnyDeclarativeTool;
const agentTool1 = {
name: 'agent-tool-1',
kind: Kind.Agent,
@@ -203,6 +214,8 @@ describe('Scheduler Parallel Execution', () => {
if (name === 'agent-tool-1') return agentTool1;
if (name === 'agent-tool-2') return agentTool2;
if (name === UPDATE_TOPIC_TOOL_NAME) return topicTool;
if (name === WRITE_FILE_TOOL_NAME) return writeTool;
if (name === EDIT_TOOL_NAME) return editTool;
return undefined;
}),
getAllToolNames: vi
@@ -214,6 +227,8 @@ describe('Scheduler Parallel Execution', () => {
'agent-tool-1',
'agent-tool-2',
UPDATE_TOPIC_TOOL_NAME,
WRITE_FILE_TOOL_NAME,
EDIT_TOOL_NAME,
]),
} as unknown as Mocked<ToolRegistry>;
@@ -336,6 +351,9 @@ describe('Scheduler Parallel Execution', () => {
vi.mocked(writeTool.build).mockReturnValue(
mockInvocation as unknown as AnyToolInvocation,
);
vi.mocked(editTool.build).mockReturnValue(
mockInvocation as unknown as AnyToolInvocation,
);
vi.mocked(agentTool1.build).mockReturnValue(
mockInvocation as unknown as AnyToolInvocation,
);
@@ -597,4 +615,44 @@ describe('Scheduler Parallel Execution', () => {
expect(executionLog.slice(2, 4)).toContain('start-call-1');
expect(executionLog.slice(2, 4)).toContain('start-call-2');
});
it.each(Array.from(EDIT_TOOL_NAMES))(
'should execute %s sequentially even without wait_for_previous',
async (toolName) => {
const executionLog: string[] = [];
mockExecutor.execute.mockImplementation(async ({ call }) => {
const id = call.request.callId;
executionLog.push(`start-${id}`);
await new Promise<void>((resolve) => setTimeout(resolve, 10));
executionLog.push(`end-${id}`);
return {
status: 'success',
response: { callId: id, responseParts: [] },
} as unknown as SuccessfulToolCall;
});
const e1: ToolCallRequestInfo = {
callId: 'e1',
name: toolName,
args: { path: 'a.txt', wait_for_previous: false },
isClientInitiated: false,
prompt_id: 'p1',
schedulerId: ROOT_SCHEDULER_ID,
};
const e2: ToolCallRequestInfo = {
...e1,
callId: 'e2',
};
await scheduler.schedule([e1, e2], signal);
// Even though wait_for_previous is false, EDIT_TOOL_NAMES enforces sequential execution
expect(executionLog).toEqual([
'start-e1',
'end-e1',
'start-e2',
'end-e2',
]);
},
);
});