Show citations at the end of each turn (#7350)

This commit is contained in:
Tommaso Sciortino
2025-08-28 16:42:54 -07:00
committed by GitHub
parent e02cb6c6a3
commit 95db24f234
8 changed files with 262 additions and 24 deletions
+1
View File
@@ -45,6 +45,7 @@ const MIGRATION_MAP: Record<string, string> = {
hideFooter: 'ui.hideFooter',
showMemoryUsage: 'ui.showMemoryUsage',
showLineNumbers: 'ui.showLineNumbers',
showCitations: 'ui.showCitations',
accessibility: 'ui.accessibility',
ideMode: 'ide.enabled',
hasSeenIdeIntegrationNudge: 'ide.hasSeenNudge',
@@ -219,6 +219,15 @@ export const SETTINGS_SCHEMA = {
description: 'Show line numbers in the chat.',
showInDialog: true,
},
showCitations: {
type: 'boolean',
label: 'Show Citations',
category: 'UI',
requiresRestart: false,
default: false,
description: 'Show citations for generated text in the chat.',
showInDialog: true,
},
accessibility: {
type: 'object',
label: 'Accessibility',
+1
View File
@@ -1609,6 +1609,7 @@ describe('App UI', () => {
_history,
_addItem,
_config,
_settings,
_onDebugMessage,
_handleSlashCommand,
_shellModeActive,
+1
View File
@@ -639,6 +639,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
history,
addItem,
config,
settings,
setDebugMessage,
handleSlashCommand,
shellModeActive,
@@ -277,6 +277,7 @@ describe('useGeminiStream', () => {
props.history,
props.addItem,
props.config,
props.loadedSettings,
props.onDebugMessage,
props.handleSlashCommand,
props.shellModeActive,
@@ -438,6 +439,7 @@ describe('useGeminiStream', () => {
[],
mockAddItem,
mockConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
@@ -517,6 +519,7 @@ describe('useGeminiStream', () => {
[],
mockAddItem,
mockConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
@@ -625,6 +628,7 @@ describe('useGeminiStream', () => {
[],
mockAddItem,
mockConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
@@ -734,6 +738,7 @@ describe('useGeminiStream', () => {
[],
mockAddItem,
mockConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
@@ -863,6 +868,7 @@ describe('useGeminiStream', () => {
[],
mockAddItem,
mockConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
@@ -1174,6 +1180,7 @@ describe('useGeminiStream', () => {
[],
mockAddItem,
mockConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
@@ -1227,6 +1234,7 @@ describe('useGeminiStream', () => {
[],
mockAddItem,
testConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
@@ -1277,6 +1285,7 @@ describe('useGeminiStream', () => {
[],
mockAddItem,
mockConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
@@ -1325,6 +1334,7 @@ describe('useGeminiStream', () => {
[],
mockAddItem,
mockConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
@@ -1374,6 +1384,7 @@ describe('useGeminiStream', () => {
[],
mockAddItem,
mockConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
@@ -1513,6 +1524,7 @@ describe('useGeminiStream', () => {
[],
mockAddItem,
mockConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false, // shellModeActive
@@ -1577,6 +1589,7 @@ describe('useGeminiStream', () => {
[],
mockAddItem,
mockConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
@@ -1655,6 +1668,7 @@ describe('useGeminiStream', () => {
[],
mockAddItem,
mockConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
@@ -1709,6 +1723,7 @@ describe('useGeminiStream', () => {
[],
mockAddItem,
mockConfig,
mockLoadedSettings,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
+28 -8
View File
@@ -7,15 +7,15 @@
import { useState, useRef, useCallback, useEffect, useMemo } from 'react';
import type {
Config,
GeminiClient,
ServerGeminiStreamEvent as GeminiEvent,
ServerGeminiContentEvent as ContentEvent,
ServerGeminiErrorEvent as ErrorEvent,
ServerGeminiChatCompressedEvent,
ServerGeminiFinishedEvent,
ToolCallRequestInfo,
EditorType,
GeminiClient,
ServerGeminiChatCompressedEvent,
ServerGeminiContentEvent as ContentEvent,
ServerGeminiFinishedEvent,
ServerGeminiStreamEvent as GeminiEvent,
ThoughtSummary,
ToolCallRequestInfo,
GeminiErrorEventValue,
} from '@google/gemini-cli-core';
import {
GeminiEventType as ServerGeminiEventType,
@@ -60,6 +60,7 @@ import {
} from './useReactToolScheduler.js';
import { useSessionStats } from '../contexts/SessionContext.js';
import { useKeypress } from './useKeypress.js';
import type { LoadedSettings } from '../../config/settings.js';
enum StreamProcessingStatus {
Completed,
@@ -76,6 +77,7 @@ export const useGeminiStream = (
history: HistoryItem[],
addItem: UseHistoryManagerReturn['addItem'],
config: Config,
settings: LoadedSettings,
onDebugMessage: (message: string) => void,
handleSlashCommand: (
cmd: PartListUnion,
@@ -463,7 +465,7 @@ export const useGeminiStream = (
);
const handleErrorEvent = useCallback(
(eventValue: ErrorEvent['value'], userMessageTimestamp: number) => {
(eventValue: GeminiErrorEventValue, userMessageTimestamp: number) => {
if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
setPendingHistoryItem(null);
@@ -486,6 +488,20 @@ export const useGeminiStream = (
[addItem, pendingHistoryItemRef, setPendingHistoryItem, config, setThought],
);
const handleCitationEvent = useCallback(
(text: string, userMessageTimestamp: number) => {
if (!settings?.merged?.ui?.showCitations) {
return;
}
if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
setPendingHistoryItem(null);
}
addItem({ type: MessageType.INFO, text }, userMessageTimestamp);
},
[addItem, pendingHistoryItemRef, setPendingHistoryItem, settings],
);
const handleFinishedEvent = useCallback(
(event: ServerGeminiFinishedEvent, userMessageTimestamp: number) => {
const finishReason = event.value;
@@ -611,6 +627,9 @@ export const useGeminiStream = (
userMessageTimestamp,
);
break;
case ServerGeminiEventType.Citation:
handleCitationEvent(event.value, userMessageTimestamp);
break;
case ServerGeminiEventType.LoopDetected:
// handle later because we want to move pending history to history
// before we add loop detected message to history
@@ -636,6 +655,7 @@ export const useGeminiStream = (
handleChatCompressionEvent,
handleFinishedEvent,
handleMaxSessionTurnsEvent,
handleCitationEvent,
],
);
+164
View File
@@ -445,6 +445,170 @@ describe('Turn', () => {
]);
});
it('should yield citation and finished events when response has citationMetadata', async () => {
const mockResponseStream = (async function* () {
yield {
candidates: [
{
content: { parts: [{ text: 'Some text.' }] },
citationMetadata: {
citations: [
{
uri: 'https://example.com/source1',
title: 'Source 1 Title',
},
],
},
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
const events = [];
for await (const event of turn.run(
[{ text: 'Test citations' }],
new AbortController().signal,
)) {
events.push(event);
}
expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'Some text.' },
{
type: GeminiEventType.Citation,
value: 'Citations:\n(Source 1 Title) https://example.com/source1',
},
{ type: GeminiEventType.Finished, value: 'STOP' },
]);
});
it('should yield a single citation event for multiple citations in one response', async () => {
const mockResponseStream = (async function* () {
yield {
candidates: [
{
content: { parts: [{ text: 'Some text.' }] },
citationMetadata: {
citations: [
{
uri: 'https://example.com/source2',
title: 'Title2',
},
{
uri: 'https://example.com/source1',
title: 'Title1',
},
],
},
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
const events = [];
for await (const event of turn.run(
[{ text: 'test' }],
new AbortController().signal,
)) {
events.push(event);
}
expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'Some text.' },
{
type: GeminiEventType.Citation,
value:
'Citations:\n(Title1) https://example.com/source1\n(Title2) https://example.com/source2',
},
{ type: GeminiEventType.Finished, value: 'STOP' },
]);
});
it('should not yield citation event if there is no finish reason', async () => {
const mockResponseStream = (async function* () {
yield {
candidates: [
{
content: { parts: [{ text: 'Some text.' }] },
citationMetadata: {
citations: [
{
uri: 'https://example.com/source1',
title: 'Source 1 Title',
},
],
},
// No finishReason
},
],
} as unknown as GenerateContentResponse;
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
const events = [];
for await (const event of turn.run(
[{ text: 'test' }],
new AbortController().signal,
)) {
events.push(event);
}
expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'Some text.' },
]);
// No Citation or Finished event
expect(events.some((e) => e.type === GeminiEventType.Citation)).toBe(
false,
);
});
it('should ignore citations without a URI', async () => {
const mockResponseStream = (async function* () {
yield {
candidates: [
{
content: { parts: [{ text: 'Some text.' }] },
citationMetadata: {
citations: [
{
uri: 'https://example.com/source1',
title: 'Good Source',
},
{
// uri is undefined
title: 'Bad Source',
},
],
},
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
const events = [];
for await (const event of turn.run(
[{ text: 'test' }],
new AbortController().signal,
)) {
events.push(event);
}
expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'Some text.' },
{
type: GeminiEventType.Citation,
value: 'Citations:\n(Good Source) https://example.com/source1',
},
{ type: GeminiEventType.Finished, value: 'STOP' },
]);
});
it('should not crash when cancelled request has malformed error', async () => {
const abortController = new AbortController();
+43 -16
View File
@@ -54,6 +54,7 @@ export enum GeminiEventType {
MaxSessionTurns = 'max_session_turns',
Finished = 'finished',
LoopDetected = 'loop_detected',
Citation = 'citation',
}
export interface StructuredError {
@@ -163,34 +164,37 @@ export type ServerGeminiLoopDetectedEvent = {
type: GeminiEventType.LoopDetected;
};
export type ServerGeminiCitationEvent = {
type: GeminiEventType.Citation;
value: string;
};
// The original union type, now composed of the individual types
export type ServerGeminiStreamEvent =
| ServerGeminiChatCompressedEvent
| ServerGeminiCitationEvent
| ServerGeminiContentEvent
| ServerGeminiErrorEvent
| ServerGeminiFinishedEvent
| ServerGeminiLoopDetectedEvent
| ServerGeminiMaxSessionTurnsEvent
| ServerGeminiThoughtEvent
| ServerGeminiToolCallConfirmationEvent
| ServerGeminiToolCallRequestEvent
| ServerGeminiToolCallResponseEvent
| ServerGeminiToolCallConfirmationEvent
| ServerGeminiUserCancelledEvent
| ServerGeminiErrorEvent
| ServerGeminiChatCompressedEvent
| ServerGeminiThoughtEvent
| ServerGeminiMaxSessionTurnsEvent
| ServerGeminiFinishedEvent
| ServerGeminiLoopDetectedEvent;
| ServerGeminiUserCancelledEvent;
// A turn manages the agentic loop turn within the server context.
export class Turn {
readonly pendingToolCalls: ToolCallRequestInfo[];
private debugResponses: GenerateContentResponse[];
finishReason: FinishReason | undefined;
readonly pendingToolCalls: ToolCallRequestInfo[] = [];
private debugResponses: GenerateContentResponse[] = [];
private pendingCitations = new Set<string>();
finishReason: FinishReason | undefined = undefined;
constructor(
private readonly chat: GeminiChat,
private readonly prompt_id: string,
) {
this.pendingToolCalls = [];
this.debugResponses = [];
this.finishReason = undefined;
}
) {}
// The run method yields simpler events suitable for server logic
async *run(
req: PartListUnion,
@@ -251,10 +255,22 @@ export class Turn {
}
}
for (const citation of getCitations(resp)) {
this.pendingCitations.add(citation);
}
// Check if response was truncated or stopped for various reasons
const finishReason = resp.candidates?.[0]?.finishReason;
if (finishReason) {
if (this.pendingCitations.size > 0) {
yield {
type: GeminiEventType.Citation,
value: `Citations:\n${[...this.pendingCitations].sort().join('\n')}`,
};
this.pendingCitations.clear();
}
this.finishReason = finishReason;
yield {
type: GeminiEventType.Finished,
@@ -325,3 +341,14 @@ export class Turn {
return this.debugResponses;
}
}
function getCitations(resp: GenerateContentResponse): string[] {
return (resp.candidates?.[0]?.citationMetadata?.citations ?? [])
.filter((citation) => citation.uri !== undefined)
.map((citation) => {
if (citation.title) {
return `(${citation.title}) ${citation.uri}`;
}
return citation.uri!;
});
}