Files
gemini-cli/packages/core/src/code_assist/server.test.ts

1119 lines
32 KiB
TypeScript

/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { beforeEach, describe, it, expect, vi, afterEach } from 'vitest';
import { CodeAssistServer } from './server.js';
import { OAuth2Client } from 'google-auth-library';
import {
UserTierId,
ActionStatus,
type LoadCodeAssistResponse,
type GeminiUserTier,
type SetCodeAssistGlobalUserSettingRequest,
type CodeAssistGlobalUserSettingResponse,
} from './types.js';
import { FinishReason } from '@google/genai';
import { LlmRole } from '../telemetry/types.js';
import { logInvalidChunk } from '../telemetry/loggers.js';
import { makeFakeConfig } from '../test-utils/config.js';
vi.mock('google-auth-library');
vi.mock('../telemetry/loggers.js', () => ({
logBillingEvent: vi.fn(),
logInvalidChunk: vi.fn(),
}));
function createTestServer(headers: Record<string, string> = {}) {
const mockRequest = vi.fn();
const client = { request: mockRequest } as unknown as OAuth2Client;
const server = new CodeAssistServer(
client,
'test-project',
{ headers },
'test-session',
UserTierId.FREE,
);
return { server, mockRequest, client };
}
describe('CodeAssistServer', () => {
beforeEach(() => {
vi.resetAllMocks();
});
it('should be able to be constructed', () => {
const auth = new OAuth2Client();
const server = new CodeAssistServer(
auth,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
expect(server).toBeInstanceOf(CodeAssistServer);
});
it('should call the generateContent endpoint', async () => {
const { server, mockRequest } = createTestServer({
'x-custom-header': 'test-value',
});
const mockResponseData = {
response: {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [{ text: 'response' }],
},
finishReason: FinishReason.STOP,
safetyRatings: [],
},
],
},
};
mockRequest.mockResolvedValue({ data: mockResponseData });
const response = await server.generateContent(
{
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
},
'user-prompt-id',
LlmRole.MAIN,
);
expect(mockRequest).toHaveBeenCalledWith({
url: expect.stringContaining(':generateContent'),
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-custom-header': 'test-value',
},
responseType: 'json',
body: expect.any(String),
signal: undefined,
retryConfig: {
retryDelay: 1000,
retry: 3,
noResponseRetries: 3,
statusCodesToRetry: [
[429, 429],
[499, 499],
[500, 599],
],
},
});
const requestBody = JSON.parse(mockRequest.mock.calls[0][0].body);
expect(requestBody.user_prompt_id).toBe('user-prompt-id');
expect(requestBody.project).toBe('test-project');
expect(response.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
'response',
);
});
it('should detect error in generateContent response', async () => {
const { server, mockRequest } = createTestServer();
const mockResponseData = {
traceId: 'test-trace-id',
response: {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [
{ text: 'response' },
{ functionCall: { name: 'replace', args: {} } },
],
},
finishReason: FinishReason.SAFETY,
safetyRatings: [],
},
],
},
};
mockRequest.mockResolvedValue({ data: mockResponseData });
const recordConversationOfferedSpy = vi.spyOn(
server,
'recordConversationOffered',
);
await server.generateContent(
{
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
},
'user-prompt-id',
LlmRole.MAIN,
);
expect(recordConversationOfferedSpy).toHaveBeenCalledWith(
expect.objectContaining({
status: ActionStatus.ACTION_STATUS_ERROR_UNKNOWN,
}),
);
});
it('should record conversation offered on successful generateContent', async () => {
const { server, mockRequest } = createTestServer();
const mockResponseData = {
traceId: 'test-trace-id',
response: {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [
{ text: 'response' },
{ functionCall: { name: 'replace', args: {} } },
],
},
finishReason: FinishReason.STOP,
safetyRatings: [],
},
],
sdkHttpResponse: {
responseInternal: {
ok: true,
},
},
},
};
mockRequest.mockResolvedValue({ data: mockResponseData });
vi.spyOn(server, 'recordCodeAssistMetrics').mockResolvedValue(undefined);
await server.generateContent(
{
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
},
'user-prompt-id',
LlmRole.MAIN,
);
expect(server.recordCodeAssistMetrics).toHaveBeenCalledWith(
expect.objectContaining({
metrics: expect.arrayContaining([
expect.objectContaining({
conversationOffered: expect.objectContaining({
traceId: 'test-trace-id',
status: ActionStatus.ACTION_STATUS_NO_ERROR,
streamingLatency: expect.objectContaining({
totalLatency: expect.stringMatching(/\d+s/),
firstMessageLatency: expect.stringMatching(/\d+s/),
}),
}),
timestamp: expect.stringMatching(
/\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.\d{3}Z/,
),
}),
]),
}),
);
});
it('should record conversation offered on generateContentStream', async () => {
const { server, mockRequest } = createTestServer();
const { Readable } = await import('node:stream');
const mockStream = new Readable({ read() {} });
mockRequest.mockResolvedValue({ data: mockStream });
vi.spyOn(server, 'recordCodeAssistMetrics').mockResolvedValue(undefined);
const stream = await server.generateContentStream(
{
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
},
'user-prompt-id',
LlmRole.MAIN,
);
const mockResponseData = {
traceId: 'stream-trace-id',
response: {
candidates: [
{
content: {
parts: [
{ text: 'chunk' },
{ functionCall: { name: 'replace', args: {} } },
],
},
},
],
sdkHttpResponse: {
responseInternal: {
ok: true,
},
},
},
};
setTimeout(() => {
mockStream.push('data: ' + JSON.stringify(mockResponseData) + '\n\n');
mockStream.push(null);
}, 0);
for await (const _ of stream) {
// Consume stream
}
expect(server.recordCodeAssistMetrics).toHaveBeenCalledWith(
expect.objectContaining({
metrics: expect.arrayContaining([
expect.objectContaining({
conversationOffered: expect.objectContaining({
traceId: 'stream-trace-id',
}),
timestamp: expect.stringMatching(
/\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.\d{3}Z/,
),
}),
]),
}),
);
});
it('should record conversation interaction', async () => {
const { server } = createTestServer();
vi.spyOn(server, 'recordCodeAssistMetrics').mockResolvedValue(undefined);
const interaction = {
traceId: 'test-trace-id',
};
await server.recordConversationInteraction(interaction);
expect(server.recordCodeAssistMetrics).toHaveBeenCalledWith(
expect.objectContaining({
project: 'test-project',
metrics: expect.arrayContaining([
expect.objectContaining({
conversationInteraction: interaction,
timestamp: expect.stringMatching(
/\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.\d{3}Z/,
),
}),
]),
}),
);
});
it('should call recordCodeAssistMetrics endpoint', async () => {
const { server, mockRequest } = createTestServer();
mockRequest.mockResolvedValue({ data: {} });
const req = {
project: 'test-project',
metrics: [],
};
await server.recordCodeAssistMetrics(req);
expect(mockRequest).toHaveBeenCalledWith(
expect.objectContaining({
url: expect.stringContaining(':recordCodeAssistMetrics'),
method: 'POST',
body: expect.any(String),
}),
);
});
describe('getMethodUrl', () => {
const originalEnv = process.env;
beforeEach(() => {
// Reset the environment variables to their original state
process.env = { ...originalEnv };
});
afterEach(() => {
// Restore the original environment variables
process.env = originalEnv;
});
it('should construct the default URL correctly', () => {
const server = new CodeAssistServer({} as never);
const url = server.getMethodUrl('testMethod');
expect(url).toBe(
'https://cloudcode-pa.googleapis.com/v1internal:testMethod',
);
});
it('should use the CODE_ASSIST_ENDPOINT environment variable if set', () => {
process.env['CODE_ASSIST_ENDPOINT'] = 'https://custom-endpoint.com';
const server = new CodeAssistServer({} as never);
const url = server.getMethodUrl('testMethod');
expect(url).toBe('https://custom-endpoint.com/v1internal:testMethod');
});
it('should use the CODE_ASSIST_API_VERSION environment variable if set', () => {
process.env['CODE_ASSIST_API_VERSION'] = 'v2beta';
const server = new CodeAssistServer({} as never);
const url = server.getMethodUrl('testMethod');
expect(url).toBe('https://cloudcode-pa.googleapis.com/v2beta:testMethod');
});
it('should use default value if CODE_ASSIST_API_VERSION env var is empty', () => {
process.env['CODE_ASSIST_API_VERSION'] = '';
const server = new CodeAssistServer({} as never);
const url = server.getMethodUrl('testMethod');
expect(url).toBe(
'https://cloudcode-pa.googleapis.com/v1internal:testMethod',
);
});
});
it('should call the generateContentStream endpoint and parse SSE', async () => {
const { server, mockRequest } = createTestServer();
// Create a mock readable stream
const { Readable } = await import('node:stream');
const mockStream = new Readable({
read() {},
});
const mockResponseData1 = {
response: { candidates: [{ content: { parts: [{ text: 'Hello' }] } }] },
};
const mockResponseData2 = {
response: { candidates: [{ content: { parts: [{ text: ' World' }] } }] },
};
mockRequest.mockResolvedValue({ data: mockStream });
const stream = await server.generateContentStream(
{
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
},
'user-prompt-id',
LlmRole.MAIN,
);
// Push SSE data to the stream
// Use setTimeout to ensure the stream processing has started
setTimeout(() => {
mockStream.push('data: ' + JSON.stringify(mockResponseData1) + '\n\n');
mockStream.push('id: 123\n'); // Should be ignored
mockStream.push('data: ' + JSON.stringify(mockResponseData2) + '\n\n');
mockStream.push(null); // End the stream
}, 0);
const results = [];
for await (const res of stream) {
results.push(res);
}
expect(mockRequest).toHaveBeenCalledWith({
url: expect.stringContaining(':streamGenerateContent'),
method: 'POST',
params: { alt: 'sse' },
responseType: 'stream',
body: expect.any(String),
headers: {
'Content-Type': 'application/json',
},
signal: undefined,
retry: false,
});
expect(results).toHaveLength(2);
expect(results[0].candidates?.[0].content?.parts?.[0].text).toBe('Hello');
expect(results[1].candidates?.[0].content?.parts?.[0].text).toBe(' World');
});
it('should handle Web ReadableStream in generateContentStream', async () => {
const { server, mockRequest } = createTestServer();
// Create a mock Web ReadableStream
const mockWebStream = new ReadableStream({
start(controller) {
const mockResponseData = {
response: {
candidates: [{ content: { parts: [{ text: 'Hello Web' }] } }],
},
};
controller.enqueue(
new TextEncoder().encode(
'data: ' + JSON.stringify(mockResponseData) + '\n\n',
),
);
controller.close();
},
});
mockRequest.mockResolvedValue({ data: mockWebStream });
const stream = await server.generateContentStream(
{
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
},
'user-prompt-id',
LlmRole.MAIN,
);
const results = [];
for await (const res of stream) {
results.push(res);
}
expect(results).toHaveLength(1);
expect(results[0].candidates?.[0].content?.parts?.[0].text).toBe(
'Hello Web',
);
});
it('should ignore malformed SSE data', async () => {
const { server, mockRequest } = createTestServer();
const { Readable } = await import('node:stream');
const mockStream = new Readable({
read() {},
});
mockRequest.mockResolvedValue({ data: mockStream });
const stream = await server.requestStreamingPost('testStream', {});
setTimeout(() => {
mockStream.push('this is a malformed line\n');
mockStream.push(null);
}, 0);
const results = [];
for await (const res of stream) {
results.push(res);
}
expect(results).toHaveLength(0);
});
it('should call the onboardUser endpoint', async () => {
const { server } = createTestServer();
const mockResponse = {
name: 'operations/123',
done: true,
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.onboardUser({
tierId: 'test-tier',
cloudaicompanionProject: 'test-project',
metadata: {},
});
expect(server.requestPost).toHaveBeenCalledWith(
'onboardUser',
expect.any(Object),
);
expect(response.name).toBe('operations/123');
});
it('should call the getOperation endpoint', async () => {
const { server } = createTestServer();
const mockResponse = {
name: 'operations/123',
done: true,
response: {
cloudaicompanionProject: {
id: 'test-project',
name: 'projects/test-project',
},
},
};
vi.spyOn(server, 'requestGetOperation').mockResolvedValue(mockResponse);
const response = await server.getOperation('operations/123');
expect(server.requestGetOperation).toHaveBeenCalledWith('operations/123');
expect(response.name).toBe('operations/123');
expect(response.response?.cloudaicompanionProject?.id).toBe('test-project');
expect(response.response?.cloudaicompanionProject?.name).toBe(
'projects/test-project',
);
});
it('should call the loadCodeAssist endpoint', async () => {
const { server } = createTestServer();
const mockResponse = {
currentTier: {
id: UserTierId.FREE,
name: 'Free',
description: 'free tier',
},
allowedTiers: [],
ineligibleTiers: [],
cloudaicompanionProject: 'projects/test',
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.loadCodeAssist({
metadata: {},
});
expect(server.requestPost).toHaveBeenCalledWith(
'loadCodeAssist',
expect.any(Object),
);
expect(response).toEqual(mockResponse);
});
it('should return 0 for countTokens', async () => {
const { server } = createTestServer();
const mockResponse = {
totalTokens: 100,
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const response = await server.countTokens({
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
});
expect(response.totalTokens).toBe(100);
});
it('should throw an error for embedContent', async () => {
const { server } = createTestServer();
await expect(
server.embedContent({
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
}),
).rejects.toThrow();
});
it('should handle VPC-SC errors when calling loadCodeAssist', async () => {
const { server } = createTestServer();
const mockVpcScError = {
response: {
data: {
error: {
details: [
{
reason: 'SECURITY_POLICY_VIOLATED',
},
],
},
},
},
};
vi.spyOn(server, 'requestPost').mockRejectedValue(mockVpcScError);
const response = await server.loadCodeAssist({
metadata: {},
});
expect(server.requestPost).toHaveBeenCalledWith(
'loadCodeAssist',
expect.any(Object),
);
expect(response).toEqual({
currentTier: { id: UserTierId.STANDARD },
});
});
it('should re-throw non-VPC-SC errors from loadCodeAssist', async () => {
const { server } = createTestServer();
const genericError = new Error('Something else went wrong');
vi.spyOn(server, 'requestPost').mockRejectedValue(genericError);
await expect(server.loadCodeAssist({ metadata: {} })).rejects.toThrow(
'Something else went wrong',
);
expect(server.requestPost).toHaveBeenCalledWith(
'loadCodeAssist',
expect.any(Object),
);
});
it('should call the listExperiments endpoint with metadata', async () => {
const { server } = createTestServer();
const mockResponse = {
experiments: [],
};
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
const metadata = {
ideVersion: 'v0.1.0',
};
const response = await server.listExperiments(metadata);
expect(server.requestPost).toHaveBeenCalledWith('listExperiments', {
project: 'test-project',
metadata: { ideVersion: 'v0.1.0', duetProject: 'test-project' },
});
expect(response).toEqual(mockResponse);
});
it('should call the retrieveUserQuota endpoint', async () => {
const { server } = createTestServer();
const mockResponse = {
buckets: [
{
modelId: 'gemini-2.5-pro',
tokenType: 'REQUESTS',
remainingFraction: 0.75,
resetTime: '2025-10-22T16:01:15Z',
},
],
};
const requestPostSpy = vi
.spyOn(server, 'requestPost')
.mockResolvedValue(mockResponse);
const req = {
project: 'projects/my-cloudcode-project',
userAgent: 'CloudCodePlugin/1.0 (gaghosh)',
};
const response = await server.retrieveUserQuota(req);
expect(requestPostSpy).toHaveBeenCalledWith('retrieveUserQuota', req);
expect(response).toEqual(mockResponse);
});
it('should call fetchAdminControls endpoint', async () => {
const { server } = createTestServer();
const mockResponse = { adminControlsApplicable: true };
const requestPostSpy = vi
.spyOn(server, 'requestPost')
.mockResolvedValue(mockResponse);
const req = { project: 'test-project' };
const response = await server.fetchAdminControls(req);
expect(requestPostSpy).toHaveBeenCalledWith('fetchAdminControls', req);
expect(response).toEqual(mockResponse);
});
it('should call getCodeAssistGlobalUserSetting endpoint', async () => {
const { server } = createTestServer();
const mockResponse: CodeAssistGlobalUserSettingResponse = {
freeTierDataCollectionOptin: true,
};
const requestGetSpy = vi
.spyOn(server, 'requestGet')
.mockResolvedValue(mockResponse);
const response = await server.getCodeAssistGlobalUserSetting();
expect(requestGetSpy).toHaveBeenCalledWith(
'getCodeAssistGlobalUserSetting',
);
expect(response).toEqual(mockResponse);
});
it('should call setCodeAssistGlobalUserSetting endpoint', async () => {
const { server } = createTestServer();
const mockResponse: CodeAssistGlobalUserSettingResponse = {
freeTierDataCollectionOptin: true,
};
const requestPostSpy = vi
.spyOn(server, 'requestPost')
.mockResolvedValue(mockResponse);
const req: SetCodeAssistGlobalUserSettingRequest = {
freeTierDataCollectionOptin: true,
};
const response = await server.setCodeAssistGlobalUserSetting(req);
expect(requestPostSpy).toHaveBeenCalledWith(
'setCodeAssistGlobalUserSetting',
req,
);
expect(response).toEqual(mockResponse);
});
it('should call loadCodeAssist during refreshAvailableCredits', async () => {
const { server } = createTestServer();
const mockPaidTier = {
id: 'test-tier',
name: 'tier',
availableCredits: [{ creditType: 'G1', creditAmount: '50' }],
};
const mockResponse = { paidTier: mockPaidTier };
vi.spyOn(server, 'loadCodeAssist').mockResolvedValue(
mockResponse as unknown as LoadCodeAssistResponse,
);
// Initial state: server has a paidTier without availableCredits
(server as unknown as { paidTier: GeminiUserTier }).paidTier = {
id: 'test-tier',
name: 'tier',
};
await server.refreshAvailableCredits();
expect(server.loadCodeAssist).toHaveBeenCalled();
expect(server.paidTier?.availableCredits).toEqual(
mockPaidTier.availableCredits,
);
});
describe('robustness testing', () => {
it('should not crash on random error objects in loadCodeAssist (isVpcScAffectedUser)', async () => {
const { server } = createTestServer();
const errors = [
null,
undefined,
'string error',
123,
{ some: 'object' },
new Error('standard error'),
{ response: {} },
{ response: { data: {} } },
];
for (const err of errors) {
vi.spyOn(server, 'requestPost').mockRejectedValueOnce(err);
try {
await server.loadCodeAssist({ metadata: {} });
} catch (e) {
expect(e).toBe(err);
}
}
});
it('should handle randomly fragmented SSE streams gracefully', async () => {
const { server, mockRequest } = createTestServer();
const { Readable } = await import('node:stream');
const fragmentedCases = [
{
chunks: ['d', 'ata: {"foo":', ' "bar"}\n\n'],
expected: [{ foo: 'bar' }],
},
{
chunks: ['data: {"foo": "bar"}\n', '\n'],
expected: [{ foo: 'bar' }],
},
{
chunks: ['data: ', '{"foo": "bar"}', '\n\n'],
expected: [{ foo: 'bar' }],
},
{
chunks: ['data: {"foo": "bar"}\n\n', 'data: {"baz": 1}\n\n'],
expected: [{ foo: 'bar' }, { baz: 1 }],
},
];
for (const { chunks, expected } of fragmentedCases) {
const mockStream = new Readable({
read() {
for (const chunk of chunks) {
this.push(chunk);
}
this.push(null);
},
});
mockRequest.mockResolvedValueOnce({ data: mockStream });
const stream = await server.requestStreamingPost('testStream', {});
const results = [];
for await (const res of stream) {
results.push(res);
}
expect(results).toEqual(expected);
}
});
it('should correctly parse valid JSON split across multiple data lines', async () => {
const { server, mockRequest } = createTestServer();
const { Readable } = await import('node:stream');
const jsonObj = {
complex: { structure: [1, 2, 3] },
bool: true,
str: 'value',
};
const jsonString = JSON.stringify(jsonObj, null, 2);
const lines = jsonString.split('\n');
const ssePayload = lines.map((line) => `data: ${line}\n`).join('') + '\n';
const mockStream = new Readable({
read() {
this.push(ssePayload);
this.push(null);
},
});
mockRequest.mockResolvedValueOnce({ data: mockStream });
const stream = await server.requestStreamingPost('testStream', {});
const results = [];
for await (const res of stream) {
results.push(res);
}
expect(results).toHaveLength(1);
expect(results[0]).toEqual(jsonObj);
});
it('should not crash on objects partially matching VPC SC error structure', async () => {
const { server } = createTestServer();
const partialErrors = [
{ response: { data: { error: { details: [{ reason: 'OTHER' }] } } } },
{ response: { data: { error: { details: [] } } } },
{ response: { data: { error: {} } } },
{ response: { data: {} } },
];
for (const err of partialErrors) {
vi.spyOn(server, 'requestPost').mockRejectedValueOnce(err);
try {
await server.loadCodeAssist({ metadata: {} });
} catch (e) {
expect(e).toBe(err);
}
}
});
it('should correctly ignore arbitrary SSE comments and ID lines and empty lines before data', async () => {
const { server, mockRequest } = createTestServer();
const { Readable } = await import('node:stream');
const jsonObj = { foo: 'bar' };
const jsonString = JSON.stringify(jsonObj);
const ssePayload = `id: 123
:comment
retry: 100
data: ${jsonString}
`;
const mockStream = new Readable({
read() {
this.push(ssePayload);
this.push(null);
},
});
mockRequest.mockResolvedValueOnce({ data: mockStream });
const stream = await server.requestStreamingPost('testStream', {});
const results = [];
for await (const res of stream) {
results.push(res);
}
expect(results).toHaveLength(1);
expect(results[0]).toEqual(jsonObj);
});
it('should log InvalidChunkEvent when SSE chunk is not valid JSON', async () => {
const config = makeFakeConfig();
const mockRequest = vi.fn();
const client = { request: mockRequest } as unknown as OAuth2Client;
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
undefined,
undefined,
config,
);
const { Readable } = await import('node:stream');
const mockStream = new Readable({
read() {},
});
mockRequest.mockResolvedValue({ data: mockStream });
const stream = await server.requestStreamingPost('testStream', {});
setTimeout(() => {
mockStream.push('data: { "invalid": json }\n\n');
mockStream.push(null);
}, 0);
const results = [];
for await (const res of stream) {
results.push(res);
}
expect(results).toHaveLength(0);
expect(logInvalidChunk).toHaveBeenCalledWith(
config,
expect.objectContaining({
error_message: 'Malformed JSON chunk',
}),
);
});
it('should handle malformed JSON within a multi-line data block', async () => {
const config = makeFakeConfig();
const mockRequest = vi.fn();
const client = { request: mockRequest } as unknown as OAuth2Client;
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
undefined,
undefined,
config,
);
const { Readable } = await import('node:stream');
const mockStream = new Readable({
read() {},
});
mockRequest.mockResolvedValue({ data: mockStream });
const stream = await server.requestStreamingPost('testStream', {});
setTimeout(() => {
mockStream.push('data: {\n');
mockStream.push('data: "invalid": json\n');
mockStream.push('data: }\n\n');
mockStream.push(null);
}, 0);
const results = [];
for await (const res of stream) {
results.push(res);
}
expect(results).toHaveLength(0);
expect(logInvalidChunk).toHaveBeenCalled();
});
it('should safely process random response streams in generateContentStream (consumed/remaining credits)', async () => {
const { mockRequest, client } = createTestServer();
const testServer = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
undefined,
{ id: 'test-tier', name: 'tier', availableCredits: [] },
);
const { Readable } = await import('node:stream');
const streamResponses = [
{
traceId: '1',
consumedCredits: [{ creditType: 'A', creditAmount: '10' }],
},
{ traceId: '2', remainingCredits: [{ creditType: 'B' }] },
{ traceId: '3' },
{ traceId: '4', consumedCredits: null, remainingCredits: undefined },
];
const mockStream = new Readable({
read() {
for (const resp of streamResponses) {
this.push(`data: ${JSON.stringify(resp)}\n\n`);
}
this.push(null);
},
});
mockRequest.mockResolvedValueOnce({ data: mockStream });
vi.spyOn(testServer, 'recordCodeAssistMetrics').mockResolvedValue(
undefined,
);
const stream = await testServer.generateContentStream(
{ model: 'test-model', contents: [] },
'user-prompt-id',
LlmRole.MAIN,
);
for await (const _ of stream) {
// Drain stream
}
// Should not crash
});
it('should be resilient to metadata-only chunks without candidates in generateContentStream', async () => {
const { mockRequest, client } = createTestServer();
const testServer = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const { Readable } = await import('node:stream');
// Chunk 2 is metadata-only, no candidates
const streamResponses = [
{
traceId: '1',
response: {
candidates: [{ content: { parts: [{ text: 'Hello' }] }, index: 0 }],
},
},
{
traceId: '2',
consumedCredits: [{ creditType: 'GOOGLE_ONE_AI', creditAmount: '5' }],
response: {
usageMetadata: { promptTokenCount: 10, totalTokenCount: 15 },
},
},
{
traceId: '3',
response: {
candidates: [
{ content: { parts: [{ text: ' World' }] }, index: 0 },
],
},
},
];
const mockStream = new Readable({
read() {
for (const resp of streamResponses) {
this.push(`data: ${JSON.stringify(resp)}\n\n`);
}
this.push(null);
},
});
mockRequest.mockResolvedValueOnce({ data: mockStream });
vi.spyOn(testServer, 'recordCodeAssistMetrics').mockResolvedValue(
undefined,
);
const stream = await testServer.generateContentStream(
{ model: 'test-model', contents: [] },
'user-prompt-id',
LlmRole.MAIN,
);
const results = [];
for await (const res of stream) {
results.push(res);
}
expect(results).toHaveLength(3);
expect(results[0].candidates).toHaveLength(1);
expect(results[0].candidates?.[0].content?.parts?.[0].text).toBe('Hello');
// Chunk 2 (metadata-only) should still be yielded but with empty candidates
expect(results[1].candidates).toHaveLength(0);
expect(results[1].usageMetadata?.promptTokenCount).toBe(10);
expect(results[2].candidates).toHaveLength(1);
expect(results[2].candidates?.[0].content?.parts?.[0].text).toBe(
' World',
);
});
});
});