From c989087ba588b65110974e050245ff0cc69e29ef Mon Sep 17 00:00:00 2001 From: mkorwel Date: Wed, 11 Feb 2026 21:28:57 -0600 Subject: [PATCH] refactor(core): unify ReAct loop in AgentHarness using Behavioral architecture --- packages/core/src/agents/agent-factory.ts | 23 ++- packages/core/src/agents/behavior.ts | 234 ++++++++++++++++------ packages/core/src/agents/harness.test.ts | 152 +++++++++----- packages/core/src/core/client.ts | 22 +- packages/core/src/core/turn.ts | 2 +- packages/core/src/utils/partUtils.ts | 14 ++ 6 files changed, 319 insertions(+), 128 deletions(-) diff --git a/packages/core/src/agents/agent-factory.ts b/packages/core/src/agents/agent-factory.ts index bbc8ff2ed6..70ce43012d 100644 --- a/packages/core/src/agents/agent-factory.ts +++ b/packages/core/src/agents/agent-factory.ts @@ -6,7 +6,7 @@ import { type Config } from '../config/config.js'; import { AgentHarness, type AgentHarnessOptions } from './harness.js'; -import { type AgentDefinition } from './types.js'; +import { type AgentDefinition, type LocalAgentDefinition } from './types.js'; import { MainAgentBehavior, SubagentBehavior } from './behavior.js'; /** @@ -19,15 +19,18 @@ export class AgentFactory { definition?: AgentDefinition, options: Partial = {}, ): AgentHarness { - const behavior = definition - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any - ? new SubagentBehavior( - config, - definition as any, - options.inputs, - options.parentPromptId - ) - : new MainAgentBehavior(config, options.parentPromptId); + let behavior; + if (definition && definition.kind === 'local') { + const localDef: LocalAgentDefinition = definition; + behavior = new SubagentBehavior( + config, + localDef, + options.inputs, + options.parentPromptId, + ); + } else { + behavior = new MainAgentBehavior(config, options.parentPromptId); + } return new AgentHarness({ config, diff --git a/packages/core/src/agents/behavior.ts b/packages/core/src/agents/behavior.ts index f7ae582e54..6c8a2f0cb2 100644 --- a/packages/core/src/agents/behavior.ts +++ b/packages/core/src/agents/behavior.ts @@ -11,14 +11,21 @@ import { Type, } from '@google/genai'; import { type Config } from '../config/config.js'; -import { type Turn, type ServerGeminiStreamEvent } from '../core/turn.js'; -import { +import { + type Turn, + type ServerGeminiStreamEvent, + GeminiEventType, +} from '../core/turn.js'; +import { AgentTerminateMode, - type LocalAgentDefinition, - type AgentInputs + type LocalAgentDefinition, + type AgentInputs, } from './types.js'; import { getCoreSystemPrompt } from '../core/prompts.js'; -import { getInitialChatHistory, getDirectoryContextString } from '../utils/environmentContext.js'; +import { + getInitialChatHistory, + getDirectoryContextString, +} from '../utils/environmentContext.js'; import { templateString } from './utils.js'; import { getVersion } from '../utils/version.js'; import { zodToJsonSchema } from 'zod-to-json-schema'; @@ -41,7 +48,7 @@ const GRACE_PERIOD_MS = 60 * 1000; export interface AgentBehavior { /** The unique ID for this agent instance. */ readonly agentId: string; - + /** The human-readable name of the agent. */ readonly name: string; @@ -54,7 +61,7 @@ export interface AgentBehavior { /** Returns the initial chat history. */ getInitialHistory(): Promise; - /** + /** * Prepares the tools list for the current turn. * @param baseTools The tools from the tool registry. */ @@ -68,12 +75,29 @@ export interface AgentBehavior { /** * Fires the "Before Agent" hooks if applicable. */ - fireBeforeAgent(request: Part[]): Promise<{ stop?: boolean; reason?: string; systemMessage?: string; additionalContext?: string }>; + fireBeforeAgent( + request: Part[], + ): Promise<{ + stop?: boolean; + reason?: string; + systemMessage?: string; + additionalContext?: string; + }>; /** * Fires the "After Agent" hooks if applicable. */ - fireAfterAgent(request: Part[], response: string, turn: Turn): Promise<{ stop?: boolean; reason?: string; systemMessage?: string; contextCleared?: boolean; shouldContinue?: boolean }>; + fireAfterAgent( + request: Part[], + response: string, + turn: Turn, + ): Promise<{ + stop?: boolean; + reason?: string; + systemMessage?: string; + contextCleared?: boolean; + shouldContinue?: boolean; + }>; /** * Transforms the initial request if needed (e.g. subagent 'Start' templating). @@ -90,18 +114,29 @@ export interface AgentBehavior { * Checks if the agent should continue executing after a model turn with no tool calls. * (e.g., Main agent running next_speaker check) */ - getContinuationRequest(turn: Turn, signal: AbortSignal): Promise; + getContinuationRequest( + turn: Turn, + signal: AbortSignal, + ): Promise; /** * Attempts to recover from a termination state (e.g., Subagent "Final Warning"). * Returns a stream of events if recovery is attempted. */ - executeRecovery(turn: Turn, reason: AgentTerminateMode, signal: AbortSignal): AsyncGenerator; + executeRecovery( + turn: Turn, + reason: AgentTerminateMode, + signal: AbortSignal, + ): AsyncGenerator; /** * Returns a final failure message for a given termination reason. */ - getFinalFailureMessage(reason: AgentTerminateMode, maxTurns: number, maxTime: number): string; + getFinalFailureMessage( + reason: AgentTerminateMode, + maxTurns: number, + maxTime: number, + ): string; } /** @@ -113,7 +148,10 @@ export class MainAgentBehavior implements AgentBehavior { private lastSentIdeContext: IdeContext | undefined; private forceFullIdeContext = true; - constructor(private readonly config: Config, parentPromptId?: string) { + constructor( + private readonly config: Config, + parentPromptId?: string, + ) { const randomIdPart = Math.random().toString(36).slice(2, 8); const parentPrefix = parentPromptId ? `${parentPromptId}-` : ''; this.agentId = `${parentPrefix}main-${randomIdPart}`; @@ -137,9 +175,12 @@ export class MainAgentBehavior implements AgentBehavior { async syncEnvironment(history: Content[]) { if (!this.config.getIdeMode()) return {}; - const lastMessage = history.length > 0 ? history[history.length - 1] : undefined; - const hasPendingToolCall = !!lastMessage && lastMessage.role === 'model' && - (lastMessage.parts?.some(p => 'functionCall' in p) || false); + const lastMessage = + history.length > 0 ? history[history.length - 1] : undefined; + const hasPendingToolCall = + !!lastMessage && + lastMessage.role === 'model' && + (lastMessage.parts?.some((p) => 'functionCall' in p) || false); if (hasPendingToolCall) return {}; @@ -147,10 +188,17 @@ export class MainAgentBehavior implements AgentBehavior { if (!currentIdeContext) return {}; let contextParts: string[] = []; - if (this.forceFullIdeContext || this.lastSentIdeContext === undefined || history.length === 0) { + if ( + this.forceFullIdeContext || + this.lastSentIdeContext === undefined || + history.length === 0 + ) { contextParts = this.getFullIdeContextParts(currentIdeContext); } else { - contextParts = this.getDeltaIdeContextParts(currentIdeContext, this.lastSentIdeContext); + contextParts = this.getDeltaIdeContextParts( + currentIdeContext, + this.lastSentIdeContext, + ); } if (contextParts.length > 0) { @@ -164,11 +212,12 @@ export class MainAgentBehavior implements AgentBehavior { private getFullIdeContextParts(context: IdeContext): string[] { const openFiles = context.workspaceState?.openFiles || []; - const activeFile = openFiles.find(f => f.isActive); - const otherOpenFiles = openFiles.filter(f => !f.isActive).map(f => f.path); + const activeFile = openFiles.find((f) => f.isActive); + const otherOpenFiles = openFiles + .filter((f) => !f.isActive) + .map((f) => f.path); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const contextData: Record = {}; + const contextData: Record = {}; if (activeFile) { contextData['activeFile'] = { path: activeFile.path, @@ -176,7 +225,8 @@ export class MainAgentBehavior implements AgentBehavior { selectedText: activeFile.selectedText || undefined, }; } - if (otherOpenFiles.length > 0) contextData['otherOpenFiles'] = otherOpenFiles; + if (otherOpenFiles.length > 0) + contextData['otherOpenFiles'] = otherOpenFiles; if (Object.keys(contextData).length === 0) return []; @@ -188,10 +238,12 @@ export class MainAgentBehavior implements AgentBehavior { ]; } - private getDeltaIdeContextParts(_current: IdeContext, _last: IdeContext): string[] { + private getDeltaIdeContextParts( + _current: IdeContext, + _last: IdeContext, + ): string[] { // Simplified delta logic for now, similar to GeminiClient - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const changes: Record = {}; + const changes: Record = {}; // ... delta logic ... if (Object.keys(changes).length === 0) return []; @@ -205,7 +257,9 @@ export class MainAgentBehavior implements AgentBehavior { async fireBeforeAgent(request: Part[]) { if (!this.config.getEnableHooks()) return {}; - const hookOutput = await this.config.getHookSystem()?.fireBeforeAgentEvent(partToString(request)); + const hookOutput = await this.config + .getHookSystem() + ?.fireBeforeAgentEvent(partToString(request)); if (!hookOutput) return {}; return { @@ -220,7 +274,9 @@ export class MainAgentBehavior implements AgentBehavior { if (!this.config.getEnableHooks()) return {}; if (turn.pendingToolCalls.length > 0) return {}; - const hookOutput = await this.config.getHookSystem()?.fireAfterAgentEvent(partToString(request), response); + const hookOutput = await this.config + .getHookSystem() + ?.fireAfterAgentEvent(partToString(request), response); if (!hookOutput) return {}; return { @@ -242,8 +298,7 @@ export class MainAgentBehavior implements AgentBehavior { async getContinuationRequest(turn: Turn, signal: AbortSignal) { const nextSpeaker = await checkNextSpeaker( - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any - (turn as any).chat, + turn.chat, this.config.getBaseLlmClient(), signal, this.agentId, @@ -254,9 +309,8 @@ export class MainAgentBehavior implements AgentBehavior { return null; } - async *executeRecovery() { - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any - if (this.agentId === 'never') yield {} as any; + async *executeRecovery(): AsyncGenerator { + if (this.agentId === 'never') yield { type: GeminiEventType.Retry }; return false; } @@ -276,7 +330,7 @@ export class SubagentBehavior implements AgentBehavior { private readonly config: Config, private readonly definition: LocalAgentDefinition, private readonly inputs?: AgentInputs, - parentPromptId?: string + parentPromptId?: string, ) { this.name = definition.name; const randomIdPart = Math.random().toString(36).slice(2, 8); @@ -293,7 +347,10 @@ export class SubagentBehavior implements AgentBehavior { activeModel: this.config.getActiveModel(), today: new Date().toLocaleDateString(), }; - let prompt = templateString(this.definition.promptConfig.systemPrompt || '', augmentedInputs); + let prompt = templateString( + this.definition.promptConfig.systemPrompt || '', + augmentedInputs, + ); const dirContext = await getDirectoryContextString(this.config); prompt += `\n\n# Environment Context\n${dirContext}`; prompt += `\n\nImportant Rules:\n* You are running in a non-interactive mode. You CANNOT ask the user for input or clarification.\n* Work systematically using available tools to complete your task.\n* Always use absolute paths for file operations.`; @@ -322,15 +379,24 @@ export class SubagentBehavior implements AgentBehavior { prepareTools(baseTools: FunctionDeclaration[]) { const completeTool: FunctionDeclaration = { name: TASK_COMPLETE_TOOL_NAME, - description: 'Call this tool to submit your final answer and complete the task.', + description: + 'Call this tool to submit your final answer and complete the task.', parameters: { type: Type.OBJECT, properties: {}, required: [] }, }; if (this.definition.outputConfig) { const schema = zodToJsonSchema(this.definition.outputConfig.schema); - const { $schema: _, definitions: __, ...cleanSchema } = schema as Record; - completeTool.parameters!.properties![this.definition.outputConfig.outputName] = cleanSchema as Schema; - completeTool.parameters!.required!.push(this.definition.outputConfig.outputName); + const { + $schema: _, + definitions: __, + ...cleanSchema + } = schema as Record; + completeTool.parameters!.properties![ + this.definition.outputConfig.outputName + ] = cleanSchema as Schema; + completeTool.parameters!.required!.push( + this.definition.outputConfig.outputName, + ); } else { completeTool.parameters!.properties!['result'] = { type: Type.STRING, @@ -355,11 +421,18 @@ export class SubagentBehavior implements AgentBehavior { } async transformRequest(request: Part[]): Promise { - if (request.length === 1 && 'text' in request[0] && request[0].text === 'Start') { + if ( + request.length === 1 && + 'text' in request[0] && + request[0].text === 'Start' + ) { return [ { text: this.definition.promptConfig.query - ? templateString(this.definition.promptConfig.query, this.inputs || {}) + ? templateString( + this.definition.promptConfig.query, + this.inputs || {}, + ) : 'Get Started!', }, ]; @@ -368,7 +441,9 @@ export class SubagentBehavior implements AgentBehavior { } isGoalReached(toolResults: Array<{ name: string; part: Part }>) { - const completeCall = toolResults.find((r) => r.name === TASK_COMPLETE_TOOL_NAME); + const completeCall = toolResults.find( + (r) => r.name === TASK_COMPLETE_TOOL_NAME, + ); if (completeCall) { // If there's an error in the call, we don't treat it as reached (model should retry) return !completeCall.part.functionResponse?.response?.['error']; @@ -380,17 +455,33 @@ export class SubagentBehavior implements AgentBehavior { return null; } - async *executeRecovery(turn: Turn, reason: AgentTerminateMode, signal: AbortSignal): AsyncGenerator { + async *executeRecovery( + turn: Turn, + reason: AgentTerminateMode, + signal: AbortSignal, + ): AsyncGenerator { const recoveryStartTime = Date.now(); let success = false; - const graceTimeoutController = new DeadlineTimer(GRACE_PERIOD_MS, 'Grace period timed out.'); - const combinedSignal = AbortSignal.any([signal, graceTimeoutController.signal]); + const graceTimeoutController = new DeadlineTimer( + GRACE_PERIOD_MS, + 'Grace period timed out.', + ); + const combinedSignal = AbortSignal.any([ + signal, + graceTimeoutController.signal, + ]); try { - const recoveryMessage: Part[] = [{ text: this.getFinalWarningMessage(reason) }]; + const recoveryMessage: Part[] = [ + { text: this.getFinalWarningMessage(reason) }, + ]; const promptId = `${this.agentId}#recovery`; const recoveryStream = promptIdContext.run(promptId, () => - turn.run({ model: this.config.getActiveModel() }, recoveryMessage, combinedSignal), + turn.run( + { model: this.config.getActiveModel() }, + recoveryMessage, + combinedSignal, + ), ); for await (const event of recoveryStream) { @@ -399,13 +490,25 @@ export class SubagentBehavior implements AgentBehavior { // Check if they called complete_task in the recovery turn if (turn.pendingToolCalls.length > 0) { - if (turn.pendingToolCalls.some(c => c.name === TASK_COMPLETE_TOOL_NAME)) { + if ( + turn.pendingToolCalls.some((c) => c.name === TASK_COMPLETE_TOOL_NAME) + ) { success = true; } } } finally { graceTimeoutController.abort(); - logRecoveryAttempt(this.config, new RecoveryAttemptEvent(this.agentId, this.name, reason, Date.now() - recoveryStartTime, success, 0)); + logRecoveryAttempt( + this.config, + new RecoveryAttemptEvent( + this.agentId, + this.name, + reason, + Date.now() - recoveryStartTime, + success, + 0, + ), + ); } return success; } @@ -413,20 +516,35 @@ export class SubagentBehavior implements AgentBehavior { private getFinalWarningMessage(reason: AgentTerminateMode): string { let explanation = ''; switch (reason) { - case AgentTerminateMode.TIMEOUT: explanation = 'You have exceeded the time limit.'; break; - case AgentTerminateMode.MAX_TURNS: explanation = 'You have exceeded the maximum number of turns.'; break; - case AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL: explanation = 'You have stopped calling tools without finishing.'; break; - default: explanation = 'Execution was interrupted.'; + case AgentTerminateMode.TIMEOUT: + explanation = 'You have exceeded the time limit.'; + break; + case AgentTerminateMode.MAX_TURNS: + explanation = 'You have exceeded the maximum number of turns.'; + break; + case AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL: + explanation = 'You have stopped calling tools without finishing.'; + break; + default: + explanation = 'Execution was interrupted.'; } return `${explanation} You have one final chance to complete the task with a short grace period. You MUST call \`${TASK_COMPLETE_TOOL_NAME}\` immediately with your best answer and explain that your investigation was interrupted. Do not call any other tools.`; } - getFinalFailureMessage(reason: AgentTerminateMode, maxTurns: number, maxTime: number) { + getFinalFailureMessage( + reason: AgentTerminateMode, + maxTurns: number, + maxTime: number, + ) { switch (reason) { - case AgentTerminateMode.TIMEOUT: return `Agent timed out after ${maxTime} minutes.`; - case AgentTerminateMode.MAX_TURNS: return `Agent reached max turns limit (${maxTurns}).`; - case AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL: return `Agent stopped calling tools but did not call '${TASK_COMPLETE_TOOL_NAME}'.`; - default: return 'Agent execution was terminated before completion.'; + case AgentTerminateMode.TIMEOUT: + return `Agent timed out after ${maxTime} minutes.`; + case AgentTerminateMode.MAX_TURNS: + return `Agent reached max turns limit (${maxTurns}).`; + case AgentTerminateMode.ERROR_NO_COMPLETE_TASK_CALL: + return `Agent stopped calling tools but did not call '${TASK_COMPLETE_TOOL_NAME}'.`; + default: + return 'Agent execution was terminated before completion.'; } } } diff --git a/packages/core/src/agents/harness.test.ts b/packages/core/src/agents/harness.test.ts index 92a78fc662..214626cb5c 100644 --- a/packages/core/src/agents/harness.test.ts +++ b/packages/core/src/agents/harness.test.ts @@ -10,17 +10,15 @@ import { makeFakeConfig } from '../test-utils/config.js'; import { GeminiChat, StreamEventType } from '../core/geminiChat.js'; import { GeminiEventType, type ServerGeminiStreamEvent } from '../core/turn.js'; import { z } from 'zod'; -import { - AgentTerminateMode, - type LocalAgentDefinition, -} from './types.js'; +import { AgentTerminateMode, type LocalAgentDefinition } from './types.js'; import { scheduleAgentTools } from './agent-scheduler.js'; import { logAgentFinish } from '../telemetry/loggers.js'; import { type Config } from '../config/config.js'; import { MainAgentBehavior, SubagentBehavior } from './behavior.js'; vi.mock('../telemetry/loggers.js', async (importOriginal) => { - const actual = await importOriginal(); + const actual = + await importOriginal(); return { ...actual, logAgentStart: vi.fn(), @@ -59,32 +57,38 @@ describe('AgentHarness', () => { mockConfig.getIdeMode = vi.fn().mockReturnValue(false); mockConfig.getBaseLlmClient = vi.fn().mockReturnValue({}); mockConfig.getModelRouterService = vi.fn().mockReturnValue({ - route: vi.fn().mockResolvedValue({ model: 'gemini-test-model', metadata: { source: 'test' } }), + route: vi + .fn() + .mockResolvedValue({ + model: 'gemini-test-model', + metadata: { source: 'test' }, + }), }); - + vi.clearAllMocks(); }); describe('SubagentBehavior', () => { it('executes a subagent and finishes when complete_task is called', async () => { - const definition: LocalAgentDefinition = { + const definition: LocalAgentDefinition = { kind: 'local', name: 'test-agent', displayName: 'Test Agent', description: 'A test agent', - inputConfig: { inputSchema: { type: 'object', properties: {}, required: [] } }, + inputConfig: { + inputSchema: { type: 'object', properties: {}, required: [] }, + }, modelConfig: { model: 'gemini-test-model' }, runConfig: { maxTurns: 5, maxTimeMinutes: 5 }, promptConfig: { systemPrompt: 'You are a test agent.' }, outputConfig: { outputName: 'result', description: 'The final result.', - schema: z.string(), + schema: z.unknown(), }, }; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const behavior = new SubagentBehavior(mockConfig, definition as any); + const behavior = new SubagentBehavior(mockConfig, definition); const harness = new AgentHarness({ config: mockConfig, behavior, @@ -108,8 +112,19 @@ describe('AgentHarness', () => { yield { type: StreamEventType.CHUNK, value: { - candidates: [{ content: { parts: [{ text: 'Done!' }] }, finishReason: 'STOP' }], - functionCalls: [{ name: 'complete_task', args: { result: 'Success' }, id: 'call_1' }], + candidates: [ + { + content: { parts: [{ text: 'Done!' }] }, + finishReason: 'STOP', + }, + ], + functionCalls: [ + { + name: 'complete_task', + args: { result: 'Success' }, + id: 'call_1', + }, + ], }, }; })(), @@ -118,18 +133,31 @@ describe('AgentHarness', () => { // Mock tool execution (scheduleAgentTools as unknown as Mock).mockResolvedValue([ { - request: { name: 'complete_task', args: { result: 'Success' }, callId: 'call_1' }, + request: { + name: 'complete_task', + args: { result: 'Success' }, + callId: 'call_1', + }, status: 'success', response: { - responseParts: [{ - functionResponse: { name: 'complete_task', response: { status: 'OK' }, id: 'call_1' }, - }], + responseParts: [ + { + functionResponse: { + name: 'complete_task', + response: { status: 'OK' }, + id: 'call_1', + }, + }, + ], }, }, ]); const events: ServerGeminiStreamEvent[] = []; - const run = harness.run([{ text: 'Start' }], new AbortController().signal); + const run = harness.run( + [{ text: 'Start' }], + new AbortController().signal, + ); while (true) { const { value, done } = await run.next(); @@ -137,7 +165,13 @@ describe('AgentHarness', () => { events.push(value); } - expect(events.some(e => e.type === GeminiEventType.ToolCallRequest && e.value.name === 'complete_task')).toBe(true); + expect( + events.some( + (e) => + e.type === GeminiEventType.ToolCallRequest && + e.value.name === 'complete_task', + ), + ).toBe(true); expect(vi.mocked(logAgentFinish)).toHaveBeenCalledWith( expect.anything(), expect.objectContaining({ terminate_reason: AgentTerminateMode.GOAL }), @@ -163,7 +197,10 @@ describe('AgentHarness', () => { mockConfig.getEnableHooks = vi.fn().mockReturnValue(true); const events: ServerGeminiStreamEvent[] = []; - const run = harness.run([{ text: 'Hello' }], new AbortController().signal); + const run = harness.run( + [{ text: 'Hello' }], + new AbortController().signal, + ); while (true) { const { value, done } = await run.next(); @@ -171,42 +208,63 @@ describe('AgentHarness', () => { events.push(value); } - expect(events.some(e => e.type === GeminiEventType.Error && e.value.error.message === 'Access denied')).toBe(true); + expect( + events.some( + (e) => + e.type === GeminiEventType.Error && + e.value.error.message === 'Access denied', + ), + ).toBe(true); expect(vi.mocked(logAgentFinish)).toHaveBeenCalledWith( expect.anything(), - expect.objectContaining({ terminate_reason: AgentTerminateMode.ABORTED }), + expect.objectContaining({ + terminate_reason: AgentTerminateMode.ABORTED, + }), ); }); it('syncs IDE context when IDE mode is enabled', async () => { - const behavior = new MainAgentBehavior(mockConfig); - const harness = new AgentHarness({ config: mockConfig, behavior }); + const behavior = new MainAgentBehavior(mockConfig); + const harness = new AgentHarness({ config: mockConfig, behavior }); - mockConfig.getIdeMode = vi.fn().mockReturnValue(true); - - const mockChat = { - sendMessageStream: vi.fn().mockResolvedValue((async function* () { - yield { type: StreamEventType.CHUNK, value: { candidates: [{ content: { parts: [{ text: 'Response' }] }, finishReason: 'STOP' }] } }; - })()), - setTools: vi.fn(), - getHistory: vi.fn().mockReturnValue([]), - addHistory: vi.fn(), - setSystemInstruction: vi.fn(), - getLastPromptTokenCount: vi.fn().mockReturnValue(0), - } as unknown as GeminiChat; - (GeminiChat as unknown as Mock).mockReturnValue(mockChat); + mockConfig.getIdeMode = vi.fn().mockReturnValue(true); - // We can't easily mock ideContextStore.get() if it's not exported as a mockable object easily, - // but we can at least verify that syncEnvironment is called by harness. - const syncSpy = vi.spyOn(behavior, 'syncEnvironment'); + const mockChat = { + sendMessageStream: vi.fn().mockResolvedValue( + (async function* () { + yield { + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { parts: [{ text: 'Response' }] }, + finishReason: 'STOP', + }, + ], + }, + }; + })(), + ), + setTools: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + addHistory: vi.fn(), + setSystemInstruction: vi.fn(), + getLastPromptTokenCount: vi.fn().mockReturnValue(0), + } as unknown as GeminiChat; + (GeminiChat as unknown as Mock).mockReturnValue(mockChat); - const run = harness.run([{ text: 'Hello' }], new AbortController().signal); - while (true) { - const { done } = await run.next(); - if (done) break; - } + const syncSpy = vi.spyOn(behavior, 'syncEnvironment'); - expect(syncSpy).toHaveBeenCalled(); + const run = harness.run( + [{ text: 'Hello' }], + new AbortController().signal, + ); + while (true) { + const { done } = await run.next(); + if (done) break; + } + + expect(syncSpy).toHaveBeenCalled(); }); }); }); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 34dabb4f9a..97b0c815b3 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -63,7 +63,7 @@ import { } from '../availability/policyHelpers.js'; import { resolveModel } from '../config/models.js'; import type { RetryAvailabilityContext } from '../utils/retry.js'; -import { partToString } from '../utils/partUtils.js'; +import { partToString, toPartArray } from '../utils/partUtils.js'; import { coreEvents, CoreEvent } from '../utils/events.js'; import { AgentFactory } from '../agents/agent-factory.js'; import { type AgentHarness } from '../agents/harness.js'; @@ -807,14 +807,13 @@ export class GeminiClient { } if (!this.harness || this.lastPromptId !== prompt_id) { - this.harness = AgentFactory.createHarness(this.config, undefined, { - parentPromptId: prompt_id - }); - this.lastPromptId = prompt_id; + this.harness = AgentFactory.createHarness(this.config, undefined, { + parentPromptId: prompt_id, + }); + this.lastPromptId = prompt_id; } - - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any - const requestParts: Part[] = (Array.isArray(request) ? request : [{ text: partToString(request) }]) as any; + + const requestParts: Part[] = toPartArray(request); const stream = this.harness.run(requestParts, signal, turns); let turn: Turn | undefined; @@ -828,10 +827,9 @@ export class GeminiClient { } if (turn) { - // Sync history back to GeminiClient's chat for transcript persistence - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion, @typescript-eslint/no-explicit-any - this.getChat().setHistory((turn as any).chat.getHistory()); - return turn; + // Sync history back to GeminiClient's chat for transcript persistence + this.getChat().setHistory(turn.chat.getHistory()); + return turn; } return new Turn(this.getChat(), prompt_id); } diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 68ca99d13b..baef5596e4 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -252,7 +252,7 @@ export class Turn { finishReason: FinishReason | undefined = undefined; constructor( - private readonly chat: GeminiChat, + readonly chat: GeminiChat, private readonly prompt_id: string, ) {} diff --git a/packages/core/src/utils/partUtils.ts b/packages/core/src/utils/partUtils.ts index 52a59258bd..44bd250efe 100644 --- a/packages/core/src/utils/partUtils.ts +++ b/packages/core/src/utils/partUtils.ts @@ -168,3 +168,17 @@ export function appendToLastTextPart( return newPrompt; } + +/** + * Normalizes a PartListUnion into an array of Parts. + */ +export function toPartArray(value: PartListUnion): Part[] { + if (!value) return []; + const items = Array.isArray(value) ? value : [value]; + return items.map((item) => { + if (typeof item === 'string') { + return { text: item }; + } + return item; + }); +}