diff --git a/packages/core/src/agents/codebase-investigator.test.ts b/packages/core/src/agents/codebase-investigator.test.ts new file mode 100644 index 0000000000..3d8453cb97 --- /dev/null +++ b/packages/core/src/agents/codebase-investigator.test.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { CodebaseInvestigatorAgent } from './codebase-investigator.js'; +import { + GLOB_TOOL_NAME, + GREP_TOOL_NAME, + LS_TOOL_NAME, + READ_FILE_TOOL_NAME, +} from '../tools/tool-names.js'; +import { DEFAULT_GEMINI_MODEL } from '../config/models.js'; + +describe('CodebaseInvestigatorAgent', () => { + it('should have the correct agent definition', () => { + expect(CodebaseInvestigatorAgent.name).toBe('codebase_investigator'); + expect(CodebaseInvestigatorAgent.displayName).toBe( + 'Codebase Investigator Agent', + ); + expect(CodebaseInvestigatorAgent.description).toBeDefined(); + expect( + CodebaseInvestigatorAgent.inputConfig.inputs['objective'].required, + ).toBe(true); + expect(CodebaseInvestigatorAgent.outputConfig?.outputName).toBe('report'); + expect(CodebaseInvestigatorAgent.modelConfig?.model).toBe( + DEFAULT_GEMINI_MODEL, + ); + expect(CodebaseInvestigatorAgent.toolConfig?.tools).toEqual([ + LS_TOOL_NAME, + READ_FILE_TOOL_NAME, + GLOB_TOOL_NAME, + GREP_TOOL_NAME, + ]); + }); + + it('should process output to a formatted JSON string', () => { + const report = { + SummaryOfFindings: 'summary', + ExplorationTrace: ['trace'], + RelevantLocations: [], + }; + const processed = CodebaseInvestigatorAgent.processOutput?.(report); + expect(processed).toBe(JSON.stringify(report, null, 2)); + }); +}); diff --git a/packages/core/src/agents/executor.test.ts b/packages/core/src/agents/executor.test.ts index 3d58df3704..960dc30bd1 100644 --- a/packages/core/src/agents/executor.test.ts +++ b/packages/core/src/agents/executor.test.ts @@ -329,6 +329,47 @@ describe('AgentExecutor', () => { new RegExp(`^${parentId}-${definition.name}-`), ); }); + + it('should correctly apply templates to initialMessages', async () => { + const definition = createTestDefinition(); + // Override promptConfig to use initialMessages instead of systemPrompt + definition.promptConfig = { + initialMessages: [ + { role: 'user', parts: [{ text: 'Goal: ${goal}' }] }, + { role: 'model', parts: [{ text: 'OK, starting on ${goal}.' }] }, + ], + }; + const inputs = { goal: 'TestGoal' }; + + // Mock a response to prevent the loop from running forever + mockModelResponse([ + { + name: TASK_COMPLETE_TOOL_NAME, + args: { finalResult: 'done' }, + id: 'call1', + }, + ]); + + const executor = await AgentExecutor.create( + definition, + mockConfig, + onActivity, + ); + await executor.run(inputs, signal); + + const chatConstructorArgs = MockedGeminiChat.mock.calls[0]; + const startHistory = chatConstructorArgs[2]; // history is the 3rd arg + + expect(startHistory).toBeDefined(); + expect(startHistory).toHaveLength(2); + + // Perform checks on defined objects to satisfy TS + const firstPart = startHistory?.[0]?.parts?.[0]; + expect(firstPart?.text).toBe('Goal: TestGoal'); + + const secondPart = startHistory?.[1]?.parts?.[0]; + expect(secondPart?.text).toBe('OK, starting on TestGoal.'); + }); }); describe('run (Execution Loop and Logic)', () => { @@ -420,9 +461,16 @@ describe('AgentExecutor', () => { const chatConstructorArgs = MockedGeminiChat.mock.calls[0]; const chatConfig = chatConstructorArgs[1]; - expect(chatConfig?.systemInstruction).toContain( + const systemInstruction = chatConfig?.systemInstruction as string; + + expect(systemInstruction).toContain( `MUST call the \`${TASK_COMPLETE_TOOL_NAME}\` tool`, ); + expect(systemInstruction).toContain('Mocked Environment Context'); + expect(systemInstruction).toContain( + 'You are running in a non-interactive mode', + ); + expect(systemInstruction).toContain('Always use absolute paths'); const turn1Params = getMockMessageParams(0); @@ -921,6 +969,203 @@ describe('AgentExecutor', () => { }); }); + describe('Edge Cases and Error Handling', () => { + it('should report an error if complete_task output fails schema validation', async () => { + const definition = createTestDefinition( + [], + {}, + 'default', + z.string().min(10), // The schema is for the output value itself + ); + const executor = await AgentExecutor.create( + definition, + mockConfig, + onActivity, + ); + + // Turn 1: Invalid arg (too short) + mockModelResponse([ + { + name: TASK_COMPLETE_TOOL_NAME, + args: { finalResult: 'short' }, + id: 'call1', + }, + ]); + + // Turn 2: Corrected + mockModelResponse([ + { + name: TASK_COMPLETE_TOOL_NAME, + args: { finalResult: 'This is a much longer and valid result' }, + id: 'call2', + }, + ]); + + const output = await executor.run({ goal: 'Validation test' }, signal); + + expect(mockSendMessageStream).toHaveBeenCalledTimes(2); + + const expectedError = + 'Output validation failed: {"formErrors":["String must contain at least 10 character(s)"],"fieldErrors":{}}'; + + // Check that the error was reported in the activity stream + expect(activities).toContainEqual( + expect.objectContaining({ + type: 'ERROR', + data: { + context: 'tool_call', + name: TASK_COMPLETE_TOOL_NAME, + error: expect.stringContaining('Output validation failed'), + }, + }), + ); + + // Check that the error was sent back to the model for the next turn + const turn2Params = getMockMessageParams(1); + const turn2Parts = turn2Params.message; + expect(turn2Parts).toEqual([ + expect.objectContaining({ + functionResponse: expect.objectContaining({ + name: TASK_COMPLETE_TOOL_NAME, + response: { error: expectedError }, + id: 'call1', + }), + }), + ]); + + // Check that the agent eventually succeeded + expect(output.result).toContain('This is a much longer and valid result'); + expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL); + }); + + it('should throw and log if GeminiChat creation fails', async () => { + const definition = createTestDefinition(); + const initError = new Error('Chat creation failed'); + MockedGeminiChat.mockImplementationOnce(() => { + throw initError; + }); + + // We expect the error to be thrown during the run, not creation + const executor = await AgentExecutor.create( + definition, + mockConfig, + onActivity, + ); + + await expect(executor.run({ goal: 'test' }, signal)).rejects.toThrow( + `Failed to create chat object: ${initError}`, + ); + + // Ensure the error was reported via the activity callback + expect(activities).toContainEqual( + expect.objectContaining({ + type: 'ERROR', + data: expect.objectContaining({ + error: `Error: Failed to create chat object: ${initError}`, + }), + }), + ); + + // Ensure the agent run was logged as a failure + expect(mockedLogAgentFinish).toHaveBeenCalledWith( + mockConfig, + expect.objectContaining({ + terminate_reason: AgentTerminateMode.ERROR, + }), + ); + }); + + it('should handle a failed tool call and feed the error to the model', async () => { + const definition = createTestDefinition([LS_TOOL_NAME]); + const executor = await AgentExecutor.create( + definition, + mockConfig, + onActivity, + ); + const toolErrorMessage = 'Tool failed spectacularly'; + + // Turn 1: Model calls a tool that will fail + mockModelResponse([ + { name: LS_TOOL_NAME, args: { path: '/fake' }, id: 'call1' }, + ]); + mockExecuteToolCall.mockResolvedValueOnce({ + status: 'error', + request: { + callId: 'call1', + name: LS_TOOL_NAME, + args: { path: '/fake' }, + isClientInitiated: false, + prompt_id: 'test-prompt', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: 'call1', + resultDisplay: '', + responseParts: [ + { + functionResponse: { + name: LS_TOOL_NAME, + response: { error: toolErrorMessage }, + id: 'call1', + }, + }, + ], + error: { + type: 'ToolError', + message: toolErrorMessage, + }, + errorType: 'ToolError', + contentLength: 0, + }, + }); + + // Turn 2: Model sees the error and completes + mockModelResponse([ + { + name: TASK_COMPLETE_TOOL_NAME, + args: { finalResult: 'Aborted due to tool failure.' }, + id: 'call2', + }, + ]); + + const output = await executor.run({ goal: 'Tool failure test' }, signal); + + expect(mockExecuteToolCall).toHaveBeenCalledTimes(1); + expect(mockSendMessageStream).toHaveBeenCalledTimes(2); + + // Verify the error was reported in the activity stream + expect(activities).toContainEqual( + expect.objectContaining({ + type: 'ERROR', + data: { + context: 'tool_call', + name: LS_TOOL_NAME, + error: toolErrorMessage, + }, + }), + ); + + // Verify the error was sent back to the model + const turn2Params = getMockMessageParams(1); + const parts = turn2Params.message; + expect(parts).toEqual([ + expect.objectContaining({ + functionResponse: expect.objectContaining({ + name: LS_TOOL_NAME, + id: 'call1', + response: { + error: toolErrorMessage, + }, + }), + }), + ]); + + expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL); + expect(output.result).toBe('Aborted due to tool failure.'); + }); + }); + describe('run (Termination Conditions)', () => { const mockWorkResponse = (id: string) => { mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]); diff --git a/packages/core/src/code_assist/codeAssist.test.ts b/packages/core/src/code_assist/codeAssist.test.ts new file mode 100644 index 0000000000..1608a7d976 --- /dev/null +++ b/packages/core/src/code_assist/codeAssist.test.ts @@ -0,0 +1,163 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { AuthType } from '../core/contentGenerator.js'; +import { getOauthClient } from './oauth2.js'; +import { setupUser } from './setup.js'; +import { CodeAssistServer } from './server.js'; +import { + createCodeAssistContentGenerator, + getCodeAssistServer, +} from './codeAssist.js'; +import type { Config } from '../config/config.js'; +import { LoggingContentGenerator } from '../core/loggingContentGenerator.js'; +import { UserTierId } from './types.js'; + +// Mock dependencies +vi.mock('./oauth2.js'); +vi.mock('./setup.js'); +vi.mock('./server.js'); +vi.mock('../core/loggingContentGenerator.js'); + +const mockedGetOauthClient = vi.mocked(getOauthClient); +const mockedSetupUser = vi.mocked(setupUser); +const MockedCodeAssistServer = vi.mocked(CodeAssistServer); +const MockedLoggingContentGenerator = vi.mocked(LoggingContentGenerator); + +describe('codeAssist', () => { + beforeEach(() => { + vi.resetAllMocks(); + }); + + describe('createCodeAssistContentGenerator', () => { + const httpOptions = {}; + const mockConfig = {} as Config; + const mockAuthClient = { a: 'client' }; + const mockUserData = { + projectId: 'test-project', + userTier: UserTierId.FREE, + }; + + it('should create a server for LOGIN_WITH_GOOGLE', async () => { + mockedGetOauthClient.mockResolvedValue(mockAuthClient as never); + mockedSetupUser.mockResolvedValue(mockUserData); + + const generator = await createCodeAssistContentGenerator( + httpOptions, + AuthType.LOGIN_WITH_GOOGLE, + mockConfig, + 'session-123', + ); + + expect(getOauthClient).toHaveBeenCalledWith( + AuthType.LOGIN_WITH_GOOGLE, + mockConfig, + ); + expect(setupUser).toHaveBeenCalledWith(mockAuthClient); + expect(MockedCodeAssistServer).toHaveBeenCalledWith( + mockAuthClient, + 'test-project', + httpOptions, + 'session-123', + 'free-tier', + ); + expect(generator).toBeInstanceOf(MockedCodeAssistServer); + }); + + it('should create a server for CLOUD_SHELL', async () => { + mockedGetOauthClient.mockResolvedValue(mockAuthClient as never); + mockedSetupUser.mockResolvedValue(mockUserData); + + const generator = await createCodeAssistContentGenerator( + httpOptions, + AuthType.CLOUD_SHELL, + mockConfig, + ); + + expect(getOauthClient).toHaveBeenCalledWith( + AuthType.CLOUD_SHELL, + mockConfig, + ); + expect(setupUser).toHaveBeenCalledWith(mockAuthClient); + expect(MockedCodeAssistServer).toHaveBeenCalledWith( + mockAuthClient, + 'test-project', + httpOptions, + undefined, // No session ID + 'free-tier', + ); + expect(generator).toBeInstanceOf(MockedCodeAssistServer); + }); + + it('should throw an error for unsupported auth types', async () => { + await expect( + createCodeAssistContentGenerator( + httpOptions, + 'api-key' as AuthType, // Use literal string to avoid enum resolution issues + mockConfig, + ), + ).rejects.toThrow('Unsupported authType: api-key'); + }); + }); + + describe('getCodeAssistServer', () => { + it('should return the server if it is a CodeAssistServer', () => { + const mockServer = new MockedCodeAssistServer({} as never, '', {}); + const mockConfig = { + getContentGenerator: () => mockServer, + } as unknown as Config; + + const server = getCodeAssistServer(mockConfig); + expect(server).toBe(mockServer); + }); + + it('should unwrap and return the server if it is wrapped in a LoggingContentGenerator', () => { + const mockServer = new MockedCodeAssistServer({} as never, '', {}); + const mockLogger = new MockedLoggingContentGenerator( + {} as never, + {} as never, + ); + vi.spyOn(mockLogger, 'getWrapped').mockReturnValue(mockServer); + + const mockConfig = { + getContentGenerator: () => mockLogger, + } as unknown as Config; + + const server = getCodeAssistServer(mockConfig); + expect(server).toBe(mockServer); + expect(mockLogger.getWrapped).toHaveBeenCalled(); + }); + + it('should return undefined if the content generator is not a CodeAssistServer', () => { + const mockGenerator = { a: 'generator' }; // Not a CodeAssistServer + const mockConfig = { + getContentGenerator: () => mockGenerator, + } as unknown as Config; + + const server = getCodeAssistServer(mockConfig); + expect(server).toBeUndefined(); + }); + + it('should return undefined if the wrapped generator is not a CodeAssistServer', () => { + const mockGenerator = { a: 'generator' }; // Not a CodeAssistServer + const mockLogger = new MockedLoggingContentGenerator( + {} as never, + {} as never, + ); + vi.spyOn(mockLogger, 'getWrapped').mockReturnValue( + mockGenerator as never, + ); + + const mockConfig = { + getContentGenerator: () => mockLogger, + } as unknown as Config; + + const server = getCodeAssistServer(mockConfig); + expect(server).toBeUndefined(); + }); + }); +}); diff --git a/packages/core/src/code_assist/experiments/client_metadata.test.ts b/packages/core/src/code_assist/experiments/client_metadata.test.ts new file mode 100644 index 0000000000..da8eef5683 --- /dev/null +++ b/packages/core/src/code_assist/experiments/client_metadata.test.ts @@ -0,0 +1,112 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { ReleaseChannel, getReleaseChannel } from '../../utils/channel.js'; + +// Mock dependencies before importing the module under test +vi.mock('../../utils/channel.js', async () => { + const actual = await vi.importActual('../../utils/channel.js'); + return { + ...(actual as object), + getReleaseChannel: vi.fn(), + }; +}); + +describe('client_metadata', () => { + const originalPlatform = process.platform; + const originalArch = process.arch; + const originalCliVersion = process.env['CLI_VERSION']; + const originalNodeVersion = process.version; + + beforeEach(async () => { + // Reset modules to clear the cached `clientMetadataPromise` + vi.resetModules(); + // Re-import the module to get a fresh instance + await import('./client_metadata.js'); + // Provide a default mock implementation for each test + vi.mocked(getReleaseChannel).mockResolvedValue(ReleaseChannel.STABLE); + }); + + afterEach(() => { + // Restore original process properties to avoid side-effects between tests + Object.defineProperty(process, 'platform', { value: originalPlatform }); + Object.defineProperty(process, 'arch', { value: originalArch }); + process.env['CLI_VERSION'] = originalCliVersion; + Object.defineProperty(process, 'version', { value: originalNodeVersion }); + vi.clearAllMocks(); + }); + + describe('getPlatform', () => { + const testCases = [ + { platform: 'darwin', arch: 'x64', expected: 'DARWIN_AMD64' }, + { platform: 'darwin', arch: 'arm64', expected: 'DARWIN_ARM64' }, + { platform: 'linux', arch: 'x64', expected: 'LINUX_AMD64' }, + { platform: 'linux', arch: 'arm64', expected: 'LINUX_ARM64' }, + { platform: 'win32', arch: 'x64', expected: 'WINDOWS_AMD64' }, + { platform: 'sunos', arch: 'x64', expected: 'PLATFORM_UNSPECIFIED' }, + { platform: 'win32', arch: 'arm', expected: 'PLATFORM_UNSPECIFIED' }, + ]; + + for (const { platform, arch, expected } of testCases) { + it(`should return ${expected} for platform ${platform} and arch ${arch}`, async () => { + Object.defineProperty(process, 'platform', { value: platform }); + Object.defineProperty(process, 'arch', { value: arch }); + const { getClientMetadata } = await import('./client_metadata.js'); + + const metadata = await getClientMetadata(); + expect(metadata.platform).toBe(expected); + }); + } + }); + + describe('getClientMetadata', () => { + it('should use CLI_VERSION for ideVersion if set', async () => { + process.env['CLI_VERSION'] = '1.2.3'; + Object.defineProperty(process, 'version', { value: 'v18.0.0' }); + const { getClientMetadata } = await import('./client_metadata.js'); + + const metadata = await getClientMetadata(); + expect(metadata.ideVersion).toBe('1.2.3'); + }); + + it('should use process.version for ideVersion as a fallback', async () => { + delete process.env['CLI_VERSION']; + Object.defineProperty(process, 'version', { value: 'v20.0.0' }); + const { getClientMetadata } = await import('./client_metadata.js'); + + const metadata = await getClientMetadata(); + expect(metadata.ideVersion).toBe('v20.0.0'); + }); + + it('should call getReleaseChannel to get the update channel', async () => { + vi.mocked(getReleaseChannel).mockResolvedValue(ReleaseChannel.NIGHTLY); + const { getClientMetadata } = await import('./client_metadata.js'); + + const metadata = await getClientMetadata(); + + expect(metadata.updateChannel).toBe('nightly'); + expect(getReleaseChannel).toHaveBeenCalled(); + }); + + it('should cache the client metadata promise', async () => { + const { getClientMetadata } = await import('./client_metadata.js'); + + const firstCall = await getClientMetadata(); + const secondCall = await getClientMetadata(); + + expect(firstCall).toBe(secondCall); + // Ensure the underlying functions are only called once + expect(getReleaseChannel).toHaveBeenCalledTimes(1); + }); + + it('should always return the IDE name as GEMINI_CLI', async () => { + const { getClientMetadata } = await import('./client_metadata.js'); + const metadata = await getClientMetadata(); + expect(metadata.ideName).toBe('GEMINI_CLI'); + }); + }); +}); diff --git a/packages/core/src/code_assist/experiments/experiments.test.ts b/packages/core/src/code_assist/experiments/experiments.test.ts new file mode 100644 index 0000000000..393e648ccf --- /dev/null +++ b/packages/core/src/code_assist/experiments/experiments.test.ts @@ -0,0 +1,115 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import type { CodeAssistServer } from '../server.js'; +import { getClientMetadata } from './client_metadata.js'; +import type { ListExperimentsResponse, Flag } from './types.js'; + +// Mock dependencies before importing the module under test +vi.mock('../server.js'); +vi.mock('./client_metadata.js'); + +describe('experiments', () => { + let mockServer: CodeAssistServer; + + beforeEach(() => { + // Reset modules to clear the cached `experimentsPromise` + vi.resetModules(); + + // Mock the dependencies that `getExperiments` relies on + vi.mocked(getClientMetadata).mockResolvedValue({ + ideName: 'GEMINI_CLI', + ideVersion: '1.0.0', + platform: 'LINUX_AMD64', + updateChannel: 'stable', + }); + + // Create a mock instance of the server for each test + mockServer = { + listExperiments: vi.fn(), + } as unknown as CodeAssistServer; + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it('should fetch and parse experiments from the server', async () => { + const { getExperiments } = await import('./experiments.js'); + const mockApiResponse: ListExperimentsResponse = { + flags: [ + { name: 'flag1', boolValue: true }, + { name: 'flag2', stringValue: 'value' }, + ], + experimentIds: [123, 456], + }; + vi.mocked(mockServer.listExperiments).mockResolvedValue(mockApiResponse); + + const experiments = await getExperiments(mockServer); + + // Verify that the dependencies were called + expect(getClientMetadata).toHaveBeenCalled(); + expect(mockServer.listExperiments).toHaveBeenCalledWith( + await getClientMetadata(), + ); + + // Verify that the response was parsed correctly + expect(experiments.flags['flag1']).toEqual({ + name: 'flag1', + boolValue: true, + }); + expect(experiments.flags['flag2']).toEqual({ + name: 'flag2', + stringValue: 'value', + }); + expect(experiments.experimentIds).toEqual([123, 456]); + }); + + it('should handle an empty or partial response from the server', async () => { + const { getExperiments } = await import('./experiments.js'); + const mockApiResponse: ListExperimentsResponse = {}; // No flags or experimentIds + vi.mocked(mockServer.listExperiments).mockResolvedValue(mockApiResponse); + + const experiments = await getExperiments(mockServer); + + expect(experiments.flags).toEqual({}); + expect(experiments.experimentIds).toEqual([]); + }); + + it('should ignore flags that are missing a name', async () => { + const { getExperiments } = await import('./experiments.js'); + const mockApiResponse: ListExperimentsResponse = { + flags: [ + { boolValue: true } as Flag, // No name + { name: 'flag2', stringValue: 'value' }, + ], + }; + vi.mocked(mockServer.listExperiments).mockResolvedValue(mockApiResponse); + + const experiments = await getExperiments(mockServer); + + expect(Object.keys(experiments.flags)).toHaveLength(1); + expect(experiments.flags['flag2']).toBeDefined(); + expect(experiments.flags['undefined']).toBeUndefined(); + }); + + it('should cache the experiments promise to avoid multiple fetches', async () => { + const { getExperiments } = await import('./experiments.js'); + const mockApiResponse: ListExperimentsResponse = { + experimentIds: [1, 2, 3], + }; + vi.mocked(mockServer.listExperiments).mockResolvedValue(mockApiResponse); + + const firstCall = await getExperiments(mockServer); + const secondCall = await getExperiments(mockServer); + + expect(firstCall).toBe(secondCall); // Should be the exact same promise object + // Verify the underlying functions were only called once + expect(getClientMetadata).toHaveBeenCalledTimes(1); + expect(mockServer.listExperiments).toHaveBeenCalledTimes(1); + }); +}); diff --git a/packages/core/src/code_assist/oauth-credential-storage.test.ts b/packages/core/src/code_assist/oauth-credential-storage.test.ts index c3fff6a59b..d75892f953 100644 --- a/packages/core/src/code_assist/oauth-credential-storage.test.ts +++ b/packages/core/src/code_assist/oauth-credential-storage.test.ts @@ -173,6 +173,58 @@ describe('OAuthCredentialStorage', () => { expect(result).toEqual(mockCredentials); }); + + it('should throw an error if the migration file contains invalid JSON', async () => { + vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue( + null, + ); + vi.spyOn(fs, 'readFile').mockResolvedValue('invalid json'); + + await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow( + 'Failed to load OAuth credentials', + ); + }); + + it('should not delete the old file if saving migrated credentials fails', async () => { + vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue( + null, + ); + vi.spyOn(fs, 'readFile').mockResolvedValue( + JSON.stringify(mockCredentials), + ); + vi.spyOn(mockHybridTokenStorage, 'setCredentials').mockRejectedValue( + new Error('Save failed'), + ); + + await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow( + 'Failed to load OAuth credentials', + ); + + expect(fs.rm).not.toHaveBeenCalled(); + }); + + it('should return credentials even if access_token is missing from storage', async () => { + const partialMcpCredentials = { + ...mockMcpCredentials, + token: { + ...mockMcpCredentials.token, + accessToken: undefined, + }, + }; + vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue( + partialMcpCredentials, + ); + + const result = await OAuthCredentialStorage.loadCredentials(); + + expect(result).toEqual({ + access_token: undefined, + refresh_token: mockCredentials.refresh_token, + token_type: mockCredentials.token_type, + scope: mockCredentials.scope, + expiry_date: mockCredentials.expiry_date, + }); + }); }); describe('saveCredentials', () => { @@ -195,6 +247,28 @@ describe('OAuthCredentialStorage', () => { 'Attempted to save credentials without an access token.', ); }); + + it('should handle saving credentials with null or undefined optional fields', async () => { + const partialCredentials: Credentials = { + access_token: 'only_access_token', + refresh_token: null, // test null + scope: undefined, // test undefined + }; + + await OAuthCredentialStorage.saveCredentials(partialCredentials); + + expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith({ + serverName: 'main-account', + token: { + accessToken: 'only_access_token', + refreshToken: undefined, + tokenType: 'Bearer', // default + scope: undefined, + expiresAt: undefined, + }, + updatedAt: expect.any(Number), + }); + }); }); describe('clearCredentials', () => { diff --git a/packages/core/src/code_assist/server.test.ts b/packages/core/src/code_assist/server.test.ts index 86c73f601a..099c05ad5e 100644 --- a/packages/core/src/code_assist/server.test.ts +++ b/packages/core/src/code_assist/server.test.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { beforeEach, describe, it, expect, vi } from 'vitest'; +import { beforeEach, describe, it, expect, vi, afterEach } from 'vitest'; import { CodeAssistServer } from './server.js'; import { OAuth2Client } from 'google-auth-library'; import { UserTierId } from './types.js'; @@ -29,15 +29,16 @@ describe('CodeAssistServer', () => { }); it('should call the generateContent endpoint', async () => { - const client = new OAuth2Client(); + const mockRequest = vi.fn(); + const client = { request: mockRequest } as unknown as OAuth2Client; const server = new CodeAssistServer( client, 'test-project', - {}, + { headers: { 'x-custom-header': 'test-value' } }, 'test-session', UserTierId.FREE, ); - const mockResponse = { + const mockResponseData = { response: { candidates: [ { @@ -52,7 +53,7 @@ describe('CodeAssistServer', () => { ], }, }; - vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse); + mockRequest.mockResolvedValue({ data: mockResponseData }); const response = await server.generateContent( { @@ -62,18 +63,59 @@ describe('CodeAssistServer', () => { 'user-prompt-id', ); - expect(server.requestPost).toHaveBeenCalledWith( - 'generateContent', - expect.any(Object), - undefined, - ); + expect(mockRequest).toHaveBeenCalledWith({ + url: expect.stringContaining(':generateContent'), + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'x-custom-header': 'test-value', + }, + responseType: 'json', + body: expect.any(String), + signal: undefined, + }); + + const requestBody = JSON.parse(mockRequest.mock.calls[0][0].body); + expect(requestBody.user_prompt_id).toBe('user-prompt-id'); + expect(requestBody.project).toBe('test-project'); + expect(response.candidates?.[0]?.content?.parts?.[0]?.text).toBe( 'response', ); }); - it('should call the generateContentStream endpoint', async () => { - const client = new OAuth2Client(); + describe('getMethodUrl', () => { + const originalEnv = process.env; + + beforeEach(() => { + // Reset the environment variables to their original state + process.env = { ...originalEnv }; + }); + + afterEach(() => { + // Restore the original environment variables + process.env = originalEnv; + }); + + it('should construct the default URL correctly', () => { + const server = new CodeAssistServer({} as never); + const url = server.getMethodUrl('testMethod'); + expect(url).toBe( + 'https://cloudcode-pa.googleapis.com/v1internal:testMethod', + ); + }); + + it('should use the CODE_ASSIST_ENDPOINT environment variable if set', () => { + process.env['CODE_ASSIST_ENDPOINT'] = 'https://custom-endpoint.com'; + const server = new CodeAssistServer({} as never); + const url = server.getMethodUrl('testMethod'); + expect(url).toBe('https://custom-endpoint.com/v1internal:testMethod'); + }); + }); + + it('should call the generateContentStream endpoint and parse SSE', async () => { + const mockRequest = vi.fn(); + const client = { request: mockRequest } as unknown as OAuth2Client; const server = new CodeAssistServer( client, 'test-project', @@ -81,24 +123,21 @@ describe('CodeAssistServer', () => { 'test-session', UserTierId.FREE, ); - const mockResponse = (async function* () { - yield { - response: { - candidates: [ - { - index: 0, - content: { - role: 'model', - parts: [{ text: 'response' }], - }, - finishReason: 'STOP', - safetyRatings: [], - }, - ], - }, - }; - })(); - vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse); + + // Create a mock readable stream + const { Readable } = await import('node:stream'); + const mockStream = new Readable({ + read() {}, + }); + + const mockResponseData1 = { + response: { candidates: [{ content: { parts: [{ text: 'Hello' }] } }] }, + }; + const mockResponseData2 = { + response: { candidates: [{ content: { parts: [{ text: ' World' }] } }] }, + }; + + mockRequest.mockResolvedValue({ data: mockStream }); const stream = await server.generateContentStream( { @@ -108,14 +147,61 @@ describe('CodeAssistServer', () => { 'user-prompt-id', ); + // Push SSE data to the stream + // Use setTimeout to ensure the stream processing has started + setTimeout(() => { + mockStream.push('data: ' + JSON.stringify(mockResponseData1) + '\n\n'); + mockStream.push('id: 123\n'); // Should be ignored + mockStream.push('data: ' + JSON.stringify(mockResponseData2) + '\n\n'); + mockStream.push(null); // End the stream + }, 0); + + const results = []; for await (const res of stream) { - expect(server.requestStreamingPost).toHaveBeenCalledWith( - 'streamGenerateContent', - expect.any(Object), - undefined, - ); - expect(res.candidates?.[0]?.content?.parts?.[0]?.text).toBe('response'); + results.push(res); } + + expect(mockRequest).toHaveBeenCalledWith({ + url: expect.stringContaining(':streamGenerateContent'), + method: 'POST', + params: { alt: 'sse' }, + responseType: 'stream', + body: expect.any(String), + headers: { + 'Content-Type': 'application/json', + }, + signal: undefined, + }); + + expect(results).toHaveLength(2); + expect(results[0].candidates?.[0].content?.parts?.[0].text).toBe('Hello'); + expect(results[1].candidates?.[0].content?.parts?.[0].text).toBe(' World'); + }); + + it('should ignore malformed SSE data', async () => { + const mockRequest = vi.fn(); + const client = { request: mockRequest } as unknown as OAuth2Client; + const server = new CodeAssistServer(client); + + const { Readable } = await import('node:stream'); + const mockStream = new Readable({ + read() {}, + }); + + mockRequest.mockResolvedValue({ data: mockStream }); + + const stream = await server.requestStreamingPost('testStream', {}); + + setTimeout(() => { + mockStream.push('this is a malformed line\n'); + mockStream.push(null); + }, 0); + + const results = []; + for await (const res of stream) { + results.push(res); + } + expect(results).toHaveLength(0); }); it('should call the onboardUser endpoint', async () => { @@ -253,6 +339,22 @@ describe('CodeAssistServer', () => { }); }); + it('should re-throw non-VPC-SC errors from loadCodeAssist', async () => { + const client = new OAuth2Client(); + const server = new CodeAssistServer(client); + const genericError = new Error('Something else went wrong'); + vi.spyOn(server, 'requestPost').mockRejectedValue(genericError); + + await expect(server.loadCodeAssist({ metadata: {} })).rejects.toThrow( + 'Something else went wrong', + ); + + expect(server.requestPost).toHaveBeenCalledWith( + 'loadCodeAssist', + expect.any(Object), + ); + }); + it('should call the listExperiments endpoint with metadata', async () => { const client = new OAuth2Client(); const server = new CodeAssistServer( diff --git a/packages/core/src/code_assist/server.ts b/packages/core/src/code_assist/server.ts index 412d6f46b2..8670fda289 100644 --- a/packages/core/src/code_assist/server.ts +++ b/packages/core/src/code_assist/server.ts @@ -232,18 +232,16 @@ export class CodeAssistServer implements ContentGenerator { let bufferedLines: string[] = []; for await (const line of rl) { - // blank lines are used to separate JSON objects in the stream - if (line === '') { + if (line.startsWith('data: ')) { + bufferedLines.push(line.slice(6).trim()); + } else if (line === '') { if (bufferedLines.length === 0) { continue; // no data to yield } yield JSON.parse(bufferedLines.join('\n')) as T; bufferedLines = []; // Reset the buffer after yielding - } else if (line.startsWith('data: ')) { - bufferedLines.push(line.slice(6).trim()); - } else { - throw new Error(`Unexpected line format in response: ${line}`); } + // Ignore other lines like comments or id fields } })(); } diff --git a/packages/core/src/core/loggingContentGenerator.test.ts b/packages/core/src/core/loggingContentGenerator.test.ts new file mode 100644 index 0000000000..e591f86be9 --- /dev/null +++ b/packages/core/src/core/loggingContentGenerator.test.ts @@ -0,0 +1,257 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +const logApiRequest = vi.hoisted(() => vi.fn()); +const logApiResponse = vi.hoisted(() => vi.fn()); +const logApiError = vi.hoisted(() => vi.fn()); + +vi.mock('../telemetry/loggers.js', () => ({ + logApiRequest, + logApiResponse, + logApiError, +})); + +const runInDevTraceSpan = vi.hoisted(() => + vi.fn(async (meta, fn) => fn({ metadata: {}, endSpan: vi.fn() })), +); + +vi.mock('../telemetry/trace.js', () => ({ + runInDevTraceSpan, +})); + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import type { + GenerateContentResponse, + EmbedContentResponse, +} from '@google/genai'; +import type { ContentGenerator } from './contentGenerator.js'; +import { LoggingContentGenerator } from './loggingContentGenerator.js'; +import type { Config } from '../config/config.js'; +import { ApiRequestEvent } from '../telemetry/types.js'; + +describe('LoggingContentGenerator', () => { + let wrapped: ContentGenerator; + let config: Config; + let loggingContentGenerator: LoggingContentGenerator; + + beforeEach(() => { + wrapped = { + generateContent: vi.fn(), + generateContentStream: vi.fn(), + countTokens: vi.fn(), + embedContent: vi.fn(), + }; + config = { + getGoogleAIConfig: vi.fn(), + getVertexAIConfig: vi.fn(), + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: 'API_KEY', + }), + } as unknown as Config; + loggingContentGenerator = new LoggingContentGenerator(wrapped, config); + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.clearAllMocks(); + vi.useRealTimers(); + }); + + describe('generateContent', () => { + it('should log request and response on success', async () => { + const req = { + contents: [{ role: 'user', parts: [{ text: 'hello' }] }], + model: 'gemini-pro', + }; + const userPromptId = 'prompt-123'; + const response: GenerateContentResponse = { + candidates: [], + usageMetadata: { + promptTokenCount: 1, + candidatesTokenCount: 2, + totalTokenCount: 3, + }, + text: undefined, + functionCalls: undefined, + executableCode: undefined, + codeExecutionResult: undefined, + data: undefined, + }; + vi.mocked(wrapped.generateContent).mockResolvedValue(response); + const startTime = new Date('2025-01-01T00:00:00.000Z'); + vi.setSystemTime(startTime); + + const promise = loggingContentGenerator.generateContent( + req, + userPromptId, + ); + + vi.advanceTimersByTime(1000); + + await promise; + + expect(wrapped.generateContent).toHaveBeenCalledWith(req, userPromptId); + expect(logApiRequest).toHaveBeenCalledWith( + config, + expect.any(ApiRequestEvent), + ); + const responseEvent = vi.mocked(logApiResponse).mock.calls[0][1]; + expect(responseEvent.duration_ms).toBe(1000); + }); + + it('should log error on failure', async () => { + const req = { + contents: [{ role: 'user', parts: [{ text: 'hello' }] }], + model: 'gemini-pro', + }; + const userPromptId = 'prompt-123'; + const error = new Error('test error'); + vi.mocked(wrapped.generateContent).mockRejectedValue(error); + const startTime = new Date('2025-01-01T00:00:00.000Z'); + vi.setSystemTime(startTime); + + const promise = loggingContentGenerator.generateContent( + req, + userPromptId, + ); + + vi.advanceTimersByTime(1000); + + await expect(promise).rejects.toThrow(error); + + expect(logApiRequest).toHaveBeenCalledWith( + config, + expect.any(ApiRequestEvent), + ); + const errorEvent = vi.mocked(logApiError).mock.calls[0][1]; + expect(errorEvent.duration_ms).toBe(1000); + }); + }); + + describe('generateContentStream', () => { + it('should log request and response on success', async () => { + const req = { + contents: [{ role: 'user', parts: [{ text: 'hello' }] }], + model: 'gemini-pro', + }; + const userPromptId = 'prompt-123'; + const response = { + candidates: [], + usageMetadata: { + promptTokenCount: 1, + candidatesTokenCount: 2, + totalTokenCount: 3, + }, + } as unknown as GenerateContentResponse; + + async function* createAsyncGenerator() { + yield response; + } + + vi.mocked(wrapped.generateContentStream).mockResolvedValue( + createAsyncGenerator(), + ); + const startTime = new Date('2025-01-01T00:00:00.000Z'); + vi.setSystemTime(startTime); + + const stream = await loggingContentGenerator.generateContentStream( + req, + userPromptId, + ); + + vi.advanceTimersByTime(1000); + + for await (const _ of stream) { + // consume stream + } + + expect(wrapped.generateContentStream).toHaveBeenCalledWith( + req, + userPromptId, + ); + expect(logApiRequest).toHaveBeenCalledWith( + config, + expect.any(ApiRequestEvent), + ); + const responseEvent = vi.mocked(logApiResponse).mock.calls[0][1]; + expect(responseEvent.duration_ms).toBe(1000); + }); + + it('should log error on failure', async () => { + const req = { + contents: [{ role: 'user', parts: [{ text: 'hello' }] }], + model: 'gemini-pro', + }; + const userPromptId = 'prompt-123'; + const error = new Error('test error'); + + async function* createAsyncGenerator() { + yield Promise.reject(error); + } + + vi.mocked(wrapped.generateContentStream).mockResolvedValue( + createAsyncGenerator(), + ); + const startTime = new Date('2025-01-01T00:00:00.000Z'); + vi.setSystemTime(startTime); + + const stream = await loggingContentGenerator.generateContentStream( + req, + userPromptId, + ); + + vi.advanceTimersByTime(1000); + + await expect(async () => { + for await (const _ of stream) { + // do nothing + } + }).rejects.toThrow(error); + + expect(logApiRequest).toHaveBeenCalledWith( + config, + expect.any(ApiRequestEvent), + ); + const errorEvent = vi.mocked(logApiError).mock.calls[0][1]; + expect(errorEvent.duration_ms).toBe(1000); + }); + }); + + describe('getWrapped', () => { + it('should return the wrapped content generator', () => { + expect(loggingContentGenerator.getWrapped()).toBe(wrapped); + }); + }); + + describe('countTokens', () => { + it('should call the wrapped countTokens method', async () => { + const req = { contents: [], model: 'gemini-pro' }; + const response = { totalTokens: 10 }; + vi.mocked(wrapped.countTokens).mockResolvedValue(response); + + const result = await loggingContentGenerator.countTokens(req); + + expect(wrapped.countTokens).toHaveBeenCalledWith(req); + expect(result).toBe(response); + }); + }); + + describe('embedContent', () => { + it('should call the wrapped embedContent method', async () => { + const req = { + contents: [{ role: 'user', parts: [] }], + model: 'gemini-pro', + }; + const response: EmbedContentResponse = { embeddings: [{ values: [] }] }; + vi.mocked(wrapped.embedContent).mockResolvedValue(response); + + const result = await loggingContentGenerator.embedContent(req); + + expect(wrapped.embedContent).toHaveBeenCalledWith(req); + expect(result).toBe(response); + }); + }); +}); diff --git a/packages/core/src/core/tokenLimits.test.ts b/packages/core/src/core/tokenLimits.test.ts new file mode 100644 index 0000000000..1bff09d315 --- /dev/null +++ b/packages/core/src/core/tokenLimits.test.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { tokenLimit, DEFAULT_TOKEN_LIMIT } from './tokenLimits.js'; + +describe('tokenLimit', () => { + it('should return the correct token limit for gemini-1.5-pro', () => { + expect(tokenLimit('gemini-1.5-pro')).toBe(2_097_152); + }); + + it('should return the correct token limit for gemini-1.5-flash', () => { + expect(tokenLimit('gemini-1.5-flash')).toBe(1_048_576); + }); + + it('should return the default token limit for an unknown model', () => { + expect(tokenLimit('unknown-model')).toBe(DEFAULT_TOKEN_LIMIT); + }); + + it('should return the default token limit if no model is provided', () => { + // @ts-expect-error testing invalid input + expect(tokenLimit(undefined)).toBe(DEFAULT_TOKEN_LIMIT); + }); + + it('should have the correct default token limit value', () => { + expect(DEFAULT_TOKEN_LIMIT).toBe(1_048_576); + }); +}); diff --git a/packages/core/src/hooks/types.test.ts b/packages/core/src/hooks/types.test.ts index d40fe32041..8853b5a103 100644 --- a/packages/core/src/hooks/types.test.ts +++ b/packages/core/src/hooks/types.test.ts @@ -4,8 +4,40 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect } from 'vitest'; -import { HookEventName, HookType } from './types.js'; +import { describe, it, expect, vi } from 'vitest'; +import { + createHookOutput, + DefaultHookOutput, + BeforeToolHookOutput, + BeforeModelHookOutput, + BeforeToolSelectionHookOutput, + AfterModelHookOutput, + HookEventName, + HookType, +} from './types.js'; +import { defaultHookTranslator } from './hookTranslator.js'; +import type { + GenerateContentParameters, + GenerateContentResponse, + ToolConfig, +} from '@google/genai'; +import type { LLMRequest, LLMResponse } from './hookTranslator.js'; +import type { HookDecision } from './types.js'; + +vi.mock('./hookTranslator.js', () => ({ + defaultHookTranslator: { + fromHookLLMResponse: vi.fn( + (response: LLMResponse) => response as unknown as GenerateContentResponse, + ), + fromHookLLMRequest: vi.fn( + (request: LLMRequest, target: GenerateContentParameters) => ({ + ...target, + ...request, + }), + ), + fromHookToolConfig: vi.fn((config: ToolConfig) => config), + }, +})); describe('Hook Types', () => { describe('HookEventName', () => { @@ -36,3 +68,318 @@ describe('Hook Types', () => { }); }); }); + +describe('Hook Output Classes', () => { + describe('createHookOutput', () => { + it('should return DefaultHookOutput for unknown event names', () => { + const output = createHookOutput('UnknownEvent', {}); + expect(output).toBeInstanceOf(DefaultHookOutput); + expect(output).not.toBeInstanceOf(BeforeModelHookOutput); + expect(output).not.toBeInstanceOf(AfterModelHookOutput); + expect(output).not.toBeInstanceOf(BeforeToolSelectionHookOutput); + }); + + it('should return BeforeModelHookOutput for BeforeModel event', () => { + const output = createHookOutput(HookEventName.BeforeModel, {}); + expect(output).toBeInstanceOf(BeforeModelHookOutput); + }); + + it('should return AfterModelHookOutput for AfterModel event', () => { + const output = createHookOutput(HookEventName.AfterModel, {}); + expect(output).toBeInstanceOf(AfterModelHookOutput); + }); + + it('should return BeforeToolSelectionHookOutput for BeforeToolSelection event', () => { + const output = createHookOutput(HookEventName.BeforeToolSelection, {}); + expect(output).toBeInstanceOf(BeforeToolSelectionHookOutput); + }); + }); + + describe('DefaultHookOutput', () => { + it('should construct with provided data', () => { + const data = { + continue: false, + stopReason: 'test stop', + suppressOutput: true, + systemMessage: 'test system message', + decision: 'block' as HookDecision, + reason: 'test reason', + hookSpecificOutput: { key: 'value' }, + }; + const output = new DefaultHookOutput(data); + expect(output.continue).toBe(data.continue); + expect(output.stopReason).toBe(data.stopReason); + expect(output.suppressOutput).toBe(data.suppressOutput); + expect(output.systemMessage).toBe(data.systemMessage); + expect(output.decision).toBe(data.decision); + expect(output.reason).toBe(data.reason); + expect(output.hookSpecificOutput).toEqual(data.hookSpecificOutput); + }); + + it('should return false for isBlockingDecision if decision is not block or deny', () => { + const output1 = new DefaultHookOutput({ decision: 'approve' }); + expect(output1.isBlockingDecision()).toBe(false); + const output2 = new DefaultHookOutput({ decision: undefined }); + expect(output2.isBlockingDecision()).toBe(false); + }); + + it('should return true for isBlockingDecision if decision is block or deny', () => { + const output1 = new DefaultHookOutput({ decision: 'block' }); + expect(output1.isBlockingDecision()).toBe(true); + const output2 = new DefaultHookOutput({ decision: 'deny' }); + expect(output2.isBlockingDecision()).toBe(true); + }); + + it('should return true for shouldStopExecution if continue is false', () => { + const output = new DefaultHookOutput({ continue: false }); + expect(output.shouldStopExecution()).toBe(true); + }); + + it('should return false for shouldStopExecution if continue is true or undefined', () => { + const output1 = new DefaultHookOutput({ continue: true }); + expect(output1.shouldStopExecution()).toBe(false); + const output2 = new DefaultHookOutput({}); + expect(output2.shouldStopExecution()).toBe(false); + }); + + it('should return reason if available', () => { + const output = new DefaultHookOutput({ reason: 'specific reason' }); + expect(output.getEffectiveReason()).toBe('specific reason'); + }); + + it('should return stopReason if reason is not available', () => { + const output = new DefaultHookOutput({ stopReason: 'stop reason' }); + expect(output.getEffectiveReason()).toBe('stop reason'); + }); + + it('should return "No reason provided" if neither reason nor stopReason are available', () => { + const output = new DefaultHookOutput({}); + expect(output.getEffectiveReason()).toBe('No reason provided'); + }); + + it('applyLLMRequestModifications should return target unchanged', () => { + const target: GenerateContentParameters = { + model: 'gemini-pro', + contents: [], + }; + const output = new DefaultHookOutput({}); + expect(output.applyLLMRequestModifications(target)).toBe(target); + }); + + it('applyToolConfigModifications should return target unchanged', () => { + const target = { toolConfig: {}, tools: [] }; + const output = new DefaultHookOutput({}); + expect(output.applyToolConfigModifications(target)).toBe(target); + }); + + it('getAdditionalContext should return additionalContext if present', () => { + const output = new DefaultHookOutput({ + hookSpecificOutput: { additionalContext: 'some context' }, + }); + expect(output.getAdditionalContext()).toBe('some context'); + }); + + it('getAdditionalContext should return undefined if additionalContext is not present', () => { + const output = new DefaultHookOutput({ + hookSpecificOutput: { other: 'value' }, + }); + expect(output.getAdditionalContext()).toBeUndefined(); + }); + + it('getAdditionalContext should return undefined if hookSpecificOutput is undefined', () => { + const output = new DefaultHookOutput({}); + expect(output.getAdditionalContext()).toBeUndefined(); + }); + + it('getBlockingError should return blocked: true and reason if blocking decision', () => { + const output = new DefaultHookOutput({ + decision: 'block', + reason: 'blocked by hook', + }); + expect(output.getBlockingError()).toEqual({ + blocked: true, + reason: 'blocked by hook', + }); + }); + + it('getBlockingError should return blocked: false if not blocking decision', () => { + const output = new DefaultHookOutput({ decision: 'approve' }); + expect(output.getBlockingError()).toEqual({ blocked: false, reason: '' }); + }); + }); + + describe('BeforeToolHookOutput', () => { + it('isBlockingDecision should use permissionDecision from hookSpecificOutput', () => { + const output1 = new BeforeToolHookOutput({ + hookSpecificOutput: { permissionDecision: 'block' }, + }); + expect(output1.isBlockingDecision()).toBe(true); + + const output2 = new BeforeToolHookOutput({ + hookSpecificOutput: { permissionDecision: 'approve' }, + }); + expect(output2.isBlockingDecision()).toBe(false); + }); + + it('getEffectiveReason should use permissionDecisionReason from hookSpecificOutput', () => { + const output1 = new BeforeToolHookOutput({ + hookSpecificOutput: { permissionDecisionReason: 'compat reason' }, + }); + expect(output1.getEffectiveReason()).toBe('compat reason'); + + const output2 = new BeforeToolHookOutput({ + reason: 'default reason', + hookSpecificOutput: { other: 'value' }, + }); + expect(output2.getEffectiveReason()).toBe('default reason'); + }); + }); + + describe('BeforeModelHookOutput', () => { + it('getSyntheticResponse should return synthetic response if llm_response is present', () => { + const mockResponse: LLMResponse = { candidates: [] }; + const output = new BeforeModelHookOutput({ + hookSpecificOutput: { llm_response: mockResponse }, + }); + expect(output.getSyntheticResponse()).toEqual(mockResponse); + expect(defaultHookTranslator.fromHookLLMResponse).toHaveBeenCalledWith( + mockResponse, + ); + }); + + it('getSyntheticResponse should return undefined if llm_response is not present', () => { + const output = new BeforeModelHookOutput({}); + expect(output.getSyntheticResponse()).toBeUndefined(); + }); + + it('applyLLMRequestModifications should apply modifications if llm_request is present', () => { + const target: GenerateContentParameters = { + model: 'gemini-pro', + contents: [{ parts: [{ text: 'original' }] }], + }; + const mockRequest: Partial = { + messages: [{ role: 'user', content: 'modified' }], + }; + const output = new BeforeModelHookOutput({ + hookSpecificOutput: { llm_request: mockRequest }, + }); + const result = output.applyLLMRequestModifications(target); + expect(result).toEqual({ ...target, ...mockRequest }); + expect(defaultHookTranslator.fromHookLLMRequest).toHaveBeenCalledWith( + mockRequest, + target, + ); + }); + + it('applyLLMRequestModifications should return target unchanged if llm_request is not present', () => { + const target: GenerateContentParameters = { + model: 'gemini-pro', + contents: [], + }; + const output = new BeforeModelHookOutput({}); + expect(output.applyLLMRequestModifications(target)).toBe(target); + }); + }); + + describe('BeforeToolSelectionHookOutput', () => { + it('applyToolConfigModifications should apply modifications if toolConfig is present', () => { + const target = { tools: [{ functionDeclarations: [] }] }; + const mockToolConfig = { functionCallingConfig: { mode: 'ANY' } }; + const output = new BeforeToolSelectionHookOutput({ + hookSpecificOutput: { toolConfig: mockToolConfig }, + }); + const result = output.applyToolConfigModifications(target); + expect(result).toEqual({ ...target, toolConfig: mockToolConfig }); + expect(defaultHookTranslator.fromHookToolConfig).toHaveBeenCalledWith( + mockToolConfig, + ); + }); + + it('applyToolConfigModifications should return target unchanged if toolConfig is not present', () => { + const target = { toolConfig: {}, tools: [] }; + const output = new BeforeToolSelectionHookOutput({}); + expect(output.applyToolConfigModifications(target)).toBe(target); + }); + + it('applyToolConfigModifications should initialize tools array if not present', () => { + const target = {}; + const mockToolConfig = { functionCallingConfig: { mode: 'ANY' } }; + const output = new BeforeToolSelectionHookOutput({ + hookSpecificOutput: { toolConfig: mockToolConfig }, + }); + const result = output.applyToolConfigModifications(target); + expect(result).toEqual({ tools: [], toolConfig: mockToolConfig }); + }); + }); + + describe('AfterModelHookOutput', () => { + it('getModifiedResponse should return modified response if llm_response is present and has content', () => { + const mockResponse: LLMResponse = { + candidates: [{ content: { role: 'model', parts: ['modified'] } }], + }; + const output = new AfterModelHookOutput({ + hookSpecificOutput: { llm_response: mockResponse }, + }); + expect(output.getModifiedResponse()).toEqual(mockResponse); + expect(defaultHookTranslator.fromHookLLMResponse).toHaveBeenCalledWith( + mockResponse, + ); + }); + + it('getModifiedResponse should return undefined if llm_response is present but no content', () => { + const mockResponse: LLMResponse = { + candidates: [{ content: { role: 'model', parts: [] } }], + }; + const output = new AfterModelHookOutput({ + hookSpecificOutput: { llm_response: mockResponse }, + }); + expect(output.getModifiedResponse()).toBeUndefined(); + }); + + it('getModifiedResponse should return undefined if llm_response is not present', () => { + const output = new AfterModelHookOutput({}); + expect(output.getModifiedResponse()).toBeUndefined(); + }); + + it('getModifiedResponse should return a synthetic stop response if shouldStopExecution is true', () => { + const output = new AfterModelHookOutput({ + continue: false, + stopReason: 'stopped by hook', + }); + const expectedResponse: LLMResponse = { + candidates: [ + { + content: { + role: 'model', + parts: ['stopped by hook'], + }, + finishReason: 'STOP', + }, + ], + }; + expect(output.getModifiedResponse()).toEqual(expectedResponse); + expect(defaultHookTranslator.fromHookLLMResponse).toHaveBeenCalledWith( + expectedResponse, + ); + }); + + it('getModifiedResponse should return a synthetic stop response with default reason if shouldStopExecution is true and no stopReason', () => { + const output = new AfterModelHookOutput({ continue: false }); + const expectedResponse: LLMResponse = { + candidates: [ + { + content: { + role: 'model', + parts: ['No reason provided'], + }, + finishReason: 'STOP', + }, + ], + }; + expect(output.getModifiedResponse()).toEqual(expectedResponse); + expect(defaultHookTranslator.fromHookLLMResponse).toHaveBeenCalledWith( + expectedResponse, + ); + }); + }); +}); diff --git a/packages/core/src/hooks/types.ts b/packages/core/src/hooks/types.ts index 5a30e2f8c8..4678b0d214 100644 --- a/packages/core/src/hooks/types.ts +++ b/packages/core/src/hooks/types.ts @@ -340,7 +340,7 @@ export class AfterModelHookOutput extends DefaultHookOutput { const hookResponse = this.hookSpecificOutput[ 'llm_response' ] as Partial; - if (hookResponse?.candidates?.[0]?.content) { + if (hookResponse?.candidates?.[0]?.content?.parts?.length) { // Convert hook format to SDK format return defaultHookTranslator.fromHookLLMResponse( hookResponse as LLMResponse, diff --git a/packages/core/src/ide/ide-client.test.ts b/packages/core/src/ide/ide-client.test.ts index 9375cd040f..1881d91122 100644 --- a/packages/core/src/ide/ide-client.test.ts +++ b/packages/core/src/ide/ide-client.test.ts @@ -647,6 +647,254 @@ describe('IdeClient', () => { }); }); + describe('resolveDiffFromCli', () => { + beforeEach(async () => { + // Ensure client is "connected" for these tests + const ideClient = await IdeClient.getInstance(); + // We need to set the client property on the instance for openDiff to work + (ideClient as unknown as { client: Client }).client = mockClient; + mockClient.request.mockResolvedValue({ + isError: false, + content: [], + }); + }); + + it("should resolve an open diff as 'accepted' and return the final content", async () => { + const ideClient = await IdeClient.getInstance(); + const closeDiffSpy = vi + .spyOn( + ideClient as unknown as { + closeDiff: () => Promise; + }, + 'closeDiff', + ) + .mockResolvedValue('final content from ide'); + + const diffPromise = ideClient.openDiff('/test.txt', 'new content'); + + // Yield to the event loop to allow the openDiff promise executor to run + await new Promise((resolve) => setImmediate(resolve)); + + await ideClient.resolveDiffFromCli('/test.txt', 'accepted'); + + const result = await diffPromise; + + expect(result).toEqual({ + status: 'accepted', + content: 'final content from ide', + }); + expect(closeDiffSpy).toHaveBeenCalledWith('/test.txt', { + suppressNotification: true, + }); + expect( + ( + ideClient as unknown as { diffResponses: Map } + ).diffResponses.has('/test.txt'), + ).toBe(false); + }); + + it("should resolve an open diff as 'rejected'", async () => { + const ideClient = await IdeClient.getInstance(); + const closeDiffSpy = vi + .spyOn( + ideClient as unknown as { + closeDiff: () => Promise; + }, + 'closeDiff', + ) + .mockResolvedValue(undefined); + + const diffPromise = ideClient.openDiff('/test.txt', 'new content'); + + // Yield to the event loop to allow the openDiff promise executor to run + await new Promise((resolve) => setImmediate(resolve)); + + await ideClient.resolveDiffFromCli('/test.txt', 'rejected'); + + const result = await diffPromise; + + expect(result).toEqual({ + status: 'rejected', + content: undefined, + }); + expect(closeDiffSpy).toHaveBeenCalledWith('/test.txt', { + suppressNotification: true, + }); + expect( + ( + ideClient as unknown as { diffResponses: Map } + ).diffResponses.has('/test.txt'), + ).toBe(false); + }); + + it('should do nothing if no diff is open for the given file path', async () => { + const ideClient = await IdeClient.getInstance(); + const closeDiffSpy = vi + .spyOn( + ideClient as unknown as { + closeDiff: () => Promise; + }, + 'closeDiff', + ) + .mockResolvedValue(undefined); + + // No call to openDiff, so no resolver will exist. + await ideClient.resolveDiffFromCli('/non-existent.txt', 'accepted'); + + expect(closeDiffSpy).toHaveBeenCalledWith('/non-existent.txt', { + suppressNotification: true, + }); + // No crash should occur, and nothing should be in the map. + expect( + ( + ideClient as unknown as { diffResponses: Map } + ).diffResponses.has('/non-existent.txt'), + ).toBe(false); + }); + }); + + describe('closeDiff', () => { + beforeEach(async () => { + const ideClient = await IdeClient.getInstance(); + (ideClient as unknown as { client: Client }).client = mockClient; + }); + + it('should return undefined if client is not connected', async () => { + const ideClient = await IdeClient.getInstance(); + (ideClient as unknown as { client: Client | undefined }).client = + undefined; + + const result = await ( + ideClient as unknown as { closeDiff: (f: string) => Promise } + ).closeDiff('/test.txt'); + expect(result).toBeUndefined(); + }); + + it('should call client.request with correct arguments', async () => { + const ideClient = await IdeClient.getInstance(); + // Return a valid, empty response as the return value is not under test here. + mockClient.request.mockResolvedValue({ isError: false, content: [] }); + + await ( + ideClient as unknown as { + closeDiff: ( + f: string, + o?: { suppressNotification?: boolean }, + ) => Promise; + } + ).closeDiff('/test.txt', { suppressNotification: true }); + + expect(mockClient.request).toHaveBeenCalledWith( + expect.objectContaining({ + params: { + name: 'closeDiff', + arguments: { + filePath: '/test.txt', + suppressNotification: true, + }, + }, + }), + expect.any(Object), // Schema + expect.any(Object), // Options + ); + }); + + it('should return content from a valid JSON response', async () => { + const ideClient = await IdeClient.getInstance(); + const response = { + isError: false, + content: [ + { type: 'text', text: JSON.stringify({ content: 'file content' }) }, + ], + }; + mockClient.request.mockResolvedValue(response); + + const result = await ( + ideClient as unknown as { closeDiff: (f: string) => Promise } + ).closeDiff('/test.txt'); + expect(result).toBe('file content'); + }); + + it('should return undefined for a valid JSON response with null content', async () => { + const ideClient = await IdeClient.getInstance(); + const response = { + isError: false, + content: [{ type: 'text', text: JSON.stringify({ content: null }) }], + }; + mockClient.request.mockResolvedValue(response); + + const result = await ( + ideClient as unknown as { closeDiff: (f: string) => Promise } + ).closeDiff('/test.txt'); + expect(result).toBeUndefined(); + }); + + it('should return undefined if response is not valid JSON', async () => { + const ideClient = await IdeClient.getInstance(); + const response = { + isError: false, + content: [{ type: 'text', text: 'not json' }], + }; + mockClient.request.mockResolvedValue(response); + + const result = await ( + ideClient as unknown as { closeDiff: (f: string) => Promise } + ).closeDiff('/test.txt'); + expect(result).toBeUndefined(); + }); + + it('should return undefined if request result has isError: true', async () => { + const ideClient = await IdeClient.getInstance(); + const response = { + isError: true, + content: [{ type: 'text', text: 'An error occurred' }], + }; + mockClient.request.mockResolvedValue(response); + + const result = await ( + ideClient as unknown as { closeDiff: (f: string) => Promise } + ).closeDiff('/test.txt'); + expect(result).toBeUndefined(); + }); + + it('should return undefined if client.request throws', async () => { + const ideClient = await IdeClient.getInstance(); + mockClient.request.mockRejectedValue(new Error('Request failed')); + + const result = await ( + ideClient as unknown as { closeDiff: (f: string) => Promise } + ).closeDiff('/test.txt'); + expect(result).toBeUndefined(); + }); + + it('should return undefined if response has no text part', async () => { + const ideClient = await IdeClient.getInstance(); + const response = { + isError: false, + content: [{ type: 'other' }], + }; + mockClient.request.mockResolvedValue(response); + + const result = await ( + ideClient as unknown as { closeDiff: (f: string) => Promise } + ).closeDiff('/test.txt'); + expect(result).toBeUndefined(); + }); + + it('should return undefined if response is falsy', async () => { + const ideClient = await IdeClient.getInstance(); + // Mocking with `null as any` to test the falsy path, as the mock + // function is strictly typed. + // eslint-disable-next-line @typescript-eslint/no-explicit-any + mockClient.request.mockResolvedValue(null as any); + + const result = await ( + ideClient as unknown as { closeDiff: (f: string) => Promise } + ).closeDiff('/test.txt'); + expect(result).toBeUndefined(); + }); + }); + describe('authentication', () => { it('should connect with an auth token if provided in the discovery file', async () => { const authToken = 'test-auth-token'; diff --git a/packages/core/src/mcp/oauth-provider.test.ts b/packages/core/src/mcp/oauth-provider.test.ts index e462cea7bd..46e9b0d732 100644 --- a/packages/core/src/mcp/oauth-provider.test.ts +++ b/packages/core/src/mcp/oauth-provider.test.ts @@ -714,6 +714,269 @@ describe('MCPOAuthProvider', () => { ).rejects.toThrow('Token exchange failed: invalid_grant - Invalid grant'); }); + it('should handle OAuth discovery failure', async () => { + const configWithoutAuth: MCPOAuthConfig = { ...mockConfig }; + delete configWithoutAuth.authorizationUrl; + delete configWithoutAuth.tokenUrl; + + mockFetch.mockResolvedValueOnce( + createMockResponse({ + ok: false, + status: 404, + }), + ); + + const authProvider = new MCPOAuthProvider(); + await expect( + authProvider.authenticate( + 'test-server', + configWithoutAuth, + 'https://api.example.com', + ), + ).rejects.toThrow( + 'Failed to discover OAuth configuration from MCP server', + ); + }); + + it('should handle authorization server metadata discovery failure', async () => { + const configWithoutClient: MCPOAuthConfig = { ...mockConfig }; + delete configWithoutClient.clientId; + + mockFetch.mockResolvedValue( + createMockResponse({ + ok: false, + status: 404, + }), + ); + + // Prevent callback server from hanging the test + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + }); + + const authProvider = new MCPOAuthProvider(); + await expect( + authProvider.authenticate('test-server', configWithoutClient), + ).rejects.toThrow( + 'Failed to fetch authorization server metadata for client registration', + ); + }); + + it('should handle invalid callback request', async () => { + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/invalid-path', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 0); + }); + + const authProvider = new MCPOAuthProvider(); + // The test will timeout if the server does not handle the invalid request correctly. + // We are testing that the server does not hang. + await Promise.race([ + authProvider.authenticate('test-server', mockConfig), + new Promise((resolve) => setTimeout(resolve, 1000)), + ]); + }); + + it('should handle token exchange failure with non-json response', async () => { + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + mockFetch.mockResolvedValueOnce( + createMockResponse({ + ok: false, + status: 500, + contentType: 'text/html', + text: 'Internal Server Error', + }), + ); + + const authProvider = new MCPOAuthProvider(); + await expect( + authProvider.authenticate('test-server', mockConfig), + ).rejects.toThrow('Token exchange failed: 500 - Internal Server Error'); + }); + + it('should handle token exchange with unexpected content type', async () => { + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + mockFetch.mockResolvedValueOnce( + createMockResponse({ + ok: true, + contentType: 'text/plain', + text: 'access_token=plain_text_token', + }), + ); + + const authProvider = new MCPOAuthProvider(); + const result = await authProvider.authenticate('test-server', mockConfig); + expect(result.accessToken).toBe('plain_text_token'); + }); + + it('should handle refresh token failure with non-json response', async () => { + mockFetch.mockResolvedValueOnce( + createMockResponse({ + ok: false, + status: 500, + contentType: 'text/html', + text: 'Internal Server Error', + }), + ); + + const authProvider = new MCPOAuthProvider(); + await expect( + authProvider.refreshAccessToken( + mockConfig, + 'invalid_refresh_token', + 'https://auth.example.com/token', + ), + ).rejects.toThrow('Token refresh failed: 500 - Internal Server Error'); + }); + + it('should handle refresh token with unexpected content type', async () => { + mockFetch.mockResolvedValueOnce( + createMockResponse({ + ok: true, + contentType: 'text/plain', + text: 'access_token=plain_text_token', + }), + ); + + const authProvider = new MCPOAuthProvider(); + const result = await authProvider.refreshAccessToken( + mockConfig, + 'refresh_token', + 'https://auth.example.com/token', + ); + expect(result.access_token).toBe('plain_text_token'); + }); + + it('should continue authentication when browser fails to open', async () => { + mockOpenBrowserSecurely.mockRejectedValue(new Error('Browser not found')); + + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + mockFetch.mockResolvedValueOnce( + createMockResponse({ + ok: true, + contentType: 'application/json', + text: JSON.stringify(mockTokenResponse), + json: mockTokenResponse, + }), + ); + + const authProvider = new MCPOAuthProvider(); + const result = await authProvider.authenticate('test-server', mockConfig); + expect(result).toBeDefined(); + }); + + it('should return null when token is expired and no refresh token is available', async () => { + const expiredCredentials = { + serverName: 'test-server', + token: { + ...mockToken, + refreshToken: undefined, + expiresAt: Date.now() - 3600000, + }, + clientId: 'test-client-id', + tokenUrl: 'https://auth.example.com/token', + updatedAt: Date.now(), + }; + + const tokenStorage = new MCPOAuthTokenStorage(); + vi.mocked(tokenStorage.getCredentials).mockResolvedValue( + expiredCredentials, + ); + vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true); + + const authProvider = new MCPOAuthProvider(); + const result = await authProvider.getValidToken( + 'test-server', + mockConfig, + ); + + expect(result).toBeNull(); + }); + it('should handle callback timeout', async () => { vi.mocked(http.createServer).mockImplementation( () => mockHttpServer as unknown as http.Server, diff --git a/packages/core/src/prompts/mcp-prompts.test.ts b/packages/core/src/prompts/mcp-prompts.test.ts new file mode 100644 index 0000000000..dfbfea859d --- /dev/null +++ b/packages/core/src/prompts/mcp-prompts.test.ts @@ -0,0 +1,50 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi } from 'vitest'; +import { getMCPServerPrompts } from './mcp-prompts.js'; +import type { Config } from '../config/config.js'; +import { PromptRegistry } from './prompt-registry.js'; +import type { DiscoveredMCPPrompt } from '../tools/mcp-client.js'; + +describe('getMCPServerPrompts', () => { + it('should return prompts from the registry for a given server', () => { + const mockPrompts: DiscoveredMCPPrompt[] = [ + { + name: 'prompt1', + serverName: 'server1', + tool: { name: 'p1', description: '', inputSchema: {} }, + invoke: async () => ({ + messages: [ + { role: 'assistant', content: { type: 'text', text: '' } }, + ], + }), + }, + ]; + + const mockRegistry = new PromptRegistry(); + vi.spyOn(mockRegistry, 'getPromptsByServer').mockReturnValue(mockPrompts); + + const mockConfig = { + getPromptRegistry: () => mockRegistry, + } as unknown as Config; + + const result = getMCPServerPrompts(mockConfig, 'server1'); + + expect(mockRegistry.getPromptsByServer).toHaveBeenCalledWith('server1'); + expect(result).toEqual(mockPrompts); + }); + + it('should return an empty array if there is no prompt registry', () => { + const mockConfig = { + getPromptRegistry: () => undefined, + } as unknown as Config; + + const result = getMCPServerPrompts(mockConfig, 'server1'); + + expect(result).toEqual([]); + }); +}); diff --git a/packages/core/src/prompts/prompt-registry.test.ts b/packages/core/src/prompts/prompt-registry.test.ts new file mode 100644 index 0000000000..35d27884b3 --- /dev/null +++ b/packages/core/src/prompts/prompt-registry.test.ts @@ -0,0 +1,129 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { PromptRegistry } from './prompt-registry.js'; +import type { DiscoveredMCPPrompt } from '../tools/mcp-client.js'; +import { debugLogger } from '../utils/debugLogger.js'; + +vi.mock('../utils/debugLogger.js', () => ({ + debugLogger: { + warn: vi.fn(), + }, +})); + +describe('PromptRegistry', () => { + let registry: PromptRegistry; + + const prompt1: DiscoveredMCPPrompt = { + name: 'prompt1', + serverName: 'server1', + tool: { + name: 'prompt1', + description: 'Prompt 1', + inputSchema: {}, + }, + invoke: async () => ({ + messages: [ + { role: 'assistant', content: { type: 'text', text: 'response1' } }, + ], + }), + }; + + const prompt2: DiscoveredMCPPrompt = { + name: 'prompt2', + serverName: 'server1', + tool: { + name: 'prompt2', + description: 'Prompt 2', + inputSchema: {}, + }, + invoke: async () => ({ + messages: [ + { role: 'assistant', content: { type: 'text', text: 'response2' } }, + ], + }), + }; + + const prompt3: DiscoveredMCPPrompt = { + name: 'prompt1', + serverName: 'server2', + tool: { + name: 'prompt1', + description: 'Prompt 3', + inputSchema: {}, + }, + invoke: async () => ({ + messages: [ + { role: 'assistant', content: { type: 'text', text: 'response3' } }, + ], + }), + }; + + beforeEach(() => { + registry = new PromptRegistry(); + vi.clearAllMocks(); + }); + + it('should register a prompt', () => { + registry.registerPrompt(prompt1); + expect(registry.getPrompt('prompt1')).toEqual(prompt1); + }); + + it('should get all prompts, sorted by name', () => { + registry.registerPrompt(prompt2); + registry.registerPrompt(prompt1); + expect(registry.getAllPrompts()).toEqual([prompt1, prompt2]); + }); + + it('should get a specific prompt by name', () => { + registry.registerPrompt(prompt1); + expect(registry.getPrompt('prompt1')).toEqual(prompt1); + expect(registry.getPrompt('non-existent')).toBeUndefined(); + }); + + it('should get prompts by server, sorted by name', () => { + registry.registerPrompt(prompt1); + registry.registerPrompt(prompt2); + registry.registerPrompt(prompt3); // different server + expect(registry.getPromptsByServer('server1')).toEqual([prompt1, prompt2]); + expect(registry.getPromptsByServer('server2')).toEqual([ + { ...prompt3, name: 'server2_prompt1' }, + ]); + }); + + it('should handle prompt name collision by renaming', () => { + registry.registerPrompt(prompt1); + registry.registerPrompt(prompt3); + + expect(registry.getPrompt('prompt1')).toEqual(prompt1); + const renamedPrompt = { ...prompt3, name: 'server2_prompt1' }; + expect(registry.getPrompt('server2_prompt1')).toEqual(renamedPrompt); + expect(debugLogger.warn).toHaveBeenCalledWith( + 'Prompt with name "prompt1" is already registered. Renaming to "server2_prompt1".', + ); + }); + + it('should clear all prompts', () => { + registry.registerPrompt(prompt1); + registry.registerPrompt(prompt2); + registry.clear(); + expect(registry.getAllPrompts()).toEqual([]); + }); + + it('should remove prompts by server', () => { + registry.registerPrompt(prompt1); + registry.registerPrompt(prompt2); + registry.registerPrompt(prompt3); + registry.removePromptsByServer('server1'); + + const renamedPrompt = { ...prompt3, name: 'server2_prompt1' }; + expect(registry.getAllPrompts()).toEqual([renamedPrompt]); + expect(registry.getPrompt('prompt1')).toBeUndefined(); + expect(registry.getPrompt('prompt2')).toBeUndefined(); + expect(registry.getPrompt('server2_prompt1')).toEqual(renamedPrompt); + }); +}); diff --git a/packages/core/src/telemetry/clearcut-logger/clearcut-logger.test.ts b/packages/core/src/telemetry/clearcut-logger/clearcut-logger.test.ts index e754717b30..9eca914db5 100644 --- a/packages/core/src/telemetry/clearcut-logger/clearcut-logger.test.ts +++ b/packages/core/src/telemetry/clearcut-logger/clearcut-logger.test.ts @@ -394,6 +394,15 @@ describe('ClearcutLogger', () => { }, expected: 'devin', }, + { + name: 'unidentified', + env: { + GITHUB_SHA: undefined, + TERM_PROGRAM: undefined, + SURFACE: undefined, + }, + expected: 'SURFACE_NOT_SET', + }, ])( 'logs the current surface as $expected from $name', ({ env, expected }) => { @@ -943,6 +952,33 @@ describe('ClearcutLogger', () => { }); }); + describe('flushIfNeeded', () => { + it('should not flush if the interval has not passed', () => { + const { logger } = setup(); + const flushSpy = vi + // eslint-disable-next-line @typescript-eslint/no-explicit-any + .spyOn(logger!, 'flushToClearcut' as any) + .mockResolvedValue({ nextRequestWaitMs: 0 }); + + logger!.flushIfNeeded(); + expect(flushSpy).not.toHaveBeenCalled(); + }); + + it('should flush if the interval has passed', async () => { + const { logger } = setup(); + const flushSpy = vi + // eslint-disable-next-line @typescript-eslint/no-explicit-any + .spyOn(logger!, 'flushToClearcut' as any) + .mockResolvedValue({ nextRequestWaitMs: 0 }); + + // Advance time by more than the flush interval + await vi.advanceTimersByTimeAsync(1000 * 60 * 2); + + logger!.flushIfNeeded(); + expect(flushSpy).toHaveBeenCalled(); + }); + }); + describe('logWebFetchFallbackAttemptEvent', () => { it('logs an event with the proper name and reason', () => { const { logger } = setup(); diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 59d1ef7baf..5c741fb08d 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -556,7 +556,7 @@ export function hasCycleInSchema(schema: object): boolean { if ('$ref' in node && typeof node.$ref === 'string') { const ref = node.$ref; - if (ref === '#/' || pathRefs.has(ref)) { + if (ref === '#' || ref === '#/' || pathRefs.has(ref)) { // A ref to just '#/' is always a cycle. return true; // Cycle detected! }