chore(mocktools): final step in unify mock tool definitions (#9166)

This commit is contained in:
Adam Weidman
2025-09-29 22:43:06 +02:00
committed by GitHub
parent 94f43c79d0
commit 0c3fcb7030
8 changed files with 171 additions and 228 deletions
@@ -30,7 +30,11 @@ import {
ApprovalMode,
} from '../index.js';
import type { Part, PartListUnion } from '@google/genai';
import { MockModifiableTool, MockTool } from '../test-utils/tools.js';
import {
MockModifiableTool,
MockTool,
MOCK_TOOL_SHOULD_CONFIRM_EXECUTE,
} from '../test-utils/mock-tool.js';
import * as fs from 'node:fs/promises';
import * as path from 'node:path';
@@ -205,8 +209,10 @@ async function waitForStatus(
describe('CoreToolScheduler', () => {
it('should cancel a tool call if the signal is aborted before confirmation', async () => {
const mockTool = new MockTool();
mockTool.shouldConfirm = true;
const mockTool = new MockTool({
name: 'mockTool',
shouldConfirmExecute: MOCK_TOOL_SHOULD_CONFIRM_EXECUTE,
});
const declarativeTool = mockTool;
const mockToolRegistry = {
getTool: () => declarativeTool,
@@ -399,6 +405,7 @@ describe('CoreToolScheduler', () => {
describe('CoreToolScheduler with payload', () => {
it('should update args and diff and execute tool when payload is provided', async () => {
const mockTool = new MockModifiableTool();
mockTool.executeFn = vi.fn();
const declarativeTool = mockTool;
const mockToolRegistry = {
getTool: () => declarativeTool,
@@ -813,13 +820,15 @@ describe('CoreToolScheduler edit cancellation', () => {
describe('CoreToolScheduler YOLO mode', () => {
it('should execute tool requiring confirmation directly without waiting', async () => {
// Arrange
const mockTool = new MockTool();
mockTool.executeFn.mockReturnValue({
const executeFn = vi.fn().mockResolvedValue({
llmContent: 'Tool executed',
returnDisplay: 'Tool executed',
});
// This tool would normally require confirmation.
mockTool.shouldConfirm = true;
const mockTool = new MockTool({
name: 'mockTool',
execute: executeFn,
shouldConfirmExecute: MOCK_TOOL_SHOULD_CONFIRM_EXECUTE,
});
const declarativeTool = mockTool;
const mockToolRegistry = {
@@ -894,7 +903,7 @@ describe('CoreToolScheduler YOLO mode', () => {
// Assert
// 1. The tool's execute method was called directly.
expect(mockTool.executeFn).toHaveBeenCalledWith({ param: 'value' });
expect(executeFn).toHaveBeenCalledWith({ param: 'value' });
// 2. The tool call status never entered 'awaiting_approval'.
const statusUpdates = onToolCallsUpdate.mock.calls
@@ -927,8 +936,8 @@ describe('CoreToolScheduler request queueing', () => {
resolveFirstCall = resolve;
});
const mockTool = new MockTool();
mockTool.executeFn.mockImplementation(() => firstCallPromise);
const executeFn = vi.fn().mockImplementation(() => firstCallPromise);
const mockTool = new MockTool({ name: 'mockTool', execute: executeFn });
const declarativeTool = mockTool;
const mockToolRegistry = {
@@ -1011,8 +1020,7 @@ describe('CoreToolScheduler request queueing', () => {
);
// Ensure the second tool call hasn't been executed yet.
expect(mockTool.executeFn).toHaveBeenCalledTimes(1);
expect(mockTool.executeFn).toHaveBeenCalledWith({ a: 1 });
expect(executeFn).toHaveBeenCalledWith({ a: 1 });
// Complete the first tool call.
resolveFirstCall!({
@@ -1034,9 +1042,9 @@ describe('CoreToolScheduler request queueing', () => {
await vi.waitFor(() => {
// Now the second tool call should have been executed.
expect(mockTool.executeFn).toHaveBeenCalledTimes(2);
expect(executeFn).toHaveBeenCalledTimes(2);
});
expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 });
expect(executeFn).toHaveBeenCalledWith({ b: 2 });
// Wait for the second completion.
await vi.waitFor(() => {
@@ -1050,13 +1058,15 @@ describe('CoreToolScheduler request queueing', () => {
it('should auto-approve a tool call if it is on the allowedTools list', async () => {
// Arrange
const mockTool = new MockTool('mockTool');
mockTool.executeFn.mockReturnValue({
const executeFn = vi.fn().mockResolvedValue({
llmContent: 'Tool executed',
returnDisplay: 'Tool executed',
});
// This tool would normally require confirmation.
mockTool.shouldConfirm = true;
const mockTool = new MockTool({
name: 'mockTool',
execute: executeFn,
shouldConfirmExecute: MOCK_TOOL_SHOULD_CONFIRM_EXECUTE,
});
const declarativeTool = mockTool;
const toolRegistry = {
@@ -1132,7 +1142,7 @@ describe('CoreToolScheduler request queueing', () => {
// Assert
// 1. The tool's execute method was called directly.
expect(mockTool.executeFn).toHaveBeenCalledWith({ param: 'value' });
expect(executeFn).toHaveBeenCalledWith({ param: 'value' });
// 2. The tool call status never entered 'awaiting_approval'.
const statusUpdates = onToolCallsUpdate.mock.calls
@@ -1159,7 +1169,11 @@ describe('CoreToolScheduler request queueing', () => {
});
it('should handle two synchronous calls to schedule', async () => {
const mockTool = new MockTool();
const executeFn = vi.fn().mockResolvedValue({
llmContent: 'Tool executed',
returnDisplay: 'Tool executed',
});
const mockTool = new MockTool({ name: 'mockTool', execute: executeFn });
const declarativeTool = mockTool;
const mockToolRegistry = {
getTool: () => declarativeTool,
@@ -1241,9 +1255,9 @@ describe('CoreToolScheduler request queueing', () => {
await Promise.all([schedulePromise1, schedulePromise2]);
// Ensure the tool was called twice with the correct arguments.
expect(mockTool.executeFn).toHaveBeenCalledTimes(2);
expect(mockTool.executeFn).toHaveBeenCalledWith({ a: 1 });
expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 });
expect(executeFn).toHaveBeenCalledTimes(2);
expect(executeFn).toHaveBeenCalledWith({ a: 1 });
expect(executeFn).toHaveBeenCalledWith({ b: 2 });
// Ensure completion callbacks were called twice.
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2);
@@ -5,6 +5,7 @@
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import type { Mock } from 'vitest';
import { executeToolCall } from './nonInteractiveToolExecutor.js';
import type {
ToolRegistry,
@@ -19,16 +20,18 @@ import {
ApprovalMode,
} from '../index.js';
import type { Part } from '@google/genai';
import { MockTool } from '../test-utils/tools.js';
import { MockTool } from '../test-utils/mock-tool.js';
describe('executeToolCall', () => {
let mockToolRegistry: ToolRegistry;
let mockTool: MockTool;
let executeFn: Mock;
let abortController: AbortController;
let mockConfig: Config;
beforeEach(() => {
mockTool = new MockTool();
executeFn = vi.fn();
mockTool = new MockTool({ name: 'testTool', execute: executeFn });
mockToolRegistry = {
getTool: vi.fn(),
@@ -77,7 +80,7 @@ describe('executeToolCall', () => {
returnDisplay: 'Success!',
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
mockTool.executeFn.mockReturnValue(toolResult);
executeFn.mockResolvedValue(toolResult);
const response = await executeToolCall(
mockConfig,
@@ -86,7 +89,7 @@ describe('executeToolCall', () => {
);
expect(mockToolRegistry.getTool).toHaveBeenCalledWith('testTool');
expect(mockTool.executeFn).toHaveBeenCalledWith(request.args);
expect(executeFn).toHaveBeenCalledWith(request.args);
expect(response).toStrictEqual({
callId: 'call1',
error: undefined,
@@ -207,7 +210,7 @@ describe('executeToolCall', () => {
},
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
mockTool.executeFn.mockReturnValue(executionErrorResult);
executeFn.mockResolvedValue(executionErrorResult);
const response = await executeToolCall(
mockConfig,
@@ -243,9 +246,7 @@ describe('executeToolCall', () => {
prompt_id: 'prompt-id-5',
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
mockTool.executeFn.mockImplementation(() => {
throw new Error('Something went very wrong');
});
executeFn.mockRejectedValue(new Error('Something went very wrong'));
const response = await executeToolCall(
mockConfig,
@@ -287,7 +288,7 @@ describe('executeToolCall', () => {
returnDisplay: 'Image processed',
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
mockTool.executeFn.mockReturnValue(toolResult);
executeFn.mockResolvedValue(toolResult);
const response = await executeToolCall(
mockConfig,
@@ -330,7 +331,7 @@ describe('executeToolCall', () => {
returnDisplay: 'String returned',
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
mockTool.executeFn.mockReturnValue(toolResult);
executeFn.mockResolvedValue(toolResult);
const response = await executeToolCall(
mockConfig,
@@ -358,7 +359,7 @@ describe('executeToolCall', () => {
returnDisplay: 'Image data returned',
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
mockTool.executeFn.mockReturnValue(toolResult);
executeFn.mockResolvedValue(toolResult);
const response = await executeToolCall(
mockConfig,
@@ -17,7 +17,7 @@ import type {
ToolCallRequestInfo,
ToolCallResponseInfo,
} from '../core/turn.js';
import { MockTool } from '../test-utils/tools.js';
import { MockTool } from '../test-utils/mock-tool.js';
describe('Circular Reference Handling', () => {
it('should handle circular references in tool function arguments', () => {
@@ -59,7 +59,7 @@ describe('Circular Reference Handling', () => {
errorType: undefined,
};
const tool = new MockTool('mock-tool');
const tool = new MockTool({ name: 'mock-tool' });
const mockCompletedToolCall: CompletedToolCall = {
status: 'success',
request: mockRequest,
@@ -109,7 +109,7 @@ describe('Circular Reference Handling', () => {
errorType: undefined,
};
const tool = new MockTool('mock-tool');
const tool = new MockTool({ name: 'mock-tool' });
const mockCompletedToolCall: CompletedToolCall = {
status: 'success',
request: mockRequest,
@@ -21,7 +21,7 @@ import type {
} from '../core/coreToolScheduler.js';
import { ToolErrorType } from '../tools/tool-error.js';
import { ToolConfirmationOutcome } from '../tools/tools.js';
import { MockTool } from '../test-utils/tools.js';
import { MockTool } from '../test-utils/mock-tool.js';
const createFakeCompletedToolCall = (
name: string,
@@ -37,7 +37,7 @@ const createFakeCompletedToolCall = (
isClientInitiated: false,
prompt_id: 'prompt-id-1',
};
const tool = new MockTool(name);
const tool = new MockTool({ name });
if (success) {
return {
+106 -3
View File
@@ -4,6 +4,10 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type {
ModifiableDeclarativeTool,
ModifyContext,
} from '../tools/modifiable-tool.js';
import type {
ToolCallConfirmationDetails,
ToolInvocation,
@@ -27,7 +31,7 @@ interface MockToolOptions {
) => Promise<ToolCallConfirmationDetails | false>;
execute?: (
params: { [key: string]: unknown },
signal: AbortSignal,
signal?: AbortSignal,
updateOutput?: (output: string) => void,
) => Promise<ToolResult>;
params?: object;
@@ -48,7 +52,11 @@ class MockToolInvocation extends BaseToolInvocation<
signal: AbortSignal,
updateOutput?: (output: string) => void,
): Promise<ToolResult> {
return this.tool.execute(this.params, signal, updateOutput);
if (updateOutput) {
return this.tool.execute(this.params, signal, updateOutput);
} else {
return this.tool.execute(this.params);
}
}
override shouldConfirmExecute(
@@ -75,7 +83,7 @@ export class MockTool extends BaseDeclarativeTool<
) => Promise<ToolCallConfirmationDetails | false>;
execute: (
params: { [key: string]: unknown },
signal: AbortSignal,
signal?: AbortSignal,
updateOutput?: (output: string) => void,
) => Promise<ToolResult>;
@@ -113,3 +121,98 @@ export class MockTool extends BaseDeclarativeTool<
return new MockToolInvocation(this, params);
}
}
export const MOCK_TOOL_SHOULD_CONFIRM_EXECUTE = () =>
Promise.resolve({
type: 'exec' as const,
title: 'Confirm mockTool',
command: 'mockTool',
rootCommand: 'mockTool',
onConfirm: async () => {},
});
export class MockModifiableToolInvocation extends BaseToolInvocation<
Record<string, unknown>,
ToolResult
> {
constructor(
private readonly tool: MockModifiableTool,
params: Record<string, unknown>,
) {
super(params);
}
async execute(_abortSignal: AbortSignal): Promise<ToolResult> {
const result = this.tool.executeFn(this.params);
return (
result ?? {
llmContent: `Tool ${this.tool.name} executed successfully.`,
returnDisplay: `Tool ${this.tool.name} executed successfully.`,
}
);
}
override async shouldConfirmExecute(
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.tool.shouldConfirm) {
return {
type: 'edit',
title: 'Confirm Mock Tool',
fileName: 'test.txt',
filePath: 'test.txt',
fileDiff: 'diff',
originalContent: 'originalContent',
newContent: 'newContent',
onConfirm: async () => {},
};
}
return false;
}
getDescription(): string {
return `A mock modifiable tool invocation for ${this.tool.name}`;
}
}
/**
* Configurable mock modifiable tool for testing.
*/
export class MockModifiableTool
extends BaseDeclarativeTool<Record<string, unknown>, ToolResult>
implements ModifiableDeclarativeTool<Record<string, unknown>>
{
// Should be overrided in test file. Functionality will be updated in follow
// up PR which has MockModifiableTool expect MockTool
executeFn: (params: Record<string, unknown>) => ToolResult | undefined = () =>
undefined;
shouldConfirm = true;
constructor(name = 'mockModifiableTool') {
super(name, name, 'A mock modifiable tool for testing.', Kind.Other, {
type: 'object',
properties: { param: { type: 'string' } },
});
}
getModifyContext(
_abortSignal: AbortSignal,
): ModifyContext<Record<string, unknown>> {
return {
getFilePath: () => 'test.txt',
getCurrentContent: async () => 'old content',
getProposedContent: async () => 'new content',
createUpdatedParams: (
_oldContent: string,
modifiedProposedContent: string,
_originalParams: Record<string, unknown>,
) => ({ newContent: modifiedProposedContent }),
};
}
protected createInvocation(
params: Record<string, unknown>,
): ToolInvocation<Record<string, unknown>, ToolResult> {
return new MockModifiableToolInvocation(this, params);
}
}
-169
View File
@@ -1,169 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi } from 'vitest';
import type {
ToolCallConfirmationDetails,
ToolInvocation,
ToolResult,
} from '../tools/tools.js';
import {
BaseDeclarativeTool,
BaseToolInvocation,
Kind,
} from '../tools/tools.js';
import type {
ModifiableDeclarativeTool,
ModifyContext,
} from '../tools/modifiable-tool.js';
class MockToolInvocation extends BaseToolInvocation<
{ [key: string]: unknown },
ToolResult
> {
constructor(
private readonly tool: MockTool,
params: { [key: string]: unknown },
) {
super(params);
}
async execute(_abortSignal: AbortSignal): Promise<ToolResult> {
const result = this.tool.executeFn(this.params);
return (
result ?? {
llmContent: `Tool ${this.tool.name} executed successfully.`,
returnDisplay: `Tool ${this.tool.name} executed successfully.`,
}
);
}
override async shouldConfirmExecute(
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.tool.shouldConfirm) {
return {
type: 'exec' as const,
title: `Confirm ${this.tool.displayName}`,
command: this.tool.name,
rootCommand: this.tool.name,
onConfirm: async () => {},
};
}
return false;
}
getDescription(): string {
return `A mock tool invocation for ${this.tool.name}`;
}
}
/**
* A highly configurable mock tool for testing purposes.
*/
export class MockTool extends BaseDeclarativeTool<
{ [key: string]: unknown },
ToolResult
> {
executeFn = vi.fn();
shouldConfirm = false;
constructor(
name = 'mock-tool',
displayName?: string,
description = 'A mock tool for testing.',
params = {
type: 'object',
properties: { param: { type: 'string' } },
},
) {
super(name, displayName ?? name, description, Kind.Other, params);
}
protected createInvocation(params: {
[key: string]: unknown;
}): ToolInvocation<{ [key: string]: unknown }, ToolResult> {
return new MockToolInvocation(this, params);
}
}
export class MockModifiableToolInvocation extends BaseToolInvocation<
Record<string, unknown>,
ToolResult
> {
constructor(
private readonly tool: MockModifiableTool,
params: Record<string, unknown>,
) {
super(params);
}
async execute(_abortSignal: AbortSignal): Promise<ToolResult> {
const result = this.tool.executeFn(this.params);
return (
result ?? {
llmContent: `Tool ${this.tool.name} executed successfully.`,
returnDisplay: `Tool ${this.tool.name} executed successfully.`,
}
);
}
override async shouldConfirmExecute(
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.tool.shouldConfirm) {
return {
type: 'edit',
title: 'Confirm Mock Tool',
fileName: 'test.txt',
filePath: 'test.txt',
fileDiff: 'diff',
originalContent: 'originalContent',
newContent: 'newContent',
onConfirm: async () => {},
};
}
return false;
}
getDescription(): string {
return `A mock modifiable tool invocation for ${this.tool.name}`;
}
}
/**
* Configurable mock modifiable tool for testing.
*/
export class MockModifiableTool
extends MockTool
implements ModifiableDeclarativeTool<Record<string, unknown>>
{
constructor(name = 'mockModifiableTool') {
super(name);
this.shouldConfirm = true;
}
getModifyContext(
_abortSignal: AbortSignal,
): ModifyContext<Record<string, unknown>> {
return {
getFilePath: () => 'test.txt',
getCurrentContent: async () => 'old content',
getProposedContent: async () => 'new content',
createUpdatedParams: (
_oldContent: string,
modifiedProposedContent: string,
_originalParams: Record<string, unknown>,
) => ({ newContent: modifiedProposedContent }),
};
}
protected override createInvocation(
params: Record<string, unknown>,
): ToolInvocation<Record<string, unknown>, ToolResult> {
return new MockModifiableToolInvocation(this, params);
}
}
+10 -10
View File
@@ -16,7 +16,7 @@ import { mcpToTool } from '@google/genai';
import { spawn } from 'node:child_process';
import fs from 'node:fs';
import { MockTool } from '../test-utils/tools.js';
import { MockTool } from '../test-utils/mock-tool.js';
import { McpClientManager } from './mcp-client-manager.js';
import { ToolErrorType } from './tool-error.js';
@@ -149,7 +149,7 @@ describe('ToolRegistry', () => {
describe('registerTool', () => {
it('should register a new tool', () => {
const tool = new MockTool();
const tool = new MockTool({ name: 'mock-tool' });
toolRegistry.registerTool(tool);
expect(toolRegistry.getTool('mock-tool')).toBe(tool);
});
@@ -158,9 +158,9 @@ describe('ToolRegistry', () => {
describe('getAllTools', () => {
it('should return all registered tools sorted alphabetically by displayName', () => {
// Register tools with displayNames in non-alphabetical order
const toolC = new MockTool('c-tool', 'Tool C');
const toolA = new MockTool('a-tool', 'Tool A');
const toolB = new MockTool('b-tool', 'Tool B');
const toolC = new MockTool({ name: 'c-tool', displayName: 'Tool C' });
const toolA = new MockTool({ name: 'a-tool', displayName: 'Tool A' });
const toolB = new MockTool({ name: 'b-tool', displayName: 'Tool B' });
toolRegistry.registerTool(toolC);
toolRegistry.registerTool(toolA);
@@ -177,9 +177,9 @@ describe('ToolRegistry', () => {
describe('getAllToolNames', () => {
it('should return all registered tool names', () => {
// Register tools with displayNames in non-alphabetical order
const toolC = new MockTool('c-tool', 'Tool C');
const toolA = new MockTool('a-tool', 'Tool A');
const toolB = new MockTool('b-tool', 'Tool B');
const toolC = new MockTool({ name: 'c-tool', displayName: 'Tool C' });
const toolA = new MockTool({ name: 'a-tool', displayName: 'Tool A' });
const toolB = new MockTool({ name: 'b-tool', displayName: 'Tool B' });
toolRegistry.registerTool(toolC);
toolRegistry.registerTool(toolA);
@@ -194,7 +194,7 @@ describe('ToolRegistry', () => {
describe('getToolsByServer', () => {
it('should return an empty array if no tools match the server name', () => {
toolRegistry.registerTool(new MockTool());
toolRegistry.registerTool(new MockTool({ name: 'mock-tool' }));
expect(toolRegistry.getToolsByServer('any-mcp-server')).toEqual([]);
});
@@ -231,7 +231,7 @@ describe('ToolRegistry', () => {
'd4',
{},
);
const nonMcpTool = new MockTool('regular-tool');
const nonMcpTool = new MockTool({ name: 'regular-tool' });
toolRegistry.registerTool(mcpTool1_c);
toolRegistry.registerTool(mcpTool1_a);