diff --git a/packages/cli/src/ui/components/AppHeader.test.tsx b/packages/cli/src/ui/components/AppHeader.test.tsx index 4dbdbc0052..2fab08c8e6 100644 --- a/packages/cli/src/ui/components/AppHeader.test.tsx +++ b/packages/cli/src/ui/components/AppHeader.test.tsx @@ -10,15 +10,20 @@ import { } from '../../test-utils/render.js'; import type { LoadedSettings } from '../../config/settings.js'; import { AppHeader } from './AppHeader.js'; -import { describe, it, expect, vi } from 'vitest'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; import { makeFakeConfig } from '@google/gemini-cli-core'; import crypto from 'node:crypto'; +import { _clearSessionBannersForTest } from '../hooks/useBanner.js'; vi.mock('../utils/terminalSetup.js', () => ({ getTerminalProgram: () => null, })); describe('', () => { + beforeEach(() => { + _clearSessionBannersForTest(); + }); + it('should render the banner with default text', async () => { const uiState = { history: [], diff --git a/packages/cli/src/ui/components/__snapshots__/InputPrompt.test.tsx.snap b/packages/cli/src/ui/components/__snapshots__/InputPrompt.test.tsx.snap index 59050691d2..caa270d8c4 100644 --- a/packages/cli/src/ui/components/__snapshots__/InputPrompt.test.tsx.snap +++ b/packages/cli/src/ui/components/__snapshots__/InputPrompt.test.tsx.snap @@ -168,13 +168,6 @@ exports[`InputPrompt > mouse interaction > should toggle paste expansion on doub " `; -exports[`InputPrompt > mouse interaction > should toggle paste expansion on double-click 4`] = ` -"▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀ - > [Pasted Text: 10 lines] -▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄ -" -`; - exports[`InputPrompt > multiline rendering > should correctly render multiline input including blank lines 1`] = ` "──────────────────────────────────────────────────────────────────────────────────────────────────── │ > hello │ diff --git a/packages/cli/src/ui/hooks/useBanner.test.ts b/packages/cli/src/ui/hooks/useBanner.test.ts index ad2c3ce0d5..5712aecc91 100644 --- a/packages/cli/src/ui/hooks/useBanner.test.ts +++ b/packages/cli/src/ui/hooks/useBanner.test.ts @@ -13,7 +13,7 @@ import { type MockedFunction, } from 'vitest'; import { renderHook } from '../../test-utils/render.js'; -import { useBanner } from './useBanner.js'; +import { useBanner, _clearSessionBannersForTest } from './useBanner.js'; import { persistentState } from '../../utils/persistentState.js'; import crypto from 'node:crypto'; @@ -56,6 +56,7 @@ describe('useBanner', () => { beforeEach(() => { vi.resetAllMocks(); + _clearSessionBannersForTest(); // Default persistentState behavior: return empty object (no counts) mockedPersistentStateGet.mockReturnValue({}); @@ -101,13 +102,18 @@ describe('useBanner', () => { ); }); - it('should NOT increment count if warning text is shown instead', async () => { + it('should increment count if warning text is shown instead', async () => { const data = { defaultText: 'Standard', warningText: 'Warning' }; await renderHook(() => useBanner(data)); - // Since warning text takes precedence, default banner logic (and increment) is skipped - expect(mockedPersistentStateSet).not.toHaveBeenCalled(); + // Warning text now also gets counted + expect(mockedPersistentStateSet).toHaveBeenCalledWith( + 'defaultBannerShownCount', + { + [crypto.createHash('sha256').update(data.warningText).digest('hex')]: 1, + }, + ); }); it('should handle newline replacements', async () => { diff --git a/packages/cli/src/ui/hooks/useBanner.ts b/packages/cli/src/ui/hooks/useBanner.ts index ab6d0b6a51..ddd739ecd2 100644 --- a/packages/cli/src/ui/hooks/useBanner.ts +++ b/packages/cli/src/ui/hooks/useBanner.ts @@ -4,12 +4,21 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useState, useEffect, useRef } from 'react'; +import { useState, useEffect } from 'react'; import { persistentState } from '../../utils/persistentState.js'; import crypto from 'node:crypto'; const DEFAULT_MAX_BANNER_SHOWN_COUNT = 5; +// Track banners incremented during this session to prevent multiple increments +// on React unmounts/remounts +const sessionIncrementedBanners = new Set(); + +// For testing purposes +export function _clearSessionBannersForTest() { + sessionIncrementedBanners.clear(); +} + interface BannerData { defaultText: string; warningText: string; @@ -22,25 +31,25 @@ export function useBanner(bannerData: BannerData) { () => persistentState.get('defaultBannerShownCount') || {}, ); + const activeText = warningText ? warningText : defaultText; + const hashedText = crypto .createHash('sha256') - .update(defaultText) + .update(activeText) .digest('hex'); const currentBannerCount = bannerCounts[hashedText] || 0; - const showDefaultBanner = - warningText === '' && currentBannerCount < DEFAULT_MAX_BANNER_SHOWN_COUNT; + const showBanner = + activeText !== '' && currentBannerCount < DEFAULT_MAX_BANNER_SHOWN_COUNT; - const rawBannerText = showDefaultBanner ? defaultText : warningText; + const rawBannerText = showBanner ? activeText : ''; const bannerText = rawBannerText.replace(/\\n/g, '\n'); - const lastIncrementedKey = useRef(null); - useEffect(() => { - if (showDefaultBanner && defaultText) { - if (lastIncrementedKey.current !== defaultText) { - lastIncrementedKey.current = defaultText; + if (showBanner && activeText) { + if (!sessionIncrementedBanners.has(activeText)) { + sessionIncrementedBanners.add(activeText); const allCounts = persistentState.get('defaultBannerShownCount') || {}; const current = allCounts[hashedText] || 0; @@ -51,7 +60,7 @@ export function useBanner(bannerData: BannerData) { }); } } - }, [showDefaultBanner, defaultText, hashedText]); + }, [showBanner, activeText, hashedText]); return { bannerText,