Enable citations by default for certain users. (#7438)

This commit is contained in:
Tommaso Sciortino
2025-09-02 09:36:24 -07:00
committed by GitHub
parent c29e44848b
commit 997136ae25
5 changed files with 87 additions and 181 deletions

View File

@@ -110,6 +110,10 @@ Settings are organized into categories. All settings should be placed within the
- **Description:** Show line numbers in the chat.
- **Default:** `false`
- **`ui.showCitations`** (boolean):
- **Description:** Show citations for generated text in the chat.
- **Default:** `false`
- **`ui.accessibility.disableLoadingPhrases`** (boolean):
- **Description:** Disable loading phrases for accessibility.
- **Default:** `false`

View File

@@ -31,6 +31,8 @@ import {
ConversationFinishedEvent,
ApprovalMode,
parseAndFormatApiError,
getCodeAssistServer,
UserTierId,
} from '@google/gemini-cli-core';
import { type Part, type PartListUnion, FinishReason } from '@google/genai';
import type {
@@ -68,6 +70,15 @@ enum StreamProcessingStatus {
Error,
}
function showCitations(settings: LoadedSettings, config: Config): boolean {
const enabled = settings?.merged?.ui?.showCitations;
if (enabled !== undefined) {
return enabled;
}
const server = getCodeAssistServer(config);
return (server && server.userTier !== UserTierId.FREE) ?? false;
}
/**
* Manages the Gemini stream, including user input, command processing,
* API interaction, and tool call lifecycle.
@@ -490,16 +501,17 @@ export const useGeminiStream = (
const handleCitationEvent = useCallback(
(text: string, userMessageTimestamp: number) => {
if (!settings?.merged?.ui?.showCitations) {
if (!showCitations(settings, config)) {
return;
}
if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
setPendingHistoryItem(null);
}
addItem({ type: MessageType.INFO, text }, userMessageTimestamp);
},
[addItem, pendingHistoryItemRef, setPendingHistoryItem, settings],
[addItem, pendingHistoryItemRef, setPendingHistoryItem, settings, config],
);
const handleFinishedEvent = useCallback(

View File

@@ -8,150 +8,31 @@ import { describe, it, expect, beforeEach, vi } from 'vitest';
import { renderHook, waitFor } from '@testing-library/react';
import type {
Config,
GeminiClient,
ContentGenerator,
} from '@google/gemini-cli-core';
import {
CodeAssistServer,
LoggingContentGenerator,
UserTierId,
LoadCodeAssistResponse,
} from '@google/gemini-cli-core';
import type { OAuth2Client } from 'google-auth-library';
import { UserTierId, getCodeAssistServer } from '@google/gemini-cli-core';
import { usePrivacySettings } from './usePrivacySettings.js';
// Mock the dependencies
vi.mock('@google/gemini-cli-core', () => {
// Mock classes for instanceof checks
class MockCodeAssistServer {
projectId = 'test-project-id';
loadCodeAssist = vi.fn();
getCodeAssistGlobalUserSetting = vi.fn();
setCodeAssistGlobalUserSetting = vi.fn();
constructor(
_client?: GeminiClient,
_projectId?: string,
_httpOptions?: Record<string, unknown>,
_sessionId?: string,
_userTier?: UserTierId,
) {}
}
class MockLoggingContentGenerator {
getWrapped = vi.fn();
constructor(
_wrapped?: ContentGenerator,
_config?: Record<string, unknown>,
) {}
}
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const actual =
await importOriginal<typeof import('@google/gemini-cli-core')>();
return {
Config: vi.fn(),
CodeAssistServer: MockCodeAssistServer,
LoggingContentGenerator: MockLoggingContentGenerator,
GeminiClient: vi.fn(),
UserTierId: {
FREE: 'free-tier',
LEGACY: 'legacy-tier',
STANDARD: 'standard-tier',
},
...actual,
getCodeAssistServer: vi.fn(),
};
});
describe('usePrivacySettings', () => {
let mockConfig: Config;
let mockClient: GeminiClient;
let mockCodeAssistServer: CodeAssistServer;
let mockLoggingContentGenerator: LoggingContentGenerator;
const mockConfig = {} as unknown as Config;
beforeEach(() => {
vi.clearAllMocks();
// Create mock CodeAssistServer instance
mockCodeAssistServer = new CodeAssistServer(
null as unknown as OAuth2Client,
'test-project-id',
) as unknown as CodeAssistServer;
(
mockCodeAssistServer.loadCodeAssist as ReturnType<typeof vi.fn>
).mockResolvedValue({
currentTier: { id: UserTierId.FREE },
});
(
mockCodeAssistServer.getCodeAssistGlobalUserSetting as ReturnType<
typeof vi.fn
>
).mockResolvedValue({
freeTierDataCollectionOptin: true,
});
(
mockCodeAssistServer.setCodeAssistGlobalUserSetting as ReturnType<
typeof vi.fn
>
).mockResolvedValue({
freeTierDataCollectionOptin: false,
});
// Create mock LoggingContentGenerator that wraps the CodeAssistServer
mockLoggingContentGenerator = new LoggingContentGenerator(
mockCodeAssistServer,
null as unknown as Config,
) as unknown as LoggingContentGenerator;
(
mockLoggingContentGenerator.getWrapped as ReturnType<typeof vi.fn>
).mockReturnValue(mockCodeAssistServer);
// Create mock GeminiClient
mockClient = {
getContentGenerator: vi.fn().mockReturnValue(mockLoggingContentGenerator),
} as unknown as GeminiClient;
// Create mock Config
mockConfig = {
getGeminiClient: vi.fn().mockReturnValue(mockClient),
} as unknown as Config;
});
it('should handle LoggingContentGenerator wrapper correctly and not throw "Oauth not being used" error', async () => {
const { result } = renderHook(() => usePrivacySettings(mockConfig));
// Initial state should be loading
expect(result.current.privacyState.isLoading).toBe(true);
expect(result.current.privacyState.error).toBeUndefined();
// Wait for the hook to complete
await waitFor(() => {
expect(result.current.privacyState.isLoading).toBe(false);
});
// Should not have the "Oauth not being used" error
expect(result.current.privacyState.error).toBeUndefined();
expect(result.current.privacyState.isFreeTier).toBe(true);
expect(result.current.privacyState.dataCollectionOptIn).toBe(true);
// Verify that getWrapped was called to unwrap the LoggingContentGenerator
expect(mockLoggingContentGenerator.getWrapped).toHaveBeenCalled();
});
it('should work with direct CodeAssistServer (no wrapper)', async () => {
// Test case where the content generator is directly a CodeAssistServer
const directServer = new CodeAssistServer(
null as unknown as OAuth2Client,
'test-project-id',
) as unknown as CodeAssistServer;
(directServer.loadCodeAssist as ReturnType<typeof vi.fn>).mockResolvedValue(
{
currentTier: { id: UserTierId.FREE },
},
);
(
directServer.getCodeAssistGlobalUserSetting as ReturnType<typeof vi.fn>
).mockResolvedValue({
freeTierDataCollectionOptin: true,
});
mockClient.getContentGenerator = vi.fn().mockReturnValue(directServer);
it('should throw error when content generator is not a CodeAssistServer', async () => {
vi.mocked(getCodeAssistServer).mockReturnValue(undefined);
const { result } = renderHook(() => usePrivacySettings(mockConfig));
@@ -159,18 +40,18 @@ describe('usePrivacySettings', () => {
expect(result.current.privacyState.isLoading).toBe(false);
});
expect(result.current.privacyState.error).toBeUndefined();
expect(result.current.privacyState.isFreeTier).toBe(true);
expect(result.current.privacyState.dataCollectionOptIn).toBe(true);
expect(result.current.privacyState.error).toBe('Oauth not being used');
});
it('should handle paid tier users correctly', async () => {
// Mock paid tier response
(
mockCodeAssistServer.loadCodeAssist as ReturnType<typeof vi.fn>
).mockResolvedValue({
currentTier: { id: UserTierId.STANDARD },
});
vi.mocked(getCodeAssistServer).mockReturnValue({
projectId: 'test-project-id',
loadCodeAssist: () =>
({
currentTier: { id: UserTierId.STANDARD },
}) as unknown as LoadCodeAssistResponse,
} as unknown as CodeAssistServer);
const { result } = renderHook(() => usePrivacySettings(mockConfig));
@@ -183,31 +64,13 @@ describe('usePrivacySettings', () => {
expect(result.current.privacyState.dataCollectionOptIn).toBeUndefined();
});
it('should throw error when content generator is not a CodeAssistServer', async () => {
// Mock a non-CodeAssistServer content generator
const mockOtherGenerator = { someOtherMethod: vi.fn() };
(
mockLoggingContentGenerator.getWrapped as ReturnType<typeof vi.fn>
).mockReturnValue(mockOtherGenerator);
const { result } = renderHook(() => usePrivacySettings(mockConfig));
await waitFor(() => {
expect(result.current.privacyState.isLoading).toBe(false);
});
expect(result.current.privacyState.error).toBe('Oauth not being used');
});
it('should throw error when CodeAssistServer has no projectId', async () => {
// Mock CodeAssistServer without projectId
const mockServerNoProject = {
...mockCodeAssistServer,
projectId: undefined,
};
(
mockLoggingContentGenerator.getWrapped as ReturnType<typeof vi.fn>
).mockReturnValue(mockServerNoProject);
vi.mocked(getCodeAssistServer).mockReturnValue({
loadCodeAssist: () =>
({
currentTier: { id: UserTierId.FREE },
}) as unknown as LoadCodeAssistResponse,
} as unknown as CodeAssistServer);
const { result } = renderHook(() => usePrivacySettings(mockConfig));
@@ -215,10 +78,27 @@ describe('usePrivacySettings', () => {
expect(result.current.privacyState.isLoading).toBe(false);
});
expect(result.current.privacyState.error).toBe('Oauth not being used');
expect(result.current.privacyState.error).toBe(
'CodeAssist server is missing a project ID',
);
});
it('should update data collection opt-in setting', async () => {
const mockCodeAssistServer = {
projectId: 'test-project-id',
getCodeAssistGlobalUserSetting: vi.fn().mockResolvedValue({
freeTierDataCollectionOptin: true,
}),
setCodeAssistGlobalUserSetting: vi.fn().mockResolvedValue({
freeTierDataCollectionOptin: false,
}),
loadCodeAssist: () =>
({
currentTier: { id: UserTierId.FREE },
}) as unknown as LoadCodeAssistResponse,
} as unknown as CodeAssistServer;
vi.mocked(getCodeAssistServer).mockReturnValue(mockCodeAssistServer);
const { result } = renderHook(() => usePrivacySettings(mockConfig));
// Wait for initial load

View File

@@ -5,11 +5,11 @@
*/
import { useState, useEffect, useCallback } from 'react';
import type { Config } from '@google/gemini-cli-core';
import {
CodeAssistServer,
type Config,
type CodeAssistServer,
UserTierId,
LoggingContentGenerator,
getCodeAssistServer,
} from '@google/gemini-cli-core';
export interface PrivacyState {
@@ -30,7 +30,7 @@ export const usePrivacySettings = (config: Config) => {
isLoading: true,
});
try {
const server = getCodeAssistServer(config);
const server = getCodeAssistServerOrFail(config);
const tier = await getTier(server);
if (tier !== UserTierId.FREE) {
// We don't need to fetch opt-out info since non-free tier
@@ -61,7 +61,7 @@ export const usePrivacySettings = (config: Config) => {
const updateDataCollectionOptIn = useCallback(
async (optIn: boolean) => {
try {
const server = getCodeAssistServer(config);
const server = getCodeAssistServerOrFail(config);
const updatedOptIn = await setRemoteDataCollectionOptIn(server, optIn);
setPrivacyState({
isLoading: false,
@@ -84,19 +84,12 @@ export const usePrivacySettings = (config: Config) => {
};
};
function getCodeAssistServer(config: Config): CodeAssistServer {
let server = config.getGeminiClient().getContentGenerator();
// Unwrap LoggingContentGenerator if present
if (server instanceof LoggingContentGenerator) {
server = server.getWrapped();
}
// Neither of these cases should ever happen.
if (!(server instanceof CodeAssistServer)) {
throw new Error('Oauth not being used');
} else if (!server.projectId) {
function getCodeAssistServerOrFail(config: Config): CodeAssistServer {
const server = getCodeAssistServer(config);
if (server === undefined) {
throw new Error('Oauth not being used');
} else if (server.projectId === undefined) {
throw new Error('CodeAssist server is missing a project ID');
}
return server;
}

View File

@@ -11,6 +11,7 @@ import { setupUser } from './setup.js';
import type { HttpOptions } from './server.js';
import { CodeAssistServer } from './server.js';
import type { Config } from '../config/config.js';
import { LoggingContentGenerator } from '../core/loggingContentGenerator.js';
export async function createCodeAssistContentGenerator(
httpOptions: HttpOptions,
@@ -35,3 +36,19 @@ export async function createCodeAssistContentGenerator(
throw new Error(`Unsupported authType: ${authType}`);
}
export function getCodeAssistServer(
config: Config,
): CodeAssistServer | undefined {
let server = config.getGeminiClient().getContentGenerator();
// Unwrap LoggingContentGenerator if present
if (server instanceof LoggingContentGenerator) {
server = server.getWrapped();
}
if (!(server instanceof CodeAssistServer)) {
return undefined;
}
return server;
}