Code assist service metrics. (#15024)

This commit is contained in:
Christian Gunderman
2025-12-16 09:34:05 -08:00
committed by GitHub
parent 5ea5107d05
commit 5e21c8c03c
9 changed files with 731 additions and 116 deletions

View File

@@ -7,10 +7,24 @@
import { beforeEach, describe, it, expect, vi, afterEach } from 'vitest';
import { CodeAssistServer } from './server.js';
import { OAuth2Client } from 'google-auth-library';
import { UserTierId } from './types.js';
import { UserTierId, ActionStatus } from './types.js';
import { FinishReason } from '@google/genai';
vi.mock('google-auth-library');
function createTestServer(headers: Record<string, string> = {}) {
const mockRequest = vi.fn();
const client = { request: mockRequest } as unknown as OAuth2Client;
const server = new CodeAssistServer(
client,
'test-project',
{ headers },
'test-session',
UserTierId.FREE,
);
return { server, mockRequest, client };
}
describe('CodeAssistServer', () => {
beforeEach(() => {
vi.resetAllMocks();
@@ -29,15 +43,9 @@ describe('CodeAssistServer', () => {
});
it('should call the generateContent endpoint', async () => {
const mockRequest = vi.fn();
const client = { request: mockRequest } as unknown as OAuth2Client;
const server = new CodeAssistServer(
client,
'test-project',
{ headers: { 'x-custom-header': 'test-value' } },
'test-session',
UserTierId.FREE,
);
const { server, mockRequest } = createTestServer({
'x-custom-header': 'test-value',
});
const mockResponseData = {
response: {
candidates: [
@@ -47,7 +55,7 @@ describe('CodeAssistServer', () => {
role: 'model',
parts: [{ text: 'response' }],
},
finishReason: 'STOP',
finishReason: FinishReason.STOP,
safetyRatings: [],
},
],
@@ -84,6 +92,190 @@ describe('CodeAssistServer', () => {
);
});
it('should detect error in generateContent response', async () => {
const { server, mockRequest } = createTestServer();
const mockResponseData = {
traceId: 'test-trace-id',
response: {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [{ text: 'response' }],
},
finishReason: FinishReason.SAFETY,
safetyRatings: [],
},
],
},
};
mockRequest.mockResolvedValue({ data: mockResponseData });
const recordConversationOfferedSpy = vi.spyOn(
server,
'recordConversationOffered',
);
await server.generateContent(
{
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
},
'user-prompt-id',
);
expect(recordConversationOfferedSpy).toHaveBeenCalledWith(
expect.objectContaining({
status: ActionStatus.ACTION_STATUS_ERROR_UNKNOWN,
}),
);
});
it('should record conversation offered on successful generateContent', async () => {
const { server, mockRequest } = createTestServer();
const mockResponseData = {
traceId: 'test-trace-id',
response: {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [{ text: 'response' }],
},
finishReason: FinishReason.STOP,
safetyRatings: [],
},
],
sdkHttpResponse: {
responseInternal: {
ok: true,
},
},
},
};
mockRequest.mockResolvedValue({ data: mockResponseData });
vi.spyOn(server, 'recordCodeAssistMetrics').mockResolvedValue(undefined);
await server.generateContent(
{
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
},
'user-prompt-id',
);
expect(server.recordCodeAssistMetrics).toHaveBeenCalledWith(
expect.objectContaining({
metrics: expect.arrayContaining([
expect.objectContaining({
conversationOffered: expect.objectContaining({
traceId: 'test-trace-id',
status: ActionStatus.ACTION_STATUS_NO_ERROR,
streamingLatency: expect.objectContaining({
totalLatency: expect.stringMatching(/\d+s/),
firstMessageLatency: expect.stringMatching(/\d+s/),
}),
}),
}),
]),
}),
);
});
it('should record conversation offered on generateContentStream', async () => {
const { server, mockRequest } = createTestServer();
const { Readable } = await import('node:stream');
const mockStream = new Readable({ read() {} });
mockRequest.mockResolvedValue({ data: mockStream });
vi.spyOn(server, 'recordCodeAssistMetrics').mockResolvedValue(undefined);
const stream = await server.generateContentStream(
{
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
},
'user-prompt-id',
);
const mockResponseData = {
traceId: 'stream-trace-id',
response: {
candidates: [{ content: { parts: [{ text: 'chunk' }] } }],
sdkHttpResponse: {
responseInternal: {
ok: true,
},
},
},
};
setTimeout(() => {
mockStream.push('data: ' + JSON.stringify(mockResponseData) + '\n\n');
mockStream.push(null);
}, 0);
for await (const _ of stream) {
// Consume stream
}
expect(server.recordCodeAssistMetrics).toHaveBeenCalledWith(
expect.objectContaining({
metrics: expect.arrayContaining([
expect.objectContaining({
conversationOffered: expect.objectContaining({
traceId: 'stream-trace-id',
}),
}),
]),
}),
);
});
it('should record conversation interaction', async () => {
const { server } = createTestServer();
vi.spyOn(server, 'recordCodeAssistMetrics').mockResolvedValue(undefined);
const interaction = {
traceId: 'test-trace-id',
};
await server.recordConversationInteraction(interaction);
expect(server.recordCodeAssistMetrics).toHaveBeenCalledWith(
expect.objectContaining({
project: 'test-project',
metrics: expect.arrayContaining([
expect.objectContaining({
conversationInteraction: interaction,
}),
]),
}),
);
});
it('should call recordCodeAssistMetrics endpoint', async () => {
const { server, mockRequest } = createTestServer();
mockRequest.mockResolvedValue({ data: {} });
const req = {
project: 'test-project',
metrics: [],
};
await server.recordCodeAssistMetrics(req);
expect(mockRequest).toHaveBeenCalledWith(
expect.objectContaining({
url: expect.stringContaining(':recordCodeAssistMetrics'),
method: 'POST',
body: expect.any(String),
}),
);
});
describe('getMethodUrl', () => {
const originalEnv = process.env;
@@ -114,15 +306,7 @@ describe('CodeAssistServer', () => {
});
it('should call the generateContentStream endpoint and parse SSE', async () => {
const mockRequest = vi.fn();
const client = { request: mockRequest } as unknown as OAuth2Client;
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const { server, mockRequest } = createTestServer();
// Create a mock readable stream
const { Readable } = await import('node:stream');
@@ -179,9 +363,7 @@ describe('CodeAssistServer', () => {
});
it('should ignore malformed SSE data', async () => {
const mockRequest = vi.fn();
const client = { request: mockRequest } as unknown as OAuth2Client;
const server = new CodeAssistServer(client);
const { server, mockRequest } = createTestServer();
const { Readable } = await import('node:stream');
const mockStream = new Readable({
@@ -205,14 +387,8 @@ describe('CodeAssistServer', () => {
});
it('should call the onboardUser endpoint', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const { server } = createTestServer();
const mockResponse = {
name: 'operations/123',
done: true,
@@ -233,14 +409,7 @@ describe('CodeAssistServer', () => {
});
it('should call the loadCodeAssist endpoint', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const { server } = createTestServer();
const mockResponse = {
currentTier: {
id: UserTierId.FREE,
@@ -265,14 +434,7 @@ describe('CodeAssistServer', () => {
});
it('should return 0 for countTokens', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const { server } = createTestServer();
const mockResponse = {
totalTokens: 100,
};
@@ -286,14 +448,7 @@ describe('CodeAssistServer', () => {
});
it('should throw an error for embedContent', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const { server } = createTestServer();
await expect(
server.embedContent({
model: 'test-model',
@@ -303,14 +458,7 @@ describe('CodeAssistServer', () => {
});
it('should handle VPC-SC errors when calling loadCodeAssist', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const { server } = createTestServer();
const mockVpcScError = {
response: {
data: {
@@ -340,8 +488,7 @@ describe('CodeAssistServer', () => {
});
it('should re-throw non-VPC-SC errors from loadCodeAssist', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(client);
const { server } = createTestServer();
const genericError = new Error('Something else went wrong');
vi.spyOn(server, 'requestPost').mockRejectedValue(genericError);
@@ -356,14 +503,7 @@ describe('CodeAssistServer', () => {
});
it('should call the listExperiments endpoint with metadata', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const { server } = createTestServer();
const mockResponse = {
experiments: [],
};
@@ -382,14 +522,7 @@ describe('CodeAssistServer', () => {
});
it('should call the retrieveUserQuota endpoint', async () => {
const client = new OAuth2Client();
const server = new CodeAssistServer(
client,
'test-project',
{},
'test-session',
UserTierId.FREE,
);
const { server } = createTestServer();
const mockResponse = {
buckets: [
{

View File

@@ -16,6 +16,10 @@ import type {
ClientMetadata,
RetrieveUserQuotaRequest,
RetrieveUserQuotaResponse,
ConversationOffered,
ConversationInteraction,
StreamingLatency,
RecordCodeAssistMetricsRequest,
} from './types.js';
import type {
ListExperimentsRequest,
@@ -42,6 +46,11 @@ import {
toCountTokenRequest,
toGenerateContentRequest,
} from './converter.js';
import {
createConversationOffered,
formatProtoJsonDuration,
} from './telemetry.js';
import { getClientMetadata } from './experiments/client_metadata.js';
/** HTTP options to be used in each of the requests. */
export interface HttpOptions {
@@ -65,28 +74,60 @@ export class CodeAssistServer implements ContentGenerator {
req: GenerateContentParameters,
userPromptId: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const resps = await this.requestStreamingPost<CaGenerateContentResponse>(
'streamGenerateContent',
toGenerateContentRequest(
req,
userPromptId,
this.projectId,
this.sessionId,
),
req.config?.abortSignal,
);
return (async function* (): AsyncGenerator<GenerateContentResponse> {
for await (const resp of resps) {
yield fromGenerateContentResponse(resp);
const responses =
await this.requestStreamingPost<CaGenerateContentResponse>(
'streamGenerateContent',
toGenerateContentRequest(
req,
userPromptId,
this.projectId,
this.sessionId,
),
req.config?.abortSignal,
);
const streamingLatency: StreamingLatency = {};
const start = Date.now();
let isFirst = true;
return (async function* (
server: CodeAssistServer,
): AsyncGenerator<GenerateContentResponse> {
for await (const response of responses) {
if (isFirst) {
streamingLatency.firstMessageLatency = formatProtoJsonDuration(
Date.now() - start,
);
isFirst = false;
}
streamingLatency.totalLatency = formatProtoJsonDuration(
Date.now() - start,
);
const translatedResponse = fromGenerateContentResponse(response);
if (response.traceId) {
const offered = createConversationOffered(
translatedResponse,
response.traceId,
req.config?.abortSignal,
streamingLatency,
);
await server.recordConversationOffered(offered);
}
yield translatedResponse;
}
})();
})(this);
}
async generateContent(
req: GenerateContentParameters,
userPromptId: string,
): Promise<GenerateContentResponse> {
const resp = await this.requestPost<CaGenerateContentResponse>(
const start = Date.now();
const response = await this.requestPost<CaGenerateContentResponse>(
'generateContent',
toGenerateContentRequest(
req,
@@ -96,7 +137,25 @@ export class CodeAssistServer implements ContentGenerator {
),
req.config?.abortSignal,
);
return fromGenerateContentResponse(resp);
const duration = formatProtoJsonDuration(Date.now() - start);
const streamingLatency: StreamingLatency = {
totalLatency: duration,
firstMessageLatency: duration,
};
const translatedResponse = fromGenerateContentResponse(response);
if (response.traceId) {
const offered = createConversationOffered(
translatedResponse,
response.traceId,
req.config?.abortSignal,
streamingLatency,
);
await this.recordConversationOffered(offered);
}
return translatedResponse;
}
async onboardUser(
@@ -176,6 +235,40 @@ export class CodeAssistServer implements ContentGenerator {
);
}
async recordConversationOffered(
conversationOffered: ConversationOffered,
): Promise<void> {
if (!this.projectId) {
return;
}
await this.recordCodeAssistMetrics({
project: this.projectId,
metadata: await getClientMetadata(),
metrics: [{ conversationOffered }],
});
}
async recordConversationInteraction(
interaction: ConversationInteraction,
): Promise<void> {
if (!this.projectId) {
return;
}
await this.recordCodeAssistMetrics({
project: this.projectId,
metadata: await getClientMetadata(),
metrics: [{ conversationInteraction: interaction }],
});
}
async recordCodeAssistMetrics(
request: RecordCodeAssistMetricsRequest,
): Promise<void> {
return this.requestPost<void>('recordCodeAssistMetrics', request);
}
async requestPost<T>(
method: string,
req: object,

View File

@@ -0,0 +1,162 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi } from 'vitest';
import {
createConversationOffered,
formatProtoJsonDuration,
} from './telemetry.js';
import { ActionStatus, type StreamingLatency } from './types.js';
import { FinishReason, GenerateContentResponse } from '@google/genai';
function createMockResponse(
candidates: GenerateContentResponse['candidates'] = [],
ok = true,
) {
const response = new GenerateContentResponse();
response.candidates = candidates;
response.sdkHttpResponse = {
responseInternal: {
ok,
} as unknown as Response,
json: async () => ({}),
};
return response;
}
describe('telemetry', () => {
describe('createConversationOffered', () => {
it('should create a ConversationOffered object with correct values', () => {
const response = createMockResponse([
{
index: 0,
content: {
role: 'model',
parts: [{ text: 'response with ```code```' }],
},
citationMetadata: {
citations: [
{ uri: 'https://example.com', startIndex: 0, endIndex: 10 },
],
},
finishReason: FinishReason.STOP,
},
]);
const traceId = 'test-trace-id';
const streamingLatency: StreamingLatency = { totalLatency: '1s' };
const result = createConversationOffered(
response,
traceId,
undefined,
streamingLatency,
);
expect(result).toEqual({
citationCount: '1',
includedCode: true,
status: ActionStatus.ACTION_STATUS_NO_ERROR,
traceId,
streamingLatency,
isAgentic: true,
});
});
it('should set status to CANCELLED if signal is aborted', () => {
const response = createMockResponse();
const signal = new AbortController().signal;
vi.spyOn(signal, 'aborted', 'get').mockReturnValue(true);
const result = createConversationOffered(
response,
'trace-id',
signal,
{},
);
expect(result.status).toBe(ActionStatus.ACTION_STATUS_CANCELLED);
});
it('should set status to ERROR_UNKNOWN if response has error (non-OK SDK response)', () => {
const response = createMockResponse([], false);
const result = createConversationOffered(
response,
'trace-id',
undefined,
{},
);
expect(result.status).toBe(ActionStatus.ACTION_STATUS_ERROR_UNKNOWN);
});
it('should set status to ERROR_UNKNOWN if finishReason is not STOP or MAX_TOKENS', () => {
const response = createMockResponse([
{
index: 0,
finishReason: FinishReason.SAFETY,
},
]);
const result = createConversationOffered(
response,
'trace-id',
undefined,
{},
);
expect(result.status).toBe(ActionStatus.ACTION_STATUS_ERROR_UNKNOWN);
});
it('should set status to EMPTY if candidates is empty', () => {
const response = createMockResponse();
const result = createConversationOffered(
response,
'trace-id',
undefined,
{},
);
expect(result.status).toBe(ActionStatus.ACTION_STATUS_EMPTY);
});
it('should detect code in response', () => {
const response = createMockResponse([
{
index: 0,
content: {
parts: [
{ text: 'Here is some code:\n```js\nconsole.log("hi")\n```' },
],
},
},
]);
const result = createConversationOffered(response, 'id', undefined, {});
expect(result.includedCode).toBe(true);
});
it('should not detect code if no backticks', () => {
const response = createMockResponse([
{
index: 0,
content: {
parts: [{ text: 'Here is some text.' }],
},
},
]);
const result = createConversationOffered(response, 'id', undefined, {});
expect(result.includedCode).toBe(false);
});
});
describe('formatProtoJsonDuration', () => {
it('should format milliseconds to seconds string', () => {
expect(formatProtoJsonDuration(1500)).toBe('1.5s');
expect(formatProtoJsonDuration(100)).toBe('0.1s');
});
});
});

View File

@@ -0,0 +1,93 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { FinishReason, type GenerateContentResponse } from '@google/genai';
import { getCitations } from '../utils/generateContentResponseUtilities.js';
import {
ActionStatus,
type ConversationOffered,
type StreamingLatency,
} from './types.js';
export function createConversationOffered(
response: GenerateContentResponse,
traceId: string,
signal: AbortSignal | undefined,
streamingLatency: StreamingLatency,
): ConversationOffered {
const actionStatus = getStatus(response, signal);
return {
citationCount: String(getCitations(response).length),
includedCode: includesCode(response),
status: actionStatus,
traceId,
streamingLatency,
isAgentic: true,
};
}
function includesCode(resp: GenerateContentResponse): boolean {
if (!resp.candidates) {
return false;
}
for (const candidate of resp.candidates) {
if (!candidate.content || !candidate.content.parts) {
continue;
}
for (const part of candidate.content.parts) {
if ('text' in part && part?.text?.includes('```')) {
return true;
}
}
}
return false;
}
function getStatus(
response: GenerateContentResponse,
signal: AbortSignal | undefined,
): ActionStatus {
if (signal?.aborted) {
return ActionStatus.ACTION_STATUS_CANCELLED;
}
if (hasError(response)) {
return ActionStatus.ACTION_STATUS_ERROR_UNKNOWN;
}
if ((response.candidates?.length ?? 0) <= 0) {
return ActionStatus.ACTION_STATUS_EMPTY;
}
return ActionStatus.ACTION_STATUS_NO_ERROR;
}
export function formatProtoJsonDuration(milliseconds: number): string {
return `${milliseconds / 1000}s`;
}
function hasError(response: GenerateContentResponse): boolean {
// Non-OK SDK results should be considered an error.
if (
response.sdkHttpResponse &&
!response.sdkHttpResponse?.responseInternal?.ok
) {
return true;
}
for (const candidate of response.candidates || []) {
// Treat sanitization, SPII, recitation, and forbidden terms as an error.
if (
candidate.finishReason &&
candidate.finishReason !== FinishReason.STOP &&
candidate.finishReason !== FinishReason.MAX_TOKENS
) {
return true;
}
}
return false;
}

View File

@@ -177,7 +177,7 @@ export interface HelpLinkUrl {
export interface SetCodeAssistGlobalUserSettingRequest {
cloudaicompanionProject?: string;
freeTierDataCollectionOptin: boolean;
freeTierDataCollectionOptin?: boolean;
}
export interface CodeAssistGlobalUserSettingResponse {
@@ -217,3 +217,63 @@ export interface BucketInfo {
export interface RetrieveUserQuotaResponse {
buckets?: BucketInfo[];
}
export interface RecordCodeAssistMetricsRequest {
project: string;
requestId?: string;
metadata?: ClientMetadata;
metrics?: CodeAssistMetric[];
}
export interface CodeAssistMetric {
timestamp?: string;
metricMetadata?: Map<string, string>;
// The event tied to this metric. Only one of these should be set.
conversationOffered?: ConversationOffered;
conversationInteraction?: ConversationInteraction;
}
export enum ConversationInteractionInteraction {
UNKNOWN = 0,
THUMBSUP = 1,
THUMBSDOWN = 2,
COPY = 3,
INSERT = 4,
ACCEPT_CODE_BLOCK = 5,
ACCEPT_ALL = 6,
ACCEPT_FILE = 7,
DIFF = 8,
ACCEPT_RANGE = 9,
}
export enum ActionStatus {
ACTION_STATUS_UNSPECIFIED = 0,
ACTION_STATUS_NO_ERROR = 1,
ACTION_STATUS_ERROR_UNKNOWN = 2,
ACTION_STATUS_CANCELLED = 3,
ACTION_STATUS_EMPTY = 4,
}
export interface ConversationOffered {
citationCount?: string;
includedCode?: boolean;
status?: ActionStatus;
traceId?: string;
streamingLatency?: StreamingLatency;
isAgentic?: boolean;
}
export interface StreamingLatency {
firstMessageLatency?: string;
totalLatency?: string;
}
export interface ConversationInteraction {
traceId: string;
status?: ActionStatus;
interaction?: ConversationInteractionInteraction;
acceptedLines?: string;
language?: string;
isAgentic?: boolean;
}

View File

@@ -36,13 +36,6 @@ vi.mock('../utils/errorReporting', () => ({
reportError: vi.fn(),
}));
// Use the actual implementation from partUtils now that it's provided.
vi.mock('../utils/generateContentResponseUtilities', () => ({
getResponseText: (resp: GenerateContentResponse) =>
resp.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') ||
undefined,
}));
describe('Turn', () => {
let turn: Turn;
// Define a type for the mocked Chat instance for clarity

View File

@@ -31,6 +31,7 @@ import { InvalidStreamError } from './geminiChat.js';
import { parseThought, type ThoughtSummary } from '../utils/thoughtUtils.js';
import { createUserContent } from '@google/genai';
import type { ModelConfigKey } from '../services/modelConfigService.js';
import { getCitations } from '../utils/generateContentResponseUtilities.js';
// Define a structure for tools passed to the server
export interface ServerTool {
@@ -405,14 +406,3 @@ export class Turn {
.join(' ');
}
}
function getCitations(resp: GenerateContentResponse): string[] {
return (resp.candidates?.[0]?.citationMetadata?.citations ?? [])
.filter((citation) => citation.uri !== undefined)
.map((citation) => {
if (citation.title) {
return `(${citation.title}) ${citation.uri}`;
}
return citation.uri!;
});
}

View File

@@ -13,11 +13,13 @@ import {
getFunctionCallsFromPartsAsJson,
getStructuredResponse,
getStructuredResponseFromParts,
getCitations,
} from './generateContentResponseUtilities.js';
import type {
GenerateContentResponse,
Part,
SafetyRating,
CitationMetadata,
} from '@google/genai';
import { FinishReason } from '@google/genai';
@@ -33,6 +35,7 @@ const mockResponse = (
parts: Part[],
finishReason: FinishReason = FinishReason.STOP,
safetyRatings: SafetyRating[] = [],
citationMetadata?: CitationMetadata,
): GenerateContentResponse => ({
candidates: [
{
@@ -43,6 +46,7 @@ const mockResponse = (
index: 0,
finishReason,
safetyRatings,
citationMetadata,
},
],
promptFeedback: {
@@ -68,6 +72,82 @@ const minimalMockResponse = (
});
describe('generateContentResponseUtilities', () => {
describe('getCitations', () => {
it('should return empty array for no candidates', () => {
expect(getCitations(minimalMockResponse(undefined))).toEqual([]);
});
it('should return empty array if no citationMetadata', () => {
const response = mockResponse([mockTextPart('Hello')]);
expect(getCitations(response)).toEqual([]);
});
it('should return citations with title and uri', () => {
const citationMetadata: CitationMetadata = {
citations: [
{
startIndex: 0,
endIndex: 10,
uri: 'https://example.com',
title: 'Example Title',
},
],
};
const response = mockResponse(
[mockTextPart('Hello')],
undefined,
undefined,
citationMetadata,
);
expect(getCitations(response)).toEqual([
'(Example Title) https://example.com',
]);
});
it('should return citations with uri only if no title', () => {
const citationMetadata: CitationMetadata = {
citations: [
{
startIndex: 0,
endIndex: 10,
uri: 'https://example.com',
},
],
};
const response = mockResponse(
[mockTextPart('Hello')],
undefined,
undefined,
citationMetadata,
);
expect(getCitations(response)).toEqual(['https://example.com']);
});
it('should filter out citations without uri', () => {
const citationMetadata: CitationMetadata = {
citations: [
{
startIndex: 0,
endIndex: 10,
title: 'No URI',
},
{
startIndex: 10,
endIndex: 20,
uri: 'https://valid.com',
},
],
};
const response = mockResponse(
[mockTextPart('Hello')],
undefined,
undefined,
citationMetadata,
);
expect(getCitations(response)).toEqual(['https://valid.com']);
});
});
describe('getResponseTextFromParts', () => {
it('should return undefined for no parts', () => {
expect(getResponseTextFromParts([])).toBeUndefined();

View File

@@ -105,3 +105,14 @@ export function getStructuredResponseFromParts(
}
return undefined;
}
export function getCitations(resp: GenerateContentResponse): string[] {
return (resp.candidates?.[0]?.citationMetadata?.citations ?? [])
.filter((citation) => citation.uri !== undefined)
.map((citation) => {
if (citation.title) {
return `(${citation.title}) ${citation.uri}`;
}
return citation.uri!;
});
}