refactor(agents): implement submit_final_output tool for agent completion (#10377)

This commit is contained in:
Abhi
2025-10-02 14:07:58 -04:00
committed by GitHub
parent 4a70d6f22f
commit a6af7bbb46
4 changed files with 696 additions and 548 deletions

View File

@@ -7,6 +7,7 @@
import type { Config } from '../config/config.js';
import { reportError } from '../utils/errorReporting.js';
import { GeminiChat, StreamEventType } from '../core/geminiChat.js';
import { Type } from '@google/genai';
import type {
Content,
Part,
@@ -39,14 +40,13 @@ import { parseThought } from '../utils/thoughtUtils.js';
/** A callback function to report on agent activity. */
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
const TASK_COMPLETE_TOOL_NAME = 'complete_task';
/**
* Executes an agent loop based on an {@link AgentDefinition}.
*
* This executor uses a simplified two-phase approach:
* 1. **Work Phase:** The agent runs in a loop, calling tools until it has
* gathered all necessary information to fulfill its goal.
* 2. **Extraction Phase:** A final prompt is sent to the model to summarize
* the work and extract the final result in the desired format.
* This executor runs the agent in a loop, calling tools until it calls the
* mandatory `complete_task` tool to signal completion.
*/
export class AgentExecutor {
readonly definition: AgentDefinition;
@@ -144,16 +144,13 @@ export class AgentExecutor {
try {
const chat = await this.createChatObject(inputs);
const tools = this.prepareToolsList();
let terminateReason = AgentTerminateMode.GOAL;
let terminateReason = AgentTerminateMode.ERROR;
let finalResult: string | null = null;
// Phase 1: Work Phase
// The agent works in a loop until it stops calling tools.
const query = this.definition.promptConfig.query
? templateString(this.definition.promptConfig.query, inputs)
: 'Get Started!';
let currentMessages: Content[] = [
{ role: 'user', parts: [{ text: query }] },
];
let currentMessage: Content = { role: 'user', parts: [{ text: query }] };
while (true) {
// Check for termination conditions like max turns or timeout.
@@ -171,7 +168,7 @@ export class AgentExecutor {
const promptId = `${this.runtimeContext.getSessionId()}#${this.agentId}#${turnCounter++}`;
const { functionCalls } = await this.callModel(
chat,
currentMessages,
currentMessage,
tools,
signal,
promptId,
@@ -182,46 +179,39 @@ export class AgentExecutor {
break;
}
// If the model stops calling tools, the work phase is complete.
// If the model stops calling tools without calling complete_task, it's an error.
if (functionCalls.length === 0) {
terminateReason = AgentTerminateMode.ERROR;
finalResult = `Agent stopped calling tools but did not call '${TASK_COMPLETE_TOOL_NAME}' to finalize the session.`;
this.emitActivity('ERROR', {
error: finalResult,
context: 'protocol_violation',
});
break;
}
currentMessages = await this.processFunctionCalls(
functionCalls,
signal,
promptId,
);
const { nextMessage, submittedOutput, taskCompleted } =
await this.processFunctionCalls(functionCalls, signal, promptId);
if (taskCompleted) {
finalResult = submittedOutput ?? 'Task completed successfully.';
terminateReason = AgentTerminateMode.GOAL;
break;
}
currentMessage = nextMessage;
}
// If the work phase was terminated early, skip extraction and return.
if (terminateReason !== AgentTerminateMode.GOAL) {
if (terminateReason === AgentTerminateMode.GOAL) {
return {
result: 'Agent execution was terminated before completion.',
result: finalResult || 'Task completed.',
terminate_reason: terminateReason,
};
}
// Phase 2: Extraction Phase
// A final message is sent to summarize findings and produce the output.
const extractionMessage = this.buildExtractionMessage();
const extractionMessages: Content[] = [
{ role: 'user', parts: [{ text: extractionMessage }] },
];
const extractionPromptId = `${this.runtimeContext.getSessionId()}#${this.agentId}#extraction`;
// TODO: Consider if we should keep tools to avoid cache reset.
const { textResponse } = await this.callModel(
chat,
extractionMessages,
[], // No tools are available in the extraction phase.
signal,
extractionPromptId,
);
return {
result: textResponse || 'No response generated',
result:
finalResult || 'Agent execution was terminated before completion.',
terminate_reason: terminateReason,
};
} catch (error) {
@@ -237,13 +227,13 @@ export class AgentExecutor {
*/
private async callModel(
chat: GeminiChat,
messages: Content[],
message: Content,
tools: FunctionDeclaration[],
signal: AbortSignal,
promptId: string,
): Promise<{ functionCalls: FunctionCall[]; textResponse: string }> {
const messageParams = {
message: messages[0]?.parts || [],
message: message.parts || [],
config: {
abortSignal: signal,
tools: tools.length > 0 ? [{ functionDeclarations: tools }] : undefined,
@@ -349,45 +339,145 @@ export class AgentExecutor {
/**
* Executes function calls requested by the model and returns the results.
*
* @returns A new `Content` object to be added to the chat history.
* @returns A new `Content` object for history, any submitted output, and completion status.
*/
private async processFunctionCalls(
functionCalls: FunctionCall[],
signal: AbortSignal,
promptId: string,
): Promise<Content[]> {
): Promise<{
nextMessage: Content;
submittedOutput: string | null;
taskCompleted: boolean;
}> {
const allowedToolNames = new Set(this.toolRegistry.getAllToolNames());
// Always allow the completion tool
allowedToolNames.add(TASK_COMPLETE_TOOL_NAME);
// Filter out any tool calls that are not in the agent's allowed list.
const validatedFunctionCalls = functionCalls.filter((call) => {
if (!allowedToolNames.has(call.name as string)) {
console.warn(
`[AgentExecutor] Agent '${this.definition.name}' attempted to call ` +
`unauthorized tool '${call.name}'. This call has been blocked.`,
);
return false;
let submittedOutput: string | null = null;
let taskCompleted = false;
// We'll collect promises for the tool executions
const toolExecutionPromises: Array<Promise<Part[] | void>> = [];
// And we'll need a place to store the synchronous results (like complete_task or blocked calls)
const syncResponseParts: Part[] = [];
for (const [index, functionCall] of functionCalls.entries()) {
const callId = functionCall.id ?? `${promptId}-${index}`;
const args = (functionCall.args ?? {}) as Record<string, unknown>;
this.emitActivity('TOOL_CALL_START', {
name: functionCall.name,
args,
});
if (functionCall.name === TASK_COMPLETE_TOOL_NAME) {
if (taskCompleted) {
// We already have a completion from this turn. Ignore subsequent ones.
const error =
'Task already marked complete in this turn. Ignoring duplicate call.';
syncResponseParts.push({
functionResponse: {
name: TASK_COMPLETE_TOOL_NAME,
response: { error },
id: callId,
},
});
this.emitActivity('ERROR', {
context: 'tool_call',
name: functionCall.name,
error,
});
continue;
}
const { outputConfig } = this.definition;
taskCompleted = true; // Signal completion regardless of output presence
if (outputConfig) {
const outputName = outputConfig.outputName;
if (args[outputName] !== undefined) {
submittedOutput = String(args[outputName]);
syncResponseParts.push({
functionResponse: {
name: TASK_COMPLETE_TOOL_NAME,
response: { result: 'Output submitted and task completed.' },
id: callId,
},
});
this.emitActivity('TOOL_CALL_END', {
name: functionCall.name,
output: 'Output submitted and task completed.',
});
} else {
// Failed to provide required output.
taskCompleted = false; // Revoke completion status
const error = `Missing required argument '${outputName}' for completion.`;
syncResponseParts.push({
functionResponse: {
name: TASK_COMPLETE_TOOL_NAME,
response: { error },
id: callId,
},
});
this.emitActivity('ERROR', {
context: 'tool_call',
name: functionCall.name,
error,
});
}
} else {
// No output expected. Just signal completion.
submittedOutput = 'Task completed successfully.';
syncResponseParts.push({
functionResponse: {
name: TASK_COMPLETE_TOOL_NAME,
response: { status: 'Task marked complete.' },
id: callId,
},
});
this.emitActivity('TOOL_CALL_END', {
name: functionCall.name,
output: 'Task marked complete.',
});
}
continue;
}
return true;
});
const toolPromises = validatedFunctionCalls.map(
async (functionCall, index) => {
const callId = functionCall.id ?? `${promptId}-${index}`;
const args = functionCall.args ?? {};
// Handle standard tools
if (!allowedToolNames.has(functionCall.name as string)) {
const error = `Unauthorized tool call: '${functionCall.name}' is not available to this agent.`;
this.emitActivity('TOOL_CALL_START', {
name: functionCall.name,
args,
console.warn(`[AgentExecutor] Blocked call: ${error}`);
syncResponseParts.push({
functionResponse: {
name: functionCall.name as string,
id: callId,
response: { error },
},
});
const requestInfo: ToolCallRequestInfo = {
this.emitActivity('ERROR', {
context: 'tool_call_unauthorized',
name: functionCall.name,
callId,
name: functionCall.name as string,
args: args as Record<string, unknown>,
isClientInitiated: true,
prompt_id: promptId,
};
error,
});
continue;
}
const requestInfo: ToolCallRequestInfo = {
callId,
name: functionCall.name as string,
args,
isClientInitiated: true,
prompt_id: promptId,
};
// Create a promise for the tool execution
const executionPromise = (async () => {
const toolResponse = await executeToolCall(
this.runtimeContext,
requestInfo,
@@ -407,24 +497,39 @@ export class AgentExecutor {
});
}
return toolResponse;
},
);
return toolResponse.responseParts;
})();
const toolResponses = await Promise.all(toolPromises);
const toolResponseParts: Part[] = toolResponses
.flatMap((response) => response.responseParts)
.filter((part): part is Part => part !== undefined);
toolExecutionPromises.push(executionPromise);
}
// If all authorized tool calls failed, provide a generic error message
// to the model so it can try a different approach.
if (functionCalls.length > 0 && toolResponseParts.length === 0) {
// Wait for all tool executions to complete
const asyncResults = await Promise.all(toolExecutionPromises);
// Combine all response parts
const toolResponseParts: Part[] = [...syncResponseParts];
for (const result of asyncResults) {
if (result) {
toolResponseParts.push(...result);
}
}
// If all authorized tool calls failed (and task isn't complete), provide a generic error.
if (
functionCalls.length > 0 &&
toolResponseParts.length === 0 &&
!taskCompleted
) {
toolResponseParts.push({
text: 'All tool calls failed. Please analyze the errors and try an alternative approach.',
text: 'All tool calls failed or were unauthorized. Please analyze the errors and try an alternative approach.',
});
}
return [{ role: 'user', parts: toolResponseParts }];
return {
nextMessage: { role: 'user', parts: toolResponseParts },
submittedOutput,
taskCompleted,
};
}
/**
@@ -432,7 +537,7 @@ export class AgentExecutor {
*/
private prepareToolsList(): FunctionDeclaration[] {
const toolsList: FunctionDeclaration[] = [];
const { toolConfig } = this.definition;
const { toolConfig, outputConfig } = this.definition;
if (toolConfig) {
const toolNamesToLoad: string[] = [];
@@ -453,12 +558,36 @@ export class AgentExecutor {
);
}
// Always inject complete_task.
// Configure its schema based on whether output is expected.
const completeTool: FunctionDeclaration = {
name: TASK_COMPLETE_TOOL_NAME,
description: outputConfig
? 'Call this tool to submit your final answer and complete the task. This is the ONLY way to finish.'
: 'Call this tool to signal that you have completed your task. This is the ONLY way to finish.',
parameters: {
type: Type.OBJECT,
properties: {},
required: [],
},
};
if (outputConfig) {
completeTool.parameters!.properties![outputConfig.outputName] = {
description: outputConfig.description,
...(outputConfig.schema ?? { type: Type.STRING }),
};
completeTool.parameters!.required!.push(outputConfig.outputName);
}
toolsList.push(completeTool);
return toolsList;
}
/** Builds the system prompt from the agent definition and inputs. */
private async buildSystemPrompt(inputs: AgentInputs): Promise<string> {
const { promptConfig, outputConfig } = this.definition;
const { promptConfig } = this.definition;
if (!promptConfig.systemPrompt) {
return '';
}
@@ -470,45 +599,21 @@ export class AgentExecutor {
const dirContext = await getDirectoryContextString(this.runtimeContext);
finalPrompt += `\n\n# Environment Context\n${dirContext}`;
// Append completion criteria to guide the model's output.
if (outputConfig?.completion_criteria) {
finalPrompt += '\n\nEnsure you complete the following:\n';
for (const criteria of outputConfig.completion_criteria) {
finalPrompt += `- ${criteria}\n`;
}
}
// Append standard rules for non-interactive execution.
finalPrompt += `
Important Rules:
* You are running in a non-interactive mode. You CANNOT ask the user for input or clarification.
* Work systematically using available tools to complete your task.
* Always use absolute paths for file operations. Construct them using the provided "Environment Context".
* When you have completed your analysis and are ready to produce the final answer, stop calling tools.`;
* Always use absolute paths for file operations. Construct them using the provided "Environment Context".`;
finalPrompt += `
* When you have completed your task, you MUST call the \`${TASK_COMPLETE_TOOL_NAME}\` tool.
* Do not call any other tools in the same turn as \`${TASK_COMPLETE_TOOL_NAME}\`.
* This is the ONLY way to complete your mission. If you stop calling tools without calling this, you have failed.`;
return finalPrompt;
}
/** Builds the final message for the extraction phase. */
private buildExtractionMessage(): string {
const { outputConfig } = this.definition;
if (outputConfig?.description) {
let message = `Based on your work so far, provide: ${outputConfig.description}`;
if (outputConfig.completion_criteria?.length) {
message += `\n\nBe sure you have addressed:\n`;
for (const criteria of outputConfig.completion_criteria) {
message += `- ${criteria}\n`;
}
}
return message;
}
// Fallback to a generic extraction message if no description is provided.
return 'Based on your work so far, provide a comprehensive summary of your analysis and findings. Do not perform any more function calls.';
}
/**
* Applies template strings to initial messages.
*