mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-25 20:44:46 -07:00
Increase code coverage for core packages (#12872)
This commit is contained in:
@@ -0,0 +1,48 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { CodebaseInvestigatorAgent } from './codebase-investigator.js';
|
||||
import {
|
||||
GLOB_TOOL_NAME,
|
||||
GREP_TOOL_NAME,
|
||||
LS_TOOL_NAME,
|
||||
READ_FILE_TOOL_NAME,
|
||||
} from '../tools/tool-names.js';
|
||||
import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
|
||||
|
||||
describe('CodebaseInvestigatorAgent', () => {
|
||||
it('should have the correct agent definition', () => {
|
||||
expect(CodebaseInvestigatorAgent.name).toBe('codebase_investigator');
|
||||
expect(CodebaseInvestigatorAgent.displayName).toBe(
|
||||
'Codebase Investigator Agent',
|
||||
);
|
||||
expect(CodebaseInvestigatorAgent.description).toBeDefined();
|
||||
expect(
|
||||
CodebaseInvestigatorAgent.inputConfig.inputs['objective'].required,
|
||||
).toBe(true);
|
||||
expect(CodebaseInvestigatorAgent.outputConfig?.outputName).toBe('report');
|
||||
expect(CodebaseInvestigatorAgent.modelConfig?.model).toBe(
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
);
|
||||
expect(CodebaseInvestigatorAgent.toolConfig?.tools).toEqual([
|
||||
LS_TOOL_NAME,
|
||||
READ_FILE_TOOL_NAME,
|
||||
GLOB_TOOL_NAME,
|
||||
GREP_TOOL_NAME,
|
||||
]);
|
||||
});
|
||||
|
||||
it('should process output to a formatted JSON string', () => {
|
||||
const report = {
|
||||
SummaryOfFindings: 'summary',
|
||||
ExplorationTrace: ['trace'],
|
||||
RelevantLocations: [],
|
||||
};
|
||||
const processed = CodebaseInvestigatorAgent.processOutput?.(report);
|
||||
expect(processed).toBe(JSON.stringify(report, null, 2));
|
||||
});
|
||||
});
|
||||
@@ -329,6 +329,47 @@ describe('AgentExecutor', () => {
|
||||
new RegExp(`^${parentId}-${definition.name}-`),
|
||||
);
|
||||
});
|
||||
|
||||
it('should correctly apply templates to initialMessages', async () => {
|
||||
const definition = createTestDefinition();
|
||||
// Override promptConfig to use initialMessages instead of systemPrompt
|
||||
definition.promptConfig = {
|
||||
initialMessages: [
|
||||
{ role: 'user', parts: [{ text: 'Goal: ${goal}' }] },
|
||||
{ role: 'model', parts: [{ text: 'OK, starting on ${goal}.' }] },
|
||||
],
|
||||
};
|
||||
const inputs = { goal: 'TestGoal' };
|
||||
|
||||
// Mock a response to prevent the loop from running forever
|
||||
mockModelResponse([
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'done' },
|
||||
id: 'call1',
|
||||
},
|
||||
]);
|
||||
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
await executor.run(inputs, signal);
|
||||
|
||||
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
|
||||
const startHistory = chatConstructorArgs[2]; // history is the 3rd arg
|
||||
|
||||
expect(startHistory).toBeDefined();
|
||||
expect(startHistory).toHaveLength(2);
|
||||
|
||||
// Perform checks on defined objects to satisfy TS
|
||||
const firstPart = startHistory?.[0]?.parts?.[0];
|
||||
expect(firstPart?.text).toBe('Goal: TestGoal');
|
||||
|
||||
const secondPart = startHistory?.[1]?.parts?.[0];
|
||||
expect(secondPart?.text).toBe('OK, starting on TestGoal.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('run (Execution Loop and Logic)', () => {
|
||||
@@ -420,9 +461,16 @@ describe('AgentExecutor', () => {
|
||||
|
||||
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
|
||||
const chatConfig = chatConstructorArgs[1];
|
||||
expect(chatConfig?.systemInstruction).toContain(
|
||||
const systemInstruction = chatConfig?.systemInstruction as string;
|
||||
|
||||
expect(systemInstruction).toContain(
|
||||
`MUST call the \`${TASK_COMPLETE_TOOL_NAME}\` tool`,
|
||||
);
|
||||
expect(systemInstruction).toContain('Mocked Environment Context');
|
||||
expect(systemInstruction).toContain(
|
||||
'You are running in a non-interactive mode',
|
||||
);
|
||||
expect(systemInstruction).toContain('Always use absolute paths');
|
||||
|
||||
const turn1Params = getMockMessageParams(0);
|
||||
|
||||
@@ -921,6 +969,203 @@ describe('AgentExecutor', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge Cases and Error Handling', () => {
|
||||
it('should report an error if complete_task output fails schema validation', async () => {
|
||||
const definition = createTestDefinition(
|
||||
[],
|
||||
{},
|
||||
'default',
|
||||
z.string().min(10), // The schema is for the output value itself
|
||||
);
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
|
||||
// Turn 1: Invalid arg (too short)
|
||||
mockModelResponse([
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'short' },
|
||||
id: 'call1',
|
||||
},
|
||||
]);
|
||||
|
||||
// Turn 2: Corrected
|
||||
mockModelResponse([
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'This is a much longer and valid result' },
|
||||
id: 'call2',
|
||||
},
|
||||
]);
|
||||
|
||||
const output = await executor.run({ goal: 'Validation test' }, signal);
|
||||
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
|
||||
const expectedError =
|
||||
'Output validation failed: {"formErrors":["String must contain at least 10 character(s)"],"fieldErrors":{}}';
|
||||
|
||||
// Check that the error was reported in the activity stream
|
||||
expect(activities).toContainEqual(
|
||||
expect.objectContaining({
|
||||
type: 'ERROR',
|
||||
data: {
|
||||
context: 'tool_call',
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
error: expect.stringContaining('Output validation failed'),
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
// Check that the error was sent back to the model for the next turn
|
||||
const turn2Params = getMockMessageParams(1);
|
||||
const turn2Parts = turn2Params.message;
|
||||
expect(turn2Parts).toEqual([
|
||||
expect.objectContaining({
|
||||
functionResponse: expect.objectContaining({
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
response: { error: expectedError },
|
||||
id: 'call1',
|
||||
}),
|
||||
}),
|
||||
]);
|
||||
|
||||
// Check that the agent eventually succeeded
|
||||
expect(output.result).toContain('This is a much longer and valid result');
|
||||
expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL);
|
||||
});
|
||||
|
||||
it('should throw and log if GeminiChat creation fails', async () => {
|
||||
const definition = createTestDefinition();
|
||||
const initError = new Error('Chat creation failed');
|
||||
MockedGeminiChat.mockImplementationOnce(() => {
|
||||
throw initError;
|
||||
});
|
||||
|
||||
// We expect the error to be thrown during the run, not creation
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
|
||||
await expect(executor.run({ goal: 'test' }, signal)).rejects.toThrow(
|
||||
`Failed to create chat object: ${initError}`,
|
||||
);
|
||||
|
||||
// Ensure the error was reported via the activity callback
|
||||
expect(activities).toContainEqual(
|
||||
expect.objectContaining({
|
||||
type: 'ERROR',
|
||||
data: expect.objectContaining({
|
||||
error: `Error: Failed to create chat object: ${initError}`,
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
// Ensure the agent run was logged as a failure
|
||||
expect(mockedLogAgentFinish).toHaveBeenCalledWith(
|
||||
mockConfig,
|
||||
expect.objectContaining({
|
||||
terminate_reason: AgentTerminateMode.ERROR,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle a failed tool call and feed the error to the model', async () => {
|
||||
const definition = createTestDefinition([LS_TOOL_NAME]);
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
const toolErrorMessage = 'Tool failed spectacularly';
|
||||
|
||||
// Turn 1: Model calls a tool that will fail
|
||||
mockModelResponse([
|
||||
{ name: LS_TOOL_NAME, args: { path: '/fake' }, id: 'call1' },
|
||||
]);
|
||||
mockExecuteToolCall.mockResolvedValueOnce({
|
||||
status: 'error',
|
||||
request: {
|
||||
callId: 'call1',
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '/fake' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: 'call1',
|
||||
resultDisplay: '',
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: LS_TOOL_NAME,
|
||||
response: { error: toolErrorMessage },
|
||||
id: 'call1',
|
||||
},
|
||||
},
|
||||
],
|
||||
error: {
|
||||
type: 'ToolError',
|
||||
message: toolErrorMessage,
|
||||
},
|
||||
errorType: 'ToolError',
|
||||
contentLength: 0,
|
||||
},
|
||||
});
|
||||
|
||||
// Turn 2: Model sees the error and completes
|
||||
mockModelResponse([
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'Aborted due to tool failure.' },
|
||||
id: 'call2',
|
||||
},
|
||||
]);
|
||||
|
||||
const output = await executor.run({ goal: 'Tool failure test' }, signal);
|
||||
|
||||
expect(mockExecuteToolCall).toHaveBeenCalledTimes(1);
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
|
||||
// Verify the error was reported in the activity stream
|
||||
expect(activities).toContainEqual(
|
||||
expect.objectContaining({
|
||||
type: 'ERROR',
|
||||
data: {
|
||||
context: 'tool_call',
|
||||
name: LS_TOOL_NAME,
|
||||
error: toolErrorMessage,
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
// Verify the error was sent back to the model
|
||||
const turn2Params = getMockMessageParams(1);
|
||||
const parts = turn2Params.message;
|
||||
expect(parts).toEqual([
|
||||
expect.objectContaining({
|
||||
functionResponse: expect.objectContaining({
|
||||
name: LS_TOOL_NAME,
|
||||
id: 'call1',
|
||||
response: {
|
||||
error: toolErrorMessage,
|
||||
},
|
||||
}),
|
||||
}),
|
||||
]);
|
||||
|
||||
expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL);
|
||||
expect(output.result).toBe('Aborted due to tool failure.');
|
||||
});
|
||||
});
|
||||
|
||||
describe('run (Termination Conditions)', () => {
|
||||
const mockWorkResponse = (id: string) => {
|
||||
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import { getOauthClient } from './oauth2.js';
|
||||
import { setupUser } from './setup.js';
|
||||
import { CodeAssistServer } from './server.js';
|
||||
import {
|
||||
createCodeAssistContentGenerator,
|
||||
getCodeAssistServer,
|
||||
} from './codeAssist.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { LoggingContentGenerator } from '../core/loggingContentGenerator.js';
|
||||
import { UserTierId } from './types.js';
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('./oauth2.js');
|
||||
vi.mock('./setup.js');
|
||||
vi.mock('./server.js');
|
||||
vi.mock('../core/loggingContentGenerator.js');
|
||||
|
||||
const mockedGetOauthClient = vi.mocked(getOauthClient);
|
||||
const mockedSetupUser = vi.mocked(setupUser);
|
||||
const MockedCodeAssistServer = vi.mocked(CodeAssistServer);
|
||||
const MockedLoggingContentGenerator = vi.mocked(LoggingContentGenerator);
|
||||
|
||||
describe('codeAssist', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
});
|
||||
|
||||
describe('createCodeAssistContentGenerator', () => {
|
||||
const httpOptions = {};
|
||||
const mockConfig = {} as Config;
|
||||
const mockAuthClient = { a: 'client' };
|
||||
const mockUserData = {
|
||||
projectId: 'test-project',
|
||||
userTier: UserTierId.FREE,
|
||||
};
|
||||
|
||||
it('should create a server for LOGIN_WITH_GOOGLE', async () => {
|
||||
mockedGetOauthClient.mockResolvedValue(mockAuthClient as never);
|
||||
mockedSetupUser.mockResolvedValue(mockUserData);
|
||||
|
||||
const generator = await createCodeAssistContentGenerator(
|
||||
httpOptions,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
mockConfig,
|
||||
'session-123',
|
||||
);
|
||||
|
||||
expect(getOauthClient).toHaveBeenCalledWith(
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
mockConfig,
|
||||
);
|
||||
expect(setupUser).toHaveBeenCalledWith(mockAuthClient);
|
||||
expect(MockedCodeAssistServer).toHaveBeenCalledWith(
|
||||
mockAuthClient,
|
||||
'test-project',
|
||||
httpOptions,
|
||||
'session-123',
|
||||
'free-tier',
|
||||
);
|
||||
expect(generator).toBeInstanceOf(MockedCodeAssistServer);
|
||||
});
|
||||
|
||||
it('should create a server for CLOUD_SHELL', async () => {
|
||||
mockedGetOauthClient.mockResolvedValue(mockAuthClient as never);
|
||||
mockedSetupUser.mockResolvedValue(mockUserData);
|
||||
|
||||
const generator = await createCodeAssistContentGenerator(
|
||||
httpOptions,
|
||||
AuthType.CLOUD_SHELL,
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(getOauthClient).toHaveBeenCalledWith(
|
||||
AuthType.CLOUD_SHELL,
|
||||
mockConfig,
|
||||
);
|
||||
expect(setupUser).toHaveBeenCalledWith(mockAuthClient);
|
||||
expect(MockedCodeAssistServer).toHaveBeenCalledWith(
|
||||
mockAuthClient,
|
||||
'test-project',
|
||||
httpOptions,
|
||||
undefined, // No session ID
|
||||
'free-tier',
|
||||
);
|
||||
expect(generator).toBeInstanceOf(MockedCodeAssistServer);
|
||||
});
|
||||
|
||||
it('should throw an error for unsupported auth types', async () => {
|
||||
await expect(
|
||||
createCodeAssistContentGenerator(
|
||||
httpOptions,
|
||||
'api-key' as AuthType, // Use literal string to avoid enum resolution issues
|
||||
mockConfig,
|
||||
),
|
||||
).rejects.toThrow('Unsupported authType: api-key');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getCodeAssistServer', () => {
|
||||
it('should return the server if it is a CodeAssistServer', () => {
|
||||
const mockServer = new MockedCodeAssistServer({} as never, '', {});
|
||||
const mockConfig = {
|
||||
getContentGenerator: () => mockServer,
|
||||
} as unknown as Config;
|
||||
|
||||
const server = getCodeAssistServer(mockConfig);
|
||||
expect(server).toBe(mockServer);
|
||||
});
|
||||
|
||||
it('should unwrap and return the server if it is wrapped in a LoggingContentGenerator', () => {
|
||||
const mockServer = new MockedCodeAssistServer({} as never, '', {});
|
||||
const mockLogger = new MockedLoggingContentGenerator(
|
||||
{} as never,
|
||||
{} as never,
|
||||
);
|
||||
vi.spyOn(mockLogger, 'getWrapped').mockReturnValue(mockServer);
|
||||
|
||||
const mockConfig = {
|
||||
getContentGenerator: () => mockLogger,
|
||||
} as unknown as Config;
|
||||
|
||||
const server = getCodeAssistServer(mockConfig);
|
||||
expect(server).toBe(mockServer);
|
||||
expect(mockLogger.getWrapped).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return undefined if the content generator is not a CodeAssistServer', () => {
|
||||
const mockGenerator = { a: 'generator' }; // Not a CodeAssistServer
|
||||
const mockConfig = {
|
||||
getContentGenerator: () => mockGenerator,
|
||||
} as unknown as Config;
|
||||
|
||||
const server = getCodeAssistServer(mockConfig);
|
||||
expect(server).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return undefined if the wrapped generator is not a CodeAssistServer', () => {
|
||||
const mockGenerator = { a: 'generator' }; // Not a CodeAssistServer
|
||||
const mockLogger = new MockedLoggingContentGenerator(
|
||||
{} as never,
|
||||
{} as never,
|
||||
);
|
||||
vi.spyOn(mockLogger, 'getWrapped').mockReturnValue(
|
||||
mockGenerator as never,
|
||||
);
|
||||
|
||||
const mockConfig = {
|
||||
getContentGenerator: () => mockLogger,
|
||||
} as unknown as Config;
|
||||
|
||||
const server = getCodeAssistServer(mockConfig);
|
||||
expect(server).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,112 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { ReleaseChannel, getReleaseChannel } from '../../utils/channel.js';
|
||||
|
||||
// Mock dependencies before importing the module under test
|
||||
vi.mock('../../utils/channel.js', async () => {
|
||||
const actual = await vi.importActual('../../utils/channel.js');
|
||||
return {
|
||||
...(actual as object),
|
||||
getReleaseChannel: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
describe('client_metadata', () => {
|
||||
const originalPlatform = process.platform;
|
||||
const originalArch = process.arch;
|
||||
const originalCliVersion = process.env['CLI_VERSION'];
|
||||
const originalNodeVersion = process.version;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Reset modules to clear the cached `clientMetadataPromise`
|
||||
vi.resetModules();
|
||||
// Re-import the module to get a fresh instance
|
||||
await import('./client_metadata.js');
|
||||
// Provide a default mock implementation for each test
|
||||
vi.mocked(getReleaseChannel).mockResolvedValue(ReleaseChannel.STABLE);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// Restore original process properties to avoid side-effects between tests
|
||||
Object.defineProperty(process, 'platform', { value: originalPlatform });
|
||||
Object.defineProperty(process, 'arch', { value: originalArch });
|
||||
process.env['CLI_VERSION'] = originalCliVersion;
|
||||
Object.defineProperty(process, 'version', { value: originalNodeVersion });
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('getPlatform', () => {
|
||||
const testCases = [
|
||||
{ platform: 'darwin', arch: 'x64', expected: 'DARWIN_AMD64' },
|
||||
{ platform: 'darwin', arch: 'arm64', expected: 'DARWIN_ARM64' },
|
||||
{ platform: 'linux', arch: 'x64', expected: 'LINUX_AMD64' },
|
||||
{ platform: 'linux', arch: 'arm64', expected: 'LINUX_ARM64' },
|
||||
{ platform: 'win32', arch: 'x64', expected: 'WINDOWS_AMD64' },
|
||||
{ platform: 'sunos', arch: 'x64', expected: 'PLATFORM_UNSPECIFIED' },
|
||||
{ platform: 'win32', arch: 'arm', expected: 'PLATFORM_UNSPECIFIED' },
|
||||
];
|
||||
|
||||
for (const { platform, arch, expected } of testCases) {
|
||||
it(`should return ${expected} for platform ${platform} and arch ${arch}`, async () => {
|
||||
Object.defineProperty(process, 'platform', { value: platform });
|
||||
Object.defineProperty(process, 'arch', { value: arch });
|
||||
const { getClientMetadata } = await import('./client_metadata.js');
|
||||
|
||||
const metadata = await getClientMetadata();
|
||||
expect(metadata.platform).toBe(expected);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
describe('getClientMetadata', () => {
|
||||
it('should use CLI_VERSION for ideVersion if set', async () => {
|
||||
process.env['CLI_VERSION'] = '1.2.3';
|
||||
Object.defineProperty(process, 'version', { value: 'v18.0.0' });
|
||||
const { getClientMetadata } = await import('./client_metadata.js');
|
||||
|
||||
const metadata = await getClientMetadata();
|
||||
expect(metadata.ideVersion).toBe('1.2.3');
|
||||
});
|
||||
|
||||
it('should use process.version for ideVersion as a fallback', async () => {
|
||||
delete process.env['CLI_VERSION'];
|
||||
Object.defineProperty(process, 'version', { value: 'v20.0.0' });
|
||||
const { getClientMetadata } = await import('./client_metadata.js');
|
||||
|
||||
const metadata = await getClientMetadata();
|
||||
expect(metadata.ideVersion).toBe('v20.0.0');
|
||||
});
|
||||
|
||||
it('should call getReleaseChannel to get the update channel', async () => {
|
||||
vi.mocked(getReleaseChannel).mockResolvedValue(ReleaseChannel.NIGHTLY);
|
||||
const { getClientMetadata } = await import('./client_metadata.js');
|
||||
|
||||
const metadata = await getClientMetadata();
|
||||
|
||||
expect(metadata.updateChannel).toBe('nightly');
|
||||
expect(getReleaseChannel).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should cache the client metadata promise', async () => {
|
||||
const { getClientMetadata } = await import('./client_metadata.js');
|
||||
|
||||
const firstCall = await getClientMetadata();
|
||||
const secondCall = await getClientMetadata();
|
||||
|
||||
expect(firstCall).toBe(secondCall);
|
||||
// Ensure the underlying functions are only called once
|
||||
expect(getReleaseChannel).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should always return the IDE name as GEMINI_CLI', async () => {
|
||||
const { getClientMetadata } = await import('./client_metadata.js');
|
||||
const metadata = await getClientMetadata();
|
||||
expect(metadata.ideName).toBe('GEMINI_CLI');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,115 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import type { CodeAssistServer } from '../server.js';
|
||||
import { getClientMetadata } from './client_metadata.js';
|
||||
import type { ListExperimentsResponse, Flag } from './types.js';
|
||||
|
||||
// Mock dependencies before importing the module under test
|
||||
vi.mock('../server.js');
|
||||
vi.mock('./client_metadata.js');
|
||||
|
||||
describe('experiments', () => {
|
||||
let mockServer: CodeAssistServer;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset modules to clear the cached `experimentsPromise`
|
||||
vi.resetModules();
|
||||
|
||||
// Mock the dependencies that `getExperiments` relies on
|
||||
vi.mocked(getClientMetadata).mockResolvedValue({
|
||||
ideName: 'GEMINI_CLI',
|
||||
ideVersion: '1.0.0',
|
||||
platform: 'LINUX_AMD64',
|
||||
updateChannel: 'stable',
|
||||
});
|
||||
|
||||
// Create a mock instance of the server for each test
|
||||
mockServer = {
|
||||
listExperiments: vi.fn(),
|
||||
} as unknown as CodeAssistServer;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should fetch and parse experiments from the server', async () => {
|
||||
const { getExperiments } = await import('./experiments.js');
|
||||
const mockApiResponse: ListExperimentsResponse = {
|
||||
flags: [
|
||||
{ name: 'flag1', boolValue: true },
|
||||
{ name: 'flag2', stringValue: 'value' },
|
||||
],
|
||||
experimentIds: [123, 456],
|
||||
};
|
||||
vi.mocked(mockServer.listExperiments).mockResolvedValue(mockApiResponse);
|
||||
|
||||
const experiments = await getExperiments(mockServer);
|
||||
|
||||
// Verify that the dependencies were called
|
||||
expect(getClientMetadata).toHaveBeenCalled();
|
||||
expect(mockServer.listExperiments).toHaveBeenCalledWith(
|
||||
await getClientMetadata(),
|
||||
);
|
||||
|
||||
// Verify that the response was parsed correctly
|
||||
expect(experiments.flags['flag1']).toEqual({
|
||||
name: 'flag1',
|
||||
boolValue: true,
|
||||
});
|
||||
expect(experiments.flags['flag2']).toEqual({
|
||||
name: 'flag2',
|
||||
stringValue: 'value',
|
||||
});
|
||||
expect(experiments.experimentIds).toEqual([123, 456]);
|
||||
});
|
||||
|
||||
it('should handle an empty or partial response from the server', async () => {
|
||||
const { getExperiments } = await import('./experiments.js');
|
||||
const mockApiResponse: ListExperimentsResponse = {}; // No flags or experimentIds
|
||||
vi.mocked(mockServer.listExperiments).mockResolvedValue(mockApiResponse);
|
||||
|
||||
const experiments = await getExperiments(mockServer);
|
||||
|
||||
expect(experiments.flags).toEqual({});
|
||||
expect(experiments.experimentIds).toEqual([]);
|
||||
});
|
||||
|
||||
it('should ignore flags that are missing a name', async () => {
|
||||
const { getExperiments } = await import('./experiments.js');
|
||||
const mockApiResponse: ListExperimentsResponse = {
|
||||
flags: [
|
||||
{ boolValue: true } as Flag, // No name
|
||||
{ name: 'flag2', stringValue: 'value' },
|
||||
],
|
||||
};
|
||||
vi.mocked(mockServer.listExperiments).mockResolvedValue(mockApiResponse);
|
||||
|
||||
const experiments = await getExperiments(mockServer);
|
||||
|
||||
expect(Object.keys(experiments.flags)).toHaveLength(1);
|
||||
expect(experiments.flags['flag2']).toBeDefined();
|
||||
expect(experiments.flags['undefined']).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should cache the experiments promise to avoid multiple fetches', async () => {
|
||||
const { getExperiments } = await import('./experiments.js');
|
||||
const mockApiResponse: ListExperimentsResponse = {
|
||||
experimentIds: [1, 2, 3],
|
||||
};
|
||||
vi.mocked(mockServer.listExperiments).mockResolvedValue(mockApiResponse);
|
||||
|
||||
const firstCall = await getExperiments(mockServer);
|
||||
const secondCall = await getExperiments(mockServer);
|
||||
|
||||
expect(firstCall).toBe(secondCall); // Should be the exact same promise object
|
||||
// Verify the underlying functions were only called once
|
||||
expect(getClientMetadata).toHaveBeenCalledTimes(1);
|
||||
expect(mockServer.listExperiments).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
@@ -173,6 +173,58 @@ describe('OAuthCredentialStorage', () => {
|
||||
|
||||
expect(result).toEqual(mockCredentials);
|
||||
});
|
||||
|
||||
it('should throw an error if the migration file contains invalid JSON', async () => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
null,
|
||||
);
|
||||
vi.spyOn(fs, 'readFile').mockResolvedValue('invalid json');
|
||||
|
||||
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
|
||||
'Failed to load OAuth credentials',
|
||||
);
|
||||
});
|
||||
|
||||
it('should not delete the old file if saving migrated credentials fails', async () => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
null,
|
||||
);
|
||||
vi.spyOn(fs, 'readFile').mockResolvedValue(
|
||||
JSON.stringify(mockCredentials),
|
||||
);
|
||||
vi.spyOn(mockHybridTokenStorage, 'setCredentials').mockRejectedValue(
|
||||
new Error('Save failed'),
|
||||
);
|
||||
|
||||
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
|
||||
'Failed to load OAuth credentials',
|
||||
);
|
||||
|
||||
expect(fs.rm).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return credentials even if access_token is missing from storage', async () => {
|
||||
const partialMcpCredentials = {
|
||||
...mockMcpCredentials,
|
||||
token: {
|
||||
...mockMcpCredentials.token,
|
||||
accessToken: undefined,
|
||||
},
|
||||
};
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
partialMcpCredentials,
|
||||
);
|
||||
|
||||
const result = await OAuthCredentialStorage.loadCredentials();
|
||||
|
||||
expect(result).toEqual({
|
||||
access_token: undefined,
|
||||
refresh_token: mockCredentials.refresh_token,
|
||||
token_type: mockCredentials.token_type,
|
||||
scope: mockCredentials.scope,
|
||||
expiry_date: mockCredentials.expiry_date,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('saveCredentials', () => {
|
||||
@@ -195,6 +247,28 @@ describe('OAuthCredentialStorage', () => {
|
||||
'Attempted to save credentials without an access token.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle saving credentials with null or undefined optional fields', async () => {
|
||||
const partialCredentials: Credentials = {
|
||||
access_token: 'only_access_token',
|
||||
refresh_token: null, // test null
|
||||
scope: undefined, // test undefined
|
||||
};
|
||||
|
||||
await OAuthCredentialStorage.saveCredentials(partialCredentials);
|
||||
|
||||
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith({
|
||||
serverName: 'main-account',
|
||||
token: {
|
||||
accessToken: 'only_access_token',
|
||||
refreshToken: undefined,
|
||||
tokenType: 'Bearer', // default
|
||||
scope: undefined,
|
||||
expiresAt: undefined,
|
||||
},
|
||||
updatedAt: expect.any(Number),
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearCredentials', () => {
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { beforeEach, describe, it, expect, vi } from 'vitest';
|
||||
import { beforeEach, describe, it, expect, vi, afterEach } from 'vitest';
|
||||
import { CodeAssistServer } from './server.js';
|
||||
import { OAuth2Client } from 'google-auth-library';
|
||||
import { UserTierId } from './types.js';
|
||||
@@ -29,15 +29,16 @@ describe('CodeAssistServer', () => {
|
||||
});
|
||||
|
||||
it('should call the generateContent endpoint', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const mockRequest = vi.fn();
|
||||
const client = { request: mockRequest } as unknown as OAuth2Client;
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
{},
|
||||
{ headers: { 'x-custom-header': 'test-value' } },
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
const mockResponse = {
|
||||
const mockResponseData = {
|
||||
response: {
|
||||
candidates: [
|
||||
{
|
||||
@@ -52,7 +53,7 @@ describe('CodeAssistServer', () => {
|
||||
],
|
||||
},
|
||||
};
|
||||
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
|
||||
mockRequest.mockResolvedValue({ data: mockResponseData });
|
||||
|
||||
const response = await server.generateContent(
|
||||
{
|
||||
@@ -62,18 +63,59 @@ describe('CodeAssistServer', () => {
|
||||
'user-prompt-id',
|
||||
);
|
||||
|
||||
expect(server.requestPost).toHaveBeenCalledWith(
|
||||
'generateContent',
|
||||
expect.any(Object),
|
||||
undefined,
|
||||
);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
url: expect.stringContaining(':generateContent'),
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-custom-header': 'test-value',
|
||||
},
|
||||
responseType: 'json',
|
||||
body: expect.any(String),
|
||||
signal: undefined,
|
||||
});
|
||||
|
||||
const requestBody = JSON.parse(mockRequest.mock.calls[0][0].body);
|
||||
expect(requestBody.user_prompt_id).toBe('user-prompt-id');
|
||||
expect(requestBody.project).toBe('test-project');
|
||||
|
||||
expect(response.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
|
||||
'response',
|
||||
);
|
||||
});
|
||||
|
||||
it('should call the generateContentStream endpoint', async () => {
|
||||
const client = new OAuth2Client();
|
||||
describe('getMethodUrl', () => {
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset the environment variables to their original state
|
||||
process.env = { ...originalEnv };
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// Restore the original environment variables
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
it('should construct the default URL correctly', () => {
|
||||
const server = new CodeAssistServer({} as never);
|
||||
const url = server.getMethodUrl('testMethod');
|
||||
expect(url).toBe(
|
||||
'https://cloudcode-pa.googleapis.com/v1internal:testMethod',
|
||||
);
|
||||
});
|
||||
|
||||
it('should use the CODE_ASSIST_ENDPOINT environment variable if set', () => {
|
||||
process.env['CODE_ASSIST_ENDPOINT'] = 'https://custom-endpoint.com';
|
||||
const server = new CodeAssistServer({} as never);
|
||||
const url = server.getMethodUrl('testMethod');
|
||||
expect(url).toBe('https://custom-endpoint.com/v1internal:testMethod');
|
||||
});
|
||||
});
|
||||
|
||||
it('should call the generateContentStream endpoint and parse SSE', async () => {
|
||||
const mockRequest = vi.fn();
|
||||
const client = { request: mockRequest } as unknown as OAuth2Client;
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
@@ -81,24 +123,21 @@ describe('CodeAssistServer', () => {
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
const mockResponse = (async function* () {
|
||||
yield {
|
||||
response: {
|
||||
candidates: [
|
||||
{
|
||||
index: 0,
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [{ text: 'response' }],
|
||||
},
|
||||
finishReason: 'STOP',
|
||||
safetyRatings: [],
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
})();
|
||||
vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse);
|
||||
|
||||
// Create a mock readable stream
|
||||
const { Readable } = await import('node:stream');
|
||||
const mockStream = new Readable({
|
||||
read() {},
|
||||
});
|
||||
|
||||
const mockResponseData1 = {
|
||||
response: { candidates: [{ content: { parts: [{ text: 'Hello' }] } }] },
|
||||
};
|
||||
const mockResponseData2 = {
|
||||
response: { candidates: [{ content: { parts: [{ text: ' World' }] } }] },
|
||||
};
|
||||
|
||||
mockRequest.mockResolvedValue({ data: mockStream });
|
||||
|
||||
const stream = await server.generateContentStream(
|
||||
{
|
||||
@@ -108,14 +147,61 @@ describe('CodeAssistServer', () => {
|
||||
'user-prompt-id',
|
||||
);
|
||||
|
||||
// Push SSE data to the stream
|
||||
// Use setTimeout to ensure the stream processing has started
|
||||
setTimeout(() => {
|
||||
mockStream.push('data: ' + JSON.stringify(mockResponseData1) + '\n\n');
|
||||
mockStream.push('id: 123\n'); // Should be ignored
|
||||
mockStream.push('data: ' + JSON.stringify(mockResponseData2) + '\n\n');
|
||||
mockStream.push(null); // End the stream
|
||||
}, 0);
|
||||
|
||||
const results = [];
|
||||
for await (const res of stream) {
|
||||
expect(server.requestStreamingPost).toHaveBeenCalledWith(
|
||||
'streamGenerateContent',
|
||||
expect.any(Object),
|
||||
undefined,
|
||||
);
|
||||
expect(res.candidates?.[0]?.content?.parts?.[0]?.text).toBe('response');
|
||||
results.push(res);
|
||||
}
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
url: expect.stringContaining(':streamGenerateContent'),
|
||||
method: 'POST',
|
||||
params: { alt: 'sse' },
|
||||
responseType: 'stream',
|
||||
body: expect.any(String),
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
signal: undefined,
|
||||
});
|
||||
|
||||
expect(results).toHaveLength(2);
|
||||
expect(results[0].candidates?.[0].content?.parts?.[0].text).toBe('Hello');
|
||||
expect(results[1].candidates?.[0].content?.parts?.[0].text).toBe(' World');
|
||||
});
|
||||
|
||||
it('should ignore malformed SSE data', async () => {
|
||||
const mockRequest = vi.fn();
|
||||
const client = { request: mockRequest } as unknown as OAuth2Client;
|
||||
const server = new CodeAssistServer(client);
|
||||
|
||||
const { Readable } = await import('node:stream');
|
||||
const mockStream = new Readable({
|
||||
read() {},
|
||||
});
|
||||
|
||||
mockRequest.mockResolvedValue({ data: mockStream });
|
||||
|
||||
const stream = await server.requestStreamingPost('testStream', {});
|
||||
|
||||
setTimeout(() => {
|
||||
mockStream.push('this is a malformed line\n');
|
||||
mockStream.push(null);
|
||||
}, 0);
|
||||
|
||||
const results = [];
|
||||
for await (const res of stream) {
|
||||
results.push(res);
|
||||
}
|
||||
expect(results).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should call the onboardUser endpoint', async () => {
|
||||
@@ -253,6 +339,22 @@ describe('CodeAssistServer', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should re-throw non-VPC-SC errors from loadCodeAssist', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(client);
|
||||
const genericError = new Error('Something else went wrong');
|
||||
vi.spyOn(server, 'requestPost').mockRejectedValue(genericError);
|
||||
|
||||
await expect(server.loadCodeAssist({ metadata: {} })).rejects.toThrow(
|
||||
'Something else went wrong',
|
||||
);
|
||||
|
||||
expect(server.requestPost).toHaveBeenCalledWith(
|
||||
'loadCodeAssist',
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should call the listExperiments endpoint with metadata', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
|
||||
@@ -232,18 +232,16 @@ export class CodeAssistServer implements ContentGenerator {
|
||||
|
||||
let bufferedLines: string[] = [];
|
||||
for await (const line of rl) {
|
||||
// blank lines are used to separate JSON objects in the stream
|
||||
if (line === '') {
|
||||
if (line.startsWith('data: ')) {
|
||||
bufferedLines.push(line.slice(6).trim());
|
||||
} else if (line === '') {
|
||||
if (bufferedLines.length === 0) {
|
||||
continue; // no data to yield
|
||||
}
|
||||
yield JSON.parse(bufferedLines.join('\n')) as T;
|
||||
bufferedLines = []; // Reset the buffer after yielding
|
||||
} else if (line.startsWith('data: ')) {
|
||||
bufferedLines.push(line.slice(6).trim());
|
||||
} else {
|
||||
throw new Error(`Unexpected line format in response: ${line}`);
|
||||
}
|
||||
// Ignore other lines like comments or id fields
|
||||
}
|
||||
})();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,257 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
const logApiRequest = vi.hoisted(() => vi.fn());
|
||||
const logApiResponse = vi.hoisted(() => vi.fn());
|
||||
const logApiError = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('../telemetry/loggers.js', () => ({
|
||||
logApiRequest,
|
||||
logApiResponse,
|
||||
logApiError,
|
||||
}));
|
||||
|
||||
const runInDevTraceSpan = vi.hoisted(() =>
|
||||
vi.fn(async (meta, fn) => fn({ metadata: {}, endSpan: vi.fn() })),
|
||||
);
|
||||
|
||||
vi.mock('../telemetry/trace.js', () => ({
|
||||
runInDevTraceSpan,
|
||||
}));
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import type {
|
||||
GenerateContentResponse,
|
||||
EmbedContentResponse,
|
||||
} from '@google/genai';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import { LoggingContentGenerator } from './loggingContentGenerator.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ApiRequestEvent } from '../telemetry/types.js';
|
||||
|
||||
describe('LoggingContentGenerator', () => {
|
||||
let wrapped: ContentGenerator;
|
||||
let config: Config;
|
||||
let loggingContentGenerator: LoggingContentGenerator;
|
||||
|
||||
beforeEach(() => {
|
||||
wrapped = {
|
||||
generateContent: vi.fn(),
|
||||
generateContentStream: vi.fn(),
|
||||
countTokens: vi.fn(),
|
||||
embedContent: vi.fn(),
|
||||
};
|
||||
config = {
|
||||
getGoogleAIConfig: vi.fn(),
|
||||
getVertexAIConfig: vi.fn(),
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
authType: 'API_KEY',
|
||||
}),
|
||||
} as unknown as Config;
|
||||
loggingContentGenerator = new LoggingContentGenerator(wrapped, config);
|
||||
vi.useFakeTimers();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
describe('generateContent', () => {
|
||||
it('should log request and response on success', async () => {
|
||||
const req = {
|
||||
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
||||
model: 'gemini-pro',
|
||||
};
|
||||
const userPromptId = 'prompt-123';
|
||||
const response: GenerateContentResponse = {
|
||||
candidates: [],
|
||||
usageMetadata: {
|
||||
promptTokenCount: 1,
|
||||
candidatesTokenCount: 2,
|
||||
totalTokenCount: 3,
|
||||
},
|
||||
text: undefined,
|
||||
functionCalls: undefined,
|
||||
executableCode: undefined,
|
||||
codeExecutionResult: undefined,
|
||||
data: undefined,
|
||||
};
|
||||
vi.mocked(wrapped.generateContent).mockResolvedValue(response);
|
||||
const startTime = new Date('2025-01-01T00:00:00.000Z');
|
||||
vi.setSystemTime(startTime);
|
||||
|
||||
const promise = loggingContentGenerator.generateContent(
|
||||
req,
|
||||
userPromptId,
|
||||
);
|
||||
|
||||
vi.advanceTimersByTime(1000);
|
||||
|
||||
await promise;
|
||||
|
||||
expect(wrapped.generateContent).toHaveBeenCalledWith(req, userPromptId);
|
||||
expect(logApiRequest).toHaveBeenCalledWith(
|
||||
config,
|
||||
expect.any(ApiRequestEvent),
|
||||
);
|
||||
const responseEvent = vi.mocked(logApiResponse).mock.calls[0][1];
|
||||
expect(responseEvent.duration_ms).toBe(1000);
|
||||
});
|
||||
|
||||
it('should log error on failure', async () => {
|
||||
const req = {
|
||||
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
||||
model: 'gemini-pro',
|
||||
};
|
||||
const userPromptId = 'prompt-123';
|
||||
const error = new Error('test error');
|
||||
vi.mocked(wrapped.generateContent).mockRejectedValue(error);
|
||||
const startTime = new Date('2025-01-01T00:00:00.000Z');
|
||||
vi.setSystemTime(startTime);
|
||||
|
||||
const promise = loggingContentGenerator.generateContent(
|
||||
req,
|
||||
userPromptId,
|
||||
);
|
||||
|
||||
vi.advanceTimersByTime(1000);
|
||||
|
||||
await expect(promise).rejects.toThrow(error);
|
||||
|
||||
expect(logApiRequest).toHaveBeenCalledWith(
|
||||
config,
|
||||
expect.any(ApiRequestEvent),
|
||||
);
|
||||
const errorEvent = vi.mocked(logApiError).mock.calls[0][1];
|
||||
expect(errorEvent.duration_ms).toBe(1000);
|
||||
});
|
||||
});
|
||||
|
||||
describe('generateContentStream', () => {
|
||||
it('should log request and response on success', async () => {
|
||||
const req = {
|
||||
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
||||
model: 'gemini-pro',
|
||||
};
|
||||
const userPromptId = 'prompt-123';
|
||||
const response = {
|
||||
candidates: [],
|
||||
usageMetadata: {
|
||||
promptTokenCount: 1,
|
||||
candidatesTokenCount: 2,
|
||||
totalTokenCount: 3,
|
||||
},
|
||||
} as unknown as GenerateContentResponse;
|
||||
|
||||
async function* createAsyncGenerator() {
|
||||
yield response;
|
||||
}
|
||||
|
||||
vi.mocked(wrapped.generateContentStream).mockResolvedValue(
|
||||
createAsyncGenerator(),
|
||||
);
|
||||
const startTime = new Date('2025-01-01T00:00:00.000Z');
|
||||
vi.setSystemTime(startTime);
|
||||
|
||||
const stream = await loggingContentGenerator.generateContentStream(
|
||||
req,
|
||||
userPromptId,
|
||||
);
|
||||
|
||||
vi.advanceTimersByTime(1000);
|
||||
|
||||
for await (const _ of stream) {
|
||||
// consume stream
|
||||
}
|
||||
|
||||
expect(wrapped.generateContentStream).toHaveBeenCalledWith(
|
||||
req,
|
||||
userPromptId,
|
||||
);
|
||||
expect(logApiRequest).toHaveBeenCalledWith(
|
||||
config,
|
||||
expect.any(ApiRequestEvent),
|
||||
);
|
||||
const responseEvent = vi.mocked(logApiResponse).mock.calls[0][1];
|
||||
expect(responseEvent.duration_ms).toBe(1000);
|
||||
});
|
||||
|
||||
it('should log error on failure', async () => {
|
||||
const req = {
|
||||
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
||||
model: 'gemini-pro',
|
||||
};
|
||||
const userPromptId = 'prompt-123';
|
||||
const error = new Error('test error');
|
||||
|
||||
async function* createAsyncGenerator() {
|
||||
yield Promise.reject(error);
|
||||
}
|
||||
|
||||
vi.mocked(wrapped.generateContentStream).mockResolvedValue(
|
||||
createAsyncGenerator(),
|
||||
);
|
||||
const startTime = new Date('2025-01-01T00:00:00.000Z');
|
||||
vi.setSystemTime(startTime);
|
||||
|
||||
const stream = await loggingContentGenerator.generateContentStream(
|
||||
req,
|
||||
userPromptId,
|
||||
);
|
||||
|
||||
vi.advanceTimersByTime(1000);
|
||||
|
||||
await expect(async () => {
|
||||
for await (const _ of stream) {
|
||||
// do nothing
|
||||
}
|
||||
}).rejects.toThrow(error);
|
||||
|
||||
expect(logApiRequest).toHaveBeenCalledWith(
|
||||
config,
|
||||
expect.any(ApiRequestEvent),
|
||||
);
|
||||
const errorEvent = vi.mocked(logApiError).mock.calls[0][1];
|
||||
expect(errorEvent.duration_ms).toBe(1000);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getWrapped', () => {
|
||||
it('should return the wrapped content generator', () => {
|
||||
expect(loggingContentGenerator.getWrapped()).toBe(wrapped);
|
||||
});
|
||||
});
|
||||
|
||||
describe('countTokens', () => {
|
||||
it('should call the wrapped countTokens method', async () => {
|
||||
const req = { contents: [], model: 'gemini-pro' };
|
||||
const response = { totalTokens: 10 };
|
||||
vi.mocked(wrapped.countTokens).mockResolvedValue(response);
|
||||
|
||||
const result = await loggingContentGenerator.countTokens(req);
|
||||
|
||||
expect(wrapped.countTokens).toHaveBeenCalledWith(req);
|
||||
expect(result).toBe(response);
|
||||
});
|
||||
});
|
||||
|
||||
describe('embedContent', () => {
|
||||
it('should call the wrapped embedContent method', async () => {
|
||||
const req = {
|
||||
contents: [{ role: 'user', parts: [] }],
|
||||
model: 'gemini-pro',
|
||||
};
|
||||
const response: EmbedContentResponse = { embeddings: [{ values: [] }] };
|
||||
vi.mocked(wrapped.embedContent).mockResolvedValue(response);
|
||||
|
||||
const result = await loggingContentGenerator.embedContent(req);
|
||||
|
||||
expect(wrapped.embedContent).toHaveBeenCalledWith(req);
|
||||
expect(result).toBe(response);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,31 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { tokenLimit, DEFAULT_TOKEN_LIMIT } from './tokenLimits.js';
|
||||
|
||||
describe('tokenLimit', () => {
|
||||
it('should return the correct token limit for gemini-1.5-pro', () => {
|
||||
expect(tokenLimit('gemini-1.5-pro')).toBe(2_097_152);
|
||||
});
|
||||
|
||||
it('should return the correct token limit for gemini-1.5-flash', () => {
|
||||
expect(tokenLimit('gemini-1.5-flash')).toBe(1_048_576);
|
||||
});
|
||||
|
||||
it('should return the default token limit for an unknown model', () => {
|
||||
expect(tokenLimit('unknown-model')).toBe(DEFAULT_TOKEN_LIMIT);
|
||||
});
|
||||
|
||||
it('should return the default token limit if no model is provided', () => {
|
||||
// @ts-expect-error testing invalid input
|
||||
expect(tokenLimit(undefined)).toBe(DEFAULT_TOKEN_LIMIT);
|
||||
});
|
||||
|
||||
it('should have the correct default token limit value', () => {
|
||||
expect(DEFAULT_TOKEN_LIMIT).toBe(1_048_576);
|
||||
});
|
||||
});
|
||||
@@ -4,8 +4,40 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { HookEventName, HookType } from './types.js';
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import {
|
||||
createHookOutput,
|
||||
DefaultHookOutput,
|
||||
BeforeToolHookOutput,
|
||||
BeforeModelHookOutput,
|
||||
BeforeToolSelectionHookOutput,
|
||||
AfterModelHookOutput,
|
||||
HookEventName,
|
||||
HookType,
|
||||
} from './types.js';
|
||||
import { defaultHookTranslator } from './hookTranslator.js';
|
||||
import type {
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
ToolConfig,
|
||||
} from '@google/genai';
|
||||
import type { LLMRequest, LLMResponse } from './hookTranslator.js';
|
||||
import type { HookDecision } from './types.js';
|
||||
|
||||
vi.mock('./hookTranslator.js', () => ({
|
||||
defaultHookTranslator: {
|
||||
fromHookLLMResponse: vi.fn(
|
||||
(response: LLMResponse) => response as unknown as GenerateContentResponse,
|
||||
),
|
||||
fromHookLLMRequest: vi.fn(
|
||||
(request: LLMRequest, target: GenerateContentParameters) => ({
|
||||
...target,
|
||||
...request,
|
||||
}),
|
||||
),
|
||||
fromHookToolConfig: vi.fn((config: ToolConfig) => config),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('Hook Types', () => {
|
||||
describe('HookEventName', () => {
|
||||
@@ -36,3 +68,318 @@ describe('Hook Types', () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Hook Output Classes', () => {
|
||||
describe('createHookOutput', () => {
|
||||
it('should return DefaultHookOutput for unknown event names', () => {
|
||||
const output = createHookOutput('UnknownEvent', {});
|
||||
expect(output).toBeInstanceOf(DefaultHookOutput);
|
||||
expect(output).not.toBeInstanceOf(BeforeModelHookOutput);
|
||||
expect(output).not.toBeInstanceOf(AfterModelHookOutput);
|
||||
expect(output).not.toBeInstanceOf(BeforeToolSelectionHookOutput);
|
||||
});
|
||||
|
||||
it('should return BeforeModelHookOutput for BeforeModel event', () => {
|
||||
const output = createHookOutput(HookEventName.BeforeModel, {});
|
||||
expect(output).toBeInstanceOf(BeforeModelHookOutput);
|
||||
});
|
||||
|
||||
it('should return AfterModelHookOutput for AfterModel event', () => {
|
||||
const output = createHookOutput(HookEventName.AfterModel, {});
|
||||
expect(output).toBeInstanceOf(AfterModelHookOutput);
|
||||
});
|
||||
|
||||
it('should return BeforeToolSelectionHookOutput for BeforeToolSelection event', () => {
|
||||
const output = createHookOutput(HookEventName.BeforeToolSelection, {});
|
||||
expect(output).toBeInstanceOf(BeforeToolSelectionHookOutput);
|
||||
});
|
||||
});
|
||||
|
||||
describe('DefaultHookOutput', () => {
|
||||
it('should construct with provided data', () => {
|
||||
const data = {
|
||||
continue: false,
|
||||
stopReason: 'test stop',
|
||||
suppressOutput: true,
|
||||
systemMessage: 'test system message',
|
||||
decision: 'block' as HookDecision,
|
||||
reason: 'test reason',
|
||||
hookSpecificOutput: { key: 'value' },
|
||||
};
|
||||
const output = new DefaultHookOutput(data);
|
||||
expect(output.continue).toBe(data.continue);
|
||||
expect(output.stopReason).toBe(data.stopReason);
|
||||
expect(output.suppressOutput).toBe(data.suppressOutput);
|
||||
expect(output.systemMessage).toBe(data.systemMessage);
|
||||
expect(output.decision).toBe(data.decision);
|
||||
expect(output.reason).toBe(data.reason);
|
||||
expect(output.hookSpecificOutput).toEqual(data.hookSpecificOutput);
|
||||
});
|
||||
|
||||
it('should return false for isBlockingDecision if decision is not block or deny', () => {
|
||||
const output1 = new DefaultHookOutput({ decision: 'approve' });
|
||||
expect(output1.isBlockingDecision()).toBe(false);
|
||||
const output2 = new DefaultHookOutput({ decision: undefined });
|
||||
expect(output2.isBlockingDecision()).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true for isBlockingDecision if decision is block or deny', () => {
|
||||
const output1 = new DefaultHookOutput({ decision: 'block' });
|
||||
expect(output1.isBlockingDecision()).toBe(true);
|
||||
const output2 = new DefaultHookOutput({ decision: 'deny' });
|
||||
expect(output2.isBlockingDecision()).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for shouldStopExecution if continue is false', () => {
|
||||
const output = new DefaultHookOutput({ continue: false });
|
||||
expect(output.shouldStopExecution()).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for shouldStopExecution if continue is true or undefined', () => {
|
||||
const output1 = new DefaultHookOutput({ continue: true });
|
||||
expect(output1.shouldStopExecution()).toBe(false);
|
||||
const output2 = new DefaultHookOutput({});
|
||||
expect(output2.shouldStopExecution()).toBe(false);
|
||||
});
|
||||
|
||||
it('should return reason if available', () => {
|
||||
const output = new DefaultHookOutput({ reason: 'specific reason' });
|
||||
expect(output.getEffectiveReason()).toBe('specific reason');
|
||||
});
|
||||
|
||||
it('should return stopReason if reason is not available', () => {
|
||||
const output = new DefaultHookOutput({ stopReason: 'stop reason' });
|
||||
expect(output.getEffectiveReason()).toBe('stop reason');
|
||||
});
|
||||
|
||||
it('should return "No reason provided" if neither reason nor stopReason are available', () => {
|
||||
const output = new DefaultHookOutput({});
|
||||
expect(output.getEffectiveReason()).toBe('No reason provided');
|
||||
});
|
||||
|
||||
it('applyLLMRequestModifications should return target unchanged', () => {
|
||||
const target: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: [],
|
||||
};
|
||||
const output = new DefaultHookOutput({});
|
||||
expect(output.applyLLMRequestModifications(target)).toBe(target);
|
||||
});
|
||||
|
||||
it('applyToolConfigModifications should return target unchanged', () => {
|
||||
const target = { toolConfig: {}, tools: [] };
|
||||
const output = new DefaultHookOutput({});
|
||||
expect(output.applyToolConfigModifications(target)).toBe(target);
|
||||
});
|
||||
|
||||
it('getAdditionalContext should return additionalContext if present', () => {
|
||||
const output = new DefaultHookOutput({
|
||||
hookSpecificOutput: { additionalContext: 'some context' },
|
||||
});
|
||||
expect(output.getAdditionalContext()).toBe('some context');
|
||||
});
|
||||
|
||||
it('getAdditionalContext should return undefined if additionalContext is not present', () => {
|
||||
const output = new DefaultHookOutput({
|
||||
hookSpecificOutput: { other: 'value' },
|
||||
});
|
||||
expect(output.getAdditionalContext()).toBeUndefined();
|
||||
});
|
||||
|
||||
it('getAdditionalContext should return undefined if hookSpecificOutput is undefined', () => {
|
||||
const output = new DefaultHookOutput({});
|
||||
expect(output.getAdditionalContext()).toBeUndefined();
|
||||
});
|
||||
|
||||
it('getBlockingError should return blocked: true and reason if blocking decision', () => {
|
||||
const output = new DefaultHookOutput({
|
||||
decision: 'block',
|
||||
reason: 'blocked by hook',
|
||||
});
|
||||
expect(output.getBlockingError()).toEqual({
|
||||
blocked: true,
|
||||
reason: 'blocked by hook',
|
||||
});
|
||||
});
|
||||
|
||||
it('getBlockingError should return blocked: false if not blocking decision', () => {
|
||||
const output = new DefaultHookOutput({ decision: 'approve' });
|
||||
expect(output.getBlockingError()).toEqual({ blocked: false, reason: '' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('BeforeToolHookOutput', () => {
|
||||
it('isBlockingDecision should use permissionDecision from hookSpecificOutput', () => {
|
||||
const output1 = new BeforeToolHookOutput({
|
||||
hookSpecificOutput: { permissionDecision: 'block' },
|
||||
});
|
||||
expect(output1.isBlockingDecision()).toBe(true);
|
||||
|
||||
const output2 = new BeforeToolHookOutput({
|
||||
hookSpecificOutput: { permissionDecision: 'approve' },
|
||||
});
|
||||
expect(output2.isBlockingDecision()).toBe(false);
|
||||
});
|
||||
|
||||
it('getEffectiveReason should use permissionDecisionReason from hookSpecificOutput', () => {
|
||||
const output1 = new BeforeToolHookOutput({
|
||||
hookSpecificOutput: { permissionDecisionReason: 'compat reason' },
|
||||
});
|
||||
expect(output1.getEffectiveReason()).toBe('compat reason');
|
||||
|
||||
const output2 = new BeforeToolHookOutput({
|
||||
reason: 'default reason',
|
||||
hookSpecificOutput: { other: 'value' },
|
||||
});
|
||||
expect(output2.getEffectiveReason()).toBe('default reason');
|
||||
});
|
||||
});
|
||||
|
||||
describe('BeforeModelHookOutput', () => {
|
||||
it('getSyntheticResponse should return synthetic response if llm_response is present', () => {
|
||||
const mockResponse: LLMResponse = { candidates: [] };
|
||||
const output = new BeforeModelHookOutput({
|
||||
hookSpecificOutput: { llm_response: mockResponse },
|
||||
});
|
||||
expect(output.getSyntheticResponse()).toEqual(mockResponse);
|
||||
expect(defaultHookTranslator.fromHookLLMResponse).toHaveBeenCalledWith(
|
||||
mockResponse,
|
||||
);
|
||||
});
|
||||
|
||||
it('getSyntheticResponse should return undefined if llm_response is not present', () => {
|
||||
const output = new BeforeModelHookOutput({});
|
||||
expect(output.getSyntheticResponse()).toBeUndefined();
|
||||
});
|
||||
|
||||
it('applyLLMRequestModifications should apply modifications if llm_request is present', () => {
|
||||
const target: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: [{ parts: [{ text: 'original' }] }],
|
||||
};
|
||||
const mockRequest: Partial<LLMRequest> = {
|
||||
messages: [{ role: 'user', content: 'modified' }],
|
||||
};
|
||||
const output = new BeforeModelHookOutput({
|
||||
hookSpecificOutput: { llm_request: mockRequest },
|
||||
});
|
||||
const result = output.applyLLMRequestModifications(target);
|
||||
expect(result).toEqual({ ...target, ...mockRequest });
|
||||
expect(defaultHookTranslator.fromHookLLMRequest).toHaveBeenCalledWith(
|
||||
mockRequest,
|
||||
target,
|
||||
);
|
||||
});
|
||||
|
||||
it('applyLLMRequestModifications should return target unchanged if llm_request is not present', () => {
|
||||
const target: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: [],
|
||||
};
|
||||
const output = new BeforeModelHookOutput({});
|
||||
expect(output.applyLLMRequestModifications(target)).toBe(target);
|
||||
});
|
||||
});
|
||||
|
||||
describe('BeforeToolSelectionHookOutput', () => {
|
||||
it('applyToolConfigModifications should apply modifications if toolConfig is present', () => {
|
||||
const target = { tools: [{ functionDeclarations: [] }] };
|
||||
const mockToolConfig = { functionCallingConfig: { mode: 'ANY' } };
|
||||
const output = new BeforeToolSelectionHookOutput({
|
||||
hookSpecificOutput: { toolConfig: mockToolConfig },
|
||||
});
|
||||
const result = output.applyToolConfigModifications(target);
|
||||
expect(result).toEqual({ ...target, toolConfig: mockToolConfig });
|
||||
expect(defaultHookTranslator.fromHookToolConfig).toHaveBeenCalledWith(
|
||||
mockToolConfig,
|
||||
);
|
||||
});
|
||||
|
||||
it('applyToolConfigModifications should return target unchanged if toolConfig is not present', () => {
|
||||
const target = { toolConfig: {}, tools: [] };
|
||||
const output = new BeforeToolSelectionHookOutput({});
|
||||
expect(output.applyToolConfigModifications(target)).toBe(target);
|
||||
});
|
||||
|
||||
it('applyToolConfigModifications should initialize tools array if not present', () => {
|
||||
const target = {};
|
||||
const mockToolConfig = { functionCallingConfig: { mode: 'ANY' } };
|
||||
const output = new BeforeToolSelectionHookOutput({
|
||||
hookSpecificOutput: { toolConfig: mockToolConfig },
|
||||
});
|
||||
const result = output.applyToolConfigModifications(target);
|
||||
expect(result).toEqual({ tools: [], toolConfig: mockToolConfig });
|
||||
});
|
||||
});
|
||||
|
||||
describe('AfterModelHookOutput', () => {
|
||||
it('getModifiedResponse should return modified response if llm_response is present and has content', () => {
|
||||
const mockResponse: LLMResponse = {
|
||||
candidates: [{ content: { role: 'model', parts: ['modified'] } }],
|
||||
};
|
||||
const output = new AfterModelHookOutput({
|
||||
hookSpecificOutput: { llm_response: mockResponse },
|
||||
});
|
||||
expect(output.getModifiedResponse()).toEqual(mockResponse);
|
||||
expect(defaultHookTranslator.fromHookLLMResponse).toHaveBeenCalledWith(
|
||||
mockResponse,
|
||||
);
|
||||
});
|
||||
|
||||
it('getModifiedResponse should return undefined if llm_response is present but no content', () => {
|
||||
const mockResponse: LLMResponse = {
|
||||
candidates: [{ content: { role: 'model', parts: [] } }],
|
||||
};
|
||||
const output = new AfterModelHookOutput({
|
||||
hookSpecificOutput: { llm_response: mockResponse },
|
||||
});
|
||||
expect(output.getModifiedResponse()).toBeUndefined();
|
||||
});
|
||||
|
||||
it('getModifiedResponse should return undefined if llm_response is not present', () => {
|
||||
const output = new AfterModelHookOutput({});
|
||||
expect(output.getModifiedResponse()).toBeUndefined();
|
||||
});
|
||||
|
||||
it('getModifiedResponse should return a synthetic stop response if shouldStopExecution is true', () => {
|
||||
const output = new AfterModelHookOutput({
|
||||
continue: false,
|
||||
stopReason: 'stopped by hook',
|
||||
});
|
||||
const expectedResponse: LLMResponse = {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: ['stopped by hook'],
|
||||
},
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
};
|
||||
expect(output.getModifiedResponse()).toEqual(expectedResponse);
|
||||
expect(defaultHookTranslator.fromHookLLMResponse).toHaveBeenCalledWith(
|
||||
expectedResponse,
|
||||
);
|
||||
});
|
||||
|
||||
it('getModifiedResponse should return a synthetic stop response with default reason if shouldStopExecution is true and no stopReason', () => {
|
||||
const output = new AfterModelHookOutput({ continue: false });
|
||||
const expectedResponse: LLMResponse = {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: ['No reason provided'],
|
||||
},
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
};
|
||||
expect(output.getModifiedResponse()).toEqual(expectedResponse);
|
||||
expect(defaultHookTranslator.fromHookLLMResponse).toHaveBeenCalledWith(
|
||||
expectedResponse,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -340,7 +340,7 @@ export class AfterModelHookOutput extends DefaultHookOutput {
|
||||
const hookResponse = this.hookSpecificOutput[
|
||||
'llm_response'
|
||||
] as Partial<LLMResponse>;
|
||||
if (hookResponse?.candidates?.[0]?.content) {
|
||||
if (hookResponse?.candidates?.[0]?.content?.parts?.length) {
|
||||
// Convert hook format to SDK format
|
||||
return defaultHookTranslator.fromHookLLMResponse(
|
||||
hookResponse as LLMResponse,
|
||||
|
||||
@@ -647,6 +647,254 @@ describe('IdeClient', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('resolveDiffFromCli', () => {
|
||||
beforeEach(async () => {
|
||||
// Ensure client is "connected" for these tests
|
||||
const ideClient = await IdeClient.getInstance();
|
||||
// We need to set the client property on the instance for openDiff to work
|
||||
(ideClient as unknown as { client: Client }).client = mockClient;
|
||||
mockClient.request.mockResolvedValue({
|
||||
isError: false,
|
||||
content: [],
|
||||
});
|
||||
});
|
||||
|
||||
it("should resolve an open diff as 'accepted' and return the final content", async () => {
|
||||
const ideClient = await IdeClient.getInstance();
|
||||
const closeDiffSpy = vi
|
||||
.spyOn(
|
||||
ideClient as unknown as {
|
||||
closeDiff: () => Promise<string | undefined>;
|
||||
},
|
||||
'closeDiff',
|
||||
)
|
||||
.mockResolvedValue('final content from ide');
|
||||
|
||||
const diffPromise = ideClient.openDiff('/test.txt', 'new content');
|
||||
|
||||
// Yield to the event loop to allow the openDiff promise executor to run
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
|
||||
await ideClient.resolveDiffFromCli('/test.txt', 'accepted');
|
||||
|
||||
const result = await diffPromise;
|
||||
|
||||
expect(result).toEqual({
|
||||
status: 'accepted',
|
||||
content: 'final content from ide',
|
||||
});
|
||||
expect(closeDiffSpy).toHaveBeenCalledWith('/test.txt', {
|
||||
suppressNotification: true,
|
||||
});
|
||||
expect(
|
||||
(
|
||||
ideClient as unknown as { diffResponses: Map<string, unknown> }
|
||||
).diffResponses.has('/test.txt'),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("should resolve an open diff as 'rejected'", async () => {
|
||||
const ideClient = await IdeClient.getInstance();
|
||||
const closeDiffSpy = vi
|
||||
.spyOn(
|
||||
ideClient as unknown as {
|
||||
closeDiff: () => Promise<string | undefined>;
|
||||
},
|
||||
'closeDiff',
|
||||
)
|
||||
.mockResolvedValue(undefined);
|
||||
|
||||
const diffPromise = ideClient.openDiff('/test.txt', 'new content');
|
||||
|
||||
// Yield to the event loop to allow the openDiff promise executor to run
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
|
||||
await ideClient.resolveDiffFromCli('/test.txt', 'rejected');
|
||||
|
||||
const result = await diffPromise;
|
||||
|
||||
expect(result).toEqual({
|
||||
status: 'rejected',
|
||||
content: undefined,
|
||||
});
|
||||
expect(closeDiffSpy).toHaveBeenCalledWith('/test.txt', {
|
||||
suppressNotification: true,
|
||||
});
|
||||
expect(
|
||||
(
|
||||
ideClient as unknown as { diffResponses: Map<string, unknown> }
|
||||
).diffResponses.has('/test.txt'),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it('should do nothing if no diff is open for the given file path', async () => {
|
||||
const ideClient = await IdeClient.getInstance();
|
||||
const closeDiffSpy = vi
|
||||
.spyOn(
|
||||
ideClient as unknown as {
|
||||
closeDiff: () => Promise<string | undefined>;
|
||||
},
|
||||
'closeDiff',
|
||||
)
|
||||
.mockResolvedValue(undefined);
|
||||
|
||||
// No call to openDiff, so no resolver will exist.
|
||||
await ideClient.resolveDiffFromCli('/non-existent.txt', 'accepted');
|
||||
|
||||
expect(closeDiffSpy).toHaveBeenCalledWith('/non-existent.txt', {
|
||||
suppressNotification: true,
|
||||
});
|
||||
// No crash should occur, and nothing should be in the map.
|
||||
expect(
|
||||
(
|
||||
ideClient as unknown as { diffResponses: Map<string, unknown> }
|
||||
).diffResponses.has('/non-existent.txt'),
|
||||
).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('closeDiff', () => {
|
||||
beforeEach(async () => {
|
||||
const ideClient = await IdeClient.getInstance();
|
||||
(ideClient as unknown as { client: Client }).client = mockClient;
|
||||
});
|
||||
|
||||
it('should return undefined if client is not connected', async () => {
|
||||
const ideClient = await IdeClient.getInstance();
|
||||
(ideClient as unknown as { client: Client | undefined }).client =
|
||||
undefined;
|
||||
|
||||
const result = await (
|
||||
ideClient as unknown as { closeDiff: (f: string) => Promise<void> }
|
||||
).closeDiff('/test.txt');
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should call client.request with correct arguments', async () => {
|
||||
const ideClient = await IdeClient.getInstance();
|
||||
// Return a valid, empty response as the return value is not under test here.
|
||||
mockClient.request.mockResolvedValue({ isError: false, content: [] });
|
||||
|
||||
await (
|
||||
ideClient as unknown as {
|
||||
closeDiff: (
|
||||
f: string,
|
||||
o?: { suppressNotification?: boolean },
|
||||
) => Promise<void>;
|
||||
}
|
||||
).closeDiff('/test.txt', { suppressNotification: true });
|
||||
|
||||
expect(mockClient.request).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
params: {
|
||||
name: 'closeDiff',
|
||||
arguments: {
|
||||
filePath: '/test.txt',
|
||||
suppressNotification: true,
|
||||
},
|
||||
},
|
||||
}),
|
||||
expect.any(Object), // Schema
|
||||
expect.any(Object), // Options
|
||||
);
|
||||
});
|
||||
|
||||
it('should return content from a valid JSON response', async () => {
|
||||
const ideClient = await IdeClient.getInstance();
|
||||
const response = {
|
||||
isError: false,
|
||||
content: [
|
||||
{ type: 'text', text: JSON.stringify({ content: 'file content' }) },
|
||||
],
|
||||
};
|
||||
mockClient.request.mockResolvedValue(response);
|
||||
|
||||
const result = await (
|
||||
ideClient as unknown as { closeDiff: (f: string) => Promise<string> }
|
||||
).closeDiff('/test.txt');
|
||||
expect(result).toBe('file content');
|
||||
});
|
||||
|
||||
it('should return undefined for a valid JSON response with null content', async () => {
|
||||
const ideClient = await IdeClient.getInstance();
|
||||
const response = {
|
||||
isError: false,
|
||||
content: [{ type: 'text', text: JSON.stringify({ content: null }) }],
|
||||
};
|
||||
mockClient.request.mockResolvedValue(response);
|
||||
|
||||
const result = await (
|
||||
ideClient as unknown as { closeDiff: (f: string) => Promise<void> }
|
||||
).closeDiff('/test.txt');
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return undefined if response is not valid JSON', async () => {
|
||||
const ideClient = await IdeClient.getInstance();
|
||||
const response = {
|
||||
isError: false,
|
||||
content: [{ type: 'text', text: 'not json' }],
|
||||
};
|
||||
mockClient.request.mockResolvedValue(response);
|
||||
|
||||
const result = await (
|
||||
ideClient as unknown as { closeDiff: (f: string) => Promise<void> }
|
||||
).closeDiff('/test.txt');
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return undefined if request result has isError: true', async () => {
|
||||
const ideClient = await IdeClient.getInstance();
|
||||
const response = {
|
||||
isError: true,
|
||||
content: [{ type: 'text', text: 'An error occurred' }],
|
||||
};
|
||||
mockClient.request.mockResolvedValue(response);
|
||||
|
||||
const result = await (
|
||||
ideClient as unknown as { closeDiff: (f: string) => Promise<void> }
|
||||
).closeDiff('/test.txt');
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return undefined if client.request throws', async () => {
|
||||
const ideClient = await IdeClient.getInstance();
|
||||
mockClient.request.mockRejectedValue(new Error('Request failed'));
|
||||
|
||||
const result = await (
|
||||
ideClient as unknown as { closeDiff: (f: string) => Promise<void> }
|
||||
).closeDiff('/test.txt');
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return undefined if response has no text part', async () => {
|
||||
const ideClient = await IdeClient.getInstance();
|
||||
const response = {
|
||||
isError: false,
|
||||
content: [{ type: 'other' }],
|
||||
};
|
||||
mockClient.request.mockResolvedValue(response);
|
||||
|
||||
const result = await (
|
||||
ideClient as unknown as { closeDiff: (f: string) => Promise<void> }
|
||||
).closeDiff('/test.txt');
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return undefined if response is falsy', async () => {
|
||||
const ideClient = await IdeClient.getInstance();
|
||||
// Mocking with `null as any` to test the falsy path, as the mock
|
||||
// function is strictly typed.
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
mockClient.request.mockResolvedValue(null as any);
|
||||
|
||||
const result = await (
|
||||
ideClient as unknown as { closeDiff: (f: string) => Promise<void> }
|
||||
).closeDiff('/test.txt');
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('authentication', () => {
|
||||
it('should connect with an auth token if provided in the discovery file', async () => {
|
||||
const authToken = 'test-auth-token';
|
||||
|
||||
@@ -714,6 +714,269 @@ describe('MCPOAuthProvider', () => {
|
||||
).rejects.toThrow('Token exchange failed: invalid_grant - Invalid grant');
|
||||
});
|
||||
|
||||
it('should handle OAuth discovery failure', async () => {
|
||||
const configWithoutAuth: MCPOAuthConfig = { ...mockConfig };
|
||||
delete configWithoutAuth.authorizationUrl;
|
||||
delete configWithoutAuth.tokenUrl;
|
||||
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
createMockResponse({
|
||||
ok: false,
|
||||
status: 404,
|
||||
}),
|
||||
);
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await expect(
|
||||
authProvider.authenticate(
|
||||
'test-server',
|
||||
configWithoutAuth,
|
||||
'https://api.example.com',
|
||||
),
|
||||
).rejects.toThrow(
|
||||
'Failed to discover OAuth configuration from MCP server',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle authorization server metadata discovery failure', async () => {
|
||||
const configWithoutClient: MCPOAuthConfig = { ...mockConfig };
|
||||
delete configWithoutClient.clientId;
|
||||
|
||||
mockFetch.mockResolvedValue(
|
||||
createMockResponse({
|
||||
ok: false,
|
||||
status: 404,
|
||||
}),
|
||||
);
|
||||
|
||||
// Prevent callback server from hanging the test
|
||||
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||
callback?.();
|
||||
});
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await expect(
|
||||
authProvider.authenticate('test-server', configWithoutClient),
|
||||
).rejects.toThrow(
|
||||
'Failed to fetch authorization server metadata for client registration',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle invalid callback request', async () => {
|
||||
let callbackHandler: unknown;
|
||||
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||
callbackHandler = handler;
|
||||
return mockHttpServer as unknown as http.Server;
|
||||
});
|
||||
|
||||
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||
callback?.();
|
||||
setTimeout(() => {
|
||||
const mockReq = {
|
||||
url: '/invalid-path',
|
||||
};
|
||||
const mockRes = {
|
||||
writeHead: vi.fn(),
|
||||
end: vi.fn(),
|
||||
};
|
||||
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||
mockReq,
|
||||
mockRes,
|
||||
);
|
||||
}, 0);
|
||||
});
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
// The test will timeout if the server does not handle the invalid request correctly.
|
||||
// We are testing that the server does not hang.
|
||||
await Promise.race([
|
||||
authProvider.authenticate('test-server', mockConfig),
|
||||
new Promise((resolve) => setTimeout(resolve, 1000)),
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle token exchange failure with non-json response', async () => {
|
||||
let callbackHandler: unknown;
|
||||
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||
callbackHandler = handler;
|
||||
return mockHttpServer as unknown as http.Server;
|
||||
});
|
||||
|
||||
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||
callback?.();
|
||||
setTimeout(() => {
|
||||
const mockReq = {
|
||||
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
|
||||
};
|
||||
const mockRes = {
|
||||
writeHead: vi.fn(),
|
||||
end: vi.fn(),
|
||||
};
|
||||
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||
mockReq,
|
||||
mockRes,
|
||||
);
|
||||
}, 10);
|
||||
});
|
||||
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
createMockResponse({
|
||||
ok: false,
|
||||
status: 500,
|
||||
contentType: 'text/html',
|
||||
text: 'Internal Server Error',
|
||||
}),
|
||||
);
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await expect(
|
||||
authProvider.authenticate('test-server', mockConfig),
|
||||
).rejects.toThrow('Token exchange failed: 500 - Internal Server Error');
|
||||
});
|
||||
|
||||
it('should handle token exchange with unexpected content type', async () => {
|
||||
let callbackHandler: unknown;
|
||||
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||
callbackHandler = handler;
|
||||
return mockHttpServer as unknown as http.Server;
|
||||
});
|
||||
|
||||
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||
callback?.();
|
||||
setTimeout(() => {
|
||||
const mockReq = {
|
||||
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
|
||||
};
|
||||
const mockRes = {
|
||||
writeHead: vi.fn(),
|
||||
end: vi.fn(),
|
||||
};
|
||||
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||
mockReq,
|
||||
mockRes,
|
||||
);
|
||||
}, 10);
|
||||
});
|
||||
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
createMockResponse({
|
||||
ok: true,
|
||||
contentType: 'text/plain',
|
||||
text: 'access_token=plain_text_token',
|
||||
}),
|
||||
);
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.authenticate('test-server', mockConfig);
|
||||
expect(result.accessToken).toBe('plain_text_token');
|
||||
});
|
||||
|
||||
it('should handle refresh token failure with non-json response', async () => {
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
createMockResponse({
|
||||
ok: false,
|
||||
status: 500,
|
||||
contentType: 'text/html',
|
||||
text: 'Internal Server Error',
|
||||
}),
|
||||
);
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await expect(
|
||||
authProvider.refreshAccessToken(
|
||||
mockConfig,
|
||||
'invalid_refresh_token',
|
||||
'https://auth.example.com/token',
|
||||
),
|
||||
).rejects.toThrow('Token refresh failed: 500 - Internal Server Error');
|
||||
});
|
||||
|
||||
it('should handle refresh token with unexpected content type', async () => {
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
createMockResponse({
|
||||
ok: true,
|
||||
contentType: 'text/plain',
|
||||
text: 'access_token=plain_text_token',
|
||||
}),
|
||||
);
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.refreshAccessToken(
|
||||
mockConfig,
|
||||
'refresh_token',
|
||||
'https://auth.example.com/token',
|
||||
);
|
||||
expect(result.access_token).toBe('plain_text_token');
|
||||
});
|
||||
|
||||
it('should continue authentication when browser fails to open', async () => {
|
||||
mockOpenBrowserSecurely.mockRejectedValue(new Error('Browser not found'));
|
||||
|
||||
let callbackHandler: unknown;
|
||||
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||
callbackHandler = handler;
|
||||
return mockHttpServer as unknown as http.Server;
|
||||
});
|
||||
|
||||
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||
callback?.();
|
||||
setTimeout(() => {
|
||||
const mockReq = {
|
||||
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
|
||||
};
|
||||
const mockRes = {
|
||||
writeHead: vi.fn(),
|
||||
end: vi.fn(),
|
||||
};
|
||||
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||
mockReq,
|
||||
mockRes,
|
||||
);
|
||||
}, 10);
|
||||
});
|
||||
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
createMockResponse({
|
||||
ok: true,
|
||||
contentType: 'application/json',
|
||||
text: JSON.stringify(mockTokenResponse),
|
||||
json: mockTokenResponse,
|
||||
}),
|
||||
);
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.authenticate('test-server', mockConfig);
|
||||
expect(result).toBeDefined();
|
||||
});
|
||||
|
||||
it('should return null when token is expired and no refresh token is available', async () => {
|
||||
const expiredCredentials = {
|
||||
serverName: 'test-server',
|
||||
token: {
|
||||
...mockToken,
|
||||
refreshToken: undefined,
|
||||
expiresAt: Date.now() - 3600000,
|
||||
},
|
||||
clientId: 'test-client-id',
|
||||
tokenUrl: 'https://auth.example.com/token',
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
|
||||
expiredCredentials,
|
||||
);
|
||||
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.getValidToken(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle callback timeout', async () => {
|
||||
vi.mocked(http.createServer).mockImplementation(
|
||||
() => mockHttpServer as unknown as http.Server,
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { getMCPServerPrompts } from './mcp-prompts.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { PromptRegistry } from './prompt-registry.js';
|
||||
import type { DiscoveredMCPPrompt } from '../tools/mcp-client.js';
|
||||
|
||||
describe('getMCPServerPrompts', () => {
|
||||
it('should return prompts from the registry for a given server', () => {
|
||||
const mockPrompts: DiscoveredMCPPrompt[] = [
|
||||
{
|
||||
name: 'prompt1',
|
||||
serverName: 'server1',
|
||||
tool: { name: 'p1', description: '', inputSchema: {} },
|
||||
invoke: async () => ({
|
||||
messages: [
|
||||
{ role: 'assistant', content: { type: 'text', text: '' } },
|
||||
],
|
||||
}),
|
||||
},
|
||||
];
|
||||
|
||||
const mockRegistry = new PromptRegistry();
|
||||
vi.spyOn(mockRegistry, 'getPromptsByServer').mockReturnValue(mockPrompts);
|
||||
|
||||
const mockConfig = {
|
||||
getPromptRegistry: () => mockRegistry,
|
||||
} as unknown as Config;
|
||||
|
||||
const result = getMCPServerPrompts(mockConfig, 'server1');
|
||||
|
||||
expect(mockRegistry.getPromptsByServer).toHaveBeenCalledWith('server1');
|
||||
expect(result).toEqual(mockPrompts);
|
||||
});
|
||||
|
||||
it('should return an empty array if there is no prompt registry', () => {
|
||||
const mockConfig = {
|
||||
getPromptRegistry: () => undefined,
|
||||
} as unknown as Config;
|
||||
|
||||
const result = getMCPServerPrompts(mockConfig, 'server1');
|
||||
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,129 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
||||
import { PromptRegistry } from './prompt-registry.js';
|
||||
import type { DiscoveredMCPPrompt } from '../tools/mcp-client.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
vi.mock('../utils/debugLogger.js', () => ({
|
||||
debugLogger: {
|
||||
warn: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('PromptRegistry', () => {
|
||||
let registry: PromptRegistry;
|
||||
|
||||
const prompt1: DiscoveredMCPPrompt = {
|
||||
name: 'prompt1',
|
||||
serverName: 'server1',
|
||||
tool: {
|
||||
name: 'prompt1',
|
||||
description: 'Prompt 1',
|
||||
inputSchema: {},
|
||||
},
|
||||
invoke: async () => ({
|
||||
messages: [
|
||||
{ role: 'assistant', content: { type: 'text', text: 'response1' } },
|
||||
],
|
||||
}),
|
||||
};
|
||||
|
||||
const prompt2: DiscoveredMCPPrompt = {
|
||||
name: 'prompt2',
|
||||
serverName: 'server1',
|
||||
tool: {
|
||||
name: 'prompt2',
|
||||
description: 'Prompt 2',
|
||||
inputSchema: {},
|
||||
},
|
||||
invoke: async () => ({
|
||||
messages: [
|
||||
{ role: 'assistant', content: { type: 'text', text: 'response2' } },
|
||||
],
|
||||
}),
|
||||
};
|
||||
|
||||
const prompt3: DiscoveredMCPPrompt = {
|
||||
name: 'prompt1',
|
||||
serverName: 'server2',
|
||||
tool: {
|
||||
name: 'prompt1',
|
||||
description: 'Prompt 3',
|
||||
inputSchema: {},
|
||||
},
|
||||
invoke: async () => ({
|
||||
messages: [
|
||||
{ role: 'assistant', content: { type: 'text', text: 'response3' } },
|
||||
],
|
||||
}),
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
registry = new PromptRegistry();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should register a prompt', () => {
|
||||
registry.registerPrompt(prompt1);
|
||||
expect(registry.getPrompt('prompt1')).toEqual(prompt1);
|
||||
});
|
||||
|
||||
it('should get all prompts, sorted by name', () => {
|
||||
registry.registerPrompt(prompt2);
|
||||
registry.registerPrompt(prompt1);
|
||||
expect(registry.getAllPrompts()).toEqual([prompt1, prompt2]);
|
||||
});
|
||||
|
||||
it('should get a specific prompt by name', () => {
|
||||
registry.registerPrompt(prompt1);
|
||||
expect(registry.getPrompt('prompt1')).toEqual(prompt1);
|
||||
expect(registry.getPrompt('non-existent')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should get prompts by server, sorted by name', () => {
|
||||
registry.registerPrompt(prompt1);
|
||||
registry.registerPrompt(prompt2);
|
||||
registry.registerPrompt(prompt3); // different server
|
||||
expect(registry.getPromptsByServer('server1')).toEqual([prompt1, prompt2]);
|
||||
expect(registry.getPromptsByServer('server2')).toEqual([
|
||||
{ ...prompt3, name: 'server2_prompt1' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle prompt name collision by renaming', () => {
|
||||
registry.registerPrompt(prompt1);
|
||||
registry.registerPrompt(prompt3);
|
||||
|
||||
expect(registry.getPrompt('prompt1')).toEqual(prompt1);
|
||||
const renamedPrompt = { ...prompt3, name: 'server2_prompt1' };
|
||||
expect(registry.getPrompt('server2_prompt1')).toEqual(renamedPrompt);
|
||||
expect(debugLogger.warn).toHaveBeenCalledWith(
|
||||
'Prompt with name "prompt1" is already registered. Renaming to "server2_prompt1".',
|
||||
);
|
||||
});
|
||||
|
||||
it('should clear all prompts', () => {
|
||||
registry.registerPrompt(prompt1);
|
||||
registry.registerPrompt(prompt2);
|
||||
registry.clear();
|
||||
expect(registry.getAllPrompts()).toEqual([]);
|
||||
});
|
||||
|
||||
it('should remove prompts by server', () => {
|
||||
registry.registerPrompt(prompt1);
|
||||
registry.registerPrompt(prompt2);
|
||||
registry.registerPrompt(prompt3);
|
||||
registry.removePromptsByServer('server1');
|
||||
|
||||
const renamedPrompt = { ...prompt3, name: 'server2_prompt1' };
|
||||
expect(registry.getAllPrompts()).toEqual([renamedPrompt]);
|
||||
expect(registry.getPrompt('prompt1')).toBeUndefined();
|
||||
expect(registry.getPrompt('prompt2')).toBeUndefined();
|
||||
expect(registry.getPrompt('server2_prompt1')).toEqual(renamedPrompt);
|
||||
});
|
||||
});
|
||||
@@ -394,6 +394,15 @@ describe('ClearcutLogger', () => {
|
||||
},
|
||||
expected: 'devin',
|
||||
},
|
||||
{
|
||||
name: 'unidentified',
|
||||
env: {
|
||||
GITHUB_SHA: undefined,
|
||||
TERM_PROGRAM: undefined,
|
||||
SURFACE: undefined,
|
||||
},
|
||||
expected: 'SURFACE_NOT_SET',
|
||||
},
|
||||
])(
|
||||
'logs the current surface as $expected from $name',
|
||||
({ env, expected }) => {
|
||||
@@ -943,6 +952,33 @@ describe('ClearcutLogger', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('flushIfNeeded', () => {
|
||||
it('should not flush if the interval has not passed', () => {
|
||||
const { logger } = setup();
|
||||
const flushSpy = vi
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
.spyOn(logger!, 'flushToClearcut' as any)
|
||||
.mockResolvedValue({ nextRequestWaitMs: 0 });
|
||||
|
||||
logger!.flushIfNeeded();
|
||||
expect(flushSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should flush if the interval has passed', async () => {
|
||||
const { logger } = setup();
|
||||
const flushSpy = vi
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
.spyOn(logger!, 'flushToClearcut' as any)
|
||||
.mockResolvedValue({ nextRequestWaitMs: 0 });
|
||||
|
||||
// Advance time by more than the flush interval
|
||||
await vi.advanceTimersByTimeAsync(1000 * 60 * 2);
|
||||
|
||||
logger!.flushIfNeeded();
|
||||
expect(flushSpy).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('logWebFetchFallbackAttemptEvent', () => {
|
||||
it('logs an event with the proper name and reason', () => {
|
||||
const { logger } = setup();
|
||||
|
||||
@@ -556,7 +556,7 @@ export function hasCycleInSchema(schema: object): boolean {
|
||||
|
||||
if ('$ref' in node && typeof node.$ref === 'string') {
|
||||
const ref = node.$ref;
|
||||
if (ref === '#/' || pathRefs.has(ref)) {
|
||||
if (ref === '#' || ref === '#/' || pathRefs.has(ref)) {
|
||||
// A ref to just '#/' is always a cycle.
|
||||
return true; // Cycle detected!
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user