mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-06-15 22:07:29 -07:00
Fix/pending tools and trust overrides (#27854)
This commit is contained in:
committed by
GitHub
parent
bca5667fc6
commit
0f8a157e5e
@@ -22,6 +22,7 @@ vi.mock('../config/config.js', () => ({
|
||||
getCheckpointingEnabled: () => false,
|
||||
}),
|
||||
loadEnvironment: vi.fn(),
|
||||
setIsTrusted: vi.fn().mockReturnValue(false),
|
||||
setTargetDir: vi.fn().mockReturnValue('/tmp'),
|
||||
}));
|
||||
|
||||
@@ -62,6 +63,12 @@ vi.mock('./task.js', () => {
|
||||
scheduleToolCalls: vi.fn().mockResolvedValue(undefined),
|
||||
waitForPendingTools: vi.fn().mockResolvedValue(undefined),
|
||||
getAndClearCompletedTools: vi.fn().mockReturnValue([]),
|
||||
get hasPendingTools() {
|
||||
return false;
|
||||
},
|
||||
get pendingToolsCount() {
|
||||
return 0;
|
||||
},
|
||||
addToolResponsesToHistory: vi.fn(),
|
||||
sendCompletedToolsToLlm: vi.fn().mockImplementation(async function* () {}),
|
||||
cancelPendingTools: vi.fn(),
|
||||
@@ -245,4 +252,52 @@ describe('CoderAgentExecutor', () => {
|
||||
expect(executor.getTask(taskId)).toBeUndefined();
|
||||
expect(wrapper.task.dispose).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should yield the turn and transition to input-required if tools are pending', async () => {
|
||||
const taskId = 'test-task-pending-tools';
|
||||
const contextId = 'test-context';
|
||||
|
||||
const mockSocket = new EventEmitter();
|
||||
(requestStorage.getStore as Mock).mockReturnValue({
|
||||
req: { socket: mockSocket },
|
||||
});
|
||||
|
||||
// Pre-create the task to safely modify its mocked methods before execution
|
||||
const wrapper = await executor.createTask(
|
||||
taskId,
|
||||
contextId,
|
||||
undefined,
|
||||
mockEventBus,
|
||||
);
|
||||
const hasPendingToolsSpy = vi
|
||||
.spyOn(wrapper.task, 'hasPendingTools', 'get')
|
||||
.mockReturnValue(true);
|
||||
vi.spyOn(wrapper.task, 'pendingToolsCount', 'get').mockReturnValue(1);
|
||||
|
||||
const requestContext = {
|
||||
userMessage: {
|
||||
messageId: 'msg-1',
|
||||
taskId,
|
||||
contextId,
|
||||
parts: [{ kind: 'confirmation', callId: '1', outcome: 'proceed' }],
|
||||
metadata: {
|
||||
coderAgent: { kind: 'agent-settings', workspacePath: '/tmp' },
|
||||
},
|
||||
},
|
||||
} as unknown as RequestContext;
|
||||
|
||||
await executor.execute(requestContext, mockEventBus);
|
||||
|
||||
// Assert that the executor yielded the turn correctly without further progression
|
||||
expect(hasPendingToolsSpy).toHaveBeenCalled();
|
||||
expect(wrapper.task.getAndClearCompletedTools).not.toHaveBeenCalled();
|
||||
expect(wrapper.task.sendCompletedToolsToLlm).not.toHaveBeenCalled();
|
||||
expect(wrapper.task.setTaskStateAndPublishUpdate).toHaveBeenCalledWith(
|
||||
'input-required',
|
||||
expect.any(Object),
|
||||
undefined,
|
||||
undefined,
|
||||
true,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -31,7 +31,12 @@ import {
|
||||
getContextIdFromMetadata,
|
||||
getAgentSettingsFromMetadata,
|
||||
} from '../types.js';
|
||||
import { loadConfig, loadEnvironment, setTargetDir } from '../config/config.js';
|
||||
import {
|
||||
loadConfig,
|
||||
loadEnvironment,
|
||||
setIsTrusted,
|
||||
setTargetDir,
|
||||
} from '../config/config.js';
|
||||
import { loadSettings } from '../config/settings.js';
|
||||
import { loadExtensions } from '../config/extension.js';
|
||||
import { Task } from './task.js';
|
||||
@@ -93,8 +98,8 @@ export class CoderAgentExecutor implements AgentExecutor {
|
||||
taskId: string,
|
||||
): Promise<Config> {
|
||||
const workspaceRoot = setTargetDir(agentSettings);
|
||||
const isTrusted = agentSettings.isTrusted ?? false;
|
||||
loadEnvironment(); // Will override any global env with workspace envs
|
||||
const isTrusted = setIsTrusted(agentSettings);
|
||||
const settings = loadSettings(workspaceRoot, isTrusted);
|
||||
const extensions = loadExtensions(workspaceRoot);
|
||||
return loadConfig(
|
||||
@@ -541,42 +546,49 @@ export class CoderAgentExecutor implements AgentExecutor {
|
||||
|
||||
if (abortSignal.aborted) throw new Error('Execution aborted');
|
||||
|
||||
const completedTools = currentTask.getAndClearCompletedTools();
|
||||
|
||||
if (completedTools.length > 0) {
|
||||
// If all completed tool calls were canceled, manually add them to history and set state to input-required, final:true
|
||||
if (completedTools.every((tool) => tool.status === 'cancelled')) {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId}: All tool calls were cancelled. Updating history and ending agent turn.`,
|
||||
);
|
||||
currentTask.addToolResponsesToHistory(completedTools);
|
||||
agentTurnActive = false;
|
||||
const stateChange: StateChange = {
|
||||
kind: CoderAgentEvent.StateChangeEvent,
|
||||
};
|
||||
currentTask.setTaskStateAndPublishUpdate(
|
||||
'input-required',
|
||||
stateChange,
|
||||
undefined,
|
||||
undefined,
|
||||
true,
|
||||
);
|
||||
} else {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId}: Found ${completedTools.length} completed tool calls. Sending results back to LLM.`,
|
||||
);
|
||||
|
||||
agentEvents = currentTask.sendCompletedToolsToLlm(
|
||||
completedTools,
|
||||
abortSignal,
|
||||
);
|
||||
// Continue the loop to process the LLM response to the tool results.
|
||||
}
|
||||
} else {
|
||||
if (currentTask.hasPendingTools) {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId}: No more tool calls to process. Ending agent turn.`,
|
||||
`[CoderAgentExecutor] Task ${taskId}: There are still ${currentTask.pendingToolsCount} pending tools waiting for approval. Yielding to user.`,
|
||||
);
|
||||
agentTurnActive = false;
|
||||
} else {
|
||||
const completedTools = currentTask.getAndClearCompletedTools();
|
||||
|
||||
if (completedTools.length > 0) {
|
||||
// If all completed tool calls were canceled, manually add them to history and set state to input-required, final:true
|
||||
if (completedTools.every((tool) => tool.status === 'cancelled')) {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId}: All tool calls were cancelled. Updating history and ending agent turn.`,
|
||||
);
|
||||
currentTask.addToolResponsesToHistory(completedTools);
|
||||
agentTurnActive = false;
|
||||
const stateChange: StateChange = {
|
||||
kind: CoderAgentEvent.StateChangeEvent,
|
||||
};
|
||||
currentTask.setTaskStateAndPublishUpdate(
|
||||
'input-required',
|
||||
stateChange,
|
||||
undefined,
|
||||
undefined,
|
||||
true,
|
||||
);
|
||||
} else {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId}: Found ${completedTools.length} completed tool calls. Sending results back to LLM.`,
|
||||
);
|
||||
|
||||
agentEvents = currentTask.sendCompletedToolsToLlm(
|
||||
completedTools,
|
||||
abortSignal,
|
||||
);
|
||||
// Continue the loop to process the LLM response to the tool results.
|
||||
}
|
||||
} else {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId}: No more tool calls to process. Ending agent turn.`,
|
||||
);
|
||||
agentTurnActive = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -631,6 +631,35 @@ describe('Task', () => {
|
||||
|
||||
expect(handleEventDrivenToolCallSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
describe('Pending Tools state', () => {
|
||||
it('should correctly report pending tools presence and count', () => {
|
||||
const mockConfig = createMockConfig();
|
||||
const mockEventBus: ExecutionEventBus = {
|
||||
publish: vi.fn(),
|
||||
on: vi.fn(),
|
||||
off: vi.fn(),
|
||||
once: vi.fn(),
|
||||
removeAllListeners: vi.fn(),
|
||||
finished: vi.fn(),
|
||||
};
|
||||
|
||||
// @ts-expect-error - Calling private constructor
|
||||
const task = new Task(
|
||||
'task-id',
|
||||
'context-id',
|
||||
mockConfig as Config,
|
||||
mockEventBus,
|
||||
);
|
||||
|
||||
expect(task.hasPendingTools).toBe(false);
|
||||
expect(task.pendingToolsCount).toBe(0);
|
||||
|
||||
task['_registerToolCall']('tool-1', 'scheduled');
|
||||
expect(task.hasPendingTools).toBe(true);
|
||||
expect(task.pendingToolsCount).toBe(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Serialization and Mapping', () => {
|
||||
|
||||
@@ -137,6 +137,14 @@ export class Task {
|
||||
);
|
||||
}
|
||||
|
||||
get hasPendingTools(): boolean {
|
||||
return this.pendingToolCalls.size > 0;
|
||||
}
|
||||
|
||||
get pendingToolsCount(): number {
|
||||
return this.pendingToolCalls.size;
|
||||
}
|
||||
|
||||
static async create(
|
||||
id: string,
|
||||
contextId: string,
|
||||
|
||||
@@ -23,6 +23,7 @@ import {
|
||||
PRIORITY_YOLO_ALLOW_ALL,
|
||||
createPolicyEngineConfig,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type { AgentSettings } from '../types.js';
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
@@ -612,3 +613,34 @@ describe('loadConfig', () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('setIsTrusted', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetModules();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.unstubAllEnvs();
|
||||
});
|
||||
|
||||
it('should return true when GEMINI_FOLDER_TRUST env var is true', async () => {
|
||||
vi.stubEnv('GEMINI_FOLDER_TRUST', 'true');
|
||||
const { setIsTrusted } = await import('./config.js');
|
||||
expect(setIsTrusted(undefined)).toBe(true);
|
||||
expect(setIsTrusted({ isTrusted: false } as AgentSettings)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false when GEMINI_FOLDER_TRUST env var is false', async () => {
|
||||
vi.stubEnv('GEMINI_FOLDER_TRUST', 'false');
|
||||
const { setIsTrusted } = await import('./config.js');
|
||||
expect(setIsTrusted(undefined)).toBe(false);
|
||||
expect(setIsTrusted({ isTrusted: true } as AgentSettings)).toBe(false);
|
||||
});
|
||||
|
||||
it('should fallback to agentSettings.isTrusted if env var is undefined', async () => {
|
||||
const { setIsTrusted } = await import('./config.js');
|
||||
expect(setIsTrusted({ isTrusted: true } as AgentSettings)).toBe(true);
|
||||
expect(setIsTrusted({ isTrusted: false } as AgentSettings)).toBe(false);
|
||||
expect(setIsTrusted(undefined)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -34,6 +34,8 @@ import { logger } from '../utils/logger.js';
|
||||
import type { Settings } from './settings.js';
|
||||
import { type AgentSettings, CoderAgentEvent } from '../types.js';
|
||||
|
||||
const INITIAL_FOLDER_TRUST = process.env['GEMINI_FOLDER_TRUST'];
|
||||
|
||||
export async function loadConfig(
|
||||
settings: Settings,
|
||||
extensionLoader: ExtensionLoader,
|
||||
@@ -182,6 +184,15 @@ export async function loadConfig(
|
||||
return config;
|
||||
}
|
||||
|
||||
export function setIsTrusted(
|
||||
agentSettings: AgentSettings | undefined,
|
||||
): boolean {
|
||||
if (INITIAL_FOLDER_TRUST !== undefined) {
|
||||
return INITIAL_FOLDER_TRUST === 'true';
|
||||
}
|
||||
return !!agentSettings?.isTrusted;
|
||||
}
|
||||
|
||||
export function setTargetDir(agentSettings: AgentSettings | undefined): string {
|
||||
const originalCWD = process.cwd();
|
||||
const targetDir =
|
||||
|
||||
@@ -26,7 +26,10 @@ import {
|
||||
type ScheduledToolCall,
|
||||
} from './types.js';
|
||||
import { ToolErrorType } from '../tools/tool-error.js';
|
||||
import { UPDATE_TOPIC_TOOL_NAME } from '../tools/tool-names.js';
|
||||
import {
|
||||
UPDATE_TOPIC_TOOL_NAME,
|
||||
EDIT_TOOL_NAMES,
|
||||
} from '../tools/tool-names.js';
|
||||
import { PolicyDecision, type ApprovalMode } from '../policy/types.js';
|
||||
import {
|
||||
ToolConfirmationOutcome,
|
||||
@@ -548,7 +551,10 @@ export class Scheduler {
|
||||
|
||||
private _isParallelizable(request: ToolCallRequestInfo): boolean {
|
||||
// update_topic tool is forced as sequential call
|
||||
if (request.name === UPDATE_TOPIC_TOOL_NAME) {
|
||||
if (
|
||||
request.name === UPDATE_TOPIC_TOOL_NAME ||
|
||||
EDIT_TOOL_NAMES.has(request.name)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
if (request.args) {
|
||||
|
||||
@@ -79,7 +79,12 @@ import {
|
||||
type Status,
|
||||
type ToolCall,
|
||||
} from './types.js';
|
||||
import { UPDATE_TOPIC_TOOL_NAME } from '../tools/tool-names.js';
|
||||
import {
|
||||
UPDATE_TOPIC_TOOL_NAME,
|
||||
WRITE_FILE_TOOL_NAME,
|
||||
EDIT_TOOL_NAME,
|
||||
EDIT_TOOL_NAMES,
|
||||
} from '../tools/tool-names.js';
|
||||
import { GeminiCliOperation } from '../telemetry/constants.js';
|
||||
import type { EditorType } from '../utils/editor.js';
|
||||
|
||||
@@ -161,6 +166,12 @@ describe('Scheduler Parallel Execution', () => {
|
||||
isReadOnly: false,
|
||||
build: vi.fn(),
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
const editTool = {
|
||||
name: EDIT_TOOL_NAME,
|
||||
kind: Kind.Execute,
|
||||
isReadOnly: false,
|
||||
build: vi.fn(),
|
||||
} as unknown as AnyDeclarativeTool;
|
||||
const agentTool1 = {
|
||||
name: 'agent-tool-1',
|
||||
kind: Kind.Agent,
|
||||
@@ -203,6 +214,8 @@ describe('Scheduler Parallel Execution', () => {
|
||||
if (name === 'agent-tool-1') return agentTool1;
|
||||
if (name === 'agent-tool-2') return agentTool2;
|
||||
if (name === UPDATE_TOPIC_TOOL_NAME) return topicTool;
|
||||
if (name === WRITE_FILE_TOOL_NAME) return writeTool;
|
||||
if (name === EDIT_TOOL_NAME) return editTool;
|
||||
return undefined;
|
||||
}),
|
||||
getAllToolNames: vi
|
||||
@@ -214,6 +227,8 @@ describe('Scheduler Parallel Execution', () => {
|
||||
'agent-tool-1',
|
||||
'agent-tool-2',
|
||||
UPDATE_TOPIC_TOOL_NAME,
|
||||
WRITE_FILE_TOOL_NAME,
|
||||
EDIT_TOOL_NAME,
|
||||
]),
|
||||
} as unknown as Mocked<ToolRegistry>;
|
||||
|
||||
@@ -336,6 +351,9 @@ describe('Scheduler Parallel Execution', () => {
|
||||
vi.mocked(writeTool.build).mockReturnValue(
|
||||
mockInvocation as unknown as AnyToolInvocation,
|
||||
);
|
||||
vi.mocked(editTool.build).mockReturnValue(
|
||||
mockInvocation as unknown as AnyToolInvocation,
|
||||
);
|
||||
vi.mocked(agentTool1.build).mockReturnValue(
|
||||
mockInvocation as unknown as AnyToolInvocation,
|
||||
);
|
||||
@@ -597,4 +615,44 @@ describe('Scheduler Parallel Execution', () => {
|
||||
expect(executionLog.slice(2, 4)).toContain('start-call-1');
|
||||
expect(executionLog.slice(2, 4)).toContain('start-call-2');
|
||||
});
|
||||
|
||||
it.each(Array.from(EDIT_TOOL_NAMES))(
|
||||
'should execute %s sequentially even without wait_for_previous',
|
||||
async (toolName) => {
|
||||
const executionLog: string[] = [];
|
||||
mockExecutor.execute.mockImplementation(async ({ call }) => {
|
||||
const id = call.request.callId;
|
||||
executionLog.push(`start-${id}`);
|
||||
await new Promise<void>((resolve) => setTimeout(resolve, 10));
|
||||
executionLog.push(`end-${id}`);
|
||||
return {
|
||||
status: 'success',
|
||||
response: { callId: id, responseParts: [] },
|
||||
} as unknown as SuccessfulToolCall;
|
||||
});
|
||||
|
||||
const e1: ToolCallRequestInfo = {
|
||||
callId: 'e1',
|
||||
name: toolName,
|
||||
args: { path: 'a.txt', wait_for_previous: false },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'p1',
|
||||
schedulerId: ROOT_SCHEDULER_ID,
|
||||
};
|
||||
const e2: ToolCallRequestInfo = {
|
||||
...e1,
|
||||
callId: 'e2',
|
||||
};
|
||||
|
||||
await scheduler.schedule([e1, e2], signal);
|
||||
|
||||
// Even though wait_for_previous is false, EDIT_TOOL_NAMES enforces sequential execution
|
||||
expect(executionLog).toEqual([
|
||||
'start-e1',
|
||||
'end-e1',
|
||||
'start-e2',
|
||||
'end-e2',
|
||||
]);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user