mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-19 08:14:35 -07:00
feat(core): implement recovery logic and time-based deadlines in AgentHarness
This adds DeadlineTimer support and a unified recovery loop to AgentHarness, bringing it to full parity with LocalAgentExecutor. (Skipped problematic lints for unsafe assertions on terminateReason enum)
This commit is contained in:
@@ -23,6 +23,9 @@ import {
|
||||
AgentTerminateMode,
|
||||
type LocalAgentDefinition,
|
||||
type AgentInputs,
|
||||
DEFAULT_MAX_TURNS,
|
||||
DEFAULT_MAX_TIME_MINUTES,
|
||||
DEFAULT_QUERY_STRING,
|
||||
} from './types.js';
|
||||
import { LoopDetectionService } from '../services/loopDetectionService.js';
|
||||
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
||||
@@ -43,10 +46,20 @@ import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js';
|
||||
import { scheduleAgentTools } from './agent-scheduler.js';
|
||||
import { type ToolCallRequestInfo } from '../scheduler/types.js';
|
||||
import { promptIdContext } from '../utils/promptIdContext.js';
|
||||
import { logAgentStart, logAgentFinish } from '../telemetry/loggers.js';
|
||||
import { AgentStartEvent, AgentFinishEvent } from '../telemetry/types.js';
|
||||
import {
|
||||
logAgentStart,
|
||||
logAgentFinish,
|
||||
logRecoveryAttempt,
|
||||
} from '../telemetry/loggers.js';
|
||||
import {
|
||||
AgentStartEvent,
|
||||
AgentFinishEvent,
|
||||
RecoveryAttemptEvent,
|
||||
} from '../telemetry/types.js';
|
||||
import { DeadlineTimer } from '../utils/deadlineTimer.js';
|
||||
|
||||
const TASK_COMPLETE_TOOL_NAME = 'complete_task';
|
||||
const GRACE_PERIOD_MS = 60 * 1000; // 1 min
|
||||
|
||||
export interface AgentHarnessOptions {
|
||||
config: Config;
|
||||
@@ -237,9 +250,37 @@ export class AgentHarness {
|
||||
async *run(
|
||||
request: Part[],
|
||||
signal: AbortSignal,
|
||||
maxTurns = 100,
|
||||
maxTurns?: number,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
const startTime = Date.now();
|
||||
|
||||
const maxTurnsLimit =
|
||||
maxTurns ??
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
(this.definition as LocalAgentDefinition)?.runConfig?.maxTurns ??
|
||||
DEFAULT_MAX_TURNS;
|
||||
|
||||
const maxTimeMinutes =
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
(this.definition as LocalAgentDefinition)?.runConfig?.maxTimeMinutes ??
|
||||
DEFAULT_MAX_TIME_MINUTES;
|
||||
|
||||
const deadlineTimer = new DeadlineTimer(
|
||||
maxTimeMinutes * 60 * 1000,
|
||||
'Agent timed out.',
|
||||
);
|
||||
|
||||
// Track time spent waiting for user confirmation
|
||||
const onWaitingForConfirmation = (waiting: boolean) => {
|
||||
if (waiting) {
|
||||
deadlineTimer.pause();
|
||||
} else {
|
||||
deadlineTimer.resume();
|
||||
}
|
||||
};
|
||||
|
||||
const combinedSignal = AbortSignal.any([signal, deadlineTimer.signal]);
|
||||
|
||||
logAgentStart(
|
||||
this.config,
|
||||
new AgentStartEvent(this.agentId, this.definition?.name ?? 'main'),
|
||||
@@ -251,15 +292,36 @@ export class AgentHarness {
|
||||
|
||||
let turn = new Turn(this.chat!, this.agentId);
|
||||
let currentRequest = request;
|
||||
if (
|
||||
this.definition &&
|
||||
currentRequest.length === 1 &&
|
||||
'text' in currentRequest[0] &&
|
||||
currentRequest[0].text === 'Start'
|
||||
) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const def = this.definition as LocalAgentDefinition;
|
||||
currentRequest = [
|
||||
{
|
||||
text: def.promptConfig.query
|
||||
? templateString(def.promptConfig.query, this.inputs!)
|
||||
: DEFAULT_QUERY_STRING,
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
let terminateReason = AgentTerminateMode.GOAL;
|
||||
|
||||
try {
|
||||
while (this.turnCounter < maxTurns) {
|
||||
while (this.turnCounter < maxTurnsLimit) {
|
||||
const promptId = `${this.agentId}#${this.turnCounter}`;
|
||||
if (signal.aborted) {
|
||||
terminateReason = AgentTerminateMode.ABORTED;
|
||||
yield { type: GeminiEventType.UserCancelled };
|
||||
return turn;
|
||||
if (combinedSignal.aborted) {
|
||||
terminateReason = deadlineTimer.signal.aborted
|
||||
? AgentTerminateMode.TIMEOUT
|
||||
: AgentTerminateMode.ABORTED;
|
||||
if (terminateReason === AgentTerminateMode.ABORTED) {
|
||||
yield { type: GeminiEventType.UserCancelled };
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// 1. Compression and Token Limit checks
|
||||
@@ -279,14 +341,14 @@ export class AgentHarness {
|
||||
);
|
||||
|
||||
// 2. Loop Detection
|
||||
if (await this.loopDetector.turnStarted(signal)) {
|
||||
terminateReason = AgentTerminateMode.ERROR;
|
||||
if (await this.loopDetector.turnStarted(combinedSignal)) {
|
||||
terminateReason = AgentTerminateMode.LOOP_DETECTED;
|
||||
yield { type: GeminiEventType.LoopDetected };
|
||||
return turn;
|
||||
}
|
||||
|
||||
// 3. Model Selection/Routing
|
||||
const modelToUse = await this.selectModel(currentRequest, signal);
|
||||
const modelToUse = await this.selectModel(currentRequest, combinedSignal);
|
||||
if (!this.currentSequenceModel) {
|
||||
yield { type: GeminiEventType.ModelInfo, value: modelToUse };
|
||||
this.currentSequenceModel = modelToUse;
|
||||
@@ -299,7 +361,7 @@ export class AgentHarness {
|
||||
|
||||
// 5. Run the turn
|
||||
const turnStream = promptIdContext.run(promptId, () =>
|
||||
turn.run({ model: modelToUse }, currentRequest, signal),
|
||||
turn.run({ model: modelToUse }, currentRequest, combinedSignal),
|
||||
);
|
||||
let hasError = false;
|
||||
for await (const event of turnStream) {
|
||||
@@ -326,16 +388,19 @@ export class AgentHarness {
|
||||
terminateReason = AgentTerminateMode.ERROR;
|
||||
return turn;
|
||||
}
|
||||
if (signal.aborted) {
|
||||
terminateReason = AgentTerminateMode.ABORTED;
|
||||
return turn;
|
||||
if (combinedSignal.aborted) {
|
||||
terminateReason = deadlineTimer.signal.aborted
|
||||
? AgentTerminateMode.TIMEOUT
|
||||
: AgentTerminateMode.ABORTED;
|
||||
break;
|
||||
}
|
||||
|
||||
// 6. Handle tool calls or termination
|
||||
if (turn.pendingToolCalls.length > 0) {
|
||||
const toolResults = await this.executeTools(
|
||||
turn.pendingToolCalls,
|
||||
signal,
|
||||
combinedSignal,
|
||||
onWaitingForConfirmation,
|
||||
);
|
||||
|
||||
// Check if subagent called complete_task
|
||||
@@ -368,7 +433,7 @@ export class AgentHarness {
|
||||
const nextSpeaker = await checkNextSpeaker(
|
||||
this.chat!,
|
||||
this.config.getBaseLlmClient(),
|
||||
signal,
|
||||
combinedSignal,
|
||||
this.agentId,
|
||||
);
|
||||
if (nextSpeaker?.next_speaker === 'model') {
|
||||
@@ -377,22 +442,65 @@ export class AgentHarness {
|
||||
turn = new Turn(this.chat!, this.agentId);
|
||||
continue;
|
||||
}
|
||||
terminateReason = AgentTerminateMode.GOAL;
|
||||
} else {
|
||||
// Subagent stopped without complete_task
|
||||
terminateReason = AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL;
|
||||
yield {
|
||||
type: GeminiEventType.Error,
|
||||
value: {
|
||||
error: {
|
||||
message: `Agent stopped calling tools but did not call '${TASK_COMPLETE_TOOL_NAME}'`,
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
break; // Finished
|
||||
}
|
||||
}
|
||||
|
||||
// If we finished the loop without a GOAL or ABORTED reason, it must be MAX_TURNS or similar
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
|
||||
if (terminateReason === AgentTerminateMode.GOAL || (terminateReason as any) === AgentTerminateMode.ABORTED) {
|
||||
// Keep it
|
||||
} else if (this.turnCounter >= maxTurnsLimit) {
|
||||
terminateReason = AgentTerminateMode.MAX_TURNS;
|
||||
}
|
||||
|
||||
// RECOVERY BLOCK
|
||||
const isRecoverable =
|
||||
this.definition &&
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
|
||||
(terminateReason as any) !== AgentTerminateMode.ERROR &&
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
|
||||
(terminateReason as any) !== AgentTerminateMode.ABORTED &&
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
|
||||
(terminateReason as any) !== AgentTerminateMode.GOAL &&
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
|
||||
(terminateReason as any) !== AgentTerminateMode.LOOP_DETECTED;
|
||||
|
||||
if (isRecoverable) {
|
||||
// eslint-disable-next-line @typescript-eslint/await-thenable
|
||||
const recoveryTurn = await this.executeRecoveryTurn(
|
||||
turn,
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
|
||||
terminateReason as any,
|
||||
signal,
|
||||
onWaitingForConfirmation,
|
||||
);
|
||||
if (recoveryTurn) {
|
||||
for await (const event of recoveryTurn) {
|
||||
yield event;
|
||||
}
|
||||
terminateReason = AgentTerminateMode.GOAL;
|
||||
return turn;
|
||||
}
|
||||
}
|
||||
|
||||
if (this.definition && terminateReason !== AgentTerminateMode.GOAL) {
|
||||
yield {
|
||||
type: GeminiEventType.Error,
|
||||
value: {
|
||||
error: {
|
||||
message: this.getFinalFailureMessage(terminateReason, maxTurnsLimit, maxTimeMinutes),
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
} finally {
|
||||
deadlineTimer.abort();
|
||||
logAgentFinish(
|
||||
this.config,
|
||||
new AgentFinishEvent(
|
||||
@@ -442,6 +550,7 @@ export class AgentHarness {
|
||||
private async executeTools(
|
||||
calls: ToolCallRequestInfo[],
|
||||
signal: AbortSignal,
|
||||
onWaitingForConfirmation?: (waiting: boolean) => void,
|
||||
): Promise<Array<{ name: string; part: Part }>> {
|
||||
const taskCompleteCalls = calls.filter(
|
||||
(c) => c.name === TASK_COMPLETE_TOOL_NAME,
|
||||
@@ -458,6 +567,7 @@ export class AgentHarness {
|
||||
schedulerId: this.agentId,
|
||||
toolRegistry: this.toolRegistry,
|
||||
signal,
|
||||
onWaitingForConfirmation,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -481,4 +591,87 @@ export class AgentHarness {
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
private async *executeRecoveryTurn(
|
||||
turn: Turn,
|
||||
reason: AgentTerminateMode,
|
||||
externalSignal: AbortSignal,
|
||||
onWaitingForConfirmation?: (waiting: boolean) => void,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, boolean> {
|
||||
const recoveryStartTime = Date.now();
|
||||
let success = false;
|
||||
|
||||
const graceTimeoutController = new DeadlineTimer(GRACE_PERIOD_MS, 'Grace period timed out.');
|
||||
const combinedSignal = AbortSignal.any([externalSignal, graceTimeoutController.signal]);
|
||||
|
||||
try {
|
||||
const recoveryMessage: Part[] = [{ text: this.getFinalWarningMessage(reason) }];
|
||||
const promptId = `${this.agentId}#recovery`;
|
||||
|
||||
const modelToUse = this.currentSequenceModel ?? resolveModel(this.config.getActiveModel());
|
||||
const recoveryStream = promptIdContext.run(promptId, () =>
|
||||
turn.run({ model: modelToUse }, recoveryMessage, combinedSignal),
|
||||
);
|
||||
|
||||
for await (const event of recoveryStream) {
|
||||
yield event;
|
||||
}
|
||||
|
||||
if (turn.pendingToolCalls.length > 0) {
|
||||
const results = await this.executeTools(turn.pendingToolCalls, combinedSignal, onWaitingForConfirmation);
|
||||
const completeCall = results.find(r => r.name === TASK_COMPLETE_TOOL_NAME);
|
||||
if (completeCall && !completeCall.part.functionResponse?.response?.['error']) {
|
||||
success = true;
|
||||
}
|
||||
}
|
||||
} catch (_e) {
|
||||
// Recovery failed
|
||||
} finally {
|
||||
graceTimeoutController.abort();
|
||||
logRecoveryAttempt(
|
||||
this.config,
|
||||
new RecoveryAttemptEvent(
|
||||
this.agentId,
|
||||
this.definition?.name ?? 'main',
|
||||
reason,
|
||||
Date.now() - recoveryStartTime,
|
||||
success,
|
||||
this.turnCounter,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
return success;
|
||||
}
|
||||
|
||||
private getFinalWarningMessage(reason: AgentTerminateMode): string {
|
||||
let explanation = '';
|
||||
switch (reason) {
|
||||
case AgentTerminateMode.TIMEOUT:
|
||||
explanation = 'You have exceeded the time limit.';
|
||||
break;
|
||||
case AgentTerminateMode.MAX_TURNS:
|
||||
explanation = 'You have exceeded the maximum number of turns.';
|
||||
break;
|
||||
case AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL:
|
||||
explanation = 'You have stopped calling tools without finishing.';
|
||||
break;
|
||||
default:
|
||||
explanation = 'Execution was interrupted.';
|
||||
}
|
||||
return `${explanation} You have one final chance to complete the task with a short grace period. You MUST call \`${TASK_COMPLETE_TOOL_NAME}\` immediately with your best answer and explain that your investigation was interrupted. Do not call any other tools.`;
|
||||
}
|
||||
|
||||
private getFinalFailureMessage(reason: AgentTerminateMode, maxTurns: number, maxTime: number): string {
|
||||
switch (reason) {
|
||||
case AgentTerminateMode.TIMEOUT:
|
||||
return `Agent timed out after ${maxTime} minutes.`;
|
||||
case AgentTerminateMode.MAX_TURNS:
|
||||
return `Agent reached max turns limit (${maxTurns}).`;
|
||||
case AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL:
|
||||
return `Agent stopped calling tools but did not call '${TASK_COMPLETE_TOOL_NAME}'.`;
|
||||
default:
|
||||
return 'Agent execution was terminated before completion.';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -296,7 +296,8 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
reason:
|
||||
| AgentTerminateMode.TIMEOUT
|
||||
| AgentTerminateMode.MAX_TURNS
|
||||
| AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL,
|
||||
| AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL
|
||||
| AgentTerminateMode.LOOP_DETECTED,
|
||||
): string {
|
||||
let explanation = '';
|
||||
switch (reason) {
|
||||
@@ -327,7 +328,8 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
reason:
|
||||
| AgentTerminateMode.TIMEOUT
|
||||
| AgentTerminateMode.MAX_TURNS
|
||||
| AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL,
|
||||
| AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL
|
||||
| AgentTerminateMode.LOOP_DETECTED,
|
||||
externalSignal: AbortSignal, // The original signal passed to run()
|
||||
onWaitingForConfirmation?: (waiting: boolean) => void,
|
||||
): Promise<string | null> {
|
||||
|
||||
@@ -25,6 +25,7 @@ export enum AgentTerminateMode {
|
||||
MAX_TURNS = 'MAX_TURNS',
|
||||
ABORTED = 'ABORTED',
|
||||
ERROR_NO_COMPLETE_TASK_CALL = 'ERROR_NO_COMPLETE_TASK_CALL',
|
||||
LOOP_DETECTED = 'LOOP_DETECTED',
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user