mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-24 03:54:43 -07:00
refactor(core): extract static concerns from CoreToolScheduler (#15589)
This commit is contained in:
@@ -16,7 +16,7 @@ import {
|
||||
} from './checkpointUtils.js';
|
||||
import type { GitService } from '../services/gitService.js';
|
||||
import type { GeminiClient } from '../core/client.js';
|
||||
import type { ToolCallRequestInfo } from '../core/turn.js';
|
||||
import type { ToolCallRequestInfo } from '../scheduler/types.js';
|
||||
|
||||
describe('checkpoint utils', () => {
|
||||
describe('getToolCallDataSchema', () => {
|
||||
|
||||
@@ -10,7 +10,7 @@ import type { GeminiClient } from '../core/client.js';
|
||||
import { getErrorMessage } from './errors.js';
|
||||
import { z } from 'zod';
|
||||
import type { Content } from '@google/genai';
|
||||
import type { ToolCallRequestInfo } from '../core/turn.js';
|
||||
import type { ToolCallRequestInfo } from '../scheduler/types.js';
|
||||
|
||||
export interface ToolCallData<HistoryType = unknown, ArgsType = unknown> {
|
||||
history?: HistoryType;
|
||||
|
||||
@@ -32,6 +32,7 @@ import {
|
||||
readFileWithEncoding,
|
||||
fileExists,
|
||||
readWasmBinaryFromDisk,
|
||||
saveTruncatedContent,
|
||||
} from './fileUtils.js';
|
||||
import { StandardFileSystemService } from '../services/fileSystemService.js';
|
||||
|
||||
@@ -1022,4 +1023,213 @@ describe('fileUtils', () => {
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('saveTruncatedContent', () => {
|
||||
const THRESHOLD = 40_000;
|
||||
const TRUNCATE_LINES = 1000;
|
||||
|
||||
it('should return content unchanged if below threshold', async () => {
|
||||
const content = 'Short content';
|
||||
const callId = 'test-call-id';
|
||||
|
||||
const result = await saveTruncatedContent(
|
||||
content,
|
||||
callId,
|
||||
tempRootDir,
|
||||
THRESHOLD,
|
||||
TRUNCATE_LINES,
|
||||
);
|
||||
|
||||
expect(result).toEqual({ content });
|
||||
const outputFile = path.join(tempRootDir, `${callId}.output`);
|
||||
expect(await fileExists(outputFile)).toBe(false);
|
||||
});
|
||||
|
||||
it('should truncate content by lines when content has many lines', async () => {
|
||||
// Create content that exceeds 100,000 character threshold with many lines
|
||||
const lines = Array(2000).fill('x'.repeat(100));
|
||||
const content = lines.join('\n');
|
||||
const callId = 'test-call-id';
|
||||
|
||||
const result = await saveTruncatedContent(
|
||||
content,
|
||||
callId,
|
||||
tempRootDir,
|
||||
THRESHOLD,
|
||||
TRUNCATE_LINES,
|
||||
);
|
||||
|
||||
const expectedOutputFile = path.join(tempRootDir, `${callId}.output`);
|
||||
expect(result.outputFile).toBe(expectedOutputFile);
|
||||
|
||||
const savedContent = await fsPromises.readFile(
|
||||
expectedOutputFile,
|
||||
'utf-8',
|
||||
);
|
||||
expect(savedContent).toBe(content);
|
||||
|
||||
// Should contain the first and last lines with 1/5 head and 4/5 tail
|
||||
const head = Math.floor(TRUNCATE_LINES / 5);
|
||||
const beginning = lines.slice(0, head);
|
||||
const end = lines.slice(-(TRUNCATE_LINES - head));
|
||||
const expectedTruncated =
|
||||
beginning.join('\n') +
|
||||
'\n... [CONTENT TRUNCATED] ...\n' +
|
||||
end.join('\n');
|
||||
|
||||
expect(result.content).toContain(
|
||||
'Tool output was too large and has been truncated',
|
||||
);
|
||||
expect(result.content).toContain('Truncated part of the output:');
|
||||
expect(result.content).toContain(expectedTruncated);
|
||||
});
|
||||
|
||||
it('should wrap and truncate content when content has few but long lines', async () => {
|
||||
const content = 'a'.repeat(200_000); // A single very long line
|
||||
const callId = 'test-call-id';
|
||||
const wrapWidth = 120;
|
||||
|
||||
// Manually wrap the content to generate the expected file content
|
||||
const wrappedLines: string[] = [];
|
||||
for (let i = 0; i < content.length; i += wrapWidth) {
|
||||
wrappedLines.push(content.substring(i, i + wrapWidth));
|
||||
}
|
||||
const expectedFileContent = wrappedLines.join('\n');
|
||||
|
||||
const result = await saveTruncatedContent(
|
||||
content,
|
||||
callId,
|
||||
tempRootDir,
|
||||
THRESHOLD,
|
||||
TRUNCATE_LINES,
|
||||
);
|
||||
|
||||
const expectedOutputFile = path.join(tempRootDir, `${callId}.output`);
|
||||
expect(result.outputFile).toBe(expectedOutputFile);
|
||||
|
||||
const savedContent = await fsPromises.readFile(
|
||||
expectedOutputFile,
|
||||
'utf-8',
|
||||
);
|
||||
expect(savedContent).toBe(expectedFileContent);
|
||||
|
||||
// Should contain the first and last lines with 1/5 head and 4/5 tail of the wrapped content
|
||||
const head = Math.floor(TRUNCATE_LINES / 5);
|
||||
const beginning = wrappedLines.slice(0, head);
|
||||
const end = wrappedLines.slice(-(TRUNCATE_LINES - head));
|
||||
const expectedTruncated =
|
||||
beginning.join('\n') +
|
||||
'\n... [CONTENT TRUNCATED] ...\n' +
|
||||
end.join('\n');
|
||||
expect(result.content).toContain(
|
||||
'Tool output was too large and has been truncated',
|
||||
);
|
||||
expect(result.content).toContain('Truncated part of the output:');
|
||||
expect(result.content).toContain(expectedTruncated);
|
||||
});
|
||||
|
||||
it('should save to correct file path with call ID', async () => {
|
||||
const content = 'a'.repeat(200_000);
|
||||
const callId = 'unique-call-123';
|
||||
const wrapWidth = 120;
|
||||
|
||||
// Manually wrap the content to generate the expected file content
|
||||
const wrappedLines: string[] = [];
|
||||
for (let i = 0; i < content.length; i += wrapWidth) {
|
||||
wrappedLines.push(content.substring(i, i + wrapWidth));
|
||||
}
|
||||
const expectedFileContent = wrappedLines.join('\n');
|
||||
|
||||
const result = await saveTruncatedContent(
|
||||
content,
|
||||
callId,
|
||||
tempRootDir,
|
||||
THRESHOLD,
|
||||
TRUNCATE_LINES,
|
||||
);
|
||||
|
||||
const expectedPath = path.join(tempRootDir, `${callId}.output`);
|
||||
expect(result.outputFile).toBe(expectedPath);
|
||||
|
||||
const savedContent = await fsPromises.readFile(expectedPath, 'utf-8');
|
||||
expect(savedContent).toBe(expectedFileContent);
|
||||
});
|
||||
|
||||
it('should include helpful instructions in truncated message', async () => {
|
||||
const content = 'a'.repeat(200_000);
|
||||
const callId = 'test-call-id';
|
||||
|
||||
const result = await saveTruncatedContent(
|
||||
content,
|
||||
callId,
|
||||
tempRootDir,
|
||||
THRESHOLD,
|
||||
TRUNCATE_LINES,
|
||||
);
|
||||
|
||||
expect(result.content).toContain(
|
||||
'read_file tool with the absolute file path above',
|
||||
);
|
||||
expect(result.content).toContain(
|
||||
'read_file tool with offset=0, limit=100',
|
||||
);
|
||||
expect(result.content).toContain(
|
||||
'read_file tool with offset=N to skip N lines',
|
||||
);
|
||||
expect(result.content).toContain(
|
||||
'read_file tool with limit=M to read only M lines',
|
||||
);
|
||||
});
|
||||
|
||||
it('should sanitize callId to prevent path traversal', async () => {
|
||||
const content = 'a'.repeat(200_000);
|
||||
const callId = '../../../../../etc/passwd';
|
||||
const wrapWidth = 120;
|
||||
|
||||
// Manually wrap the content to generate the expected file content
|
||||
const wrappedLines: string[] = [];
|
||||
for (let i = 0; i < content.length; i += wrapWidth) {
|
||||
wrappedLines.push(content.substring(i, i + wrapWidth));
|
||||
}
|
||||
const expectedFileContent = wrappedLines.join('\n');
|
||||
|
||||
await saveTruncatedContent(
|
||||
content,
|
||||
callId,
|
||||
tempRootDir,
|
||||
THRESHOLD,
|
||||
TRUNCATE_LINES,
|
||||
);
|
||||
|
||||
const expectedPath = path.join(tempRootDir, 'passwd.output');
|
||||
|
||||
const savedContent = await fsPromises.readFile(expectedPath, 'utf-8');
|
||||
expect(savedContent).toBe(expectedFileContent);
|
||||
});
|
||||
|
||||
it('should handle file write errors gracefully', async () => {
|
||||
const content = 'a'.repeat(50_000);
|
||||
const callId = 'test-call-id-fail';
|
||||
|
||||
const writeFileSpy = vi
|
||||
.spyOn(fsPromises, 'writeFile')
|
||||
.mockRejectedValue(new Error('File write failed'));
|
||||
|
||||
const result = await saveTruncatedContent(
|
||||
content,
|
||||
callId,
|
||||
tempRootDir,
|
||||
THRESHOLD,
|
||||
TRUNCATE_LINES,
|
||||
);
|
||||
|
||||
expect(result.outputFile).toBeUndefined();
|
||||
expect(result.content).toContain(
|
||||
'[Note: Could not save full output to file]',
|
||||
);
|
||||
expect(writeFileSpy).toHaveBeenCalled();
|
||||
|
||||
writeFileSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -15,6 +15,7 @@ import { ToolErrorType } from '../tools/tool-error.js';
|
||||
import { BINARY_EXTENSIONS } from './ignorePatterns.js';
|
||||
import { createRequire as createModuleRequire } from 'node:module';
|
||||
import { debugLogger } from './debugLogger.js';
|
||||
import { READ_FILE_TOOL_NAME } from '../tools/tool-names.js';
|
||||
|
||||
const requireModule = createModuleRequire(import.meta.url);
|
||||
|
||||
@@ -515,3 +516,67 @@ export async function fileExists(filePath: string): Promise<boolean> {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export async function saveTruncatedContent(
|
||||
content: string,
|
||||
callId: string,
|
||||
projectTempDir: string,
|
||||
threshold: number,
|
||||
truncateLines: number,
|
||||
): Promise<{ content: string; outputFile?: string }> {
|
||||
if (content.length <= threshold) {
|
||||
return { content };
|
||||
}
|
||||
|
||||
let lines = content.split('\n');
|
||||
let fileContent = content;
|
||||
|
||||
// If the content is long but has few lines, wrap it to enable line-based truncation.
|
||||
if (lines.length <= truncateLines) {
|
||||
const wrapWidth = 120; // A reasonable width for wrapping.
|
||||
const wrappedLines: string[] = [];
|
||||
for (const line of lines) {
|
||||
if (line.length > wrapWidth) {
|
||||
for (let i = 0; i < line.length; i += wrapWidth) {
|
||||
wrappedLines.push(line.substring(i, i + wrapWidth));
|
||||
}
|
||||
} else {
|
||||
wrappedLines.push(line);
|
||||
}
|
||||
}
|
||||
lines = wrappedLines;
|
||||
fileContent = lines.join('\n');
|
||||
}
|
||||
|
||||
const head = Math.floor(truncateLines / 5);
|
||||
const beginning = lines.slice(0, head);
|
||||
const end = lines.slice(-(truncateLines - head));
|
||||
const truncatedContent =
|
||||
beginning.join('\n') + '\n... [CONTENT TRUNCATED] ...\n' + end.join('\n');
|
||||
|
||||
// Sanitize callId to prevent path traversal.
|
||||
const safeFileName = `${path.basename(callId)}.output`;
|
||||
const outputFile = path.join(projectTempDir, safeFileName);
|
||||
try {
|
||||
await fsPromises.writeFile(outputFile, fileContent);
|
||||
|
||||
return {
|
||||
content: `Tool output was too large and has been truncated.
|
||||
The full output has been saved to: ${outputFile}
|
||||
To read the complete output, use the ${READ_FILE_TOOL_NAME} tool with the absolute file path above. For large files, you can use the offset and limit parameters to read specific sections:
|
||||
- ${READ_FILE_TOOL_NAME} tool with offset=0, limit=100 to see the first 100 lines
|
||||
- ${READ_FILE_TOOL_NAME} tool with offset=N to skip N lines from the beginning
|
||||
- ${READ_FILE_TOOL_NAME} tool with limit=M to read only M lines at a time
|
||||
The truncated output below shows the beginning and end of the content. The marker '... [CONTENT TRUNCATED] ...' indicates where content was removed.
|
||||
This allows you to efficiently examine different parts of the output without loading the entire file.
|
||||
Truncated part of the output:
|
||||
${truncatedContent}`,
|
||||
outputFile,
|
||||
};
|
||||
} catch (_error) {
|
||||
return {
|
||||
content:
|
||||
truncatedContent + `\n[Note: Could not save full output to file]`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,14 +14,20 @@ import {
|
||||
getStructuredResponse,
|
||||
getStructuredResponseFromParts,
|
||||
getCitations,
|
||||
convertToFunctionResponse,
|
||||
} from './generateContentResponseUtilities.js';
|
||||
import type {
|
||||
GenerateContentResponse,
|
||||
Part,
|
||||
SafetyRating,
|
||||
CitationMetadata,
|
||||
PartListUnion,
|
||||
} from '@google/genai';
|
||||
import { FinishReason } from '@google/genai';
|
||||
import {
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
} from '../config/models.js';
|
||||
|
||||
const mockTextPart = (text: string): Part => ({ text });
|
||||
const mockFunctionCallPart = (
|
||||
@@ -72,6 +78,312 @@ const minimalMockResponse = (
|
||||
});
|
||||
|
||||
describe('generateContentResponseUtilities', () => {
|
||||
describe('convertToFunctionResponse', () => {
|
||||
const toolName = 'testTool';
|
||||
const callId = 'call1';
|
||||
|
||||
it('should handle simple string llmContent', () => {
|
||||
const llmContent = 'Simple text output';
|
||||
const result = convertToFunctionResponse(
|
||||
toolName,
|
||||
callId,
|
||||
llmContent,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
functionResponse: {
|
||||
name: toolName,
|
||||
id: callId,
|
||||
response: { output: 'Simple text output' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle llmContent as a single Part with text', () => {
|
||||
const llmContent: Part = { text: 'Text from Part object' };
|
||||
const result = convertToFunctionResponse(
|
||||
toolName,
|
||||
callId,
|
||||
llmContent,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
functionResponse: {
|
||||
name: toolName,
|
||||
id: callId,
|
||||
response: { output: 'Text from Part object' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle llmContent as a PartListUnion array with a single text Part', () => {
|
||||
const llmContent: PartListUnion = [{ text: 'Text from array' }];
|
||||
const result = convertToFunctionResponse(
|
||||
toolName,
|
||||
callId,
|
||||
llmContent,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
functionResponse: {
|
||||
name: toolName,
|
||||
id: callId,
|
||||
response: { output: 'Text from array' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle llmContent as a PartListUnion array with multiple Parts', () => {
|
||||
const llmContent: PartListUnion = [{ text: 'part1' }, { text: 'part2' }];
|
||||
const result = convertToFunctionResponse(
|
||||
toolName,
|
||||
callId,
|
||||
llmContent,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
functionResponse: {
|
||||
name: toolName,
|
||||
id: callId,
|
||||
response: { output: 'part1\npart2' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle llmContent with fileData for Gemini 3 model (should be siblings)', () => {
|
||||
const llmContent: Part = {
|
||||
fileData: { mimeType: 'application/pdf', fileUri: 'gs://...' },
|
||||
};
|
||||
const result = convertToFunctionResponse(
|
||||
toolName,
|
||||
callId,
|
||||
llmContent,
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
functionResponse: {
|
||||
name: toolName,
|
||||
id: callId,
|
||||
response: { output: 'Binary content provided (1 item(s)).' },
|
||||
},
|
||||
},
|
||||
llmContent,
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle llmContent with inlineData for Gemini 3 model (should be nested)', () => {
|
||||
const llmContent: Part = {
|
||||
inlineData: { mimeType: 'image/png', data: 'base64...' },
|
||||
};
|
||||
const result = convertToFunctionResponse(
|
||||
toolName,
|
||||
callId,
|
||||
llmContent,
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
functionResponse: {
|
||||
name: toolName,
|
||||
id: callId,
|
||||
response: { output: 'Binary content provided (1 item(s)).' },
|
||||
parts: [llmContent],
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle llmContent with fileData for non-Gemini 3 models', () => {
|
||||
const llmContent: Part = {
|
||||
fileData: { mimeType: 'application/pdf', fileUri: 'gs://...' },
|
||||
};
|
||||
const result = convertToFunctionResponse(
|
||||
toolName,
|
||||
callId,
|
||||
llmContent,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
functionResponse: {
|
||||
name: toolName,
|
||||
id: callId,
|
||||
response: { output: 'Binary content provided (1 item(s)).' },
|
||||
},
|
||||
},
|
||||
llmContent,
|
||||
]);
|
||||
});
|
||||
|
||||
it('should preserve existing functionResponse metadata', () => {
|
||||
const innerId = 'inner-call-id';
|
||||
const innerName = 'inner-tool-name';
|
||||
const responseMetadata = {
|
||||
flags: ['flag1'],
|
||||
isError: false,
|
||||
customData: { key: 'value' },
|
||||
};
|
||||
const input: Part = {
|
||||
functionResponse: {
|
||||
id: innerId,
|
||||
name: innerName,
|
||||
response: responseMetadata,
|
||||
},
|
||||
};
|
||||
|
||||
const result = convertToFunctionResponse(
|
||||
toolName,
|
||||
callId,
|
||||
input,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
);
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].functionResponse).toEqual({
|
||||
id: callId,
|
||||
name: toolName,
|
||||
response: responseMetadata,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle llmContent as an array of multiple Parts (text and inlineData)', () => {
|
||||
const llmContent: PartListUnion = [
|
||||
{ text: 'Some textual description' },
|
||||
{ inlineData: { mimeType: 'image/jpeg', data: 'base64data...' } },
|
||||
{ text: 'Another text part' },
|
||||
];
|
||||
const result = convertToFunctionResponse(
|
||||
toolName,
|
||||
callId,
|
||||
llmContent,
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
functionResponse: {
|
||||
name: toolName,
|
||||
id: callId,
|
||||
response: {
|
||||
output: 'Some textual description\nAnother text part',
|
||||
},
|
||||
parts: [
|
||||
{
|
||||
inlineData: { mimeType: 'image/jpeg', data: 'base64data...' },
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle llmContent as an array with a single inlineData Part', () => {
|
||||
const llmContent: PartListUnion = [
|
||||
{ inlineData: { mimeType: 'image/gif', data: 'gifdata...' } },
|
||||
];
|
||||
const result = convertToFunctionResponse(
|
||||
toolName,
|
||||
callId,
|
||||
llmContent,
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
functionResponse: {
|
||||
name: toolName,
|
||||
id: callId,
|
||||
response: { output: 'Binary content provided (1 item(s)).' },
|
||||
parts: llmContent,
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle llmContent as a generic Part (not text, inlineData, or fileData)', () => {
|
||||
const llmContent: Part = { functionCall: { name: 'test', args: {} } };
|
||||
const result = convertToFunctionResponse(
|
||||
toolName,
|
||||
callId,
|
||||
llmContent,
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
functionResponse: {
|
||||
name: toolName,
|
||||
id: callId,
|
||||
response: {},
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle empty string llmContent', () => {
|
||||
const llmContent = '';
|
||||
const result = convertToFunctionResponse(
|
||||
toolName,
|
||||
callId,
|
||||
llmContent,
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
functionResponse: {
|
||||
name: toolName,
|
||||
id: callId,
|
||||
response: { output: '' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle llmContent as an empty array', () => {
|
||||
const llmContent: PartListUnion = [];
|
||||
const result = convertToFunctionResponse(
|
||||
toolName,
|
||||
callId,
|
||||
llmContent,
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
functionResponse: {
|
||||
name: toolName,
|
||||
id: callId,
|
||||
response: {},
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle llmContent as a Part with undefined inlineData/fileData/text', () => {
|
||||
const llmContent: Part = {}; // An empty part object
|
||||
const result = convertToFunctionResponse(
|
||||
toolName,
|
||||
callId,
|
||||
llmContent,
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
functionResponse: {
|
||||
name: toolName,
|
||||
id: callId,
|
||||
response: {},
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getCitations', () => {
|
||||
it('should return empty array for no candidates', () => {
|
||||
expect(getCitations(minimalMockResponse(undefined))).toEqual([]);
|
||||
|
||||
@@ -8,8 +8,125 @@ import type {
|
||||
GenerateContentResponse,
|
||||
Part,
|
||||
FunctionCall,
|
||||
PartListUnion,
|
||||
} from '@google/genai';
|
||||
import { getResponseText } from './partUtils.js';
|
||||
import { supportsMultimodalFunctionResponse } from '../config/models.js';
|
||||
import { debugLogger } from './debugLogger.js';
|
||||
|
||||
/**
|
||||
* Formats tool output for a Gemini FunctionResponse.
|
||||
*/
|
||||
function createFunctionResponsePart(
|
||||
callId: string,
|
||||
toolName: string,
|
||||
output: string,
|
||||
): Part {
|
||||
return {
|
||||
functionResponse: {
|
||||
id: callId,
|
||||
name: toolName,
|
||||
response: { output },
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function toParts(input: PartListUnion): Part[] {
|
||||
const parts: Part[] = [];
|
||||
for (const part of Array.isArray(input) ? input : [input]) {
|
||||
if (typeof part === 'string') {
|
||||
parts.push({ text: part });
|
||||
} else if (part) {
|
||||
parts.push(part);
|
||||
}
|
||||
}
|
||||
return parts;
|
||||
}
|
||||
|
||||
export function convertToFunctionResponse(
|
||||
toolName: string,
|
||||
callId: string,
|
||||
llmContent: PartListUnion,
|
||||
model: string,
|
||||
): Part[] {
|
||||
if (typeof llmContent === 'string') {
|
||||
return [createFunctionResponsePart(callId, toolName, llmContent)];
|
||||
}
|
||||
|
||||
const parts = toParts(llmContent);
|
||||
|
||||
// Separate text from binary types
|
||||
const textParts: string[] = [];
|
||||
const inlineDataParts: Part[] = [];
|
||||
const fileDataParts: Part[] = [];
|
||||
|
||||
for (const part of parts) {
|
||||
if (part.text !== undefined) {
|
||||
textParts.push(part.text);
|
||||
} else if (part.inlineData) {
|
||||
inlineDataParts.push(part);
|
||||
} else if (part.fileData) {
|
||||
fileDataParts.push(part);
|
||||
} else if (part.functionResponse) {
|
||||
if (parts.length > 1) {
|
||||
debugLogger.warn(
|
||||
'convertToFunctionResponse received multiple parts with a functionResponse. Only the functionResponse will be used, other parts will be ignored',
|
||||
);
|
||||
}
|
||||
// Handle passthrough case
|
||||
return [
|
||||
{
|
||||
functionResponse: {
|
||||
id: callId,
|
||||
name: toolName,
|
||||
response: part.functionResponse.response,
|
||||
},
|
||||
},
|
||||
];
|
||||
}
|
||||
// Ignore other part types
|
||||
}
|
||||
|
||||
// Build the primary response part
|
||||
const part: Part = {
|
||||
functionResponse: {
|
||||
id: callId,
|
||||
name: toolName,
|
||||
response: textParts.length > 0 ? { output: textParts.join('\n') } : {},
|
||||
},
|
||||
};
|
||||
|
||||
const isMultimodalFRSupported = supportsMultimodalFunctionResponse(model);
|
||||
const siblingParts: Part[] = [...fileDataParts];
|
||||
|
||||
if (inlineDataParts.length > 0) {
|
||||
if (isMultimodalFRSupported) {
|
||||
// Nest inlineData if supported by the model
|
||||
(part.functionResponse as unknown as { parts: Part[] }).parts =
|
||||
inlineDataParts;
|
||||
} else {
|
||||
// Otherwise treat as siblings
|
||||
siblingParts.push(...inlineDataParts);
|
||||
}
|
||||
}
|
||||
|
||||
// Add descriptive text if the response object is empty but we have binary content
|
||||
if (
|
||||
textParts.length === 0 &&
|
||||
(inlineDataParts.length > 0 || fileDataParts.length > 0)
|
||||
) {
|
||||
const totalBinaryItems = inlineDataParts.length + fileDataParts.length;
|
||||
part.functionResponse!.response = {
|
||||
output: `Binary content provided (${totalBinaryItems} item(s)).`,
|
||||
};
|
||||
}
|
||||
|
||||
if (siblingParts.length > 0) {
|
||||
return [part, ...siblingParts];
|
||||
}
|
||||
|
||||
return [part];
|
||||
}
|
||||
|
||||
export function getResponseTextFromParts(parts: Part[]): string | undefined {
|
||||
if (!parts) {
|
||||
|
||||
@@ -5,10 +5,34 @@
|
||||
*/
|
||||
|
||||
import { expect, describe, it } from 'vitest';
|
||||
import { doesToolInvocationMatch } from './tool-utils.js';
|
||||
import { doesToolInvocationMatch, getToolSuggestion } from './tool-utils.js';
|
||||
import type { AnyToolInvocation, Config } from '../index.js';
|
||||
import { ReadFileTool } from '../tools/read-file.js';
|
||||
|
||||
describe('getToolSuggestion', () => {
|
||||
it('should suggest the top N closest tool names for a typo', () => {
|
||||
const allToolNames = ['list_files', 'read_file', 'write_file'];
|
||||
|
||||
// Test that the right tool is selected, with only 1 result, for typos
|
||||
const misspelledTool = getToolSuggestion('list_fils', allToolNames, 1);
|
||||
expect(misspelledTool).toBe(' Did you mean "list_files"?');
|
||||
|
||||
// Test that the right tool is selected, with only 1 result, for prefixes
|
||||
const prefixedTool = getToolSuggestion(
|
||||
'github.list_files',
|
||||
allToolNames,
|
||||
1,
|
||||
);
|
||||
expect(prefixedTool).toBe(' Did you mean "list_files"?');
|
||||
|
||||
// Test that the right tool is first
|
||||
const suggestionMultiple = getToolSuggestion('list_fils', allToolNames);
|
||||
expect(suggestionMultiple).toBe(
|
||||
' Did you mean one of: "list_files", "read_file", "write_file"?',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('doesToolInvocationMatch', () => {
|
||||
it('should not match a partial command prefix', () => {
|
||||
const invocation = {
|
||||
|
||||
@@ -7,6 +7,44 @@
|
||||
import type { AnyDeclarativeTool, AnyToolInvocation } from '../index.js';
|
||||
import { isTool } from '../index.js';
|
||||
import { SHELL_TOOL_NAMES } from './shell-utils.js';
|
||||
import levenshtein from 'fast-levenshtein';
|
||||
|
||||
/**
|
||||
* Generates a suggestion string for a tool name that was not found in the registry.
|
||||
* It finds the closest matches based on Levenshtein distance.
|
||||
* @param unknownToolName The tool name that was not found.
|
||||
* @param allToolNames The list of all available tool names.
|
||||
* @param topN The number of suggestions to return. Defaults to 3.
|
||||
* @returns A suggestion string like " Did you mean 'tool'?" or " Did you mean one of: 'tool1', 'tool2'?", or an empty string if no suggestions are found.
|
||||
*/
|
||||
export function getToolSuggestion(
|
||||
unknownToolName: string,
|
||||
allToolNames: string[],
|
||||
topN = 3,
|
||||
): string {
|
||||
const matches = allToolNames.map((toolName) => ({
|
||||
name: toolName,
|
||||
distance: levenshtein.get(unknownToolName, toolName),
|
||||
}));
|
||||
|
||||
matches.sort((a, b) => a.distance - b.distance);
|
||||
|
||||
const topNResults = matches.slice(0, topN);
|
||||
|
||||
if (topNResults.length === 0) {
|
||||
return '';
|
||||
}
|
||||
|
||||
const suggestedNames = topNResults
|
||||
.map((match) => `"${match.name}"`)
|
||||
.join(', ');
|
||||
|
||||
if (topNResults.length > 1) {
|
||||
return ` Did you mean one of: ${suggestedNames}?`;
|
||||
} else {
|
||||
return ` Did you mean ${suggestedNames}?`;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a tool invocation matches any of a list of patterns.
|
||||
|
||||
Reference in New Issue
Block a user