From 898382ad61ff054d6d7743160b3757744cc027ec Mon Sep 17 00:00:00 2001 From: Christian Gunderman Date: Mon, 9 Feb 2026 16:07:22 -0800 Subject: [PATCH] Switch code assist client to Zod. --- packages/core/src/code_assist/converter.ts | 38 ++-- .../core/src/code_assist/experiments/types.ts | 58 ++--- packages/core/src/code_assist/server.test.ts | 25 ++- packages/core/src/code_assist/server.ts | 47 +++- packages/core/src/code_assist/types.ts | 204 ++++++++++-------- 5 files changed, 223 insertions(+), 149 deletions(-) diff --git a/packages/core/src/code_assist/converter.ts b/packages/core/src/code_assist/converter.ts index 1f2b4417ac..c077aa39bb 100644 --- a/packages/core/src/code_assist/converter.ts +++ b/packages/core/src/code_assist/converter.ts @@ -14,10 +14,7 @@ import type { CountTokensResponse, GenerationConfigRoutingConfig, MediaResolution, - Candidate, ModelSelectionConfig, - GenerateContentResponsePromptFeedback, - GenerateContentResponseUsageMetadata, Part, SafetySetting, PartUnion, @@ -27,6 +24,7 @@ import type { ToolConfig, } from '@google/genai'; import { GenerateContentResponse } from '@google/genai'; +import { z } from 'zod'; export interface CAGenerateContentRequest { model: string; @@ -71,18 +69,22 @@ interface VertexGenerationConfig { thinkingConfig?: ThinkingConfig; } -export interface CaGenerateContentResponse { - response: VertexGenerateContentResponse; - traceId?: string; -} +const AnySchema = z.any(); -interface VertexGenerateContentResponse { - candidates: Candidate[]; - automaticFunctionCallingHistory?: Content[]; - promptFeedback?: GenerateContentResponsePromptFeedback; - usageMetadata?: GenerateContentResponseUsageMetadata; - modelVersion?: string; -} +export const CaGenerateContentResponseSchema = z.object({ + response: z.object({ + candidates: z.array(AnySchema), + automaticFunctionCallingHistory: z.array(AnySchema).optional(), + promptFeedback: AnySchema.optional(), + usageMetadata: AnySchema.optional(), + modelVersion: z.string().optional(), + }), + traceId: z.string().optional(), +}); + +export type CaGenerateContentResponse = z.infer< + typeof CaGenerateContentResponseSchema +>; export interface CaCountTokenRequest { request: VertexCountTokenRequest; @@ -93,9 +95,11 @@ interface VertexCountTokenRequest { contents: Content[]; } -export interface CaCountTokenResponse { - totalTokens: number; -} +export const CaCountTokenResponseSchema = z.object({ + totalTokens: z.number(), +}); + +export type CaCountTokenResponse = z.infer; export function toCountTokenRequest( req: CountTokensParameters, diff --git a/packages/core/src/code_assist/experiments/types.ts b/packages/core/src/code_assist/experiments/types.ts index 510f3a7cbe..aa1c8fbd55 100644 --- a/packages/core/src/code_assist/experiments/types.ts +++ b/packages/core/src/code_assist/experiments/types.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { z } from 'zod'; import type { ClientMetadata } from '../types.js'; export interface ListExperimentsRequest { @@ -11,32 +12,39 @@ export interface ListExperimentsRequest { metadata?: ClientMetadata; } -export interface ListExperimentsResponse { - experimentIds?: number[]; - flags?: Flag[]; - filteredFlags?: FilteredFlag[]; - debugString?: string; -} +export const Int32ListSchema = z.object({ + values: z.array(z.number()).optional(), +}); +export type Int32List = z.infer; -export interface Flag { - flagId?: number; - boolValue?: boolean; - floatValue?: number; - intValue?: string; // int64 - stringValue?: string; - int32ListValue?: Int32List; - stringListValue?: StringList; -} +export const StringListSchema = z.object({ + values: z.array(z.string()).optional(), +}); +export type StringList = z.infer; -export interface Int32List { - values?: number[]; -} +export const FlagSchema = z.object({ + flagId: z.number().optional(), + boolValue: z.boolean().optional(), + floatValue: z.number().optional(), + intValue: z.string().optional(), + stringValue: z.string().optional(), + int32ListValue: Int32ListSchema.optional(), + stringListValue: StringListSchema.optional(), +}); +export type Flag = z.infer; -export interface StringList { - values?: string[]; -} +export const FilteredFlagSchema = z.object({ + name: z.string().optional(), + reason: z.string().optional(), +}); +export type FilteredFlag = z.infer; -export interface FilteredFlag { - name?: string; - reason?: string; -} +export const ListExperimentsResponseSchema = z.object({ + experimentIds: z.array(z.number()).optional(), + flags: z.array(FlagSchema).optional(), + filteredFlags: z.array(FilteredFlagSchema).optional(), + debugString: z.string().optional(), +}); +export type ListExperimentsResponse = z.infer< + typeof ListExperimentsResponseSchema +>; diff --git a/packages/core/src/code_assist/server.test.ts b/packages/core/src/code_assist/server.test.ts index 35b91fd1c5..42c1d89922 100644 --- a/packages/core/src/code_assist/server.test.ts +++ b/packages/core/src/code_assist/server.test.ts @@ -5,6 +5,7 @@ */ import { beforeEach, describe, it, expect, vi, afterEach } from 'vitest'; +import { z } from 'zod'; import { CodeAssistServer } from './server.js'; import { OAuth2Client } from 'google-auth-library'; import { UserTierId, ActionStatus } from './types.js'; @@ -412,7 +413,7 @@ describe('CodeAssistServer', () => { mockRequest.mockResolvedValue({ data: mockStream }); - const stream = await server.requestStreamingPost('testStream', {}); + const stream = await server.requestStreamingPost('testStream', {}, z.any()); setTimeout(() => { mockStream.push('this is a malformed line\n'); @@ -444,6 +445,7 @@ describe('CodeAssistServer', () => { expect(server.requestPost).toHaveBeenCalledWith( 'onboardUser', expect.any(Object), + expect.anything(), ); expect(response.name).toBe('operations/123'); }); @@ -494,6 +496,7 @@ describe('CodeAssistServer', () => { expect(server.requestPost).toHaveBeenCalledWith( 'loadCodeAssist', expect.any(Object), + expect.anything(), ); expect(response).toEqual(mockResponse); }); @@ -546,6 +549,7 @@ describe('CodeAssistServer', () => { expect(server.requestPost).toHaveBeenCalledWith( 'loadCodeAssist', expect.any(Object), + expect.anything(), ); expect(response).toEqual({ currentTier: { id: UserTierId.STANDARD }, @@ -564,6 +568,7 @@ describe('CodeAssistServer', () => { expect(server.requestPost).toHaveBeenCalledWith( 'loadCodeAssist', expect.any(Object), + expect.anything(), ); }); @@ -579,10 +584,14 @@ describe('CodeAssistServer', () => { }; const response = await server.listExperiments(metadata); - expect(server.requestPost).toHaveBeenCalledWith('listExperiments', { - project: 'test-project', - metadata: { ideVersion: 'v0.1.0', duetProject: 'test-project' }, - }); + expect(server.requestPost).toHaveBeenCalledWith( + 'listExperiments', + { + project: 'test-project', + metadata: { ideVersion: 'v0.1.0', duetProject: 'test-project' }, + }, + expect.anything(), + ); expect(response).toEqual(mockResponse); }); @@ -609,7 +618,11 @@ describe('CodeAssistServer', () => { const response = await server.retrieveUserQuota(req); - expect(requestPostSpy).toHaveBeenCalledWith('retrieveUserQuota', req); + expect(requestPostSpy).toHaveBeenCalledWith( + 'retrieveUserQuota', + req, + expect.anything(), + ); expect(response).toEqual(mockResponse); }); }); diff --git a/packages/core/src/code_assist/server.ts b/packages/core/src/code_assist/server.ts index 055c041d2b..c85788e665 100644 --- a/packages/core/src/code_assist/server.ts +++ b/packages/core/src/code_assist/server.ts @@ -23,10 +23,20 @@ import type { StreamingLatency, RecordCodeAssistMetricsRequest, } from './types.js'; +import { + LongRunningOperationResponseSchema, + LoadCodeAssistResponseSchema, + FetchAdminControlsResponseSchema, + CodeAssistGlobalUserSettingResponseSchema, + RetrieveUserQuotaResponseSchema, + RecordCodeAssistMetricsResponseSchema, + UserTierId, +} from './types.js'; import type { ListExperimentsRequest, ListExperimentsResponse, } from './experiments/types.js'; +import { ListExperimentsResponseSchema } from './experiments/types.js'; import type { CountTokensParameters, CountTokensResponse, @@ -36,13 +46,15 @@ import type { GenerateContentResponse, } from '@google/genai'; import * as readline from 'node:readline'; +import type { z } from 'zod'; import type { ContentGenerator } from '../core/contentGenerator.js'; -import { UserTierId } from './types.js'; import type { CaCountTokenResponse, CaGenerateContentResponse, } from './converter.js'; import { + CaGenerateContentResponseSchema, + CaCountTokenResponseSchema, fromCountTokenResponse, fromGenerateContentResponse, toCountTokenRequest, @@ -85,6 +97,7 @@ export class CodeAssistServer implements ContentGenerator { this.projectId, this.sessionId, ), + CaGenerateContentResponseSchema, req.config?.abortSignal, ); @@ -135,6 +148,7 @@ export class CodeAssistServer implements ContentGenerator { this.projectId, this.sessionId, ), + CaGenerateContentResponseSchema, req.config?.abortSignal, ); const duration = formatProtoJsonDuration(Date.now() - start); @@ -159,7 +173,11 @@ export class CodeAssistServer implements ContentGenerator { async onboardUser( req: OnboardUserRequest, ): Promise { - return this.requestPost('onboardUser', req); + return this.requestPost( + 'onboardUser', + req, + LongRunningOperationResponseSchema, + ); } async getOperation(name: string): Promise { @@ -173,6 +191,7 @@ export class CodeAssistServer implements ContentGenerator { return await this.requestPost( 'loadCodeAssist', req, + LoadCodeAssistResponseSchema, ); } catch (e) { if (isVpcScAffectedUser(e)) { @@ -191,6 +210,7 @@ export class CodeAssistServer implements ContentGenerator { return this.requestPost( 'fetchAdminControls', req, + FetchAdminControlsResponseSchema, ); } @@ -206,6 +226,7 @@ export class CodeAssistServer implements ContentGenerator { return this.requestPost( 'setCodeAssistGlobalUserSetting', req, + CodeAssistGlobalUserSettingResponseSchema, ); } @@ -213,6 +234,7 @@ export class CodeAssistServer implements ContentGenerator { const resp = await this.requestPost( 'countTokens', toCountTokenRequest(req), + CaCountTokenResponseSchema, ); return fromCountTokenResponse(resp); } @@ -234,7 +256,11 @@ export class CodeAssistServer implements ContentGenerator { project: projectId, metadata: { ...metadata, duetProject: projectId }, }; - return this.requestPost('listExperiments', req); + return this.requestPost( + 'listExperiments', + req, + ListExperimentsResponseSchema, + ); } async retrieveUserQuota( @@ -243,6 +269,7 @@ export class CodeAssistServer implements ContentGenerator { return this.requestPost( 'retrieveUserQuota', req, + RetrieveUserQuotaResponseSchema, ); } @@ -282,12 +309,17 @@ export class CodeAssistServer implements ContentGenerator { async recordCodeAssistMetrics( request: RecordCodeAssistMetricsRequest, ): Promise { - return this.requestPost('recordCodeAssistMetrics', request); + return this.requestPost( + 'recordCodeAssistMetrics', + request, + RecordCodeAssistMetricsResponseSchema, + ); } async requestPost( method: string, req: object, + schema: z.ZodType, signal?: AbortSignal, ): Promise { const res = await this.client.request({ @@ -301,8 +333,7 @@ export class CodeAssistServer implements ContentGenerator { body: JSON.stringify(req), signal, }); - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - return res.data as T; + return schema.parse(res.data); } private async makeGetRequest( @@ -334,6 +365,7 @@ export class CodeAssistServer implements ContentGenerator { async requestStreamingPost( method: string, req: object, + schema: z.ZodType, signal?: AbortSignal, ): Promise> { const res = await this.client.request({ @@ -366,8 +398,7 @@ export class CodeAssistServer implements ContentGenerator { if (bufferedLines.length === 0) { continue; // no data to yield } - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - yield JSON.parse(bufferedLines.join('\n')) as T; + yield schema.parse(JSON.parse(bufferedLines.join('\n'))); bufferedLines = []; // Reset the buffer after yielding } // Ignore other lines like comments or id fields diff --git a/packages/core/src/code_assist/types.ts b/packages/core/src/code_assist/types.ts index 3f9bd9fa7e..a2447575c4 100644 --- a/packages/core/src/code_assist/types.ts +++ b/packages/core/src/code_assist/types.ts @@ -45,50 +45,43 @@ export interface LoadCodeAssistRequest { } /** - * Represents LoadCodeAssistResponse proto json field - * http://google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=224 + * UserTierId represents IDs returned from the Cloud Code Private API representing a user's tier + * + * This is a subset of all available tiers. Since the source list is frequently updated, + * only add a tierId here if specific client-side handling is required. */ -export interface LoadCodeAssistResponse { - currentTier?: GeminiUserTier | null; - allowedTiers?: GeminiUserTier[] | null; - ineligibleTiers?: IneligibleTier[] | null; - cloudaicompanionProject?: string | null; - paidTier?: GeminiUserTier | null; -} +export const UserTierId = { + FREE: 'free-tier', + LEGACY: 'legacy-tier', + STANDARD: 'standard-tier', +} as const; + +export type UserTierId = (typeof UserTierId)[keyof typeof UserTierId] | string; + +/** + * PrivacyNotice reflects the structure received from the CodeAssist in regards to a tier + * privacy notice. + */ +export const PrivacyNoticeSchema = z.object({ + showNotice: z.boolean(), + noticeText: z.string().optional(), +}); +export type PrivacyNotice = z.infer; /** * GeminiUserTier reflects the structure received from the CodeAssist when calling LoadCodeAssist. */ -export interface GeminiUserTier { - id: UserTierId; - name?: string; - description?: string; - // This value is used to declare whether a given tier requires the user to configure the project setting on the IDE settings or not. - userDefinedCloudaicompanionProject?: boolean | null; - isDefault?: boolean; - privacyNotice?: PrivacyNotice; - hasAcceptedTos?: boolean; - hasOnboardedPreviously?: boolean; -} - -/** - * Includes information specifying the reasons for a user's ineligibility for a specific tier. - * @param reasonCode mnemonic code representing the reason for in-eligibility. - * @param reasonMessage message to display to the user. - * @param tierId id of the tier. - * @param tierName name of the tier. - */ -export interface IneligibleTier { - reasonCode: IneligibleTierReasonCode; - reasonMessage: string; - tierId: UserTierId; - tierName: string; - validationErrorMessage?: string; - validationUrl?: string; - validationUrlLinkText?: string; - validationLearnMoreUrl?: string; - validationLearnMoreLinkText?: string; -} +export const GeminiUserTierSchema = z.object({ + id: z.string(), + name: z.string().optional(), + description: z.string().optional(), + userDefinedCloudaicompanionProject: z.boolean().nullable().optional(), + isDefault: z.boolean().optional(), + privacyNotice: PrivacyNoticeSchema.optional(), + hasAcceptedTos: z.boolean().optional(), + hasOnboardedPreviously: z.boolean().optional(), +}); +export type GeminiUserTier = z.infer; /** * List of predefined reason codes when a tier is blocked from a specific tier. @@ -107,29 +100,40 @@ export enum IneligibleTierReasonCode { VALIDATION_REQUIRED = 'VALIDATION_REQUIRED', // go/keep-sorted end } -/** - * UserTierId represents IDs returned from the Cloud Code Private API representing a user's tier - * - * http://google3/cloud/developer_experience/codeassist/shared/usertier/tiers.go - * This is a subset of all available tiers. Since the source list is frequently updated, - * only add a tierId here if specific client-side handling is required. - */ -export const UserTierId = { - FREE: 'free-tier', - LEGACY: 'legacy-tier', - STANDARD: 'standard-tier', -} as const; - -export type UserTierId = (typeof UserTierId)[keyof typeof UserTierId] | string; /** - * PrivacyNotice reflects the structure received from the CodeAssist in regards to a tier - * privacy notice. + * Includes information specifying the reasons for a user's ineligibility for a specific tier. + * @param reasonCode mnemonic code representing the reason for in-eligibility. + * @param reasonMessage message to display to the user. + * @param tierId id of the tier. + * @param tierName name of the tier. */ -export interface PrivacyNotice { - showNotice: boolean; - noticeText?: string; -} +export const IneligibleTierSchema = z.object({ + reasonCode: z.nativeEnum(IneligibleTierReasonCode), + reasonMessage: z.string(), + tierId: z.string(), + tierName: z.string(), + validationErrorMessage: z.string().optional(), + validationUrl: z.string().optional(), + validationUrlLinkText: z.string().optional(), + validationLearnMoreUrl: z.string().optional(), + validationLearnMoreLinkText: z.string().optional(), +}); +export type IneligibleTier = z.infer; + +/** + * Represents LoadCodeAssistResponse proto json field + */ +export const LoadCodeAssistResponseSchema = z.object({ + currentTier: GeminiUserTierSchema.nullable().optional(), + allowedTiers: z.array(GeminiUserTierSchema).nullable().optional(), + ineligibleTiers: z.array(IneligibleTierSchema).nullable().optional(), + cloudaicompanionProject: z.string().nullable().optional(), + paidTier: GeminiUserTierSchema.nullable().optional(), +}); +export type LoadCodeAssistResponse = z.infer< + typeof LoadCodeAssistResponseSchema +>; /** * Proto signature of OnboardUserRequest as payload to OnboardUser call @@ -140,27 +144,32 @@ export interface OnboardUserRequest { metadata: ClientMetadata | undefined; } -/** - * Represents LongRunningOperation proto - * http://google3/google/longrunning/operations.proto;rcl=698857719;l=107 - */ -export interface LongRunningOperationResponse { - name: string; - done?: boolean; - response?: OnboardUserResponse; -} - /** * Represents OnboardUserResponse proto * http://google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=215 */ -export interface OnboardUserResponse { +export const OnboardUserResponseSchema = z.object({ // tslint:disable-next-line:enforce-name-casing This is the name of the field in the proto. - cloudaicompanionProject?: { - id: string; - name: string; - }; -} + cloudaicompanionProject: z + .object({ + id: z.string(), + name: z.string(), + }) + .optional(), +}); +export type OnboardUserResponse = z.infer; + +/** + * Represents LongRunningOperation proto + */ +export const LongRunningOperationResponseSchema = z.object({ + name: z.string(), + done: z.boolean().optional(), + response: OnboardUserResponseSchema.optional(), +}); +export type LongRunningOperationResponse = z.infer< + typeof LongRunningOperationResponseSchema +>; /** * Status code of user license status @@ -193,10 +202,13 @@ export interface SetCodeAssistGlobalUserSettingRequest { freeTierDataCollectionOptin?: boolean; } -export interface CodeAssistGlobalUserSettingResponse { - cloudaicompanionProject?: string; - freeTierDataCollectionOptin: boolean; -} +export const CodeAssistGlobalUserSettingResponseSchema = z.object({ + cloudaicompanionProject: z.string().optional(), + freeTierDataCollectionOptin: z.boolean(), +}); +export type CodeAssistGlobalUserSettingResponse = z.infer< + typeof CodeAssistGlobalUserSettingResponseSchema +>; /** * Relevant fields that can be returned from a Google RPC response @@ -219,17 +231,21 @@ export interface RetrieveUserQuotaRequest { userAgent?: string; } -export interface BucketInfo { - remainingAmount?: string; - remainingFraction?: number; - resetTime?: string; - tokenType?: string; - modelId?: string; -} +export const BucketInfoSchema = z.object({ + remainingAmount: z.string().optional(), + remainingFraction: z.number().optional(), + resetTime: z.string().optional(), + tokenType: z.string().optional(), + modelId: z.string().optional(), +}); +export type BucketInfo = z.infer; -export interface RetrieveUserQuotaResponse { - buckets?: BucketInfo[]; -} +export const RetrieveUserQuotaResponseSchema = z.object({ + buckets: z.array(BucketInfoSchema).optional(), +}); +export type RetrieveUserQuotaResponse = z.infer< + typeof RetrieveUserQuotaResponseSchema +>; export interface RecordCodeAssistMetricsRequest { project: string; @@ -303,10 +319,6 @@ export interface FetchAdminControlsRequest { project: string; } -export type FetchAdminControlsResponse = z.infer< - typeof FetchAdminControlsResponseSchema ->; - const ExtensionsSettingSchema = z.object({ extensionsEnabled: z.boolean().optional(), }); @@ -356,3 +368,9 @@ export const FetchAdminControlsResponseSchema = z.object({ mcpSetting: McpSettingSchema.optional(), cliFeatureSetting: CliFeatureSettingSchema.optional(), }); + +export type FetchAdminControlsResponse = z.infer< + typeof FetchAdminControlsResponseSchema +>; + +export const RecordCodeAssistMetricsResponseSchema = z.any();