diff --git a/packages/cli/src/ui/hooks/useToolScheduler.test.ts b/packages/cli/src/ui/hooks/useToolScheduler.test.ts index 6e758a3475..4a626d93ce 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.test.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.test.ts @@ -163,8 +163,6 @@ describe('useReactToolScheduler in YOLO Mode', () => { // Check that execute WAS called expect(mockToolRequiresConfirmation.execute).toHaveBeenCalledWith( request.args, - expect.any(AbortSignal), - undefined, ); // Check that onComplete was called with success @@ -311,11 +309,7 @@ describe('useReactToolScheduler', () => { await vi.runAllTimersAsync(); }); - expect(mockTool.execute).toHaveBeenCalledWith( - request.args, - expect.any(AbortSignal), - undefined, - ); + expect(mockTool.execute).toHaveBeenCalledWith(request.args); expect(onComplete).toHaveBeenCalledWith([ expect.objectContaining({ status: 'success', diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index 4aaa0b3d45..d24179d056 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -30,7 +30,11 @@ import { ApprovalMode, } from '../index.js'; import type { Part, PartListUnion } from '@google/genai'; -import { MockModifiableTool, MockTool } from '../test-utils/tools.js'; +import { + MockModifiableTool, + MockTool, + MOCK_TOOL_SHOULD_CONFIRM_EXECUTE, +} from '../test-utils/mock-tool.js'; import * as fs from 'node:fs/promises'; import * as path from 'node:path'; @@ -205,8 +209,10 @@ async function waitForStatus( describe('CoreToolScheduler', () => { it('should cancel a tool call if the signal is aborted before confirmation', async () => { - const mockTool = new MockTool(); - mockTool.shouldConfirm = true; + const mockTool = new MockTool({ + name: 'mockTool', + shouldConfirmExecute: MOCK_TOOL_SHOULD_CONFIRM_EXECUTE, + }); const declarativeTool = mockTool; const mockToolRegistry = { getTool: () => declarativeTool, @@ -399,6 +405,7 @@ describe('CoreToolScheduler', () => { describe('CoreToolScheduler with payload', () => { it('should update args and diff and execute tool when payload is provided', async () => { const mockTool = new MockModifiableTool(); + mockTool.executeFn = vi.fn(); const declarativeTool = mockTool; const mockToolRegistry = { getTool: () => declarativeTool, @@ -813,13 +820,15 @@ describe('CoreToolScheduler edit cancellation', () => { describe('CoreToolScheduler YOLO mode', () => { it('should execute tool requiring confirmation directly without waiting', async () => { // Arrange - const mockTool = new MockTool(); - mockTool.executeFn.mockReturnValue({ + const executeFn = vi.fn().mockResolvedValue({ llmContent: 'Tool executed', returnDisplay: 'Tool executed', }); - // This tool would normally require confirmation. - mockTool.shouldConfirm = true; + const mockTool = new MockTool({ + name: 'mockTool', + execute: executeFn, + shouldConfirmExecute: MOCK_TOOL_SHOULD_CONFIRM_EXECUTE, + }); const declarativeTool = mockTool; const mockToolRegistry = { @@ -894,7 +903,7 @@ describe('CoreToolScheduler YOLO mode', () => { // Assert // 1. The tool's execute method was called directly. - expect(mockTool.executeFn).toHaveBeenCalledWith({ param: 'value' }); + expect(executeFn).toHaveBeenCalledWith({ param: 'value' }); // 2. The tool call status never entered 'awaiting_approval'. const statusUpdates = onToolCallsUpdate.mock.calls @@ -927,8 +936,8 @@ describe('CoreToolScheduler request queueing', () => { resolveFirstCall = resolve; }); - const mockTool = new MockTool(); - mockTool.executeFn.mockImplementation(() => firstCallPromise); + const executeFn = vi.fn().mockImplementation(() => firstCallPromise); + const mockTool = new MockTool({ name: 'mockTool', execute: executeFn }); const declarativeTool = mockTool; const mockToolRegistry = { @@ -1011,8 +1020,7 @@ describe('CoreToolScheduler request queueing', () => { ); // Ensure the second tool call hasn't been executed yet. - expect(mockTool.executeFn).toHaveBeenCalledTimes(1); - expect(mockTool.executeFn).toHaveBeenCalledWith({ a: 1 }); + expect(executeFn).toHaveBeenCalledWith({ a: 1 }); // Complete the first tool call. resolveFirstCall!({ @@ -1034,9 +1042,9 @@ describe('CoreToolScheduler request queueing', () => { await vi.waitFor(() => { // Now the second tool call should have been executed. - expect(mockTool.executeFn).toHaveBeenCalledTimes(2); + expect(executeFn).toHaveBeenCalledTimes(2); }); - expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 }); + expect(executeFn).toHaveBeenCalledWith({ b: 2 }); // Wait for the second completion. await vi.waitFor(() => { @@ -1050,13 +1058,15 @@ describe('CoreToolScheduler request queueing', () => { it('should auto-approve a tool call if it is on the allowedTools list', async () => { // Arrange - const mockTool = new MockTool('mockTool'); - mockTool.executeFn.mockReturnValue({ + const executeFn = vi.fn().mockResolvedValue({ llmContent: 'Tool executed', returnDisplay: 'Tool executed', }); - // This tool would normally require confirmation. - mockTool.shouldConfirm = true; + const mockTool = new MockTool({ + name: 'mockTool', + execute: executeFn, + shouldConfirmExecute: MOCK_TOOL_SHOULD_CONFIRM_EXECUTE, + }); const declarativeTool = mockTool; const toolRegistry = { @@ -1132,7 +1142,7 @@ describe('CoreToolScheduler request queueing', () => { // Assert // 1. The tool's execute method was called directly. - expect(mockTool.executeFn).toHaveBeenCalledWith({ param: 'value' }); + expect(executeFn).toHaveBeenCalledWith({ param: 'value' }); // 2. The tool call status never entered 'awaiting_approval'. const statusUpdates = onToolCallsUpdate.mock.calls @@ -1159,7 +1169,11 @@ describe('CoreToolScheduler request queueing', () => { }); it('should handle two synchronous calls to schedule', async () => { - const mockTool = new MockTool(); + const executeFn = vi.fn().mockResolvedValue({ + llmContent: 'Tool executed', + returnDisplay: 'Tool executed', + }); + const mockTool = new MockTool({ name: 'mockTool', execute: executeFn }); const declarativeTool = mockTool; const mockToolRegistry = { getTool: () => declarativeTool, @@ -1241,9 +1255,9 @@ describe('CoreToolScheduler request queueing', () => { await Promise.all([schedulePromise1, schedulePromise2]); // Ensure the tool was called twice with the correct arguments. - expect(mockTool.executeFn).toHaveBeenCalledTimes(2); - expect(mockTool.executeFn).toHaveBeenCalledWith({ a: 1 }); - expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 }); + expect(executeFn).toHaveBeenCalledTimes(2); + expect(executeFn).toHaveBeenCalledWith({ a: 1 }); + expect(executeFn).toHaveBeenCalledWith({ b: 2 }); // Ensure completion callbacks were called twice. expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2); diff --git a/packages/core/src/core/nonInteractiveToolExecutor.test.ts b/packages/core/src/core/nonInteractiveToolExecutor.test.ts index e8387c2d13..3c05aca66d 100644 --- a/packages/core/src/core/nonInteractiveToolExecutor.test.ts +++ b/packages/core/src/core/nonInteractiveToolExecutor.test.ts @@ -5,6 +5,7 @@ */ import { describe, it, expect, vi, beforeEach } from 'vitest'; +import type { Mock } from 'vitest'; import { executeToolCall } from './nonInteractiveToolExecutor.js'; import type { ToolRegistry, @@ -19,16 +20,18 @@ import { ApprovalMode, } from '../index.js'; import type { Part } from '@google/genai'; -import { MockTool } from '../test-utils/tools.js'; +import { MockTool } from '../test-utils/mock-tool.js'; describe('executeToolCall', () => { let mockToolRegistry: ToolRegistry; let mockTool: MockTool; + let executeFn: Mock; let abortController: AbortController; let mockConfig: Config; beforeEach(() => { - mockTool = new MockTool(); + executeFn = vi.fn(); + mockTool = new MockTool({ name: 'testTool', execute: executeFn }); mockToolRegistry = { getTool: vi.fn(), @@ -77,7 +80,7 @@ describe('executeToolCall', () => { returnDisplay: 'Success!', }; vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - mockTool.executeFn.mockReturnValue(toolResult); + executeFn.mockResolvedValue(toolResult); const response = await executeToolCall( mockConfig, @@ -86,7 +89,7 @@ describe('executeToolCall', () => { ); expect(mockToolRegistry.getTool).toHaveBeenCalledWith('testTool'); - expect(mockTool.executeFn).toHaveBeenCalledWith(request.args); + expect(executeFn).toHaveBeenCalledWith(request.args); expect(response).toStrictEqual({ callId: 'call1', error: undefined, @@ -207,7 +210,7 @@ describe('executeToolCall', () => { }, }; vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - mockTool.executeFn.mockReturnValue(executionErrorResult); + executeFn.mockResolvedValue(executionErrorResult); const response = await executeToolCall( mockConfig, @@ -243,9 +246,7 @@ describe('executeToolCall', () => { prompt_id: 'prompt-id-5', }; vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - mockTool.executeFn.mockImplementation(() => { - throw new Error('Something went very wrong'); - }); + executeFn.mockRejectedValue(new Error('Something went very wrong')); const response = await executeToolCall( mockConfig, @@ -287,7 +288,7 @@ describe('executeToolCall', () => { returnDisplay: 'Image processed', }; vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - mockTool.executeFn.mockReturnValue(toolResult); + executeFn.mockResolvedValue(toolResult); const response = await executeToolCall( mockConfig, @@ -330,7 +331,7 @@ describe('executeToolCall', () => { returnDisplay: 'String returned', }; vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - mockTool.executeFn.mockReturnValue(toolResult); + executeFn.mockResolvedValue(toolResult); const response = await executeToolCall( mockConfig, @@ -358,7 +359,7 @@ describe('executeToolCall', () => { returnDisplay: 'Image data returned', }; vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - mockTool.executeFn.mockReturnValue(toolResult); + executeFn.mockResolvedValue(toolResult); const response = await executeToolCall( mockConfig, diff --git a/packages/core/src/telemetry/loggers.test.circular.ts b/packages/core/src/telemetry/loggers.test.circular.ts index 240604bfe8..23a6b6aeed 100644 --- a/packages/core/src/telemetry/loggers.test.circular.ts +++ b/packages/core/src/telemetry/loggers.test.circular.ts @@ -17,7 +17,7 @@ import type { ToolCallRequestInfo, ToolCallResponseInfo, } from '../core/turn.js'; -import { MockTool } from '../test-utils/tools.js'; +import { MockTool } from '../test-utils/mock-tool.js'; describe('Circular Reference Handling', () => { it('should handle circular references in tool function arguments', () => { @@ -59,7 +59,7 @@ describe('Circular Reference Handling', () => { errorType: undefined, }; - const tool = new MockTool('mock-tool'); + const tool = new MockTool({ name: 'mock-tool' }); const mockCompletedToolCall: CompletedToolCall = { status: 'success', request: mockRequest, @@ -109,7 +109,7 @@ describe('Circular Reference Handling', () => { errorType: undefined, }; - const tool = new MockTool('mock-tool'); + const tool = new MockTool({ name: 'mock-tool' }); const mockCompletedToolCall: CompletedToolCall = { status: 'success', request: mockRequest, diff --git a/packages/core/src/telemetry/uiTelemetry.test.ts b/packages/core/src/telemetry/uiTelemetry.test.ts index cecc9cea2b..d5ef8099da 100644 --- a/packages/core/src/telemetry/uiTelemetry.test.ts +++ b/packages/core/src/telemetry/uiTelemetry.test.ts @@ -21,7 +21,7 @@ import type { } from '../core/coreToolScheduler.js'; import { ToolErrorType } from '../tools/tool-error.js'; import { ToolConfirmationOutcome } from '../tools/tools.js'; -import { MockTool } from '../test-utils/tools.js'; +import { MockTool } from '../test-utils/mock-tool.js'; const createFakeCompletedToolCall = ( name: string, @@ -37,7 +37,7 @@ const createFakeCompletedToolCall = ( isClientInitiated: false, prompt_id: 'prompt-id-1', }; - const tool = new MockTool(name); + const tool = new MockTool({ name }); if (success) { return { diff --git a/packages/core/src/test-utils/mock-tool.ts b/packages/core/src/test-utils/mock-tool.ts index af140f8e47..75bdf26c5e 100644 --- a/packages/core/src/test-utils/mock-tool.ts +++ b/packages/core/src/test-utils/mock-tool.ts @@ -4,6 +4,10 @@ * SPDX-License-Identifier: Apache-2.0 */ +import type { + ModifiableDeclarativeTool, + ModifyContext, +} from '../tools/modifiable-tool.js'; import type { ToolCallConfirmationDetails, ToolInvocation, @@ -27,7 +31,7 @@ interface MockToolOptions { ) => Promise; execute?: ( params: { [key: string]: unknown }, - signal: AbortSignal, + signal?: AbortSignal, updateOutput?: (output: string) => void, ) => Promise; params?: object; @@ -48,7 +52,11 @@ class MockToolInvocation extends BaseToolInvocation< signal: AbortSignal, updateOutput?: (output: string) => void, ): Promise { - return this.tool.execute(this.params, signal, updateOutput); + if (updateOutput) { + return this.tool.execute(this.params, signal, updateOutput); + } else { + return this.tool.execute(this.params); + } } override shouldConfirmExecute( @@ -75,7 +83,7 @@ export class MockTool extends BaseDeclarativeTool< ) => Promise; execute: ( params: { [key: string]: unknown }, - signal: AbortSignal, + signal?: AbortSignal, updateOutput?: (output: string) => void, ) => Promise; @@ -113,3 +121,98 @@ export class MockTool extends BaseDeclarativeTool< return new MockToolInvocation(this, params); } } + +export const MOCK_TOOL_SHOULD_CONFIRM_EXECUTE = () => + Promise.resolve({ + type: 'exec' as const, + title: 'Confirm mockTool', + command: 'mockTool', + rootCommand: 'mockTool', + onConfirm: async () => {}, + }); + +export class MockModifiableToolInvocation extends BaseToolInvocation< + Record, + ToolResult +> { + constructor( + private readonly tool: MockModifiableTool, + params: Record, + ) { + super(params); + } + + async execute(_abortSignal: AbortSignal): Promise { + const result = this.tool.executeFn(this.params); + return ( + result ?? { + llmContent: `Tool ${this.tool.name} executed successfully.`, + returnDisplay: `Tool ${this.tool.name} executed successfully.`, + } + ); + } + + override async shouldConfirmExecute( + _abortSignal: AbortSignal, + ): Promise { + if (this.tool.shouldConfirm) { + return { + type: 'edit', + title: 'Confirm Mock Tool', + fileName: 'test.txt', + filePath: 'test.txt', + fileDiff: 'diff', + originalContent: 'originalContent', + newContent: 'newContent', + onConfirm: async () => {}, + }; + } + return false; + } + + getDescription(): string { + return `A mock modifiable tool invocation for ${this.tool.name}`; + } +} + +/** + * Configurable mock modifiable tool for testing. + */ +export class MockModifiableTool + extends BaseDeclarativeTool, ToolResult> + implements ModifiableDeclarativeTool> +{ + // Should be overrided in test file. Functionality will be updated in follow + // up PR which has MockModifiableTool expect MockTool + executeFn: (params: Record) => ToolResult | undefined = () => + undefined; + shouldConfirm = true; + + constructor(name = 'mockModifiableTool') { + super(name, name, 'A mock modifiable tool for testing.', Kind.Other, { + type: 'object', + properties: { param: { type: 'string' } }, + }); + } + + getModifyContext( + _abortSignal: AbortSignal, + ): ModifyContext> { + return { + getFilePath: () => 'test.txt', + getCurrentContent: async () => 'old content', + getProposedContent: async () => 'new content', + createUpdatedParams: ( + _oldContent: string, + modifiedProposedContent: string, + _originalParams: Record, + ) => ({ newContent: modifiedProposedContent }), + }; + } + + protected createInvocation( + params: Record, + ): ToolInvocation, ToolResult> { + return new MockModifiableToolInvocation(this, params); + } +} diff --git a/packages/core/src/test-utils/tools.ts b/packages/core/src/test-utils/tools.ts deleted file mode 100644 index fca72b5357..0000000000 --- a/packages/core/src/test-utils/tools.ts +++ /dev/null @@ -1,169 +0,0 @@ -/** - * @license - * Copyright 2025 Google LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import { vi } from 'vitest'; -import type { - ToolCallConfirmationDetails, - ToolInvocation, - ToolResult, -} from '../tools/tools.js'; -import { - BaseDeclarativeTool, - BaseToolInvocation, - Kind, -} from '../tools/tools.js'; -import type { - ModifiableDeclarativeTool, - ModifyContext, -} from '../tools/modifiable-tool.js'; - -class MockToolInvocation extends BaseToolInvocation< - { [key: string]: unknown }, - ToolResult -> { - constructor( - private readonly tool: MockTool, - params: { [key: string]: unknown }, - ) { - super(params); - } - - async execute(_abortSignal: AbortSignal): Promise { - const result = this.tool.executeFn(this.params); - return ( - result ?? { - llmContent: `Tool ${this.tool.name} executed successfully.`, - returnDisplay: `Tool ${this.tool.name} executed successfully.`, - } - ); - } - - override async shouldConfirmExecute( - _abortSignal: AbortSignal, - ): Promise { - if (this.tool.shouldConfirm) { - return { - type: 'exec' as const, - title: `Confirm ${this.tool.displayName}`, - command: this.tool.name, - rootCommand: this.tool.name, - onConfirm: async () => {}, - }; - } - return false; - } - - getDescription(): string { - return `A mock tool invocation for ${this.tool.name}`; - } -} - -/** - * A highly configurable mock tool for testing purposes. - */ -export class MockTool extends BaseDeclarativeTool< - { [key: string]: unknown }, - ToolResult -> { - executeFn = vi.fn(); - shouldConfirm = false; - - constructor( - name = 'mock-tool', - displayName?: string, - description = 'A mock tool for testing.', - params = { - type: 'object', - properties: { param: { type: 'string' } }, - }, - ) { - super(name, displayName ?? name, description, Kind.Other, params); - } - - protected createInvocation(params: { - [key: string]: unknown; - }): ToolInvocation<{ [key: string]: unknown }, ToolResult> { - return new MockToolInvocation(this, params); - } -} - -export class MockModifiableToolInvocation extends BaseToolInvocation< - Record, - ToolResult -> { - constructor( - private readonly tool: MockModifiableTool, - params: Record, - ) { - super(params); - } - - async execute(_abortSignal: AbortSignal): Promise { - const result = this.tool.executeFn(this.params); - return ( - result ?? { - llmContent: `Tool ${this.tool.name} executed successfully.`, - returnDisplay: `Tool ${this.tool.name} executed successfully.`, - } - ); - } - - override async shouldConfirmExecute( - _abortSignal: AbortSignal, - ): Promise { - if (this.tool.shouldConfirm) { - return { - type: 'edit', - title: 'Confirm Mock Tool', - fileName: 'test.txt', - filePath: 'test.txt', - fileDiff: 'diff', - originalContent: 'originalContent', - newContent: 'newContent', - onConfirm: async () => {}, - }; - } - return false; - } - - getDescription(): string { - return `A mock modifiable tool invocation for ${this.tool.name}`; - } -} - -/** - * Configurable mock modifiable tool for testing. - */ -export class MockModifiableTool - extends MockTool - implements ModifiableDeclarativeTool> -{ - constructor(name = 'mockModifiableTool') { - super(name); - this.shouldConfirm = true; - } - - getModifyContext( - _abortSignal: AbortSignal, - ): ModifyContext> { - return { - getFilePath: () => 'test.txt', - getCurrentContent: async () => 'old content', - getProposedContent: async () => 'new content', - createUpdatedParams: ( - _oldContent: string, - modifiedProposedContent: string, - _originalParams: Record, - ) => ({ newContent: modifiedProposedContent }), - }; - } - - protected override createInvocation( - params: Record, - ): ToolInvocation, ToolResult> { - return new MockModifiableToolInvocation(this, params); - } -} diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index f42106872e..a75bf96d67 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -16,7 +16,7 @@ import { mcpToTool } from '@google/genai'; import { spawn } from 'node:child_process'; import fs from 'node:fs'; -import { MockTool } from '../test-utils/tools.js'; +import { MockTool } from '../test-utils/mock-tool.js'; import { McpClientManager } from './mcp-client-manager.js'; import { ToolErrorType } from './tool-error.js'; @@ -149,7 +149,7 @@ describe('ToolRegistry', () => { describe('registerTool', () => { it('should register a new tool', () => { - const tool = new MockTool(); + const tool = new MockTool({ name: 'mock-tool' }); toolRegistry.registerTool(tool); expect(toolRegistry.getTool('mock-tool')).toBe(tool); }); @@ -158,9 +158,9 @@ describe('ToolRegistry', () => { describe('getAllTools', () => { it('should return all registered tools sorted alphabetically by displayName', () => { // Register tools with displayNames in non-alphabetical order - const toolC = new MockTool('c-tool', 'Tool C'); - const toolA = new MockTool('a-tool', 'Tool A'); - const toolB = new MockTool('b-tool', 'Tool B'); + const toolC = new MockTool({ name: 'c-tool', displayName: 'Tool C' }); + const toolA = new MockTool({ name: 'a-tool', displayName: 'Tool A' }); + const toolB = new MockTool({ name: 'b-tool', displayName: 'Tool B' }); toolRegistry.registerTool(toolC); toolRegistry.registerTool(toolA); @@ -177,9 +177,9 @@ describe('ToolRegistry', () => { describe('getAllToolNames', () => { it('should return all registered tool names', () => { // Register tools with displayNames in non-alphabetical order - const toolC = new MockTool('c-tool', 'Tool C'); - const toolA = new MockTool('a-tool', 'Tool A'); - const toolB = new MockTool('b-tool', 'Tool B'); + const toolC = new MockTool({ name: 'c-tool', displayName: 'Tool C' }); + const toolA = new MockTool({ name: 'a-tool', displayName: 'Tool A' }); + const toolB = new MockTool({ name: 'b-tool', displayName: 'Tool B' }); toolRegistry.registerTool(toolC); toolRegistry.registerTool(toolA); @@ -194,7 +194,7 @@ describe('ToolRegistry', () => { describe('getToolsByServer', () => { it('should return an empty array if no tools match the server name', () => { - toolRegistry.registerTool(new MockTool()); + toolRegistry.registerTool(new MockTool({ name: 'mock-tool' })); expect(toolRegistry.getToolsByServer('any-mcp-server')).toEqual([]); }); @@ -231,7 +231,7 @@ describe('ToolRegistry', () => { 'd4', {}, ); - const nonMcpTool = new MockTool('regular-tool'); + const nonMcpTool = new MockTool({ name: 'regular-tool' }); toolRegistry.registerTool(mcpTool1_c); toolRegistry.registerTool(mcpTool1_a);