Switch code assist client to Zod.

This commit is contained in:
Christian Gunderman
2026-02-09 16:07:22 -08:00
parent fd65416a2f
commit 898382ad61
5 changed files with 223 additions and 149 deletions

View File

@@ -14,10 +14,7 @@ import type {
CountTokensResponse,
GenerationConfigRoutingConfig,
MediaResolution,
Candidate,
ModelSelectionConfig,
GenerateContentResponsePromptFeedback,
GenerateContentResponseUsageMetadata,
Part,
SafetySetting,
PartUnion,
@@ -27,6 +24,7 @@ import type {
ToolConfig,
} from '@google/genai';
import { GenerateContentResponse } from '@google/genai';
import { z } from 'zod';
export interface CAGenerateContentRequest {
model: string;
@@ -71,18 +69,22 @@ interface VertexGenerationConfig {
thinkingConfig?: ThinkingConfig;
}
export interface CaGenerateContentResponse {
response: VertexGenerateContentResponse;
traceId?: string;
}
const AnySchema = z.any();
interface VertexGenerateContentResponse {
candidates: Candidate[];
automaticFunctionCallingHistory?: Content[];
promptFeedback?: GenerateContentResponsePromptFeedback;
usageMetadata?: GenerateContentResponseUsageMetadata;
modelVersion?: string;
}
export const CaGenerateContentResponseSchema = z.object({
response: z.object({
candidates: z.array(AnySchema),
automaticFunctionCallingHistory: z.array(AnySchema).optional(),
promptFeedback: AnySchema.optional(),
usageMetadata: AnySchema.optional(),
modelVersion: z.string().optional(),
}),
traceId: z.string().optional(),
});
export type CaGenerateContentResponse = z.infer<
typeof CaGenerateContentResponseSchema
>;
export interface CaCountTokenRequest {
request: VertexCountTokenRequest;
@@ -93,9 +95,11 @@ interface VertexCountTokenRequest {
contents: Content[];
}
export interface CaCountTokenResponse {
totalTokens: number;
}
export const CaCountTokenResponseSchema = z.object({
totalTokens: z.number(),
});
export type CaCountTokenResponse = z.infer<typeof CaCountTokenResponseSchema>;
export function toCountTokenRequest(
req: CountTokensParameters,

View File

@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { z } from 'zod';
import type { ClientMetadata } from '../types.js';
export interface ListExperimentsRequest {
@@ -11,32 +12,39 @@ export interface ListExperimentsRequest {
metadata?: ClientMetadata;
}
export interface ListExperimentsResponse {
experimentIds?: number[];
flags?: Flag[];
filteredFlags?: FilteredFlag[];
debugString?: string;
}
export const Int32ListSchema = z.object({
values: z.array(z.number()).optional(),
});
export type Int32List = z.infer<typeof Int32ListSchema>;
export interface Flag {
flagId?: number;
boolValue?: boolean;
floatValue?: number;
intValue?: string; // int64
stringValue?: string;
int32ListValue?: Int32List;
stringListValue?: StringList;
}
export const StringListSchema = z.object({
values: z.array(z.string()).optional(),
});
export type StringList = z.infer<typeof StringListSchema>;
export interface Int32List {
values?: number[];
}
export const FlagSchema = z.object({
flagId: z.number().optional(),
boolValue: z.boolean().optional(),
floatValue: z.number().optional(),
intValue: z.string().optional(),
stringValue: z.string().optional(),
int32ListValue: Int32ListSchema.optional(),
stringListValue: StringListSchema.optional(),
});
export type Flag = z.infer<typeof FlagSchema>;
export interface StringList {
values?: string[];
}
export const FilteredFlagSchema = z.object({
name: z.string().optional(),
reason: z.string().optional(),
});
export type FilteredFlag = z.infer<typeof FilteredFlagSchema>;
export interface FilteredFlag {
name?: string;
reason?: string;
}
export const ListExperimentsResponseSchema = z.object({
experimentIds: z.array(z.number()).optional(),
flags: z.array(FlagSchema).optional(),
filteredFlags: z.array(FilteredFlagSchema).optional(),
debugString: z.string().optional(),
});
export type ListExperimentsResponse = z.infer<
typeof ListExperimentsResponseSchema
>;

View File

@@ -5,6 +5,7 @@
*/
import { beforeEach, describe, it, expect, vi, afterEach } from 'vitest';
import { z } from 'zod';
import { CodeAssistServer } from './server.js';
import { OAuth2Client } from 'google-auth-library';
import { UserTierId, ActionStatus } from './types.js';
@@ -412,7 +413,7 @@ describe('CodeAssistServer', () => {
mockRequest.mockResolvedValue({ data: mockStream });
const stream = await server.requestStreamingPost('testStream', {});
const stream = await server.requestStreamingPost('testStream', {}, z.any());
setTimeout(() => {
mockStream.push('this is a malformed line\n');
@@ -444,6 +445,7 @@ describe('CodeAssistServer', () => {
expect(server.requestPost).toHaveBeenCalledWith(
'onboardUser',
expect.any(Object),
expect.anything(),
);
expect(response.name).toBe('operations/123');
});
@@ -494,6 +496,7 @@ describe('CodeAssistServer', () => {
expect(server.requestPost).toHaveBeenCalledWith(
'loadCodeAssist',
expect.any(Object),
expect.anything(),
);
expect(response).toEqual(mockResponse);
});
@@ -546,6 +549,7 @@ describe('CodeAssistServer', () => {
expect(server.requestPost).toHaveBeenCalledWith(
'loadCodeAssist',
expect.any(Object),
expect.anything(),
);
expect(response).toEqual({
currentTier: { id: UserTierId.STANDARD },
@@ -564,6 +568,7 @@ describe('CodeAssistServer', () => {
expect(server.requestPost).toHaveBeenCalledWith(
'loadCodeAssist',
expect.any(Object),
expect.anything(),
);
});
@@ -579,10 +584,14 @@ describe('CodeAssistServer', () => {
};
const response = await server.listExperiments(metadata);
expect(server.requestPost).toHaveBeenCalledWith('listExperiments', {
project: 'test-project',
metadata: { ideVersion: 'v0.1.0', duetProject: 'test-project' },
});
expect(server.requestPost).toHaveBeenCalledWith(
'listExperiments',
{
project: 'test-project',
metadata: { ideVersion: 'v0.1.0', duetProject: 'test-project' },
},
expect.anything(),
);
expect(response).toEqual(mockResponse);
});
@@ -609,7 +618,11 @@ describe('CodeAssistServer', () => {
const response = await server.retrieveUserQuota(req);
expect(requestPostSpy).toHaveBeenCalledWith('retrieveUserQuota', req);
expect(requestPostSpy).toHaveBeenCalledWith(
'retrieveUserQuota',
req,
expect.anything(),
);
expect(response).toEqual(mockResponse);
});
});

View File

@@ -23,10 +23,20 @@ import type {
StreamingLatency,
RecordCodeAssistMetricsRequest,
} from './types.js';
import {
LongRunningOperationResponseSchema,
LoadCodeAssistResponseSchema,
FetchAdminControlsResponseSchema,
CodeAssistGlobalUserSettingResponseSchema,
RetrieveUserQuotaResponseSchema,
RecordCodeAssistMetricsResponseSchema,
UserTierId,
} from './types.js';
import type {
ListExperimentsRequest,
ListExperimentsResponse,
} from './experiments/types.js';
import { ListExperimentsResponseSchema } from './experiments/types.js';
import type {
CountTokensParameters,
CountTokensResponse,
@@ -36,13 +46,15 @@ import type {
GenerateContentResponse,
} from '@google/genai';
import * as readline from 'node:readline';
import type { z } from 'zod';
import type { ContentGenerator } from '../core/contentGenerator.js';
import { UserTierId } from './types.js';
import type {
CaCountTokenResponse,
CaGenerateContentResponse,
} from './converter.js';
import {
CaGenerateContentResponseSchema,
CaCountTokenResponseSchema,
fromCountTokenResponse,
fromGenerateContentResponse,
toCountTokenRequest,
@@ -85,6 +97,7 @@ export class CodeAssistServer implements ContentGenerator {
this.projectId,
this.sessionId,
),
CaGenerateContentResponseSchema,
req.config?.abortSignal,
);
@@ -135,6 +148,7 @@ export class CodeAssistServer implements ContentGenerator {
this.projectId,
this.sessionId,
),
CaGenerateContentResponseSchema,
req.config?.abortSignal,
);
const duration = formatProtoJsonDuration(Date.now() - start);
@@ -159,7 +173,11 @@ export class CodeAssistServer implements ContentGenerator {
async onboardUser(
req: OnboardUserRequest,
): Promise<LongRunningOperationResponse> {
return this.requestPost<LongRunningOperationResponse>('onboardUser', req);
return this.requestPost<LongRunningOperationResponse>(
'onboardUser',
req,
LongRunningOperationResponseSchema,
);
}
async getOperation(name: string): Promise<LongRunningOperationResponse> {
@@ -173,6 +191,7 @@ export class CodeAssistServer implements ContentGenerator {
return await this.requestPost<LoadCodeAssistResponse>(
'loadCodeAssist',
req,
LoadCodeAssistResponseSchema,
);
} catch (e) {
if (isVpcScAffectedUser(e)) {
@@ -191,6 +210,7 @@ export class CodeAssistServer implements ContentGenerator {
return this.requestPost<FetchAdminControlsResponse>(
'fetchAdminControls',
req,
FetchAdminControlsResponseSchema,
);
}
@@ -206,6 +226,7 @@ export class CodeAssistServer implements ContentGenerator {
return this.requestPost<CodeAssistGlobalUserSettingResponse>(
'setCodeAssistGlobalUserSetting',
req,
CodeAssistGlobalUserSettingResponseSchema,
);
}
@@ -213,6 +234,7 @@ export class CodeAssistServer implements ContentGenerator {
const resp = await this.requestPost<CaCountTokenResponse>(
'countTokens',
toCountTokenRequest(req),
CaCountTokenResponseSchema,
);
return fromCountTokenResponse(resp);
}
@@ -234,7 +256,11 @@ export class CodeAssistServer implements ContentGenerator {
project: projectId,
metadata: { ...metadata, duetProject: projectId },
};
return this.requestPost<ListExperimentsResponse>('listExperiments', req);
return this.requestPost<ListExperimentsResponse>(
'listExperiments',
req,
ListExperimentsResponseSchema,
);
}
async retrieveUserQuota(
@@ -243,6 +269,7 @@ export class CodeAssistServer implements ContentGenerator {
return this.requestPost<RetrieveUserQuotaResponse>(
'retrieveUserQuota',
req,
RetrieveUserQuotaResponseSchema,
);
}
@@ -282,12 +309,17 @@ export class CodeAssistServer implements ContentGenerator {
async recordCodeAssistMetrics(
request: RecordCodeAssistMetricsRequest,
): Promise<void> {
return this.requestPost<void>('recordCodeAssistMetrics', request);
return this.requestPost<void>(
'recordCodeAssistMetrics',
request,
RecordCodeAssistMetricsResponseSchema,
);
}
async requestPost<T>(
method: string,
req: object,
schema: z.ZodType<T>,
signal?: AbortSignal,
): Promise<T> {
const res = await this.client.request({
@@ -301,8 +333,7 @@ export class CodeAssistServer implements ContentGenerator {
body: JSON.stringify(req),
signal,
});
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
return res.data as T;
return schema.parse(res.data);
}
private async makeGetRequest<T>(
@@ -334,6 +365,7 @@ export class CodeAssistServer implements ContentGenerator {
async requestStreamingPost<T>(
method: string,
req: object,
schema: z.ZodType<T>,
signal?: AbortSignal,
): Promise<AsyncGenerator<T>> {
const res = await this.client.request({
@@ -366,8 +398,7 @@ export class CodeAssistServer implements ContentGenerator {
if (bufferedLines.length === 0) {
continue; // no data to yield
}
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
yield JSON.parse(bufferedLines.join('\n')) as T;
yield schema.parse(JSON.parse(bufferedLines.join('\n')));
bufferedLines = []; // Reset the buffer after yielding
}
// Ignore other lines like comments or id fields

View File

@@ -45,50 +45,43 @@ export interface LoadCodeAssistRequest {
}
/**
* Represents LoadCodeAssistResponse proto json field
* http://google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=224
* UserTierId represents IDs returned from the Cloud Code Private API representing a user's tier
*
* This is a subset of all available tiers. Since the source list is frequently updated,
* only add a tierId here if specific client-side handling is required.
*/
export interface LoadCodeAssistResponse {
currentTier?: GeminiUserTier | null;
allowedTiers?: GeminiUserTier[] | null;
ineligibleTiers?: IneligibleTier[] | null;
cloudaicompanionProject?: string | null;
paidTier?: GeminiUserTier | null;
}
export const UserTierId = {
FREE: 'free-tier',
LEGACY: 'legacy-tier',
STANDARD: 'standard-tier',
} as const;
export type UserTierId = (typeof UserTierId)[keyof typeof UserTierId] | string;
/**
* PrivacyNotice reflects the structure received from the CodeAssist in regards to a tier
* privacy notice.
*/
export const PrivacyNoticeSchema = z.object({
showNotice: z.boolean(),
noticeText: z.string().optional(),
});
export type PrivacyNotice = z.infer<typeof PrivacyNoticeSchema>;
/**
* GeminiUserTier reflects the structure received from the CodeAssist when calling LoadCodeAssist.
*/
export interface GeminiUserTier {
id: UserTierId;
name?: string;
description?: string;
// This value is used to declare whether a given tier requires the user to configure the project setting on the IDE settings or not.
userDefinedCloudaicompanionProject?: boolean | null;
isDefault?: boolean;
privacyNotice?: PrivacyNotice;
hasAcceptedTos?: boolean;
hasOnboardedPreviously?: boolean;
}
/**
* Includes information specifying the reasons for a user's ineligibility for a specific tier.
* @param reasonCode mnemonic code representing the reason for in-eligibility.
* @param reasonMessage message to display to the user.
* @param tierId id of the tier.
* @param tierName name of the tier.
*/
export interface IneligibleTier {
reasonCode: IneligibleTierReasonCode;
reasonMessage: string;
tierId: UserTierId;
tierName: string;
validationErrorMessage?: string;
validationUrl?: string;
validationUrlLinkText?: string;
validationLearnMoreUrl?: string;
validationLearnMoreLinkText?: string;
}
export const GeminiUserTierSchema = z.object({
id: z.string(),
name: z.string().optional(),
description: z.string().optional(),
userDefinedCloudaicompanionProject: z.boolean().nullable().optional(),
isDefault: z.boolean().optional(),
privacyNotice: PrivacyNoticeSchema.optional(),
hasAcceptedTos: z.boolean().optional(),
hasOnboardedPreviously: z.boolean().optional(),
});
export type GeminiUserTier = z.infer<typeof GeminiUserTierSchema>;
/**
* List of predefined reason codes when a tier is blocked from a specific tier.
@@ -107,29 +100,40 @@ export enum IneligibleTierReasonCode {
VALIDATION_REQUIRED = 'VALIDATION_REQUIRED',
// go/keep-sorted end
}
/**
* UserTierId represents IDs returned from the Cloud Code Private API representing a user's tier
*
* http://google3/cloud/developer_experience/codeassist/shared/usertier/tiers.go
* This is a subset of all available tiers. Since the source list is frequently updated,
* only add a tierId here if specific client-side handling is required.
*/
export const UserTierId = {
FREE: 'free-tier',
LEGACY: 'legacy-tier',
STANDARD: 'standard-tier',
} as const;
export type UserTierId = (typeof UserTierId)[keyof typeof UserTierId] | string;
/**
* PrivacyNotice reflects the structure received from the CodeAssist in regards to a tier
* privacy notice.
* Includes information specifying the reasons for a user's ineligibility for a specific tier.
* @param reasonCode mnemonic code representing the reason for in-eligibility.
* @param reasonMessage message to display to the user.
* @param tierId id of the tier.
* @param tierName name of the tier.
*/
export interface PrivacyNotice {
showNotice: boolean;
noticeText?: string;
}
export const IneligibleTierSchema = z.object({
reasonCode: z.nativeEnum(IneligibleTierReasonCode),
reasonMessage: z.string(),
tierId: z.string(),
tierName: z.string(),
validationErrorMessage: z.string().optional(),
validationUrl: z.string().optional(),
validationUrlLinkText: z.string().optional(),
validationLearnMoreUrl: z.string().optional(),
validationLearnMoreLinkText: z.string().optional(),
});
export type IneligibleTier = z.infer<typeof IneligibleTierSchema>;
/**
* Represents LoadCodeAssistResponse proto json field
*/
export const LoadCodeAssistResponseSchema = z.object({
currentTier: GeminiUserTierSchema.nullable().optional(),
allowedTiers: z.array(GeminiUserTierSchema).nullable().optional(),
ineligibleTiers: z.array(IneligibleTierSchema).nullable().optional(),
cloudaicompanionProject: z.string().nullable().optional(),
paidTier: GeminiUserTierSchema.nullable().optional(),
});
export type LoadCodeAssistResponse = z.infer<
typeof LoadCodeAssistResponseSchema
>;
/**
* Proto signature of OnboardUserRequest as payload to OnboardUser call
@@ -140,27 +144,32 @@ export interface OnboardUserRequest {
metadata: ClientMetadata | undefined;
}
/**
* Represents LongRunningOperation proto
* http://google3/google/longrunning/operations.proto;rcl=698857719;l=107
*/
export interface LongRunningOperationResponse {
name: string;
done?: boolean;
response?: OnboardUserResponse;
}
/**
* Represents OnboardUserResponse proto
* http://google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=215
*/
export interface OnboardUserResponse {
export const OnboardUserResponseSchema = z.object({
// tslint:disable-next-line:enforce-name-casing This is the name of the field in the proto.
cloudaicompanionProject?: {
id: string;
name: string;
};
}
cloudaicompanionProject: z
.object({
id: z.string(),
name: z.string(),
})
.optional(),
});
export type OnboardUserResponse = z.infer<typeof OnboardUserResponseSchema>;
/**
* Represents LongRunningOperation proto
*/
export const LongRunningOperationResponseSchema = z.object({
name: z.string(),
done: z.boolean().optional(),
response: OnboardUserResponseSchema.optional(),
});
export type LongRunningOperationResponse = z.infer<
typeof LongRunningOperationResponseSchema
>;
/**
* Status code of user license status
@@ -193,10 +202,13 @@ export interface SetCodeAssistGlobalUserSettingRequest {
freeTierDataCollectionOptin?: boolean;
}
export interface CodeAssistGlobalUserSettingResponse {
cloudaicompanionProject?: string;
freeTierDataCollectionOptin: boolean;
}
export const CodeAssistGlobalUserSettingResponseSchema = z.object({
cloudaicompanionProject: z.string().optional(),
freeTierDataCollectionOptin: z.boolean(),
});
export type CodeAssistGlobalUserSettingResponse = z.infer<
typeof CodeAssistGlobalUserSettingResponseSchema
>;
/**
* Relevant fields that can be returned from a Google RPC response
@@ -219,17 +231,21 @@ export interface RetrieveUserQuotaRequest {
userAgent?: string;
}
export interface BucketInfo {
remainingAmount?: string;
remainingFraction?: number;
resetTime?: string;
tokenType?: string;
modelId?: string;
}
export const BucketInfoSchema = z.object({
remainingAmount: z.string().optional(),
remainingFraction: z.number().optional(),
resetTime: z.string().optional(),
tokenType: z.string().optional(),
modelId: z.string().optional(),
});
export type BucketInfo = z.infer<typeof BucketInfoSchema>;
export interface RetrieveUserQuotaResponse {
buckets?: BucketInfo[];
}
export const RetrieveUserQuotaResponseSchema = z.object({
buckets: z.array(BucketInfoSchema).optional(),
});
export type RetrieveUserQuotaResponse = z.infer<
typeof RetrieveUserQuotaResponseSchema
>;
export interface RecordCodeAssistMetricsRequest {
project: string;
@@ -303,10 +319,6 @@ export interface FetchAdminControlsRequest {
project: string;
}
export type FetchAdminControlsResponse = z.infer<
typeof FetchAdminControlsResponseSchema
>;
const ExtensionsSettingSchema = z.object({
extensionsEnabled: z.boolean().optional(),
});
@@ -356,3 +368,9 @@ export const FetchAdminControlsResponseSchema = z.object({
mcpSetting: McpSettingSchema.optional(),
cliFeatureSetting: CliFeatureSettingSchema.optional(),
});
export type FetchAdminControlsResponse = z.infer<
typeof FetchAdminControlsResponseSchema
>;
export const RecordCodeAssistMetricsResponseSchema = z.any();