Enable citations by default for certain users. (#7438)

This commit is contained in:
Tommaso Sciortino
2025-09-02 09:36:24 -07:00
committed by GitHub
parent c29e44848b
commit 997136ae25
5 changed files with 87 additions and 181 deletions
+4
View File
@@ -110,6 +110,10 @@ Settings are organized into categories. All settings should be placed within the
- **Description:** Show line numbers in the chat. - **Description:** Show line numbers in the chat.
- **Default:** `false` - **Default:** `false`
- **`ui.showCitations`** (boolean):
- **Description:** Show citations for generated text in the chat.
- **Default:** `false`
- **`ui.accessibility.disableLoadingPhrases`** (boolean): - **`ui.accessibility.disableLoadingPhrases`** (boolean):
- **Description:** Disable loading phrases for accessibility. - **Description:** Disable loading phrases for accessibility.
- **Default:** `false` - **Default:** `false`
+14 -2
View File
@@ -31,6 +31,8 @@ import {
ConversationFinishedEvent, ConversationFinishedEvent,
ApprovalMode, ApprovalMode,
parseAndFormatApiError, parseAndFormatApiError,
getCodeAssistServer,
UserTierId,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import { type Part, type PartListUnion, FinishReason } from '@google/genai'; import { type Part, type PartListUnion, FinishReason } from '@google/genai';
import type { import type {
@@ -68,6 +70,15 @@ enum StreamProcessingStatus {
Error, Error,
} }
function showCitations(settings: LoadedSettings, config: Config): boolean {
const enabled = settings?.merged?.ui?.showCitations;
if (enabled !== undefined) {
return enabled;
}
const server = getCodeAssistServer(config);
return (server && server.userTier !== UserTierId.FREE) ?? false;
}
/** /**
* Manages the Gemini stream, including user input, command processing, * Manages the Gemini stream, including user input, command processing,
* API interaction, and tool call lifecycle. * API interaction, and tool call lifecycle.
@@ -490,16 +501,17 @@ export const useGeminiStream = (
const handleCitationEvent = useCallback( const handleCitationEvent = useCallback(
(text: string, userMessageTimestamp: number) => { (text: string, userMessageTimestamp: number) => {
if (!settings?.merged?.ui?.showCitations) { if (!showCitations(settings, config)) {
return; return;
} }
if (pendingHistoryItemRef.current) { if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, userMessageTimestamp); addItem(pendingHistoryItemRef.current, userMessageTimestamp);
setPendingHistoryItem(null); setPendingHistoryItem(null);
} }
addItem({ type: MessageType.INFO, text }, userMessageTimestamp); addItem({ type: MessageType.INFO, text }, userMessageTimestamp);
}, },
[addItem, pendingHistoryItemRef, setPendingHistoryItem, settings], [addItem, pendingHistoryItemRef, setPendingHistoryItem, settings, config],
); );
const handleFinishedEvent = useCallback( const handleFinishedEvent = useCallback(
@@ -8,150 +8,31 @@ import { describe, it, expect, beforeEach, vi } from 'vitest';
import { renderHook, waitFor } from '@testing-library/react'; import { renderHook, waitFor } from '@testing-library/react';
import type { import type {
Config, Config,
GeminiClient,
ContentGenerator,
} from '@google/gemini-cli-core';
import {
CodeAssistServer, CodeAssistServer,
LoggingContentGenerator, LoadCodeAssistResponse,
UserTierId,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import type { OAuth2Client } from 'google-auth-library'; import { UserTierId, getCodeAssistServer } from '@google/gemini-cli-core';
import { usePrivacySettings } from './usePrivacySettings.js'; import { usePrivacySettings } from './usePrivacySettings.js';
// Mock the dependencies // Mock the dependencies
vi.mock('@google/gemini-cli-core', () => { vi.mock('@google/gemini-cli-core', async (importOriginal) => {
// Mock classes for instanceof checks const actual =
class MockCodeAssistServer { await importOriginal<typeof import('@google/gemini-cli-core')>();
projectId = 'test-project-id';
loadCodeAssist = vi.fn();
getCodeAssistGlobalUserSetting = vi.fn();
setCodeAssistGlobalUserSetting = vi.fn();
constructor(
_client?: GeminiClient,
_projectId?: string,
_httpOptions?: Record<string, unknown>,
_sessionId?: string,
_userTier?: UserTierId,
) {}
}
class MockLoggingContentGenerator {
getWrapped = vi.fn();
constructor(
_wrapped?: ContentGenerator,
_config?: Record<string, unknown>,
) {}
}
return { return {
Config: vi.fn(), ...actual,
CodeAssistServer: MockCodeAssistServer, getCodeAssistServer: vi.fn(),
LoggingContentGenerator: MockLoggingContentGenerator,
GeminiClient: vi.fn(),
UserTierId: {
FREE: 'free-tier',
LEGACY: 'legacy-tier',
STANDARD: 'standard-tier',
},
}; };
}); });
describe('usePrivacySettings', () => { describe('usePrivacySettings', () => {
let mockConfig: Config; const mockConfig = {} as unknown as Config;
let mockClient: GeminiClient;
let mockCodeAssistServer: CodeAssistServer;
let mockLoggingContentGenerator: LoggingContentGenerator;
beforeEach(() => { beforeEach(() => {
vi.clearAllMocks(); vi.clearAllMocks();
// Create mock CodeAssistServer instance
mockCodeAssistServer = new CodeAssistServer(
null as unknown as OAuth2Client,
'test-project-id',
) as unknown as CodeAssistServer;
(
mockCodeAssistServer.loadCodeAssist as ReturnType<typeof vi.fn>
).mockResolvedValue({
currentTier: { id: UserTierId.FREE },
});
(
mockCodeAssistServer.getCodeAssistGlobalUserSetting as ReturnType<
typeof vi.fn
>
).mockResolvedValue({
freeTierDataCollectionOptin: true,
});
(
mockCodeAssistServer.setCodeAssistGlobalUserSetting as ReturnType<
typeof vi.fn
>
).mockResolvedValue({
freeTierDataCollectionOptin: false,
});
// Create mock LoggingContentGenerator that wraps the CodeAssistServer
mockLoggingContentGenerator = new LoggingContentGenerator(
mockCodeAssistServer,
null as unknown as Config,
) as unknown as LoggingContentGenerator;
(
mockLoggingContentGenerator.getWrapped as ReturnType<typeof vi.fn>
).mockReturnValue(mockCodeAssistServer);
// Create mock GeminiClient
mockClient = {
getContentGenerator: vi.fn().mockReturnValue(mockLoggingContentGenerator),
} as unknown as GeminiClient;
// Create mock Config
mockConfig = {
getGeminiClient: vi.fn().mockReturnValue(mockClient),
} as unknown as Config;
}); });
it('should handle LoggingContentGenerator wrapper correctly and not throw "Oauth not being used" error', async () => { it('should throw error when content generator is not a CodeAssistServer', async () => {
const { result } = renderHook(() => usePrivacySettings(mockConfig)); vi.mocked(getCodeAssistServer).mockReturnValue(undefined);
// Initial state should be loading
expect(result.current.privacyState.isLoading).toBe(true);
expect(result.current.privacyState.error).toBeUndefined();
// Wait for the hook to complete
await waitFor(() => {
expect(result.current.privacyState.isLoading).toBe(false);
});
// Should not have the "Oauth not being used" error
expect(result.current.privacyState.error).toBeUndefined();
expect(result.current.privacyState.isFreeTier).toBe(true);
expect(result.current.privacyState.dataCollectionOptIn).toBe(true);
// Verify that getWrapped was called to unwrap the LoggingContentGenerator
expect(mockLoggingContentGenerator.getWrapped).toHaveBeenCalled();
});
it('should work with direct CodeAssistServer (no wrapper)', async () => {
// Test case where the content generator is directly a CodeAssistServer
const directServer = new CodeAssistServer(
null as unknown as OAuth2Client,
'test-project-id',
) as unknown as CodeAssistServer;
(directServer.loadCodeAssist as ReturnType<typeof vi.fn>).mockResolvedValue(
{
currentTier: { id: UserTierId.FREE },
},
);
(
directServer.getCodeAssistGlobalUserSetting as ReturnType<typeof vi.fn>
).mockResolvedValue({
freeTierDataCollectionOptin: true,
});
mockClient.getContentGenerator = vi.fn().mockReturnValue(directServer);
const { result } = renderHook(() => usePrivacySettings(mockConfig)); const { result } = renderHook(() => usePrivacySettings(mockConfig));
@@ -159,18 +40,18 @@ describe('usePrivacySettings', () => {
expect(result.current.privacyState.isLoading).toBe(false); expect(result.current.privacyState.isLoading).toBe(false);
}); });
expect(result.current.privacyState.error).toBeUndefined(); expect(result.current.privacyState.error).toBe('Oauth not being used');
expect(result.current.privacyState.isFreeTier).toBe(true);
expect(result.current.privacyState.dataCollectionOptIn).toBe(true);
}); });
it('should handle paid tier users correctly', async () => { it('should handle paid tier users correctly', async () => {
// Mock paid tier response // Mock paid tier response
( vi.mocked(getCodeAssistServer).mockReturnValue({
mockCodeAssistServer.loadCodeAssist as ReturnType<typeof vi.fn> projectId: 'test-project-id',
).mockResolvedValue({ loadCodeAssist: () =>
currentTier: { id: UserTierId.STANDARD }, ({
}); currentTier: { id: UserTierId.STANDARD },
}) as unknown as LoadCodeAssistResponse,
} as unknown as CodeAssistServer);
const { result } = renderHook(() => usePrivacySettings(mockConfig)); const { result } = renderHook(() => usePrivacySettings(mockConfig));
@@ -183,31 +64,13 @@ describe('usePrivacySettings', () => {
expect(result.current.privacyState.dataCollectionOptIn).toBeUndefined(); expect(result.current.privacyState.dataCollectionOptIn).toBeUndefined();
}); });
it('should throw error when content generator is not a CodeAssistServer', async () => {
// Mock a non-CodeAssistServer content generator
const mockOtherGenerator = { someOtherMethod: vi.fn() };
(
mockLoggingContentGenerator.getWrapped as ReturnType<typeof vi.fn>
).mockReturnValue(mockOtherGenerator);
const { result } = renderHook(() => usePrivacySettings(mockConfig));
await waitFor(() => {
expect(result.current.privacyState.isLoading).toBe(false);
});
expect(result.current.privacyState.error).toBe('Oauth not being used');
});
it('should throw error when CodeAssistServer has no projectId', async () => { it('should throw error when CodeAssistServer has no projectId', async () => {
// Mock CodeAssistServer without projectId vi.mocked(getCodeAssistServer).mockReturnValue({
const mockServerNoProject = { loadCodeAssist: () =>
...mockCodeAssistServer, ({
projectId: undefined, currentTier: { id: UserTierId.FREE },
}; }) as unknown as LoadCodeAssistResponse,
( } as unknown as CodeAssistServer);
mockLoggingContentGenerator.getWrapped as ReturnType<typeof vi.fn>
).mockReturnValue(mockServerNoProject);
const { result } = renderHook(() => usePrivacySettings(mockConfig)); const { result } = renderHook(() => usePrivacySettings(mockConfig));
@@ -215,10 +78,27 @@ describe('usePrivacySettings', () => {
expect(result.current.privacyState.isLoading).toBe(false); expect(result.current.privacyState.isLoading).toBe(false);
}); });
expect(result.current.privacyState.error).toBe('Oauth not being used'); expect(result.current.privacyState.error).toBe(
'CodeAssist server is missing a project ID',
);
}); });
it('should update data collection opt-in setting', async () => { it('should update data collection opt-in setting', async () => {
const mockCodeAssistServer = {
projectId: 'test-project-id',
getCodeAssistGlobalUserSetting: vi.fn().mockResolvedValue({
freeTierDataCollectionOptin: true,
}),
setCodeAssistGlobalUserSetting: vi.fn().mockResolvedValue({
freeTierDataCollectionOptin: false,
}),
loadCodeAssist: () =>
({
currentTier: { id: UserTierId.FREE },
}) as unknown as LoadCodeAssistResponse,
} as unknown as CodeAssistServer;
vi.mocked(getCodeAssistServer).mockReturnValue(mockCodeAssistServer);
const { result } = renderHook(() => usePrivacySettings(mockConfig)); const { result } = renderHook(() => usePrivacySettings(mockConfig));
// Wait for initial load // Wait for initial load
+10 -17
View File
@@ -5,11 +5,11 @@
*/ */
import { useState, useEffect, useCallback } from 'react'; import { useState, useEffect, useCallback } from 'react';
import type { Config } from '@google/gemini-cli-core';
import { import {
CodeAssistServer, type Config,
type CodeAssistServer,
UserTierId, UserTierId,
LoggingContentGenerator, getCodeAssistServer,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
export interface PrivacyState { export interface PrivacyState {
@@ -30,7 +30,7 @@ export const usePrivacySettings = (config: Config) => {
isLoading: true, isLoading: true,
}); });
try { try {
const server = getCodeAssistServer(config); const server = getCodeAssistServerOrFail(config);
const tier = await getTier(server); const tier = await getTier(server);
if (tier !== UserTierId.FREE) { if (tier !== UserTierId.FREE) {
// We don't need to fetch opt-out info since non-free tier // We don't need to fetch opt-out info since non-free tier
@@ -61,7 +61,7 @@ export const usePrivacySettings = (config: Config) => {
const updateDataCollectionOptIn = useCallback( const updateDataCollectionOptIn = useCallback(
async (optIn: boolean) => { async (optIn: boolean) => {
try { try {
const server = getCodeAssistServer(config); const server = getCodeAssistServerOrFail(config);
const updatedOptIn = await setRemoteDataCollectionOptIn(server, optIn); const updatedOptIn = await setRemoteDataCollectionOptIn(server, optIn);
setPrivacyState({ setPrivacyState({
isLoading: false, isLoading: false,
@@ -84,19 +84,12 @@ export const usePrivacySettings = (config: Config) => {
}; };
}; };
function getCodeAssistServer(config: Config): CodeAssistServer { function getCodeAssistServerOrFail(config: Config): CodeAssistServer {
let server = config.getGeminiClient().getContentGenerator(); const server = getCodeAssistServer(config);
if (server === undefined) {
// Unwrap LoggingContentGenerator if present
if (server instanceof LoggingContentGenerator) {
server = server.getWrapped();
}
// Neither of these cases should ever happen.
if (!(server instanceof CodeAssistServer)) {
throw new Error('Oauth not being used');
} else if (!server.projectId) {
throw new Error('Oauth not being used'); throw new Error('Oauth not being used');
} else if (server.projectId === undefined) {
throw new Error('CodeAssist server is missing a project ID');
} }
return server; return server;
} }
@@ -11,6 +11,7 @@ import { setupUser } from './setup.js';
import type { HttpOptions } from './server.js'; import type { HttpOptions } from './server.js';
import { CodeAssistServer } from './server.js'; import { CodeAssistServer } from './server.js';
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
import { LoggingContentGenerator } from '../core/loggingContentGenerator.js';
export async function createCodeAssistContentGenerator( export async function createCodeAssistContentGenerator(
httpOptions: HttpOptions, httpOptions: HttpOptions,
@@ -35,3 +36,19 @@ export async function createCodeAssistContentGenerator(
throw new Error(`Unsupported authType: ${authType}`); throw new Error(`Unsupported authType: ${authType}`);
} }
export function getCodeAssistServer(
config: Config,
): CodeAssistServer | undefined {
let server = config.getGeminiClient().getContentGenerator();
// Unwrap LoggingContentGenerator if present
if (server instanceof LoggingContentGenerator) {
server = server.getWrapped();
}
if (!(server instanceof CodeAssistServer)) {
return undefined;
}
return server;
}