diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index ac8d9f1bd6..b7e85962a5 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -291,6 +291,7 @@ describe('Gemini Client (client.ts)', () => { it('should call chat.addHistory with the provided content', async () => { const mockChat = { addHistory: vi.fn(), + setTools: vi.fn(), } as unknown as GeminiChat; client['chat'] = mockChat; @@ -389,6 +390,7 @@ describe('Gemini Client (client.ts)', () => { getHistory: mockGetHistory, addHistory: vi.fn(), setHistory: vi.fn(), + setTools: vi.fn(), getLastPromptTokenCount: vi.fn(), } as unknown as GeminiChat; }); @@ -805,6 +807,7 @@ describe('Gemini Client (client.ts)', () => { const mockChat = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), } as unknown as GeminiChat; @@ -868,6 +871,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -926,6 +930,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1003,6 +1008,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1119,6 +1125,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1167,6 +1174,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1232,6 +1240,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1289,6 +1298,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1349,6 +1359,7 @@ ${JSON.stringify( const lastPromptTokenCount = 900; const mockChat: Partial = { getLastPromptTokenCount: vi.fn().mockReturnValue(lastPromptTokenCount), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), }; client['chat'] = mockChat as GeminiChat; @@ -1409,6 +1420,7 @@ ${JSON.stringify( const lastPromptTokenCount = 900; const mockChat: Partial = { getLastPromptTokenCount: vi.fn().mockReturnValue(lastPromptTokenCount), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), }; client['chat'] = mockChat as GeminiChat; @@ -1467,6 +1479,7 @@ ${JSON.stringify( .fn() .mockReturnValue([{ role: 'user', parts: [{ text: 'old' }] }]), addHistory: vi.fn(), + setTools: vi.fn(), getChatRecordingService: vi.fn().mockReturnValue({ getConversation: vi.fn(), getConversationFilePath: vi.fn(), @@ -1479,6 +1492,7 @@ ${JSON.stringify( .fn() .mockReturnValue([{ role: 'user', parts: [{ text: 'old' }] }]), addHistory: vi.fn(), + setTools: vi.fn(), getChatRecordingService: vi.fn().mockReturnValue({ getConversation: vi.fn(), getConversationFilePath: vi.fn(), @@ -1616,6 +1630,7 @@ ${JSON.stringify( const lastPromptTokenCount = 10000; const mockChat: Partial = { getLastPromptTokenCount: vi.fn().mockReturnValue(lastPromptTokenCount), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), }; client['chat'] = mockChat as GeminiChat; @@ -1689,6 +1704,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1892,6 +1908,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1947,6 +1964,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -1984,6 +2002,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -2028,6 +2047,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), setHistory: vi.fn(), + setTools: vi.fn(), // Assume history is not empty for delta checks getHistory: vi .fn() @@ -2443,6 +2463,7 @@ ${JSON.stringify( addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), // Default empty history setHistory: vi.fn(), + setTools: vi.fn(), getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; @@ -2783,6 +2804,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -2820,6 +2842,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -2857,6 +2880,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -3069,6 +3093,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; @@ -3103,6 +3128,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), + setTools: vi.fn(), getHistory: vi.fn().mockReturnValue([]), getLastPromptTokenCount: vi.fn(), }; diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 91434d12b3..4781dd7618 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -256,9 +256,20 @@ export class GeminiClient { this.forceFullIdeContext = true; } - async setTools(): Promise { + private lastUsedModelId?: string; + + async setTools(modelId?: string): Promise { + if (!this.chat) { + return; + } + + if (modelId && modelId === this.lastUsedModelId) { + return; + } + this.lastUsedModelId = modelId; + const toolRegistry = this.config.getToolRegistry(); - const toolDeclarations = toolRegistry.getFunctionDeclarations(); + const toolDeclarations = toolRegistry.getFunctionDeclarations(modelId); const tools: Tool[] = [{ functionDeclarations: toolDeclarations }]; this.getChat().setTools(tools); } @@ -321,6 +332,7 @@ export class GeminiClient { ): Promise { this.forceFullIdeContext = true; this.hasFailedCompressionAttempt = false; + this.lastUsedModelId = undefined; const toolRegistry = this.config.getToolRegistry(); const toolDeclarations = toolRegistry.getFunctionDeclarations(); @@ -339,6 +351,13 @@ export class GeminiClient { tools, history, resumedSessionData, + async (modelId: string) => { + this.lastUsedModelId = modelId; + const toolRegistry = this.config.getToolRegistry(); + const toolDeclarations = + toolRegistry.getFunctionDeclarations(modelId); + return [{ functionDeclarations: toolDeclarations }]; + }, ); } catch (error) { await reportError( @@ -653,6 +672,10 @@ export class GeminiClient { yield { type: GeminiEventType.ModelInfo, value: modelToUse }; } this.currentSequenceModel = modelToUse; + + // Update tools with the final modelId to ensure model-dependent descriptions are used. + await this.setTools(modelToUse); + const resultStream = turn.run( modelConfigKey, request, diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index df98e3ebd7..8f2c4b9267 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -247,6 +247,7 @@ export class GeminiChat { private tools: Tool[] = [], private history: Content[] = [], resumedSessionData?: ResumedSessionData, + private readonly onModelChanged?: (modelId: string) => Promise, ) { validateHistory(history); this.chatRecordingService = new ChatRecordingService(config); @@ -580,6 +581,10 @@ export class GeminiChat { } } + if (this.onModelChanged) { + this.tools = await this.onModelChanged(modelToUse); + } + // Track final request parameters for AfterModel hooks lastModelToUse = modelToUse; lastConfig = config; diff --git a/packages/core/src/tools/__snapshots__/read-file.test.ts.snap b/packages/core/src/tools/__snapshots__/read-file.test.ts.snap new file mode 100644 index 0000000000..c6adf2819d --- /dev/null +++ b/packages/core/src/tools/__snapshots__/read-file.test.ts.snap @@ -0,0 +1,5 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`ReadFileTool > getSchema > should return the base schema when no modelId is provided 1`] = `"Reads and returns the content of a specified file. If the file is large, the content will be truncated. The tool's response will clearly indicate if truncation has occurred and will provide details on how to read more of the file using the 'offset' and 'limit' parameters. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), audio files (MP3, WAV, AIFF, AAC, OGG, FLAC), and PDF files. For text files, it can read specific line ranges."`; + +exports[`ReadFileTool > getSchema > should return the schema from the resolver when modelId is provided 1`] = `"Reads and returns the content of a specified file. If the file is large, the content will be truncated. The tool's response will clearly indicate if truncation has occurred and will provide details on how to read more of the file using the 'offset' and 'limit' parameters. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), audio files (MP3, WAV, AIFF, AAC, OGG, FLAC), and PDF files. For text files, it can read specific line ranges."`; diff --git a/packages/core/src/tools/__snapshots__/shell.test.ts.snap b/packages/core/src/tools/__snapshots__/shell.test.ts.snap index 73245052a7..471ce45f6e 100644 --- a/packages/core/src/tools/__snapshots__/shell.test.ts.snap +++ b/packages/core/src/tools/__snapshots__/shell.test.ts.snap @@ -33,3 +33,37 @@ exports[`ShellTool > getDescription > should return the windows description when Background PIDs: Only included if background processes were started. Process Group PGID: Only included if available." `; + +exports[`ShellTool > getSchema > should return the base schema when no modelId is provided 1`] = ` +"This tool executes a given shell command as \`bash -c \`. Command can start background processes using \`&\`. Command is executed as a subprocess that leads its own process group. Command process group can be terminated as \`kill -- -PGID\` or signaled as \`kill -s SIGNAL -- -PGID\`. + + Efficiency Guidelines: + - Quiet Flags: Always prefer silent or quiet flags (e.g., \`npm install --silent\`, \`git --no-pager\`) to reduce output volume while still capturing necessary information. + - Pagination: Always disable terminal pagination to ensure commands terminate (e.g., use \`git --no-pager\`, \`systemctl --no-pager\`, or set \`PAGER=cat\`). + + The following information is returned: + + Output: Combined stdout/stderr. Can be \`(empty)\` or partial on error and for any unwaited background processes. + Exit Code: Only included if non-zero (command failed). + Error: Only included if a process-level error occurred (e.g., spawn failure). + Signal: Only included if process was terminated by a signal. + Background PIDs: Only included if background processes were started. + Process Group PGID: Only included if available." +`; + +exports[`ShellTool > getSchema > should return the schema from the resolver when modelId is provided 1`] = ` +"This tool executes a given shell command as \`bash -c \`. Command can start background processes using \`&\`. Command is executed as a subprocess that leads its own process group. Command process group can be terminated as \`kill -- -PGID\` or signaled as \`kill -s SIGNAL -- -PGID\`. + + Efficiency Guidelines: + - Quiet Flags: Always prefer silent or quiet flags (e.g., \`npm install --silent\`, \`git --no-pager\`) to reduce output volume while still capturing necessary information. + - Pagination: Always disable terminal pagination to ensure commands terminate (e.g., use \`git --no-pager\`, \`systemctl --no-pager\`, or set \`PAGER=cat\`). + + The following information is returned: + + Output: Combined stdout/stderr. Can be \`(empty)\` or partial on error and for any unwaited background processes. + Exit Code: Only included if non-zero (command failed). + Error: Only included if a process-level error occurred (e.g., spawn failure). + Signal: Only included if process was terminated by a signal. + Background PIDs: Only included if background processes were started. + Process Group PGID: Only included if available." +`; diff --git a/packages/core/src/tools/definitions/coreTools.ts b/packages/core/src/tools/definitions/coreTools.ts new file mode 100644 index 0000000000..cfc33b7b6a --- /dev/null +++ b/packages/core/src/tools/definitions/coreTools.ts @@ -0,0 +1,291 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { Type } from '@google/genai'; +import type { ToolDefinition } from './types.js'; +import * as os from 'node:os'; + +// Centralized tool names to avoid circular dependencies +export const GLOB_TOOL_NAME = 'glob'; +export const GREP_TOOL_NAME = 'grep_search'; +export const LS_TOOL_NAME = 'list_directory'; +export const READ_FILE_TOOL_NAME = 'read_file'; +export const SHELL_TOOL_NAME = 'run_shell_command'; +export const WRITE_FILE_TOOL_NAME = 'write_file'; + +// ============================================================================ +// READ_FILE TOOL +// ============================================================================ + +export const READ_FILE_DEFINITION: ToolDefinition = { + base: { + name: READ_FILE_TOOL_NAME, + description: `Reads and returns the content of a specified file. If the file is large, the content will be truncated. The tool's response will clearly indicate if truncation has occurred and will provide details on how to read more of the file using the 'offset' and 'limit' parameters. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), audio files (MP3, WAV, AIFF, AAC, OGG, FLAC), and PDF files. For text files, it can read specific line ranges.`, + parametersJsonSchema: { + type: Type.OBJECT, + properties: { + file_path: { + description: 'The path to the file to read.', + type: Type.STRING, + }, + offset: { + description: + "Optional: For text files, the 0-based line number to start reading from. Requires 'limit' to be set. Use for paginating through large files.", + type: Type.NUMBER, + }, + limit: { + description: + "Optional: For text files, maximum number of lines to read. Use with 'offset' to paginate through large files. If omitted, reads the entire file (if feasible, up to a default limit).", + type: Type.NUMBER, + }, + }, + required: ['file_path'], + }, + }, +}; + +// ============================================================================ +// WRITE_FILE TOOL +// ============================================================================ + +export const WRITE_FILE_DEFINITION: ToolDefinition = { + base: { + name: WRITE_FILE_TOOL_NAME, + description: `Writes content to a specified file in the local filesystem. + + The user has the ability to modify \`content\`. If modified, this will be stated in the response.`, + parametersJsonSchema: { + type: Type.OBJECT, + properties: { + file_path: { + description: 'The path to the file to write to.', + type: Type.STRING, + }, + content: { + description: 'The content to write to the file.', + type: Type.STRING, + }, + }, + required: ['file_path', 'content'], + }, + }, +}; + +// ============================================================================ +// GREP TOOL +// ============================================================================ + +export const GREP_DEFINITION: ToolDefinition = { + base: { + name: GREP_TOOL_NAME, + description: + 'Searches for a regular expression pattern within file contents. Max 100 matches.', + parametersJsonSchema: { + type: Type.OBJECT, + properties: { + pattern: { + description: `The regular expression (regex) pattern to search for within file contents (e.g., 'function\\s+myFunction', 'import\\s+\\{.*\\}\\s+from\\s+.*').`, + type: Type.STRING, + }, + dir_path: { + description: + 'Optional: The absolute path to the directory to search within. If omitted, searches the current working directory.', + type: Type.STRING, + }, + include: { + description: `Optional: A glob pattern to filter which files are searched (e.g., '*.js', '*.{ts,tsx}', 'src/**'). If omitted, searches all files (respecting potential global ignores).`, + type: Type.STRING, + }, + }, + required: ['pattern'], + }, + }, +}; + +// ============================================================================ +// GLOB TOOL +// ============================================================================ + +export const GLOB_DEFINITION: ToolDefinition = { + base: { + name: GLOB_TOOL_NAME, + description: + 'Efficiently finds files matching specific glob patterns (e.g., `src/**/*.ts`, `**/*.md`), returning absolute paths sorted by modification time (newest first). Ideal for quickly locating files based on their name or path structure, especially in large codebases.', + parametersJsonSchema: { + type: Type.OBJECT, + properties: { + pattern: { + description: + "The glob pattern to match against (e.g., '**/*.py', 'docs/*.md').", + type: Type.STRING, + }, + dir_path: { + description: + 'Optional: The absolute path to the directory to search within. If omitted, searches the root directory.', + type: Type.STRING, + }, + case_sensitive: { + description: + 'Optional: Whether the search should be case-sensitive. Defaults to false.', + type: Type.BOOLEAN, + }, + respect_git_ignore: { + description: + 'Optional: Whether to respect .gitignore patterns when finding files. Only available in git repositories. Defaults to true.', + type: Type.BOOLEAN, + }, + respect_gemini_ignore: { + description: + 'Optional: Whether to respect .geminiignore patterns when finding files. Defaults to true.', + type: Type.BOOLEAN, + }, + }, + required: ['pattern'], + }, + }, +}; + +// ============================================================================ +// LS TOOL +// ============================================================================ + +export const LS_DEFINITION: ToolDefinition = { + base: { + name: LS_TOOL_NAME, + description: + 'Lists the names of files and subdirectories directly within a specified directory path. Can optionally ignore entries matching provided glob patterns.', + parametersJsonSchema: { + type: Type.OBJECT, + properties: { + dir_path: { + description: 'The path to the directory to list', + type: Type.STRING, + }, + ignore: { + description: 'List of glob patterns to ignore', + items: { + type: Type.STRING, + }, + type: Type.ARRAY, + }, + file_filtering_options: { + description: + 'Optional: Whether to respect ignore patterns from .gitignore or .geminiignore', + type: Type.OBJECT, + properties: { + respect_git_ignore: { + description: + 'Optional: Whether to respect .gitignore patterns when listing files. Only available in git repositories. Defaults to true.', + type: Type.BOOLEAN, + }, + respect_gemini_ignore: { + description: + 'Optional: Whether to respect .geminiignore patterns when listing files. Defaults to true.', + type: Type.BOOLEAN, + }, + }, + }, + }, + required: ['dir_path'], + }, + }, +}; + +// ============================================================================ +// SHELL TOOL +// ============================================================================ + +/** + * Generates the platform-specific description for the shell tool. + */ +export function getShellToolDescription( + enableInteractiveShell: boolean, + enableEfficiency: boolean, +): string { + const efficiencyGuidelines = enableEfficiency + ? ` + + Efficiency Guidelines: + - Quiet Flags: Always prefer silent or quiet flags (e.g., \`npm install --silent\`, \`git --no-pager\`) to reduce output volume while still capturing necessary information. + - Pagination: Always disable terminal pagination to ensure commands terminate (e.g., use \`git --no-pager\`, \`systemctl --no-pager\`, or set \`PAGER=cat\`).` + : ''; + + const returnedInfo = ` + + The following information is returned: + + Output: Combined stdout/stderr. Can be \`(empty)\` or partial on error and for any unwaited background processes. + Exit Code: Only included if non-zero (command failed). + Error: Only included if a process-level error occurred (e.g., spawn failure). + Signal: Only included if process was terminated by a signal. + Background PIDs: Only included if background processes were started. + Process Group PGID: Only included if available.`; + + if (os.platform() === 'win32') { + const backgroundInstructions = enableInteractiveShell + ? 'To run a command in the background, set the `is_background` parameter to true. Do NOT use PowerShell background constructs.' + : 'Command can start background processes using PowerShell constructs such as `Start-Process -NoNewWindow` or `Start-Job`.'; + return `This tool executes a given shell command as \`powershell.exe -NoProfile -Command \`. ${backgroundInstructions}${efficiencyGuidelines}${returnedInfo}`; + } else { + const backgroundInstructions = enableInteractiveShell + ? 'To run a command in the background, set the `is_background` parameter to true. Do NOT use `&` to background commands.' + : 'Command can start background processes using `&`.'; + return `This tool executes a given shell command as \`bash -c \`. ${backgroundInstructions} Command is executed as a subprocess that leads its own process group. Command process group can be terminated as \`kill -- -PGID\` or signaled as \`kill -s SIGNAL -- -PGID\`.${efficiencyGuidelines}${returnedInfo}`; + } +} + +/** + * Returns the platform-specific description for the 'command' parameter. + */ +export function getCommandDescription(): string { + if (os.platform() === 'win32') { + return 'Exact command to execute as `powershell.exe -NoProfile -Command `'; + } + return 'Exact bash command to execute as `bash -c `'; +} + +/** + * Returns the tool definition for the shell tool, customized for the platform. + */ +export function getShellDefinition( + enableInteractiveShell: boolean, + enableEfficiency: boolean, +): ToolDefinition { + return { + base: { + name: SHELL_TOOL_NAME, + description: getShellToolDescription( + enableInteractiveShell, + enableEfficiency, + ), + parametersJsonSchema: { + type: Type.OBJECT, + properties: { + command: { + type: Type.STRING, + description: getCommandDescription(), + }, + description: { + type: Type.STRING, + description: + 'Brief description of the command for the user. Be specific and concise. Ideally a single sentence. Can be up to 3 sentences for clarity. No line breaks.', + }, + dir_path: { + type: Type.STRING, + description: + '(OPTIONAL) The path of the directory to run the command in. If not provided, the project root directory is used. Must be a directory within the workspace and must already exist.', + }, + is_background: { + type: Type.BOOLEAN, + description: + 'Set to true if this command should be run in the background (e.g. for long-running servers or watchers). The command will be started, allowed to run for a brief moment to check for immediate errors, and then moved to the background.', + }, + }, + required: ['command'], + }, + }, + }; +} diff --git a/packages/core/src/tools/definitions/resolver.test.ts b/packages/core/src/tools/definitions/resolver.test.ts new file mode 100644 index 0000000000..a765608ac7 --- /dev/null +++ b/packages/core/src/tools/definitions/resolver.test.ts @@ -0,0 +1,40 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { Type } from '@google/genai'; +import { resolveToolDeclaration } from './resolver.js'; +import type { ToolDefinition } from './types.js'; + +describe('resolveToolDeclaration', () => { + const mockDefinition: ToolDefinition = { + base: { + name: 'test_tool', + description: 'A test tool description', + parameters: { + type: Type.OBJECT, + properties: { + param1: { type: Type.STRING }, + }, + }, + }, + }; + + it('should return the base definition when no modelId is provided', () => { + const result = resolveToolDeclaration(mockDefinition); + expect(result).toEqual(mockDefinition.base); + }); + + it('should return the base definition when a modelId is provided (current implementation)', () => { + const result = resolveToolDeclaration(mockDefinition, 'gemini-1.5-pro'); + expect(result).toEqual(mockDefinition.base); + }); + + it('should return the same object reference as base (current implementation)', () => { + const result = resolveToolDeclaration(mockDefinition); + expect(result).toBe(mockDefinition.base); + }); +}); diff --git a/packages/core/src/tools/definitions/resolver.ts b/packages/core/src/tools/definitions/resolver.ts new file mode 100644 index 0000000000..8176e48104 --- /dev/null +++ b/packages/core/src/tools/definitions/resolver.ts @@ -0,0 +1,22 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { type FunctionDeclaration } from '@google/genai'; +import type { ToolDefinition } from './types.js'; + +/** + * Resolves the declaration for a tool. + * + * @param definition The tool definition containing the base declaration. + * @param _modelId Optional model identifier (ignored in this plain refactor). + * @returns The FunctionDeclaration to be sent to the API. + */ +export function resolveToolDeclaration( + definition: ToolDefinition, + _modelId?: string, +): FunctionDeclaration { + return definition.base; +} diff --git a/packages/core/src/tools/definitions/types.ts b/packages/core/src/tools/definitions/types.ts new file mode 100644 index 0000000000..dc928e0a66 --- /dev/null +++ b/packages/core/src/tools/definitions/types.ts @@ -0,0 +1,15 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { type FunctionDeclaration } from '@google/genai'; + +/** + * Defines a tool's identity using a structured declaration. + */ +export interface ToolDefinition { + /** The base declaration for the tool. */ + base: FunctionDeclaration; +} diff --git a/packages/core/src/tools/read-file.test.ts b/packages/core/src/tools/read-file.test.ts index 15071f2620..494b007dec 100644 --- a/packages/core/src/tools/read-file.test.ts +++ b/packages/core/src/tools/read-file.test.ts @@ -563,4 +563,19 @@ describe('ReadFileTool', () => { }); }); }); + + describe('getSchema', () => { + it('should return the base schema when no modelId is provided', () => { + const schema = tool.getSchema(); + expect(schema.name).toBe(ReadFileTool.Name); + expect(schema.description).toMatchSnapshot(); + }); + + it('should return the schema from the resolver when modelId is provided', () => { + const modelId = 'gemini-2.0-flash'; + const schema = tool.getSchema(modelId); + expect(schema.name).toBe(ReadFileTool.Name); + expect(schema.description).toMatchSnapshot(); + }); + }); }); diff --git a/packages/core/src/tools/read-file.ts b/packages/core/src/tools/read-file.ts index b71f5c8e29..8aa823ecda 100644 --- a/packages/core/src/tools/read-file.ts +++ b/packages/core/src/tools/read-file.ts @@ -23,6 +23,8 @@ import { logFileOperation } from '../telemetry/loggers.js'; import { FileOperationEvent } from '../telemetry/types.js'; import { READ_FILE_TOOL_NAME } from './tool-names.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; +import { READ_FILE_DEFINITION } from './definitions/coreTools.js'; +import { resolveToolDeclaration } from './definitions/resolver.js'; /** * Parameters for the ReadFile tool @@ -172,28 +174,9 @@ export class ReadFileTool extends BaseDeclarativeTool< super( ReadFileTool.Name, 'ReadFile', - `Reads and returns the content of a specified file. If the file is large, the content will be truncated. The tool's response will clearly indicate if truncation has occurred and will provide details on how to read more of the file using the 'offset' and 'limit' parameters. Handles text, images (PNG, JPG, GIF, WEBP, SVG, BMP), audio files (MP3, WAV, AIFF, AAC, OGG, FLAC), and PDF files. For text files, it can read specific line ranges.`, + READ_FILE_DEFINITION.base.description!, Kind.Read, - { - properties: { - file_path: { - description: 'The path to the file to read.', - type: 'string', - }, - offset: { - description: - "Optional: For text files, the 0-based line number to start reading from. Requires 'limit' to be set. Use for paginating through large files.", - type: 'number', - }, - limit: { - description: - "Optional: For text files, maximum number of lines to read. Use with 'offset' to paginate through large files. If omitted, reads the entire file (if feasible, up to a default limit).", - type: 'number', - }, - }, - required: ['file_path'], - type: 'object', - }, + READ_FILE_DEFINITION.base.parameters!, messageBus, true, false, @@ -258,4 +241,8 @@ export class ReadFileTool extends BaseDeclarativeTool< _toolDisplayName, ); } + + override getSchema(modelId?: string) { + return resolveToolDeclaration(READ_FILE_DEFINITION, modelId); + } } diff --git a/packages/core/src/tools/shell.test.ts b/packages/core/src/tools/shell.test.ts index e1b16f0a4a..5fc3ca7f25 100644 --- a/packages/core/src/tools/shell.test.ts +++ b/packages/core/src/tools/shell.test.ts @@ -825,4 +825,19 @@ describe('ShellTool', () => { } }); }); + + describe('getSchema', () => { + it('should return the base schema when no modelId is provided', () => { + const schema = shellTool.getSchema(); + expect(schema.name).toBe(SHELL_TOOL_NAME); + expect(schema.description).toMatchSnapshot(); + }); + + it('should return the schema from the resolver when modelId is provided', () => { + const modelId = 'gemini-2.0-flash'; + const schema = shellTool.getSchema(modelId); + expect(schema.name).toBe(SHELL_TOOL_NAME); + expect(schema.description).toMatchSnapshot(); + }); + }); }); diff --git a/packages/core/src/tools/shell.ts b/packages/core/src/tools/shell.ts index 1c7192e254..ff20b8a7b2 100644 --- a/packages/core/src/tools/shell.ts +++ b/packages/core/src/tools/shell.ts @@ -43,6 +43,8 @@ import { } from '../utils/shell-utils.js'; import { SHELL_TOOL_NAME } from './tool-names.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { getShellDefinition } from './definitions/coreTools.js'; +import { resolveToolDeclaration } from './definitions/resolver.js'; export const OUTPUT_UPDATE_INTERVAL_MS = 1000; @@ -451,50 +453,6 @@ export class ShellToolInvocation extends BaseToolInvocation< } } -function getShellToolDescription( - enableInteractiveShell: boolean, - enableEfficiency: boolean, -): string { - const efficiencyGuidelines = enableEfficiency - ? ` - - Efficiency Guidelines: - - Quiet Flags: Always prefer silent or quiet flags (e.g., \`npm install --silent\`, \`git --no-pager\`) to reduce output volume while still capturing necessary information. - - Pagination: Always disable terminal pagination to ensure commands terminate (e.g., use \`git --no-pager\`, \`systemctl --no-pager\`, or set \`PAGER=cat\`).` - : ''; - - const returnedInfo = ` - - The following information is returned: - - Output: Combined stdout/stderr. Can be \`(empty)\` or partial on error and for any unwaited background processes. - Exit Code: Only included if non-zero (command failed). - Error: Only included if a process-level error occurred (e.g., spawn failure). - Signal: Only included if process was terminated by a signal. - Background PIDs: Only included if background processes were started. - Process Group PGID: Only included if available.`; - - if (os.platform() === 'win32') { - const backgroundInstructions = enableInteractiveShell - ? 'To run a command in the background, set the `is_background` parameter to true. Do NOT use PowerShell background constructs.' - : 'Command can start background processes using PowerShell constructs such as `Start-Process -NoNewWindow` or `Start-Job`.'; - return `This tool executes a given shell command as \`powershell.exe -NoProfile -Command \`. ${backgroundInstructions}${efficiencyGuidelines}${returnedInfo}`; - } else { - const backgroundInstructions = enableInteractiveShell - ? 'To run a command in the background, set the `is_background` parameter to true. Do NOT use `&` to background commands.' - : 'Command can start background processes using `&`.'; - return `This tool executes a given shell command as \`bash -c \`. ${backgroundInstructions} Command is executed as a subprocess that leads its own process group. Command process group can be terminated as \`kill -- -PGID\` or signaled as \`kill -s SIGNAL -- -PGID\`.${efficiencyGuidelines}${returnedInfo}`; - } -} - -function getCommandDescription(): string { - if (os.platform() === 'win32') { - return 'Exact command to execute as `powershell.exe -NoProfile -Command `'; - } else { - return 'Exact bash command to execute as `bash -c `'; - } -} - export class ShellTool extends BaseDeclarativeTool< ShellToolParams, ToolResult @@ -508,39 +466,16 @@ export class ShellTool extends BaseDeclarativeTool< void initializeShellParsers().catch(() => { // Errors are surfaced when parsing commands. }); + const definition = getShellDefinition( + config.getEnableInteractiveShell(), + config.getEnableShellOutputEfficiency(), + ); super( ShellTool.Name, 'Shell', - getShellToolDescription( - config.getEnableInteractiveShell(), - config.getEnableShellOutputEfficiency(), - ), + definition.base.description!, Kind.Execute, - { - type: 'object', - properties: { - command: { - type: 'string', - description: getCommandDescription(), - }, - description: { - type: 'string', - description: - 'Brief description of the command for the user. Be specific and concise. Ideally a single sentence. Can be up to 3 sentences for clarity. No line breaks.', - }, - dir_path: { - type: 'string', - description: - '(OPTIONAL) The path of the directory to run the command in. If not provided, the project root directory is used. Must be a directory within the workspace and must already exist.', - }, - is_background: { - type: 'boolean', - description: - 'Set to true if this command should be run in the background (e.g. for long-running servers or watchers). The command will be started, allowed to run for a brief moment to check for immediate errors, and then moved to the background.', - }, - }, - required: ['command'], - }, + definition.base.parametersJsonSchema, messageBus, false, // output is not markdown true, // output can be updated @@ -578,4 +513,12 @@ export class ShellTool extends BaseDeclarativeTool< _toolDisplayName, ); } + + override getSchema(modelId?: string) { + const definition = getShellDefinition( + this.config.getEnableInteractiveShell(), + this.config.getEnableShellOutputEfficiency(), + ); + return resolveToolDeclaration(definition, modelId); + } } diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index c26349f50f..963830200d 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -261,6 +261,17 @@ describe('ToolRegistry', () => { toolRegistry.registerTool(tool); expect(toolRegistry.getTool('mock-tool')).toBe(tool); }); + + it('should pass modelId to getSchema when getting function declarations', () => { + const tool = new MockTool({ name: 'mock-tool' }); + const getSchemaSpy = vi.spyOn(tool, 'getSchema'); + toolRegistry.registerTool(tool); + + const modelId = 'test-model-id'; + toolRegistry.getFunctionDeclarations(modelId); + + expect(getSchemaSpy).toHaveBeenCalledWith(modelId); + }); }); describe('excluded tools', () => { diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index ae4278986b..94082dcb57 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -498,12 +498,13 @@ export class ToolRegistry { * Retrieves the list of tool schemas (FunctionDeclaration array). * Extracts the declarations from the ToolListUnion structure. * Includes discovered (vs registered) tools if configured. + * @param modelId Optional model identifier to get model-specific schemas. * @returns An array of FunctionDeclarations. */ - getFunctionDeclarations(): FunctionDeclaration[] { + getFunctionDeclarations(modelId?: string): FunctionDeclaration[] { const declarations: FunctionDeclaration[] = []; this.getActiveTools().forEach((tool) => { - declarations.push(tool.schema); + declarations.push(tool.getSchema(modelId)); }); return declarations; } @@ -511,14 +512,18 @@ export class ToolRegistry { /** * Retrieves a filtered list of tool schemas based on a list of tool names. * @param toolNames - An array of tool names to include. + * @param modelId Optional model identifier to get model-specific schemas. * @returns An array of FunctionDeclarations for the specified tools. */ - getFunctionDeclarationsFiltered(toolNames: string[]): FunctionDeclaration[] { + getFunctionDeclarationsFiltered( + toolNames: string[], + modelId?: string, + ): FunctionDeclaration[] { const declarations: FunctionDeclaration[] = []; for (const name of toolNames) { const tool = this.getTool(name); if (tool) { - declarations.push(tool.schema); + declarations.push(tool.getSchema(modelId)); } } return declarations; diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 65aeb0884f..2811653b20 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -312,8 +312,15 @@ export interface ToolBuilder< /** * Function declaration schema from @google/genai. + * @param modelId Optional model identifier to get a model-specific schema. */ - schema: FunctionDeclaration; + getSchema(modelId?: string): FunctionDeclaration; + + /** + * Function declaration schema for the default model. + * @deprecated Use getSchema(modelId) for model-specific schemas. + */ + readonly schema: FunctionDeclaration; /** * Whether the tool's output should be rendered as markdown. @@ -355,7 +362,7 @@ export abstract class DeclarativeTool< readonly extensionId?: string, ) {} - get schema(): FunctionDeclaration { + getSchema(_modelId?: string): FunctionDeclaration { return { name: this.name, description: this.description, @@ -363,6 +370,10 @@ export abstract class DeclarativeTool< }; } + get schema(): FunctionDeclaration { + return this.getSchema(); + } + /** * Validates the raw tool parameters. * Subclasses should override this to add custom validation logic