/** * @license * Copyright 2025 Google LLC * SPDX-License-Identifier: Apache-2.0 */ import { describe, it, expect, vi, beforeEach, afterEach, type MockedClass, } from 'vitest'; import { AgentExecutor, type ActivityCallback } from './executor.js'; import type { AgentDefinition, AgentInputs, SubagentActivityEvent, } from './types.js'; import { AgentTerminateMode } from './types.js'; import { makeFakeConfig } from '../test-utils/config.js'; import { ToolRegistry } from '../tools/tool-registry.js'; import { LSTool } from '../tools/ls.js'; import { ReadFileTool } from '../tools/read-file.js'; import { GeminiChat, StreamEventType, type StreamEvent, } from '../core/geminiChat.js'; import type { FunctionCall, Part, GenerateContentResponse, } from '@google/genai'; import type { Config } from '../config/config.js'; import { MockTool } from '../test-utils/mock-tool.js'; import { getDirectoryContextString } from '../utils/environmentContext.js'; const { mockSendMessageStream, mockExecuteToolCall } = vi.hoisted(() => ({ mockSendMessageStream: vi.fn(), mockExecuteToolCall: vi.fn(), })); vi.mock('../core/geminiChat.js', async (importOriginal) => { const actual = await importOriginal(); return { ...(actual as object), GeminiChat: vi.fn().mockImplementation(() => ({ sendMessageStream: mockSendMessageStream, })), }; }); vi.mock('../core/nonInteractiveToolExecutor.js', () => ({ executeToolCall: mockExecuteToolCall, })); vi.mock('../utils/environmentContext.js'); const MockedGeminiChat = GeminiChat as MockedClass; // A mock tool that is NOT on the NON_INTERACTIVE_TOOL_ALLOWLIST const MOCK_TOOL_NOT_ALLOWED = new MockTool({ name: 'write_file' }); const createMockResponseChunk = ( parts: Part[], functionCalls?: FunctionCall[], ): GenerateContentResponse => ({ candidates: [{ index: 0, content: { role: 'model', parts } }], functionCalls, }) as unknown as GenerateContentResponse; const mockModelResponse = ( functionCalls: FunctionCall[], thought?: string, text?: string, ) => { const parts: Part[] = []; if (thought) { parts.push({ text: `**${thought}** This is the reasoning part.`, thought: true, }); } if (text) parts.push({ text }); const responseChunk = createMockResponseChunk( parts, // Ensure functionCalls is undefined if the array is empty, matching API behavior functionCalls.length > 0 ? functionCalls : undefined, ); mockSendMessageStream.mockImplementationOnce(async () => (async function* () { yield { type: StreamEventType.CHUNK, value: responseChunk, } as StreamEvent; })(), ); }; let mockConfig: Config; let parentToolRegistry: ToolRegistry; const createTestDefinition = ( tools: Array = [LSTool.Name], runConfigOverrides: Partial = {}, outputConfigOverrides: Partial = {}, ): AgentDefinition => ({ name: 'TestAgent', description: 'An agent for testing.', inputConfig: { inputs: { goal: { type: 'string', required: true, description: 'goal' } }, }, modelConfig: { model: 'gemini-test-model', temp: 0, top_p: 1 }, runConfig: { max_time_minutes: 5, max_turns: 5, ...runConfigOverrides }, promptConfig: { systemPrompt: 'Achieve the goal: ${goal}.' }, toolConfig: { tools }, outputConfig: { description: 'The final result.', ...outputConfigOverrides }, }); describe('AgentExecutor', () => { let activities: SubagentActivityEvent[]; let onActivity: ActivityCallback; let abortController: AbortController; let signal: AbortSignal; beforeEach(async () => { mockSendMessageStream.mockClear(); mockExecuteToolCall.mockClear(); vi.clearAllMocks(); // Use fake timers for timeout and concurrency testing vi.useFakeTimers(); mockConfig = makeFakeConfig(); parentToolRegistry = new ToolRegistry(mockConfig); parentToolRegistry.registerTool(new LSTool(mockConfig)); parentToolRegistry.registerTool(new ReadFileTool(mockConfig)); parentToolRegistry.registerTool(MOCK_TOOL_NOT_ALLOWED); vi.spyOn(mockConfig, 'getToolRegistry').mockResolvedValue( parentToolRegistry, ); vi.mocked(getDirectoryContextString).mockResolvedValue( 'Mocked Environment Context', ); activities = []; onActivity = (activity) => activities.push(activity); abortController = new AbortController(); signal = abortController.signal; }); afterEach(() => { vi.useRealTimers(); }); describe('create (Initialization and Validation)', () => { it('should create successfully with allowed tools', async () => { const definition = createTestDefinition([LSTool.Name]); const executor = await AgentExecutor.create( definition, mockConfig, onActivity, ); expect(executor).toBeInstanceOf(AgentExecutor); }); it('SECURITY: should throw if a tool is not on the non-interactive allowlist', async () => { const definition = createTestDefinition([MOCK_TOOL_NOT_ALLOWED.name]); await expect( AgentExecutor.create(definition, mockConfig, onActivity), ).rejects.toThrow( `Tool "${MOCK_TOOL_NOT_ALLOWED.name}" is not on the allow-list for non-interactive execution`, ); }); it('should create an isolated ToolRegistry for the agent', async () => { const definition = createTestDefinition([LSTool.Name, ReadFileTool.Name]); const executor = await AgentExecutor.create( definition, mockConfig, onActivity, ); // @ts-expect-error - accessing private property for test validation const agentRegistry = executor.toolRegistry as ToolRegistry; expect(agentRegistry).not.toBe(parentToolRegistry); expect(agentRegistry.getAllToolNames()).toEqual( expect.arrayContaining([LSTool.Name, ReadFileTool.Name]), ); expect(agentRegistry.getAllToolNames()).toHaveLength(2); expect(agentRegistry.getTool(MOCK_TOOL_NOT_ALLOWED.name)).toBeUndefined(); }); }); describe('run (Execution Loop and Logic)', () => { it('should execute a successful work and extraction phase (Happy Path) and emit activities', async () => { const definition = createTestDefinition(); const executor = await AgentExecutor.create( definition, mockConfig, onActivity, ); const inputs: AgentInputs = { goal: 'Find files' }; // Turn 1: Model calls ls mockModelResponse( [{ name: LSTool.Name, args: { path: '.' }, id: 'call1' }], 'T1: Listing', ); mockExecuteToolCall.mockResolvedValueOnce({ callId: 'call1', resultDisplay: 'file1.txt', responseParts: [ { functionResponse: { name: LSTool.Name, response: { result: 'file1.txt' }, id: 'call1', }, }, ], error: undefined, }); // Turn 2: Model stops mockModelResponse([], 'T2: Done'); // Extraction Phase mockModelResponse([], undefined, 'Result: file1.txt.'); const output = await executor.run(inputs, signal); expect(mockSendMessageStream).toHaveBeenCalledTimes(3); expect(mockExecuteToolCall).toHaveBeenCalledTimes(1); // Verify System Prompt Templating const chatConstructorArgs = MockedGeminiChat.mock.calls[0]; const chatConfig = chatConstructorArgs[1]; expect(chatConfig?.systemInstruction).toContain( 'Achieve the goal: Find files.', ); // Verify environment context is appended expect(chatConfig?.systemInstruction).toContain( '# Environment Context\nMocked Environment Context', ); // Verify standard rules are appended expect(chatConfig?.systemInstruction).toContain( 'You are running in a non-interactive mode.', ); // Verify absolute path rule is appended expect(chatConfig?.systemInstruction).toContain( 'Always use absolute paths for file operations.', ); // Verify Extraction Phase Call (Specific arguments) expect(mockSendMessageStream).toHaveBeenCalledWith( 'gemini-test-model', expect.objectContaining({ // Extraction message should be based on outputConfig.description message: expect.arrayContaining([ { text: expect.stringContaining( 'Based on your work so far, provide: The final result.', ), }, ]), config: expect.objectContaining({ tools: undefined }), // No tools in extraction }), expect.stringContaining('#extraction'), ); expect(output.result).toBe('Result: file1.txt.'); expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL); // Verify Activity Stream (Observability) expect(activities).toEqual( expect.arrayContaining([ // Thought subjects are extracted by the executor (parseThought) expect.objectContaining({ type: 'THOUGHT_CHUNK', data: { text: 'T1: Listing' }, }), expect.objectContaining({ type: 'TOOL_CALL_START', data: { name: LSTool.Name, args: { path: '.' } }, }), expect.objectContaining({ type: 'TOOL_CALL_END', data: { name: LSTool.Name, output: 'file1.txt' }, }), expect.objectContaining({ type: 'THOUGHT_CHUNK', data: { text: 'T2: Done' }, }), ]), ); }); it('should execute parallel tool calls concurrently', async () => { const definition = createTestDefinition([LSTool.Name, ReadFileTool.Name]); const executor = await AgentExecutor.create( definition, mockConfig, onActivity, ); const call1 = { name: LSTool.Name, args: { path: '/dir1' }, id: 'call1', }; // Using LSTool twice for simplicity in mocking standardized responses. const call2 = { name: LSTool.Name, args: { path: '/dir2' }, id: 'call2', }; // Turn 1: Model calls two tools simultaneously mockModelResponse([call1, call2], 'T1: Listing both'); // Use concurrency tracking to ensure parallelism let activeCalls = 0; let maxActiveCalls = 0; mockExecuteToolCall.mockImplementation(async (_ctx, reqInfo) => { activeCalls++; maxActiveCalls = Math.max(maxActiveCalls, activeCalls); // Simulate latency. We must advance the fake timers for this to resolve. await new Promise((resolve) => setTimeout(resolve, 100)); activeCalls--; return { callId: reqInfo.callId, resultDisplay: `Result for ${reqInfo.name}`, responseParts: [ { functionResponse: { name: reqInfo.name, response: {}, id: reqInfo.callId, }, }, ], error: undefined, }; }); // Turn 2: Model stops mockModelResponse([]); // Extraction mockModelResponse([], undefined, 'Done.'); const runPromise = executor.run({ goal: 'Parallel test' }, signal); // Advance timers while the parallel calls (Promise.all + setTimeout) are running await vi.advanceTimersByTimeAsync(150); await runPromise; expect(mockExecuteToolCall).toHaveBeenCalledTimes(2); expect(maxActiveCalls).toBe(2); // Verify the input to the next model call (Turn 2) contains both responses // sendMessageStream calls: [0] Turn 1, [1] Turn 2, [2] Extraction const turn2Input = mockSendMessageStream.mock.calls[1][1]; const turn2Parts = turn2Input.message as Part[]; // Promise.all preserves the order of the input array. expect(turn2Parts.length).toBe(2); expect(turn2Parts[0]).toEqual( expect.objectContaining({ functionResponse: expect.objectContaining({ id: 'call1' }), }), ); expect(turn2Parts[1]).toEqual( expect.objectContaining({ functionResponse: expect.objectContaining({ id: 'call2' }), }), ); }); it('should handle tool execution failure gracefully and report error', async () => { const definition = createTestDefinition([LSTool.Name]); const executor = await AgentExecutor.create( definition, mockConfig, onActivity, ); // Turn 1: Model calls ls, but it fails mockModelResponse([ { name: LSTool.Name, args: { path: '/invalid' }, id: 'call1' }, ]); const errorMessage = 'Internal failure.'; mockExecuteToolCall.mockResolvedValueOnce({ callId: 'call1', resultDisplay: `Error: ${errorMessage}`, responseParts: undefined, // Failed tools might return undefined parts error: { message: errorMessage }, }); // Turn 2: Model stops mockModelResponse([]); mockModelResponse([], undefined, 'Failed.'); await executor.run({ goal: 'Failure test' }, signal); // Verify that the error was reported in the activity stream expect(activities).toContainEqual( expect.objectContaining({ type: 'ERROR', data: { error: errorMessage, context: 'tool_call', name: LSTool.Name, }, }), ); // Verify the input to the next model call (Turn 2) contains the fallback error message const turn2Input = mockSendMessageStream.mock.calls[1][1]; const turn2Parts = turn2Input.message as Part[]; expect(turn2Parts).toEqual([ { text: 'All tool calls failed. Please analyze the errors and try an alternative approach.', }, ]); }); it('SECURITY: should block calls to tools not registered for the agent at runtime', async () => { // Agent definition only includes LSTool const definition = createTestDefinition([LSTool.Name]); const executor = await AgentExecutor.create( definition, mockConfig, onActivity, ); // Turn 1: Model hallucinates a call to ReadFileTool // (ReadFileTool exists in the parent registry but not the agent's isolated registry) mockModelResponse([ { name: ReadFileTool.Name, args: { path: 'config.txt' }, id: 'call_blocked', }, ]); // Turn 2: Model stops mockModelResponse([]); // Extraction mockModelResponse([], undefined, 'Done.'); const consoleWarnSpy = vi .spyOn(console, 'warn') .mockImplementation(() => {}); await executor.run({ goal: 'Security test' }, signal); // Verify executeToolCall was NEVER called because the tool was unauthorized expect(mockExecuteToolCall).not.toHaveBeenCalled(); expect(consoleWarnSpy).toHaveBeenCalledWith( expect.stringContaining( `attempted to call unauthorized tool '${ReadFileTool.Name}'`, ), ); consoleWarnSpy.mockRestore(); // Verify the input to the next model call (Turn 2) indicates failure (as the only call was blocked) const turn2Input = mockSendMessageStream.mock.calls[1][1]; const turn2Parts = turn2Input.message as Part[]; expect(turn2Parts[0].text).toContain('All tool calls failed'); }); it('should use OutputConfig completion_criteria in the extraction message', async () => { const definition = createTestDefinition( [LSTool.Name], {}, { description: 'A summary.', completion_criteria: ['Must include file names', 'Must be concise'], }, ); const executor = await AgentExecutor.create( definition, mockConfig, onActivity, ); // Turn 1: Model stops immediately mockModelResponse([]); // Extraction Phase mockModelResponse([], undefined, 'Result: Done.'); await executor.run({ goal: 'Extraction test' }, signal); // Verify the extraction call (the second call) const extractionCallArgs = mockSendMessageStream.mock.calls[1][1]; const extractionMessageParts = extractionCallArgs.message as Part[]; const extractionText = extractionMessageParts[0].text; expect(extractionText).toContain( 'Based on your work so far, provide: A summary.', ); expect(extractionText).toContain('Be sure you have addressed:'); expect(extractionText).toContain('- Must include file names'); expect(extractionText).toContain('- Must be concise'); }); }); describe('run (Termination Conditions)', () => { const mockKeepAliveResponse = () => { mockModelResponse( [{ name: LSTool.Name, args: { path: '.' }, id: 'loop' }], 'Looping', ); mockExecuteToolCall.mockResolvedValue({ callId: 'loop', resultDisplay: 'ok', responseParts: [ { functionResponse: { name: LSTool.Name, response: {}, id: 'loop' } }, ], error: undefined, }); }; it('should terminate when max_turns is reached', async () => { const MAX_TURNS = 2; const definition = createTestDefinition([LSTool.Name], { max_turns: MAX_TURNS, }); const executor = await AgentExecutor.create( definition, mockConfig, onActivity, ); // Turn 1 mockKeepAliveResponse(); // Turn 2 mockKeepAliveResponse(); const output = await executor.run({ goal: 'Termination test' }, signal); expect(output.terminate_reason).toBe(AgentTerminateMode.MAX_TURNS); expect(mockSendMessageStream).toHaveBeenCalledTimes(MAX_TURNS); // Extraction phase should be skipped when termination is forced expect(mockSendMessageStream).not.toHaveBeenCalledWith( expect.any(String), expect.any(Object), expect.stringContaining('#extraction'), ); }); it('should terminate if timeout is reached', async () => { const definition = createTestDefinition([LSTool.Name], { max_time_minutes: 5, max_turns: 100, }); const executor = await AgentExecutor.create( definition, mockConfig, onActivity, ); // Turn 1 setup mockModelResponse( [{ name: LSTool.Name, args: { path: '.' }, id: 'loop' }], 'Looping', ); // Mock a tool call that takes a long time, causing the overall timeout mockExecuteToolCall.mockImplementation(async () => { // Advance time past the 5-minute limit during the tool call execution await vi.advanceTimersByTimeAsync(5 * 60 * 1000 + 1); return { callId: 'loop', resultDisplay: 'ok', responseParts: [ { functionResponse: { name: LSTool.Name, response: {}, id: 'loop' }, }, ], error: undefined, }; }); const output = await executor.run({ goal: 'Termination test' }, signal); expect(output.terminate_reason).toBe(AgentTerminateMode.TIMEOUT); // Should only have called the model once before the timeout check stopped it expect(mockSendMessageStream).toHaveBeenCalledTimes(1); }); it('should terminate when AbortSignal is triggered mid-stream', async () => { const definition = createTestDefinition(); const executor = await AgentExecutor.create( definition, mockConfig, onActivity, ); // Mock the model response stream mockSendMessageStream.mockImplementation(async () => (async function* () { // Yield the first chunk yield { type: StreamEventType.CHUNK, value: createMockResponseChunk([ { text: '**Thinking** Step 1', thought: true }, ]), } as StreamEvent; // Simulate abort happening mid-stream abortController.abort(); // The loop in callModel should break immediately due to signal check. })(), ); const output = await executor.run({ goal: 'Termination test' }, signal); expect(output.terminate_reason).toBe(AgentTerminateMode.ABORTED); }); }); });