Increase code coverage for core packages (#12872)

This commit is contained in:
Megha Bansal
2025-11-11 20:06:43 -08:00
committed by GitHub
parent e8038c727f
commit 11a0a9b911
18 changed files with 2265 additions and 47 deletions
@@ -0,0 +1,48 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { CodebaseInvestigatorAgent } from './codebase-investigator.js';
import {
GLOB_TOOL_NAME,
GREP_TOOL_NAME,
LS_TOOL_NAME,
READ_FILE_TOOL_NAME,
} from '../tools/tool-names.js';
import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
describe('CodebaseInvestigatorAgent', () => {
it('should have the correct agent definition', () => {
expect(CodebaseInvestigatorAgent.name).toBe('codebase_investigator');
expect(CodebaseInvestigatorAgent.displayName).toBe(
'Codebase Investigator Agent',
);
expect(CodebaseInvestigatorAgent.description).toBeDefined();
expect(
CodebaseInvestigatorAgent.inputConfig.inputs['objective'].required,
).toBe(true);
expect(CodebaseInvestigatorAgent.outputConfig?.outputName).toBe('report');
expect(CodebaseInvestigatorAgent.modelConfig?.model).toBe(
DEFAULT_GEMINI_MODEL,
);
expect(CodebaseInvestigatorAgent.toolConfig?.tools).toEqual([
LS_TOOL_NAME,
READ_FILE_TOOL_NAME,
GLOB_TOOL_NAME,
GREP_TOOL_NAME,
]);
});
it('should process output to a formatted JSON string', () => {
const report = {
SummaryOfFindings: 'summary',
ExplorationTrace: ['trace'],
RelevantLocations: [],
};
const processed = CodebaseInvestigatorAgent.processOutput?.(report);
expect(processed).toBe(JSON.stringify(report, null, 2));
});
});
+246 -1
View File
@@ -329,6 +329,47 @@ describe('AgentExecutor', () => {
new RegExp(`^${parentId}-${definition.name}-`),
);
});
it('should correctly apply templates to initialMessages', async () => {
const definition = createTestDefinition();
// Override promptConfig to use initialMessages instead of systemPrompt
definition.promptConfig = {
initialMessages: [
{ role: 'user', parts: [{ text: 'Goal: ${goal}' }] },
{ role: 'model', parts: [{ text: 'OK, starting on ${goal}.' }] },
],
};
const inputs = { goal: 'TestGoal' };
// Mock a response to prevent the loop from running forever
mockModelResponse([
{
name: TASK_COMPLETE_TOOL_NAME,
args: { finalResult: 'done' },
id: 'call1',
},
]);
const executor = await AgentExecutor.create(
definition,
mockConfig,
onActivity,
);
await executor.run(inputs, signal);
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
const startHistory = chatConstructorArgs[2]; // history is the 3rd arg
expect(startHistory).toBeDefined();
expect(startHistory).toHaveLength(2);
// Perform checks on defined objects to satisfy TS
const firstPart = startHistory?.[0]?.parts?.[0];
expect(firstPart?.text).toBe('Goal: TestGoal');
const secondPart = startHistory?.[1]?.parts?.[0];
expect(secondPart?.text).toBe('OK, starting on TestGoal.');
});
});
describe('run (Execution Loop and Logic)', () => {
@@ -420,9 +461,16 @@ describe('AgentExecutor', () => {
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
const chatConfig = chatConstructorArgs[1];
expect(chatConfig?.systemInstruction).toContain(
const systemInstruction = chatConfig?.systemInstruction as string;
expect(systemInstruction).toContain(
`MUST call the \`${TASK_COMPLETE_TOOL_NAME}\` tool`,
);
expect(systemInstruction).toContain('Mocked Environment Context');
expect(systemInstruction).toContain(
'You are running in a non-interactive mode',
);
expect(systemInstruction).toContain('Always use absolute paths');
const turn1Params = getMockMessageParams(0);
@@ -921,6 +969,203 @@ describe('AgentExecutor', () => {
});
});
describe('Edge Cases and Error Handling', () => {
it('should report an error if complete_task output fails schema validation', async () => {
const definition = createTestDefinition(
[],
{},
'default',
z.string().min(10), // The schema is for the output value itself
);
const executor = await AgentExecutor.create(
definition,
mockConfig,
onActivity,
);
// Turn 1: Invalid arg (too short)
mockModelResponse([
{
name: TASK_COMPLETE_TOOL_NAME,
args: { finalResult: 'short' },
id: 'call1',
},
]);
// Turn 2: Corrected
mockModelResponse([
{
name: TASK_COMPLETE_TOOL_NAME,
args: { finalResult: 'This is a much longer and valid result' },
id: 'call2',
},
]);
const output = await executor.run({ goal: 'Validation test' }, signal);
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
const expectedError =
'Output validation failed: {"formErrors":["String must contain at least 10 character(s)"],"fieldErrors":{}}';
// Check that the error was reported in the activity stream
expect(activities).toContainEqual(
expect.objectContaining({
type: 'ERROR',
data: {
context: 'tool_call',
name: TASK_COMPLETE_TOOL_NAME,
error: expect.stringContaining('Output validation failed'),
},
}),
);
// Check that the error was sent back to the model for the next turn
const turn2Params = getMockMessageParams(1);
const turn2Parts = turn2Params.message;
expect(turn2Parts).toEqual([
expect.objectContaining({
functionResponse: expect.objectContaining({
name: TASK_COMPLETE_TOOL_NAME,
response: { error: expectedError },
id: 'call1',
}),
}),
]);
// Check that the agent eventually succeeded
expect(output.result).toContain('This is a much longer and valid result');
expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL);
});
it('should throw and log if GeminiChat creation fails', async () => {
const definition = createTestDefinition();
const initError = new Error('Chat creation failed');
MockedGeminiChat.mockImplementationOnce(() => {
throw initError;
});
// We expect the error to be thrown during the run, not creation
const executor = await AgentExecutor.create(
definition,
mockConfig,
onActivity,
);
await expect(executor.run({ goal: 'test' }, signal)).rejects.toThrow(
`Failed to create chat object: ${initError}`,
);
// Ensure the error was reported via the activity callback
expect(activities).toContainEqual(
expect.objectContaining({
type: 'ERROR',
data: expect.objectContaining({
error: `Error: Failed to create chat object: ${initError}`,
}),
}),
);
// Ensure the agent run was logged as a failure
expect(mockedLogAgentFinish).toHaveBeenCalledWith(
mockConfig,
expect.objectContaining({
terminate_reason: AgentTerminateMode.ERROR,
}),
);
});
it('should handle a failed tool call and feed the error to the model', async () => {
const definition = createTestDefinition([LS_TOOL_NAME]);
const executor = await AgentExecutor.create(
definition,
mockConfig,
onActivity,
);
const toolErrorMessage = 'Tool failed spectacularly';
// Turn 1: Model calls a tool that will fail
mockModelResponse([
{ name: LS_TOOL_NAME, args: { path: '/fake' }, id: 'call1' },
]);
mockExecuteToolCall.mockResolvedValueOnce({
status: 'error',
request: {
callId: 'call1',
name: LS_TOOL_NAME,
args: { path: '/fake' },
isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'call1',
resultDisplay: '',
responseParts: [
{
functionResponse: {
name: LS_TOOL_NAME,
response: { error: toolErrorMessage },
id: 'call1',
},
},
],
error: {
type: 'ToolError',
message: toolErrorMessage,
},
errorType: 'ToolError',
contentLength: 0,
},
});
// Turn 2: Model sees the error and completes
mockModelResponse([
{
name: TASK_COMPLETE_TOOL_NAME,
args: { finalResult: 'Aborted due to tool failure.' },
id: 'call2',
},
]);
const output = await executor.run({ goal: 'Tool failure test' }, signal);
expect(mockExecuteToolCall).toHaveBeenCalledTimes(1);
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
// Verify the error was reported in the activity stream
expect(activities).toContainEqual(
expect.objectContaining({
type: 'ERROR',
data: {
context: 'tool_call',
name: LS_TOOL_NAME,
error: toolErrorMessage,
},
}),
);
// Verify the error was sent back to the model
const turn2Params = getMockMessageParams(1);
const parts = turn2Params.message;
expect(parts).toEqual([
expect.objectContaining({
functionResponse: expect.objectContaining({
name: LS_TOOL_NAME,
id: 'call1',
response: {
error: toolErrorMessage,
},
}),
}),
]);
expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL);
expect(output.result).toBe('Aborted due to tool failure.');
});
});
describe('run (Termination Conditions)', () => {
const mockWorkResponse = (id: string) => {
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
@@ -0,0 +1,163 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { AuthType } from '../core/contentGenerator.js';
import { getOauthClient } from './oauth2.js';
import { setupUser } from './setup.js';
import { CodeAssistServer } from './server.js';
import {
createCodeAssistContentGenerator,
getCodeAssistServer,
} from './codeAssist.js';
import type { Config } from '../config/config.js';
import { LoggingContentGenerator } from '../core/loggingContentGenerator.js';
import { UserTierId } from './types.js';
// Mock dependencies
vi.mock('./oauth2.js');
vi.mock('./setup.js');
vi.mock('./server.js');
vi.mock('../core/loggingContentGenerator.js');
const mockedGetOauthClient = vi.mocked(getOauthClient);
const mockedSetupUser = vi.mocked(setupUser);
const MockedCodeAssistServer = vi.mocked(CodeAssistServer);
const MockedLoggingContentGenerator = vi.mocked(LoggingContentGenerator);
describe('codeAssist', () => {
beforeEach(() => {
vi.resetAllMocks();
});
describe('createCodeAssistContentGenerator', () => {
const httpOptions = {};
const mockConfig = {} as Config;
const mockAuthClient = { a: 'client' };
const mockUserData = {
projectId: 'test-project',
userTier: UserTierId.FREE,
};
it('should create a server for LOGIN_WITH_GOOGLE', async () => {
mockedGetOauthClient.mockResolvedValue(mockAuthClient as never);
mockedSetupUser.mockResolvedValue(mockUserData);
const generator = await createCodeAssistContentGenerator(
httpOptions,
AuthType.LOGIN_WITH_GOOGLE,
mockConfig,
'session-123',
);
expect(getOauthClient).toHaveBeenCalledWith(
AuthType.LOGIN_WITH_GOOGLE,
mockConfig,
);
expect(setupUser).toHaveBeenCalledWith(mockAuthClient);
expect(MockedCodeAssistServer).toHaveBeenCalledWith(
mockAuthClient,
'test-project',
httpOptions,
'session-123',
'free-tier',
);
expect(generator).toBeInstanceOf(MockedCodeAssistServer);
});
it('should create a server for CLOUD_SHELL', async () => {
mockedGetOauthClient.mockResolvedValue(mockAuthClient as never);
mockedSetupUser.mockResolvedValue(mockUserData);
const generator = await createCodeAssistContentGenerator(
httpOptions,
AuthType.CLOUD_SHELL,
mockConfig,
);
expect(getOauthClient).toHaveBeenCalledWith(
AuthType.CLOUD_SHELL,
mockConfig,
);
expect(setupUser).toHaveBeenCalledWith(mockAuthClient);
expect(MockedCodeAssistServer).toHaveBeenCalledWith(
mockAuthClient,
'test-project',
httpOptions,
undefined, // No session ID
'free-tier',
);
expect(generator).toBeInstanceOf(MockedCodeAssistServer);
});
it('should throw an error for unsupported auth types', async () => {
await expect(
createCodeAssistContentGenerator(
httpOptions,
'api-key' as AuthType, // Use literal string to avoid enum resolution issues
mockConfig,
),
).rejects.toThrow('Unsupported authType: api-key');
});
});
describe('getCodeAssistServer', () => {
it('should return the server if it is a CodeAssistServer', () => {
const mockServer = new MockedCodeAssistServer({} as never, '', {});
const mockConfig = {
getContentGenerator: () => mockServer,
} as unknown as Config;
const server = getCodeAssistServer(mockConfig);
expect(server).toBe(mockServer);
});
it('should unwrap and return the server if it is wrapped in a LoggingContentGenerator', () => {
const mockServer = new MockedCodeAssistServer({} as never, '', {});
const mockLogger = new MockedLoggingContentGenerator(
{} as never,
{} as never,
);
vi.spyOn(mockLogger, 'getWrapped').mockReturnValue(mockServer);
const mockConfig = {
getContentGenerator: () => mockLogger,
} as unknown as Config;
const server = getCodeAssistServer(mockConfig);
expect(server).toBe(mockServer);
expect(mockLogger.getWrapped).toHaveBeenCalled();
});
it('should return undefined if the content generator is not a CodeAssistServer', () => {
const mockGenerator = { a: 'generator' }; // Not a CodeAssistServer
const mockConfig = {
getContentGenerator: () => mockGenerator,
} as unknown as Config;
const server = getCodeAssistServer(mockConfig);
expect(server).toBeUndefined();
});
it('should return undefined if the wrapped generator is not a CodeAssistServer', () => {
const mockGenerator = { a: 'generator' }; // Not a CodeAssistServer
const mockLogger = new MockedLoggingContentGenerator(
{} as never,
{} as never,
);
vi.spyOn(mockLogger, 'getWrapped').mockReturnValue(
mockGenerator as never,
);
const mockConfig = {
getContentGenerator: () => mockLogger,
} as unknown as Config;
const server = getCodeAssistServer(mockConfig);
expect(server).toBeUndefined();
});
});
});
@@ -0,0 +1,112 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { ReleaseChannel, getReleaseChannel } from '../../utils/channel.js';
// Mock dependencies before importing the module under test
vi.mock('../../utils/channel.js', async () => {
const actual = await vi.importActual('../../utils/channel.js');
return {
...(actual as object),
getReleaseChannel: vi.fn(),
};
});
describe('client_metadata', () => {
const originalPlatform = process.platform;
const originalArch = process.arch;
const originalCliVersion = process.env['CLI_VERSION'];
const originalNodeVersion = process.version;
beforeEach(async () => {
// Reset modules to clear the cached `clientMetadataPromise`
vi.resetModules();
// Re-import the module to get a fresh instance
await import('./client_metadata.js');
// Provide a default mock implementation for each test
vi.mocked(getReleaseChannel).mockResolvedValue(ReleaseChannel.STABLE);
});
afterEach(() => {
// Restore original process properties to avoid side-effects between tests
Object.defineProperty(process, 'platform', { value: originalPlatform });
Object.defineProperty(process, 'arch', { value: originalArch });
process.env['CLI_VERSION'] = originalCliVersion;
Object.defineProperty(process, 'version', { value: originalNodeVersion });
vi.clearAllMocks();
});
describe('getPlatform', () => {
const testCases = [
{ platform: 'darwin', arch: 'x64', expected: 'DARWIN_AMD64' },
{ platform: 'darwin', arch: 'arm64', expected: 'DARWIN_ARM64' },
{ platform: 'linux', arch: 'x64', expected: 'LINUX_AMD64' },
{ platform: 'linux', arch: 'arm64', expected: 'LINUX_ARM64' },
{ platform: 'win32', arch: 'x64', expected: 'WINDOWS_AMD64' },
{ platform: 'sunos', arch: 'x64', expected: 'PLATFORM_UNSPECIFIED' },
{ platform: 'win32', arch: 'arm', expected: 'PLATFORM_UNSPECIFIED' },
];
for (const { platform, arch, expected } of testCases) {
it(`should return ${expected} for platform ${platform} and arch ${arch}`, async () => {
Object.defineProperty(process, 'platform', { value: platform });
Object.defineProperty(process, 'arch', { value: arch });
const { getClientMetadata } = await import('./client_metadata.js');
const metadata = await getClientMetadata();
expect(metadata.platform).toBe(expected);
});
}
});
describe('getClientMetadata', () => {
it('should use CLI_VERSION for ideVersion if set', async () => {
process.env['CLI_VERSION'] = '1.2.3';
Object.defineProperty(process, 'version', { value: 'v18.0.0' });
const { getClientMetadata } = await import('./client_metadata.js');
const metadata = await getClientMetadata();
expect(metadata.ideVersion).toBe('1.2.3');
});
it('should use process.version for ideVersion as a fallback', async () => {
delete process.env['CLI_VERSION'];
Object.defineProperty(process, 'version', { value: 'v20.0.0' });
const { getClientMetadata } = await import('./client_metadata.js');
const metadata = await getClientMetadata();
expect(metadata.ideVersion).toBe('v20.0.0');
});
it('should call getReleaseChannel to get the update channel', async () => {
vi.mocked(getReleaseChannel).mockResolvedValue(ReleaseChannel.NIGHTLY);
const { getClientMetadata } = await import('./client_metadata.js');
const metadata = await getClientMetadata();
expect(metadata.updateChannel).toBe('nightly');
expect(getReleaseChannel).toHaveBeenCalled();
});
it('should cache the client metadata promise', async () => {
const { getClientMetadata } = await import('./client_metadata.js');
const firstCall = await getClientMetadata();
const secondCall = await getClientMetadata();
expect(firstCall).toBe(secondCall);
// Ensure the underlying functions are only called once
expect(getReleaseChannel).toHaveBeenCalledTimes(1);
});
it('should always return the IDE name as GEMINI_CLI', async () => {
const { getClientMetadata } = await import('./client_metadata.js');
const metadata = await getClientMetadata();
expect(metadata.ideName).toBe('GEMINI_CLI');
});
});
});
@@ -0,0 +1,115 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import type { CodeAssistServer } from '../server.js';
import { getClientMetadata } from './client_metadata.js';
import type { ListExperimentsResponse, Flag } from './types.js';
// Mock dependencies before importing the module under test
vi.mock('../server.js');
vi.mock('./client_metadata.js');
describe('experiments', () => {
let mockServer: CodeAssistServer;
beforeEach(() => {
// Reset modules to clear the cached `experimentsPromise`
vi.resetModules();
// Mock the dependencies that `getExperiments` relies on
vi.mocked(getClientMetadata).mockResolvedValue({
ideName: 'GEMINI_CLI',
ideVersion: '1.0.0',
platform: 'LINUX_AMD64',
updateChannel: 'stable',
});
// Create a mock instance of the server for each test
mockServer = {
listExperiments: vi.fn(),
} as unknown as CodeAssistServer;
});
afterEach(() => {
vi.clearAllMocks();
});
it('should fetch and parse experiments from the server', async () => {
const { getExperiments } = await import('./experiments.js');
const mockApiResponse: ListExperimentsResponse = {
flags: [
{ name: 'flag1', boolValue: true },
{ name: 'flag2', stringValue: 'value' },
],
experimentIds: [123, 456],
};
vi.mocked(mockServer.listExperiments).mockResolvedValue(mockApiResponse);
const experiments = await getExperiments(mockServer);
// Verify that the dependencies were called
expect(getClientMetadata).toHaveBeenCalled();
expect(mockServer.listExperiments).toHaveBeenCalledWith(
await getClientMetadata(),
);
// Verify that the response was parsed correctly
expect(experiments.flags['flag1']).toEqual({
name: 'flag1',
boolValue: true,
});
expect(experiments.flags['flag2']).toEqual({
name: 'flag2',
stringValue: 'value',
});
expect(experiments.experimentIds).toEqual([123, 456]);
});
it('should handle an empty or partial response from the server', async () => {
const { getExperiments } = await import('./experiments.js');
const mockApiResponse: ListExperimentsResponse = {}; // No flags or experimentIds
vi.mocked(mockServer.listExperiments).mockResolvedValue(mockApiResponse);
const experiments = await getExperiments(mockServer);
expect(experiments.flags).toEqual({});
expect(experiments.experimentIds).toEqual([]);
});
it('should ignore flags that are missing a name', async () => {
const { getExperiments } = await import('./experiments.js');
const mockApiResponse: ListExperimentsResponse = {
flags: [
{ boolValue: true } as Flag, // No name
{ name: 'flag2', stringValue: 'value' },
],
};
vi.mocked(mockServer.listExperiments).mockResolvedValue(mockApiResponse);
const experiments = await getExperiments(mockServer);
expect(Object.keys(experiments.flags)).toHaveLength(1);
expect(experiments.flags['flag2']).toBeDefined();
expect(experiments.flags['undefined']).toBeUndefined();
});
it('should cache the experiments promise to avoid multiple fetches', async () => {
const { getExperiments } = await import('./experiments.js');
const mockApiResponse: ListExperimentsResponse = {
experimentIds: [1, 2, 3],
};
vi.mocked(mockServer.listExperiments).mockResolvedValue(mockApiResponse);
const firstCall = await getExperiments(mockServer);
const secondCall = await getExperiments(mockServer);
expect(firstCall).toBe(secondCall); // Should be the exact same promise object
// Verify the underlying functions were only called once
expect(getClientMetadata).toHaveBeenCalledTimes(1);
expect(mockServer.listExperiments).toHaveBeenCalledTimes(1);
});
});
@@ -173,6 +173,58 @@ describe('OAuthCredentialStorage', () => {
expect(result).toEqual(mockCredentials);
});
it('should throw an error if the migration file contains invalid JSON', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
null,
);
vi.spyOn(fs, 'readFile').mockResolvedValue('invalid json');
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
'Failed to load OAuth credentials',
);
});
it('should not delete the old file if saving migrated credentials fails', async () => {
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
null,
);
vi.spyOn(fs, 'readFile').mockResolvedValue(
JSON.stringify(mockCredentials),
);
vi.spyOn(mockHybridTokenStorage, 'setCredentials').mockRejectedValue(
new Error('Save failed'),
);
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
'Failed to load OAuth credentials',
);
expect(fs.rm).not.toHaveBeenCalled();
});
it('should return credentials even if access_token is missing from storage', async () => {
const partialMcpCredentials = {
...mockMcpCredentials,
token: {
...mockMcpCredentials.token,
accessToken: undefined,
},
};
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
partialMcpCredentials,
);
const result = await OAuthCredentialStorage.loadCredentials();
expect(result).toEqual({
access_token: undefined,
refresh_token: mockCredentials.refresh_token,
token_type: mockCredentials.token_type,
scope: mockCredentials.scope,
expiry_date: mockCredentials.expiry_date,
});
});
});
describe('saveCredentials', () => {
@@ -195,6 +247,28 @@ describe('OAuthCredentialStorage', () => {
'Attempted to save credentials without an access token.',
);
});
it('should handle saving credentials with null or undefined optional fields', async () => {
const partialCredentials: Credentials = {
access_token: 'only_access_token',
refresh_token: null, // test null
scope: undefined, // test undefined
};
await OAuthCredentialStorage.saveCredentials(partialCredentials);
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith({
serverName: 'main-account',
token: {
accessToken: 'only_access_token',
refreshToken: undefined,
tokenType: 'Bearer', // default
scope: undefined,
expiresAt: undefined,
},
updatedAt: expect.any(Number),
});
});
});
describe('clearCredentials', () => {
+138 -36
View File
@@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { beforeEach, describe, it, expect, vi } from 'vitest';
import { beforeEach, describe, it, expect, vi, afterEach } from 'vitest';
import { CodeAssistServer } from './server.js';
import { OAuth2Client } from 'google-auth-library';
import { UserTierId } from './types.js';
@@ -29,15 +29,16 @@ describe('CodeAssistServer', () => {
});
it('should call the generateContent endpoint', async () => {
const client = new OAuth2Client();
const mockRequest = vi.fn();
const client = { request: mockRequest } as unknown as OAuth2Client;
const server = new CodeAssistServer(
client,
'test-project',
{},
{ headers: { 'x-custom-header': 'test-value' } },
'test-session',
UserTierId.FREE,
);
const mockResponse = {
const mockResponseData = {
response: {
candidates: [
{
@@ -52,7 +53,7 @@ describe('CodeAssistServer', () => {
],
},
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
mockRequest.mockResolvedValue({ data: mockResponseData });
const response = await server.generateContent(
{
@@ -62,18 +63,59 @@ describe('CodeAssistServer', () => {
'user-prompt-id',
);
expect(server.requestPost).toHaveBeenCalledWith(
'generateContent',
expect.any(Object),
undefined,
);
expect(mockRequest).toHaveBeenCalledWith({
url: expect.stringContaining(':generateContent'),
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-custom-header': 'test-value',
},
responseType: 'json',
body: expect.any(String),
signal: undefined,
});
const requestBody = JSON.parse(mockRequest.mock.calls[0][0].body);
expect(requestBody.user_prompt_id).toBe('user-prompt-id');
expect(requestBody.project).toBe('test-project');
expect(response.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
'response',
);
});
it('should call the generateContentStream endpoint', async () => {
const client = new OAuth2Client();
describe('getMethodUrl', () => {
const originalEnv = process.env;
beforeEach(() => {
// Reset the environment variables to their original state
process.env = { ...originalEnv };
});
afterEach(() => {
// Restore the original environment variables
process.env = originalEnv;
});
it('should construct the default URL correctly', () => {
const server = new CodeAssistServer({} as never);
const url = server.getMethodUrl('testMethod');
expect(url).toBe(
'https://cloudcode-pa.googleapis.com/v1internal:testMethod',
);
});
it('should use the CODE_ASSIST_ENDPOINT environment variable if set', () => {
process.env['CODE_ASSIST_ENDPOINT'] = 'https://custom-endpoint.com';
const server = new CodeAssistServer({} as never);
const url = server.getMethodUrl('testMethod');
expect(url).toBe('https://custom-endpoint.com/v1internal:testMethod');
});
});
it('should call the generateContentStream endpoint and parse SSE', async () => {
const mockRequest = vi.fn();
const client = { request: mockRequest } as unknown as OAuth2Client;
const server = new CodeAssistServer(
client,
'test-project',
@@ -81,24 +123,21 @@ describe('CodeAssistServer', () => {
'test-session',
UserTierId.FREE,
);
const mockResponse = (async function* () {
yield {
response: {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [{ text: 'response' }],
},
finishReason: 'STOP',
safetyRatings: [],
},
],
},
};
})();
vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse);
// Create a mock readable stream
const { Readable } = await import('node:stream');
const mockStream = new Readable({
read() {},
});
const mockResponseData1 = {
response: { candidates: [{ content: { parts: [{ text: 'Hello' }] } }] },
};
const mockResponseData2 = {
response: { candidates: [{ content: { parts: [{ text: ' World' }] } }] },
};
mockRequest.mockResolvedValue({ data: mockStream });
const stream = await server.generateContentStream(
{
@@ -108,14 +147,61 @@ describe('CodeAssistServer', () => {
'user-prompt-id',
);
// Push SSE data to the stream
// Use setTimeout to ensure the stream processing has started
setTimeout(() => {
mockStream.push('data: ' + JSON.stringify(mockResponseData1) + '\n\n');
mockStream.push('id: 123\n'); // Should be ignored
mockStream.push('data: ' + JSON.stringify(mockResponseData2) + '\n\n');
mockStream.push(null); // End the stream
}, 0);
const results = [];
for await (const res of stream) {
expect(server.requestStreamingPost).toHaveBeenCalledWith(
'streamGenerateContent',
expect.any(Object),
undefined,
);
expect(res.candidates?.[0]?.content?.parts?.[0]?.text).toBe('response');
results.push(res);
}
expect(mockRequest).toHaveBeenCalledWith({
url: expect.stringContaining(':streamGenerateContent'),
method: 'POST',
params: { alt: 'sse' },
responseType: 'stream',
body: expect.any(String),
headers: {
'Content-Type': 'application/json',
},
signal: undefined,
});
expect(results).toHaveLength(2);
expect(results[0].candidates?.[0].content?.parts?.[0].text).toBe('Hello');
expect(results[1].candidates?.[0].content?.parts?.[0].text).toBe(' World');
});
it('should ignore malformed SSE data', async () => {
const mockRequest = vi.fn();
const client = { request: mockRequest } as unknown as OAuth2Client;
const server = new CodeAssistServer(client);
const { Readable } = await import('node:stream');
const mockStream = new Readable({
read() {},
});
mockRequest.mockResolvedValue({ data: mockStream });
const stream = await server.requestStreamingPost('testStream', {});
setTimeout(() => {
mockStream.push('this is a malformed line\n');
mockStream.push(null);
}, 0);
const results = [];
for await (const res of stream) {
results.push(res);
}
expect(results).toHaveLength(0);
});
it('should call the onboardUser endpoint', async () => {
@@ -253,6 +339,22 @@ describe('CodeAssistServer', () => {
});
});
it('should re-throw non-VPC-SC errors from loadCodeAssist', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(client);
const genericError = new Error('Something else went wrong');
vi.spyOn(server, 'requestPost').mockRejectedValue(genericError);
await expect(server.loadCodeAssist({ metadata: {} })).rejects.toThrow(
'Something else went wrong',
);
expect(server.requestPost).toHaveBeenCalledWith(
'loadCodeAssist',
expect.any(Object),
);
});
it('should call the listExperiments endpoint with metadata', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
+4 -6
View File
@@ -232,18 +232,16 @@ export class CodeAssistServer implements ContentGenerator {
let bufferedLines: string[] = [];
for await (const line of rl) {
// blank lines are used to separate JSON objects in the stream
if (line === '') {
if (line.startsWith('data: ')) {
bufferedLines.push(line.slice(6).trim());
} else if (line === '') {
if (bufferedLines.length === 0) {
continue; // no data to yield
}
yield JSON.parse(bufferedLines.join('\n')) as T;
bufferedLines = []; // Reset the buffer after yielding
} else if (line.startsWith('data: ')) {
bufferedLines.push(line.slice(6).trim());
} else {
throw new Error(`Unexpected line format in response: ${line}`);
}
// Ignore other lines like comments or id fields
}
})();
}
@@ -0,0 +1,257 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
const logApiRequest = vi.hoisted(() => vi.fn());
const logApiResponse = vi.hoisted(() => vi.fn());
const logApiError = vi.hoisted(() => vi.fn());
vi.mock('../telemetry/loggers.js', () => ({
logApiRequest,
logApiResponse,
logApiError,
}));
const runInDevTraceSpan = vi.hoisted(() =>
vi.fn(async (meta, fn) => fn({ metadata: {}, endSpan: vi.fn() })),
);
vi.mock('../telemetry/trace.js', () => ({
runInDevTraceSpan,
}));
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import type {
GenerateContentResponse,
EmbedContentResponse,
} from '@google/genai';
import type { ContentGenerator } from './contentGenerator.js';
import { LoggingContentGenerator } from './loggingContentGenerator.js';
import type { Config } from '../config/config.js';
import { ApiRequestEvent } from '../telemetry/types.js';
describe('LoggingContentGenerator', () => {
let wrapped: ContentGenerator;
let config: Config;
let loggingContentGenerator: LoggingContentGenerator;
beforeEach(() => {
wrapped = {
generateContent: vi.fn(),
generateContentStream: vi.fn(),
countTokens: vi.fn(),
embedContent: vi.fn(),
};
config = {
getGoogleAIConfig: vi.fn(),
getVertexAIConfig: vi.fn(),
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'API_KEY',
}),
} as unknown as Config;
loggingContentGenerator = new LoggingContentGenerator(wrapped, config);
vi.useFakeTimers();
});
afterEach(() => {
vi.clearAllMocks();
vi.useRealTimers();
});
describe('generateContent', () => {
it('should log request and response on success', async () => {
const req = {
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
model: 'gemini-pro',
};
const userPromptId = 'prompt-123';
const response: GenerateContentResponse = {
candidates: [],
usageMetadata: {
promptTokenCount: 1,
candidatesTokenCount: 2,
totalTokenCount: 3,
},
text: undefined,
functionCalls: undefined,
executableCode: undefined,
codeExecutionResult: undefined,
data: undefined,
};
vi.mocked(wrapped.generateContent).mockResolvedValue(response);
const startTime = new Date('2025-01-01T00:00:00.000Z');
vi.setSystemTime(startTime);
const promise = loggingContentGenerator.generateContent(
req,
userPromptId,
);
vi.advanceTimersByTime(1000);
await promise;
expect(wrapped.generateContent).toHaveBeenCalledWith(req, userPromptId);
expect(logApiRequest).toHaveBeenCalledWith(
config,
expect.any(ApiRequestEvent),
);
const responseEvent = vi.mocked(logApiResponse).mock.calls[0][1];
expect(responseEvent.duration_ms).toBe(1000);
});
it('should log error on failure', async () => {
const req = {
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
model: 'gemini-pro',
};
const userPromptId = 'prompt-123';
const error = new Error('test error');
vi.mocked(wrapped.generateContent).mockRejectedValue(error);
const startTime = new Date('2025-01-01T00:00:00.000Z');
vi.setSystemTime(startTime);
const promise = loggingContentGenerator.generateContent(
req,
userPromptId,
);
vi.advanceTimersByTime(1000);
await expect(promise).rejects.toThrow(error);
expect(logApiRequest).toHaveBeenCalledWith(
config,
expect.any(ApiRequestEvent),
);
const errorEvent = vi.mocked(logApiError).mock.calls[0][1];
expect(errorEvent.duration_ms).toBe(1000);
});
});
describe('generateContentStream', () => {
it('should log request and response on success', async () => {
const req = {
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
model: 'gemini-pro',
};
const userPromptId = 'prompt-123';
const response = {
candidates: [],
usageMetadata: {
promptTokenCount: 1,
candidatesTokenCount: 2,
totalTokenCount: 3,
},
} as unknown as GenerateContentResponse;
async function* createAsyncGenerator() {
yield response;
}
vi.mocked(wrapped.generateContentStream).mockResolvedValue(
createAsyncGenerator(),
);
const startTime = new Date('2025-01-01T00:00:00.000Z');
vi.setSystemTime(startTime);
const stream = await loggingContentGenerator.generateContentStream(
req,
userPromptId,
);
vi.advanceTimersByTime(1000);
for await (const _ of stream) {
// consume stream
}
expect(wrapped.generateContentStream).toHaveBeenCalledWith(
req,
userPromptId,
);
expect(logApiRequest).toHaveBeenCalledWith(
config,
expect.any(ApiRequestEvent),
);
const responseEvent = vi.mocked(logApiResponse).mock.calls[0][1];
expect(responseEvent.duration_ms).toBe(1000);
});
it('should log error on failure', async () => {
const req = {
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
model: 'gemini-pro',
};
const userPromptId = 'prompt-123';
const error = new Error('test error');
async function* createAsyncGenerator() {
yield Promise.reject(error);
}
vi.mocked(wrapped.generateContentStream).mockResolvedValue(
createAsyncGenerator(),
);
const startTime = new Date('2025-01-01T00:00:00.000Z');
vi.setSystemTime(startTime);
const stream = await loggingContentGenerator.generateContentStream(
req,
userPromptId,
);
vi.advanceTimersByTime(1000);
await expect(async () => {
for await (const _ of stream) {
// do nothing
}
}).rejects.toThrow(error);
expect(logApiRequest).toHaveBeenCalledWith(
config,
expect.any(ApiRequestEvent),
);
const errorEvent = vi.mocked(logApiError).mock.calls[0][1];
expect(errorEvent.duration_ms).toBe(1000);
});
});
describe('getWrapped', () => {
it('should return the wrapped content generator', () => {
expect(loggingContentGenerator.getWrapped()).toBe(wrapped);
});
});
describe('countTokens', () => {
it('should call the wrapped countTokens method', async () => {
const req = { contents: [], model: 'gemini-pro' };
const response = { totalTokens: 10 };
vi.mocked(wrapped.countTokens).mockResolvedValue(response);
const result = await loggingContentGenerator.countTokens(req);
expect(wrapped.countTokens).toHaveBeenCalledWith(req);
expect(result).toBe(response);
});
});
describe('embedContent', () => {
it('should call the wrapped embedContent method', async () => {
const req = {
contents: [{ role: 'user', parts: [] }],
model: 'gemini-pro',
};
const response: EmbedContentResponse = { embeddings: [{ values: [] }] };
vi.mocked(wrapped.embedContent).mockResolvedValue(response);
const result = await loggingContentGenerator.embedContent(req);
expect(wrapped.embedContent).toHaveBeenCalledWith(req);
expect(result).toBe(response);
});
});
});
@@ -0,0 +1,31 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { tokenLimit, DEFAULT_TOKEN_LIMIT } from './tokenLimits.js';
describe('tokenLimit', () => {
it('should return the correct token limit for gemini-1.5-pro', () => {
expect(tokenLimit('gemini-1.5-pro')).toBe(2_097_152);
});
it('should return the correct token limit for gemini-1.5-flash', () => {
expect(tokenLimit('gemini-1.5-flash')).toBe(1_048_576);
});
it('should return the default token limit for an unknown model', () => {
expect(tokenLimit('unknown-model')).toBe(DEFAULT_TOKEN_LIMIT);
});
it('should return the default token limit if no model is provided', () => {
// @ts-expect-error testing invalid input
expect(tokenLimit(undefined)).toBe(DEFAULT_TOKEN_LIMIT);
});
it('should have the correct default token limit value', () => {
expect(DEFAULT_TOKEN_LIMIT).toBe(1_048_576);
});
});
+349 -2
View File
@@ -4,8 +4,40 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { HookEventName, HookType } from './types.js';
import { describe, it, expect, vi } from 'vitest';
import {
createHookOutput,
DefaultHookOutput,
BeforeToolHookOutput,
BeforeModelHookOutput,
BeforeToolSelectionHookOutput,
AfterModelHookOutput,
HookEventName,
HookType,
} from './types.js';
import { defaultHookTranslator } from './hookTranslator.js';
import type {
GenerateContentParameters,
GenerateContentResponse,
ToolConfig,
} from '@google/genai';
import type { LLMRequest, LLMResponse } from './hookTranslator.js';
import type { HookDecision } from './types.js';
vi.mock('./hookTranslator.js', () => ({
defaultHookTranslator: {
fromHookLLMResponse: vi.fn(
(response: LLMResponse) => response as unknown as GenerateContentResponse,
),
fromHookLLMRequest: vi.fn(
(request: LLMRequest, target: GenerateContentParameters) => ({
...target,
...request,
}),
),
fromHookToolConfig: vi.fn((config: ToolConfig) => config),
},
}));
describe('Hook Types', () => {
describe('HookEventName', () => {
@@ -36,3 +68,318 @@ describe('Hook Types', () => {
});
});
});
describe('Hook Output Classes', () => {
describe('createHookOutput', () => {
it('should return DefaultHookOutput for unknown event names', () => {
const output = createHookOutput('UnknownEvent', {});
expect(output).toBeInstanceOf(DefaultHookOutput);
expect(output).not.toBeInstanceOf(BeforeModelHookOutput);
expect(output).not.toBeInstanceOf(AfterModelHookOutput);
expect(output).not.toBeInstanceOf(BeforeToolSelectionHookOutput);
});
it('should return BeforeModelHookOutput for BeforeModel event', () => {
const output = createHookOutput(HookEventName.BeforeModel, {});
expect(output).toBeInstanceOf(BeforeModelHookOutput);
});
it('should return AfterModelHookOutput for AfterModel event', () => {
const output = createHookOutput(HookEventName.AfterModel, {});
expect(output).toBeInstanceOf(AfterModelHookOutput);
});
it('should return BeforeToolSelectionHookOutput for BeforeToolSelection event', () => {
const output = createHookOutput(HookEventName.BeforeToolSelection, {});
expect(output).toBeInstanceOf(BeforeToolSelectionHookOutput);
});
});
describe('DefaultHookOutput', () => {
it('should construct with provided data', () => {
const data = {
continue: false,
stopReason: 'test stop',
suppressOutput: true,
systemMessage: 'test system message',
decision: 'block' as HookDecision,
reason: 'test reason',
hookSpecificOutput: { key: 'value' },
};
const output = new DefaultHookOutput(data);
expect(output.continue).toBe(data.continue);
expect(output.stopReason).toBe(data.stopReason);
expect(output.suppressOutput).toBe(data.suppressOutput);
expect(output.systemMessage).toBe(data.systemMessage);
expect(output.decision).toBe(data.decision);
expect(output.reason).toBe(data.reason);
expect(output.hookSpecificOutput).toEqual(data.hookSpecificOutput);
});
it('should return false for isBlockingDecision if decision is not block or deny', () => {
const output1 = new DefaultHookOutput({ decision: 'approve' });
expect(output1.isBlockingDecision()).toBe(false);
const output2 = new DefaultHookOutput({ decision: undefined });
expect(output2.isBlockingDecision()).toBe(false);
});
it('should return true for isBlockingDecision if decision is block or deny', () => {
const output1 = new DefaultHookOutput({ decision: 'block' });
expect(output1.isBlockingDecision()).toBe(true);
const output2 = new DefaultHookOutput({ decision: 'deny' });
expect(output2.isBlockingDecision()).toBe(true);
});
it('should return true for shouldStopExecution if continue is false', () => {
const output = new DefaultHookOutput({ continue: false });
expect(output.shouldStopExecution()).toBe(true);
});
it('should return false for shouldStopExecution if continue is true or undefined', () => {
const output1 = new DefaultHookOutput({ continue: true });
expect(output1.shouldStopExecution()).toBe(false);
const output2 = new DefaultHookOutput({});
expect(output2.shouldStopExecution()).toBe(false);
});
it('should return reason if available', () => {
const output = new DefaultHookOutput({ reason: 'specific reason' });
expect(output.getEffectiveReason()).toBe('specific reason');
});
it('should return stopReason if reason is not available', () => {
const output = new DefaultHookOutput({ stopReason: 'stop reason' });
expect(output.getEffectiveReason()).toBe('stop reason');
});
it('should return "No reason provided" if neither reason nor stopReason are available', () => {
const output = new DefaultHookOutput({});
expect(output.getEffectiveReason()).toBe('No reason provided');
});
it('applyLLMRequestModifications should return target unchanged', () => {
const target: GenerateContentParameters = {
model: 'gemini-pro',
contents: [],
};
const output = new DefaultHookOutput({});
expect(output.applyLLMRequestModifications(target)).toBe(target);
});
it('applyToolConfigModifications should return target unchanged', () => {
const target = { toolConfig: {}, tools: [] };
const output = new DefaultHookOutput({});
expect(output.applyToolConfigModifications(target)).toBe(target);
});
it('getAdditionalContext should return additionalContext if present', () => {
const output = new DefaultHookOutput({
hookSpecificOutput: { additionalContext: 'some context' },
});
expect(output.getAdditionalContext()).toBe('some context');
});
it('getAdditionalContext should return undefined if additionalContext is not present', () => {
const output = new DefaultHookOutput({
hookSpecificOutput: { other: 'value' },
});
expect(output.getAdditionalContext()).toBeUndefined();
});
it('getAdditionalContext should return undefined if hookSpecificOutput is undefined', () => {
const output = new DefaultHookOutput({});
expect(output.getAdditionalContext()).toBeUndefined();
});
it('getBlockingError should return blocked: true and reason if blocking decision', () => {
const output = new DefaultHookOutput({
decision: 'block',
reason: 'blocked by hook',
});
expect(output.getBlockingError()).toEqual({
blocked: true,
reason: 'blocked by hook',
});
});
it('getBlockingError should return blocked: false if not blocking decision', () => {
const output = new DefaultHookOutput({ decision: 'approve' });
expect(output.getBlockingError()).toEqual({ blocked: false, reason: '' });
});
});
describe('BeforeToolHookOutput', () => {
it('isBlockingDecision should use permissionDecision from hookSpecificOutput', () => {
const output1 = new BeforeToolHookOutput({
hookSpecificOutput: { permissionDecision: 'block' },
});
expect(output1.isBlockingDecision()).toBe(true);
const output2 = new BeforeToolHookOutput({
hookSpecificOutput: { permissionDecision: 'approve' },
});
expect(output2.isBlockingDecision()).toBe(false);
});
it('getEffectiveReason should use permissionDecisionReason from hookSpecificOutput', () => {
const output1 = new BeforeToolHookOutput({
hookSpecificOutput: { permissionDecisionReason: 'compat reason' },
});
expect(output1.getEffectiveReason()).toBe('compat reason');
const output2 = new BeforeToolHookOutput({
reason: 'default reason',
hookSpecificOutput: { other: 'value' },
});
expect(output2.getEffectiveReason()).toBe('default reason');
});
});
describe('BeforeModelHookOutput', () => {
it('getSyntheticResponse should return synthetic response if llm_response is present', () => {
const mockResponse: LLMResponse = { candidates: [] };
const output = new BeforeModelHookOutput({
hookSpecificOutput: { llm_response: mockResponse },
});
expect(output.getSyntheticResponse()).toEqual(mockResponse);
expect(defaultHookTranslator.fromHookLLMResponse).toHaveBeenCalledWith(
mockResponse,
);
});
it('getSyntheticResponse should return undefined if llm_response is not present', () => {
const output = new BeforeModelHookOutput({});
expect(output.getSyntheticResponse()).toBeUndefined();
});
it('applyLLMRequestModifications should apply modifications if llm_request is present', () => {
const target: GenerateContentParameters = {
model: 'gemini-pro',
contents: [{ parts: [{ text: 'original' }] }],
};
const mockRequest: Partial<LLMRequest> = {
messages: [{ role: 'user', content: 'modified' }],
};
const output = new BeforeModelHookOutput({
hookSpecificOutput: { llm_request: mockRequest },
});
const result = output.applyLLMRequestModifications(target);
expect(result).toEqual({ ...target, ...mockRequest });
expect(defaultHookTranslator.fromHookLLMRequest).toHaveBeenCalledWith(
mockRequest,
target,
);
});
it('applyLLMRequestModifications should return target unchanged if llm_request is not present', () => {
const target: GenerateContentParameters = {
model: 'gemini-pro',
contents: [],
};
const output = new BeforeModelHookOutput({});
expect(output.applyLLMRequestModifications(target)).toBe(target);
});
});
describe('BeforeToolSelectionHookOutput', () => {
it('applyToolConfigModifications should apply modifications if toolConfig is present', () => {
const target = { tools: [{ functionDeclarations: [] }] };
const mockToolConfig = { functionCallingConfig: { mode: 'ANY' } };
const output = new BeforeToolSelectionHookOutput({
hookSpecificOutput: { toolConfig: mockToolConfig },
});
const result = output.applyToolConfigModifications(target);
expect(result).toEqual({ ...target, toolConfig: mockToolConfig });
expect(defaultHookTranslator.fromHookToolConfig).toHaveBeenCalledWith(
mockToolConfig,
);
});
it('applyToolConfigModifications should return target unchanged if toolConfig is not present', () => {
const target = { toolConfig: {}, tools: [] };
const output = new BeforeToolSelectionHookOutput({});
expect(output.applyToolConfigModifications(target)).toBe(target);
});
it('applyToolConfigModifications should initialize tools array if not present', () => {
const target = {};
const mockToolConfig = { functionCallingConfig: { mode: 'ANY' } };
const output = new BeforeToolSelectionHookOutput({
hookSpecificOutput: { toolConfig: mockToolConfig },
});
const result = output.applyToolConfigModifications(target);
expect(result).toEqual({ tools: [], toolConfig: mockToolConfig });
});
});
describe('AfterModelHookOutput', () => {
it('getModifiedResponse should return modified response if llm_response is present and has content', () => {
const mockResponse: LLMResponse = {
candidates: [{ content: { role: 'model', parts: ['modified'] } }],
};
const output = new AfterModelHookOutput({
hookSpecificOutput: { llm_response: mockResponse },
});
expect(output.getModifiedResponse()).toEqual(mockResponse);
expect(defaultHookTranslator.fromHookLLMResponse).toHaveBeenCalledWith(
mockResponse,
);
});
it('getModifiedResponse should return undefined if llm_response is present but no content', () => {
const mockResponse: LLMResponse = {
candidates: [{ content: { role: 'model', parts: [] } }],
};
const output = new AfterModelHookOutput({
hookSpecificOutput: { llm_response: mockResponse },
});
expect(output.getModifiedResponse()).toBeUndefined();
});
it('getModifiedResponse should return undefined if llm_response is not present', () => {
const output = new AfterModelHookOutput({});
expect(output.getModifiedResponse()).toBeUndefined();
});
it('getModifiedResponse should return a synthetic stop response if shouldStopExecution is true', () => {
const output = new AfterModelHookOutput({
continue: false,
stopReason: 'stopped by hook',
});
const expectedResponse: LLMResponse = {
candidates: [
{
content: {
role: 'model',
parts: ['stopped by hook'],
},
finishReason: 'STOP',
},
],
};
expect(output.getModifiedResponse()).toEqual(expectedResponse);
expect(defaultHookTranslator.fromHookLLMResponse).toHaveBeenCalledWith(
expectedResponse,
);
});
it('getModifiedResponse should return a synthetic stop response with default reason if shouldStopExecution is true and no stopReason', () => {
const output = new AfterModelHookOutput({ continue: false });
const expectedResponse: LLMResponse = {
candidates: [
{
content: {
role: 'model',
parts: ['No reason provided'],
},
finishReason: 'STOP',
},
],
};
expect(output.getModifiedResponse()).toEqual(expectedResponse);
expect(defaultHookTranslator.fromHookLLMResponse).toHaveBeenCalledWith(
expectedResponse,
);
});
});
});
+1 -1
View File
@@ -340,7 +340,7 @@ export class AfterModelHookOutput extends DefaultHookOutput {
const hookResponse = this.hookSpecificOutput[
'llm_response'
] as Partial<LLMResponse>;
if (hookResponse?.candidates?.[0]?.content) {
if (hookResponse?.candidates?.[0]?.content?.parts?.length) {
// Convert hook format to SDK format
return defaultHookTranslator.fromHookLLMResponse(
hookResponse as LLMResponse,
+248
View File
@@ -647,6 +647,254 @@ describe('IdeClient', () => {
});
});
describe('resolveDiffFromCli', () => {
beforeEach(async () => {
// Ensure client is "connected" for these tests
const ideClient = await IdeClient.getInstance();
// We need to set the client property on the instance for openDiff to work
(ideClient as unknown as { client: Client }).client = mockClient;
mockClient.request.mockResolvedValue({
isError: false,
content: [],
});
});
it("should resolve an open diff as 'accepted' and return the final content", async () => {
const ideClient = await IdeClient.getInstance();
const closeDiffSpy = vi
.spyOn(
ideClient as unknown as {
closeDiff: () => Promise<string | undefined>;
},
'closeDiff',
)
.mockResolvedValue('final content from ide');
const diffPromise = ideClient.openDiff('/test.txt', 'new content');
// Yield to the event loop to allow the openDiff promise executor to run
await new Promise((resolve) => setImmediate(resolve));
await ideClient.resolveDiffFromCli('/test.txt', 'accepted');
const result = await diffPromise;
expect(result).toEqual({
status: 'accepted',
content: 'final content from ide',
});
expect(closeDiffSpy).toHaveBeenCalledWith('/test.txt', {
suppressNotification: true,
});
expect(
(
ideClient as unknown as { diffResponses: Map<string, unknown> }
).diffResponses.has('/test.txt'),
).toBe(false);
});
it("should resolve an open diff as 'rejected'", async () => {
const ideClient = await IdeClient.getInstance();
const closeDiffSpy = vi
.spyOn(
ideClient as unknown as {
closeDiff: () => Promise<string | undefined>;
},
'closeDiff',
)
.mockResolvedValue(undefined);
const diffPromise = ideClient.openDiff('/test.txt', 'new content');
// Yield to the event loop to allow the openDiff promise executor to run
await new Promise((resolve) => setImmediate(resolve));
await ideClient.resolveDiffFromCli('/test.txt', 'rejected');
const result = await diffPromise;
expect(result).toEqual({
status: 'rejected',
content: undefined,
});
expect(closeDiffSpy).toHaveBeenCalledWith('/test.txt', {
suppressNotification: true,
});
expect(
(
ideClient as unknown as { diffResponses: Map<string, unknown> }
).diffResponses.has('/test.txt'),
).toBe(false);
});
it('should do nothing if no diff is open for the given file path', async () => {
const ideClient = await IdeClient.getInstance();
const closeDiffSpy = vi
.spyOn(
ideClient as unknown as {
closeDiff: () => Promise<string | undefined>;
},
'closeDiff',
)
.mockResolvedValue(undefined);
// No call to openDiff, so no resolver will exist.
await ideClient.resolveDiffFromCli('/non-existent.txt', 'accepted');
expect(closeDiffSpy).toHaveBeenCalledWith('/non-existent.txt', {
suppressNotification: true,
});
// No crash should occur, and nothing should be in the map.
expect(
(
ideClient as unknown as { diffResponses: Map<string, unknown> }
).diffResponses.has('/non-existent.txt'),
).toBe(false);
});
});
describe('closeDiff', () => {
beforeEach(async () => {
const ideClient = await IdeClient.getInstance();
(ideClient as unknown as { client: Client }).client = mockClient;
});
it('should return undefined if client is not connected', async () => {
const ideClient = await IdeClient.getInstance();
(ideClient as unknown as { client: Client | undefined }).client =
undefined;
const result = await (
ideClient as unknown as { closeDiff: (f: string) => Promise<void> }
).closeDiff('/test.txt');
expect(result).toBeUndefined();
});
it('should call client.request with correct arguments', async () => {
const ideClient = await IdeClient.getInstance();
// Return a valid, empty response as the return value is not under test here.
mockClient.request.mockResolvedValue({ isError: false, content: [] });
await (
ideClient as unknown as {
closeDiff: (
f: string,
o?: { suppressNotification?: boolean },
) => Promise<void>;
}
).closeDiff('/test.txt', { suppressNotification: true });
expect(mockClient.request).toHaveBeenCalledWith(
expect.objectContaining({
params: {
name: 'closeDiff',
arguments: {
filePath: '/test.txt',
suppressNotification: true,
},
},
}),
expect.any(Object), // Schema
expect.any(Object), // Options
);
});
it('should return content from a valid JSON response', async () => {
const ideClient = await IdeClient.getInstance();
const response = {
isError: false,
content: [
{ type: 'text', text: JSON.stringify({ content: 'file content' }) },
],
};
mockClient.request.mockResolvedValue(response);
const result = await (
ideClient as unknown as { closeDiff: (f: string) => Promise<string> }
).closeDiff('/test.txt');
expect(result).toBe('file content');
});
it('should return undefined for a valid JSON response with null content', async () => {
const ideClient = await IdeClient.getInstance();
const response = {
isError: false,
content: [{ type: 'text', text: JSON.stringify({ content: null }) }],
};
mockClient.request.mockResolvedValue(response);
const result = await (
ideClient as unknown as { closeDiff: (f: string) => Promise<void> }
).closeDiff('/test.txt');
expect(result).toBeUndefined();
});
it('should return undefined if response is not valid JSON', async () => {
const ideClient = await IdeClient.getInstance();
const response = {
isError: false,
content: [{ type: 'text', text: 'not json' }],
};
mockClient.request.mockResolvedValue(response);
const result = await (
ideClient as unknown as { closeDiff: (f: string) => Promise<void> }
).closeDiff('/test.txt');
expect(result).toBeUndefined();
});
it('should return undefined if request result has isError: true', async () => {
const ideClient = await IdeClient.getInstance();
const response = {
isError: true,
content: [{ type: 'text', text: 'An error occurred' }],
};
mockClient.request.mockResolvedValue(response);
const result = await (
ideClient as unknown as { closeDiff: (f: string) => Promise<void> }
).closeDiff('/test.txt');
expect(result).toBeUndefined();
});
it('should return undefined if client.request throws', async () => {
const ideClient = await IdeClient.getInstance();
mockClient.request.mockRejectedValue(new Error('Request failed'));
const result = await (
ideClient as unknown as { closeDiff: (f: string) => Promise<void> }
).closeDiff('/test.txt');
expect(result).toBeUndefined();
});
it('should return undefined if response has no text part', async () => {
const ideClient = await IdeClient.getInstance();
const response = {
isError: false,
content: [{ type: 'other' }],
};
mockClient.request.mockResolvedValue(response);
const result = await (
ideClient as unknown as { closeDiff: (f: string) => Promise<void> }
).closeDiff('/test.txt');
expect(result).toBeUndefined();
});
it('should return undefined if response is falsy', async () => {
const ideClient = await IdeClient.getInstance();
// Mocking with `null as any` to test the falsy path, as the mock
// function is strictly typed.
// eslint-disable-next-line @typescript-eslint/no-explicit-any
mockClient.request.mockResolvedValue(null as any);
const result = await (
ideClient as unknown as { closeDiff: (f: string) => Promise<void> }
).closeDiff('/test.txt');
expect(result).toBeUndefined();
});
});
describe('authentication', () => {
it('should connect with an auth token if provided in the discovery file', async () => {
const authToken = 'test-auth-token';
@@ -714,6 +714,269 @@ describe('MCPOAuthProvider', () => {
).rejects.toThrow('Token exchange failed: invalid_grant - Invalid grant');
});
it('should handle OAuth discovery failure', async () => {
const configWithoutAuth: MCPOAuthConfig = { ...mockConfig };
delete configWithoutAuth.authorizationUrl;
delete configWithoutAuth.tokenUrl;
mockFetch.mockResolvedValueOnce(
createMockResponse({
ok: false,
status: 404,
}),
);
const authProvider = new MCPOAuthProvider();
await expect(
authProvider.authenticate(
'test-server',
configWithoutAuth,
'https://api.example.com',
),
).rejects.toThrow(
'Failed to discover OAuth configuration from MCP server',
);
});
it('should handle authorization server metadata discovery failure', async () => {
const configWithoutClient: MCPOAuthConfig = { ...mockConfig };
delete configWithoutClient.clientId;
mockFetch.mockResolvedValue(
createMockResponse({
ok: false,
status: 404,
}),
);
// Prevent callback server from hanging the test
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
});
const authProvider = new MCPOAuthProvider();
await expect(
authProvider.authenticate('test-server', configWithoutClient),
).rejects.toThrow(
'Failed to fetch authorization server metadata for client registration',
);
});
it('should handle invalid callback request', async () => {
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
callbackHandler = handler;
return mockHttpServer as unknown as http.Server;
});
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
setTimeout(() => {
const mockReq = {
url: '/invalid-path',
};
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
};
(callbackHandler as (req: unknown, res: unknown) => void)(
mockReq,
mockRes,
);
}, 0);
});
const authProvider = new MCPOAuthProvider();
// The test will timeout if the server does not handle the invalid request correctly.
// We are testing that the server does not hang.
await Promise.race([
authProvider.authenticate('test-server', mockConfig),
new Promise((resolve) => setTimeout(resolve, 1000)),
]);
});
it('should handle token exchange failure with non-json response', async () => {
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
callbackHandler = handler;
return mockHttpServer as unknown as http.Server;
});
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
setTimeout(() => {
const mockReq = {
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
};
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
};
(callbackHandler as (req: unknown, res: unknown) => void)(
mockReq,
mockRes,
);
}, 10);
});
mockFetch.mockResolvedValueOnce(
createMockResponse({
ok: false,
status: 500,
contentType: 'text/html',
text: 'Internal Server Error',
}),
);
const authProvider = new MCPOAuthProvider();
await expect(
authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('Token exchange failed: 500 - Internal Server Error');
});
it('should handle token exchange with unexpected content type', async () => {
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
callbackHandler = handler;
return mockHttpServer as unknown as http.Server;
});
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
setTimeout(() => {
const mockReq = {
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
};
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
};
(callbackHandler as (req: unknown, res: unknown) => void)(
mockReq,
mockRes,
);
}, 10);
});
mockFetch.mockResolvedValueOnce(
createMockResponse({
ok: true,
contentType: 'text/plain',
text: 'access_token=plain_text_token',
}),
);
const authProvider = new MCPOAuthProvider();
const result = await authProvider.authenticate('test-server', mockConfig);
expect(result.accessToken).toBe('plain_text_token');
});
it('should handle refresh token failure with non-json response', async () => {
mockFetch.mockResolvedValueOnce(
createMockResponse({
ok: false,
status: 500,
contentType: 'text/html',
text: 'Internal Server Error',
}),
);
const authProvider = new MCPOAuthProvider();
await expect(
authProvider.refreshAccessToken(
mockConfig,
'invalid_refresh_token',
'https://auth.example.com/token',
),
).rejects.toThrow('Token refresh failed: 500 - Internal Server Error');
});
it('should handle refresh token with unexpected content type', async () => {
mockFetch.mockResolvedValueOnce(
createMockResponse({
ok: true,
contentType: 'text/plain',
text: 'access_token=plain_text_token',
}),
);
const authProvider = new MCPOAuthProvider();
const result = await authProvider.refreshAccessToken(
mockConfig,
'refresh_token',
'https://auth.example.com/token',
);
expect(result.access_token).toBe('plain_text_token');
});
it('should continue authentication when browser fails to open', async () => {
mockOpenBrowserSecurely.mockRejectedValue(new Error('Browser not found'));
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
callbackHandler = handler;
return mockHttpServer as unknown as http.Server;
});
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
setTimeout(() => {
const mockReq = {
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
};
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
};
(callbackHandler as (req: unknown, res: unknown) => void)(
mockReq,
mockRes,
);
}, 10);
});
mockFetch.mockResolvedValueOnce(
createMockResponse({
ok: true,
contentType: 'application/json',
text: JSON.stringify(mockTokenResponse),
json: mockTokenResponse,
}),
);
const authProvider = new MCPOAuthProvider();
const result = await authProvider.authenticate('test-server', mockConfig);
expect(result).toBeDefined();
});
it('should return null when token is expired and no refresh token is available', async () => {
const expiredCredentials = {
serverName: 'test-server',
token: {
...mockToken,
refreshToken: undefined,
expiresAt: Date.now() - 3600000,
},
clientId: 'test-client-id',
tokenUrl: 'https://auth.example.com/token',
updatedAt: Date.now(),
};
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
expiredCredentials,
);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server',
mockConfig,
);
expect(result).toBeNull();
});
it('should handle callback timeout', async () => {
vi.mocked(http.createServer).mockImplementation(
() => mockHttpServer as unknown as http.Server,
@@ -0,0 +1,50 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi } from 'vitest';
import { getMCPServerPrompts } from './mcp-prompts.js';
import type { Config } from '../config/config.js';
import { PromptRegistry } from './prompt-registry.js';
import type { DiscoveredMCPPrompt } from '../tools/mcp-client.js';
describe('getMCPServerPrompts', () => {
it('should return prompts from the registry for a given server', () => {
const mockPrompts: DiscoveredMCPPrompt[] = [
{
name: 'prompt1',
serverName: 'server1',
tool: { name: 'p1', description: '', inputSchema: {} },
invoke: async () => ({
messages: [
{ role: 'assistant', content: { type: 'text', text: '' } },
],
}),
},
];
const mockRegistry = new PromptRegistry();
vi.spyOn(mockRegistry, 'getPromptsByServer').mockReturnValue(mockPrompts);
const mockConfig = {
getPromptRegistry: () => mockRegistry,
} as unknown as Config;
const result = getMCPServerPrompts(mockConfig, 'server1');
expect(mockRegistry.getPromptsByServer).toHaveBeenCalledWith('server1');
expect(result).toEqual(mockPrompts);
});
it('should return an empty array if there is no prompt registry', () => {
const mockConfig = {
getPromptRegistry: () => undefined,
} as unknown as Config;
const result = getMCPServerPrompts(mockConfig, 'server1');
expect(result).toEqual([]);
});
});
@@ -0,0 +1,129 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, vi } from 'vitest';
import { PromptRegistry } from './prompt-registry.js';
import type { DiscoveredMCPPrompt } from '../tools/mcp-client.js';
import { debugLogger } from '../utils/debugLogger.js';
vi.mock('../utils/debugLogger.js', () => ({
debugLogger: {
warn: vi.fn(),
},
}));
describe('PromptRegistry', () => {
let registry: PromptRegistry;
const prompt1: DiscoveredMCPPrompt = {
name: 'prompt1',
serverName: 'server1',
tool: {
name: 'prompt1',
description: 'Prompt 1',
inputSchema: {},
},
invoke: async () => ({
messages: [
{ role: 'assistant', content: { type: 'text', text: 'response1' } },
],
}),
};
const prompt2: DiscoveredMCPPrompt = {
name: 'prompt2',
serverName: 'server1',
tool: {
name: 'prompt2',
description: 'Prompt 2',
inputSchema: {},
},
invoke: async () => ({
messages: [
{ role: 'assistant', content: { type: 'text', text: 'response2' } },
],
}),
};
const prompt3: DiscoveredMCPPrompt = {
name: 'prompt1',
serverName: 'server2',
tool: {
name: 'prompt1',
description: 'Prompt 3',
inputSchema: {},
},
invoke: async () => ({
messages: [
{ role: 'assistant', content: { type: 'text', text: 'response3' } },
],
}),
};
beforeEach(() => {
registry = new PromptRegistry();
vi.clearAllMocks();
});
it('should register a prompt', () => {
registry.registerPrompt(prompt1);
expect(registry.getPrompt('prompt1')).toEqual(prompt1);
});
it('should get all prompts, sorted by name', () => {
registry.registerPrompt(prompt2);
registry.registerPrompt(prompt1);
expect(registry.getAllPrompts()).toEqual([prompt1, prompt2]);
});
it('should get a specific prompt by name', () => {
registry.registerPrompt(prompt1);
expect(registry.getPrompt('prompt1')).toEqual(prompt1);
expect(registry.getPrompt('non-existent')).toBeUndefined();
});
it('should get prompts by server, sorted by name', () => {
registry.registerPrompt(prompt1);
registry.registerPrompt(prompt2);
registry.registerPrompt(prompt3); // different server
expect(registry.getPromptsByServer('server1')).toEqual([prompt1, prompt2]);
expect(registry.getPromptsByServer('server2')).toEqual([
{ ...prompt3, name: 'server2_prompt1' },
]);
});
it('should handle prompt name collision by renaming', () => {
registry.registerPrompt(prompt1);
registry.registerPrompt(prompt3);
expect(registry.getPrompt('prompt1')).toEqual(prompt1);
const renamedPrompt = { ...prompt3, name: 'server2_prompt1' };
expect(registry.getPrompt('server2_prompt1')).toEqual(renamedPrompt);
expect(debugLogger.warn).toHaveBeenCalledWith(
'Prompt with name "prompt1" is already registered. Renaming to "server2_prompt1".',
);
});
it('should clear all prompts', () => {
registry.registerPrompt(prompt1);
registry.registerPrompt(prompt2);
registry.clear();
expect(registry.getAllPrompts()).toEqual([]);
});
it('should remove prompts by server', () => {
registry.registerPrompt(prompt1);
registry.registerPrompt(prompt2);
registry.registerPrompt(prompt3);
registry.removePromptsByServer('server1');
const renamedPrompt = { ...prompt3, name: 'server2_prompt1' };
expect(registry.getAllPrompts()).toEqual([renamedPrompt]);
expect(registry.getPrompt('prompt1')).toBeUndefined();
expect(registry.getPrompt('prompt2')).toBeUndefined();
expect(registry.getPrompt('server2_prompt1')).toEqual(renamedPrompt);
});
});
@@ -394,6 +394,15 @@ describe('ClearcutLogger', () => {
},
expected: 'devin',
},
{
name: 'unidentified',
env: {
GITHUB_SHA: undefined,
TERM_PROGRAM: undefined,
SURFACE: undefined,
},
expected: 'SURFACE_NOT_SET',
},
])(
'logs the current surface as $expected from $name',
({ env, expected }) => {
@@ -943,6 +952,33 @@ describe('ClearcutLogger', () => {
});
});
describe('flushIfNeeded', () => {
it('should not flush if the interval has not passed', () => {
const { logger } = setup();
const flushSpy = vi
// eslint-disable-next-line @typescript-eslint/no-explicit-any
.spyOn(logger!, 'flushToClearcut' as any)
.mockResolvedValue({ nextRequestWaitMs: 0 });
logger!.flushIfNeeded();
expect(flushSpy).not.toHaveBeenCalled();
});
it('should flush if the interval has passed', async () => {
const { logger } = setup();
const flushSpy = vi
// eslint-disable-next-line @typescript-eslint/no-explicit-any
.spyOn(logger!, 'flushToClearcut' as any)
.mockResolvedValue({ nextRequestWaitMs: 0 });
// Advance time by more than the flush interval
await vi.advanceTimersByTimeAsync(1000 * 60 * 2);
logger!.flushIfNeeded();
expect(flushSpy).toHaveBeenCalled();
});
});
describe('logWebFetchFallbackAttemptEvent', () => {
it('logs an event with the proper name and reason', () => {
const { logger } = setup();
+1 -1
View File
@@ -556,7 +556,7 @@ export function hasCycleInSchema(schema: object): boolean {
if ('$ref' in node && typeof node.$ref === 'string') {
const ref = node.$ref;
if (ref === '#/' || pathRefs.has(ref)) {
if (ref === '#' || ref === '#/' || pathRefs.has(ref)) {
// A ref to just '#/' is always a cycle.
return true; // Cycle detected!
}