mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
feat(core): pause agent timeout budget while waiting for tool confirmation (#18415)
This commit is contained in:
@@ -27,6 +27,8 @@ export interface AgentSchedulingOptions {
|
||||
signal: AbortSignal;
|
||||
/** Optional function to get the preferred editor for tool modifications. */
|
||||
getPreferredEditor?: () => EditorType | undefined;
|
||||
/** Optional function to be notified when the scheduler is waiting for user confirmation. */
|
||||
onWaitingForConfirmation?: (waiting: boolean) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -48,6 +50,7 @@ export async function scheduleAgentTools(
|
||||
toolRegistry,
|
||||
signal,
|
||||
getPreferredEditor,
|
||||
onWaitingForConfirmation,
|
||||
} = options;
|
||||
|
||||
// Create a proxy/override of the config to provide the agent-specific tool registry.
|
||||
@@ -60,6 +63,7 @@ export async function scheduleAgentTools(
|
||||
getPreferredEditor: getPreferredEditor ?? (() => undefined),
|
||||
schedulerId,
|
||||
parentCallId,
|
||||
onWaitingForConfirmation,
|
||||
});
|
||||
|
||||
return scheduler.schedule(requests, signal);
|
||||
|
||||
@@ -58,6 +58,7 @@ import { getModelConfigAlias } from './registry.js';
|
||||
import { getVersion } from '../utils/version.js';
|
||||
import { getToolCallContext } from '../utils/toolCallContext.js';
|
||||
import { scheduleAgentTools } from './agent-scheduler.js';
|
||||
import { DeadlineTimer } from '../utils/deadlineTimer.js';
|
||||
|
||||
/** A callback function to report on agent activity. */
|
||||
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
|
||||
@@ -231,6 +232,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
turnCounter: number,
|
||||
combinedSignal: AbortSignal,
|
||||
timeoutSignal: AbortSignal, // Pass the timeout controller's signal
|
||||
onWaitingForConfirmation?: (waiting: boolean) => void,
|
||||
): Promise<AgentTurnResult> {
|
||||
const promptId = `${this.agentId}#${turnCounter}`;
|
||||
|
||||
@@ -265,7 +267,12 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
}
|
||||
|
||||
const { nextMessage, submittedOutput, taskCompleted } =
|
||||
await this.processFunctionCalls(functionCalls, combinedSignal, promptId);
|
||||
await this.processFunctionCalls(
|
||||
functionCalls,
|
||||
combinedSignal,
|
||||
promptId,
|
||||
onWaitingForConfirmation,
|
||||
);
|
||||
if (taskCompleted) {
|
||||
const finalResult = submittedOutput ?? 'Task completed successfully.';
|
||||
return {
|
||||
@@ -322,6 +329,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
| AgentTerminateMode.MAX_TURNS
|
||||
| AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL,
|
||||
externalSignal: AbortSignal, // The original signal passed to run()
|
||||
onWaitingForConfirmation?: (waiting: boolean) => void,
|
||||
): Promise<string | null> {
|
||||
this.emitActivity('THOUGHT_CHUNK', {
|
||||
text: `Execution limit reached (${reason}). Attempting one final recovery turn with a grace period.`,
|
||||
@@ -355,6 +363,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
turnCounter, // This will be the "last" turn number
|
||||
combinedSignal,
|
||||
graceTimeoutController.signal, // Pass grace signal to identify a *grace* timeout
|
||||
onWaitingForConfirmation,
|
||||
);
|
||||
|
||||
if (
|
||||
@@ -415,14 +424,22 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
this.definition.runConfig.maxTimeMinutes ?? DEFAULT_MAX_TIME_MINUTES;
|
||||
const maxTurns = this.definition.runConfig.maxTurns ?? DEFAULT_MAX_TURNS;
|
||||
|
||||
const timeoutController = new AbortController();
|
||||
const timeoutId = setTimeout(
|
||||
() => timeoutController.abort(new Error('Agent timed out.')),
|
||||
const deadlineTimer = new DeadlineTimer(
|
||||
maxTimeMinutes * 60 * 1000,
|
||||
'Agent timed out.',
|
||||
);
|
||||
|
||||
// Track time spent waiting for user confirmation to credit it back to the agent.
|
||||
const onWaitingForConfirmation = (waiting: boolean) => {
|
||||
if (waiting) {
|
||||
deadlineTimer.pause();
|
||||
} else {
|
||||
deadlineTimer.resume();
|
||||
}
|
||||
};
|
||||
|
||||
// Combine the external signal with the internal timeout signal.
|
||||
const combinedSignal = AbortSignal.any([signal, timeoutController.signal]);
|
||||
const combinedSignal = AbortSignal.any([signal, deadlineTimer.signal]);
|
||||
|
||||
logAgentStart(
|
||||
this.runtimeContext,
|
||||
@@ -458,7 +475,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
// Check for timeout or external abort.
|
||||
if (combinedSignal.aborted) {
|
||||
// Determine which signal caused the abort.
|
||||
terminateReason = timeoutController.signal.aborted
|
||||
terminateReason = deadlineTimer.signal.aborted
|
||||
? AgentTerminateMode.TIMEOUT
|
||||
: AgentTerminateMode.ABORTED;
|
||||
break;
|
||||
@@ -469,7 +486,8 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
currentMessage,
|
||||
turnCounter++,
|
||||
combinedSignal,
|
||||
timeoutController.signal,
|
||||
deadlineTimer.signal,
|
||||
onWaitingForConfirmation,
|
||||
);
|
||||
|
||||
if (turnResult.status === 'stop') {
|
||||
@@ -498,6 +516,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
turnCounter, // Use current turnCounter for the recovery attempt
|
||||
terminateReason,
|
||||
signal, // Pass the external signal
|
||||
onWaitingForConfirmation,
|
||||
);
|
||||
|
||||
if (recoveryResult !== null) {
|
||||
@@ -551,7 +570,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
if (
|
||||
error instanceof Error &&
|
||||
error.name === 'AbortError' &&
|
||||
timeoutController.signal.aborted &&
|
||||
deadlineTimer.signal.aborted &&
|
||||
!signal.aborted // Ensure the external signal was not the cause
|
||||
) {
|
||||
terminateReason = AgentTerminateMode.TIMEOUT;
|
||||
@@ -563,6 +582,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
turnCounter, // Use current turnCounter
|
||||
AgentTerminateMode.TIMEOUT,
|
||||
signal,
|
||||
onWaitingForConfirmation,
|
||||
);
|
||||
|
||||
if (recoveryResult !== null) {
|
||||
@@ -591,7 +611,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
this.emitActivity('ERROR', { error: String(error) });
|
||||
throw error; // Re-throw other errors or external aborts.
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
deadlineTimer.abort();
|
||||
logAgentFinish(
|
||||
this.runtimeContext,
|
||||
new AgentFinishEvent(
|
||||
@@ -779,6 +799,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
functionCalls: FunctionCall[],
|
||||
signal: AbortSignal,
|
||||
promptId: string,
|
||||
onWaitingForConfirmation?: (waiting: boolean) => void,
|
||||
): Promise<{
|
||||
nextMessage: Content;
|
||||
submittedOutput: string | null;
|
||||
@@ -979,6 +1000,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
parentCallId: this.parentCallId,
|
||||
toolRegistry: this.toolRegistry,
|
||||
signal,
|
||||
onWaitingForConfirmation,
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
@@ -109,9 +109,10 @@ export async function resolveConfirmation(
|
||||
modifier: ToolModificationHandler;
|
||||
getPreferredEditor: () => EditorType | undefined;
|
||||
schedulerId: string;
|
||||
onWaitingForConfirmation?: (waiting: boolean) => void;
|
||||
},
|
||||
): Promise<ResolutionResult> {
|
||||
const { state } = deps;
|
||||
const { state, onWaitingForConfirmation } = deps;
|
||||
const callId = toolCall.request.callId;
|
||||
let outcome = ToolConfirmationOutcome.ModifyWithEditor;
|
||||
let lastDetails: SerializableConfirmationDetails | undefined;
|
||||
@@ -147,12 +148,14 @@ export async function resolveConfirmation(
|
||||
correlationId,
|
||||
});
|
||||
|
||||
onWaitingForConfirmation?.(true);
|
||||
const response = await waitForConfirmation(
|
||||
deps.messageBus,
|
||||
correlationId,
|
||||
signal,
|
||||
ideConfirmation,
|
||||
);
|
||||
onWaitingForConfirmation?.(false);
|
||||
outcome = response.outcome;
|
||||
|
||||
if ('onConfirm' in details && typeof details.onConfirm === 'function') {
|
||||
|
||||
@@ -51,6 +51,7 @@ export interface SchedulerOptions {
|
||||
getPreferredEditor: () => EditorType | undefined;
|
||||
schedulerId: string;
|
||||
parentCallId?: string;
|
||||
onWaitingForConfirmation?: (waiting: boolean) => void;
|
||||
}
|
||||
|
||||
const createErrorResponse = (
|
||||
@@ -90,6 +91,7 @@ export class Scheduler {
|
||||
private readonly getPreferredEditor: () => EditorType | undefined;
|
||||
private readonly schedulerId: string;
|
||||
private readonly parentCallId?: string;
|
||||
private readonly onWaitingForConfirmation?: (waiting: boolean) => void;
|
||||
|
||||
private isProcessing = false;
|
||||
private isCancelling = false;
|
||||
@@ -101,6 +103,7 @@ export class Scheduler {
|
||||
this.getPreferredEditor = options.getPreferredEditor;
|
||||
this.schedulerId = options.schedulerId;
|
||||
this.parentCallId = options.parentCallId;
|
||||
this.onWaitingForConfirmation = options.onWaitingForConfirmation;
|
||||
this.state = new SchedulerStateManager(
|
||||
this.messageBus,
|
||||
this.schedulerId,
|
||||
@@ -437,6 +440,7 @@ export class Scheduler {
|
||||
modifier: this.modifier,
|
||||
getPreferredEditor: this.getPreferredEditor,
|
||||
schedulerId: this.schedulerId,
|
||||
onWaitingForConfirmation: this.onWaitingForConfirmation,
|
||||
});
|
||||
outcome = result.outcome;
|
||||
lastDetails = result.lastDetails;
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { Scheduler } from './scheduler.js';
|
||||
import { resolveConfirmation } from './confirmation.js';
|
||||
import { checkPolicy } from './policy.js';
|
||||
import { PolicyDecision } from '../policy/types.js';
|
||||
import { ToolConfirmationOutcome } from '../tools/tools.js';
|
||||
import { ToolRegistry } from '../tools/tool-registry.js';
|
||||
import { MockTool } from '../test-utils/mock-tool.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
import { makeFakeConfig } from '../test-utils/config.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { ToolCallRequestInfo } from './types.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
|
||||
vi.mock('./confirmation.js');
|
||||
vi.mock('./policy.js');
|
||||
|
||||
describe('Scheduler waiting callback', () => {
|
||||
let mockConfig: Config;
|
||||
let messageBus: MessageBus;
|
||||
let toolRegistry: ToolRegistry;
|
||||
let mockTool: MockTool;
|
||||
|
||||
beforeEach(() => {
|
||||
messageBus = createMockMessageBus();
|
||||
mockConfig = makeFakeConfig();
|
||||
|
||||
// Override methods to use our mocks
|
||||
vi.spyOn(mockConfig, 'getMessageBus').mockReturnValue(messageBus);
|
||||
|
||||
mockTool = new MockTool({ name: 'test_tool' });
|
||||
toolRegistry = new ToolRegistry(mockConfig, messageBus);
|
||||
vi.spyOn(mockConfig, 'getToolRegistry').mockReturnValue(toolRegistry);
|
||||
toolRegistry.registerTool(mockTool);
|
||||
|
||||
vi.mocked(checkPolicy).mockResolvedValue({
|
||||
decision: PolicyDecision.ASK_USER,
|
||||
rule: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
it('should trigger onWaitingForConfirmation callback', async () => {
|
||||
const onWaitingForConfirmation = vi.fn();
|
||||
const scheduler = new Scheduler({
|
||||
config: mockConfig,
|
||||
messageBus,
|
||||
getPreferredEditor: () => undefined,
|
||||
schedulerId: 'test-scheduler',
|
||||
onWaitingForConfirmation,
|
||||
});
|
||||
|
||||
vi.mocked(resolveConfirmation).mockResolvedValue({
|
||||
outcome: ToolConfirmationOutcome.ProceedOnce,
|
||||
});
|
||||
|
||||
const req: ToolCallRequestInfo = {
|
||||
callId: 'call-1',
|
||||
name: 'test_tool',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
};
|
||||
|
||||
await scheduler.schedule(req, new AbortController().signal);
|
||||
|
||||
expect(resolveConfirmation).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
onWaitingForConfirmation,
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
82
packages/core/src/utils/deadlineTimer.test.ts
Normal file
82
packages/core/src/utils/deadlineTimer.test.ts
Normal file
@@ -0,0 +1,82 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { DeadlineTimer } from './deadlineTimer.js';
|
||||
|
||||
describe('DeadlineTimer', () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should abort when timeout is reached', () => {
|
||||
const timer = new DeadlineTimer(1000);
|
||||
const signal = timer.signal;
|
||||
expect(signal.aborted).toBe(false);
|
||||
|
||||
vi.advanceTimersByTime(1000);
|
||||
expect(signal.aborted).toBe(true);
|
||||
expect(signal.reason).toBeInstanceOf(Error);
|
||||
expect((signal.reason as Error).message).toBe('Timeout exceeded.');
|
||||
});
|
||||
|
||||
it('should allow extending the deadline', () => {
|
||||
const timer = new DeadlineTimer(1000);
|
||||
const signal = timer.signal;
|
||||
|
||||
vi.advanceTimersByTime(500);
|
||||
expect(signal.aborted).toBe(false);
|
||||
|
||||
timer.extend(1000); // New deadline is 1000 + 1000 = 2000 from start
|
||||
|
||||
vi.advanceTimersByTime(600); // 1100 total
|
||||
expect(signal.aborted).toBe(false);
|
||||
|
||||
vi.advanceTimersByTime(900); // 2000 total
|
||||
expect(signal.aborted).toBe(true);
|
||||
});
|
||||
|
||||
it('should allow pausing and resuming the timer', () => {
|
||||
const timer = new DeadlineTimer(1000);
|
||||
const signal = timer.signal;
|
||||
|
||||
vi.advanceTimersByTime(500);
|
||||
timer.pause();
|
||||
|
||||
vi.advanceTimersByTime(2000); // Wait a long time while paused
|
||||
expect(signal.aborted).toBe(false);
|
||||
|
||||
timer.resume();
|
||||
vi.advanceTimersByTime(400);
|
||||
expect(signal.aborted).toBe(false);
|
||||
|
||||
vi.advanceTimersByTime(200); // Total active time 500 + 400 + 200 = 1100
|
||||
expect(signal.aborted).toBe(true);
|
||||
});
|
||||
|
||||
it('should abort immediately when abort() is called', () => {
|
||||
const timer = new DeadlineTimer(1000);
|
||||
const signal = timer.signal;
|
||||
|
||||
timer.abort('cancelled');
|
||||
expect(signal.aborted).toBe(true);
|
||||
expect(signal.reason).toBe('cancelled');
|
||||
});
|
||||
|
||||
it('should not fire timeout if aborted manually', () => {
|
||||
const timer = new DeadlineTimer(1000);
|
||||
const signal = timer.signal;
|
||||
|
||||
timer.abort();
|
||||
vi.advanceTimersByTime(1000);
|
||||
// Already aborted, but shouldn't re-abort or throw
|
||||
expect(signal.aborted).toBe(true);
|
||||
});
|
||||
});
|
||||
94
packages/core/src/utils/deadlineTimer.ts
Normal file
94
packages/core/src/utils/deadlineTimer.ts
Normal file
@@ -0,0 +1,94 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/**
|
||||
* A utility that manages a timeout and an AbortController, allowing the
|
||||
* timeout to be paused, resumed, and dynamically extended.
|
||||
*/
|
||||
export class DeadlineTimer {
|
||||
private readonly controller: AbortController;
|
||||
private timeoutId: NodeJS.Timeout | null = null;
|
||||
private remainingMs: number;
|
||||
private lastStartedAt: number;
|
||||
private isPaused = false;
|
||||
|
||||
constructor(timeoutMs: number, reason = 'Timeout exceeded.') {
|
||||
this.controller = new AbortController();
|
||||
this.remainingMs = timeoutMs;
|
||||
this.lastStartedAt = Date.now();
|
||||
this.schedule(timeoutMs, reason);
|
||||
}
|
||||
|
||||
/** The AbortSignal managed by this timer. */
|
||||
get signal(): AbortSignal {
|
||||
return this.controller.signal;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pauses the timer, clearing any active timeout.
|
||||
*/
|
||||
pause(): void {
|
||||
if (this.isPaused || this.controller.signal.aborted) return;
|
||||
|
||||
if (this.timeoutId) {
|
||||
clearTimeout(this.timeoutId);
|
||||
this.timeoutId = null;
|
||||
}
|
||||
|
||||
const elapsed = Date.now() - this.lastStartedAt;
|
||||
this.remainingMs = Math.max(0, this.remainingMs - elapsed);
|
||||
this.isPaused = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resumes the timer with the remaining budget.
|
||||
*/
|
||||
resume(reason = 'Timeout exceeded.'): void {
|
||||
if (!this.isPaused || this.controller.signal.aborted) return;
|
||||
|
||||
this.lastStartedAt = Date.now();
|
||||
this.schedule(this.remainingMs, reason);
|
||||
this.isPaused = false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extends the current budget by the specified number of milliseconds.
|
||||
*/
|
||||
extend(ms: number, reason = 'Timeout exceeded.'): void {
|
||||
if (this.controller.signal.aborted) return;
|
||||
|
||||
if (this.isPaused) {
|
||||
this.remainingMs += ms;
|
||||
} else {
|
||||
if (this.timeoutId) {
|
||||
clearTimeout(this.timeoutId);
|
||||
}
|
||||
const elapsed = Date.now() - this.lastStartedAt;
|
||||
this.remainingMs = Math.max(0, this.remainingMs - elapsed) + ms;
|
||||
this.lastStartedAt = Date.now();
|
||||
this.schedule(this.remainingMs, reason);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Aborts the signal immediately and clears any pending timers.
|
||||
*/
|
||||
abort(reason?: unknown): void {
|
||||
if (this.timeoutId) {
|
||||
clearTimeout(this.timeoutId);
|
||||
this.timeoutId = null;
|
||||
}
|
||||
this.isPaused = false;
|
||||
this.controller.abort(reason);
|
||||
}
|
||||
|
||||
private schedule(ms: number, reason: string): void {
|
||||
this.timeoutId = setTimeout(() => {
|
||||
this.timeoutId = null;
|
||||
this.controller.abort(new Error(reason));
|
||||
}, ms);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user