diff --git a/packages/cli/src/config/settings.ts b/packages/cli/src/config/settings.ts index 4058c8a35e..6cf79312dc 100644 --- a/packages/cli/src/config/settings.ts +++ b/packages/cli/src/config/settings.ts @@ -45,6 +45,7 @@ const MIGRATION_MAP: Record = { hideFooter: 'ui.hideFooter', showMemoryUsage: 'ui.showMemoryUsage', showLineNumbers: 'ui.showLineNumbers', + showCitations: 'ui.showCitations', accessibility: 'ui.accessibility', ideMode: 'ide.enabled', hasSeenIdeIntegrationNudge: 'ide.hasSeenNudge', diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index c27d132772..9210f0b5c7 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -219,6 +219,15 @@ export const SETTINGS_SCHEMA = { description: 'Show line numbers in the chat.', showInDialog: true, }, + showCitations: { + type: 'boolean', + label: 'Show Citations', + category: 'UI', + requiresRestart: false, + default: false, + description: 'Show citations for generated text in the chat.', + showInDialog: true, + }, accessibility: { type: 'object', label: 'Accessibility', diff --git a/packages/cli/src/ui/App.test.tsx b/packages/cli/src/ui/App.test.tsx index 65b5994878..55e16e0992 100644 --- a/packages/cli/src/ui/App.test.tsx +++ b/packages/cli/src/ui/App.test.tsx @@ -1609,6 +1609,7 @@ describe('App UI', () => { _history, _addItem, _config, + _settings, _onDebugMessage, _handleSlashCommand, _shellModeActive, diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx index ff10918a77..5fbf632aab 100644 --- a/packages/cli/src/ui/App.tsx +++ b/packages/cli/src/ui/App.tsx @@ -639,6 +639,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => { history, addItem, config, + settings, setDebugMessage, handleSlashCommand, shellModeActive, diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 4928b52148..09dfb713ac 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -277,6 +277,7 @@ describe('useGeminiStream', () => { props.history, props.addItem, props.config, + props.loadedSettings, props.onDebugMessage, props.handleSlashCommand, props.shellModeActive, @@ -438,6 +439,7 @@ describe('useGeminiStream', () => { [], mockAddItem, mockConfig, + mockLoadedSettings, mockOnDebugMessage, mockHandleSlashCommand, false, @@ -517,6 +519,7 @@ describe('useGeminiStream', () => { [], mockAddItem, mockConfig, + mockLoadedSettings, mockOnDebugMessage, mockHandleSlashCommand, false, @@ -625,6 +628,7 @@ describe('useGeminiStream', () => { [], mockAddItem, mockConfig, + mockLoadedSettings, mockOnDebugMessage, mockHandleSlashCommand, false, @@ -734,6 +738,7 @@ describe('useGeminiStream', () => { [], mockAddItem, mockConfig, + mockLoadedSettings, mockOnDebugMessage, mockHandleSlashCommand, false, @@ -863,6 +868,7 @@ describe('useGeminiStream', () => { [], mockAddItem, mockConfig, + mockLoadedSettings, mockOnDebugMessage, mockHandleSlashCommand, false, @@ -1174,6 +1180,7 @@ describe('useGeminiStream', () => { [], mockAddItem, mockConfig, + mockLoadedSettings, mockOnDebugMessage, mockHandleSlashCommand, false, @@ -1227,6 +1234,7 @@ describe('useGeminiStream', () => { [], mockAddItem, testConfig, + mockLoadedSettings, mockOnDebugMessage, mockHandleSlashCommand, false, @@ -1277,6 +1285,7 @@ describe('useGeminiStream', () => { [], mockAddItem, mockConfig, + mockLoadedSettings, mockOnDebugMessage, mockHandleSlashCommand, false, @@ -1325,6 +1334,7 @@ describe('useGeminiStream', () => { [], mockAddItem, mockConfig, + mockLoadedSettings, mockOnDebugMessage, mockHandleSlashCommand, false, @@ -1374,6 +1384,7 @@ describe('useGeminiStream', () => { [], mockAddItem, mockConfig, + mockLoadedSettings, mockOnDebugMessage, mockHandleSlashCommand, false, @@ -1513,6 +1524,7 @@ describe('useGeminiStream', () => { [], mockAddItem, mockConfig, + mockLoadedSettings, mockOnDebugMessage, mockHandleSlashCommand, false, // shellModeActive @@ -1577,6 +1589,7 @@ describe('useGeminiStream', () => { [], mockAddItem, mockConfig, + mockLoadedSettings, mockOnDebugMessage, mockHandleSlashCommand, false, @@ -1655,6 +1668,7 @@ describe('useGeminiStream', () => { [], mockAddItem, mockConfig, + mockLoadedSettings, mockOnDebugMessage, mockHandleSlashCommand, false, @@ -1709,6 +1723,7 @@ describe('useGeminiStream', () => { [], mockAddItem, mockConfig, + mockLoadedSettings, mockOnDebugMessage, mockHandleSlashCommand, false, diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index b8004edbc7..d105a0449e 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -7,15 +7,15 @@ import { useState, useRef, useCallback, useEffect, useMemo } from 'react'; import type { Config, - GeminiClient, - ServerGeminiStreamEvent as GeminiEvent, - ServerGeminiContentEvent as ContentEvent, - ServerGeminiErrorEvent as ErrorEvent, - ServerGeminiChatCompressedEvent, - ServerGeminiFinishedEvent, - ToolCallRequestInfo, EditorType, + GeminiClient, + ServerGeminiChatCompressedEvent, + ServerGeminiContentEvent as ContentEvent, + ServerGeminiFinishedEvent, + ServerGeminiStreamEvent as GeminiEvent, ThoughtSummary, + ToolCallRequestInfo, + GeminiErrorEventValue, } from '@google/gemini-cli-core'; import { GeminiEventType as ServerGeminiEventType, @@ -60,6 +60,7 @@ import { } from './useReactToolScheduler.js'; import { useSessionStats } from '../contexts/SessionContext.js'; import { useKeypress } from './useKeypress.js'; +import type { LoadedSettings } from '../../config/settings.js'; enum StreamProcessingStatus { Completed, @@ -76,6 +77,7 @@ export const useGeminiStream = ( history: HistoryItem[], addItem: UseHistoryManagerReturn['addItem'], config: Config, + settings: LoadedSettings, onDebugMessage: (message: string) => void, handleSlashCommand: ( cmd: PartListUnion, @@ -463,7 +465,7 @@ export const useGeminiStream = ( ); const handleErrorEvent = useCallback( - (eventValue: ErrorEvent['value'], userMessageTimestamp: number) => { + (eventValue: GeminiErrorEventValue, userMessageTimestamp: number) => { if (pendingHistoryItemRef.current) { addItem(pendingHistoryItemRef.current, userMessageTimestamp); setPendingHistoryItem(null); @@ -486,6 +488,20 @@ export const useGeminiStream = ( [addItem, pendingHistoryItemRef, setPendingHistoryItem, config, setThought], ); + const handleCitationEvent = useCallback( + (text: string, userMessageTimestamp: number) => { + if (!settings?.merged?.ui?.showCitations) { + return; + } + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + setPendingHistoryItem(null); + } + addItem({ type: MessageType.INFO, text }, userMessageTimestamp); + }, + [addItem, pendingHistoryItemRef, setPendingHistoryItem, settings], + ); + const handleFinishedEvent = useCallback( (event: ServerGeminiFinishedEvent, userMessageTimestamp: number) => { const finishReason = event.value; @@ -611,6 +627,9 @@ export const useGeminiStream = ( userMessageTimestamp, ); break; + case ServerGeminiEventType.Citation: + handleCitationEvent(event.value, userMessageTimestamp); + break; case ServerGeminiEventType.LoopDetected: // handle later because we want to move pending history to history // before we add loop detected message to history @@ -636,6 +655,7 @@ export const useGeminiStream = ( handleChatCompressionEvent, handleFinishedEvent, handleMaxSessionTurnsEvent, + handleCitationEvent, ], ); diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts index c749babab8..74e601e5df 100644 --- a/packages/core/src/core/turn.test.ts +++ b/packages/core/src/core/turn.test.ts @@ -445,6 +445,170 @@ describe('Turn', () => { ]); }); + it('should yield citation and finished events when response has citationMetadata', async () => { + const mockResponseStream = (async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'Some text.' }] }, + citationMetadata: { + citations: [ + { + uri: 'https://example.com/source1', + title: 'Source 1 Title', + }, + ], + }, + finishReason: 'STOP', + }, + ], + } as unknown as GenerateContentResponse; + })(); + mockSendMessageStream.mockResolvedValue(mockResponseStream); + + const events = []; + for await (const event of turn.run( + [{ text: 'Test citations' }], + new AbortController().signal, + )) { + events.push(event); + } + + expect(events).toEqual([ + { type: GeminiEventType.Content, value: 'Some text.' }, + { + type: GeminiEventType.Citation, + value: 'Citations:\n(Source 1 Title) https://example.com/source1', + }, + { type: GeminiEventType.Finished, value: 'STOP' }, + ]); + }); + + it('should yield a single citation event for multiple citations in one response', async () => { + const mockResponseStream = (async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'Some text.' }] }, + citationMetadata: { + citations: [ + { + uri: 'https://example.com/source2', + title: 'Title2', + }, + { + uri: 'https://example.com/source1', + title: 'Title1', + }, + ], + }, + finishReason: 'STOP', + }, + ], + } as unknown as GenerateContentResponse; + })(); + mockSendMessageStream.mockResolvedValue(mockResponseStream); + + const events = []; + for await (const event of turn.run( + [{ text: 'test' }], + new AbortController().signal, + )) { + events.push(event); + } + + expect(events).toEqual([ + { type: GeminiEventType.Content, value: 'Some text.' }, + { + type: GeminiEventType.Citation, + value: + 'Citations:\n(Title1) https://example.com/source1\n(Title2) https://example.com/source2', + }, + { type: GeminiEventType.Finished, value: 'STOP' }, + ]); + }); + + it('should not yield citation event if there is no finish reason', async () => { + const mockResponseStream = (async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'Some text.' }] }, + citationMetadata: { + citations: [ + { + uri: 'https://example.com/source1', + title: 'Source 1 Title', + }, + ], + }, + // No finishReason + }, + ], + } as unknown as GenerateContentResponse; + })(); + mockSendMessageStream.mockResolvedValue(mockResponseStream); + + const events = []; + for await (const event of turn.run( + [{ text: 'test' }], + new AbortController().signal, + )) { + events.push(event); + } + + expect(events).toEqual([ + { type: GeminiEventType.Content, value: 'Some text.' }, + ]); + // No Citation or Finished event + expect(events.some((e) => e.type === GeminiEventType.Citation)).toBe( + false, + ); + }); + + it('should ignore citations without a URI', async () => { + const mockResponseStream = (async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'Some text.' }] }, + citationMetadata: { + citations: [ + { + uri: 'https://example.com/source1', + title: 'Good Source', + }, + { + // uri is undefined + title: 'Bad Source', + }, + ], + }, + finishReason: 'STOP', + }, + ], + } as unknown as GenerateContentResponse; + })(); + mockSendMessageStream.mockResolvedValue(mockResponseStream); + + const events = []; + for await (const event of turn.run( + [{ text: 'test' }], + new AbortController().signal, + )) { + events.push(event); + } + + expect(events).toEqual([ + { type: GeminiEventType.Content, value: 'Some text.' }, + { + type: GeminiEventType.Citation, + value: 'Citations:\n(Good Source) https://example.com/source1', + }, + { type: GeminiEventType.Finished, value: 'STOP' }, + ]); + }); + it('should not crash when cancelled request has malformed error', async () => { const abortController = new AbortController(); diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index ab46ee00ab..f21138c7b4 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -54,6 +54,7 @@ export enum GeminiEventType { MaxSessionTurns = 'max_session_turns', Finished = 'finished', LoopDetected = 'loop_detected', + Citation = 'citation', } export interface StructuredError { @@ -163,34 +164,37 @@ export type ServerGeminiLoopDetectedEvent = { type: GeminiEventType.LoopDetected; }; +export type ServerGeminiCitationEvent = { + type: GeminiEventType.Citation; + value: string; +}; + // The original union type, now composed of the individual types export type ServerGeminiStreamEvent = + | ServerGeminiChatCompressedEvent + | ServerGeminiCitationEvent | ServerGeminiContentEvent + | ServerGeminiErrorEvent + | ServerGeminiFinishedEvent + | ServerGeminiLoopDetectedEvent + | ServerGeminiMaxSessionTurnsEvent + | ServerGeminiThoughtEvent + | ServerGeminiToolCallConfirmationEvent | ServerGeminiToolCallRequestEvent | ServerGeminiToolCallResponseEvent - | ServerGeminiToolCallConfirmationEvent - | ServerGeminiUserCancelledEvent - | ServerGeminiErrorEvent - | ServerGeminiChatCompressedEvent - | ServerGeminiThoughtEvent - | ServerGeminiMaxSessionTurnsEvent - | ServerGeminiFinishedEvent - | ServerGeminiLoopDetectedEvent; + | ServerGeminiUserCancelledEvent; // A turn manages the agentic loop turn within the server context. export class Turn { - readonly pendingToolCalls: ToolCallRequestInfo[]; - private debugResponses: GenerateContentResponse[]; - finishReason: FinishReason | undefined; + readonly pendingToolCalls: ToolCallRequestInfo[] = []; + private debugResponses: GenerateContentResponse[] = []; + private pendingCitations = new Set(); + finishReason: FinishReason | undefined = undefined; constructor( private readonly chat: GeminiChat, private readonly prompt_id: string, - ) { - this.pendingToolCalls = []; - this.debugResponses = []; - this.finishReason = undefined; - } + ) {} // The run method yields simpler events suitable for server logic async *run( req: PartListUnion, @@ -251,10 +255,22 @@ export class Turn { } } + for (const citation of getCitations(resp)) { + this.pendingCitations.add(citation); + } + // Check if response was truncated or stopped for various reasons const finishReason = resp.candidates?.[0]?.finishReason; if (finishReason) { + if (this.pendingCitations.size > 0) { + yield { + type: GeminiEventType.Citation, + value: `Citations:\n${[...this.pendingCitations].sort().join('\n')}`, + }; + this.pendingCitations.clear(); + } + this.finishReason = finishReason; yield { type: GeminiEventType.Finished, @@ -325,3 +341,14 @@ export class Turn { return this.debugResponses; } } + +function getCitations(resp: GenerateContentResponse): string[] { + return (resp.candidates?.[0]?.citationMetadata?.citations ?? []) + .filter((citation) => citation.uri !== undefined) + .map((citation) => { + if (citation.title) { + return `(${citation.title}) ${citation.uri}`; + } + return citation.uri!; + }); +}