mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-12 22:31:12 -07:00
Increase code coverage for core packages (#12872)
This commit is contained in:
163
packages/core/src/code_assist/codeAssist.test.ts
Normal file
163
packages/core/src/code_assist/codeAssist.test.ts
Normal file
@@ -0,0 +1,163 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import { getOauthClient } from './oauth2.js';
|
||||
import { setupUser } from './setup.js';
|
||||
import { CodeAssistServer } from './server.js';
|
||||
import {
|
||||
createCodeAssistContentGenerator,
|
||||
getCodeAssistServer,
|
||||
} from './codeAssist.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { LoggingContentGenerator } from '../core/loggingContentGenerator.js';
|
||||
import { UserTierId } from './types.js';
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('./oauth2.js');
|
||||
vi.mock('./setup.js');
|
||||
vi.mock('./server.js');
|
||||
vi.mock('../core/loggingContentGenerator.js');
|
||||
|
||||
const mockedGetOauthClient = vi.mocked(getOauthClient);
|
||||
const mockedSetupUser = vi.mocked(setupUser);
|
||||
const MockedCodeAssistServer = vi.mocked(CodeAssistServer);
|
||||
const MockedLoggingContentGenerator = vi.mocked(LoggingContentGenerator);
|
||||
|
||||
describe('codeAssist', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
});
|
||||
|
||||
describe('createCodeAssistContentGenerator', () => {
|
||||
const httpOptions = {};
|
||||
const mockConfig = {} as Config;
|
||||
const mockAuthClient = { a: 'client' };
|
||||
const mockUserData = {
|
||||
projectId: 'test-project',
|
||||
userTier: UserTierId.FREE,
|
||||
};
|
||||
|
||||
it('should create a server for LOGIN_WITH_GOOGLE', async () => {
|
||||
mockedGetOauthClient.mockResolvedValue(mockAuthClient as never);
|
||||
mockedSetupUser.mockResolvedValue(mockUserData);
|
||||
|
||||
const generator = await createCodeAssistContentGenerator(
|
||||
httpOptions,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
mockConfig,
|
||||
'session-123',
|
||||
);
|
||||
|
||||
expect(getOauthClient).toHaveBeenCalledWith(
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
mockConfig,
|
||||
);
|
||||
expect(setupUser).toHaveBeenCalledWith(mockAuthClient);
|
||||
expect(MockedCodeAssistServer).toHaveBeenCalledWith(
|
||||
mockAuthClient,
|
||||
'test-project',
|
||||
httpOptions,
|
||||
'session-123',
|
||||
'free-tier',
|
||||
);
|
||||
expect(generator).toBeInstanceOf(MockedCodeAssistServer);
|
||||
});
|
||||
|
||||
it('should create a server for CLOUD_SHELL', async () => {
|
||||
mockedGetOauthClient.mockResolvedValue(mockAuthClient as never);
|
||||
mockedSetupUser.mockResolvedValue(mockUserData);
|
||||
|
||||
const generator = await createCodeAssistContentGenerator(
|
||||
httpOptions,
|
||||
AuthType.CLOUD_SHELL,
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(getOauthClient).toHaveBeenCalledWith(
|
||||
AuthType.CLOUD_SHELL,
|
||||
mockConfig,
|
||||
);
|
||||
expect(setupUser).toHaveBeenCalledWith(mockAuthClient);
|
||||
expect(MockedCodeAssistServer).toHaveBeenCalledWith(
|
||||
mockAuthClient,
|
||||
'test-project',
|
||||
httpOptions,
|
||||
undefined, // No session ID
|
||||
'free-tier',
|
||||
);
|
||||
expect(generator).toBeInstanceOf(MockedCodeAssistServer);
|
||||
});
|
||||
|
||||
it('should throw an error for unsupported auth types', async () => {
|
||||
await expect(
|
||||
createCodeAssistContentGenerator(
|
||||
httpOptions,
|
||||
'api-key' as AuthType, // Use literal string to avoid enum resolution issues
|
||||
mockConfig,
|
||||
),
|
||||
).rejects.toThrow('Unsupported authType: api-key');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getCodeAssistServer', () => {
|
||||
it('should return the server if it is a CodeAssistServer', () => {
|
||||
const mockServer = new MockedCodeAssistServer({} as never, '', {});
|
||||
const mockConfig = {
|
||||
getContentGenerator: () => mockServer,
|
||||
} as unknown as Config;
|
||||
|
||||
const server = getCodeAssistServer(mockConfig);
|
||||
expect(server).toBe(mockServer);
|
||||
});
|
||||
|
||||
it('should unwrap and return the server if it is wrapped in a LoggingContentGenerator', () => {
|
||||
const mockServer = new MockedCodeAssistServer({} as never, '', {});
|
||||
const mockLogger = new MockedLoggingContentGenerator(
|
||||
{} as never,
|
||||
{} as never,
|
||||
);
|
||||
vi.spyOn(mockLogger, 'getWrapped').mockReturnValue(mockServer);
|
||||
|
||||
const mockConfig = {
|
||||
getContentGenerator: () => mockLogger,
|
||||
} as unknown as Config;
|
||||
|
||||
const server = getCodeAssistServer(mockConfig);
|
||||
expect(server).toBe(mockServer);
|
||||
expect(mockLogger.getWrapped).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return undefined if the content generator is not a CodeAssistServer', () => {
|
||||
const mockGenerator = { a: 'generator' }; // Not a CodeAssistServer
|
||||
const mockConfig = {
|
||||
getContentGenerator: () => mockGenerator,
|
||||
} as unknown as Config;
|
||||
|
||||
const server = getCodeAssistServer(mockConfig);
|
||||
expect(server).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return undefined if the wrapped generator is not a CodeAssistServer', () => {
|
||||
const mockGenerator = { a: 'generator' }; // Not a CodeAssistServer
|
||||
const mockLogger = new MockedLoggingContentGenerator(
|
||||
{} as never,
|
||||
{} as never,
|
||||
);
|
||||
vi.spyOn(mockLogger, 'getWrapped').mockReturnValue(
|
||||
mockGenerator as never,
|
||||
);
|
||||
|
||||
const mockConfig = {
|
||||
getContentGenerator: () => mockLogger,
|
||||
} as unknown as Config;
|
||||
|
||||
const server = getCodeAssistServer(mockConfig);
|
||||
expect(server).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,112 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { ReleaseChannel, getReleaseChannel } from '../../utils/channel.js';
|
||||
|
||||
// Mock dependencies before importing the module under test
|
||||
vi.mock('../../utils/channel.js', async () => {
|
||||
const actual = await vi.importActual('../../utils/channel.js');
|
||||
return {
|
||||
...(actual as object),
|
||||
getReleaseChannel: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
describe('client_metadata', () => {
|
||||
const originalPlatform = process.platform;
|
||||
const originalArch = process.arch;
|
||||
const originalCliVersion = process.env['CLI_VERSION'];
|
||||
const originalNodeVersion = process.version;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Reset modules to clear the cached `clientMetadataPromise`
|
||||
vi.resetModules();
|
||||
// Re-import the module to get a fresh instance
|
||||
await import('./client_metadata.js');
|
||||
// Provide a default mock implementation for each test
|
||||
vi.mocked(getReleaseChannel).mockResolvedValue(ReleaseChannel.STABLE);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// Restore original process properties to avoid side-effects between tests
|
||||
Object.defineProperty(process, 'platform', { value: originalPlatform });
|
||||
Object.defineProperty(process, 'arch', { value: originalArch });
|
||||
process.env['CLI_VERSION'] = originalCliVersion;
|
||||
Object.defineProperty(process, 'version', { value: originalNodeVersion });
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('getPlatform', () => {
|
||||
const testCases = [
|
||||
{ platform: 'darwin', arch: 'x64', expected: 'DARWIN_AMD64' },
|
||||
{ platform: 'darwin', arch: 'arm64', expected: 'DARWIN_ARM64' },
|
||||
{ platform: 'linux', arch: 'x64', expected: 'LINUX_AMD64' },
|
||||
{ platform: 'linux', arch: 'arm64', expected: 'LINUX_ARM64' },
|
||||
{ platform: 'win32', arch: 'x64', expected: 'WINDOWS_AMD64' },
|
||||
{ platform: 'sunos', arch: 'x64', expected: 'PLATFORM_UNSPECIFIED' },
|
||||
{ platform: 'win32', arch: 'arm', expected: 'PLATFORM_UNSPECIFIED' },
|
||||
];
|
||||
|
||||
for (const { platform, arch, expected } of testCases) {
|
||||
it(`should return ${expected} for platform ${platform} and arch ${arch}`, async () => {
|
||||
Object.defineProperty(process, 'platform', { value: platform });
|
||||
Object.defineProperty(process, 'arch', { value: arch });
|
||||
const { getClientMetadata } = await import('./client_metadata.js');
|
||||
|
||||
const metadata = await getClientMetadata();
|
||||
expect(metadata.platform).toBe(expected);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
describe('getClientMetadata', () => {
|
||||
it('should use CLI_VERSION for ideVersion if set', async () => {
|
||||
process.env['CLI_VERSION'] = '1.2.3';
|
||||
Object.defineProperty(process, 'version', { value: 'v18.0.0' });
|
||||
const { getClientMetadata } = await import('./client_metadata.js');
|
||||
|
||||
const metadata = await getClientMetadata();
|
||||
expect(metadata.ideVersion).toBe('1.2.3');
|
||||
});
|
||||
|
||||
it('should use process.version for ideVersion as a fallback', async () => {
|
||||
delete process.env['CLI_VERSION'];
|
||||
Object.defineProperty(process, 'version', { value: 'v20.0.0' });
|
||||
const { getClientMetadata } = await import('./client_metadata.js');
|
||||
|
||||
const metadata = await getClientMetadata();
|
||||
expect(metadata.ideVersion).toBe('v20.0.0');
|
||||
});
|
||||
|
||||
it('should call getReleaseChannel to get the update channel', async () => {
|
||||
vi.mocked(getReleaseChannel).mockResolvedValue(ReleaseChannel.NIGHTLY);
|
||||
const { getClientMetadata } = await import('./client_metadata.js');
|
||||
|
||||
const metadata = await getClientMetadata();
|
||||
|
||||
expect(metadata.updateChannel).toBe('nightly');
|
||||
expect(getReleaseChannel).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should cache the client metadata promise', async () => {
|
||||
const { getClientMetadata } = await import('./client_metadata.js');
|
||||
|
||||
const firstCall = await getClientMetadata();
|
||||
const secondCall = await getClientMetadata();
|
||||
|
||||
expect(firstCall).toBe(secondCall);
|
||||
// Ensure the underlying functions are only called once
|
||||
expect(getReleaseChannel).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should always return the IDE name as GEMINI_CLI', async () => {
|
||||
const { getClientMetadata } = await import('./client_metadata.js');
|
||||
const metadata = await getClientMetadata();
|
||||
expect(metadata.ideName).toBe('GEMINI_CLI');
|
||||
});
|
||||
});
|
||||
});
|
||||
115
packages/core/src/code_assist/experiments/experiments.test.ts
Normal file
115
packages/core/src/code_assist/experiments/experiments.test.ts
Normal file
@@ -0,0 +1,115 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import type { CodeAssistServer } from '../server.js';
|
||||
import { getClientMetadata } from './client_metadata.js';
|
||||
import type { ListExperimentsResponse, Flag } from './types.js';
|
||||
|
||||
// Mock dependencies before importing the module under test
|
||||
vi.mock('../server.js');
|
||||
vi.mock('./client_metadata.js');
|
||||
|
||||
describe('experiments', () => {
|
||||
let mockServer: CodeAssistServer;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset modules to clear the cached `experimentsPromise`
|
||||
vi.resetModules();
|
||||
|
||||
// Mock the dependencies that `getExperiments` relies on
|
||||
vi.mocked(getClientMetadata).mockResolvedValue({
|
||||
ideName: 'GEMINI_CLI',
|
||||
ideVersion: '1.0.0',
|
||||
platform: 'LINUX_AMD64',
|
||||
updateChannel: 'stable',
|
||||
});
|
||||
|
||||
// Create a mock instance of the server for each test
|
||||
mockServer = {
|
||||
listExperiments: vi.fn(),
|
||||
} as unknown as CodeAssistServer;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should fetch and parse experiments from the server', async () => {
|
||||
const { getExperiments } = await import('./experiments.js');
|
||||
const mockApiResponse: ListExperimentsResponse = {
|
||||
flags: [
|
||||
{ name: 'flag1', boolValue: true },
|
||||
{ name: 'flag2', stringValue: 'value' },
|
||||
],
|
||||
experimentIds: [123, 456],
|
||||
};
|
||||
vi.mocked(mockServer.listExperiments).mockResolvedValue(mockApiResponse);
|
||||
|
||||
const experiments = await getExperiments(mockServer);
|
||||
|
||||
// Verify that the dependencies were called
|
||||
expect(getClientMetadata).toHaveBeenCalled();
|
||||
expect(mockServer.listExperiments).toHaveBeenCalledWith(
|
||||
await getClientMetadata(),
|
||||
);
|
||||
|
||||
// Verify that the response was parsed correctly
|
||||
expect(experiments.flags['flag1']).toEqual({
|
||||
name: 'flag1',
|
||||
boolValue: true,
|
||||
});
|
||||
expect(experiments.flags['flag2']).toEqual({
|
||||
name: 'flag2',
|
||||
stringValue: 'value',
|
||||
});
|
||||
expect(experiments.experimentIds).toEqual([123, 456]);
|
||||
});
|
||||
|
||||
it('should handle an empty or partial response from the server', async () => {
|
||||
const { getExperiments } = await import('./experiments.js');
|
||||
const mockApiResponse: ListExperimentsResponse = {}; // No flags or experimentIds
|
||||
vi.mocked(mockServer.listExperiments).mockResolvedValue(mockApiResponse);
|
||||
|
||||
const experiments = await getExperiments(mockServer);
|
||||
|
||||
expect(experiments.flags).toEqual({});
|
||||
expect(experiments.experimentIds).toEqual([]);
|
||||
});
|
||||
|
||||
it('should ignore flags that are missing a name', async () => {
|
||||
const { getExperiments } = await import('./experiments.js');
|
||||
const mockApiResponse: ListExperimentsResponse = {
|
||||
flags: [
|
||||
{ boolValue: true } as Flag, // No name
|
||||
{ name: 'flag2', stringValue: 'value' },
|
||||
],
|
||||
};
|
||||
vi.mocked(mockServer.listExperiments).mockResolvedValue(mockApiResponse);
|
||||
|
||||
const experiments = await getExperiments(mockServer);
|
||||
|
||||
expect(Object.keys(experiments.flags)).toHaveLength(1);
|
||||
expect(experiments.flags['flag2']).toBeDefined();
|
||||
expect(experiments.flags['undefined']).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should cache the experiments promise to avoid multiple fetches', async () => {
|
||||
const { getExperiments } = await import('./experiments.js');
|
||||
const mockApiResponse: ListExperimentsResponse = {
|
||||
experimentIds: [1, 2, 3],
|
||||
};
|
||||
vi.mocked(mockServer.listExperiments).mockResolvedValue(mockApiResponse);
|
||||
|
||||
const firstCall = await getExperiments(mockServer);
|
||||
const secondCall = await getExperiments(mockServer);
|
||||
|
||||
expect(firstCall).toBe(secondCall); // Should be the exact same promise object
|
||||
// Verify the underlying functions were only called once
|
||||
expect(getClientMetadata).toHaveBeenCalledTimes(1);
|
||||
expect(mockServer.listExperiments).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
@@ -173,6 +173,58 @@ describe('OAuthCredentialStorage', () => {
|
||||
|
||||
expect(result).toEqual(mockCredentials);
|
||||
});
|
||||
|
||||
it('should throw an error if the migration file contains invalid JSON', async () => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
null,
|
||||
);
|
||||
vi.spyOn(fs, 'readFile').mockResolvedValue('invalid json');
|
||||
|
||||
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
|
||||
'Failed to load OAuth credentials',
|
||||
);
|
||||
});
|
||||
|
||||
it('should not delete the old file if saving migrated credentials fails', async () => {
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
null,
|
||||
);
|
||||
vi.spyOn(fs, 'readFile').mockResolvedValue(
|
||||
JSON.stringify(mockCredentials),
|
||||
);
|
||||
vi.spyOn(mockHybridTokenStorage, 'setCredentials').mockRejectedValue(
|
||||
new Error('Save failed'),
|
||||
);
|
||||
|
||||
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
|
||||
'Failed to load OAuth credentials',
|
||||
);
|
||||
|
||||
expect(fs.rm).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return credentials even if access_token is missing from storage', async () => {
|
||||
const partialMcpCredentials = {
|
||||
...mockMcpCredentials,
|
||||
token: {
|
||||
...mockMcpCredentials.token,
|
||||
accessToken: undefined,
|
||||
},
|
||||
};
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
partialMcpCredentials,
|
||||
);
|
||||
|
||||
const result = await OAuthCredentialStorage.loadCredentials();
|
||||
|
||||
expect(result).toEqual({
|
||||
access_token: undefined,
|
||||
refresh_token: mockCredentials.refresh_token,
|
||||
token_type: mockCredentials.token_type,
|
||||
scope: mockCredentials.scope,
|
||||
expiry_date: mockCredentials.expiry_date,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('saveCredentials', () => {
|
||||
@@ -195,6 +247,28 @@ describe('OAuthCredentialStorage', () => {
|
||||
'Attempted to save credentials without an access token.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle saving credentials with null or undefined optional fields', async () => {
|
||||
const partialCredentials: Credentials = {
|
||||
access_token: 'only_access_token',
|
||||
refresh_token: null, // test null
|
||||
scope: undefined, // test undefined
|
||||
};
|
||||
|
||||
await OAuthCredentialStorage.saveCredentials(partialCredentials);
|
||||
|
||||
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith({
|
||||
serverName: 'main-account',
|
||||
token: {
|
||||
accessToken: 'only_access_token',
|
||||
refreshToken: undefined,
|
||||
tokenType: 'Bearer', // default
|
||||
scope: undefined,
|
||||
expiresAt: undefined,
|
||||
},
|
||||
updatedAt: expect.any(Number),
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearCredentials', () => {
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { beforeEach, describe, it, expect, vi } from 'vitest';
|
||||
import { beforeEach, describe, it, expect, vi, afterEach } from 'vitest';
|
||||
import { CodeAssistServer } from './server.js';
|
||||
import { OAuth2Client } from 'google-auth-library';
|
||||
import { UserTierId } from './types.js';
|
||||
@@ -29,15 +29,16 @@ describe('CodeAssistServer', () => {
|
||||
});
|
||||
|
||||
it('should call the generateContent endpoint', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const mockRequest = vi.fn();
|
||||
const client = { request: mockRequest } as unknown as OAuth2Client;
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
{},
|
||||
{ headers: { 'x-custom-header': 'test-value' } },
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
const mockResponse = {
|
||||
const mockResponseData = {
|
||||
response: {
|
||||
candidates: [
|
||||
{
|
||||
@@ -52,7 +53,7 @@ describe('CodeAssistServer', () => {
|
||||
],
|
||||
},
|
||||
};
|
||||
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
|
||||
mockRequest.mockResolvedValue({ data: mockResponseData });
|
||||
|
||||
const response = await server.generateContent(
|
||||
{
|
||||
@@ -62,18 +63,59 @@ describe('CodeAssistServer', () => {
|
||||
'user-prompt-id',
|
||||
);
|
||||
|
||||
expect(server.requestPost).toHaveBeenCalledWith(
|
||||
'generateContent',
|
||||
expect.any(Object),
|
||||
undefined,
|
||||
);
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
url: expect.stringContaining(':generateContent'),
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-custom-header': 'test-value',
|
||||
},
|
||||
responseType: 'json',
|
||||
body: expect.any(String),
|
||||
signal: undefined,
|
||||
});
|
||||
|
||||
const requestBody = JSON.parse(mockRequest.mock.calls[0][0].body);
|
||||
expect(requestBody.user_prompt_id).toBe('user-prompt-id');
|
||||
expect(requestBody.project).toBe('test-project');
|
||||
|
||||
expect(response.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
|
||||
'response',
|
||||
);
|
||||
});
|
||||
|
||||
it('should call the generateContentStream endpoint', async () => {
|
||||
const client = new OAuth2Client();
|
||||
describe('getMethodUrl', () => {
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset the environment variables to their original state
|
||||
process.env = { ...originalEnv };
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// Restore the original environment variables
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
it('should construct the default URL correctly', () => {
|
||||
const server = new CodeAssistServer({} as never);
|
||||
const url = server.getMethodUrl('testMethod');
|
||||
expect(url).toBe(
|
||||
'https://cloudcode-pa.googleapis.com/v1internal:testMethod',
|
||||
);
|
||||
});
|
||||
|
||||
it('should use the CODE_ASSIST_ENDPOINT environment variable if set', () => {
|
||||
process.env['CODE_ASSIST_ENDPOINT'] = 'https://custom-endpoint.com';
|
||||
const server = new CodeAssistServer({} as never);
|
||||
const url = server.getMethodUrl('testMethod');
|
||||
expect(url).toBe('https://custom-endpoint.com/v1internal:testMethod');
|
||||
});
|
||||
});
|
||||
|
||||
it('should call the generateContentStream endpoint and parse SSE', async () => {
|
||||
const mockRequest = vi.fn();
|
||||
const client = { request: mockRequest } as unknown as OAuth2Client;
|
||||
const server = new CodeAssistServer(
|
||||
client,
|
||||
'test-project',
|
||||
@@ -81,24 +123,21 @@ describe('CodeAssistServer', () => {
|
||||
'test-session',
|
||||
UserTierId.FREE,
|
||||
);
|
||||
const mockResponse = (async function* () {
|
||||
yield {
|
||||
response: {
|
||||
candidates: [
|
||||
{
|
||||
index: 0,
|
||||
content: {
|
||||
role: 'model',
|
||||
parts: [{ text: 'response' }],
|
||||
},
|
||||
finishReason: 'STOP',
|
||||
safetyRatings: [],
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
})();
|
||||
vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse);
|
||||
|
||||
// Create a mock readable stream
|
||||
const { Readable } = await import('node:stream');
|
||||
const mockStream = new Readable({
|
||||
read() {},
|
||||
});
|
||||
|
||||
const mockResponseData1 = {
|
||||
response: { candidates: [{ content: { parts: [{ text: 'Hello' }] } }] },
|
||||
};
|
||||
const mockResponseData2 = {
|
||||
response: { candidates: [{ content: { parts: [{ text: ' World' }] } }] },
|
||||
};
|
||||
|
||||
mockRequest.mockResolvedValue({ data: mockStream });
|
||||
|
||||
const stream = await server.generateContentStream(
|
||||
{
|
||||
@@ -108,14 +147,61 @@ describe('CodeAssistServer', () => {
|
||||
'user-prompt-id',
|
||||
);
|
||||
|
||||
// Push SSE data to the stream
|
||||
// Use setTimeout to ensure the stream processing has started
|
||||
setTimeout(() => {
|
||||
mockStream.push('data: ' + JSON.stringify(mockResponseData1) + '\n\n');
|
||||
mockStream.push('id: 123\n'); // Should be ignored
|
||||
mockStream.push('data: ' + JSON.stringify(mockResponseData2) + '\n\n');
|
||||
mockStream.push(null); // End the stream
|
||||
}, 0);
|
||||
|
||||
const results = [];
|
||||
for await (const res of stream) {
|
||||
expect(server.requestStreamingPost).toHaveBeenCalledWith(
|
||||
'streamGenerateContent',
|
||||
expect.any(Object),
|
||||
undefined,
|
||||
);
|
||||
expect(res.candidates?.[0]?.content?.parts?.[0]?.text).toBe('response');
|
||||
results.push(res);
|
||||
}
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
url: expect.stringContaining(':streamGenerateContent'),
|
||||
method: 'POST',
|
||||
params: { alt: 'sse' },
|
||||
responseType: 'stream',
|
||||
body: expect.any(String),
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
signal: undefined,
|
||||
});
|
||||
|
||||
expect(results).toHaveLength(2);
|
||||
expect(results[0].candidates?.[0].content?.parts?.[0].text).toBe('Hello');
|
||||
expect(results[1].candidates?.[0].content?.parts?.[0].text).toBe(' World');
|
||||
});
|
||||
|
||||
it('should ignore malformed SSE data', async () => {
|
||||
const mockRequest = vi.fn();
|
||||
const client = { request: mockRequest } as unknown as OAuth2Client;
|
||||
const server = new CodeAssistServer(client);
|
||||
|
||||
const { Readable } = await import('node:stream');
|
||||
const mockStream = new Readable({
|
||||
read() {},
|
||||
});
|
||||
|
||||
mockRequest.mockResolvedValue({ data: mockStream });
|
||||
|
||||
const stream = await server.requestStreamingPost('testStream', {});
|
||||
|
||||
setTimeout(() => {
|
||||
mockStream.push('this is a malformed line\n');
|
||||
mockStream.push(null);
|
||||
}, 0);
|
||||
|
||||
const results = [];
|
||||
for await (const res of stream) {
|
||||
results.push(res);
|
||||
}
|
||||
expect(results).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should call the onboardUser endpoint', async () => {
|
||||
@@ -253,6 +339,22 @@ describe('CodeAssistServer', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should re-throw non-VPC-SC errors from loadCodeAssist', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(client);
|
||||
const genericError = new Error('Something else went wrong');
|
||||
vi.spyOn(server, 'requestPost').mockRejectedValue(genericError);
|
||||
|
||||
await expect(server.loadCodeAssist({ metadata: {} })).rejects.toThrow(
|
||||
'Something else went wrong',
|
||||
);
|
||||
|
||||
expect(server.requestPost).toHaveBeenCalledWith(
|
||||
'loadCodeAssist',
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should call the listExperiments endpoint with metadata', async () => {
|
||||
const client = new OAuth2Client();
|
||||
const server = new CodeAssistServer(
|
||||
|
||||
@@ -232,18 +232,16 @@ export class CodeAssistServer implements ContentGenerator {
|
||||
|
||||
let bufferedLines: string[] = [];
|
||||
for await (const line of rl) {
|
||||
// blank lines are used to separate JSON objects in the stream
|
||||
if (line === '') {
|
||||
if (line.startsWith('data: ')) {
|
||||
bufferedLines.push(line.slice(6).trim());
|
||||
} else if (line === '') {
|
||||
if (bufferedLines.length === 0) {
|
||||
continue; // no data to yield
|
||||
}
|
||||
yield JSON.parse(bufferedLines.join('\n')) as T;
|
||||
bufferedLines = []; // Reset the buffer after yielding
|
||||
} else if (line.startsWith('data: ')) {
|
||||
bufferedLines.push(line.slice(6).trim());
|
||||
} else {
|
||||
throw new Error(`Unexpected line format in response: ${line}`);
|
||||
}
|
||||
// Ignore other lines like comments or id fields
|
||||
}
|
||||
})();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user