Refactored 4 files of tools package (#13235)

Co-authored-by: riddhi <duttariddhi@google.com>
This commit is contained in:
Riddhi Dutta
2025-11-18 01:01:29 +05:30
committed by GitHub
parent ba88707b1e
commit 1d1bdc57ce
4 changed files with 492 additions and 696 deletions
+124 -147
View File
@@ -176,53 +176,48 @@ describe('SmartEditTool', () => {
describe('calculateReplacement', () => {
const abortSignal = new AbortController().signal;
it('should perform an exact replacement', async () => {
const content = 'hello world';
const result = await calculateReplacement(mockConfig, {
params: {
file_path: 'test.txt',
instruction: 'test',
old_string: 'world',
new_string: 'moon',
},
currentContent: content,
abortSignal,
});
expect(result.newContent).toBe('hello moon');
expect(result.occurrences).toBe(1);
});
it('should perform a flexible, whitespace-insensitive replacement', async () => {
const content = ' hello\n world\n';
const result = await calculateReplacement(mockConfig, {
params: {
file_path: 'test.txt',
instruction: 'test',
old_string: 'hello\nworld',
new_string: 'goodbye\nmoon',
},
currentContent: content,
abortSignal,
});
expect(result.newContent).toBe(' goodbye\n moon\n');
expect(result.occurrences).toBe(1);
});
it('should return 0 occurrences if no match is found', async () => {
const content = 'hello world';
const result = await calculateReplacement(mockConfig, {
params: {
file_path: 'test.txt',
instruction: 'test',
old_string: 'nomatch',
new_string: 'moon',
},
currentContent: content,
abortSignal,
});
expect(result.newContent).toBe(content);
expect(result.occurrences).toBe(0);
});
it.each([
{
name: 'perform an exact replacement',
content: 'hello world',
old_string: 'world',
new_string: 'moon',
expected: 'hello moon',
occurrences: 1,
},
{
name: 'perform a flexible, whitespace-insensitive replacement',
content: ' hello\n world\n',
old_string: 'hello\nworld',
new_string: 'goodbye\nmoon',
expected: ' goodbye\n moon\n',
occurrences: 1,
},
{
name: 'return 0 occurrences if no match is found',
content: 'hello world',
old_string: 'nomatch',
new_string: 'moon',
expected: 'hello world',
occurrences: 0,
},
])(
'should $name',
async ({ content, old_string, new_string, expected, occurrences }) => {
const result = await calculateReplacement(mockConfig, {
params: {
file_path: 'test.txt',
instruction: 'test',
old_string,
new_string,
},
currentContent: content,
abortSignal,
});
expect(result.newContent).toBe(expected);
expect(result.occurrences).toBe(occurrences);
},
);
it('should perform a regex-based replacement for flexible intra-line whitespace', async () => {
// This case would fail with the previous exact and line-trimming flexible logic
@@ -496,60 +491,44 @@ describe('SmartEditTool', () => {
filePath = path.join(rootDir, testFile);
});
it('should return FILE_NOT_FOUND error', async () => {
const params: EditToolParams = {
file_path: filePath,
instruction: 'test',
old_string: 'any',
new_string: 'new',
};
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.error?.type).toBe(ToolErrorType.FILE_NOT_FOUND);
});
it('should return ATTEMPT_TO_CREATE_EXISTING_FILE error', async () => {
fs.writeFileSync(filePath, 'existing content', 'utf8');
const params: EditToolParams = {
file_path: filePath,
instruction: 'test',
old_string: '',
new_string: 'new content',
};
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.error?.type).toBe(
ToolErrorType.ATTEMPT_TO_CREATE_EXISTING_FILE,
);
});
it('should return NO_OCCURRENCE_FOUND error', async () => {
fs.writeFileSync(filePath, 'content', 'utf8');
const params: EditToolParams = {
file_path: filePath,
instruction: 'test',
old_string: 'not-found',
new_string: 'new',
};
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.error?.type).toBe(ToolErrorType.EDIT_NO_OCCURRENCE_FOUND);
});
it('should return EXPECTED_OCCURRENCE_MISMATCH error', async () => {
fs.writeFileSync(filePath, 'one one two', 'utf8');
const params: EditToolParams = {
file_path: filePath,
instruction: 'test',
old_string: 'one',
new_string: 'new',
};
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.error?.type).toBe(
ToolErrorType.EDIT_EXPECTED_OCCURRENCE_MISMATCH,
);
});
it.each([
{
name: 'FILE_NOT_FOUND',
setup: () => {}, // no file created
params: { old_string: 'any', new_string: 'new' },
expectedError: ToolErrorType.FILE_NOT_FOUND,
},
{
name: 'ATTEMPT_TO_CREATE_EXISTING_FILE',
setup: (fp: string) => fs.writeFileSync(fp, 'existing content', 'utf8'),
params: { old_string: '', new_string: 'new content' },
expectedError: ToolErrorType.ATTEMPT_TO_CREATE_EXISTING_FILE,
},
{
name: 'NO_OCCURRENCE_FOUND',
setup: (fp: string) => fs.writeFileSync(fp, 'content', 'utf8'),
params: { old_string: 'not-found', new_string: 'new' },
expectedError: ToolErrorType.EDIT_NO_OCCURRENCE_FOUND,
},
{
name: 'EXPECTED_OCCURRENCE_MISMATCH',
setup: (fp: string) => fs.writeFileSync(fp, 'one one two', 'utf8'),
params: { old_string: 'one', new_string: 'new' },
expectedError: ToolErrorType.EDIT_EXPECTED_OCCURRENCE_MISMATCH,
},
])(
'should return $name error',
async ({ setup, params, expectedError }) => {
setup(filePath);
const invocation = tool.build({
file_path: filePath,
instruction: 'test',
...params,
});
const result = await invocation.execute(new AbortController().signal);
expect(result.error?.type).toBe(expectedError);
},
);
});
describe('expected_replacements', () => {
@@ -560,53 +539,51 @@ describe('SmartEditTool', () => {
filePath = path.join(rootDir, testFile);
});
it('should succeed when occurrences match expected_replacements', async () => {
fs.writeFileSync(filePath, 'foo foo foo', 'utf8');
const params: EditToolParams = {
file_path: filePath,
instruction: 'Replace all foo with bar',
old_string: 'foo',
new_string: 'bar',
expected_replacements: 3,
};
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.error).toBeUndefined();
expect(fs.readFileSync(filePath, 'utf8')).toBe('bar bar bar');
});
it.each([
{
name: 'succeed when occurrences match expected_replacements',
content: 'foo foo foo',
expected: 3,
shouldSucceed: true,
finalContent: 'bar bar bar',
},
{
name: 'fail when occurrences do not match expected_replacements',
content: 'foo foo foo',
expected: 2,
shouldSucceed: false,
},
{
name: 'default to 1 expected replacement if not specified',
content: 'foo foo',
expected: undefined,
shouldSucceed: false,
},
])(
'should $name',
async ({ content, expected, shouldSucceed, finalContent }) => {
fs.writeFileSync(filePath, content, 'utf8');
const params: EditToolParams = {
file_path: filePath,
instruction: 'Replace all foo with bar',
old_string: 'foo',
new_string: 'bar',
...(expected !== undefined && { expected_replacements: expected }),
};
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
it('should fail when occurrences do not match expected_replacements', async () => {
fs.writeFileSync(filePath, 'foo foo foo', 'utf8');
const params: EditToolParams = {
file_path: filePath,
instruction: 'Replace all foo with bar',
old_string: 'foo',
new_string: 'bar',
expected_replacements: 2,
};
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.error?.type).toBe(
ToolErrorType.EDIT_EXPECTED_OCCURRENCE_MISMATCH,
);
});
it('should default to 1 expected replacement if not specified', async () => {
fs.writeFileSync(filePath, 'foo foo', 'utf8');
const params: EditToolParams = {
file_path: filePath,
instruction: 'Replace foo with bar',
old_string: 'foo',
new_string: 'bar',
// expected_replacements is undefined, defaults to 1
};
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
// Should fail because there are 2 occurrences but default expectation is 1
expect(result.error?.type).toBe(
ToolErrorType.EDIT_EXPECTED_OCCURRENCE_MISMATCH,
);
});
if (shouldSucceed) {
expect(result.error).toBeUndefined();
if (finalContent)
expect(fs.readFileSync(filePath, 'utf8')).toBe(finalContent);
} else {
expect(result.error?.type).toBe(
ToolErrorType.EDIT_EXPECTED_OCCURRENCE_MISMATCH,
);
}
},
);
});
describe('IDE mode', () => {
+85 -148
View File
@@ -92,6 +92,72 @@ const createMockCallableTool = (
callTool: vi.fn(),
});
// Helper to create a DiscoveredMCPTool
const createMCPTool = (
serverName: string,
toolName: string,
description: string,
mockCallable: CallableTool = {} as CallableTool,
) => new DiscoveredMCPTool(mockCallable, serverName, toolName, description, {});
// Helper to create a mock spawn process for tool discovery
const createDiscoveryProcess = (toolDeclarations: FunctionDeclaration[]) => {
const mockProcess = {
stdout: { on: vi.fn(), removeListener: vi.fn() },
stderr: { on: vi.fn(), removeListener: vi.fn() },
on: vi.fn(),
};
mockProcess.stdout.on.mockImplementation((event, callback) => {
if (event === 'data') {
callback(
Buffer.from(
JSON.stringify([{ functionDeclarations: toolDeclarations }]),
),
);
}
return mockProcess as any;
});
mockProcess.on.mockImplementation((event, callback) => {
if (event === 'close') {
callback(0);
}
return mockProcess as any;
});
return mockProcess;
};
// Helper to create a mock spawn process for tool execution
const createExecutionProcess = (exitCode: number, stderrMessage?: string) => {
const mockProcess = {
stdout: { on: vi.fn(), removeListener: vi.fn() },
stderr: { on: vi.fn(), removeListener: vi.fn() },
stdin: { write: vi.fn(), end: vi.fn() },
on: vi.fn(),
connected: true,
disconnect: vi.fn(),
removeListener: vi.fn(),
};
if (stderrMessage) {
mockProcess.stderr.on.mockImplementation((event, callback) => {
if (event === 'data') {
callback(Buffer.from(stderrMessage));
}
});
}
mockProcess.on.mockImplementation((event, callback) => {
if (event === 'close') {
callback(exitCode);
}
});
return mockProcess;
};
const baseConfigParams: ConfigParameters = {
cwd: '/tmp',
model: 'test-model',
@@ -165,13 +231,10 @@ describe('ToolRegistry', () => {
name: 'excluded-tool-class',
displayName: 'Excluded Tool Class',
});
const mockCallable = {} as CallableTool;
const mcpTool = new DiscoveredMCPTool(
mockCallable,
const mcpTool = createMCPTool(
'mcp-server',
'excluded-mcp-tool',
'description',
{},
);
const allowedTool = new MockTool({
name: 'allowed-tool',
@@ -271,36 +334,10 @@ describe('ToolRegistry', () => {
it('should return only tools matching the server name, sorted by name', async () => {
const server1Name = 'mcp-server-uno';
const server2Name = 'mcp-server-dos';
const mockCallable = {} as CallableTool;
const mcpTool1_c = new DiscoveredMCPTool(
mockCallable,
server1Name,
'zebra-tool',
'd1',
{},
);
const mcpTool1_a = new DiscoveredMCPTool(
mockCallable,
server1Name,
'apple-tool',
'd2',
{},
);
const mcpTool1_b = new DiscoveredMCPTool(
mockCallable,
server1Name,
'banana-tool',
'd3',
{},
);
const mcpTool2 = new DiscoveredMCPTool(
mockCallable,
server2Name,
'tool-on-server2',
'd4',
{},
);
const mcpTool1_c = createMCPTool(server1Name, 'zebra-tool', 'd1');
const mcpTool1_a = createMCPTool(server1Name, 'apple-tool', 'd2');
const mcpTool1_b = createMCPTool(server1Name, 'banana-tool', 'd3');
const mcpTool2 = createMCPTool(server2Name, 'tool-on-server2', 'd4');
const nonMcpTool = new MockTool({ name: 'regular-tool' });
toolRegistry.registerTool(mcpTool1_c);
@@ -339,21 +376,8 @@ describe('ToolRegistry', () => {
'desc',
{},
);
const mockCallable = {} as CallableTool;
const mcpZebra = new DiscoveredMCPTool(
mockCallable,
'zebra-server',
'mcp-zebra',
'desc',
{},
);
const mcpApple = new DiscoveredMCPTool(
mockCallable,
'apple-server',
'mcp-apple',
'desc',
{},
);
const mcpZebra = createMCPTool('zebra-server', 'mcp-zebra', 'desc');
const mcpApple = createMCPTool('apple-server', 'mcp-apple', 'desc');
// Register in mixed order
toolRegistry.registerTool(mcpZebra);
@@ -394,34 +418,9 @@ describe('ToolRegistry', () => {
};
const mockSpawn = vi.mocked(spawn);
const mockChildProcess = {
stdout: { on: vi.fn() },
stderr: { on: vi.fn() },
on: vi.fn(),
};
mockSpawn.mockReturnValue(mockChildProcess as any);
// Simulate stdout data
mockChildProcess.stdout.on.mockImplementation((event, callback) => {
if (event === 'data') {
callback(
Buffer.from(
JSON.stringify([
{ function_declarations: [unsanitizedToolDeclaration] },
]),
),
);
}
return mockChildProcess as any;
});
// Simulate process close
mockChildProcess.on.mockImplementation((event, callback) => {
if (event === 'close') {
callback(0);
}
return mockChildProcess as any;
});
mockSpawn.mockReturnValue(
createDiscoveryProcess([unsanitizedToolDeclaration]) as any,
);
await toolRegistry.discoverAllTools();
@@ -458,28 +457,9 @@ describe('ToolRegistry', () => {
};
const mockSpawn = vi.mocked(spawn);
// --- Discovery Mock ---
const discoveryProcess = {
stdout: { on: vi.fn(), removeListener: vi.fn() },
stderr: { on: vi.fn(), removeListener: vi.fn() },
on: vi.fn(),
};
mockSpawn.mockReturnValueOnce(discoveryProcess as any);
discoveryProcess.stdout.on.mockImplementation((event, callback) => {
if (event === 'data') {
callback(
Buffer.from(
JSON.stringify([{ functionDeclarations: [toolDeclaration] }]),
),
);
}
});
discoveryProcess.on.mockImplementation((event, callback) => {
if (event === 'close') {
callback(0);
}
});
mockSpawn.mockReturnValueOnce(
createDiscoveryProcess([toolDeclaration]) as any,
);
await toolRegistry.discoverAllTools();
const discoveredTool = toolRegistry.getTool(
@@ -487,28 +467,9 @@ describe('ToolRegistry', () => {
);
expect(discoveredTool).toBeDefined();
// --- Execution Mock ---
const executionProcess = {
stdout: { on: vi.fn(), removeListener: vi.fn() },
stderr: { on: vi.fn(), removeListener: vi.fn() },
stdin: { write: vi.fn(), end: vi.fn() },
on: vi.fn(),
connected: true,
disconnect: vi.fn(),
removeListener: vi.fn(),
};
mockSpawn.mockReturnValueOnce(executionProcess as any);
executionProcess.stderr.on.mockImplementation((event, callback) => {
if (event === 'data') {
callback(Buffer.from('Something went wrong'));
}
});
executionProcess.on.mockImplementation((event, callback) => {
if (event === 'close') {
callback(1); // Non-zero exit code
}
});
mockSpawn.mockReturnValueOnce(
createExecutionProcess(1, 'Something went wrong') as any,
);
const invocation = (discoveredTool as DiscoveredTool).build({});
const result = await invocation.execute(new AbortController().signal);
@@ -524,7 +485,6 @@ describe('ToolRegistry', () => {
const discoveryCommand = 'my-discovery-command';
mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand);
// Mock MessageBus
const mockMessageBus = {
publish: vi.fn(),
subscribe: vi.fn(),
@@ -539,41 +499,18 @@ describe('ToolRegistry', () => {
};
const mockSpawn = vi.mocked(spawn);
const discoveryProcess = {
stdout: { on: vi.fn(), removeListener: vi.fn() },
stderr: { on: vi.fn(), removeListener: vi.fn() },
on: vi.fn(),
kill: vi.fn(),
};
mockSpawn.mockReturnValueOnce(discoveryProcess as any);
discoveryProcess.stdout.on.mockImplementation((event, callback) => {
if (event === 'data') {
callback(
Buffer.from(
JSON.stringify([{ functionDeclarations: [toolDeclaration] }]),
),
);
}
});
discoveryProcess.on.mockImplementation((event, callback) => {
if (event === 'close') {
callback(0);
}
});
mockSpawn.mockReturnValueOnce(
createDiscoveryProcess([toolDeclaration]) as any,
);
await toolRegistry.discoverAllTools();
const tool = toolRegistry.getTool(
DISCOVERED_TOOL_PREFIX + 'policy-test-tool',
);
expect(tool).toBeDefined();
// Verify DiscoveredTool has the message bus
expect((tool as any).messageBus).toBe(mockMessageBus);
const invocation = tool!.build({});
// Verify DiscoveredToolInvocation has the message bus
expect((invocation as any).messageBus).toBe(mockMessageBus);
});
});
+143 -230
View File
@@ -71,46 +71,38 @@ describe('parsePrompt', () => {
expect(validUrls[0]).toBe('https://example.com./');
});
it('should detect URLs wrapped in punctuation as malformed', () => {
const prompt = 'Read (https://example.com)';
it.each([
{
name: 'URLs wrapped in punctuation',
prompt: 'Read (https://example.com)',
expectedErrorContent: ['Malformed URL detected', '(https://example.com)'],
},
{
name: 'unsupported protocols (httpshttps://)',
prompt: 'Summarize httpshttps://github.com/JuliaLang/julia/issues/58346',
expectedErrorContent: [
'Unsupported protocol',
'httpshttps://github.com/JuliaLang/julia/issues/58346',
],
},
{
name: 'unsupported protocols (ftp://)',
prompt: 'ftp://example.com/file.txt',
expectedErrorContent: ['Unsupported protocol'],
},
{
name: 'malformed URLs (http://)',
prompt: 'http://',
expectedErrorContent: ['Malformed URL detected'],
},
])('should detect $name as errors', ({ prompt, expectedErrorContent }) => {
const { validUrls, errors } = parsePrompt(prompt);
expect(validUrls).toHaveLength(0);
expect(errors).toHaveLength(1);
expect(errors[0]).toContain('Malformed URL detected');
expect(errors[0]).toContain('(https://example.com)');
});
it('should detect unsupported protocols (httpshttps://)', () => {
const prompt =
'Summarize httpshttps://github.com/JuliaLang/julia/issues/58346';
const { validUrls, errors } = parsePrompt(prompt);
expect(validUrls).toHaveLength(0);
expect(errors).toHaveLength(1);
expect(errors[0]).toContain('Unsupported protocol');
expect(errors[0]).toContain(
'httpshttps://github.com/JuliaLang/julia/issues/58346',
);
});
it('should detect unsupported protocols (ftp://)', () => {
const prompt = 'ftp://example.com/file.txt';
const { validUrls, errors } = parsePrompt(prompt);
expect(validUrls).toHaveLength(0);
expect(errors).toHaveLength(1);
expect(errors[0]).toContain('Unsupported protocol');
});
it('should detect malformed URLs', () => {
// http:// is not a valid URL in Node's new URL()
const prompt = 'http://';
const { validUrls, errors } = parsePrompt(prompt);
expect(validUrls).toHaveLength(0);
expect(errors).toHaveLength(1);
expect(errors[0]).toContain('Malformed URL detected');
expectedErrorContent.forEach((content) => {
expect(errors[0]).toContain(content);
});
});
it('should handle prompts with no URLs', () => {
@@ -153,24 +145,25 @@ describe('WebFetchTool', () => {
});
describe('validateToolParamValues', () => {
it('should throw if prompt is empty', () => {
it.each([
{
name: 'empty prompt',
prompt: '',
expectedError: "The 'prompt' parameter cannot be empty",
},
{
name: 'prompt with no URLs',
prompt: 'hello world',
expectedError: "The 'prompt' must contain at least one valid URL",
},
{
name: 'prompt with malformed URLs',
prompt: 'fetch httpshttps://example.com',
expectedError: 'Error(s) in prompt URLs:',
},
])('should throw if $name', ({ prompt, expectedError }) => {
const tool = new WebFetchTool(mockConfig);
expect(() => tool.build({ prompt: '' })).toThrow(
"The 'prompt' parameter cannot be empty",
);
});
it('should throw if prompt contains no URLs', () => {
const tool = new WebFetchTool(mockConfig);
expect(() => tool.build({ prompt: 'hello world' })).toThrow(
"The 'prompt' must contain at least one valid URL",
);
});
it('should throw if prompt contains malformed URLs (httpshttps://)', () => {
const tool = new WebFetchTool(mockConfig);
const prompt = 'fetch httpshttps://example.com';
expect(() => tool.build({ prompt })).toThrow('Error(s) in prompt URLs:');
expect(() => tool.build({ prompt })).toThrow(expectedError);
});
it('should pass if prompt contains at least one valid URL', () => {
@@ -267,105 +260,71 @@ describe('WebFetchTool', () => {
});
});
it('should convert HTML content using html-to-text', async () => {
const htmlContent = '<html><body><h1>Hello</h1></body></html>';
vi.spyOn(fetchUtils, 'fetchWithTimeout').mockResolvedValue({
ok: true,
headers: new Headers({ 'content-type': 'text/html; charset=utf-8' }),
text: () => Promise.resolve(htmlContent),
} as Response);
it.each([
{
name: 'HTML content using html-to-text',
content: '<html><body><h1>Hello</h1></body></html>',
contentType: 'text/html; charset=utf-8',
shouldConvert: true,
},
{
name: 'raw text for JSON content',
content: '{"key": "value"}',
contentType: 'application/json',
shouldConvert: false,
},
{
name: 'raw text for plain text content',
content: 'Just some text.',
contentType: 'text/plain',
shouldConvert: false,
},
{
name: 'content with no Content-Type header as HTML',
content: '<p>No header</p>',
contentType: null,
shouldConvert: true,
},
])(
'should handle $name',
async ({ content, contentType, shouldConvert }) => {
const headers = contentType
? new Headers({ 'content-type': contentType })
: new Headers();
// Mock fallback LLM call to return the content passed to it
mockGenerateContent.mockImplementationOnce(async (_, req) => ({
candidates: [{ content: { parts: [{ text: req[0].parts[0].text }] } }],
}));
vi.spyOn(fetchUtils, 'fetchWithTimeout').mockResolvedValue({
ok: true,
headers,
text: () => Promise.resolve(content),
} as Response);
const tool = new WebFetchTool(mockConfig);
const params = { prompt: 'fetch https://example.com' };
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
// Mock fallback LLM call to return the content passed to it
mockGenerateContent.mockImplementationOnce(async (_, req) => ({
candidates: [
{ content: { parts: [{ text: req[0].parts[0].text }] } },
],
}));
expect(convert).toHaveBeenCalledWith(htmlContent, {
wordwrap: false,
selectors: [
{ selector: 'a', options: { ignoreHref: true } },
{ selector: 'img', format: 'skip' },
],
});
expect(result.llmContent).toContain(`Converted: ${htmlContent}`);
});
const tool = new WebFetchTool(mockConfig);
const params = { prompt: 'fetch https://example.com' };
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
it('should return raw text for JSON content', async () => {
const jsonContent = '{"key": "value"}';
vi.spyOn(fetchUtils, 'fetchWithTimeout').mockResolvedValue({
ok: true,
headers: new Headers({ 'content-type': 'application/json' }),
text: () => Promise.resolve(jsonContent),
} as Response);
// Mock fallback LLM call to return the content passed to it
mockGenerateContent.mockImplementationOnce(async (_, req) => ({
candidates: [{ content: { parts: [{ text: req[0].parts[0].text }] } }],
}));
const tool = new WebFetchTool(mockConfig);
const params = { prompt: 'fetch https://example.com' };
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(convert).not.toHaveBeenCalled();
expect(result.llmContent).toContain(jsonContent);
});
it('should return raw text for plain text content', async () => {
const textContent = 'Just some text.';
vi.spyOn(fetchUtils, 'fetchWithTimeout').mockResolvedValue({
ok: true,
headers: new Headers({ 'content-type': 'text/plain' }),
text: () => Promise.resolve(textContent),
} as Response);
// Mock fallback LLM call to return the content passed to it
mockGenerateContent.mockImplementationOnce(async (_, req) => ({
candidates: [{ content: { parts: [{ text: req[0].parts[0].text }] } }],
}));
const tool = new WebFetchTool(mockConfig);
const params = { prompt: 'fetch https://example.com' };
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(convert).not.toHaveBeenCalled();
expect(result.llmContent).toContain(textContent);
});
it('should treat content with no Content-Type header as HTML', async () => {
const content = '<p>No header</p>';
vi.spyOn(fetchUtils, 'fetchWithTimeout').mockResolvedValue({
ok: true,
headers: new Headers(),
text: () => Promise.resolve(content),
} as Response);
// Mock fallback LLM call to return the content passed to it
mockGenerateContent.mockImplementationOnce(async (_, req) => ({
candidates: [{ content: { parts: [{ text: req[0].parts[0].text }] } }],
}));
const tool = new WebFetchTool(mockConfig);
const params = { prompt: 'fetch https://example.com' };
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(convert).toHaveBeenCalledWith(content, {
wordwrap: false,
selectors: [
{ selector: 'a', options: { ignoreHref: true } },
{ selector: 'img', format: 'skip' },
],
});
expect(result.llmContent).toContain(`Converted: ${content}`);
});
if (shouldConvert) {
expect(convert).toHaveBeenCalledWith(content, {
wordwrap: false,
selectors: [
{ selector: 'a', options: { ignoreHref: true } },
{ selector: 'img', format: 'skip' },
],
});
expect(result.llmContent).toContain(`Converted: ${content}`);
} else {
expect(convert).not.toHaveBeenCalled();
expect(result.llmContent).toContain(content);
}
},
);
});
describe('shouldConfirmExecute', () => {
@@ -452,6 +411,28 @@ describe('WebFetchTool', () => {
let messageBus: MessageBus;
let mockUUID: Mock;
const createToolWithMessageBus = (bus?: MessageBus) => {
const tool = new WebFetchTool(mockConfig, bus);
const params = { prompt: 'fetch https://example.com' };
return { tool, invocation: tool.build(params) };
};
const simulateMessageBusResponse = (
subscribeSpy: ReturnType<typeof vi.spyOn>,
confirmed: boolean,
correlationId = 'test-correlation-id',
) => {
const responseHandler = subscribeSpy.mock.calls[0][1] as (
response: ToolConfirmationResponse,
) => void;
const response: ToolConfirmationResponse = {
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId,
confirmed,
};
responseHandler(response);
};
beforeEach(() => {
policyEngine = new PolicyEngine();
messageBus = new MessageBus(policyEngine);
@@ -460,21 +441,15 @@ describe('WebFetchTool', () => {
});
it('should use message bus for confirmation when available', async () => {
const tool = new WebFetchTool(mockConfig, messageBus);
const params = { prompt: 'fetch https://example.com' };
const invocation = tool.build(params);
// Mock message bus publish and subscribe
const { invocation } = createToolWithMessageBus(messageBus);
const publishSpy = vi.spyOn(messageBus, 'publish');
const subscribeSpy = vi.spyOn(messageBus, 'subscribe');
const unsubscribeSpy = vi.spyOn(messageBus, 'unsubscribe');
// Start confirmation process
const confirmationPromise = invocation.shouldConfirmExecute(
new AbortController().signal,
);
// Verify confirmation request was published
expect(publishSpy).toHaveBeenCalledWith({
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
toolCall: {
@@ -484,49 +459,28 @@ describe('WebFetchTool', () => {
correlationId: 'test-correlation-id',
});
// Verify subscription to response
expect(subscribeSpy).toHaveBeenCalledWith(
MessageBusType.TOOL_CONFIRMATION_RESPONSE,
expect.any(Function),
);
// Simulate confirmation response
const responseHandler = subscribeSpy.mock.calls[0][1];
const response: ToolConfirmationResponse = {
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'test-correlation-id',
confirmed: true,
};
responseHandler(response);
simulateMessageBusResponse(subscribeSpy, true);
const result = await confirmationPromise;
expect(result).toBe(false); // No further confirmation needed
expect(result).toBe(false);
expect(unsubscribeSpy).toHaveBeenCalled();
});
it('should reject promise when confirmation is denied via message bus', async () => {
const tool = new WebFetchTool(mockConfig, messageBus);
const params = { prompt: 'fetch https://example.com' };
const invocation = tool.build(params);
const { invocation } = createToolWithMessageBus(messageBus);
const subscribeSpy = vi.spyOn(messageBus, 'subscribe');
const confirmationPromise = invocation.shouldConfirmExecute(
new AbortController().signal,
);
// Simulate denial response
const responseHandler = subscribeSpy.mock.calls[0][1];
const response: ToolConfirmationResponse = {
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'test-correlation-id',
confirmed: false,
};
simulateMessageBusResponse(subscribeSpy, false);
responseHandler(response);
// Should reject with error when denied
await expect(confirmationPromise).rejects.toThrow(
'Tool execution for "WebFetch" denied by policy.',
);
@@ -534,16 +488,11 @@ describe('WebFetchTool', () => {
it('should handle timeout gracefully', async () => {
vi.useFakeTimers();
const tool = new WebFetchTool(mockConfig, messageBus);
const params = { prompt: 'fetch https://example.com' };
const invocation = tool.build(params);
const { invocation } = createToolWithMessageBus(messageBus);
const confirmationPromise = invocation.shouldConfirmExecute(
new AbortController().signal,
);
// Fast-forward past timeout
await vi.advanceTimersByTimeAsync(30000);
const result = await confirmationPromise;
expect(result).not.toBe(false);
@@ -553,16 +502,12 @@ describe('WebFetchTool', () => {
});
it('should handle abort signal during confirmation', async () => {
const tool = new WebFetchTool(mockConfig, messageBus);
const params = { prompt: 'fetch https://example.com' };
const invocation = tool.build(params);
const { invocation } = createToolWithMessageBus(messageBus);
const abortController = new AbortController();
const confirmationPromise = invocation.shouldConfirmExecute(
abortController.signal,
);
// Abort the operation
abortController.abort();
await expect(confirmationPromise).rejects.toThrow(
@@ -571,42 +516,25 @@ describe('WebFetchTool', () => {
});
it('should fall back to legacy confirmation when no message bus', async () => {
const tool = new WebFetchTool(mockConfig); // No message bus
const params = { prompt: 'fetch https://example.com' };
const invocation = tool.build(params);
const { invocation } = createToolWithMessageBus(); // No message bus
const result = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
// Should use legacy confirmation flow (returns confirmation details, not false)
expect(result).not.toBe(false);
expect(result).toHaveProperty('type', 'info');
});
it('should ignore responses with wrong correlation ID', async () => {
vi.useFakeTimers();
const tool = new WebFetchTool(mockConfig, messageBus);
const params = { prompt: 'fetch https://example.com' };
const invocation = tool.build(params);
const { invocation } = createToolWithMessageBus(messageBus);
const subscribeSpy = vi.spyOn(messageBus, 'subscribe');
const confirmationPromise = invocation.shouldConfirmExecute(
new AbortController().signal,
);
// Send response with wrong correlation ID
const responseHandler = subscribeSpy.mock.calls[0][1];
const wrongResponse: ToolConfirmationResponse = {
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'wrong-id',
confirmed: true,
};
simulateMessageBusResponse(subscribeSpy, true, 'wrong-id');
responseHandler(wrongResponse);
// Should timeout since correct response wasn't received
await vi.advanceTimersByTimeAsync(30000);
const result = await confirmationPromise;
expect(result).not.toBe(false);
@@ -616,11 +544,7 @@ describe('WebFetchTool', () => {
});
it('should handle message bus publish errors gracefully', async () => {
const tool = new WebFetchTool(mockConfig, messageBus);
const params = { prompt: 'fetch https://example.com' };
const invocation = tool.build(params);
// Mock publish to throw error
const { invocation } = createToolWithMessageBus(messageBus);
vi.spyOn(messageBus, 'publish').mockImplementation(() => {
throw new Error('Message bus error');
});
@@ -628,7 +552,7 @@ describe('WebFetchTool', () => {
const result = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(result).toBe(false); // Should gracefully fall back
expect(result).toBe(false);
});
it('should execute normally after confirmation approval', async () => {
@@ -644,28 +568,17 @@ describe('WebFetchTool', () => {
],
});
const tool = new WebFetchTool(mockConfig, messageBus);
const params = { prompt: 'fetch https://example.com' };
const invocation = tool.build(params);
const { invocation } = createToolWithMessageBus(messageBus);
const subscribeSpy = vi.spyOn(messageBus, 'subscribe');
// Start confirmation
const confirmationPromise = invocation.shouldConfirmExecute(
new AbortController().signal,
);
// Approve via message bus
const responseHandler = subscribeSpy.mock.calls[0][1];
responseHandler({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'test-correlation-id',
confirmed: true,
});
simulateMessageBusResponse(subscribeSpy, true);
await confirmationPromise;
// Execute the tool
const result = await invocation.execute(new AbortController().signal);
expect(result.error).toBeUndefined();
expect(result.llmContent).toContain('Fetched content');
+140 -171
View File
@@ -16,7 +16,12 @@ import {
import type { WriteFileToolParams } from './write-file.js';
import { getCorrectedFileContent, WriteFileTool } from './write-file.js';
import { ToolErrorType } from './tool-error.js';
import type { FileDiff, ToolEditConfirmationDetails } from './tools.js';
import type {
FileDiff,
ToolEditConfirmationDetails,
ToolInvocation,
ToolResult,
} from './tools.js';
import { ToolConfirmationOutcome } from './tools.js';
import { type EditToolParams } from './edit.js';
import type { Config } from '../config/config.js';
@@ -538,6 +543,7 @@ describe('WriteFileTool', () => {
});
it('should not await ideConfirmation promise', async () => {
const IDE_DIFF_DELAY_MS = 50;
const filePath = path.join(rootDir, 'ide_no_await_file.txt');
const params = { file_path: filePath, content: 'test' };
const invocation = tool.build(params);
@@ -547,7 +553,7 @@ describe('WriteFileTool', () => {
setTimeout(() => {
diffPromiseResolved = true;
resolve({ status: 'accepted', content: 'ide-modified-content' });
}, 50); // A small delay to ensure the check happens before resolution
}, IDE_DIFF_DELAY_MS);
});
mockIdeClient.openDiff.mockReturnValue(diffPromise);
@@ -571,6 +577,20 @@ describe('WriteFileTool', () => {
describe('execute', () => {
const abortSignal = new AbortController().signal;
async function confirmExecution(
invocation: ToolInvocation<WriteFileToolParams, ToolResult>,
signal: AbortSignal = abortSignal,
) {
const confirmDetails = await invocation.shouldConfirmExecute(signal);
if (
typeof confirmDetails === 'object' &&
'onConfirm' in confirmDetails &&
confirmDetails.onConfirm
) {
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
}
}
it('should write a new file with a relative path', async () => {
const relativePath = 'execute_relative_new_file.txt';
const filePath = path.join(rootDir, relativePath);
@@ -624,14 +644,7 @@ describe('WriteFileTool', () => {
const params = { file_path: filePath, content: proposedContent };
const invocation = tool.build(params);
const confirmDetails = await invocation.shouldConfirmExecute(abortSignal);
if (
typeof confirmDetails === 'object' &&
'onConfirm' in confirmDetails &&
confirmDetails.onConfirm
) {
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
}
await confirmExecution(invocation);
const result = await invocation.execute(abortSignal);
@@ -681,14 +694,7 @@ describe('WriteFileTool', () => {
const params = { file_path: filePath, content: proposedContent };
const invocation = tool.build(params);
const confirmDetails = await invocation.shouldConfirmExecute(abortSignal);
if (
typeof confirmDetails === 'object' &&
'onConfirm' in confirmDetails &&
confirmDetails.onConfirm
) {
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
}
await confirmExecution(invocation);
const result = await invocation.execute(abortSignal);
@@ -725,15 +731,8 @@ describe('WriteFileTool', () => {
const params = { file_path: filePath, content };
const invocation = tool.build(params);
// Simulate confirmation if your logic requires it before execute, or remove if not needed for this path
const confirmDetails = await invocation.shouldConfirmExecute(abortSignal);
if (
typeof confirmDetails === 'object' &&
'onConfirm' in confirmDetails &&
confirmDetails.onConfirm
) {
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
}
await confirmExecution(invocation);
await invocation.execute(abortSignal);
@@ -743,52 +742,44 @@ describe('WriteFileTool', () => {
expect(fs.readFileSync(filePath, 'utf8')).toBe(content);
});
it('should include modification message when proposed content is modified', async () => {
const filePath = path.join(rootDir, 'new_file_modified.txt');
const content = 'New file content modified by user';
mockEnsureCorrectFileContent.mockResolvedValue(content);
const params = {
file_path: filePath,
content,
it.each([
{
modified_by_user: true,
};
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).toMatch(/User modified the `content`/);
});
it('should not include modification message when proposed content is not modified', async () => {
const filePath = path.join(rootDir, 'new_file_unmodified.txt');
const content = 'New file content not modified';
mockEnsureCorrectFileContent.mockResolvedValue(content);
const params = {
file_path: filePath,
content,
shouldIncludeMessage: true,
testCase: 'when modified_by_user is true',
},
{
modified_by_user: false,
};
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
shouldIncludeMessage: false,
testCase: 'when modified_by_user is false',
},
{
modified_by_user: undefined,
shouldIncludeMessage: false,
testCase: 'when modified_by_user is not provided',
},
])(
'should $testCase include modification message',
async ({ modified_by_user, shouldIncludeMessage }) => {
const filePath = path.join(rootDir, `new_file_${modified_by_user}.txt`);
const content = 'New file content';
mockEnsureCorrectFileContent.mockResolvedValue(content);
expect(result.llmContent).not.toMatch(/User modified the `content`/);
});
const params: WriteFileToolParams = {
file_path: filePath,
content,
...(modified_by_user !== undefined && { modified_by_user }),
};
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
it('should not include modification message when modified_by_user is not provided', async () => {
const filePath = path.join(rootDir, 'new_file_unmodified.txt');
const content = 'New file content not modified';
mockEnsureCorrectFileContent.mockResolvedValue(content);
const params = {
file_path: filePath,
content,
};
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.llmContent).not.toMatch(/User modified the `content`/);
});
if (shouldIncludeMessage) {
expect(result.llmContent).toMatch(/User modified the `content`/);
} else {
expect(result.llmContent).not.toMatch(/User modified the `content`/);
}
},
);
});
describe('workspace boundary validation', () => {
@@ -814,114 +805,92 @@ describe('WriteFileTool', () => {
describe('specific error types for write failures', () => {
const abortSignal = new AbortController().signal;
it('should return PERMISSION_DENIED error when write fails with EACCES', async () => {
const filePath = path.join(rootDir, 'permission_denied_file.txt');
const content = 'test content';
it.each([
{
errorCode: 'EACCES',
errorType: ToolErrorType.PERMISSION_DENIED,
errorMessage: 'Permission denied',
expectedMessagePrefix: 'Permission denied writing to file',
mockFsExistsSync: false,
restoreAllMocks: false,
},
{
errorCode: 'ENOSPC',
errorType: ToolErrorType.NO_SPACE_LEFT,
errorMessage: 'No space left on device',
expectedMessagePrefix: 'No space left on device',
mockFsExistsSync: false,
restoreAllMocks: false,
},
{
errorCode: 'EISDIR',
errorType: ToolErrorType.TARGET_IS_DIRECTORY,
errorMessage: 'Is a directory',
expectedMessagePrefix: 'Target is a directory, not a file',
mockFsExistsSync: true,
restoreAllMocks: false,
},
{
errorCode: undefined,
errorType: ToolErrorType.FILE_WRITE_FAILURE,
errorMessage: 'Generic write error',
expectedMessagePrefix: 'Error writing to file',
mockFsExistsSync: false,
restoreAllMocks: true,
},
])(
'should return $errorType error when write fails with $errorCode',
async ({
errorCode,
errorType,
errorMessage,
expectedMessagePrefix,
mockFsExistsSync,
restoreAllMocks,
}) => {
const filePath = path.join(rootDir, `${errorType}_file.txt`);
const content = 'test content';
// Mock FileSystemService writeTextFile to throw EACCES error
vi.spyOn(fsService, 'writeTextFile').mockImplementationOnce(() => {
const error = new Error('Permission denied') as NodeJS.ErrnoException;
error.code = 'EACCES';
return Promise.reject(error);
});
const params = { file_path: filePath, content };
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.error?.type).toBe(ToolErrorType.PERMISSION_DENIED);
expect(result.llmContent).toContain(
`Permission denied writing to file: ${filePath} (EACCES)`,
);
expect(result.returnDisplay).toContain(
`Permission denied writing to file: ${filePath} (EACCES)`,
);
});
it('should return NO_SPACE_LEFT error when write fails with ENOSPC', async () => {
const filePath = path.join(rootDir, 'no_space_file.txt');
const content = 'test content';
// Mock FileSystemService writeTextFile to throw ENOSPC error
vi.spyOn(fsService, 'writeTextFile').mockImplementationOnce(() => {
const error = new Error(
'No space left on device',
) as NodeJS.ErrnoException;
error.code = 'ENOSPC';
return Promise.reject(error);
});
const params = { file_path: filePath, content };
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.error?.type).toBe(ToolErrorType.NO_SPACE_LEFT);
expect(result.llmContent).toContain(
`No space left on device: ${filePath} (ENOSPC)`,
);
expect(result.returnDisplay).toContain(
`No space left on device: ${filePath} (ENOSPC)`,
);
});
it('should return TARGET_IS_DIRECTORY error when write fails with EISDIR', async () => {
const dirPath = path.join(rootDir, 'test_directory');
const content = 'test content';
// Mock fs.existsSync to return false to bypass validation
const originalExistsSync = fs.existsSync;
vi.spyOn(fs, 'existsSync').mockImplementation((path) => {
if (path === dirPath) {
return false; // Pretend directory doesn't exist to bypass validation
if (restoreAllMocks) {
vi.restoreAllMocks();
}
return originalExistsSync(path as string);
});
// Mock FileSystemService writeTextFile to throw EISDIR error
vi.spyOn(fsService, 'writeTextFile').mockImplementationOnce(() => {
const error = new Error('Is a directory') as NodeJS.ErrnoException;
error.code = 'EISDIR';
return Promise.reject(error);
});
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let existsSyncSpy: any;
const params = { file_path: dirPath, content };
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
try {
if (mockFsExistsSync) {
const originalExistsSync = fs.existsSync;
existsSyncSpy = vi
.spyOn(fs, 'existsSync')
.mockImplementation((path) =>
path === filePath ? false : originalExistsSync(path as string),
);
}
expect(result.error?.type).toBe(ToolErrorType.TARGET_IS_DIRECTORY);
expect(result.llmContent).toContain(
`Target is a directory, not a file: ${dirPath} (EISDIR)`,
);
expect(result.returnDisplay).toContain(
`Target is a directory, not a file: ${dirPath} (EISDIR)`,
);
vi.spyOn(fsService, 'writeTextFile').mockImplementationOnce(() => {
const error = new Error(errorMessage) as NodeJS.ErrnoException;
if (errorCode) error.code = errorCode;
return Promise.reject(error);
});
vi.spyOn(fs, 'existsSync').mockImplementation(originalExistsSync);
});
const params = { file_path: filePath, content };
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
it('should return FILE_WRITE_FAILURE for generic write errors', async () => {
const filePath = path.join(rootDir, 'generic_error_file.txt');
const content = 'test content';
// Ensure fs.existsSync is not mocked for this test
vi.restoreAllMocks();
// Mock FileSystemService writeTextFile to throw generic error
vi.spyOn(fsService, 'writeTextFile').mockImplementationOnce(() =>
Promise.reject(new Error('Generic write error')),
);
const params = { file_path: filePath, content };
const invocation = tool.build(params);
const result = await invocation.execute(abortSignal);
expect(result.error?.type).toBe(ToolErrorType.FILE_WRITE_FAILURE);
expect(result.llmContent).toContain(
'Error writing to file: Generic write error',
);
expect(result.returnDisplay).toContain(
'Error writing to file: Generic write error',
);
});
expect(result.error?.type).toBe(errorType);
const errorSuffix = errorCode ? ` (${errorCode})` : '';
const expectedMessage = errorCode
? `${expectedMessagePrefix}: ${filePath}${errorSuffix}`
: `${expectedMessagePrefix}: ${errorMessage}`;
expect(result.llmContent).toContain(expectedMessage);
expect(result.returnDisplay).toContain(expectedMessage);
} finally {
if (existsSyncSpy) {
existsSyncSpy.mockRestore();
}
}
},
);
});
});