feat(core): pause agent timeout budget while waiting for tool confirmation (#18415)

This commit is contained in:
Abhi
2026-02-07 23:03:47 -05:00
committed by GitHub
parent bc8ffa6631
commit 11951592aa
7 changed files with 299 additions and 10 deletions

View File

@@ -27,6 +27,8 @@ export interface AgentSchedulingOptions {
signal: AbortSignal;
/** Optional function to get the preferred editor for tool modifications. */
getPreferredEditor?: () => EditorType | undefined;
/** Optional function to be notified when the scheduler is waiting for user confirmation. */
onWaitingForConfirmation?: (waiting: boolean) => void;
}
/**
@@ -48,6 +50,7 @@ export async function scheduleAgentTools(
toolRegistry,
signal,
getPreferredEditor,
onWaitingForConfirmation,
} = options;
// Create a proxy/override of the config to provide the agent-specific tool registry.
@@ -60,6 +63,7 @@ export async function scheduleAgentTools(
getPreferredEditor: getPreferredEditor ?? (() => undefined),
schedulerId,
parentCallId,
onWaitingForConfirmation,
});
return scheduler.schedule(requests, signal);

View File

@@ -58,6 +58,7 @@ import { getModelConfigAlias } from './registry.js';
import { getVersion } from '../utils/version.js';
import { getToolCallContext } from '../utils/toolCallContext.js';
import { scheduleAgentTools } from './agent-scheduler.js';
import { DeadlineTimer } from '../utils/deadlineTimer.js';
/** A callback function to report on agent activity. */
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
@@ -231,6 +232,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
turnCounter: number,
combinedSignal: AbortSignal,
timeoutSignal: AbortSignal, // Pass the timeout controller's signal
onWaitingForConfirmation?: (waiting: boolean) => void,
): Promise<AgentTurnResult> {
const promptId = `${this.agentId}#${turnCounter}`;
@@ -265,7 +267,12 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
}
const { nextMessage, submittedOutput, taskCompleted } =
await this.processFunctionCalls(functionCalls, combinedSignal, promptId);
await this.processFunctionCalls(
functionCalls,
combinedSignal,
promptId,
onWaitingForConfirmation,
);
if (taskCompleted) {
const finalResult = submittedOutput ?? 'Task completed successfully.';
return {
@@ -322,6 +329,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
| AgentTerminateMode.MAX_TURNS
| AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL,
externalSignal: AbortSignal, // The original signal passed to run()
onWaitingForConfirmation?: (waiting: boolean) => void,
): Promise<string | null> {
this.emitActivity('THOUGHT_CHUNK', {
text: `Execution limit reached (${reason}). Attempting one final recovery turn with a grace period.`,
@@ -355,6 +363,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
turnCounter, // This will be the "last" turn number
combinedSignal,
graceTimeoutController.signal, // Pass grace signal to identify a *grace* timeout
onWaitingForConfirmation,
);
if (
@@ -415,14 +424,22 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
this.definition.runConfig.maxTimeMinutes ?? DEFAULT_MAX_TIME_MINUTES;
const maxTurns = this.definition.runConfig.maxTurns ?? DEFAULT_MAX_TURNS;
const timeoutController = new AbortController();
const timeoutId = setTimeout(
() => timeoutController.abort(new Error('Agent timed out.')),
const deadlineTimer = new DeadlineTimer(
maxTimeMinutes * 60 * 1000,
'Agent timed out.',
);
// Track time spent waiting for user confirmation to credit it back to the agent.
const onWaitingForConfirmation = (waiting: boolean) => {
if (waiting) {
deadlineTimer.pause();
} else {
deadlineTimer.resume();
}
};
// Combine the external signal with the internal timeout signal.
const combinedSignal = AbortSignal.any([signal, timeoutController.signal]);
const combinedSignal = AbortSignal.any([signal, deadlineTimer.signal]);
logAgentStart(
this.runtimeContext,
@@ -458,7 +475,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
// Check for timeout or external abort.
if (combinedSignal.aborted) {
// Determine which signal caused the abort.
terminateReason = timeoutController.signal.aborted
terminateReason = deadlineTimer.signal.aborted
? AgentTerminateMode.TIMEOUT
: AgentTerminateMode.ABORTED;
break;
@@ -469,7 +486,8 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
currentMessage,
turnCounter++,
combinedSignal,
timeoutController.signal,
deadlineTimer.signal,
onWaitingForConfirmation,
);
if (turnResult.status === 'stop') {
@@ -498,6 +516,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
turnCounter, // Use current turnCounter for the recovery attempt
terminateReason,
signal, // Pass the external signal
onWaitingForConfirmation,
);
if (recoveryResult !== null) {
@@ -551,7 +570,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
if (
error instanceof Error &&
error.name === 'AbortError' &&
timeoutController.signal.aborted &&
deadlineTimer.signal.aborted &&
!signal.aborted // Ensure the external signal was not the cause
) {
terminateReason = AgentTerminateMode.TIMEOUT;
@@ -563,6 +582,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
turnCounter, // Use current turnCounter
AgentTerminateMode.TIMEOUT,
signal,
onWaitingForConfirmation,
);
if (recoveryResult !== null) {
@@ -591,7 +611,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
this.emitActivity('ERROR', { error: String(error) });
throw error; // Re-throw other errors or external aborts.
} finally {
clearTimeout(timeoutId);
deadlineTimer.abort();
logAgentFinish(
this.runtimeContext,
new AgentFinishEvent(
@@ -779,6 +799,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
functionCalls: FunctionCall[],
signal: AbortSignal,
promptId: string,
onWaitingForConfirmation?: (waiting: boolean) => void,
): Promise<{
nextMessage: Content;
submittedOutput: string | null;
@@ -979,6 +1000,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
parentCallId: this.parentCallId,
toolRegistry: this.toolRegistry,
signal,
onWaitingForConfirmation,
},
);

View File

@@ -109,9 +109,10 @@ export async function resolveConfirmation(
modifier: ToolModificationHandler;
getPreferredEditor: () => EditorType | undefined;
schedulerId: string;
onWaitingForConfirmation?: (waiting: boolean) => void;
},
): Promise<ResolutionResult> {
const { state } = deps;
const { state, onWaitingForConfirmation } = deps;
const callId = toolCall.request.callId;
let outcome = ToolConfirmationOutcome.ModifyWithEditor;
let lastDetails: SerializableConfirmationDetails | undefined;
@@ -147,12 +148,14 @@ export async function resolveConfirmation(
correlationId,
});
onWaitingForConfirmation?.(true);
const response = await waitForConfirmation(
deps.messageBus,
correlationId,
signal,
ideConfirmation,
);
onWaitingForConfirmation?.(false);
outcome = response.outcome;
if ('onConfirm' in details && typeof details.onConfirm === 'function') {

View File

@@ -51,6 +51,7 @@ export interface SchedulerOptions {
getPreferredEditor: () => EditorType | undefined;
schedulerId: string;
parentCallId?: string;
onWaitingForConfirmation?: (waiting: boolean) => void;
}
const createErrorResponse = (
@@ -90,6 +91,7 @@ export class Scheduler {
private readonly getPreferredEditor: () => EditorType | undefined;
private readonly schedulerId: string;
private readonly parentCallId?: string;
private readonly onWaitingForConfirmation?: (waiting: boolean) => void;
private isProcessing = false;
private isCancelling = false;
@@ -101,6 +103,7 @@ export class Scheduler {
this.getPreferredEditor = options.getPreferredEditor;
this.schedulerId = options.schedulerId;
this.parentCallId = options.parentCallId;
this.onWaitingForConfirmation = options.onWaitingForConfirmation;
this.state = new SchedulerStateManager(
this.messageBus,
this.schedulerId,
@@ -437,6 +440,7 @@ export class Scheduler {
modifier: this.modifier,
getPreferredEditor: this.getPreferredEditor,
schedulerId: this.schedulerId,
onWaitingForConfirmation: this.onWaitingForConfirmation,
});
outcome = result.outcome;
lastDetails = result.lastDetails;

View File

@@ -0,0 +1,80 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { Scheduler } from './scheduler.js';
import { resolveConfirmation } from './confirmation.js';
import { checkPolicy } from './policy.js';
import { PolicyDecision } from '../policy/types.js';
import { ToolConfirmationOutcome } from '../tools/tools.js';
import { ToolRegistry } from '../tools/tool-registry.js';
import { MockTool } from '../test-utils/mock-tool.js';
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
import { makeFakeConfig } from '../test-utils/config.js';
import type { Config } from '../config/config.js';
import type { ToolCallRequestInfo } from './types.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
vi.mock('./confirmation.js');
vi.mock('./policy.js');
describe('Scheduler waiting callback', () => {
let mockConfig: Config;
let messageBus: MessageBus;
let toolRegistry: ToolRegistry;
let mockTool: MockTool;
beforeEach(() => {
messageBus = createMockMessageBus();
mockConfig = makeFakeConfig();
// Override methods to use our mocks
vi.spyOn(mockConfig, 'getMessageBus').mockReturnValue(messageBus);
mockTool = new MockTool({ name: 'test_tool' });
toolRegistry = new ToolRegistry(mockConfig, messageBus);
vi.spyOn(mockConfig, 'getToolRegistry').mockReturnValue(toolRegistry);
toolRegistry.registerTool(mockTool);
vi.mocked(checkPolicy).mockResolvedValue({
decision: PolicyDecision.ASK_USER,
rule: undefined,
});
});
it('should trigger onWaitingForConfirmation callback', async () => {
const onWaitingForConfirmation = vi.fn();
const scheduler = new Scheduler({
config: mockConfig,
messageBus,
getPreferredEditor: () => undefined,
schedulerId: 'test-scheduler',
onWaitingForConfirmation,
});
vi.mocked(resolveConfirmation).mockResolvedValue({
outcome: ToolConfirmationOutcome.ProceedOnce,
});
const req: ToolCallRequestInfo = {
callId: 'call-1',
name: 'test_tool',
args: {},
isClientInitiated: false,
prompt_id: 'test-prompt',
};
await scheduler.schedule(req, new AbortController().signal);
expect(resolveConfirmation).toHaveBeenCalledWith(
expect.anything(),
expect.anything(),
expect.objectContaining({
onWaitingForConfirmation,
}),
);
});
});

View File

@@ -0,0 +1,82 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { DeadlineTimer } from './deadlineTimer.js';
describe('DeadlineTimer', () => {
beforeEach(() => {
vi.useFakeTimers();
});
afterEach(() => {
vi.restoreAllMocks();
});
it('should abort when timeout is reached', () => {
const timer = new DeadlineTimer(1000);
const signal = timer.signal;
expect(signal.aborted).toBe(false);
vi.advanceTimersByTime(1000);
expect(signal.aborted).toBe(true);
expect(signal.reason).toBeInstanceOf(Error);
expect((signal.reason as Error).message).toBe('Timeout exceeded.');
});
it('should allow extending the deadline', () => {
const timer = new DeadlineTimer(1000);
const signal = timer.signal;
vi.advanceTimersByTime(500);
expect(signal.aborted).toBe(false);
timer.extend(1000); // New deadline is 1000 + 1000 = 2000 from start
vi.advanceTimersByTime(600); // 1100 total
expect(signal.aborted).toBe(false);
vi.advanceTimersByTime(900); // 2000 total
expect(signal.aborted).toBe(true);
});
it('should allow pausing and resuming the timer', () => {
const timer = new DeadlineTimer(1000);
const signal = timer.signal;
vi.advanceTimersByTime(500);
timer.pause();
vi.advanceTimersByTime(2000); // Wait a long time while paused
expect(signal.aborted).toBe(false);
timer.resume();
vi.advanceTimersByTime(400);
expect(signal.aborted).toBe(false);
vi.advanceTimersByTime(200); // Total active time 500 + 400 + 200 = 1100
expect(signal.aborted).toBe(true);
});
it('should abort immediately when abort() is called', () => {
const timer = new DeadlineTimer(1000);
const signal = timer.signal;
timer.abort('cancelled');
expect(signal.aborted).toBe(true);
expect(signal.reason).toBe('cancelled');
});
it('should not fire timeout if aborted manually', () => {
const timer = new DeadlineTimer(1000);
const signal = timer.signal;
timer.abort();
vi.advanceTimersByTime(1000);
// Already aborted, but shouldn't re-abort or throw
expect(signal.aborted).toBe(true);
});
});

View File

@@ -0,0 +1,94 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
/**
* A utility that manages a timeout and an AbortController, allowing the
* timeout to be paused, resumed, and dynamically extended.
*/
export class DeadlineTimer {
private readonly controller: AbortController;
private timeoutId: NodeJS.Timeout | null = null;
private remainingMs: number;
private lastStartedAt: number;
private isPaused = false;
constructor(timeoutMs: number, reason = 'Timeout exceeded.') {
this.controller = new AbortController();
this.remainingMs = timeoutMs;
this.lastStartedAt = Date.now();
this.schedule(timeoutMs, reason);
}
/** The AbortSignal managed by this timer. */
get signal(): AbortSignal {
return this.controller.signal;
}
/**
* Pauses the timer, clearing any active timeout.
*/
pause(): void {
if (this.isPaused || this.controller.signal.aborted) return;
if (this.timeoutId) {
clearTimeout(this.timeoutId);
this.timeoutId = null;
}
const elapsed = Date.now() - this.lastStartedAt;
this.remainingMs = Math.max(0, this.remainingMs - elapsed);
this.isPaused = true;
}
/**
* Resumes the timer with the remaining budget.
*/
resume(reason = 'Timeout exceeded.'): void {
if (!this.isPaused || this.controller.signal.aborted) return;
this.lastStartedAt = Date.now();
this.schedule(this.remainingMs, reason);
this.isPaused = false;
}
/**
* Extends the current budget by the specified number of milliseconds.
*/
extend(ms: number, reason = 'Timeout exceeded.'): void {
if (this.controller.signal.aborted) return;
if (this.isPaused) {
this.remainingMs += ms;
} else {
if (this.timeoutId) {
clearTimeout(this.timeoutId);
}
const elapsed = Date.now() - this.lastStartedAt;
this.remainingMs = Math.max(0, this.remainingMs - elapsed) + ms;
this.lastStartedAt = Date.now();
this.schedule(this.remainingMs, reason);
}
}
/**
* Aborts the signal immediately and clears any pending timers.
*/
abort(reason?: unknown): void {
if (this.timeoutId) {
clearTimeout(this.timeoutId);
this.timeoutId = null;
}
this.isPaused = false;
this.controller.abort(reason);
}
private schedule(ms: number, reason: string): void {
this.timeoutId = setTimeout(() => {
this.timeoutId = null;
this.controller.abort(new Error(reason));
}, ms);
}
}