From 5e21c8c03c0a74cdeb8be6e1dd756bfd1cf1cc78 Mon Sep 17 00:00:00 2001 From: Christian Gunderman Date: Tue, 16 Dec 2025 09:34:05 -0800 Subject: [PATCH] Code assist service metrics. (#15024) --- packages/core/src/code_assist/server.test.ts | 295 +++++++++++++----- packages/core/src/code_assist/server.ts | 125 +++++++- .../core/src/code_assist/telemetry.test.ts | 162 ++++++++++ packages/core/src/code_assist/telemetry.ts | 93 ++++++ packages/core/src/code_assist/types.ts | 62 +++- packages/core/src/core/turn.test.ts | 7 - packages/core/src/core/turn.ts | 12 +- .../generateContentResponseUtilities.test.ts | 80 +++++ .../utils/generateContentResponseUtilities.ts | 11 + 9 files changed, 731 insertions(+), 116 deletions(-) create mode 100644 packages/core/src/code_assist/telemetry.test.ts create mode 100644 packages/core/src/code_assist/telemetry.ts diff --git a/packages/core/src/code_assist/server.test.ts b/packages/core/src/code_assist/server.test.ts index 281b46549d..321239df96 100644 --- a/packages/core/src/code_assist/server.test.ts +++ b/packages/core/src/code_assist/server.test.ts @@ -7,10 +7,24 @@ import { beforeEach, describe, it, expect, vi, afterEach } from 'vitest'; import { CodeAssistServer } from './server.js'; import { OAuth2Client } from 'google-auth-library'; -import { UserTierId } from './types.js'; +import { UserTierId, ActionStatus } from './types.js'; +import { FinishReason } from '@google/genai'; vi.mock('google-auth-library'); +function createTestServer(headers: Record = {}) { + const mockRequest = vi.fn(); + const client = { request: mockRequest } as unknown as OAuth2Client; + const server = new CodeAssistServer( + client, + 'test-project', + { headers }, + 'test-session', + UserTierId.FREE, + ); + return { server, mockRequest, client }; +} + describe('CodeAssistServer', () => { beforeEach(() => { vi.resetAllMocks(); @@ -29,15 +43,9 @@ describe('CodeAssistServer', () => { }); it('should call the generateContent endpoint', async () => { - const mockRequest = vi.fn(); - const client = { request: mockRequest } as unknown as OAuth2Client; - const server = new CodeAssistServer( - client, - 'test-project', - { headers: { 'x-custom-header': 'test-value' } }, - 'test-session', - UserTierId.FREE, - ); + const { server, mockRequest } = createTestServer({ + 'x-custom-header': 'test-value', + }); const mockResponseData = { response: { candidates: [ @@ -47,7 +55,7 @@ describe('CodeAssistServer', () => { role: 'model', parts: [{ text: 'response' }], }, - finishReason: 'STOP', + finishReason: FinishReason.STOP, safetyRatings: [], }, ], @@ -84,6 +92,190 @@ describe('CodeAssistServer', () => { ); }); + it('should detect error in generateContent response', async () => { + const { server, mockRequest } = createTestServer(); + const mockResponseData = { + traceId: 'test-trace-id', + response: { + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [{ text: 'response' }], + }, + finishReason: FinishReason.SAFETY, + safetyRatings: [], + }, + ], + }, + }; + mockRequest.mockResolvedValue({ data: mockResponseData }); + + const recordConversationOfferedSpy = vi.spyOn( + server, + 'recordConversationOffered', + ); + + await server.generateContent( + { + model: 'test-model', + contents: [{ role: 'user', parts: [{ text: 'request' }] }], + }, + 'user-prompt-id', + ); + + expect(recordConversationOfferedSpy).toHaveBeenCalledWith( + expect.objectContaining({ + status: ActionStatus.ACTION_STATUS_ERROR_UNKNOWN, + }), + ); + }); + + it('should record conversation offered on successful generateContent', async () => { + const { server, mockRequest } = createTestServer(); + const mockResponseData = { + traceId: 'test-trace-id', + response: { + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [{ text: 'response' }], + }, + finishReason: FinishReason.STOP, + safetyRatings: [], + }, + ], + sdkHttpResponse: { + responseInternal: { + ok: true, + }, + }, + }, + }; + mockRequest.mockResolvedValue({ data: mockResponseData }); + vi.spyOn(server, 'recordCodeAssistMetrics').mockResolvedValue(undefined); + + await server.generateContent( + { + model: 'test-model', + contents: [{ role: 'user', parts: [{ text: 'request' }] }], + }, + 'user-prompt-id', + ); + + expect(server.recordCodeAssistMetrics).toHaveBeenCalledWith( + expect.objectContaining({ + metrics: expect.arrayContaining([ + expect.objectContaining({ + conversationOffered: expect.objectContaining({ + traceId: 'test-trace-id', + status: ActionStatus.ACTION_STATUS_NO_ERROR, + streamingLatency: expect.objectContaining({ + totalLatency: expect.stringMatching(/\d+s/), + firstMessageLatency: expect.stringMatching(/\d+s/), + }), + }), + }), + ]), + }), + ); + }); + + it('should record conversation offered on generateContentStream', async () => { + const { server, mockRequest } = createTestServer(); + + const { Readable } = await import('node:stream'); + const mockStream = new Readable({ read() {} }); + mockRequest.mockResolvedValue({ data: mockStream }); + + vi.spyOn(server, 'recordCodeAssistMetrics').mockResolvedValue(undefined); + + const stream = await server.generateContentStream( + { + model: 'test-model', + contents: [{ role: 'user', parts: [{ text: 'request' }] }], + }, + 'user-prompt-id', + ); + + const mockResponseData = { + traceId: 'stream-trace-id', + response: { + candidates: [{ content: { parts: [{ text: 'chunk' }] } }], + sdkHttpResponse: { + responseInternal: { + ok: true, + }, + }, + }, + }; + + setTimeout(() => { + mockStream.push('data: ' + JSON.stringify(mockResponseData) + '\n\n'); + mockStream.push(null); + }, 0); + + for await (const _ of stream) { + // Consume stream + } + + expect(server.recordCodeAssistMetrics).toHaveBeenCalledWith( + expect.objectContaining({ + metrics: expect.arrayContaining([ + expect.objectContaining({ + conversationOffered: expect.objectContaining({ + traceId: 'stream-trace-id', + }), + }), + ]), + }), + ); + }); + + it('should record conversation interaction', async () => { + const { server } = createTestServer(); + vi.spyOn(server, 'recordCodeAssistMetrics').mockResolvedValue(undefined); + + const interaction = { + traceId: 'test-trace-id', + }; + + await server.recordConversationInteraction(interaction); + + expect(server.recordCodeAssistMetrics).toHaveBeenCalledWith( + expect.objectContaining({ + project: 'test-project', + metrics: expect.arrayContaining([ + expect.objectContaining({ + conversationInteraction: interaction, + }), + ]), + }), + ); + }); + + it('should call recordCodeAssistMetrics endpoint', async () => { + const { server, mockRequest } = createTestServer(); + mockRequest.mockResolvedValue({ data: {} }); + + const req = { + project: 'test-project', + metrics: [], + }; + await server.recordCodeAssistMetrics(req); + + expect(mockRequest).toHaveBeenCalledWith( + expect.objectContaining({ + url: expect.stringContaining(':recordCodeAssistMetrics'), + method: 'POST', + body: expect.any(String), + }), + ); + }); + describe('getMethodUrl', () => { const originalEnv = process.env; @@ -114,15 +306,7 @@ describe('CodeAssistServer', () => { }); it('should call the generateContentStream endpoint and parse SSE', async () => { - const mockRequest = vi.fn(); - const client = { request: mockRequest } as unknown as OAuth2Client; - const server = new CodeAssistServer( - client, - 'test-project', - {}, - 'test-session', - UserTierId.FREE, - ); + const { server, mockRequest } = createTestServer(); // Create a mock readable stream const { Readable } = await import('node:stream'); @@ -179,9 +363,7 @@ describe('CodeAssistServer', () => { }); it('should ignore malformed SSE data', async () => { - const mockRequest = vi.fn(); - const client = { request: mockRequest } as unknown as OAuth2Client; - const server = new CodeAssistServer(client); + const { server, mockRequest } = createTestServer(); const { Readable } = await import('node:stream'); const mockStream = new Readable({ @@ -205,14 +387,8 @@ describe('CodeAssistServer', () => { }); it('should call the onboardUser endpoint', async () => { - const client = new OAuth2Client(); - const server = new CodeAssistServer( - client, - 'test-project', - {}, - 'test-session', - UserTierId.FREE, - ); + const { server } = createTestServer(); + const mockResponse = { name: 'operations/123', done: true, @@ -233,14 +409,7 @@ describe('CodeAssistServer', () => { }); it('should call the loadCodeAssist endpoint', async () => { - const client = new OAuth2Client(); - const server = new CodeAssistServer( - client, - 'test-project', - {}, - 'test-session', - UserTierId.FREE, - ); + const { server } = createTestServer(); const mockResponse = { currentTier: { id: UserTierId.FREE, @@ -265,14 +434,7 @@ describe('CodeAssistServer', () => { }); it('should return 0 for countTokens', async () => { - const client = new OAuth2Client(); - const server = new CodeAssistServer( - client, - 'test-project', - {}, - 'test-session', - UserTierId.FREE, - ); + const { server } = createTestServer(); const mockResponse = { totalTokens: 100, }; @@ -286,14 +448,7 @@ describe('CodeAssistServer', () => { }); it('should throw an error for embedContent', async () => { - const client = new OAuth2Client(); - const server = new CodeAssistServer( - client, - 'test-project', - {}, - 'test-session', - UserTierId.FREE, - ); + const { server } = createTestServer(); await expect( server.embedContent({ model: 'test-model', @@ -303,14 +458,7 @@ describe('CodeAssistServer', () => { }); it('should handle VPC-SC errors when calling loadCodeAssist', async () => { - const client = new OAuth2Client(); - const server = new CodeAssistServer( - client, - 'test-project', - {}, - 'test-session', - UserTierId.FREE, - ); + const { server } = createTestServer(); const mockVpcScError = { response: { data: { @@ -340,8 +488,7 @@ describe('CodeAssistServer', () => { }); it('should re-throw non-VPC-SC errors from loadCodeAssist', async () => { - const client = new OAuth2Client(); - const server = new CodeAssistServer(client); + const { server } = createTestServer(); const genericError = new Error('Something else went wrong'); vi.spyOn(server, 'requestPost').mockRejectedValue(genericError); @@ -356,14 +503,7 @@ describe('CodeAssistServer', () => { }); it('should call the listExperiments endpoint with metadata', async () => { - const client = new OAuth2Client(); - const server = new CodeAssistServer( - client, - 'test-project', - {}, - 'test-session', - UserTierId.FREE, - ); + const { server } = createTestServer(); const mockResponse = { experiments: [], }; @@ -382,14 +522,7 @@ describe('CodeAssistServer', () => { }); it('should call the retrieveUserQuota endpoint', async () => { - const client = new OAuth2Client(); - const server = new CodeAssistServer( - client, - 'test-project', - {}, - 'test-session', - UserTierId.FREE, - ); + const { server } = createTestServer(); const mockResponse = { buckets: [ { diff --git a/packages/core/src/code_assist/server.ts b/packages/core/src/code_assist/server.ts index f345ff0fcc..e6f7009f7b 100644 --- a/packages/core/src/code_assist/server.ts +++ b/packages/core/src/code_assist/server.ts @@ -16,6 +16,10 @@ import type { ClientMetadata, RetrieveUserQuotaRequest, RetrieveUserQuotaResponse, + ConversationOffered, + ConversationInteraction, + StreamingLatency, + RecordCodeAssistMetricsRequest, } from './types.js'; import type { ListExperimentsRequest, @@ -42,6 +46,11 @@ import { toCountTokenRequest, toGenerateContentRequest, } from './converter.js'; +import { + createConversationOffered, + formatProtoJsonDuration, +} from './telemetry.js'; +import { getClientMetadata } from './experiments/client_metadata.js'; /** HTTP options to be used in each of the requests. */ export interface HttpOptions { @@ -65,28 +74,60 @@ export class CodeAssistServer implements ContentGenerator { req: GenerateContentParameters, userPromptId: string, ): Promise> { - const resps = await this.requestStreamingPost( - 'streamGenerateContent', - toGenerateContentRequest( - req, - userPromptId, - this.projectId, - this.sessionId, - ), - req.config?.abortSignal, - ); - return (async function* (): AsyncGenerator { - for await (const resp of resps) { - yield fromGenerateContentResponse(resp); + const responses = + await this.requestStreamingPost( + 'streamGenerateContent', + toGenerateContentRequest( + req, + userPromptId, + this.projectId, + this.sessionId, + ), + req.config?.abortSignal, + ); + + const streamingLatency: StreamingLatency = {}; + const start = Date.now(); + let isFirst = true; + + return (async function* ( + server: CodeAssistServer, + ): AsyncGenerator { + for await (const response of responses) { + if (isFirst) { + streamingLatency.firstMessageLatency = formatProtoJsonDuration( + Date.now() - start, + ); + isFirst = false; + } + + streamingLatency.totalLatency = formatProtoJsonDuration( + Date.now() - start, + ); + + const translatedResponse = fromGenerateContentResponse(response); + + if (response.traceId) { + const offered = createConversationOffered( + translatedResponse, + response.traceId, + req.config?.abortSignal, + streamingLatency, + ); + await server.recordConversationOffered(offered); + } + + yield translatedResponse; } - })(); + })(this); } async generateContent( req: GenerateContentParameters, userPromptId: string, ): Promise { - const resp = await this.requestPost( + const start = Date.now(); + const response = await this.requestPost( 'generateContent', toGenerateContentRequest( req, @@ -96,7 +137,25 @@ export class CodeAssistServer implements ContentGenerator { ), req.config?.abortSignal, ); - return fromGenerateContentResponse(resp); + const duration = formatProtoJsonDuration(Date.now() - start); + const streamingLatency: StreamingLatency = { + totalLatency: duration, + firstMessageLatency: duration, + }; + + const translatedResponse = fromGenerateContentResponse(response); + + if (response.traceId) { + const offered = createConversationOffered( + translatedResponse, + response.traceId, + req.config?.abortSignal, + streamingLatency, + ); + await this.recordConversationOffered(offered); + } + + return translatedResponse; } async onboardUser( @@ -176,6 +235,40 @@ export class CodeAssistServer implements ContentGenerator { ); } + async recordConversationOffered( + conversationOffered: ConversationOffered, + ): Promise { + if (!this.projectId) { + return; + } + + await this.recordCodeAssistMetrics({ + project: this.projectId, + metadata: await getClientMetadata(), + metrics: [{ conversationOffered }], + }); + } + + async recordConversationInteraction( + interaction: ConversationInteraction, + ): Promise { + if (!this.projectId) { + return; + } + + await this.recordCodeAssistMetrics({ + project: this.projectId, + metadata: await getClientMetadata(), + metrics: [{ conversationInteraction: interaction }], + }); + } + + async recordCodeAssistMetrics( + request: RecordCodeAssistMetricsRequest, + ): Promise { + return this.requestPost('recordCodeAssistMetrics', request); + } + async requestPost( method: string, req: object, diff --git a/packages/core/src/code_assist/telemetry.test.ts b/packages/core/src/code_assist/telemetry.test.ts new file mode 100644 index 0000000000..1f6eaff152 --- /dev/null +++ b/packages/core/src/code_assist/telemetry.test.ts @@ -0,0 +1,162 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi } from 'vitest'; +import { + createConversationOffered, + formatProtoJsonDuration, +} from './telemetry.js'; +import { ActionStatus, type StreamingLatency } from './types.js'; +import { FinishReason, GenerateContentResponse } from '@google/genai'; + +function createMockResponse( + candidates: GenerateContentResponse['candidates'] = [], + ok = true, +) { + const response = new GenerateContentResponse(); + response.candidates = candidates; + response.sdkHttpResponse = { + responseInternal: { + ok, + } as unknown as Response, + json: async () => ({}), + }; + return response; +} + +describe('telemetry', () => { + describe('createConversationOffered', () => { + it('should create a ConversationOffered object with correct values', () => { + const response = createMockResponse([ + { + index: 0, + content: { + role: 'model', + parts: [{ text: 'response with ```code```' }], + }, + citationMetadata: { + citations: [ + { uri: 'https://example.com', startIndex: 0, endIndex: 10 }, + ], + }, + finishReason: FinishReason.STOP, + }, + ]); + const traceId = 'test-trace-id'; + const streamingLatency: StreamingLatency = { totalLatency: '1s' }; + + const result = createConversationOffered( + response, + traceId, + undefined, + streamingLatency, + ); + + expect(result).toEqual({ + citationCount: '1', + includedCode: true, + status: ActionStatus.ACTION_STATUS_NO_ERROR, + traceId, + streamingLatency, + isAgentic: true, + }); + }); + + it('should set status to CANCELLED if signal is aborted', () => { + const response = createMockResponse(); + const signal = new AbortController().signal; + vi.spyOn(signal, 'aborted', 'get').mockReturnValue(true); + + const result = createConversationOffered( + response, + 'trace-id', + signal, + {}, + ); + + expect(result.status).toBe(ActionStatus.ACTION_STATUS_CANCELLED); + }); + + it('should set status to ERROR_UNKNOWN if response has error (non-OK SDK response)', () => { + const response = createMockResponse([], false); + + const result = createConversationOffered( + response, + 'trace-id', + undefined, + {}, + ); + + expect(result.status).toBe(ActionStatus.ACTION_STATUS_ERROR_UNKNOWN); + }); + + it('should set status to ERROR_UNKNOWN if finishReason is not STOP or MAX_TOKENS', () => { + const response = createMockResponse([ + { + index: 0, + finishReason: FinishReason.SAFETY, + }, + ]); + + const result = createConversationOffered( + response, + 'trace-id', + undefined, + {}, + ); + + expect(result.status).toBe(ActionStatus.ACTION_STATUS_ERROR_UNKNOWN); + }); + + it('should set status to EMPTY if candidates is empty', () => { + const response = createMockResponse(); + + const result = createConversationOffered( + response, + 'trace-id', + undefined, + {}, + ); + + expect(result.status).toBe(ActionStatus.ACTION_STATUS_EMPTY); + }); + + it('should detect code in response', () => { + const response = createMockResponse([ + { + index: 0, + content: { + parts: [ + { text: 'Here is some code:\n```js\nconsole.log("hi")\n```' }, + ], + }, + }, + ]); + const result = createConversationOffered(response, 'id', undefined, {}); + expect(result.includedCode).toBe(true); + }); + + it('should not detect code if no backticks', () => { + const response = createMockResponse([ + { + index: 0, + content: { + parts: [{ text: 'Here is some text.' }], + }, + }, + ]); + const result = createConversationOffered(response, 'id', undefined, {}); + expect(result.includedCode).toBe(false); + }); + }); + + describe('formatProtoJsonDuration', () => { + it('should format milliseconds to seconds string', () => { + expect(formatProtoJsonDuration(1500)).toBe('1.5s'); + expect(formatProtoJsonDuration(100)).toBe('0.1s'); + }); + }); +}); diff --git a/packages/core/src/code_assist/telemetry.ts b/packages/core/src/code_assist/telemetry.ts new file mode 100644 index 0000000000..bda72a4da1 --- /dev/null +++ b/packages/core/src/code_assist/telemetry.ts @@ -0,0 +1,93 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { FinishReason, type GenerateContentResponse } from '@google/genai'; +import { getCitations } from '../utils/generateContentResponseUtilities.js'; +import { + ActionStatus, + type ConversationOffered, + type StreamingLatency, +} from './types.js'; + +export function createConversationOffered( + response: GenerateContentResponse, + traceId: string, + signal: AbortSignal | undefined, + streamingLatency: StreamingLatency, +): ConversationOffered { + const actionStatus = getStatus(response, signal); + + return { + citationCount: String(getCitations(response).length), + includedCode: includesCode(response), + status: actionStatus, + traceId, + streamingLatency, + isAgentic: true, + }; +} + +function includesCode(resp: GenerateContentResponse): boolean { + if (!resp.candidates) { + return false; + } + for (const candidate of resp.candidates) { + if (!candidate.content || !candidate.content.parts) { + continue; + } + for (const part of candidate.content.parts) { + if ('text' in part && part?.text?.includes('```')) { + return true; + } + } + } + return false; +} + +function getStatus( + response: GenerateContentResponse, + signal: AbortSignal | undefined, +): ActionStatus { + if (signal?.aborted) { + return ActionStatus.ACTION_STATUS_CANCELLED; + } + + if (hasError(response)) { + return ActionStatus.ACTION_STATUS_ERROR_UNKNOWN; + } + + if ((response.candidates?.length ?? 0) <= 0) { + return ActionStatus.ACTION_STATUS_EMPTY; + } + + return ActionStatus.ACTION_STATUS_NO_ERROR; +} + +export function formatProtoJsonDuration(milliseconds: number): string { + return `${milliseconds / 1000}s`; +} + +function hasError(response: GenerateContentResponse): boolean { + // Non-OK SDK results should be considered an error. + if ( + response.sdkHttpResponse && + !response.sdkHttpResponse?.responseInternal?.ok + ) { + return true; + } + + for (const candidate of response.candidates || []) { + // Treat sanitization, SPII, recitation, and forbidden terms as an error. + if ( + candidate.finishReason && + candidate.finishReason !== FinishReason.STOP && + candidate.finishReason !== FinishReason.MAX_TOKENS + ) { + return true; + } + } + return false; +} diff --git a/packages/core/src/code_assist/types.ts b/packages/core/src/code_assist/types.ts index 36e2f3f2fb..824f6ff530 100644 --- a/packages/core/src/code_assist/types.ts +++ b/packages/core/src/code_assist/types.ts @@ -177,7 +177,7 @@ export interface HelpLinkUrl { export interface SetCodeAssistGlobalUserSettingRequest { cloudaicompanionProject?: string; - freeTierDataCollectionOptin: boolean; + freeTierDataCollectionOptin?: boolean; } export interface CodeAssistGlobalUserSettingResponse { @@ -217,3 +217,63 @@ export interface BucketInfo { export interface RetrieveUserQuotaResponse { buckets?: BucketInfo[]; } + +export interface RecordCodeAssistMetricsRequest { + project: string; + requestId?: string; + metadata?: ClientMetadata; + metrics?: CodeAssistMetric[]; +} + +export interface CodeAssistMetric { + timestamp?: string; + metricMetadata?: Map; + + // The event tied to this metric. Only one of these should be set. + conversationOffered?: ConversationOffered; + conversationInteraction?: ConversationInteraction; +} + +export enum ConversationInteractionInteraction { + UNKNOWN = 0, + THUMBSUP = 1, + THUMBSDOWN = 2, + COPY = 3, + INSERT = 4, + ACCEPT_CODE_BLOCK = 5, + ACCEPT_ALL = 6, + ACCEPT_FILE = 7, + DIFF = 8, + ACCEPT_RANGE = 9, +} + +export enum ActionStatus { + ACTION_STATUS_UNSPECIFIED = 0, + ACTION_STATUS_NO_ERROR = 1, + ACTION_STATUS_ERROR_UNKNOWN = 2, + ACTION_STATUS_CANCELLED = 3, + ACTION_STATUS_EMPTY = 4, +} + +export interface ConversationOffered { + citationCount?: string; + includedCode?: boolean; + status?: ActionStatus; + traceId?: string; + streamingLatency?: StreamingLatency; + isAgentic?: boolean; +} + +export interface StreamingLatency { + firstMessageLatency?: string; + totalLatency?: string; +} + +export interface ConversationInteraction { + traceId: string; + status?: ActionStatus; + interaction?: ConversationInteractionInteraction; + acceptedLines?: string; + language?: string; + isAgentic?: boolean; +} diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts index 3627ee927b..e951d80933 100644 --- a/packages/core/src/core/turn.test.ts +++ b/packages/core/src/core/turn.test.ts @@ -36,13 +36,6 @@ vi.mock('../utils/errorReporting', () => ({ reportError: vi.fn(), })); -// Use the actual implementation from partUtils now that it's provided. -vi.mock('../utils/generateContentResponseUtilities', () => ({ - getResponseText: (resp: GenerateContentResponse) => - resp.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') || - undefined, -})); - describe('Turn', () => { let turn: Turn; // Define a type for the mocked Chat instance for clarity diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 4c802d8362..e7ba0d8bb7 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -31,6 +31,7 @@ import { InvalidStreamError } from './geminiChat.js'; import { parseThought, type ThoughtSummary } from '../utils/thoughtUtils.js'; import { createUserContent } from '@google/genai'; import type { ModelConfigKey } from '../services/modelConfigService.js'; +import { getCitations } from '../utils/generateContentResponseUtilities.js'; // Define a structure for tools passed to the server export interface ServerTool { @@ -405,14 +406,3 @@ export class Turn { .join(' '); } } - -function getCitations(resp: GenerateContentResponse): string[] { - return (resp.candidates?.[0]?.citationMetadata?.citations ?? []) - .filter((citation) => citation.uri !== undefined) - .map((citation) => { - if (citation.title) { - return `(${citation.title}) ${citation.uri}`; - } - return citation.uri!; - }); -} diff --git a/packages/core/src/utils/generateContentResponseUtilities.test.ts b/packages/core/src/utils/generateContentResponseUtilities.test.ts index 5f99fa8d42..b4c413fcbf 100644 --- a/packages/core/src/utils/generateContentResponseUtilities.test.ts +++ b/packages/core/src/utils/generateContentResponseUtilities.test.ts @@ -13,11 +13,13 @@ import { getFunctionCallsFromPartsAsJson, getStructuredResponse, getStructuredResponseFromParts, + getCitations, } from './generateContentResponseUtilities.js'; import type { GenerateContentResponse, Part, SafetyRating, + CitationMetadata, } from '@google/genai'; import { FinishReason } from '@google/genai'; @@ -33,6 +35,7 @@ const mockResponse = ( parts: Part[], finishReason: FinishReason = FinishReason.STOP, safetyRatings: SafetyRating[] = [], + citationMetadata?: CitationMetadata, ): GenerateContentResponse => ({ candidates: [ { @@ -43,6 +46,7 @@ const mockResponse = ( index: 0, finishReason, safetyRatings, + citationMetadata, }, ], promptFeedback: { @@ -68,6 +72,82 @@ const minimalMockResponse = ( }); describe('generateContentResponseUtilities', () => { + describe('getCitations', () => { + it('should return empty array for no candidates', () => { + expect(getCitations(minimalMockResponse(undefined))).toEqual([]); + }); + + it('should return empty array if no citationMetadata', () => { + const response = mockResponse([mockTextPart('Hello')]); + expect(getCitations(response)).toEqual([]); + }); + + it('should return citations with title and uri', () => { + const citationMetadata: CitationMetadata = { + citations: [ + { + startIndex: 0, + endIndex: 10, + uri: 'https://example.com', + title: 'Example Title', + }, + ], + }; + const response = mockResponse( + [mockTextPart('Hello')], + undefined, + undefined, + citationMetadata, + ); + expect(getCitations(response)).toEqual([ + '(Example Title) https://example.com', + ]); + }); + + it('should return citations with uri only if no title', () => { + const citationMetadata: CitationMetadata = { + citations: [ + { + startIndex: 0, + endIndex: 10, + uri: 'https://example.com', + }, + ], + }; + const response = mockResponse( + [mockTextPart('Hello')], + undefined, + undefined, + citationMetadata, + ); + expect(getCitations(response)).toEqual(['https://example.com']); + }); + + it('should filter out citations without uri', () => { + const citationMetadata: CitationMetadata = { + citations: [ + { + startIndex: 0, + endIndex: 10, + title: 'No URI', + }, + { + startIndex: 10, + endIndex: 20, + uri: 'https://valid.com', + }, + ], + }; + const response = mockResponse( + [mockTextPart('Hello')], + undefined, + undefined, + citationMetadata, + ); + expect(getCitations(response)).toEqual(['https://valid.com']); + }); + }); + describe('getResponseTextFromParts', () => { it('should return undefined for no parts', () => { expect(getResponseTextFromParts([])).toBeUndefined(); diff --git a/packages/core/src/utils/generateContentResponseUtilities.ts b/packages/core/src/utils/generateContentResponseUtilities.ts index 8d5bec9f42..2532988533 100644 --- a/packages/core/src/utils/generateContentResponseUtilities.ts +++ b/packages/core/src/utils/generateContentResponseUtilities.ts @@ -105,3 +105,14 @@ export function getStructuredResponseFromParts( } return undefined; } + +export function getCitations(resp: GenerateContentResponse): string[] { + return (resp.candidates?.[0]?.citationMetadata?.citations ?? []) + .filter((citation) => citation.uri !== undefined) + .map((citation) => { + if (citation.title) { + return `(${citation.title}) ${citation.uri}`; + } + return citation.uri!; + }); +}