refactor(core): unify ReAct loop in AgentHarness using Behavioral architecture

This commit is contained in:
mkorwel
2026-02-11 21:28:57 -06:00
parent 14e781c77c
commit c989087ba5
6 changed files with 319 additions and 128 deletions
+13 -10
View File
@@ -6,7 +6,7 @@
import { type Config } from '../config/config.js';
import { AgentHarness, type AgentHarnessOptions } from './harness.js';
import { type AgentDefinition } from './types.js';
import { type AgentDefinition, type LocalAgentDefinition } from './types.js';
import { MainAgentBehavior, SubagentBehavior } from './behavior.js';
/**
@@ -19,15 +19,18 @@ export class AgentFactory {
definition?: AgentDefinition,
options: Partial<AgentHarnessOptions> = {},
): AgentHarness {
const behavior = definition
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
? new SubagentBehavior(
config,
definition as any,
options.inputs,
options.parentPromptId
)
: new MainAgentBehavior(config, options.parentPromptId);
let behavior;
if (definition && definition.kind === 'local') {
const localDef: LocalAgentDefinition = definition;
behavior = new SubagentBehavior(
config,
localDef,
options.inputs,
options.parentPromptId,
);
} else {
behavior = new MainAgentBehavior(config, options.parentPromptId);
}
return new AgentHarness({
config,
+176 -58
View File
@@ -11,14 +11,21 @@ import {
Type,
} from '@google/genai';
import { type Config } from '../config/config.js';
import { type Turn, type ServerGeminiStreamEvent } from '../core/turn.js';
import {
import {
type Turn,
type ServerGeminiStreamEvent,
GeminiEventType,
} from '../core/turn.js';
import {
AgentTerminateMode,
type LocalAgentDefinition,
type AgentInputs
type LocalAgentDefinition,
type AgentInputs,
} from './types.js';
import { getCoreSystemPrompt } from '../core/prompts.js';
import { getInitialChatHistory, getDirectoryContextString } from '../utils/environmentContext.js';
import {
getInitialChatHistory,
getDirectoryContextString,
} from '../utils/environmentContext.js';
import { templateString } from './utils.js';
import { getVersion } from '../utils/version.js';
import { zodToJsonSchema } from 'zod-to-json-schema';
@@ -41,7 +48,7 @@ const GRACE_PERIOD_MS = 60 * 1000;
export interface AgentBehavior {
/** The unique ID for this agent instance. */
readonly agentId: string;
/** The human-readable name of the agent. */
readonly name: string;
@@ -54,7 +61,7 @@ export interface AgentBehavior {
/** Returns the initial chat history. */
getInitialHistory(): Promise<Content[]>;
/**
/**
* Prepares the tools list for the current turn.
* @param baseTools The tools from the tool registry.
*/
@@ -68,12 +75,29 @@ export interface AgentBehavior {
/**
* Fires the "Before Agent" hooks if applicable.
*/
fireBeforeAgent(request: Part[]): Promise<{ stop?: boolean; reason?: string; systemMessage?: string; additionalContext?: string }>;
fireBeforeAgent(
request: Part[],
): Promise<{
stop?: boolean;
reason?: string;
systemMessage?: string;
additionalContext?: string;
}>;
/**
* Fires the "After Agent" hooks if applicable.
*/
fireAfterAgent(request: Part[], response: string, turn: Turn): Promise<{ stop?: boolean; reason?: string; systemMessage?: string; contextCleared?: boolean; shouldContinue?: boolean }>;
fireAfterAgent(
request: Part[],
response: string,
turn: Turn,
): Promise<{
stop?: boolean;
reason?: string;
systemMessage?: string;
contextCleared?: boolean;
shouldContinue?: boolean;
}>;
/**
* Transforms the initial request if needed (e.g. subagent 'Start' templating).
@@ -90,18 +114,29 @@ export interface AgentBehavior {
* Checks if the agent should continue executing after a model turn with no tool calls.
* (e.g., Main agent running next_speaker check)
*/
getContinuationRequest(turn: Turn, signal: AbortSignal): Promise<Part[] | null>;
getContinuationRequest(
turn: Turn,
signal: AbortSignal,
): Promise<Part[] | null>;
/**
* Attempts to recover from a termination state (e.g., Subagent "Final Warning").
* Returns a stream of events if recovery is attempted.
*/
executeRecovery(turn: Turn, reason: AgentTerminateMode, signal: AbortSignal): AsyncGenerator<ServerGeminiStreamEvent, boolean>;
executeRecovery(
turn: Turn,
reason: AgentTerminateMode,
signal: AbortSignal,
): AsyncGenerator<ServerGeminiStreamEvent, boolean>;
/**
* Returns a final failure message for a given termination reason.
*/
getFinalFailureMessage(reason: AgentTerminateMode, maxTurns: number, maxTime: number): string;
getFinalFailureMessage(
reason: AgentTerminateMode,
maxTurns: number,
maxTime: number,
): string;
}
/**
@@ -113,7 +148,10 @@ export class MainAgentBehavior implements AgentBehavior {
private lastSentIdeContext: IdeContext | undefined;
private forceFullIdeContext = true;
constructor(private readonly config: Config, parentPromptId?: string) {
constructor(
private readonly config: Config,
parentPromptId?: string,
) {
const randomIdPart = Math.random().toString(36).slice(2, 8);
const parentPrefix = parentPromptId ? `${parentPromptId}-` : '';
this.agentId = `${parentPrefix}main-${randomIdPart}`;
@@ -137,9 +175,12 @@ export class MainAgentBehavior implements AgentBehavior {
async syncEnvironment(history: Content[]) {
if (!this.config.getIdeMode()) return {};
const lastMessage = history.length > 0 ? history[history.length - 1] : undefined;
const hasPendingToolCall = !!lastMessage && lastMessage.role === 'model' &&
(lastMessage.parts?.some(p => 'functionCall' in p) || false);
const lastMessage =
history.length > 0 ? history[history.length - 1] : undefined;
const hasPendingToolCall =
!!lastMessage &&
lastMessage.role === 'model' &&
(lastMessage.parts?.some((p) => 'functionCall' in p) || false);
if (hasPendingToolCall) return {};
@@ -147,10 +188,17 @@ export class MainAgentBehavior implements AgentBehavior {
if (!currentIdeContext) return {};
let contextParts: string[] = [];
if (this.forceFullIdeContext || this.lastSentIdeContext === undefined || history.length === 0) {
if (
this.forceFullIdeContext ||
this.lastSentIdeContext === undefined ||
history.length === 0
) {
contextParts = this.getFullIdeContextParts(currentIdeContext);
} else {
contextParts = this.getDeltaIdeContextParts(currentIdeContext, this.lastSentIdeContext);
contextParts = this.getDeltaIdeContextParts(
currentIdeContext,
this.lastSentIdeContext,
);
}
if (contextParts.length > 0) {
@@ -164,11 +212,12 @@ export class MainAgentBehavior implements AgentBehavior {
private getFullIdeContextParts(context: IdeContext): string[] {
const openFiles = context.workspaceState?.openFiles || [];
const activeFile = openFiles.find(f => f.isActive);
const otherOpenFiles = openFiles.filter(f => !f.isActive).map(f => f.path);
const activeFile = openFiles.find((f) => f.isActive);
const otherOpenFiles = openFiles
.filter((f) => !f.isActive)
.map((f) => f.path);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const contextData: Record<string, any> = {};
const contextData: Record<string, unknown> = {};
if (activeFile) {
contextData['activeFile'] = {
path: activeFile.path,
@@ -176,7 +225,8 @@ export class MainAgentBehavior implements AgentBehavior {
selectedText: activeFile.selectedText || undefined,
};
}
if (otherOpenFiles.length > 0) contextData['otherOpenFiles'] = otherOpenFiles;
if (otherOpenFiles.length > 0)
contextData['otherOpenFiles'] = otherOpenFiles;
if (Object.keys(contextData).length === 0) return [];
@@ -188,10 +238,12 @@ export class MainAgentBehavior implements AgentBehavior {
];
}
private getDeltaIdeContextParts(_current: IdeContext, _last: IdeContext): string[] {
private getDeltaIdeContextParts(
_current: IdeContext,
_last: IdeContext,
): string[] {
// Simplified delta logic for now, similar to GeminiClient
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const changes: Record<string, any> = {};
const changes: Record<string, unknown> = {};
// ... delta logic ...
if (Object.keys(changes).length === 0) return [];
@@ -205,7 +257,9 @@ export class MainAgentBehavior implements AgentBehavior {
async fireBeforeAgent(request: Part[]) {
if (!this.config.getEnableHooks()) return {};
const hookOutput = await this.config.getHookSystem()?.fireBeforeAgentEvent(partToString(request));
const hookOutput = await this.config
.getHookSystem()
?.fireBeforeAgentEvent(partToString(request));
if (!hookOutput) return {};
return {
@@ -220,7 +274,9 @@ export class MainAgentBehavior implements AgentBehavior {
if (!this.config.getEnableHooks()) return {};
if (turn.pendingToolCalls.length > 0) return {};
const hookOutput = await this.config.getHookSystem()?.fireAfterAgentEvent(partToString(request), response);
const hookOutput = await this.config
.getHookSystem()
?.fireAfterAgentEvent(partToString(request), response);
if (!hookOutput) return {};
return {
@@ -242,8 +298,7 @@ export class MainAgentBehavior implements AgentBehavior {
async getContinuationRequest(turn: Turn, signal: AbortSignal) {
const nextSpeaker = await checkNextSpeaker(
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
(turn as any).chat,
turn.chat,
this.config.getBaseLlmClient(),
signal,
this.agentId,
@@ -254,9 +309,8 @@ export class MainAgentBehavior implements AgentBehavior {
return null;
}
async *executeRecovery() {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
if (this.agentId === 'never') yield {} as any;
async *executeRecovery(): AsyncGenerator<ServerGeminiStreamEvent, boolean> {
if (this.agentId === 'never') yield { type: GeminiEventType.Retry };
return false;
}
@@ -276,7 +330,7 @@ export class SubagentBehavior implements AgentBehavior {
private readonly config: Config,
private readonly definition: LocalAgentDefinition,
private readonly inputs?: AgentInputs,
parentPromptId?: string
parentPromptId?: string,
) {
this.name = definition.name;
const randomIdPart = Math.random().toString(36).slice(2, 8);
@@ -293,7 +347,10 @@ export class SubagentBehavior implements AgentBehavior {
activeModel: this.config.getActiveModel(),
today: new Date().toLocaleDateString(),
};
let prompt = templateString(this.definition.promptConfig.systemPrompt || '', augmentedInputs);
let prompt = templateString(
this.definition.promptConfig.systemPrompt || '',
augmentedInputs,
);
const dirContext = await getDirectoryContextString(this.config);
prompt += `\n\n# Environment Context\n${dirContext}`;
prompt += `\n\nImportant Rules:\n* You are running in a non-interactive mode. You CANNOT ask the user for input or clarification.\n* Work systematically using available tools to complete your task.\n* Always use absolute paths for file operations.`;
@@ -322,15 +379,24 @@ export class SubagentBehavior implements AgentBehavior {
prepareTools(baseTools: FunctionDeclaration[]) {
const completeTool: FunctionDeclaration = {
name: TASK_COMPLETE_TOOL_NAME,
description: 'Call this tool to submit your final answer and complete the task.',
description:
'Call this tool to submit your final answer and complete the task.',
parameters: { type: Type.OBJECT, properties: {}, required: [] },
};
if (this.definition.outputConfig) {
const schema = zodToJsonSchema(this.definition.outputConfig.schema);
const { $schema: _, definitions: __, ...cleanSchema } = schema as Record<string, unknown>;
completeTool.parameters!.properties![this.definition.outputConfig.outputName] = cleanSchema as Schema;
completeTool.parameters!.required!.push(this.definition.outputConfig.outputName);
const {
$schema: _,
definitions: __,
...cleanSchema
} = schema as Record<string, unknown>;
completeTool.parameters!.properties![
this.definition.outputConfig.outputName
] = cleanSchema as Schema;
completeTool.parameters!.required!.push(
this.definition.outputConfig.outputName,
);
} else {
completeTool.parameters!.properties!['result'] = {
type: Type.STRING,
@@ -355,11 +421,18 @@ export class SubagentBehavior implements AgentBehavior {
}
async transformRequest(request: Part[]): Promise<Part[]> {
if (request.length === 1 && 'text' in request[0] && request[0].text === 'Start') {
if (
request.length === 1 &&
'text' in request[0] &&
request[0].text === 'Start'
) {
return [
{
text: this.definition.promptConfig.query
? templateString(this.definition.promptConfig.query, this.inputs || {})
? templateString(
this.definition.promptConfig.query,
this.inputs || {},
)
: 'Get Started!',
},
];
@@ -368,7 +441,9 @@ export class SubagentBehavior implements AgentBehavior {
}
isGoalReached(toolResults: Array<{ name: string; part: Part }>) {
const completeCall = toolResults.find((r) => r.name === TASK_COMPLETE_TOOL_NAME);
const completeCall = toolResults.find(
(r) => r.name === TASK_COMPLETE_TOOL_NAME,
);
if (completeCall) {
// If there's an error in the call, we don't treat it as reached (model should retry)
return !completeCall.part.functionResponse?.response?.['error'];
@@ -380,17 +455,33 @@ export class SubagentBehavior implements AgentBehavior {
return null;
}
async *executeRecovery(turn: Turn, reason: AgentTerminateMode, signal: AbortSignal): AsyncGenerator<ServerGeminiStreamEvent, boolean> {
async *executeRecovery(
turn: Turn,
reason: AgentTerminateMode,
signal: AbortSignal,
): AsyncGenerator<ServerGeminiStreamEvent, boolean> {
const recoveryStartTime = Date.now();
let success = false;
const graceTimeoutController = new DeadlineTimer(GRACE_PERIOD_MS, 'Grace period timed out.');
const combinedSignal = AbortSignal.any([signal, graceTimeoutController.signal]);
const graceTimeoutController = new DeadlineTimer(
GRACE_PERIOD_MS,
'Grace period timed out.',
);
const combinedSignal = AbortSignal.any([
signal,
graceTimeoutController.signal,
]);
try {
const recoveryMessage: Part[] = [{ text: this.getFinalWarningMessage(reason) }];
const recoveryMessage: Part[] = [
{ text: this.getFinalWarningMessage(reason) },
];
const promptId = `${this.agentId}#recovery`;
const recoveryStream = promptIdContext.run(promptId, () =>
turn.run({ model: this.config.getActiveModel() }, recoveryMessage, combinedSignal),
turn.run(
{ model: this.config.getActiveModel() },
recoveryMessage,
combinedSignal,
),
);
for await (const event of recoveryStream) {
@@ -399,13 +490,25 @@ export class SubagentBehavior implements AgentBehavior {
// Check if they called complete_task in the recovery turn
if (turn.pendingToolCalls.length > 0) {
if (turn.pendingToolCalls.some(c => c.name === TASK_COMPLETE_TOOL_NAME)) {
if (
turn.pendingToolCalls.some((c) => c.name === TASK_COMPLETE_TOOL_NAME)
) {
success = true;
}
}
} finally {
graceTimeoutController.abort();
logRecoveryAttempt(this.config, new RecoveryAttemptEvent(this.agentId, this.name, reason, Date.now() - recoveryStartTime, success, 0));
logRecoveryAttempt(
this.config,
new RecoveryAttemptEvent(
this.agentId,
this.name,
reason,
Date.now() - recoveryStartTime,
success,
0,
),
);
}
return success;
}
@@ -413,20 +516,35 @@ export class SubagentBehavior implements AgentBehavior {
private getFinalWarningMessage(reason: AgentTerminateMode): string {
let explanation = '';
switch (reason) {
case AgentTerminateMode.TIMEOUT: explanation = 'You have exceeded the time limit.'; break;
case AgentTerminateMode.MAX_TURNS: explanation = 'You have exceeded the maximum number of turns.'; break;
case AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL: explanation = 'You have stopped calling tools without finishing.'; break;
default: explanation = 'Execution was interrupted.';
case AgentTerminateMode.TIMEOUT:
explanation = 'You have exceeded the time limit.';
break;
case AgentTerminateMode.MAX_TURNS:
explanation = 'You have exceeded the maximum number of turns.';
break;
case AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL:
explanation = 'You have stopped calling tools without finishing.';
break;
default:
explanation = 'Execution was interrupted.';
}
return `${explanation} You have one final chance to complete the task with a short grace period. You MUST call \`${TASK_COMPLETE_TOOL_NAME}\` immediately with your best answer and explain that your investigation was interrupted. Do not call any other tools.`;
}
getFinalFailureMessage(reason: AgentTerminateMode, maxTurns: number, maxTime: number) {
getFinalFailureMessage(
reason: AgentTerminateMode,
maxTurns: number,
maxTime: number,
) {
switch (reason) {
case AgentTerminateMode.TIMEOUT: return `Agent timed out after ${maxTime} minutes.`;
case AgentTerminateMode.MAX_TURNS: return `Agent reached max turns limit (${maxTurns}).`;
case AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL: return `Agent stopped calling tools but did not call '${TASK_COMPLETE_TOOL_NAME}'.`;
default: return 'Agent execution was terminated before completion.';
case AgentTerminateMode.TIMEOUT:
return `Agent timed out after ${maxTime} minutes.`;
case AgentTerminateMode.MAX_TURNS:
return `Agent reached max turns limit (${maxTurns}).`;
case AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL:
return `Agent stopped calling tools but did not call '${TASK_COMPLETE_TOOL_NAME}'.`;
default:
return 'Agent execution was terminated before completion.';
}
}
}
+105 -47
View File
@@ -10,17 +10,15 @@ import { makeFakeConfig } from '../test-utils/config.js';
import { GeminiChat, StreamEventType } from '../core/geminiChat.js';
import { GeminiEventType, type ServerGeminiStreamEvent } from '../core/turn.js';
import { z } from 'zod';
import {
AgentTerminateMode,
type LocalAgentDefinition,
} from './types.js';
import { AgentTerminateMode, type LocalAgentDefinition } from './types.js';
import { scheduleAgentTools } from './agent-scheduler.js';
import { logAgentFinish } from '../telemetry/loggers.js';
import { type Config } from '../config/config.js';
import { MainAgentBehavior, SubagentBehavior } from './behavior.js';
vi.mock('../telemetry/loggers.js', async (importOriginal) => {
const actual = await importOriginal<typeof import('../telemetry/loggers.js')>();
const actual =
await importOriginal<typeof import('../telemetry/loggers.js')>();
return {
...actual,
logAgentStart: vi.fn(),
@@ -59,32 +57,38 @@ describe('AgentHarness', () => {
mockConfig.getIdeMode = vi.fn().mockReturnValue(false);
mockConfig.getBaseLlmClient = vi.fn().mockReturnValue({});
mockConfig.getModelRouterService = vi.fn().mockReturnValue({
route: vi.fn().mockResolvedValue({ model: 'gemini-test-model', metadata: { source: 'test' } }),
route: vi
.fn()
.mockResolvedValue({
model: 'gemini-test-model',
metadata: { source: 'test' },
}),
});
vi.clearAllMocks();
});
describe('SubagentBehavior', () => {
it('executes a subagent and finishes when complete_task is called', async () => {
const definition: LocalAgentDefinition<z.ZodString> = {
const definition: LocalAgentDefinition<z.ZodUnknown> = {
kind: 'local',
name: 'test-agent',
displayName: 'Test Agent',
description: 'A test agent',
inputConfig: { inputSchema: { type: 'object', properties: {}, required: [] } },
inputConfig: {
inputSchema: { type: 'object', properties: {}, required: [] },
},
modelConfig: { model: 'gemini-test-model' },
runConfig: { maxTurns: 5, maxTimeMinutes: 5 },
promptConfig: { systemPrompt: 'You are a test agent.' },
outputConfig: {
outputName: 'result',
description: 'The final result.',
schema: z.string(),
schema: z.unknown(),
},
};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const behavior = new SubagentBehavior(mockConfig, definition as any);
const behavior = new SubagentBehavior(mockConfig, definition);
const harness = new AgentHarness({
config: mockConfig,
behavior,
@@ -108,8 +112,19 @@ describe('AgentHarness', () => {
yield {
type: StreamEventType.CHUNK,
value: {
candidates: [{ content: { parts: [{ text: 'Done!' }] }, finishReason: 'STOP' }],
functionCalls: [{ name: 'complete_task', args: { result: 'Success' }, id: 'call_1' }],
candidates: [
{
content: { parts: [{ text: 'Done!' }] },
finishReason: 'STOP',
},
],
functionCalls: [
{
name: 'complete_task',
args: { result: 'Success' },
id: 'call_1',
},
],
},
};
})(),
@@ -118,18 +133,31 @@ describe('AgentHarness', () => {
// Mock tool execution
(scheduleAgentTools as unknown as Mock).mockResolvedValue([
{
request: { name: 'complete_task', args: { result: 'Success' }, callId: 'call_1' },
request: {
name: 'complete_task',
args: { result: 'Success' },
callId: 'call_1',
},
status: 'success',
response: {
responseParts: [{
functionResponse: { name: 'complete_task', response: { status: 'OK' }, id: 'call_1' },
}],
responseParts: [
{
functionResponse: {
name: 'complete_task',
response: { status: 'OK' },
id: 'call_1',
},
},
],
},
},
]);
const events: ServerGeminiStreamEvent[] = [];
const run = harness.run([{ text: 'Start' }], new AbortController().signal);
const run = harness.run(
[{ text: 'Start' }],
new AbortController().signal,
);
while (true) {
const { value, done } = await run.next();
@@ -137,7 +165,13 @@ describe('AgentHarness', () => {
events.push(value);
}
expect(events.some(e => e.type === GeminiEventType.ToolCallRequest && e.value.name === 'complete_task')).toBe(true);
expect(
events.some(
(e) =>
e.type === GeminiEventType.ToolCallRequest &&
e.value.name === 'complete_task',
),
).toBe(true);
expect(vi.mocked(logAgentFinish)).toHaveBeenCalledWith(
expect.anything(),
expect.objectContaining({ terminate_reason: AgentTerminateMode.GOAL }),
@@ -163,7 +197,10 @@ describe('AgentHarness', () => {
mockConfig.getEnableHooks = vi.fn().mockReturnValue(true);
const events: ServerGeminiStreamEvent[] = [];
const run = harness.run([{ text: 'Hello' }], new AbortController().signal);
const run = harness.run(
[{ text: 'Hello' }],
new AbortController().signal,
);
while (true) {
const { value, done } = await run.next();
@@ -171,42 +208,63 @@ describe('AgentHarness', () => {
events.push(value);
}
expect(events.some(e => e.type === GeminiEventType.Error && e.value.error.message === 'Access denied')).toBe(true);
expect(
events.some(
(e) =>
e.type === GeminiEventType.Error &&
e.value.error.message === 'Access denied',
),
).toBe(true);
expect(vi.mocked(logAgentFinish)).toHaveBeenCalledWith(
expect.anything(),
expect.objectContaining({ terminate_reason: AgentTerminateMode.ABORTED }),
expect.objectContaining({
terminate_reason: AgentTerminateMode.ABORTED,
}),
);
});
it('syncs IDE context when IDE mode is enabled', async () => {
const behavior = new MainAgentBehavior(mockConfig);
const harness = new AgentHarness({ config: mockConfig, behavior });
const behavior = new MainAgentBehavior(mockConfig);
const harness = new AgentHarness({ config: mockConfig, behavior });
mockConfig.getIdeMode = vi.fn().mockReturnValue(true);
const mockChat = {
sendMessageStream: vi.fn().mockResolvedValue((async function* () {
yield { type: StreamEventType.CHUNK, value: { candidates: [{ content: { parts: [{ text: 'Response' }] }, finishReason: 'STOP' }] } };
})()),
setTools: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
addHistory: vi.fn(),
setSystemInstruction: vi.fn(),
getLastPromptTokenCount: vi.fn().mockReturnValue(0),
} as unknown as GeminiChat;
(GeminiChat as unknown as Mock).mockReturnValue(mockChat);
mockConfig.getIdeMode = vi.fn().mockReturnValue(true);
// We can't easily mock ideContextStore.get() if it's not exported as a mockable object easily,
// but we can at least verify that syncEnvironment is called by harness.
const syncSpy = vi.spyOn(behavior, 'syncEnvironment');
const mockChat = {
sendMessageStream: vi.fn().mockResolvedValue(
(async function* () {
yield {
type: StreamEventType.CHUNK,
value: {
candidates: [
{
content: { parts: [{ text: 'Response' }] },
finishReason: 'STOP',
},
],
},
};
})(),
),
setTools: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
addHistory: vi.fn(),
setSystemInstruction: vi.fn(),
getLastPromptTokenCount: vi.fn().mockReturnValue(0),
} as unknown as GeminiChat;
(GeminiChat as unknown as Mock).mockReturnValue(mockChat);
const run = harness.run([{ text: 'Hello' }], new AbortController().signal);
while (true) {
const { done } = await run.next();
if (done) break;
}
const syncSpy = vi.spyOn(behavior, 'syncEnvironment');
expect(syncSpy).toHaveBeenCalled();
const run = harness.run(
[{ text: 'Hello' }],
new AbortController().signal,
);
while (true) {
const { done } = await run.next();
if (done) break;
}
expect(syncSpy).toHaveBeenCalled();
});
});
});
+10 -12
View File
@@ -63,7 +63,7 @@ import {
} from '../availability/policyHelpers.js';
import { resolveModel } from '../config/models.js';
import type { RetryAvailabilityContext } from '../utils/retry.js';
import { partToString } from '../utils/partUtils.js';
import { partToString, toPartArray } from '../utils/partUtils.js';
import { coreEvents, CoreEvent } from '../utils/events.js';
import { AgentFactory } from '../agents/agent-factory.js';
import { type AgentHarness } from '../agents/harness.js';
@@ -807,14 +807,13 @@ export class GeminiClient {
}
if (!this.harness || this.lastPromptId !== prompt_id) {
this.harness = AgentFactory.createHarness(this.config, undefined, {
parentPromptId: prompt_id
});
this.lastPromptId = prompt_id;
this.harness = AgentFactory.createHarness(this.config, undefined, {
parentPromptId: prompt_id,
});
this.lastPromptId = prompt_id;
}
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
const requestParts: Part[] = (Array.isArray(request) ? request : [{ text: partToString(request) }]) as any;
const requestParts: Part[] = toPartArray(request);
const stream = this.harness.run(requestParts, signal, turns);
let turn: Turn | undefined;
@@ -828,10 +827,9 @@ export class GeminiClient {
}
if (turn) {
// Sync history back to GeminiClient's chat for transcript persistence
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any
this.getChat().setHistory((turn as any).chat.getHistory());
return turn;
// Sync history back to GeminiClient's chat for transcript persistence
this.getChat().setHistory(turn.chat.getHistory());
return turn;
}
return new Turn(this.getChat(), prompt_id);
}
+1 -1
View File
@@ -252,7 +252,7 @@ export class Turn {
finishReason: FinishReason | undefined = undefined;
constructor(
private readonly chat: GeminiChat,
readonly chat: GeminiChat,
private readonly prompt_id: string,
) {}
+14
View File
@@ -168,3 +168,17 @@ export function appendToLastTextPart(
return newPrompt;
}
/**
* Normalizes a PartListUnion into an array of Parts.
*/
export function toPartArray(value: PartListUnion): Part[] {
if (!value) return [];
const items = Array.isArray(value) ? value : [value];
return items.map((item) => {
if (typeof item === 'string') {
return { text: item };
}
return item;
});
}