mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-06-13 21:07:00 -07:00
refactor(core): unify ReAct loop in AgentHarness using Behavioral architecture
This commit is contained in:
@@ -6,7 +6,7 @@
|
||||
|
||||
import { type Config } from '../config/config.js';
|
||||
import { AgentHarness, type AgentHarnessOptions } from './harness.js';
|
||||
import { type AgentDefinition } from './types.js';
|
||||
import { type AgentDefinition, type LocalAgentDefinition } from './types.js';
|
||||
import { MainAgentBehavior, SubagentBehavior } from './behavior.js';
|
||||
|
||||
/**
|
||||
@@ -19,15 +19,18 @@ export class AgentFactory {
|
||||
definition?: AgentDefinition,
|
||||
options: Partial<AgentHarnessOptions> = {},
|
||||
): AgentHarness {
|
||||
const behavior = definition
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
|
||||
? new SubagentBehavior(
|
||||
config,
|
||||
definition as any,
|
||||
options.inputs,
|
||||
options.parentPromptId
|
||||
)
|
||||
: new MainAgentBehavior(config, options.parentPromptId);
|
||||
let behavior;
|
||||
if (definition && definition.kind === 'local') {
|
||||
const localDef: LocalAgentDefinition = definition;
|
||||
behavior = new SubagentBehavior(
|
||||
config,
|
||||
localDef,
|
||||
options.inputs,
|
||||
options.parentPromptId,
|
||||
);
|
||||
} else {
|
||||
behavior = new MainAgentBehavior(config, options.parentPromptId);
|
||||
}
|
||||
|
||||
return new AgentHarness({
|
||||
config,
|
||||
|
||||
@@ -11,14 +11,21 @@ import {
|
||||
Type,
|
||||
} from '@google/genai';
|
||||
import { type Config } from '../config/config.js';
|
||||
import { type Turn, type ServerGeminiStreamEvent } from '../core/turn.js';
|
||||
import {
|
||||
import {
|
||||
type Turn,
|
||||
type ServerGeminiStreamEvent,
|
||||
GeminiEventType,
|
||||
} from '../core/turn.js';
|
||||
import {
|
||||
AgentTerminateMode,
|
||||
type LocalAgentDefinition,
|
||||
type AgentInputs
|
||||
type LocalAgentDefinition,
|
||||
type AgentInputs,
|
||||
} from './types.js';
|
||||
import { getCoreSystemPrompt } from '../core/prompts.js';
|
||||
import { getInitialChatHistory, getDirectoryContextString } from '../utils/environmentContext.js';
|
||||
import {
|
||||
getInitialChatHistory,
|
||||
getDirectoryContextString,
|
||||
} from '../utils/environmentContext.js';
|
||||
import { templateString } from './utils.js';
|
||||
import { getVersion } from '../utils/version.js';
|
||||
import { zodToJsonSchema } from 'zod-to-json-schema';
|
||||
@@ -41,7 +48,7 @@ const GRACE_PERIOD_MS = 60 * 1000;
|
||||
export interface AgentBehavior {
|
||||
/** The unique ID for this agent instance. */
|
||||
readonly agentId: string;
|
||||
|
||||
|
||||
/** The human-readable name of the agent. */
|
||||
readonly name: string;
|
||||
|
||||
@@ -54,7 +61,7 @@ export interface AgentBehavior {
|
||||
/** Returns the initial chat history. */
|
||||
getInitialHistory(): Promise<Content[]>;
|
||||
|
||||
/**
|
||||
/**
|
||||
* Prepares the tools list for the current turn.
|
||||
* @param baseTools The tools from the tool registry.
|
||||
*/
|
||||
@@ -68,12 +75,29 @@ export interface AgentBehavior {
|
||||
/**
|
||||
* Fires the "Before Agent" hooks if applicable.
|
||||
*/
|
||||
fireBeforeAgent(request: Part[]): Promise<{ stop?: boolean; reason?: string; systemMessage?: string; additionalContext?: string }>;
|
||||
fireBeforeAgent(
|
||||
request: Part[],
|
||||
): Promise<{
|
||||
stop?: boolean;
|
||||
reason?: string;
|
||||
systemMessage?: string;
|
||||
additionalContext?: string;
|
||||
}>;
|
||||
|
||||
/**
|
||||
* Fires the "After Agent" hooks if applicable.
|
||||
*/
|
||||
fireAfterAgent(request: Part[], response: string, turn: Turn): Promise<{ stop?: boolean; reason?: string; systemMessage?: string; contextCleared?: boolean; shouldContinue?: boolean }>;
|
||||
fireAfterAgent(
|
||||
request: Part[],
|
||||
response: string,
|
||||
turn: Turn,
|
||||
): Promise<{
|
||||
stop?: boolean;
|
||||
reason?: string;
|
||||
systemMessage?: string;
|
||||
contextCleared?: boolean;
|
||||
shouldContinue?: boolean;
|
||||
}>;
|
||||
|
||||
/**
|
||||
* Transforms the initial request if needed (e.g. subagent 'Start' templating).
|
||||
@@ -90,18 +114,29 @@ export interface AgentBehavior {
|
||||
* Checks if the agent should continue executing after a model turn with no tool calls.
|
||||
* (e.g., Main agent running next_speaker check)
|
||||
*/
|
||||
getContinuationRequest(turn: Turn, signal: AbortSignal): Promise<Part[] | null>;
|
||||
getContinuationRequest(
|
||||
turn: Turn,
|
||||
signal: AbortSignal,
|
||||
): Promise<Part[] | null>;
|
||||
|
||||
/**
|
||||
* Attempts to recover from a termination state (e.g., Subagent "Final Warning").
|
||||
* Returns a stream of events if recovery is attempted.
|
||||
*/
|
||||
executeRecovery(turn: Turn, reason: AgentTerminateMode, signal: AbortSignal): AsyncGenerator<ServerGeminiStreamEvent, boolean>;
|
||||
executeRecovery(
|
||||
turn: Turn,
|
||||
reason: AgentTerminateMode,
|
||||
signal: AbortSignal,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, boolean>;
|
||||
|
||||
/**
|
||||
* Returns a final failure message for a given termination reason.
|
||||
*/
|
||||
getFinalFailureMessage(reason: AgentTerminateMode, maxTurns: number, maxTime: number): string;
|
||||
getFinalFailureMessage(
|
||||
reason: AgentTerminateMode,
|
||||
maxTurns: number,
|
||||
maxTime: number,
|
||||
): string;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -113,7 +148,10 @@ export class MainAgentBehavior implements AgentBehavior {
|
||||
private lastSentIdeContext: IdeContext | undefined;
|
||||
private forceFullIdeContext = true;
|
||||
|
||||
constructor(private readonly config: Config, parentPromptId?: string) {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
parentPromptId?: string,
|
||||
) {
|
||||
const randomIdPart = Math.random().toString(36).slice(2, 8);
|
||||
const parentPrefix = parentPromptId ? `${parentPromptId}-` : '';
|
||||
this.agentId = `${parentPrefix}main-${randomIdPart}`;
|
||||
@@ -137,9 +175,12 @@ export class MainAgentBehavior implements AgentBehavior {
|
||||
async syncEnvironment(history: Content[]) {
|
||||
if (!this.config.getIdeMode()) return {};
|
||||
|
||||
const lastMessage = history.length > 0 ? history[history.length - 1] : undefined;
|
||||
const hasPendingToolCall = !!lastMessage && lastMessage.role === 'model' &&
|
||||
(lastMessage.parts?.some(p => 'functionCall' in p) || false);
|
||||
const lastMessage =
|
||||
history.length > 0 ? history[history.length - 1] : undefined;
|
||||
const hasPendingToolCall =
|
||||
!!lastMessage &&
|
||||
lastMessage.role === 'model' &&
|
||||
(lastMessage.parts?.some((p) => 'functionCall' in p) || false);
|
||||
|
||||
if (hasPendingToolCall) return {};
|
||||
|
||||
@@ -147,10 +188,17 @@ export class MainAgentBehavior implements AgentBehavior {
|
||||
if (!currentIdeContext) return {};
|
||||
|
||||
let contextParts: string[] = [];
|
||||
if (this.forceFullIdeContext || this.lastSentIdeContext === undefined || history.length === 0) {
|
||||
if (
|
||||
this.forceFullIdeContext ||
|
||||
this.lastSentIdeContext === undefined ||
|
||||
history.length === 0
|
||||
) {
|
||||
contextParts = this.getFullIdeContextParts(currentIdeContext);
|
||||
} else {
|
||||
contextParts = this.getDeltaIdeContextParts(currentIdeContext, this.lastSentIdeContext);
|
||||
contextParts = this.getDeltaIdeContextParts(
|
||||
currentIdeContext,
|
||||
this.lastSentIdeContext,
|
||||
);
|
||||
}
|
||||
|
||||
if (contextParts.length > 0) {
|
||||
@@ -164,11 +212,12 @@ export class MainAgentBehavior implements AgentBehavior {
|
||||
|
||||
private getFullIdeContextParts(context: IdeContext): string[] {
|
||||
const openFiles = context.workspaceState?.openFiles || [];
|
||||
const activeFile = openFiles.find(f => f.isActive);
|
||||
const otherOpenFiles = openFiles.filter(f => !f.isActive).map(f => f.path);
|
||||
const activeFile = openFiles.find((f) => f.isActive);
|
||||
const otherOpenFiles = openFiles
|
||||
.filter((f) => !f.isActive)
|
||||
.map((f) => f.path);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const contextData: Record<string, any> = {};
|
||||
const contextData: Record<string, unknown> = {};
|
||||
if (activeFile) {
|
||||
contextData['activeFile'] = {
|
||||
path: activeFile.path,
|
||||
@@ -176,7 +225,8 @@ export class MainAgentBehavior implements AgentBehavior {
|
||||
selectedText: activeFile.selectedText || undefined,
|
||||
};
|
||||
}
|
||||
if (otherOpenFiles.length > 0) contextData['otherOpenFiles'] = otherOpenFiles;
|
||||
if (otherOpenFiles.length > 0)
|
||||
contextData['otherOpenFiles'] = otherOpenFiles;
|
||||
|
||||
if (Object.keys(contextData).length === 0) return [];
|
||||
|
||||
@@ -188,10 +238,12 @@ export class MainAgentBehavior implements AgentBehavior {
|
||||
];
|
||||
}
|
||||
|
||||
private getDeltaIdeContextParts(_current: IdeContext, _last: IdeContext): string[] {
|
||||
private getDeltaIdeContextParts(
|
||||
_current: IdeContext,
|
||||
_last: IdeContext,
|
||||
): string[] {
|
||||
// Simplified delta logic for now, similar to GeminiClient
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const changes: Record<string, any> = {};
|
||||
const changes: Record<string, unknown> = {};
|
||||
// ... delta logic ...
|
||||
if (Object.keys(changes).length === 0) return [];
|
||||
|
||||
@@ -205,7 +257,9 @@ export class MainAgentBehavior implements AgentBehavior {
|
||||
|
||||
async fireBeforeAgent(request: Part[]) {
|
||||
if (!this.config.getEnableHooks()) return {};
|
||||
const hookOutput = await this.config.getHookSystem()?.fireBeforeAgentEvent(partToString(request));
|
||||
const hookOutput = await this.config
|
||||
.getHookSystem()
|
||||
?.fireBeforeAgentEvent(partToString(request));
|
||||
if (!hookOutput) return {};
|
||||
|
||||
return {
|
||||
@@ -220,7 +274,9 @@ export class MainAgentBehavior implements AgentBehavior {
|
||||
if (!this.config.getEnableHooks()) return {};
|
||||
if (turn.pendingToolCalls.length > 0) return {};
|
||||
|
||||
const hookOutput = await this.config.getHookSystem()?.fireAfterAgentEvent(partToString(request), response);
|
||||
const hookOutput = await this.config
|
||||
.getHookSystem()
|
||||
?.fireAfterAgentEvent(partToString(request), response);
|
||||
if (!hookOutput) return {};
|
||||
|
||||
return {
|
||||
@@ -242,8 +298,7 @@ export class MainAgentBehavior implements AgentBehavior {
|
||||
|
||||
async getContinuationRequest(turn: Turn, signal: AbortSignal) {
|
||||
const nextSpeaker = await checkNextSpeaker(
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
|
||||
(turn as any).chat,
|
||||
turn.chat,
|
||||
this.config.getBaseLlmClient(),
|
||||
signal,
|
||||
this.agentId,
|
||||
@@ -254,9 +309,8 @@ export class MainAgentBehavior implements AgentBehavior {
|
||||
return null;
|
||||
}
|
||||
|
||||
async *executeRecovery() {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
|
||||
if (this.agentId === 'never') yield {} as any;
|
||||
async *executeRecovery(): AsyncGenerator<ServerGeminiStreamEvent, boolean> {
|
||||
if (this.agentId === 'never') yield { type: GeminiEventType.Retry };
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -276,7 +330,7 @@ export class SubagentBehavior implements AgentBehavior {
|
||||
private readonly config: Config,
|
||||
private readonly definition: LocalAgentDefinition,
|
||||
private readonly inputs?: AgentInputs,
|
||||
parentPromptId?: string
|
||||
parentPromptId?: string,
|
||||
) {
|
||||
this.name = definition.name;
|
||||
const randomIdPart = Math.random().toString(36).slice(2, 8);
|
||||
@@ -293,7 +347,10 @@ export class SubagentBehavior implements AgentBehavior {
|
||||
activeModel: this.config.getActiveModel(),
|
||||
today: new Date().toLocaleDateString(),
|
||||
};
|
||||
let prompt = templateString(this.definition.promptConfig.systemPrompt || '', augmentedInputs);
|
||||
let prompt = templateString(
|
||||
this.definition.promptConfig.systemPrompt || '',
|
||||
augmentedInputs,
|
||||
);
|
||||
const dirContext = await getDirectoryContextString(this.config);
|
||||
prompt += `\n\n# Environment Context\n${dirContext}`;
|
||||
prompt += `\n\nImportant Rules:\n* You are running in a non-interactive mode. You CANNOT ask the user for input or clarification.\n* Work systematically using available tools to complete your task.\n* Always use absolute paths for file operations.`;
|
||||
@@ -322,15 +379,24 @@ export class SubagentBehavior implements AgentBehavior {
|
||||
prepareTools(baseTools: FunctionDeclaration[]) {
|
||||
const completeTool: FunctionDeclaration = {
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
description: 'Call this tool to submit your final answer and complete the task.',
|
||||
description:
|
||||
'Call this tool to submit your final answer and complete the task.',
|
||||
parameters: { type: Type.OBJECT, properties: {}, required: [] },
|
||||
};
|
||||
|
||||
if (this.definition.outputConfig) {
|
||||
const schema = zodToJsonSchema(this.definition.outputConfig.schema);
|
||||
const { $schema: _, definitions: __, ...cleanSchema } = schema as Record<string, unknown>;
|
||||
completeTool.parameters!.properties![this.definition.outputConfig.outputName] = cleanSchema as Schema;
|
||||
completeTool.parameters!.required!.push(this.definition.outputConfig.outputName);
|
||||
const {
|
||||
$schema: _,
|
||||
definitions: __,
|
||||
...cleanSchema
|
||||
} = schema as Record<string, unknown>;
|
||||
completeTool.parameters!.properties![
|
||||
this.definition.outputConfig.outputName
|
||||
] = cleanSchema as Schema;
|
||||
completeTool.parameters!.required!.push(
|
||||
this.definition.outputConfig.outputName,
|
||||
);
|
||||
} else {
|
||||
completeTool.parameters!.properties!['result'] = {
|
||||
type: Type.STRING,
|
||||
@@ -355,11 +421,18 @@ export class SubagentBehavior implements AgentBehavior {
|
||||
}
|
||||
|
||||
async transformRequest(request: Part[]): Promise<Part[]> {
|
||||
if (request.length === 1 && 'text' in request[0] && request[0].text === 'Start') {
|
||||
if (
|
||||
request.length === 1 &&
|
||||
'text' in request[0] &&
|
||||
request[0].text === 'Start'
|
||||
) {
|
||||
return [
|
||||
{
|
||||
text: this.definition.promptConfig.query
|
||||
? templateString(this.definition.promptConfig.query, this.inputs || {})
|
||||
? templateString(
|
||||
this.definition.promptConfig.query,
|
||||
this.inputs || {},
|
||||
)
|
||||
: 'Get Started!',
|
||||
},
|
||||
];
|
||||
@@ -368,7 +441,9 @@ export class SubagentBehavior implements AgentBehavior {
|
||||
}
|
||||
|
||||
isGoalReached(toolResults: Array<{ name: string; part: Part }>) {
|
||||
const completeCall = toolResults.find((r) => r.name === TASK_COMPLETE_TOOL_NAME);
|
||||
const completeCall = toolResults.find(
|
||||
(r) => r.name === TASK_COMPLETE_TOOL_NAME,
|
||||
);
|
||||
if (completeCall) {
|
||||
// If there's an error in the call, we don't treat it as reached (model should retry)
|
||||
return !completeCall.part.functionResponse?.response?.['error'];
|
||||
@@ -380,17 +455,33 @@ export class SubagentBehavior implements AgentBehavior {
|
||||
return null;
|
||||
}
|
||||
|
||||
async *executeRecovery(turn: Turn, reason: AgentTerminateMode, signal: AbortSignal): AsyncGenerator<ServerGeminiStreamEvent, boolean> {
|
||||
async *executeRecovery(
|
||||
turn: Turn,
|
||||
reason: AgentTerminateMode,
|
||||
signal: AbortSignal,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, boolean> {
|
||||
const recoveryStartTime = Date.now();
|
||||
let success = false;
|
||||
const graceTimeoutController = new DeadlineTimer(GRACE_PERIOD_MS, 'Grace period timed out.');
|
||||
const combinedSignal = AbortSignal.any([signal, graceTimeoutController.signal]);
|
||||
const graceTimeoutController = new DeadlineTimer(
|
||||
GRACE_PERIOD_MS,
|
||||
'Grace period timed out.',
|
||||
);
|
||||
const combinedSignal = AbortSignal.any([
|
||||
signal,
|
||||
graceTimeoutController.signal,
|
||||
]);
|
||||
|
||||
try {
|
||||
const recoveryMessage: Part[] = [{ text: this.getFinalWarningMessage(reason) }];
|
||||
const recoveryMessage: Part[] = [
|
||||
{ text: this.getFinalWarningMessage(reason) },
|
||||
];
|
||||
const promptId = `${this.agentId}#recovery`;
|
||||
const recoveryStream = promptIdContext.run(promptId, () =>
|
||||
turn.run({ model: this.config.getActiveModel() }, recoveryMessage, combinedSignal),
|
||||
turn.run(
|
||||
{ model: this.config.getActiveModel() },
|
||||
recoveryMessage,
|
||||
combinedSignal,
|
||||
),
|
||||
);
|
||||
|
||||
for await (const event of recoveryStream) {
|
||||
@@ -399,13 +490,25 @@ export class SubagentBehavior implements AgentBehavior {
|
||||
|
||||
// Check if they called complete_task in the recovery turn
|
||||
if (turn.pendingToolCalls.length > 0) {
|
||||
if (turn.pendingToolCalls.some(c => c.name === TASK_COMPLETE_TOOL_NAME)) {
|
||||
if (
|
||||
turn.pendingToolCalls.some((c) => c.name === TASK_COMPLETE_TOOL_NAME)
|
||||
) {
|
||||
success = true;
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
graceTimeoutController.abort();
|
||||
logRecoveryAttempt(this.config, new RecoveryAttemptEvent(this.agentId, this.name, reason, Date.now() - recoveryStartTime, success, 0));
|
||||
logRecoveryAttempt(
|
||||
this.config,
|
||||
new RecoveryAttemptEvent(
|
||||
this.agentId,
|
||||
this.name,
|
||||
reason,
|
||||
Date.now() - recoveryStartTime,
|
||||
success,
|
||||
0,
|
||||
),
|
||||
);
|
||||
}
|
||||
return success;
|
||||
}
|
||||
@@ -413,20 +516,35 @@ export class SubagentBehavior implements AgentBehavior {
|
||||
private getFinalWarningMessage(reason: AgentTerminateMode): string {
|
||||
let explanation = '';
|
||||
switch (reason) {
|
||||
case AgentTerminateMode.TIMEOUT: explanation = 'You have exceeded the time limit.'; break;
|
||||
case AgentTerminateMode.MAX_TURNS: explanation = 'You have exceeded the maximum number of turns.'; break;
|
||||
case AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL: explanation = 'You have stopped calling tools without finishing.'; break;
|
||||
default: explanation = 'Execution was interrupted.';
|
||||
case AgentTerminateMode.TIMEOUT:
|
||||
explanation = 'You have exceeded the time limit.';
|
||||
break;
|
||||
case AgentTerminateMode.MAX_TURNS:
|
||||
explanation = 'You have exceeded the maximum number of turns.';
|
||||
break;
|
||||
case AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL:
|
||||
explanation = 'You have stopped calling tools without finishing.';
|
||||
break;
|
||||
default:
|
||||
explanation = 'Execution was interrupted.';
|
||||
}
|
||||
return `${explanation} You have one final chance to complete the task with a short grace period. You MUST call \`${TASK_COMPLETE_TOOL_NAME}\` immediately with your best answer and explain that your investigation was interrupted. Do not call any other tools.`;
|
||||
}
|
||||
|
||||
getFinalFailureMessage(reason: AgentTerminateMode, maxTurns: number, maxTime: number) {
|
||||
getFinalFailureMessage(
|
||||
reason: AgentTerminateMode,
|
||||
maxTurns: number,
|
||||
maxTime: number,
|
||||
) {
|
||||
switch (reason) {
|
||||
case AgentTerminateMode.TIMEOUT: return `Agent timed out after ${maxTime} minutes.`;
|
||||
case AgentTerminateMode.MAX_TURNS: return `Agent reached max turns limit (${maxTurns}).`;
|
||||
case AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL: return `Agent stopped calling tools but did not call '${TASK_COMPLETE_TOOL_NAME}'.`;
|
||||
default: return 'Agent execution was terminated before completion.';
|
||||
case AgentTerminateMode.TIMEOUT:
|
||||
return `Agent timed out after ${maxTime} minutes.`;
|
||||
case AgentTerminateMode.MAX_TURNS:
|
||||
return `Agent reached max turns limit (${maxTurns}).`;
|
||||
case AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL:
|
||||
return `Agent stopped calling tools but did not call '${TASK_COMPLETE_TOOL_NAME}'.`;
|
||||
default:
|
||||
return 'Agent execution was terminated before completion.';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,17 +10,15 @@ import { makeFakeConfig } from '../test-utils/config.js';
|
||||
import { GeminiChat, StreamEventType } from '../core/geminiChat.js';
|
||||
import { GeminiEventType, type ServerGeminiStreamEvent } from '../core/turn.js';
|
||||
import { z } from 'zod';
|
||||
import {
|
||||
AgentTerminateMode,
|
||||
type LocalAgentDefinition,
|
||||
} from './types.js';
|
||||
import { AgentTerminateMode, type LocalAgentDefinition } from './types.js';
|
||||
import { scheduleAgentTools } from './agent-scheduler.js';
|
||||
import { logAgentFinish } from '../telemetry/loggers.js';
|
||||
import { type Config } from '../config/config.js';
|
||||
import { MainAgentBehavior, SubagentBehavior } from './behavior.js';
|
||||
|
||||
vi.mock('../telemetry/loggers.js', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('../telemetry/loggers.js')>();
|
||||
const actual =
|
||||
await importOriginal<typeof import('../telemetry/loggers.js')>();
|
||||
return {
|
||||
...actual,
|
||||
logAgentStart: vi.fn(),
|
||||
@@ -59,32 +57,38 @@ describe('AgentHarness', () => {
|
||||
mockConfig.getIdeMode = vi.fn().mockReturnValue(false);
|
||||
mockConfig.getBaseLlmClient = vi.fn().mockReturnValue({});
|
||||
mockConfig.getModelRouterService = vi.fn().mockReturnValue({
|
||||
route: vi.fn().mockResolvedValue({ model: 'gemini-test-model', metadata: { source: 'test' } }),
|
||||
route: vi
|
||||
.fn()
|
||||
.mockResolvedValue({
|
||||
model: 'gemini-test-model',
|
||||
metadata: { source: 'test' },
|
||||
}),
|
||||
});
|
||||
|
||||
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('SubagentBehavior', () => {
|
||||
it('executes a subagent and finishes when complete_task is called', async () => {
|
||||
const definition: LocalAgentDefinition<z.ZodString> = {
|
||||
const definition: LocalAgentDefinition<z.ZodUnknown> = {
|
||||
kind: 'local',
|
||||
name: 'test-agent',
|
||||
displayName: 'Test Agent',
|
||||
description: 'A test agent',
|
||||
inputConfig: { inputSchema: { type: 'object', properties: {}, required: [] } },
|
||||
inputConfig: {
|
||||
inputSchema: { type: 'object', properties: {}, required: [] },
|
||||
},
|
||||
modelConfig: { model: 'gemini-test-model' },
|
||||
runConfig: { maxTurns: 5, maxTimeMinutes: 5 },
|
||||
promptConfig: { systemPrompt: 'You are a test agent.' },
|
||||
outputConfig: {
|
||||
outputName: 'result',
|
||||
description: 'The final result.',
|
||||
schema: z.string(),
|
||||
schema: z.unknown(),
|
||||
},
|
||||
};
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const behavior = new SubagentBehavior(mockConfig, definition as any);
|
||||
const behavior = new SubagentBehavior(mockConfig, definition);
|
||||
const harness = new AgentHarness({
|
||||
config: mockConfig,
|
||||
behavior,
|
||||
@@ -108,8 +112,19 @@ describe('AgentHarness', () => {
|
||||
yield {
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
candidates: [{ content: { parts: [{ text: 'Done!' }] }, finishReason: 'STOP' }],
|
||||
functionCalls: [{ name: 'complete_task', args: { result: 'Success' }, id: 'call_1' }],
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Done!' }] },
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
functionCalls: [
|
||||
{
|
||||
name: 'complete_task',
|
||||
args: { result: 'Success' },
|
||||
id: 'call_1',
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
})(),
|
||||
@@ -118,18 +133,31 @@ describe('AgentHarness', () => {
|
||||
// Mock tool execution
|
||||
(scheduleAgentTools as unknown as Mock).mockResolvedValue([
|
||||
{
|
||||
request: { name: 'complete_task', args: { result: 'Success' }, callId: 'call_1' },
|
||||
request: {
|
||||
name: 'complete_task',
|
||||
args: { result: 'Success' },
|
||||
callId: 'call_1',
|
||||
},
|
||||
status: 'success',
|
||||
response: {
|
||||
responseParts: [{
|
||||
functionResponse: { name: 'complete_task', response: { status: 'OK' }, id: 'call_1' },
|
||||
}],
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: 'complete_task',
|
||||
response: { status: 'OK' },
|
||||
id: 'call_1',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
const events: ServerGeminiStreamEvent[] = [];
|
||||
const run = harness.run([{ text: 'Start' }], new AbortController().signal);
|
||||
const run = harness.run(
|
||||
[{ text: 'Start' }],
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
while (true) {
|
||||
const { value, done } = await run.next();
|
||||
@@ -137,7 +165,13 @@ describe('AgentHarness', () => {
|
||||
events.push(value);
|
||||
}
|
||||
|
||||
expect(events.some(e => e.type === GeminiEventType.ToolCallRequest && e.value.name === 'complete_task')).toBe(true);
|
||||
expect(
|
||||
events.some(
|
||||
(e) =>
|
||||
e.type === GeminiEventType.ToolCallRequest &&
|
||||
e.value.name === 'complete_task',
|
||||
),
|
||||
).toBe(true);
|
||||
expect(vi.mocked(logAgentFinish)).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({ terminate_reason: AgentTerminateMode.GOAL }),
|
||||
@@ -163,7 +197,10 @@ describe('AgentHarness', () => {
|
||||
mockConfig.getEnableHooks = vi.fn().mockReturnValue(true);
|
||||
|
||||
const events: ServerGeminiStreamEvent[] = [];
|
||||
const run = harness.run([{ text: 'Hello' }], new AbortController().signal);
|
||||
const run = harness.run(
|
||||
[{ text: 'Hello' }],
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
while (true) {
|
||||
const { value, done } = await run.next();
|
||||
@@ -171,42 +208,63 @@ describe('AgentHarness', () => {
|
||||
events.push(value);
|
||||
}
|
||||
|
||||
expect(events.some(e => e.type === GeminiEventType.Error && e.value.error.message === 'Access denied')).toBe(true);
|
||||
expect(
|
||||
events.some(
|
||||
(e) =>
|
||||
e.type === GeminiEventType.Error &&
|
||||
e.value.error.message === 'Access denied',
|
||||
),
|
||||
).toBe(true);
|
||||
expect(vi.mocked(logAgentFinish)).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({ terminate_reason: AgentTerminateMode.ABORTED }),
|
||||
expect.objectContaining({
|
||||
terminate_reason: AgentTerminateMode.ABORTED,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('syncs IDE context when IDE mode is enabled', async () => {
|
||||
const behavior = new MainAgentBehavior(mockConfig);
|
||||
const harness = new AgentHarness({ config: mockConfig, behavior });
|
||||
const behavior = new MainAgentBehavior(mockConfig);
|
||||
const harness = new AgentHarness({ config: mockConfig, behavior });
|
||||
|
||||
mockConfig.getIdeMode = vi.fn().mockReturnValue(true);
|
||||
|
||||
const mockChat = {
|
||||
sendMessageStream: vi.fn().mockResolvedValue((async function* () {
|
||||
yield { type: StreamEventType.CHUNK, value: { candidates: [{ content: { parts: [{ text: 'Response' }] }, finishReason: 'STOP' }] } };
|
||||
})()),
|
||||
setTools: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
addHistory: vi.fn(),
|
||||
setSystemInstruction: vi.fn(),
|
||||
getLastPromptTokenCount: vi.fn().mockReturnValue(0),
|
||||
} as unknown as GeminiChat;
|
||||
(GeminiChat as unknown as Mock).mockReturnValue(mockChat);
|
||||
mockConfig.getIdeMode = vi.fn().mockReturnValue(true);
|
||||
|
||||
// We can't easily mock ideContextStore.get() if it's not exported as a mockable object easily,
|
||||
// but we can at least verify that syncEnvironment is called by harness.
|
||||
const syncSpy = vi.spyOn(behavior, 'syncEnvironment');
|
||||
const mockChat = {
|
||||
sendMessageStream: vi.fn().mockResolvedValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Response' }] },
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
})(),
|
||||
),
|
||||
setTools: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
addHistory: vi.fn(),
|
||||
setSystemInstruction: vi.fn(),
|
||||
getLastPromptTokenCount: vi.fn().mockReturnValue(0),
|
||||
} as unknown as GeminiChat;
|
||||
(GeminiChat as unknown as Mock).mockReturnValue(mockChat);
|
||||
|
||||
const run = harness.run([{ text: 'Hello' }], new AbortController().signal);
|
||||
while (true) {
|
||||
const { done } = await run.next();
|
||||
if (done) break;
|
||||
}
|
||||
const syncSpy = vi.spyOn(behavior, 'syncEnvironment');
|
||||
|
||||
expect(syncSpy).toHaveBeenCalled();
|
||||
const run = harness.run(
|
||||
[{ text: 'Hello' }],
|
||||
new AbortController().signal,
|
||||
);
|
||||
while (true) {
|
||||
const { done } = await run.next();
|
||||
if (done) break;
|
||||
}
|
||||
|
||||
expect(syncSpy).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -63,7 +63,7 @@ import {
|
||||
} from '../availability/policyHelpers.js';
|
||||
import { resolveModel } from '../config/models.js';
|
||||
import type { RetryAvailabilityContext } from '../utils/retry.js';
|
||||
import { partToString } from '../utils/partUtils.js';
|
||||
import { partToString, toPartArray } from '../utils/partUtils.js';
|
||||
import { coreEvents, CoreEvent } from '../utils/events.js';
|
||||
import { AgentFactory } from '../agents/agent-factory.js';
|
||||
import { type AgentHarness } from '../agents/harness.js';
|
||||
@@ -807,14 +807,13 @@ export class GeminiClient {
|
||||
}
|
||||
|
||||
if (!this.harness || this.lastPromptId !== prompt_id) {
|
||||
this.harness = AgentFactory.createHarness(this.config, undefined, {
|
||||
parentPromptId: prompt_id
|
||||
});
|
||||
this.lastPromptId = prompt_id;
|
||||
this.harness = AgentFactory.createHarness(this.config, undefined, {
|
||||
parentPromptId: prompt_id,
|
||||
});
|
||||
this.lastPromptId = prompt_id;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
|
||||
const requestParts: Part[] = (Array.isArray(request) ? request : [{ text: partToString(request) }]) as any;
|
||||
|
||||
const requestParts: Part[] = toPartArray(request);
|
||||
const stream = this.harness.run(requestParts, signal, turns);
|
||||
|
||||
let turn: Turn | undefined;
|
||||
@@ -828,10 +827,9 @@ export class GeminiClient {
|
||||
}
|
||||
|
||||
if (turn) {
|
||||
// Sync history back to GeminiClient's chat for transcript persistence
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
|
||||
this.getChat().setHistory((turn as any).chat.getHistory());
|
||||
return turn;
|
||||
// Sync history back to GeminiClient's chat for transcript persistence
|
||||
this.getChat().setHistory(turn.chat.getHistory());
|
||||
return turn;
|
||||
}
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
}
|
||||
|
||||
@@ -252,7 +252,7 @@ export class Turn {
|
||||
finishReason: FinishReason | undefined = undefined;
|
||||
|
||||
constructor(
|
||||
private readonly chat: GeminiChat,
|
||||
readonly chat: GeminiChat,
|
||||
private readonly prompt_id: string,
|
||||
) {}
|
||||
|
||||
|
||||
@@ -168,3 +168,17 @@ export function appendToLastTextPart(
|
||||
|
||||
return newPrompt;
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalizes a PartListUnion into an array of Parts.
|
||||
*/
|
||||
export function toPartArray(value: PartListUnion): Part[] {
|
||||
if (!value) return [];
|
||||
const items = Array.isArray(value) ? value : [value];
|
||||
return items.map((item) => {
|
||||
if (typeof item === 'string') {
|
||||
return { text: item };
|
||||
}
|
||||
return item;
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user