mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 14:10:37 -07:00
Grants subagent a recovery turn for when it hits TIMEOUT, MAX_TURNS or NO_TOOL_CALL failures. (#12344)
This commit is contained in:
@@ -571,22 +571,28 @@ describe('AgentExecutor', () => {
|
||||
},
|
||||
});
|
||||
|
||||
// Turn 2 (protocol violation)
|
||||
mockModelResponse([], 'I think I am done.');
|
||||
|
||||
// Turn 3 (recovery turn - also fails)
|
||||
mockModelResponse([], 'I still give up.');
|
||||
|
||||
const output = await executor.run({ goal: 'Strict test' }, signal);
|
||||
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(3);
|
||||
|
||||
const expectedError = `Agent stopped calling tools but did not call '${TASK_COMPLETE_TOOL_NAME}' to finalize the session.`;
|
||||
const expectedError = `Agent stopped calling tools but did not call '${TASK_COMPLETE_TOOL_NAME}'.`;
|
||||
|
||||
expect(output.terminate_reason).toBe(AgentTerminateMode.ERROR);
|
||||
expect(output.terminate_reason).toBe(
|
||||
AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL,
|
||||
);
|
||||
expect(output.result).toBe(expectedError);
|
||||
|
||||
// Telemetry check for error
|
||||
expect(mockedLogAgentFinish).toHaveBeenCalledWith(
|
||||
mockConfig,
|
||||
expect.objectContaining({
|
||||
terminate_reason: AgentTerminateMode.ERROR,
|
||||
terminate_reason: AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL,
|
||||
}),
|
||||
);
|
||||
|
||||
@@ -901,11 +907,13 @@ describe('AgentExecutor', () => {
|
||||
|
||||
mockWorkResponse('t1');
|
||||
mockWorkResponse('t2');
|
||||
// Recovery turn
|
||||
mockModelResponse([], 'I give up');
|
||||
|
||||
const output = await executor.run({ goal: 'Turns test' }, signal);
|
||||
|
||||
expect(output.terminate_reason).toBe(AgentTerminateMode.MAX_TURNS);
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(MAX);
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(MAX + 1);
|
||||
});
|
||||
|
||||
it('should terminate with TIMEOUT if a model call takes too long', async () => {
|
||||
@@ -931,6 +939,8 @@ describe('AgentExecutor', () => {
|
||||
});
|
||||
})();
|
||||
});
|
||||
// Recovery turn
|
||||
mockModelResponse([], 'I give up');
|
||||
|
||||
const runPromise = executor.run({ goal: 'Timeout test' }, signal);
|
||||
|
||||
@@ -941,7 +951,7 @@ describe('AgentExecutor', () => {
|
||||
|
||||
expect(output.terminate_reason).toBe(AgentTerminateMode.TIMEOUT);
|
||||
expect(output.result).toContain('Agent timed out after 0.5 minutes.');
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
|
||||
// Verify activity stream reported the timeout
|
||||
expect(activities).toContainEqual(
|
||||
@@ -992,10 +1002,13 @@ describe('AgentExecutor', () => {
|
||||
};
|
||||
});
|
||||
|
||||
// Recovery turn
|
||||
mockModelResponse([], 'I give up');
|
||||
|
||||
const output = await executor.run({ goal: 'Timeout test' }, signal);
|
||||
|
||||
expect(output.terminate_reason).toBe(AgentTerminateMode.TIMEOUT);
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should terminate when AbortSignal is triggered', async () => {
|
||||
@@ -1019,4 +1032,315 @@ describe('AgentExecutor', () => {
|
||||
expect(output.terminate_reason).toBe(AgentTerminateMode.ABORTED);
|
||||
});
|
||||
});
|
||||
|
||||
describe('run (Recovery Turns)', () => {
|
||||
const mockWorkResponse = (id: string) => {
|
||||
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
||||
mockExecuteToolCall.mockResolvedValueOnce({
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: id,
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: id,
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
it('should recover successfully if complete_task is called during the grace turn after MAX_TURNS', async () => {
|
||||
const MAX = 1;
|
||||
const definition = createTestDefinition([LS_TOOL_NAME], {
|
||||
max_turns: MAX,
|
||||
});
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
|
||||
// Turn 1 (hits max_turns)
|
||||
mockWorkResponse('t1');
|
||||
|
||||
// Recovery Turn (succeeds)
|
||||
mockModelResponse(
|
||||
[
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'Recovered!' },
|
||||
id: 't2',
|
||||
},
|
||||
],
|
||||
'Recovering from max turns',
|
||||
);
|
||||
|
||||
const output = await executor.run({ goal: 'Turns recovery' }, signal);
|
||||
|
||||
expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL);
|
||||
expect(output.result).toBe('Recovered!');
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(MAX + 1); // 1 regular + 1 recovery
|
||||
|
||||
expect(activities).toContainEqual(
|
||||
expect.objectContaining({
|
||||
type: 'THOUGHT_CHUNK',
|
||||
data: {
|
||||
text: 'Execution limit reached (MAX_TURNS). Attempting one final recovery turn with a grace period.',
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(activities).toContainEqual(
|
||||
expect.objectContaining({
|
||||
type: 'THOUGHT_CHUNK',
|
||||
data: { text: 'Graceful recovery succeeded.' },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should fail if complete_task is NOT called during the grace turn after MAX_TURNS', async () => {
|
||||
const MAX = 1;
|
||||
const definition = createTestDefinition([LS_TOOL_NAME], {
|
||||
max_turns: MAX,
|
||||
});
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
|
||||
// Turn 1 (hits max_turns)
|
||||
mockWorkResponse('t1');
|
||||
|
||||
// Recovery Turn (fails by calling no tools)
|
||||
mockModelResponse([], 'I give up again.');
|
||||
|
||||
const output = await executor.run(
|
||||
{ goal: 'Turns recovery fail' },
|
||||
signal,
|
||||
);
|
||||
|
||||
expect(output.terminate_reason).toBe(AgentTerminateMode.MAX_TURNS);
|
||||
expect(output.result).toContain('Agent reached max turns limit');
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(MAX + 1);
|
||||
|
||||
expect(activities).toContainEqual(
|
||||
expect.objectContaining({
|
||||
type: 'ERROR',
|
||||
data: expect.objectContaining({
|
||||
context: 'recovery_turn',
|
||||
error: 'Graceful recovery attempt failed. Reason: stop',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should recover successfully from a protocol violation (no complete_task)', async () => {
|
||||
const definition = createTestDefinition();
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
|
||||
// Turn 1: Normal work
|
||||
mockWorkResponse('t1');
|
||||
|
||||
// Turn 2: Protocol violation (no tool calls)
|
||||
mockModelResponse([], 'I think I am done, but I forgot the right tool.');
|
||||
|
||||
// Turn 3: Recovery turn (succeeds)
|
||||
mockModelResponse(
|
||||
[
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'Recovered from violation!' },
|
||||
id: 't3',
|
||||
},
|
||||
],
|
||||
'My mistake, here is the completion.',
|
||||
);
|
||||
|
||||
const output = await executor.run({ goal: 'Violation recovery' }, signal);
|
||||
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(3);
|
||||
expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL);
|
||||
expect(output.result).toBe('Recovered from violation!');
|
||||
|
||||
expect(activities).toContainEqual(
|
||||
expect.objectContaining({
|
||||
type: 'THOUGHT_CHUNK',
|
||||
data: {
|
||||
text: 'Execution limit reached (ERROR_NO_COMPLETE_TASK_CALL). Attempting one final recovery turn with a grace period.',
|
||||
},
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should fail recovery from a protocol violation if it violates again', async () => {
|
||||
const definition = createTestDefinition();
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
|
||||
// Turn 1: Normal work
|
||||
mockWorkResponse('t1');
|
||||
|
||||
// Turn 2: Protocol violation (no tool calls)
|
||||
mockModelResponse([], 'I think I am done, but I forgot the right tool.');
|
||||
|
||||
// Turn 3: Recovery turn (fails again)
|
||||
mockModelResponse([], 'I still dont know what to do.');
|
||||
|
||||
const output = await executor.run(
|
||||
{ goal: 'Violation recovery fail' },
|
||||
signal,
|
||||
);
|
||||
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(3);
|
||||
expect(output.terminate_reason).toBe(
|
||||
AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL,
|
||||
);
|
||||
expect(output.result).toContain(
|
||||
`Agent stopped calling tools but did not call '${TASK_COMPLETE_TOOL_NAME}'`,
|
||||
);
|
||||
|
||||
expect(activities).toContainEqual(
|
||||
expect.objectContaining({
|
||||
type: 'ERROR',
|
||||
data: expect.objectContaining({
|
||||
context: 'recovery_turn',
|
||||
error: 'Graceful recovery attempt failed. Reason: stop',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should recover successfully from a TIMEOUT', async () => {
|
||||
const definition = createTestDefinition([LS_TOOL_NAME], {
|
||||
max_time_minutes: 0.5, // 30 seconds
|
||||
});
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
|
||||
// Mock a model call that gets interrupted by the timeout.
|
||||
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
|
||||
const signal = params?.config?.abortSignal;
|
||||
// eslint-disable-next-line require-yield
|
||||
return (async function* () {
|
||||
// This promise never resolves, it waits for abort.
|
||||
await new Promise<void>((resolve) => {
|
||||
signal?.addEventListener('abort', () => resolve());
|
||||
});
|
||||
})();
|
||||
});
|
||||
|
||||
// Recovery turn (succeeds)
|
||||
mockModelResponse(
|
||||
[
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'Recovered from timeout!' },
|
||||
id: 't2',
|
||||
},
|
||||
],
|
||||
'Apologies for the delay, finishing up.',
|
||||
);
|
||||
|
||||
const runPromise = executor.run({ goal: 'Timeout recovery' }, signal);
|
||||
|
||||
// Advance time past the timeout to trigger the abort and recovery.
|
||||
await vi.advanceTimersByTimeAsync(31 * 1000);
|
||||
|
||||
const output = await runPromise;
|
||||
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2); // 1 failed + 1 recovery
|
||||
expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL);
|
||||
expect(output.result).toBe('Recovered from timeout!');
|
||||
|
||||
expect(activities).toContainEqual(
|
||||
expect.objectContaining({
|
||||
type: 'THOUGHT_CHUNK',
|
||||
data: {
|
||||
text: 'Execution limit reached (TIMEOUT). Attempting one final recovery turn with a grace period.',
|
||||
},
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should fail recovery from a TIMEOUT if the grace period also times out', async () => {
|
||||
const definition = createTestDefinition([LS_TOOL_NAME], {
|
||||
max_time_minutes: 0.5, // 30 seconds
|
||||
});
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
|
||||
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
|
||||
const signal = params?.config?.abortSignal;
|
||||
// eslint-disable-next-line require-yield
|
||||
return (async function* () {
|
||||
await new Promise<void>((resolve) =>
|
||||
signal?.addEventListener('abort', () => resolve()),
|
||||
);
|
||||
})();
|
||||
});
|
||||
|
||||
// Mock the recovery call to also be long-running
|
||||
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
|
||||
const signal = params?.config?.abortSignal;
|
||||
// eslint-disable-next-line require-yield
|
||||
return (async function* () {
|
||||
await new Promise<void>((resolve) =>
|
||||
signal?.addEventListener('abort', () => resolve()),
|
||||
);
|
||||
})();
|
||||
});
|
||||
|
||||
const runPromise = executor.run(
|
||||
{ goal: 'Timeout recovery fail' },
|
||||
signal,
|
||||
);
|
||||
|
||||
// 1. Trigger the main timeout
|
||||
await vi.advanceTimersByTimeAsync(31 * 1000);
|
||||
// 2. Let microtasks run (start recovery turn)
|
||||
await vi.advanceTimersByTimeAsync(1);
|
||||
// 3. Trigger the grace period timeout (60s)
|
||||
await vi.advanceTimersByTimeAsync(61 * 1000);
|
||||
|
||||
const output = await runPromise;
|
||||
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
expect(output.terminate_reason).toBe(AgentTerminateMode.TIMEOUT);
|
||||
expect(output.result).toContain('Agent timed out after 0.5 minutes.');
|
||||
|
||||
expect(activities).toContainEqual(
|
||||
expect.objectContaining({
|
||||
type: 'ERROR',
|
||||
data: expect.objectContaining({
|
||||
context: 'recovery_turn',
|
||||
error: 'Graceful recovery attempt failed. Reason: stop',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -49,6 +49,19 @@ import { debugLogger } from '../utils/debugLogger.js';
|
||||
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
|
||||
|
||||
const TASK_COMPLETE_TOOL_NAME = 'complete_task';
|
||||
const GRACE_PERIOD_MS = 60 * 1000; // 1 min
|
||||
|
||||
/** The possible outcomes of a single agent turn. */
|
||||
type AgentTurnResult =
|
||||
| {
|
||||
status: 'continue';
|
||||
nextMessage: Content;
|
||||
}
|
||||
| {
|
||||
status: 'stop';
|
||||
terminateReason: AgentTerminateMode;
|
||||
finalResult: string | null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Executes an agent loop based on an {@link AgentDefinition}.
|
||||
@@ -146,6 +159,173 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
this.agentId = `${parentPrefix}${this.definition.name}-${randomIdPart}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes a single turn of the agent's logic, from calling the model
|
||||
* to processing its response.
|
||||
*
|
||||
* @returns An {@link AgentTurnResult} object indicating whether to continue
|
||||
* or stop the agent loop.
|
||||
*/
|
||||
private async executeTurn(
|
||||
chat: GeminiChat,
|
||||
currentMessage: Content,
|
||||
tools: FunctionDeclaration[],
|
||||
turnCounter: number,
|
||||
combinedSignal: AbortSignal,
|
||||
timeoutSignal: AbortSignal, // Pass the timeout controller's signal
|
||||
): Promise<AgentTurnResult> {
|
||||
const promptId = `${this.agentId}#${turnCounter}`;
|
||||
|
||||
const { functionCalls } = await promptIdContext.run(promptId, async () =>
|
||||
this.callModel(chat, currentMessage, tools, combinedSignal, promptId),
|
||||
);
|
||||
|
||||
if (combinedSignal.aborted) {
|
||||
const terminateReason = timeoutSignal.aborted
|
||||
? AgentTerminateMode.TIMEOUT
|
||||
: AgentTerminateMode.ABORTED;
|
||||
return {
|
||||
status: 'stop',
|
||||
terminateReason,
|
||||
finalResult: null, // 'run' method will set the final timeout string
|
||||
};
|
||||
}
|
||||
|
||||
// If the model stops calling tools without calling complete_task, it's an error.
|
||||
if (functionCalls.length === 0) {
|
||||
this.emitActivity('ERROR', {
|
||||
error: `Agent stopped calling tools but did not call '${TASK_COMPLETE_TOOL_NAME}' to finalize the session.`,
|
||||
context: 'protocol_violation',
|
||||
});
|
||||
return {
|
||||
status: 'stop',
|
||||
terminateReason: AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL,
|
||||
finalResult: null,
|
||||
};
|
||||
}
|
||||
|
||||
const { nextMessage, submittedOutput, taskCompleted } =
|
||||
await this.processFunctionCalls(functionCalls, combinedSignal, promptId);
|
||||
|
||||
if (taskCompleted) {
|
||||
const finalResult = submittedOutput ?? 'Task completed successfully.';
|
||||
return {
|
||||
status: 'stop',
|
||||
terminateReason: AgentTerminateMode.GOAL,
|
||||
finalResult,
|
||||
};
|
||||
}
|
||||
|
||||
// Task is not complete, continue to the next turn.
|
||||
return {
|
||||
status: 'continue',
|
||||
nextMessage,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a specific warning message for the agent's final turn.
|
||||
*/
|
||||
private getFinalWarningMessage(
|
||||
reason:
|
||||
| AgentTerminateMode.TIMEOUT
|
||||
| AgentTerminateMode.MAX_TURNS
|
||||
| AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL,
|
||||
): string {
|
||||
let explanation = '';
|
||||
switch (reason) {
|
||||
case AgentTerminateMode.TIMEOUT:
|
||||
explanation = 'You have exceeded the time limit.';
|
||||
break;
|
||||
case AgentTerminateMode.MAX_TURNS:
|
||||
explanation = 'You have exceeded the maximum number of turns.';
|
||||
break;
|
||||
case AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL:
|
||||
explanation = 'You have stopped calling tools without finishing.';
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unknown terminate reason: ${reason}`);
|
||||
}
|
||||
return `${explanation} You have one final chance to complete the task with a short grace period. You MUST call \`${TASK_COMPLETE_TOOL_NAME}\` immediately with your best answer and explain that your investigation was interrupted. Do not call any other tools.`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts a single, final recovery turn if the agent stops for a recoverable reason.
|
||||
* Gives the agent a grace period to call `complete_task`.
|
||||
*
|
||||
* @returns The final result string if recovery was successful, or `null` if it failed.
|
||||
*/
|
||||
private async executeFinalWarningTurn(
|
||||
chat: GeminiChat,
|
||||
tools: FunctionDeclaration[],
|
||||
turnCounter: number,
|
||||
reason:
|
||||
| AgentTerminateMode.TIMEOUT
|
||||
| AgentTerminateMode.MAX_TURNS
|
||||
| AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL,
|
||||
externalSignal: AbortSignal, // The original signal passed to run()
|
||||
): Promise<string | null> {
|
||||
this.emitActivity('THOUGHT_CHUNK', {
|
||||
text: `Execution limit reached (${reason}). Attempting one final recovery turn with a grace period.`,
|
||||
});
|
||||
|
||||
const gracePeriodMs = GRACE_PERIOD_MS;
|
||||
const graceTimeoutController = new AbortController();
|
||||
const graceTimeoutId = setTimeout(
|
||||
() => graceTimeoutController.abort(new Error('Grace period timed out.')),
|
||||
gracePeriodMs,
|
||||
);
|
||||
|
||||
try {
|
||||
const recoveryMessage: Content = {
|
||||
role: 'user',
|
||||
parts: [{ text: this.getFinalWarningMessage(reason) }],
|
||||
};
|
||||
|
||||
// We monitor both the external signal and our new grace period timeout
|
||||
const combinedSignal = AbortSignal.any([
|
||||
externalSignal,
|
||||
graceTimeoutController.signal,
|
||||
]);
|
||||
|
||||
const turnResult = await this.executeTurn(
|
||||
chat,
|
||||
recoveryMessage,
|
||||
tools,
|
||||
turnCounter, // This will be the "last" turn number
|
||||
combinedSignal,
|
||||
graceTimeoutController.signal, // Pass grace signal to identify a *grace* timeout
|
||||
);
|
||||
|
||||
if (
|
||||
turnResult.status === 'stop' &&
|
||||
turnResult.terminateReason === AgentTerminateMode.GOAL
|
||||
) {
|
||||
// Success!
|
||||
this.emitActivity('THOUGHT_CHUNK', {
|
||||
text: 'Graceful recovery succeeded.',
|
||||
});
|
||||
return turnResult.finalResult ?? 'Task completed during grace period.';
|
||||
}
|
||||
|
||||
// Any other outcome (continue, error, non-GOAL stop) is a failure.
|
||||
this.emitActivity('ERROR', {
|
||||
error: `Graceful recovery attempt failed. Reason: ${turnResult.status}`,
|
||||
context: 'recovery_turn',
|
||||
});
|
||||
return null;
|
||||
} catch (error) {
|
||||
// This catch block will likely catch the 'Grace period timed out' error.
|
||||
this.emitActivity('ERROR', {
|
||||
error: `Graceful recovery attempt failed: ${String(error)}`,
|
||||
context: 'recovery_turn',
|
||||
});
|
||||
return null;
|
||||
} finally {
|
||||
clearTimeout(graceTimeoutId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the agent.
|
||||
*
|
||||
@@ -174,22 +354,25 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
new AgentStartEvent(this.agentId, this.definition.name),
|
||||
);
|
||||
|
||||
let chat: GeminiChat | undefined;
|
||||
let tools: FunctionDeclaration[] | undefined;
|
||||
try {
|
||||
const chat = await this.createChatObject(inputs);
|
||||
const tools = this.prepareToolsList();
|
||||
|
||||
chat = await this.createChatObject(inputs);
|
||||
tools = this.prepareToolsList();
|
||||
const query = this.definition.promptConfig.query
|
||||
? templateString(this.definition.promptConfig.query, inputs)
|
||||
: 'Get Started!';
|
||||
let currentMessage: Content = { role: 'user', parts: [{ text: query }] };
|
||||
|
||||
while (true) {
|
||||
// Check for termination conditions like max turns or timeout.
|
||||
// Check for termination conditions like max turns.
|
||||
const reason = this.checkTermination(startTime, turnCounter);
|
||||
if (reason) {
|
||||
terminateReason = reason;
|
||||
break;
|
||||
}
|
||||
|
||||
// Check for timeout or external abort.
|
||||
if (combinedSignal.aborted) {
|
||||
// Determine which signal caused the abort.
|
||||
terminateReason = timeoutController.signal.aborted
|
||||
@@ -198,62 +381,78 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
break;
|
||||
}
|
||||
|
||||
const promptId = `${this.agentId}#${turnCounter++}`;
|
||||
|
||||
const { functionCalls } = await promptIdContext.run(
|
||||
promptId,
|
||||
async () =>
|
||||
this.callModel(
|
||||
chat,
|
||||
currentMessage,
|
||||
tools,
|
||||
combinedSignal,
|
||||
promptId,
|
||||
),
|
||||
const turnResult = await this.executeTurn(
|
||||
chat,
|
||||
currentMessage,
|
||||
tools,
|
||||
turnCounter++,
|
||||
combinedSignal,
|
||||
timeoutController.signal,
|
||||
);
|
||||
|
||||
if (combinedSignal.aborted) {
|
||||
terminateReason = timeoutController.signal.aborted
|
||||
? AgentTerminateMode.TIMEOUT
|
||||
: AgentTerminateMode.ABORTED;
|
||||
break;
|
||||
if (turnResult.status === 'stop') {
|
||||
terminateReason = turnResult.terminateReason;
|
||||
// Only set finalResult if the turn provided one (e.g., error or goal).
|
||||
if (turnResult.finalResult) {
|
||||
finalResult = turnResult.finalResult;
|
||||
}
|
||||
break; // Exit the loop for *any* stop reason.
|
||||
}
|
||||
|
||||
// If the model stops calling tools without calling complete_task, it's an error.
|
||||
if (functionCalls.length === 0) {
|
||||
terminateReason = AgentTerminateMode.ERROR;
|
||||
finalResult = `Agent stopped calling tools but did not call '${TASK_COMPLETE_TOOL_NAME}' to finalize the session.`;
|
||||
this.emitActivity('ERROR', {
|
||||
error: finalResult,
|
||||
context: 'protocol_violation',
|
||||
});
|
||||
break;
|
||||
}
|
||||
// If status is 'continue', update message for the next loop
|
||||
currentMessage = turnResult.nextMessage;
|
||||
}
|
||||
|
||||
const { nextMessage, submittedOutput, taskCompleted } =
|
||||
await this.processFunctionCalls(
|
||||
functionCalls,
|
||||
combinedSignal,
|
||||
promptId,
|
||||
);
|
||||
// === UNIFIED RECOVERY BLOCK ===
|
||||
// Only attempt recovery if it's a known recoverable reason.
|
||||
// We don't recover from GOAL (already done) or ABORTED (user cancelled).
|
||||
if (
|
||||
terminateReason !== AgentTerminateMode.ERROR &&
|
||||
terminateReason !== AgentTerminateMode.ABORTED &&
|
||||
terminateReason !== AgentTerminateMode.GOAL
|
||||
) {
|
||||
const recoveryResult = await this.executeFinalWarningTurn(
|
||||
chat,
|
||||
tools,
|
||||
turnCounter, // Use current turnCounter for the recovery attempt
|
||||
terminateReason,
|
||||
signal, // Pass the external signal
|
||||
);
|
||||
|
||||
if (taskCompleted) {
|
||||
finalResult = submittedOutput ?? 'Task completed successfully.';
|
||||
if (recoveryResult !== null) {
|
||||
// Recovery Succeeded
|
||||
terminateReason = AgentTerminateMode.GOAL;
|
||||
break;
|
||||
finalResult = recoveryResult;
|
||||
} else {
|
||||
// Recovery Failed. Set the final error message based on the *original* reason.
|
||||
if (terminateReason === AgentTerminateMode.TIMEOUT) {
|
||||
finalResult = `Agent timed out after ${this.definition.runConfig.max_time_minutes} minutes.`;
|
||||
this.emitActivity('ERROR', {
|
||||
error: finalResult,
|
||||
context: 'timeout',
|
||||
});
|
||||
} else if (terminateReason === AgentTerminateMode.MAX_TURNS) {
|
||||
finalResult = `Agent reached max turns limit (${this.definition.runConfig.max_turns}).`;
|
||||
this.emitActivity('ERROR', {
|
||||
error: finalResult,
|
||||
context: 'max_turns',
|
||||
});
|
||||
} else if (
|
||||
terminateReason === AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL
|
||||
) {
|
||||
// The finalResult was already set by executeTurn, but we re-emit just in case.
|
||||
finalResult =
|
||||
finalResult ||
|
||||
`Agent stopped calling tools but did not call '${TASK_COMPLETE_TOOL_NAME}'.`;
|
||||
this.emitActivity('ERROR', {
|
||||
error: finalResult,
|
||||
context: 'protocol_violation',
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
currentMessage = nextMessage;
|
||||
}
|
||||
|
||||
if (terminateReason === AgentTerminateMode.TIMEOUT) {
|
||||
finalResult = `Agent timed out after ${this.definition.runConfig.max_time_minutes} minutes.`;
|
||||
this.emitActivity('ERROR', {
|
||||
error: finalResult,
|
||||
context: 'timeout',
|
||||
});
|
||||
}
|
||||
|
||||
// === FINAL RETURN LOGIC ===
|
||||
if (terminateReason === AgentTerminateMode.GOAL) {
|
||||
return {
|
||||
result: finalResult || 'Task completed.',
|
||||
@@ -275,6 +474,29 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
!signal.aborted // Ensure the external signal was not the cause
|
||||
) {
|
||||
terminateReason = AgentTerminateMode.TIMEOUT;
|
||||
|
||||
// Also use the unified recovery logic here
|
||||
if (chat && tools) {
|
||||
const recoveryResult = await this.executeFinalWarningTurn(
|
||||
chat,
|
||||
tools,
|
||||
turnCounter, // Use current turnCounter
|
||||
AgentTerminateMode.TIMEOUT,
|
||||
signal,
|
||||
);
|
||||
|
||||
if (recoveryResult !== null) {
|
||||
// Recovery Succeeded
|
||||
terminateReason = AgentTerminateMode.GOAL;
|
||||
finalResult = recoveryResult;
|
||||
return {
|
||||
result: finalResult,
|
||||
terminate_reason: terminateReason,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Recovery failed or wasn't possible
|
||||
finalResult = `Agent timed out after ${this.definition.runConfig.max_time_minutes} minutes.`;
|
||||
this.emitActivity('ERROR', {
|
||||
error: finalResult,
|
||||
|
||||
@@ -21,6 +21,7 @@ export enum AgentTerminateMode {
|
||||
GOAL = 'GOAL',
|
||||
MAX_TURNS = 'MAX_TURNS',
|
||||
ABORTED = 'ABORTED',
|
||||
ERROR_NO_COMPLETE_TASK_CALL = 'ERROR_NO_COMPLETE_TASK_CALL',
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user