Files
gemini-cli/packages/core/src/code_assist/server.ts

567 lines
15 KiB
TypeScript

/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { AuthClient } from 'google-auth-library';
import {
UserTierId,
type CodeAssistGlobalUserSettingResponse,
type LoadCodeAssistRequest,
type LoadCodeAssistResponse,
type LongRunningOperationResponse,
type OnboardUserRequest,
type SetCodeAssistGlobalUserSettingRequest,
type ClientMetadata,
type RetrieveUserQuotaRequest,
type RetrieveUserQuotaResponse,
type FetchAdminControlsRequest,
type FetchAdminControlsResponse,
type ConversationOffered,
type ConversationInteraction,
type StreamingLatency,
type RecordCodeAssistMetricsRequest,
type GeminiUserTier,
type Credits,
} from './types.js';
import type {
ListExperimentsRequest,
ListExperimentsResponse,
} from './experiments/types.js';
import type {
CountTokensParameters,
CountTokensResponse,
EmbedContentParameters,
EmbedContentResponse,
GenerateContentParameters,
GenerateContentResponse,
} from '@google/genai';
import * as readline from 'node:readline';
import { Readable } from 'node:stream';
import type { ContentGenerator } from '../core/contentGenerator.js';
import type { Config } from '../config/config.js';
import {
G1_CREDIT_TYPE,
getG1CreditBalance,
isOverageEligibleModel,
shouldAutoUseCredits,
} from '../billing/billing.js';
import { logBillingEvent, logInvalidChunk } from '../telemetry/loggers.js';
import { CreditsUsedEvent } from '../telemetry/billingEvents.js';
import {
fromCountTokenResponse,
fromGenerateContentResponse,
toCountTokenRequest,
toGenerateContentRequest,
type CaCountTokenResponse,
type CaGenerateContentResponse,
} from './converter.js';
import {
formatProtoJsonDuration,
recordConversationOffered,
} from './telemetry.js';
import { getClientMetadata } from './experiments/client_metadata.js';
import { InvalidChunkEvent, type LlmRole } from '../telemetry/types.js';
/** HTTP options to be used in each of the requests. */
export interface HttpOptions {
/** Additional HTTP headers to be sent with the request. */
headers?: Record<string, string>;
}
export const CODE_ASSIST_ENDPOINT = 'https://cloudcode-pa.googleapis.com';
export const CODE_ASSIST_API_VERSION = 'v1internal';
const GENERATE_CONTENT_RETRY_DELAY_IN_MILLISECONDS = 1000;
export class CodeAssistServer implements ContentGenerator {
constructor(
readonly client: AuthClient,
readonly projectId?: string,
readonly httpOptions: HttpOptions = {},
readonly sessionId?: string,
readonly userTier?: UserTierId,
readonly userTierName?: string,
readonly paidTier?: GeminiUserTier,
readonly config?: Config,
) {}
async generateContentStream(
req: GenerateContentParameters,
userPromptId: string,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
role: LlmRole,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const autoUse = this.config
? shouldAutoUseCredits(
this.config.getBillingSettings().overageStrategy,
getG1CreditBalance(this.paidTier),
)
: false;
const modelIsEligible = isOverageEligibleModel(req.model);
const shouldEnableCredits = modelIsEligible && autoUse;
const enabledCreditTypes = shouldEnableCredits
? ([G1_CREDIT_TYPE] as string[])
: undefined;
const responses =
await this.requestStreamingPost<CaGenerateContentResponse>(
'streamGenerateContent',
toGenerateContentRequest(
req,
userPromptId,
this.projectId,
this.sessionId,
enabledCreditTypes,
),
req.config?.abortSignal,
);
const streamingLatency: StreamingLatency = {};
const start = Date.now();
let isFirst = true;
return (async function* (
server: CodeAssistServer,
): AsyncGenerator<GenerateContentResponse> {
let totalConsumed = 0;
let lastRemaining = 0;
for await (const response of responses) {
if (isFirst) {
streamingLatency.firstMessageLatency = formatProtoJsonDuration(
Date.now() - start,
);
isFirst = false;
}
streamingLatency.totalLatency = formatProtoJsonDuration(
Date.now() - start,
);
const translatedResponse = fromGenerateContentResponse(response);
await recordConversationOffered(
server,
response.traceId,
translatedResponse,
streamingLatency,
req.config?.abortSignal,
);
if (response.consumedCredits) {
for (const credit of response.consumedCredits) {
if (credit.creditType === G1_CREDIT_TYPE && credit.creditAmount) {
totalConsumed += parseInt(credit.creditAmount, 10) || 0;
}
}
}
if (response.remainingCredits) {
// Sum all G1 credit entries for consistency with getG1CreditBalance
lastRemaining = response.remainingCredits.reduce((sum, credit) => {
if (credit.creditType === G1_CREDIT_TYPE && credit.creditAmount) {
return sum + (parseInt(credit.creditAmount, 10) || 0);
}
return sum;
}, 0);
server.updateCredits(response.remainingCredits);
}
yield translatedResponse;
}
// Emit credits used telemetry after the stream completes
if (totalConsumed > 0 && server.config) {
logBillingEvent(
server.config,
new CreditsUsedEvent(
req.model ?? 'unknown',
totalConsumed,
lastRemaining,
),
);
}
})(this);
}
async generateContent(
req: GenerateContentParameters,
userPromptId: string,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
role: LlmRole,
): Promise<GenerateContentResponse> {
const start = Date.now();
const response = await this.requestPost<CaGenerateContentResponse>(
'generateContent',
toGenerateContentRequest(
req,
userPromptId,
this.projectId,
this.sessionId,
undefined,
),
req.config?.abortSignal,
GENERATE_CONTENT_RETRY_DELAY_IN_MILLISECONDS,
);
const duration = formatProtoJsonDuration(Date.now() - start);
const streamingLatency: StreamingLatency = {
totalLatency: duration,
firstMessageLatency: duration,
};
const translatedResponse = fromGenerateContentResponse(response);
await recordConversationOffered(
this,
response.traceId,
translatedResponse,
streamingLatency,
req.config?.abortSignal,
);
if (response.remainingCredits) {
this.updateCredits(response.remainingCredits);
}
return translatedResponse;
}
private updateCredits(remainingCredits: Credits[]): void {
if (!this.paidTier) {
return;
}
// Replace the G1 credits entries with the latest remaining amounts.
// Non-G1 credits are preserved as-is.
const nonG1Credits = (this.paidTier.availableCredits ?? []).filter(
(c) => c.creditType !== G1_CREDIT_TYPE,
);
const updatedG1Credits = remainingCredits.filter(
(c) => c.creditType === G1_CREDIT_TYPE,
);
this.paidTier.availableCredits = [...nonG1Credits, ...updatedG1Credits];
}
async onboardUser(
req: OnboardUserRequest,
): Promise<LongRunningOperationResponse> {
return this.requestPost<LongRunningOperationResponse>('onboardUser', req);
}
async getOperation(name: string): Promise<LongRunningOperationResponse> {
return this.requestGetOperation<LongRunningOperationResponse>(name);
}
async loadCodeAssist(
req: LoadCodeAssistRequest,
): Promise<LoadCodeAssistResponse> {
try {
return await this.requestPost<LoadCodeAssistResponse>(
'loadCodeAssist',
req,
);
} catch (e) {
if (isVpcScAffectedUser(e)) {
return {
currentTier: { id: UserTierId.STANDARD },
};
} else {
throw e;
}
}
}
async refreshAvailableCredits(): Promise<void> {
if (!this.paidTier) {
return;
}
const res = await this.loadCodeAssist({
cloudaicompanionProject: this.projectId,
metadata: {
ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI',
duetProject: this.projectId,
},
mode: 'HEALTH_CHECK',
});
if (res.paidTier?.availableCredits) {
this.paidTier.availableCredits = res.paidTier.availableCredits;
}
}
async fetchAdminControls(
req: FetchAdminControlsRequest,
): Promise<FetchAdminControlsResponse> {
return this.requestPost<FetchAdminControlsResponse>(
'fetchAdminControls',
req,
);
}
async getCodeAssistGlobalUserSetting(): Promise<CodeAssistGlobalUserSettingResponse> {
return this.requestGet<CodeAssistGlobalUserSettingResponse>(
'getCodeAssistGlobalUserSetting',
);
}
async setCodeAssistGlobalUserSetting(
req: SetCodeAssistGlobalUserSettingRequest,
): Promise<CodeAssistGlobalUserSettingResponse> {
return this.requestPost<CodeAssistGlobalUserSettingResponse>(
'setCodeAssistGlobalUserSetting',
req,
);
}
async countTokens(req: CountTokensParameters): Promise<CountTokensResponse> {
const resp = await this.requestPost<CaCountTokenResponse>(
'countTokens',
toCountTokenRequest(req),
);
return fromCountTokenResponse(resp);
}
async embedContent(
_req: EmbedContentParameters,
): Promise<EmbedContentResponse> {
throw Error();
}
async listExperiments(
metadata: ClientMetadata,
): Promise<ListExperimentsResponse> {
if (!this.projectId) {
throw new Error('projectId is not defined for CodeAssistServer.');
}
const projectId = this.projectId;
const req: ListExperimentsRequest = {
project: projectId,
metadata: { ...metadata, duetProject: projectId },
};
return this.requestPost<ListExperimentsResponse>('listExperiments', req);
}
async retrieveUserQuota(
req: RetrieveUserQuotaRequest,
): Promise<RetrieveUserQuotaResponse> {
return this.requestPost<RetrieveUserQuotaResponse>(
'retrieveUserQuota',
req,
);
}
async recordConversationOffered(
conversationOffered: ConversationOffered,
): Promise<void> {
if (!this.projectId) {
return;
}
await this.recordCodeAssistMetrics({
project: this.projectId,
metadata: await getClientMetadata(),
metrics: [{ conversationOffered, timestamp: new Date().toISOString() }],
});
}
async recordConversationInteraction(
interaction: ConversationInteraction,
): Promise<void> {
if (!this.projectId) {
return;
}
await this.recordCodeAssistMetrics({
project: this.projectId,
metadata: await getClientMetadata(),
metrics: [
{
conversationInteraction: interaction,
timestamp: new Date().toISOString(),
},
],
});
}
async recordCodeAssistMetrics(
request: RecordCodeAssistMetricsRequest,
): Promise<void> {
return this.requestPost<void>('recordCodeAssistMetrics', request);
}
async requestPost<T>(
method: string,
req: object,
signal?: AbortSignal,
retryDelay: number = 100,
): Promise<T> {
const res = await this.client.request<T>({
url: this.getMethodUrl(method),
method: 'POST',
headers: {
'Content-Type': 'application/json',
...this.httpOptions.headers,
},
responseType: 'json',
body: JSON.stringify(req),
signal,
retryConfig: {
retryDelay,
retry: 3,
noResponseRetries: 3,
statusCodesToRetry: [
[429, 429],
[499, 499],
[500, 599],
],
},
});
return res.data;
}
private async makeGetRequest<T>(
url: string,
signal?: AbortSignal,
): Promise<T> {
const res = await this.client.request<T>({
url,
method: 'GET',
headers: {
'Content-Type': 'application/json',
...this.httpOptions.headers,
},
responseType: 'json',
signal,
});
return res.data;
}
async requestGet<T>(method: string, signal?: AbortSignal): Promise<T> {
return this.makeGetRequest<T>(this.getMethodUrl(method), signal);
}
async requestGetOperation<T>(name: string, signal?: AbortSignal): Promise<T> {
return this.makeGetRequest<T>(this.getOperationUrl(name), signal);
}
async requestStreamingPost<T>(
method: string,
req: object,
signal?: AbortSignal,
): Promise<AsyncGenerator<T>> {
const res = await this.client.request<AsyncIterable<unknown>>({
url: this.getMethodUrl(method),
method: 'POST',
params: {
alt: 'sse',
},
headers: {
'Content-Type': 'application/json',
...this.httpOptions.headers,
},
responseType: 'stream',
body: JSON.stringify(req),
signal,
retry: false,
});
return (async function* (server: CodeAssistServer): AsyncGenerator<T> {
const rl = readline.createInterface({
input: Readable.from(res.data),
crlfDelay: Infinity, // Recognizes '\r\n' and '\n' as line breaks
});
let bufferedLines: string[] = [];
for await (const line of rl) {
if (line.startsWith('data: ')) {
bufferedLines.push(line.slice(6).trim());
} else if (line === '') {
if (bufferedLines.length === 0) {
continue; // no data to yield
}
const chunk = bufferedLines.join('\n');
try {
yield JSON.parse(chunk);
} catch (_e) {
if (server.config) {
logInvalidChunk(
server.config,
// Don't include the chunk content in the log for security/privacy reasons.
new InvalidChunkEvent('Malformed JSON chunk'),
);
}
}
bufferedLines = []; // Reset the buffer after yielding
}
// Ignore other lines like comments or id fields
}
})(this);
}
private getBaseUrl(): string {
const endpoint =
process.env['CODE_ASSIST_ENDPOINT'] ?? CODE_ASSIST_ENDPOINT;
const version =
process.env['CODE_ASSIST_API_VERSION'] || CODE_ASSIST_API_VERSION;
return `${endpoint}/${version}`;
}
getMethodUrl(method: string): string {
return `${this.getBaseUrl()}:${method}`;
}
getOperationUrl(name: string): string {
return `${this.getBaseUrl()}/${name}`;
}
}
interface VpcScErrorResponse {
response?: {
data?: {
error?: {
details?: unknown[];
};
};
};
}
function isVpcScErrorResponse(error: unknown): error is VpcScErrorResponse & {
response: {
data: {
error: {
details: unknown[];
};
};
};
} {
return (
!!error &&
typeof error === 'object' &&
'response' in error &&
!!error.response &&
typeof error.response === 'object' &&
'data' in error.response &&
!!error.response.data &&
typeof error.response.data === 'object' &&
'error' in error.response.data &&
!!error.response.data.error &&
typeof error.response.data.error === 'object' &&
'details' in error.response.data.error &&
Array.isArray(error.response.data.error.details)
);
}
function isVpcScAffectedUser(error: unknown): boolean {
if (isVpcScErrorResponse(error)) {
return error.response.data.error.details.some(
(detail: unknown) =>
detail &&
typeof detail === 'object' &&
'reason' in detail &&
detail.reason === 'SECURITY_POLICY_VIOLATED',
);
}
return false;
}