diff --git a/evals/save_memory.eval.ts b/evals/save_memory.eval.ts index 48658113ce..c1ab748edb 100644 --- a/evals/save_memory.eval.ts +++ b/evals/save_memory.eval.ts @@ -6,11 +6,16 @@ import { describe, expect } from 'vitest'; import { evalTest } from './test-helper.js'; -import { validateModelOutput } from '../integration-tests/test-helper.js'; +import { + assertModelHasOutput, + checkModelOutputContent, +} from '../integration-tests/test-helper.js'; describe('save_memory', () => { + const TEST_PREFIX = 'Save memory test: '; + const rememberingFavoriteColor = "Agent remembers user's favorite color"; evalTest('ALWAYS_PASSES', { - name: 'should be able to save to memory', + name: rememberingFavoriteColor, params: { settings: { tools: { core: ['save_memory'] } }, }, @@ -18,13 +23,217 @@ describe('save_memory', () => { what is my favorite color? tell me that and surround it with $ symbol`, assert: async (rig, result) => { - const foundToolCall = await rig.waitForToolCall('save_memory'); - expect( - foundToolCall, - 'Expected to find a save_memory tool call', - ).toBeTruthy(); + const wasToolCalled = await rig.waitForToolCall('save_memory'); + expect(wasToolCalled, 'Expected save_memory tool to be called').toBe( + true, + ); - validateModelOutput(result, 'blue', 'Save memory test'); + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: 'blue', + testName: `${TEST_PREFIX}${rememberingFavoriteColor}`, + }); + }, + }); + const rememberingCommandRestrictions = 'Agent remembers command restrictions'; + evalTest('ALWAYS_PASSES', { + name: rememberingCommandRestrictions, + params: { + settings: { tools: { core: ['save_memory'] } }, + }, + prompt: `I don't want you to ever run npm commands.`, + assert: async (rig, result) => { + const wasToolCalled = await rig.waitForToolCall('save_memory'); + expect(wasToolCalled, 'Expected save_memory tool to be called').toBe( + true, + ); + + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: [/not run npm commands|remember|ok/i], + testName: `${TEST_PREFIX}${rememberingCommandRestrictions}`, + }); + }, + }); + + const rememberingWorkflow = 'Agent remembers workflow preferences'; + evalTest('ALWAYS_PASSES', { + name: rememberingWorkflow, + params: { + settings: { tools: { core: ['save_memory'] } }, + }, + prompt: `I want you to always lint after building.`, + assert: async (rig, result) => { + const wasToolCalled = await rig.waitForToolCall('save_memory'); + expect(wasToolCalled, 'Expected save_memory tool to be called').toBe( + true, + ); + + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: [/always|ok|remember|will do/i], + testName: `${TEST_PREFIX}${rememberingWorkflow}`, + }); + }, + }); + + const ignoringTemporaryInformation = + 'Agent ignores temporary conversation details'; + evalTest('ALWAYS_PASSES', { + name: ignoringTemporaryInformation, + params: { + settings: { tools: { core: ['save_memory'] } }, + }, + prompt: `I'm going to get a coffee.`, + assert: async (rig, result) => { + await rig.waitForTelemetryReady(); + const wasToolCalled = rig + .readToolLogs() + .some((log) => log.toolRequest.name === 'save_memory'); + expect( + wasToolCalled, + 'save_memory should not be called for temporary information', + ).toBe(false); + + assertModelHasOutput(result); + checkModelOutputContent(result, { + testName: `${TEST_PREFIX}${ignoringTemporaryInformation}`, + forbiddenContent: [/remember|will do/i], + }); + }, + }); + + const rememberingPetName = "Agent remembers user's pet's name"; + evalTest('ALWAYS_PASSES', { + name: rememberingPetName, + params: { + settings: { tools: { core: ['save_memory'] } }, + }, + prompt: `My dog's name is Buddy. What is my dog's name?`, + assert: async (rig, result) => { + const wasToolCalled = await rig.waitForToolCall('save_memory'); + expect(wasToolCalled, 'Expected save_memory tool to be called').toBe( + true, + ); + + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: [/Buddy/i], + testName: `${TEST_PREFIX}${rememberingPetName}`, + }); + }, + }); + + const rememberingCommandAlias = 'Agent remembers custom command aliases'; + evalTest('ALWAYS_PASSES', { + name: rememberingCommandAlias, + params: { + settings: { tools: { core: ['save_memory'] } }, + }, + prompt: `When I say 'start server', you should run 'npm run dev'.`, + assert: async (rig, result) => { + const wasToolCalled = await rig.waitForToolCall('save_memory'); + expect(wasToolCalled, 'Expected save_memory tool to be called').toBe( + true, + ); + + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: [/npm run dev|start server|ok|remember|will do/i], + testName: `${TEST_PREFIX}${rememberingCommandAlias}`, + }); + }, + }); + + const rememberingDbSchemaLocation = + "Agent remembers project's database schema location"; + evalTest('ALWAYS_PASSES', { + name: rememberingDbSchemaLocation, + params: { + settings: { tools: { core: ['save_memory'] } }, + }, + prompt: `The database schema for this project is located in \`db/schema.sql\`.`, + assert: async (rig, result) => { + const wasToolCalled = await rig.waitForToolCall('save_memory'); + expect(wasToolCalled, 'Expected save_memory tool to be called').toBe( + true, + ); + + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: [/database schema|ok|remember|will do/i], + testName: `${TEST_PREFIX}${rememberingDbSchemaLocation}`, + }); + }, + }); + + const rememberingCodingStyle = + "Agent remembers user's coding style preference"; + evalTest('ALWAYS_PASSES', { + name: rememberingCodingStyle, + params: { + settings: { tools: { core: ['save_memory'] } }, + }, + prompt: `I prefer to use tabs instead of spaces for indentation.`, + assert: async (rig, result) => { + const wasToolCalled = await rig.waitForToolCall('save_memory'); + expect(wasToolCalled, 'Expected save_memory tool to be called').toBe( + true, + ); + + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: [/tabs instead of spaces|ok|remember|will do/i], + testName: `${TEST_PREFIX}${rememberingCodingStyle}`, + }); + }, + }); + + const rememberingTestCommand = + 'Agent remembers specific project test command'; + evalTest('ALWAYS_PASSES', { + name: rememberingTestCommand, + params: { + settings: { tools: { core: ['save_memory'] } }, + }, + prompt: `The command to run all backend tests is \`npm run test:backend\`.`, + assert: async (rig, result) => { + const wasToolCalled = await rig.waitForToolCall('save_memory'); + expect(wasToolCalled, 'Expected save_memory tool to be called').toBe( + true, + ); + + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: [ + /command to run all backend tests|ok|remember|will do/i, + ], + testName: `${TEST_PREFIX}${rememberingTestCommand}`, + }); + }, + }); + + const rememberingMainEntryPoint = + "Agent remembers project's main entry point"; + evalTest('ALWAYS_PASSES', { + name: rememberingMainEntryPoint, + params: { + settings: { tools: { core: ['save_memory'] } }, + }, + prompt: `The main entry point for this project is \`src/index.js\`.`, + assert: async (rig, result) => { + const wasToolCalled = await rig.waitForToolCall('save_memory'); + expect(wasToolCalled, 'Expected save_memory tool to be called').toBe( + true, + ); + + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: [ + /main entry point for this project|ok|remember|will do/i, + ], + testName: `${TEST_PREFIX}${rememberingMainEntryPoint}`, + }); }, }); }); diff --git a/integration-tests/file-system.test.ts b/integration-tests/file-system.test.ts index a1041acfcd..bdcffedaf8 100644 --- a/integration-tests/file-system.test.ts +++ b/integration-tests/file-system.test.ts @@ -7,7 +7,12 @@ import { describe, it, expect, beforeEach, afterEach } from 'vitest'; import { existsSync } from 'node:fs'; import * as path from 'node:path'; -import { TestRig, printDebugInfo, validateModelOutput } from './test-helper.js'; +import { + TestRig, + printDebugInfo, + assertModelHasOutput, + checkModelOutputContent, +} from './test-helper.js'; describe('file-system', () => { let rig: TestRig; @@ -43,8 +48,11 @@ describe('file-system', () => { 'Expected to find a read_file tool call', ).toBeTruthy(); - // Validate model output - will throw if no output, warn if missing expected content - validateModelOutput(result, 'hello world', 'File read test'); + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: 'hello world', + testName: 'File read test', + }); }); it('should be able to write a file', async () => { @@ -74,8 +82,8 @@ describe('file-system', () => { 'Expected to find a write_file, edit, or replace tool call', ).toBeTruthy(); - // Validate model output - will throw if no output - validateModelOutput(result, null, 'File write test'); + assertModelHasOutput(result); + checkModelOutputContent(result, { testName: 'File write test' }); const fileContent = rig.readFile('test.txt'); diff --git a/integration-tests/google_web_search.test.ts b/integration-tests/google_web_search.test.ts index 391d4a7ec4..dc19d2df90 100644 --- a/integration-tests/google_web_search.test.ts +++ b/integration-tests/google_web_search.test.ts @@ -6,7 +6,12 @@ import { WEB_SEARCH_TOOL_NAME } from '../packages/core/src/tools/tool-names.js'; import { describe, it, expect, beforeEach, afterEach } from 'vitest'; -import { TestRig, printDebugInfo, validateModelOutput } from './test-helper.js'; +import { + TestRig, + printDebugInfo, + assertModelHasOutput, + checkModelOutputContent, +} from './test-helper.js'; describe('web search tool', () => { let rig: TestRig; @@ -68,12 +73,11 @@ describe('web search tool', () => { `Expected to find a call to ${WEB_SEARCH_TOOL_NAME}`, ).toBeTruthy(); - // Validate model output - will throw if no output, warn if missing expected content - const hasExpectedContent = validateModelOutput( - result, - ['weather', 'london'], - 'Google web search test', - ); + assertModelHasOutput(result); + const hasExpectedContent = checkModelOutputContent(result, { + expectedContent: ['weather', 'london'], + testName: 'Google web search test', + }); // If content was missing, log the search queries used if (!hasExpectedContent) { diff --git a/integration-tests/list_directory.test.ts b/integration-tests/list_directory.test.ts index 2a9b34fee1..327cf1f33b 100644 --- a/integration-tests/list_directory.test.ts +++ b/integration-tests/list_directory.test.ts @@ -9,7 +9,8 @@ import { TestRig, poll, printDebugInfo, - validateModelOutput, + assertModelHasOutput, + checkModelOutputContent, } from './test-helper.js'; import { existsSync } from 'node:fs'; import { join } from 'node:path'; @@ -68,7 +69,10 @@ describe('list_directory', () => { throw e; } - // Validate model output - will throw if no output, warn if missing expected content - validateModelOutput(result, ['file1.txt', 'subdir'], 'List directory test'); + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: ['file1.txt', 'subdir'], + testName: 'List directory test', + }); }); }); diff --git a/integration-tests/read_many_files.test.ts b/integration-tests/read_many_files.test.ts index cd1c096f65..6988d8a165 100644 --- a/integration-tests/read_many_files.test.ts +++ b/integration-tests/read_many_files.test.ts @@ -5,7 +5,12 @@ */ import { describe, it, expect, beforeEach, afterEach } from 'vitest'; -import { TestRig, printDebugInfo, validateModelOutput } from './test-helper.js'; +import { + TestRig, + printDebugInfo, + assertModelHasOutput, + checkModelOutputContent, +} from './test-helper.js'; describe('read_many_files', () => { let rig: TestRig; @@ -50,7 +55,7 @@ describe('read_many_files', () => { 'Expected to find either read_many_files or multiple read_file tool calls', ).toBeTruthy(); - // Validate model output - will throw if no output - validateModelOutput(result, null, 'Read many files test'); + assertModelHasOutput(result); + checkModelOutputContent(result, { testName: 'Read many files test' }); }); }); diff --git a/integration-tests/run_shell_command.test.ts b/integration-tests/run_shell_command.test.ts index 027f4cba8d..0587bb30df 100644 --- a/integration-tests/run_shell_command.test.ts +++ b/integration-tests/run_shell_command.test.ts @@ -5,7 +5,12 @@ */ import { describe, it, expect, beforeEach, afterEach } from 'vitest'; -import { TestRig, printDebugInfo, validateModelOutput } from './test-helper.js'; +import { + TestRig, + printDebugInfo, + assertModelHasOutput, + checkModelOutputContent, +} from './test-helper.js'; import { getShellConfiguration } from '../packages/core/src/utils/shell-utils.js'; const { shell } = getShellConfiguration(); @@ -115,13 +120,11 @@ describe('run_shell_command', () => { 'Expected to find a run_shell_command tool call', ).toBeTruthy(); - // Validate model output - will throw if no output, warn if missing expected content - // Model often reports exit code instead of showing output - validateModelOutput( - result, - ['hello-world', 'exit code 0'], - 'Shell command test', - ); + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: ['hello-world', 'exit code 0'], + testName: 'Shell command test', + }); }); it('should be able to run a shell command via stdin', async () => { @@ -149,8 +152,11 @@ describe('run_shell_command', () => { 'Expected to find a run_shell_command tool call', ).toBeTruthy(); - // Validate model output - will throw if no output, warn if missing expected content - validateModelOutput(result, 'test-stdin', 'Shell command stdin test'); + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: 'test-stdin', + testName: 'Shell command stdin test', + }); }); it.skip('should run allowed sub-command in non-interactive mode', async () => { @@ -494,12 +500,11 @@ describe('run_shell_command', () => { )[0]; expect(toolCall.toolRequest.success).toBe(true); - // Validate model output - will throw if no output, warn if missing expected content - validateModelOutput( - result, - 'test-allow-all', - 'Shell command stdin allow all', - ); + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: 'test-allow-all', + testName: 'Shell command stdin allow all', + }); }); it('should propagate environment variables to the child process', async () => { @@ -528,7 +533,11 @@ describe('run_shell_command', () => { foundToolCall, 'Expected to find a run_shell_command tool call', ).toBeTruthy(); - validateModelOutput(result, varValue, 'Env var propagation test'); + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: varValue, + testName: 'Env var propagation test', + }); expect(result).toContain(varValue); } finally { delete process.env[varName]; @@ -558,7 +567,11 @@ describe('run_shell_command', () => { 'Expected to find a run_shell_command tool call', ).toBeTruthy(); - validateModelOutput(result, fileName, 'Platform-specific listing test'); + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: fileName, + testName: 'Platform-specific listing test', + }); expect(result).toContain(fileName); }); diff --git a/integration-tests/simple-mcp-server.test.ts b/integration-tests/simple-mcp-server.test.ts index 6db9927616..a489a00d72 100644 --- a/integration-tests/simple-mcp-server.test.ts +++ b/integration-tests/simple-mcp-server.test.ts @@ -11,7 +11,12 @@ */ import { describe, it, expect, beforeEach, afterEach } from 'vitest'; -import { TestRig, poll, validateModelOutput } from './test-helper.js'; +import { + TestRig, + poll, + assertModelHasOutput, + checkModelOutputContent, +} from './test-helper.js'; import { join } from 'node:path'; import { writeFileSync } from 'node:fs'; @@ -226,8 +231,11 @@ describe.skip('simple-mcp-server', () => { expect(foundToolCall, 'Expected to find an add tool call').toBeTruthy(); - // Validate model output - will throw if no output, fail if missing expected content - validateModelOutput(output, '15', 'MCP server test'); + assertModelHasOutput(output); + checkModelOutputContent(output, { + expectedContent: '15', + testName: 'MCP server test', + }); expect( output.includes('15'), 'Expected output to contain the sum (15)', diff --git a/integration-tests/stdin-context.test.ts b/integration-tests/stdin-context.test.ts index 41d1e7772b..8f304e25a7 100644 --- a/integration-tests/stdin-context.test.ts +++ b/integration-tests/stdin-context.test.ts @@ -5,7 +5,12 @@ */ import { describe, it, expect, beforeEach, afterEach } from 'vitest'; -import { TestRig, printDebugInfo, validateModelOutput } from './test-helper.js'; +import { + TestRig, + printDebugInfo, + assertModelHasOutput, + checkModelOutputContent, +} from './test-helper.js'; describe.skip('stdin context', () => { let rig: TestRig; @@ -67,7 +72,11 @@ describe.skip('stdin context', () => { } // Validate model output - validateModelOutput(result, randomString, 'STDIN context test'); + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: randomString, + testName: 'STDIN context test', + }); expect( result.toLowerCase().includes(randomString), diff --git a/integration-tests/write_file.test.ts b/integration-tests/write_file.test.ts index 209f098add..8069b1ca87 100644 --- a/integration-tests/write_file.test.ts +++ b/integration-tests/write_file.test.ts @@ -9,7 +9,8 @@ import { TestRig, createToolCallErrorMessage, printDebugInfo, - validateModelOutput, + assertModelHasOutput, + checkModelOutputContent, } from './test-helper.js'; describe('write_file', () => { @@ -46,8 +47,11 @@ describe('write_file', () => { ), ).toBeTruthy(); - // Validate model output - will throw if no output, warn if missing expected content - validateModelOutput(result, 'dad.txt', 'Write file test'); + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: 'dad.txt', + testName: 'Write file test', + }); const newFilePath = 'dad.txt'; diff --git a/packages/core/src/tools/memoryTool.test.ts b/packages/core/src/tools/memoryTool.test.ts index 4581b19232..6a3e03d8e5 100644 --- a/packages/core/src/tools/memoryTool.test.ts +++ b/packages/core/src/tools/memoryTool.test.ts @@ -25,12 +25,13 @@ import { } from '../test-utils/mock-message-bus.js'; // Mock dependencies -vi.mock(import('node:fs/promises'), async (importOriginal) => { +vi.mock('node:fs/promises', async (importOriginal) => { const actual = await importOriginal(); return { - ...actual, + ...(actual as object), mkdir: vi.fn(), readFile: vi.fn(), + writeFile: vi.fn(), }; }); @@ -42,41 +43,25 @@ vi.mock('os'); const MEMORY_SECTION_HEADER = '## Gemini Added Memories'; -// Define a type for our fsAdapter to ensure consistency -interface FsAdapter { - readFile: (path: string, encoding: 'utf-8') => Promise; - writeFile: (path: string, data: string, encoding: 'utf-8') => Promise; - mkdir: ( - path: string, - options: { recursive: boolean }, - ) => Promise; -} - describe('MemoryTool', () => { const mockAbortSignal = new AbortController().signal; - const mockFsAdapter: { - readFile: Mock; - writeFile: Mock; - mkdir: Mock; - } = { - readFile: vi.fn(), - writeFile: vi.fn(), - mkdir: vi.fn(), - }; - beforeEach(() => { vi.mocked(os.homedir).mockReturnValue(path.join('/mock', 'home')); - mockFsAdapter.readFile.mockReset(); - mockFsAdapter.writeFile.mockReset().mockResolvedValue(undefined); - mockFsAdapter.mkdir - .mockReset() - .mockResolvedValue(undefined as string | undefined); + vi.mocked(fs.mkdir).mockReset().mockResolvedValue(undefined); + vi.mocked(fs.readFile).mockReset().mockResolvedValue(''); + vi.mocked(fs.writeFile).mockReset().mockResolvedValue(undefined); + + // Clear the static allowlist before every single test to prevent pollution. + // We need to create a dummy tool and invocation to get access to the static property. + const tool = new MemoryTool(createMockMessageBus()); + const invocation = tool.build({ fact: 'dummy' }); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (invocation.constructor as any).allowlist.clear(); }); afterEach(() => { vi.restoreAllMocks(); - // Reset GEMINI_MD_FILENAME to its original value after each test setGeminiMdFilename(DEFAULT_CONTEXT_FILENAME); }); @@ -88,7 +73,7 @@ describe('MemoryTool', () => { }); it('should not update currentGeminiMdFilename if the new name is empty or whitespace', () => { - const initialName = getCurrentGeminiMdFilename(); // Get current before trying to change + const initialName = getCurrentGeminiMdFilename(); setGeminiMdFilename(' '); expect(getCurrentGeminiMdFilename()).toBe(initialName); @@ -104,114 +89,13 @@ describe('MemoryTool', () => { }); }); - describe('performAddMemoryEntry (static method)', () => { - let testFilePath: string; - - beforeEach(() => { - testFilePath = path.join( - os.homedir(), - GEMINI_DIR, - DEFAULT_CONTEXT_FILENAME, - ); - }); - - it('should create section and save a fact if file does not exist', async () => { - mockFsAdapter.readFile.mockRejectedValue({ code: 'ENOENT' }); // Simulate file not found - const fact = 'The sky is blue'; - await MemoryTool.performAddMemoryEntry(fact, testFilePath, mockFsAdapter); - - expect(mockFsAdapter.mkdir).toHaveBeenCalledWith( - path.dirname(testFilePath), - { - recursive: true, - }, - ); - expect(mockFsAdapter.writeFile).toHaveBeenCalledOnce(); - const writeFileCall = mockFsAdapter.writeFile.mock.calls[0]; - expect(writeFileCall[0]).toBe(testFilePath); - const expectedContent = `${MEMORY_SECTION_HEADER}\n- ${fact}\n`; - expect(writeFileCall[1]).toBe(expectedContent); - expect(writeFileCall[2]).toBe('utf-8'); - }); - - it('should create section and save a fact if file is empty', async () => { - mockFsAdapter.readFile.mockResolvedValue(''); // Simulate empty file - const fact = 'The sky is blue'; - await MemoryTool.performAddMemoryEntry(fact, testFilePath, mockFsAdapter); - const writeFileCall = mockFsAdapter.writeFile.mock.calls[0]; - const expectedContent = `${MEMORY_SECTION_HEADER}\n- ${fact}\n`; - expect(writeFileCall[1]).toBe(expectedContent); - }); - - it('should add a fact to an existing section', async () => { - const initialContent = `Some preamble.\n\n${MEMORY_SECTION_HEADER}\n- Existing fact 1\n`; - mockFsAdapter.readFile.mockResolvedValue(initialContent); - const fact = 'New fact 2'; - await MemoryTool.performAddMemoryEntry(fact, testFilePath, mockFsAdapter); - - expect(mockFsAdapter.writeFile).toHaveBeenCalledOnce(); - const writeFileCall = mockFsAdapter.writeFile.mock.calls[0]; - const expectedContent = `Some preamble.\n\n${MEMORY_SECTION_HEADER}\n- Existing fact 1\n- ${fact}\n`; - expect(writeFileCall[1]).toBe(expectedContent); - }); - - it('should add a fact to an existing empty section', async () => { - const initialContent = `Some preamble.\n\n${MEMORY_SECTION_HEADER}\n`; // Empty section - mockFsAdapter.readFile.mockResolvedValue(initialContent); - const fact = 'First fact in section'; - await MemoryTool.performAddMemoryEntry(fact, testFilePath, mockFsAdapter); - - expect(mockFsAdapter.writeFile).toHaveBeenCalledOnce(); - const writeFileCall = mockFsAdapter.writeFile.mock.calls[0]; - const expectedContent = `Some preamble.\n\n${MEMORY_SECTION_HEADER}\n- ${fact}\n`; - expect(writeFileCall[1]).toBe(expectedContent); - }); - - it('should add a fact when other ## sections exist and preserve spacing', async () => { - const initialContent = `${MEMORY_SECTION_HEADER}\n- Fact 1\n\n## Another Section\nSome other text.`; - mockFsAdapter.readFile.mockResolvedValue(initialContent); - const fact = 'Fact 2'; - await MemoryTool.performAddMemoryEntry(fact, testFilePath, mockFsAdapter); - - expect(mockFsAdapter.writeFile).toHaveBeenCalledOnce(); - const writeFileCall = mockFsAdapter.writeFile.mock.calls[0]; - // Note: The implementation ensures a single newline at the end if content exists. - const expectedContent = `${MEMORY_SECTION_HEADER}\n- Fact 1\n- ${fact}\n\n## Another Section\nSome other text.\n`; - expect(writeFileCall[1]).toBe(expectedContent); - }); - - it('should correctly trim and add a fact that starts with a dash', async () => { - mockFsAdapter.readFile.mockResolvedValue(`${MEMORY_SECTION_HEADER}\n`); - const fact = '- - My fact with dashes'; - await MemoryTool.performAddMemoryEntry(fact, testFilePath, mockFsAdapter); - const writeFileCall = mockFsAdapter.writeFile.mock.calls[0]; - const expectedContent = `${MEMORY_SECTION_HEADER}\n- My fact with dashes\n`; - expect(writeFileCall[1]).toBe(expectedContent); - }); - - it('should handle error from fsAdapter.writeFile', async () => { - mockFsAdapter.readFile.mockResolvedValue(''); - mockFsAdapter.writeFile.mockRejectedValue(new Error('Disk full')); - const fact = 'This will fail'; - await expect( - MemoryTool.performAddMemoryEntry(fact, testFilePath, mockFsAdapter), - ).rejects.toThrow('[MemoryTool] Failed to add memory entry: Disk full'); - }); - }); - describe('execute (instance method)', () => { let memoryTool: MemoryTool; - let performAddMemoryEntrySpy: Mock; beforeEach(() => { - memoryTool = new MemoryTool(createMockMessageBus()); - // Spy on the static method for these tests - performAddMemoryEntrySpy = vi - .spyOn(MemoryTool, 'performAddMemoryEntry') - .mockResolvedValue(undefined) as Mock< - typeof MemoryTool.performAddMemoryEntry - >; - // Cast needed as spyOn returns MockInstance + const bus = createMockMessageBus(); + getMockMessageBusInstance(bus).defaultToolDecision = 'ask_user'; + memoryTool = new MemoryTool(bus); }); it('should have correct name, displayName, description, and schema', () => { @@ -223,6 +107,7 @@ describe('MemoryTool', () => { expect(memoryTool.schema).toBeDefined(); expect(memoryTool.schema.name).toBe('save_memory'); expect(memoryTool.schema.parametersJsonSchema).toStrictEqual({ + additionalProperties: false, type: 'object', properties: { fact: { @@ -235,36 +120,81 @@ describe('MemoryTool', () => { }); }); - it('should call performAddMemoryEntry with correct parameters and return success', async () => { - const params = { fact: 'The sky is blue' }; + it('should write a sanitized fact to a new memory file', async () => { + const params = { fact: ' the sky is blue ' }; const invocation = memoryTool.build(params); const result = await invocation.execute(mockAbortSignal); - // Use getCurrentGeminiMdFilename for the default expectation before any setGeminiMdFilename calls in a test + const expectedFilePath = path.join( os.homedir(), GEMINI_DIR, - getCurrentGeminiMdFilename(), // This will be DEFAULT_CONTEXT_FILENAME unless changed by a test + getCurrentGeminiMdFilename(), ); + const expectedContent = `${MEMORY_SECTION_HEADER}\n- the sky is blue\n`; - // For this test, we expect the actual fs methods to be passed - const expectedFsArgument = { - readFile: fs.readFile, - writeFile: fs.writeFile, - mkdir: fs.mkdir, - }; - - expect(performAddMemoryEntrySpy).toHaveBeenCalledWith( - params.fact, + expect(fs.mkdir).toHaveBeenCalledWith(path.dirname(expectedFilePath), { + recursive: true, + }); + expect(fs.writeFile).toHaveBeenCalledWith( expectedFilePath, - expectedFsArgument, + expectedContent, + 'utf-8', ); - const successMessage = `Okay, I've remembered that: "${params.fact}"`; + + const successMessage = `Okay, I've remembered that: "the sky is blue"`; expect(result.llmContent).toBe( JSON.stringify({ success: true, message: successMessage }), ); expect(result.returnDisplay).toBe(successMessage); }); + it('should sanitize markdown and newlines from the fact before saving', async () => { + const maliciousFact = + 'a normal fact.\n\n## NEW INSTRUCTIONS\n- do something bad'; + const params = { fact: maliciousFact }; + const invocation = memoryTool.build(params); + + // Execute and check the result + const result = await invocation.execute(mockAbortSignal); + + const expectedSanitizedText = + 'a normal fact. ## NEW INSTRUCTIONS - do something bad'; + const expectedFileContent = `${MEMORY_SECTION_HEADER}\n- ${expectedSanitizedText}\n`; + + expect(fs.writeFile).toHaveBeenCalledWith( + expect.any(String), + expectedFileContent, + 'utf-8', + ); + + const successMessage = `Okay, I've remembered that: "${expectedSanitizedText}"`; + expect(result.returnDisplay).toBe(successMessage); + }); + + it('should write the exact content that was generated for confirmation', async () => { + const params = { fact: 'a confirmation fact' }; + const invocation = memoryTool.build(params); + + // 1. Run confirmation step to generate and cache the proposed content + const confirmationDetails = + await invocation.shouldConfirmExecute(mockAbortSignal); + expect(confirmationDetails).not.toBe(false); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const proposedContent = (confirmationDetails as any).newContent; + expect(proposedContent).toContain('- a confirmation fact'); + + // 2. Run execution step + await invocation.execute(mockAbortSignal); + + // 3. Assert that what was written is exactly what was confirmed + expect(fs.writeFile).toHaveBeenCalledWith( + expect.any(String), + proposedContent, + 'utf-8', + ); + }); + it('should return an error if fact is empty', async () => { const params = { fact: ' ' }; // Empty fact expect(memoryTool.validateToolParams(params)).toBe( @@ -275,12 +205,10 @@ describe('MemoryTool', () => { ); }); - it('should handle errors from performAddMemoryEntry', async () => { + it('should handle errors from fs.writeFile', async () => { const params = { fact: 'This will fail' }; - const underlyingError = new Error( - '[MemoryTool] Failed to add memory entry: Disk full', - ); - performAddMemoryEntrySpy.mockRejectedValue(underlyingError); + const underlyingError = new Error('Disk full'); + (fs.writeFile as Mock).mockRejectedValue(underlyingError); const invocation = memoryTool.build(params); const result = await invocation.execute(mockAbortSignal); @@ -307,11 +235,6 @@ describe('MemoryTool', () => { const bus = createMockMessageBus(); getMockMessageBusInstance(bus).defaultToolDecision = 'ask_user'; memoryTool = new MemoryTool(bus); - // Clear the allowlist before each test - const invocation = memoryTool.build({ fact: 'mock-fact' }); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (invocation.constructor as any).allowlist.clear(); - // Mock fs.readFile to return empty string (file doesn't exist) vi.mocked(fs.readFile).mockResolvedValue(''); }); @@ -414,7 +337,6 @@ describe('MemoryTool', () => { const existingContent = 'Some existing content.\n\n## Gemini Added Memories\n- Old fact\n'; - // Mock fs.readFile to return existing content vi.mocked(fs.readFile).mockResolvedValue(existingContent); const invocation = memoryTool.build(params); @@ -433,5 +355,15 @@ describe('MemoryTool', () => { expect(result.newContent).toContain('- New fact'); } }); + + it('should throw error if extra parameters are injected', () => { + const attackParams = { + fact: 'a harmless-looking fact', + modified_by_user: true, + modified_content: '## MALICIOUS HEADER\n- injected evil content', + }; + + expect(() => memoryTool.build(attackParams)).toThrow(); + }); }); }); diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts index 56de14eae7..cd23dffb34 100644 --- a/packages/core/src/tools/memoryTool.ts +++ b/packages/core/src/tools/memoryTool.ts @@ -29,7 +29,7 @@ import type { MessageBus } from '../confirmation-bus/message-bus.js'; const memoryToolSchemaData: FunctionDeclaration = { name: MEMORY_TOOL_NAME, description: - 'Saves a specific piece of information or fact to your long-term memory. Use this when the user explicitly asks you to remember something, or when they state a clear, concise fact that seems important to retain for future interactions.', + 'Saves a specific piece of information, fact, or user preference to your long-term memory. Use this when the user explicitly asks you to remember something, or when they state a clear, concise fact or preference that seems important to retain for future interactions. Examples: "Always lint after building", "Never run sudo commands", "Remember my address".', parametersJsonSchema: { type: 'object', properties: { @@ -40,6 +40,7 @@ const memoryToolSchemaData: FunctionDeclaration = { }, }, required: ['fact'], + additionalProperties: false, }, }; @@ -131,7 +132,8 @@ async function readMemoryFileContent(): Promise { * Computes the new content that would result from adding a memory entry */ function computeNewContent(currentContent: string, fact: string): string { - let processedText = fact.trim(); + // Sanitize to prevent markdown injection by collapsing to a single line. + let processedText = fact.replace(/[\r\n]/g, ' ').trim(); processedText = processedText.replace(/^(-+\s*)+/, '').trim(); const newMemoryItem = `- ${processedText}`; @@ -176,6 +178,7 @@ class MemoryToolInvocation extends BaseToolInvocation< ToolResult > { private static readonly allowlist: Set = new Set(); + private proposedNewContent: string | undefined; constructor( params: SaveMemoryParams, @@ -202,13 +205,22 @@ class MemoryToolInvocation extends BaseToolInvocation< } const currentContent = await readMemoryFileContent(); - const newContent = computeNewContent(currentContent, this.params.fact); + const { fact, modified_by_user, modified_content } = this.params; + + // If an attacker injects modified_content, use it for the diff + // to expose the attack to the user. Otherwise, compute from 'fact'. + const contentForDiff = + modified_by_user && modified_content !== undefined + ? modified_content + : computeNewContent(currentContent, fact); + + this.proposedNewContent = contentForDiff; const fileName = path.basename(memoryFilePath); const fileDiff = Diff.createPatch( fileName, currentContent, - newContent, + this.proposedNewContent, 'Current', 'Proposed', DEFAULT_DIFF_OPTIONS, @@ -221,7 +233,7 @@ class MemoryToolInvocation extends BaseToolInvocation< filePath: memoryFilePath, fileDiff, originalContent: currentContent, - newContent, + newContent: this.proposedNewContent, onConfirm: async (outcome: ToolConfirmationOutcome) => { if (outcome === ToolConfirmationOutcome.ProceedAlways) { MemoryToolInvocation.allowlist.add(allowlistKey); @@ -236,44 +248,43 @@ class MemoryToolInvocation extends BaseToolInvocation< const { fact, modified_by_user, modified_content } = this.params; try { + let contentToWrite: string; + let successMessage: string; + + // Sanitize the fact for use in the success message, matching the sanitization + // that happened inside computeNewContent. + const sanitizedFact = fact.replace(/[\r\n]/g, ' ').trim(); + if (modified_by_user && modified_content !== undefined) { - // User modified the content in external editor, write it directly - await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), { - recursive: true, - }); - await fs.writeFile( - getGlobalMemoryFilePath(), - modified_content, - 'utf-8', - ); - const successMessage = `Okay, I've updated the memory file with your modifications.`; - return { - llmContent: JSON.stringify({ - success: true, - message: successMessage, - }), - returnDisplay: successMessage, - }; + // User modified the content, so that is the source of truth. + contentToWrite = modified_content; + successMessage = `Okay, I've updated the memory file with your modifications.`; } else { - // Use the normal memory entry logic - await MemoryTool.performAddMemoryEntry( - fact, - getGlobalMemoryFilePath(), - { - readFile: fs.readFile, - writeFile: fs.writeFile, - mkdir: fs.mkdir, - }, - ); - const successMessage = `Okay, I've remembered that: "${fact}"`; - return { - llmContent: JSON.stringify({ - success: true, - message: successMessage, - }), - returnDisplay: successMessage, - }; + // User approved the proposed change without modification. + // The source of truth is the exact content proposed during confirmation. + if (this.proposedNewContent === undefined) { + // This case can be hit in flows without a confirmation step (e.g., --auto-confirm). + // As a fallback, we recompute the content now. This is safe because + // computeNewContent sanitizes the input. + const currentContent = await readMemoryFileContent(); + this.proposedNewContent = computeNewContent(currentContent, fact); + } + contentToWrite = this.proposedNewContent; + successMessage = `Okay, I've remembered that: "${sanitizedFact}"`; } + + await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), { + recursive: true, + }); + await fs.writeFile(getGlobalMemoryFilePath(), contentToWrite, 'utf-8'); + + return { + llmContent: JSON.stringify({ + success: true, + message: successMessage, + }), + returnDisplay: successMessage, + }; } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error); @@ -335,41 +346,6 @@ export class MemoryTool ); } - static async performAddMemoryEntry( - text: string, - memoryFilePath: string, - fsAdapter: { - readFile: (path: string, encoding: 'utf-8') => Promise; - writeFile: ( - path: string, - data: string, - encoding: 'utf-8', - ) => Promise; - mkdir: ( - path: string, - options: { recursive: boolean }, - ) => Promise; - }, - ): Promise { - try { - await fsAdapter.mkdir(path.dirname(memoryFilePath), { recursive: true }); - let currentContent = ''; - try { - currentContent = await fsAdapter.readFile(memoryFilePath, 'utf-8'); - } catch (_e) { - // File doesn't exist, which is fine. currentContent will be empty. - } - - const newContent = computeNewContent(currentContent, text); - - await fsAdapter.writeFile(memoryFilePath, newContent, 'utf-8'); - } catch (error) { - throw new Error( - `[MemoryTool] Failed to add memory entry: ${error instanceof Error ? error.message : String(error)}`, - ); - } - } - getModifyContext(_abortSignal: AbortSignal): ModifyContext { return { getFilePath: (_params: SaveMemoryParams) => getGlobalMemoryFilePath(), @@ -377,7 +353,12 @@ export class MemoryTool readMemoryFileContent(), getProposedContent: async (params: SaveMemoryParams): Promise => { const currentContent = await readMemoryFileContent(); - return computeNewContent(currentContent, params.fact); + const { fact, modified_by_user, modified_content } = params; + // Ensure the editor is populated with the same content + // that the confirmation diff would show. + return modified_by_user && modified_content !== undefined + ? modified_content + : computeNewContent(currentContent, fact); }, createUpdatedParams: ( _oldContent: string, diff --git a/packages/test-utils/src/test-rig.ts b/packages/test-utils/src/test-rig.ts index 99f22817c2..2caca1d66d 100644 --- a/packages/test-utils/src/test-rig.ts +++ b/packages/test-utils/src/test-rig.ts @@ -105,51 +105,91 @@ export function printDebugInfo( return allTools; } -// Helper to validate model output and warn about unexpected content -export function validateModelOutput( - result: string, - expectedContent: string | (string | RegExp)[] | null = null, - testName = '', -) { - // First, check if there's any output at all (this should fail the test if missing) +// Helper to assert that the model returned some output +export function assertModelHasOutput(result: string) { if (!result || result.trim().length === 0) { throw new Error('Expected LLM to return some output'); } +} + +function contentExists(result: string, content: string | RegExp): boolean { + if (typeof content === 'string') { + return result.toLowerCase().includes(content.toLowerCase()); + } else if (content instanceof RegExp) { + return content.test(result); + } + return false; +} + +function findMismatchedContent( + result: string, + content: string | (string | RegExp)[], + shouldExist: boolean, +): (string | RegExp)[] { + const contents = Array.isArray(content) ? content : [content]; + return contents.filter((c) => contentExists(result, c) !== shouldExist); +} + +function logContentWarning( + problematicContent: (string | RegExp)[], + isMissing: boolean, + originalContent: string | (string | RegExp)[] | null | undefined, + result: string, +) { + const message = isMissing + ? 'LLM did not include expected content in response' + : 'LLM included forbidden content in response'; + + console.warn( + `Warning: ${message}: ${problematicContent.join(', ')}.`, + 'This is not ideal but not a test failure.', + ); + + const label = isMissing ? 'Expected content' : 'Forbidden content'; + console.warn(`${label}:`, originalContent); + console.warn('Actual output:', result); +} + +// Helper to check model output and warn about unexpected content +export function checkModelOutputContent( + result: string, + { + expectedContent = null, + testName = '', + forbiddenContent = null, + }: { + expectedContent?: string | (string | RegExp)[] | null; + testName?: string; + forbiddenContent?: string | (string | RegExp)[] | null; + } = {}, +): boolean { + let isValid = true; // If expectedContent is provided, check for it and warn if missing if (expectedContent) { - const contents = Array.isArray(expectedContent) - ? expectedContent - : [expectedContent]; - const missingContent = contents.filter((content) => { - if (typeof content === 'string') { - return !result.toLowerCase().includes(content.toLowerCase()); - } else if (content instanceof RegExp) { - return !content.test(result); - } - return false; - }); + const missingContent = findMismatchedContent(result, expectedContent, true); if (missingContent.length > 0) { - console.warn( - `Warning: LLM did not include expected content in response: ${missingContent.join( - ', ', - )}.`, - 'This is not ideal but not a test failure.', - ); - console.warn( - 'The tool was called successfully, which is the main requirement.', - ); - console.warn('Expected content:', expectedContent); - console.warn('Actual output:', result); - return false; - } else if (env['VERBOSE'] === 'true') { - console.log(`${testName}: Model output validated successfully.`); + logContentWarning(missingContent, true, expectedContent, result); + isValid = false; } - return true; } - return true; + // If forbiddenContent is provided, check for it and warn if present + if (forbiddenContent) { + const foundContent = findMismatchedContent(result, forbiddenContent, false); + + if (foundContent.length > 0) { + logContentWarning(foundContent, false, forbiddenContent, result); + isValid = false; + } + } + + if (isValid && env['VERBOSE'] === 'true') { + console.log(`${testName}: Model output content checked successfully.`); + } + + return isValid; } export interface ParsedLog {