Switch to a reducer for tracking update state fixing flicker issues due to continuous renders (#10280)

This commit is contained in:
Jacob Richman
2025-10-01 14:53:15 -07:00
committed by GitHub
parent ef76a801c4
commit a404fb8d2e
13 changed files with 599 additions and 361 deletions

View File

@@ -91,11 +91,20 @@ export async function handleUpdate(args: UpdateArgs) {
}
if (args.all) {
try {
const extensionState = new Map();
await checkForAllExtensionUpdates(extensions, (action) => {
if (action.type === 'SET_STATE') {
extensionState.set(action.payload.name, {
status: action.payload.state,
processed: true, // No need to process as we will force the update.
});
}
});
let updateInfos = await updateAllUpdatableExtensions(
workingDir,
requestConsentNonInteractive,
extensions,
await checkForAllExtensionUpdates(extensions, new Map(), (_) => {}),
extensionState,
() => {},
);
updateInfos = updateInfos.filter(

View File

@@ -185,8 +185,7 @@ describe('update tests', () => {
});
mockGit.getRemotes.mockResolvedValue([{ name: 'origin' }]);
const setExtensionUpdateState = vi.fn();
const dispatch = vi.fn();
const extension = annotateActiveExtensions(
[
loadExtension({
@@ -202,15 +201,23 @@ describe('update tests', () => {
tempHomeDir,
async (_) => true,
ExtensionUpdateState.UPDATE_AVAILABLE,
setExtensionUpdateState,
dispatch,
);
expect(setExtensionUpdateState).toHaveBeenCalledWith(
ExtensionUpdateState.UPDATING,
);
expect(setExtensionUpdateState).toHaveBeenCalledWith(
ExtensionUpdateState.UPDATED_NEEDS_RESTART,
);
expect(dispatch).toHaveBeenCalledWith({
type: 'SET_STATE',
payload: {
name: extensionName,
state: ExtensionUpdateState.UPDATING,
},
});
expect(dispatch).toHaveBeenCalledWith({
type: 'SET_STATE',
payload: {
name: extensionName,
state: ExtensionUpdateState.UPDATED_NEEDS_RESTART,
},
});
});
it('should call setExtensionUpdateState with ERROR on failure', async () => {
@@ -228,7 +235,7 @@ describe('update tests', () => {
mockGit.clone.mockRejectedValue(new Error('Git clone failed'));
mockGit.getRemotes.mockResolvedValue([{ name: 'origin' }]);
const setExtensionUpdateState = vi.fn();
const dispatch = vi.fn();
const extension = annotateActiveExtensions(
[
loadExtension({
@@ -245,16 +252,24 @@ describe('update tests', () => {
tempHomeDir,
async (_) => true,
ExtensionUpdateState.UPDATE_AVAILABLE,
setExtensionUpdateState,
dispatch,
),
).rejects.toThrow();
expect(setExtensionUpdateState).toHaveBeenCalledWith(
ExtensionUpdateState.UPDATING,
);
expect(setExtensionUpdateState).toHaveBeenCalledWith(
ExtensionUpdateState.ERROR,
);
expect(dispatch).toHaveBeenCalledWith({
type: 'SET_STATE',
payload: {
name: extensionName,
state: ExtensionUpdateState.UPDATING,
},
});
expect(dispatch).toHaveBeenCalledWith({
type: 'SET_STATE',
payload: {
name: extensionName,
state: ExtensionUpdateState.ERROR,
},
});
});
});
@@ -286,20 +301,15 @@ describe('update tests', () => {
mockGit.listRemote.mockResolvedValue('remoteHash HEAD');
mockGit.revparse.mockResolvedValue('localHash');
let extensionState = new Map();
const results = await checkForAllExtensionUpdates(
[extension],
extensionState,
(newState) => {
if (typeof newState === 'function') {
newState(extensionState);
} else {
extensionState = newState;
}
const dispatch = vi.fn();
await checkForAllExtensionUpdates([extension], dispatch);
expect(dispatch).toHaveBeenCalledWith({
type: 'SET_STATE',
payload: {
name: 'test-extension',
state: ExtensionUpdateState.UPDATE_AVAILABLE,
},
);
const result = results.get('test-extension');
expect(result).toBe(ExtensionUpdateState.UPDATE_AVAILABLE);
});
});
it('should return UpToDate for a git extension with no updates', async () => {
@@ -329,20 +339,15 @@ describe('update tests', () => {
mockGit.listRemote.mockResolvedValue('sameHash HEAD');
mockGit.revparse.mockResolvedValue('sameHash');
let extensionState = new Map();
const results = await checkForAllExtensionUpdates(
[extension],
extensionState,
(newState) => {
if (typeof newState === 'function') {
newState(extensionState);
} else {
extensionState = newState;
}
const dispatch = vi.fn();
await checkForAllExtensionUpdates([extension], dispatch);
expect(dispatch).toHaveBeenCalledWith({
type: 'SET_STATE',
payload: {
name: 'test-extension',
state: ExtensionUpdateState.UP_TO_DATE,
},
);
const result = results.get('test-extension');
expect(result).toBe(ExtensionUpdateState.UP_TO_DATE);
});
});
it('should return UpToDate for a local extension with no updates', async () => {
@@ -369,21 +374,15 @@ describe('update tests', () => {
process.cwd(),
new ExtensionEnablementManager(ExtensionStorage.getUserExtensionsDir()),
)[0];
let extensionState = new Map();
const results = await checkForAllExtensionUpdates(
[extension],
extensionState,
(newState) => {
if (typeof newState === 'function') {
newState(extensionState);
} else {
extensionState = newState;
}
const dispatch = vi.fn();
await checkForAllExtensionUpdates([extension], dispatch);
expect(dispatch).toHaveBeenCalledWith({
type: 'SET_STATE',
payload: {
name: 'local-extension',
state: ExtensionUpdateState.UP_TO_DATE,
},
tempWorkspaceDir,
);
const result = results.get('local-extension');
expect(result).toBe(ExtensionUpdateState.UP_TO_DATE);
});
});
it('should return UpdateAvailable for a local extension with updates', async () => {
@@ -410,21 +409,15 @@ describe('update tests', () => {
process.cwd(),
new ExtensionEnablementManager(ExtensionStorage.getUserExtensionsDir()),
)[0];
let extensionState = new Map();
const results = await checkForAllExtensionUpdates(
[extension],
extensionState,
(newState) => {
if (typeof newState === 'function') {
newState(extensionState);
} else {
extensionState = newState;
}
const dispatch = vi.fn();
await checkForAllExtensionUpdates([extension], dispatch);
expect(dispatch).toHaveBeenCalledWith({
type: 'SET_STATE',
payload: {
name: 'local-extension',
state: ExtensionUpdateState.UPDATE_AVAILABLE,
},
tempWorkspaceDir,
);
const result = results.get('local-extension');
expect(result).toBe(ExtensionUpdateState.UPDATE_AVAILABLE);
});
});
it('should return Error when git check fails', async () => {
@@ -450,20 +443,15 @@ describe('update tests', () => {
mockGit.getRemotes.mockRejectedValue(new Error('Git error'));
let extensionState = new Map();
const results = await checkForAllExtensionUpdates(
[extension],
extensionState,
(newState) => {
if (typeof newState === 'function') {
newState(extensionState);
} else {
extensionState = newState;
}
const dispatch = vi.fn();
await checkForAllExtensionUpdates([extension], dispatch);
expect(dispatch).toHaveBeenCalledWith({
type: 'SET_STATE',
payload: {
name: 'error-extension',
state: ExtensionUpdateState.ERROR,
},
);
const result = results.get('error-extension');
expect(result).toBe(ExtensionUpdateState.ERROR);
});
});
});
});

View File

@@ -4,11 +4,11 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { GeminiCLIExtension } from '@google/gemini-cli-core';
import * as fs from 'node:fs';
import { getErrorMessage } from '../../utils/errors.js';
import { ExtensionUpdateState } from '../../ui/state/extensions.js';
import { type Dispatch, type SetStateAction } from 'react';
import {
type ExtensionUpdateAction,
ExtensionUpdateState,
type ExtensionUpdateStatus,
} from '../../ui/state/extensions.js';
import {
copyExtension,
installExtension,
@@ -19,6 +19,9 @@ import {
loadExtensionConfig,
} from '../extension.js';
import { checkForExtensionUpdate } from './github.js';
import type { GeminiCLIExtension } from '@google/gemini-cli-core';
import * as fs from 'node:fs';
import { getErrorMessage } from '../../utils/errors.js';
export interface ExtensionUpdateInfo {
name: string;
@@ -31,22 +34,31 @@ export async function updateExtension(
cwd: string = process.cwd(),
requestConsent: (consent: string) => Promise<boolean>,
currentState: ExtensionUpdateState,
setExtensionUpdateState: (updateState: ExtensionUpdateState) => void,
dispatchExtensionStateUpdate: (action: ExtensionUpdateAction) => void,
): Promise<ExtensionUpdateInfo | undefined> {
if (currentState === ExtensionUpdateState.UPDATING) {
return undefined;
}
setExtensionUpdateState(ExtensionUpdateState.UPDATING);
dispatchExtensionStateUpdate({
type: 'SET_STATE',
payload: { name: extension.name, state: ExtensionUpdateState.UPDATING },
});
const installMetadata = loadInstallMetadata(extension.path);
if (!installMetadata?.type) {
setExtensionUpdateState(ExtensionUpdateState.ERROR);
dispatchExtensionStateUpdate({
type: 'SET_STATE',
payload: { name: extension.name, state: ExtensionUpdateState.ERROR },
});
throw new Error(
`Extension ${extension.name} cannot be updated, type is unknown.`,
);
}
if (installMetadata?.type === 'link') {
setExtensionUpdateState(ExtensionUpdateState.UP_TO_DATE);
dispatchExtensionStateUpdate({
type: 'SET_STATE',
payload: { name: extension.name, state: ExtensionUpdateState.UP_TO_DATE },
});
throw new Error(`Extension is linked so does not need to be updated`);
}
const originalVersion = extension.version;
@@ -72,11 +84,20 @@ export async function updateExtension(
workspaceDir: cwd,
});
if (!updatedExtension) {
setExtensionUpdateState(ExtensionUpdateState.ERROR);
dispatchExtensionStateUpdate({
type: 'SET_STATE',
payload: { name: extension.name, state: ExtensionUpdateState.ERROR },
});
throw new Error('Updated extension not found after installation.');
}
const updatedVersion = updatedExtension.config.version;
setExtensionUpdateState(ExtensionUpdateState.UPDATED_NEEDS_RESTART);
dispatchExtensionStateUpdate({
type: 'SET_STATE',
payload: {
name: extension.name,
state: ExtensionUpdateState.UPDATED_NEEDS_RESTART,
},
});
return {
name: extension.name,
originalVersion,
@@ -86,7 +107,10 @@ export async function updateExtension(
console.error(
`Error updating extension, rolling back. ${getErrorMessage(e)}`,
);
setExtensionUpdateState(ExtensionUpdateState.ERROR);
dispatchExtensionStateUpdate({
type: 'SET_STATE',
payload: { name: extension.name, state: ExtensionUpdateState.ERROR },
});
await copyExtension(tempDir, extension.path);
throw e;
} finally {
@@ -98,17 +122,15 @@ export async function updateAllUpdatableExtensions(
cwd: string = process.cwd(),
requestConsent: (consent: string) => Promise<boolean>,
extensions: GeminiCLIExtension[],
extensionsState: Map<string, ExtensionUpdateState>,
setExtensionsUpdateState: Dispatch<
SetStateAction<Map<string, ExtensionUpdateState>>
>,
extensionsState: Map<string, ExtensionUpdateStatus>,
dispatch: (action: ExtensionUpdateAction) => void,
): Promise<ExtensionUpdateInfo[]> {
return (
await Promise.all(
extensions
.filter(
(extension) =>
extensionsState.get(extension.name) ===
extensionsState.get(extension.name)?.status ===
ExtensionUpdateState.UPDATE_AVAILABLE,
)
.map((extension) =>
@@ -116,14 +138,8 @@ export async function updateAllUpdatableExtensions(
extension,
cwd,
requestConsent,
extensionsState.get(extension.name)!,
(updateState) => {
setExtensionsUpdateState((prev) => {
const finalState = new Map(prev);
finalState.set(extension.name, updateState);
return finalState;
});
},
extensionsState.get(extension.name)!.status,
dispatch,
),
),
)
@@ -137,38 +153,30 @@ export interface ExtensionUpdateCheckResult {
export async function checkForAllExtensionUpdates(
extensions: GeminiCLIExtension[],
extensionsUpdateState: Map<string, ExtensionUpdateState>,
setExtensionsUpdateState: Dispatch<
SetStateAction<Map<string, ExtensionUpdateState>>
>,
cwd: string = process.cwd(),
): Promise<Map<string, ExtensionUpdateState>> {
let newStates: Map<string, ExtensionUpdateState> = new Map(
extensionsUpdateState,
);
dispatch: (action: ExtensionUpdateAction) => void,
): Promise<void> {
dispatch({ type: 'BATCH_CHECK_START' });
const promises: Array<Promise<void>> = [];
for (const extension of extensions) {
const initialState = extensionsUpdateState.get(extension.name);
if (initialState === undefined) {
if (!extension.installMetadata) {
setExtensionsUpdateState((prev) => {
newStates = new Map(prev);
newStates.set(extension.name, ExtensionUpdateState.NOT_UPDATABLE);
return newStates;
});
continue;
}
await checkForExtensionUpdate(
extension,
(updatedState) => {
setExtensionsUpdateState((prev) => {
newStates = new Map(prev);
newStates.set(extension.name, updatedState);
return newStates;
});
if (!extension.installMetadata) {
dispatch({
type: 'SET_STATE',
payload: {
name: extension.name,
state: ExtensionUpdateState.NOT_UPDATABLE,
},
cwd,
);
});
continue;
}
promises.push(
checkForExtensionUpdate(extension, (updatedState) => {
dispatch({
type: 'SET_STATE',
payload: { name: extension.name, state: updatedState },
});
}),
);
}
return newStates;
await Promise.all(promises);
dispatch({ type: 'BATCH_CHECK_END' });
}

View File

@@ -155,7 +155,8 @@ export const AppContainer = (props: AppContainerProps) => {
const extensions = config.getExtensions();
const {
extensionsUpdateState,
setExtensionsUpdateState,
extensionsUpdateStateInternal,
dispatchExtensionStateUpdate,
confirmUpdateExtensionRequests,
addConfirmUpdateExtensionRequest,
} = useExtensionUpdates(
@@ -459,7 +460,7 @@ Logging in with Google... Please restart Gemini CLI to continue.
},
setDebugMessage,
toggleCorgiMode: () => setCorgiMode((prev) => !prev),
setExtensionsUpdateState,
dispatchExtensionStateUpdate,
addConfirmUpdateExtensionRequest,
}),
[
@@ -472,7 +473,7 @@ Logging in with Google... Please restart Gemini CLI to continue.
setDebugMessage,
setShowPrivacyNotice,
setCorgiMode,
setExtensionsUpdateState,
dispatchExtensionStateUpdate,
openPermissionsDialog,
addConfirmUpdateExtensionRequest,
],
@@ -496,7 +497,7 @@ Logging in with Google... Please restart Gemini CLI to continue.
setIsProcessing,
setGeminiMdFileCount,
slashCommandActions,
extensionsUpdateState,
extensionsUpdateStateInternal,
isConfigInitialized,
);

View File

@@ -26,6 +26,7 @@ import { ExtensionUpdateState } from '../state/extensions.js';
vi.mock('../../config/extensions/update.js', () => ({
updateExtension: vi.fn(),
updateAllUpdatableExtensions: vi.fn(),
checkForAllExtensionUpdates: vi.fn(),
}));
const mockUpdateExtension = updateExtension as MockedFunction<
@@ -51,6 +52,9 @@ describe('extensionsCommand', () => {
getWorkingDir: () => '/test/dir',
},
},
ui: {
dispatchExtensionStateUpdate: vi.fn(),
},
});
});
@@ -168,10 +172,10 @@ describe('extensionsCommand', () => {
updatedVersion: '1.0.1',
});
mockGetExtensions.mockReturnValue([extension]);
mockContext.ui.extensionsUpdateState.set(
extension.name,
ExtensionUpdateState.UPDATE_AVAILABLE,
);
mockContext.ui.extensionsUpdateState.set(extension.name, {
status: ExtensionUpdateState.UPDATE_AVAILABLE,
processed: false,
});
await updateAction(mockContext, 'ext-one');
expect(mockUpdateExtension).toHaveBeenCalledWith(
extension,

View File

@@ -9,6 +9,7 @@ import {
updateAllUpdatableExtensions,
type ExtensionUpdateInfo,
updateExtension,
checkForAllExtensionUpdates,
} from '../../config/extensions/update.js';
import { getErrorMessage } from '../../utils/errors.js';
import { ExtensionUpdateState } from '../state/extensions.js';
@@ -46,6 +47,10 @@ async function updateAction(context: CommandContext, args: string) {
}
try {
await checkForAllExtensionUpdates(
context.services.config!.getExtensions(),
context.ui.dispatchExtensionStateUpdate,
);
context.ui.setPendingItem({
type: MessageType.EXTENSIONS_LIST,
});
@@ -60,7 +65,7 @@ async function updateAction(context: CommandContext, args: string) {
),
context.services.config!.getExtensions(),
context.ui.extensionsUpdateState,
context.ui.setExtensionsUpdateState,
context.ui.dispatchExtensionStateUpdate,
);
} else if (names?.length) {
const workingDir = context.services.config!.getWorkingDir();
@@ -87,15 +92,9 @@ async function updateAction(context: CommandContext, args: string) {
description,
context.ui.addConfirmUpdateExtensionRequest,
),
context.ui.extensionsUpdateState.get(extension.name) ??
context.ui.extensionsUpdateState.get(extension.name)?.status ??
ExtensionUpdateState.UNKNOWN,
(updateState) => {
context.ui.setExtensionsUpdateState((prev) => {
const newState = new Map(prev);
newState.set(name, updateState);
return newState;
});
},
context.ui.dispatchExtensionStateUpdate,
);
if (updateInfo) updateInfos.push(updateInfo);
}

View File

@@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { Dispatch, ReactNode, SetStateAction } from 'react';
import type { ReactNode } from 'react';
import type { Content, PartListUnion } from '@google/genai';
import type {
HistoryItemWithoutId,
@@ -15,7 +15,10 @@ import type { Config, GitService, Logger } from '@google/gemini-cli-core';
import type { LoadedSettings } from '../../config/settings.js';
import type { UseHistoryManagerReturn } from '../hooks/useHistoryManager.js';
import type { SessionStatsState } from '../contexts/SessionContext.js';
import type { ExtensionUpdateState } from '../state/extensions.js';
import type {
ExtensionUpdateAction,
ExtensionUpdateStatus,
} from '../state/extensions.js';
// Grouped dependencies for clarity and easier mocking
export interface CommandContext {
@@ -66,10 +69,8 @@ export interface CommandContext {
toggleVimEnabled: () => Promise<boolean>;
setGeminiMdFileCount: (count: number) => void;
reloadCommands: () => void;
extensionsUpdateState: Map<string, ExtensionUpdateState>;
setExtensionsUpdateState: Dispatch<
SetStateAction<Map<string, ExtensionUpdateState>>
>;
extensionsUpdateState: Map<string, ExtensionUpdateStatus>;
dispatchExtensionStateUpdate: (action: ExtensionUpdateAction) => void;
addConfirmUpdateExtensionRequest: (value: ConfirmationRequest) => void;
};
// Session-specific data

View File

@@ -26,8 +26,8 @@ import type {
} from '@google/gemini-cli-core';
import type { DOMElement } from 'ink';
import type { SessionStatsState } from '../contexts/SessionContext.js';
import type { UpdateObject } from '../utils/updateCheck.js';
import type { ExtensionUpdateState } from '../state/extensions.js';
import type { UpdateObject } from '../utils/updateCheck.js';
export interface ProQuotaDialogRequest {
failedModel: string;

View File

@@ -4,14 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import {
useCallback,
useMemo,
useEffect,
useState,
type Dispatch,
type SetStateAction,
} from 'react';
import { useCallback, useMemo, useEffect, useState } from 'react';
import { type PartListUnion } from '@google/genai';
import process from 'node:process';
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
@@ -42,7 +35,10 @@ import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js';
import { FileCommandLoader } from '../../services/FileCommandLoader.js';
import { McpPromptLoader } from '../../services/McpPromptLoader.js';
import { parseSlashCommand } from '../../utils/commands.js';
import type { ExtensionUpdateState } from '../state/extensions.js';
import {
type ExtensionUpdateAction,
type ExtensionUpdateStatus,
} from '../state/extensions.js';
interface SlashCommandProcessorActions {
openAuthDialog: () => void;
@@ -55,9 +51,7 @@ interface SlashCommandProcessorActions {
quit: (messages: HistoryItem[]) => void;
setDebugMessage: (message: string) => void;
toggleCorgiMode: () => void;
setExtensionsUpdateState: Dispatch<
SetStateAction<Map<string, ExtensionUpdateState>>
>;
dispatchExtensionStateUpdate: (action: ExtensionUpdateAction) => void;
addConfirmUpdateExtensionRequest: (request: ConfirmationRequest) => void;
}
@@ -75,7 +69,7 @@ export const useSlashCommandProcessor = (
setIsProcessing: (isProcessing: boolean) => void,
setGeminiMdFileCount: (count: number) => void,
actions: SlashCommandProcessorActions,
extensionsUpdateState: Map<string, ExtensionUpdateState>,
extensionsUpdateState: Map<string, ExtensionUpdateStatus>,
isConfigInitialized: boolean,
) => {
const session = useSessionStats();
@@ -207,7 +201,7 @@ export const useSlashCommandProcessor = (
setGeminiMdFileCount,
reloadCommands,
extensionsUpdateState,
setExtensionsUpdateState: actions.setExtensionsUpdateState,
dispatchExtensionStateUpdate: actions.dispatchExtensionStateUpdate,
addConfirmUpdateExtensionRequest:
actions.addConfirmUpdateExtensionRequest,
},

View File

@@ -9,7 +9,6 @@ import * as fs from 'node:fs';
import * as os from 'node:os';
import * as path from 'node:path';
import {
EXTENSIONS_CONFIG_FILENAME,
ExtensionStorage,
annotateActiveExtensions,
loadExtension,
@@ -17,29 +16,14 @@ import {
import { createExtension } from '../../test-utils/createExtension.js';
import { useExtensionUpdates } from './useExtensionUpdates.js';
import { GEMINI_DIR, type GeminiCLIExtension } from '@google/gemini-cli-core';
import { isWorkspaceTrusted } from '../../config/trustedFolders.js';
import { renderHook, waitFor } from '@testing-library/react';
import { MessageType } from '../types.js';
import { ExtensionEnablementManager } from '../../config/extensions/extensionEnablement.js';
const mockGit = {
clone: vi.fn(),
getRemotes: vi.fn(),
fetch: vi.fn(),
checkout: vi.fn(),
listRemote: vi.fn(),
revparse: vi.fn(),
// Not a part of the actual API, but we need to use this to do the correct
// file system interactions.
path: vi.fn(),
};
vi.mock('simple-git', () => ({
simpleGit: vi.fn((path: string) => {
mockGit.path.mockReturnValue(path);
return mockGit;
}),
}));
import {
checkForAllExtensionUpdates,
updateExtension,
} from '../../config/extensions/update.js';
import { ExtensionUpdateState } from '../state/extensions.js';
vi.mock('os', async (importOriginal) => {
const mockedOs = await importOriginal<typeof os>();
@@ -49,45 +33,9 @@ vi.mock('os', async (importOriginal) => {
};
});
vi.mock('../../config/trustedFolders.js', async (importOriginal) => {
const actual =
await importOriginal<typeof import('../../config/trustedFolders.js')>();
return {
...actual,
isWorkspaceTrusted: vi.fn(),
};
});
const mockLogExtensionInstallEvent = vi.hoisted(() => vi.fn());
const mockLogExtensionUninstall = vi.hoisted(() => vi.fn());
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const actual =
await importOriginal<typeof import('@google/gemini-cli-core')>();
return {
...actual,
logExtensionInstallEvent: mockLogExtensionInstallEvent,
logExtensionUninstall: mockLogExtensionUninstall,
ExtensionInstallEvent: vi.fn(),
ExtensionUninstallEvent: vi.fn(),
};
});
vi.mock('child_process', async (importOriginal) => {
const actual = await importOriginal<typeof import('child_process')>();
return {
...actual,
execSync: vi.fn(),
};
});
const mockQuestion = vi.hoisted(() => vi.fn());
const mockClose = vi.hoisted(() => vi.fn());
vi.mock('node:readline', () => ({
createInterface: vi.fn(() => ({
question: mockQuestion,
close: mockClose,
})),
vi.mock('../../config/extensions/update.js', () => ({
checkForAllExtensionUpdates: vi.fn(),
updateExtension: vi.fn(),
}));
describe('useExtensionUpdates', () => {
@@ -101,7 +49,8 @@ describe('useExtensionUpdates', () => {
vi.mocked(os.homedir).mockReturnValue(tempHomeDir);
userExtensionsDir = path.join(tempHomeDir, GEMINI_DIR, 'extensions');
fs.mkdirSync(userExtensionsDir, { recursive: true });
Object.values(mockGit).forEach((fn) => fn.mockReset());
vi.mocked(checkForAllExtensionUpdates).mockReset();
vi.mocked(updateExtension).mockReset();
});
afterEach(() => {
@@ -126,16 +75,17 @@ describe('useExtensionUpdates', () => {
const addItem = vi.fn();
const cwd = '/test/cwd';
mockGit.getRemotes.mockResolvedValue([
{
name: 'origin',
refs: {
fetch: 'https://github.com/google/gemini-cli.git',
},
vi.mocked(checkForAllExtensionUpdates).mockImplementation(
async (extensions, dispatch) => {
dispatch({
type: 'SET_STATE',
payload: {
name: 'test-extension',
state: ExtensionUpdateState.UPDATE_AVAILABLE,
},
});
},
]);
mockGit.revparse.mockResolvedValue('local-hash');
mockGit.listRemote.mockResolvedValue('remote-hash\tHEAD');
);
renderHook(() =>
useExtensionUpdates(extensions as GeminiCLIExtension[], addItem, cwd),
@@ -170,28 +120,23 @@ describe('useExtensionUpdates', () => {
)[0];
const addItem = vi.fn();
mockGit.getRemotes.mockResolvedValue([
{
name: 'origin',
refs: {
fetch: 'https://github.com/google/gemini-cli.git',
},
vi.mocked(checkForAllExtensionUpdates).mockImplementation(
async (extensions, dispatch) => {
dispatch({
type: 'SET_STATE',
payload: {
name: 'test-extension',
state: ExtensionUpdateState.UPDATE_AVAILABLE,
},
});
},
]);
mockGit.revparse.mockResolvedValue('local-hash');
mockGit.listRemote.mockResolvedValue('remote-hash\tHEAD');
mockGit.clone.mockImplementation(async (_, destination) => {
fs.mkdirSync(path.join(mockGit.path(), destination), {
recursive: true,
});
fs.writeFileSync(
path.join(mockGit.path(), destination, EXTENSIONS_CONFIG_FILENAME),
JSON.stringify({ name: 'test-extension', version: '1.1.0' }),
);
});
vi.mocked(isWorkspaceTrusted).mockReturnValue({
isTrusted: true,
source: 'file',
);
vi.mocked(updateExtension).mockResolvedValue({
originalVersion: '1.0.0',
updatedVersion: '1.1.0',
name: '',
});
renderHook(() => useExtensionUpdates([extension], addItem, tempHomeDir));
@@ -206,7 +151,169 @@ describe('useExtensionUpdates', () => {
expect.any(Number),
);
},
{ timeout: 2000 },
{ timeout: 4000 },
);
});
it('should batch update notifications for multiple extensions', async () => {
const extensionDir1 = createExtension({
extensionsDir: userExtensionsDir,
name: 'test-extension-1',
version: '1.0.0',
installMetadata: {
source: 'https://some.git/repo1',
type: 'git',
autoUpdate: true,
},
});
const extensionDir2 = createExtension({
extensionsDir: userExtensionsDir,
name: 'test-extension-2',
version: '2.0.0',
installMetadata: {
source: 'https://some.git/repo2',
type: 'git',
autoUpdate: true,
},
});
const extensions = annotateActiveExtensions(
[
loadExtension({
extensionDir: extensionDir1,
workspaceDir: tempHomeDir,
})!,
loadExtension({
extensionDir: extensionDir2,
workspaceDir: tempHomeDir,
})!,
],
tempHomeDir,
new ExtensionEnablementManager(ExtensionStorage.getUserExtensionsDir()),
);
const addItem = vi.fn();
vi.mocked(checkForAllExtensionUpdates).mockImplementation(
async (extensions, dispatch) => {
dispatch({
type: 'SET_STATE',
payload: {
name: 'test-extension-1',
state: ExtensionUpdateState.UPDATE_AVAILABLE,
},
});
dispatch({
type: 'SET_STATE',
payload: {
name: 'test-extension-2',
state: ExtensionUpdateState.UPDATE_AVAILABLE,
},
});
},
);
vi.mocked(updateExtension)
.mockResolvedValueOnce({
originalVersion: '1.0.0',
updatedVersion: '1.1.0',
name: '',
})
.mockResolvedValueOnce({
originalVersion: '2.0.0',
updatedVersion: '2.1.0',
name: '',
});
renderHook(() => useExtensionUpdates(extensions, addItem, tempHomeDir));
await waitFor(
() => {
expect(addItem).toHaveBeenCalledTimes(2);
expect(addItem).toHaveBeenCalledWith(
{
type: MessageType.INFO,
text: 'Extension "test-extension-1" successfully updated: 1.0.0 → 1.1.0.',
},
expect.any(Number),
);
expect(addItem).toHaveBeenCalledWith(
{
type: MessageType.INFO,
text: 'Extension "test-extension-2" successfully updated: 2.0.0 → 2.1.0.',
},
expect.any(Number),
);
},
{ timeout: 4000 },
);
});
it('should batch update notifications for multiple extensions with autoUpdate: false', async () => {
const extensions = [
{
name: 'test-extension-1',
type: 'git',
version: '1.0.0',
path: '/some/path1',
isActive: true,
installMetadata: {
type: 'git',
source: 'https://some/repo1',
autoUpdate: false,
},
},
{
name: 'test-extension-2',
type: 'git',
version: '2.0.0',
path: '/some/path2',
isActive: true,
installMetadata: {
type: 'git',
source: 'https://some/repo2',
autoUpdate: false,
},
},
];
const addItem = vi.fn();
const cwd = '/test/cwd';
vi.mocked(checkForAllExtensionUpdates).mockImplementation(
async (extensions, dispatch) => {
dispatch({ type: 'BATCH_CHECK_START' });
dispatch({
type: 'SET_STATE',
payload: {
name: 'test-extension-1',
state: ExtensionUpdateState.UPDATE_AVAILABLE,
},
});
await new Promise((r) => setTimeout(r, 50));
dispatch({
type: 'SET_STATE',
payload: {
name: 'test-extension-2',
state: ExtensionUpdateState.UPDATE_AVAILABLE,
},
});
dispatch({ type: 'BATCH_CHECK_END' });
},
);
renderHook(() =>
useExtensionUpdates(extensions as GeminiCLIExtension[], addItem, cwd),
);
await waitFor(() => {
expect(addItem).toHaveBeenCalledTimes(1);
expect(addItem).toHaveBeenCalledWith(
{
type: MessageType.INFO,
text: 'You have 2 extensions with an update available, run "/extensions list" for more information.',
},
expect.any(Number),
);
});
});
});

View File

@@ -6,8 +6,12 @@
import type { GeminiCLIExtension } from '@google/gemini-cli-core';
import { getErrorMessage } from '../../utils/errors.js';
import { ExtensionUpdateState } from '../state/extensions.js';
import { useCallback, useState } from 'react';
import {
ExtensionUpdateState,
extensionUpdatesReducer,
initialExtensionUpdatesState,
} from '../state/extensions.js';
import { useCallback, useEffect, useMemo, useReducer } from 'react';
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
import { MessageType, type ConfirmationRequest } from '../types.js';
import {
@@ -15,118 +19,167 @@ import {
updateExtension,
} from '../../config/extensions/update.js';
import { requestConsentInteractive } from '../../config/extension.js';
import { checkExhaustive } from '../../utils/checks.js';
type ConfirmationRequestWrapper = {
prompt: React.ReactNode;
onConfirm: (confirmed: boolean) => void;
};
type ConfirmationRequestAction =
| { type: 'add'; request: ConfirmationRequestWrapper }
| { type: 'remove'; request: ConfirmationRequestWrapper };
function confirmationRequestsReducer(
state: ConfirmationRequestWrapper[],
action: ConfirmationRequestAction,
): ConfirmationRequestWrapper[] {
switch (action.type) {
case 'add':
return [...state, action.request];
case 'remove':
return state.filter((r) => r !== action.request);
default:
checkExhaustive(action);
return state;
}
}
export const useExtensionUpdates = (
extensions: GeminiCLIExtension[],
addItem: UseHistoryManagerReturn['addItem'],
cwd: string,
) => {
const [extensionsUpdateState, setExtensionsUpdateState] = useState(
new Map<string, ExtensionUpdateState>(),
const [extensionsUpdateState, dispatchExtensionStateUpdate] = useReducer(
extensionUpdatesReducer,
initialExtensionUpdatesState,
);
const [isChecking, setIsChecking] = useState(false);
const [confirmUpdateExtensionRequests, setConfirmUpdateExtensionRequests] =
useState<
Array<{
prompt: React.ReactNode;
onConfirm: (confirmed: boolean) => void;
}>
>([]);
const [
confirmUpdateExtensionRequests,
dispatchConfirmUpdateExtensionRequests,
] = useReducer(confirmationRequestsReducer, []);
const addConfirmUpdateExtensionRequest = useCallback(
(original: ConfirmationRequest) => {
const wrappedRequest = {
prompt: original.prompt,
onConfirm: (confirmed: boolean) => {
// Remove it from the outstanding list of requests by identity.
setConfirmUpdateExtensionRequests((prev) =>
prev.filter((r) => r !== wrappedRequest),
);
dispatchConfirmUpdateExtensionRequests({
type: 'remove',
request: wrappedRequest,
});
original.onConfirm(confirmed);
},
};
setConfirmUpdateExtensionRequests((prev) => [...prev, wrappedRequest]);
dispatchConfirmUpdateExtensionRequests({
type: 'add',
request: wrappedRequest,
});
},
[setConfirmUpdateExtensionRequests],
[dispatchConfirmUpdateExtensionRequests],
);
(async () => {
if (isChecking) return;
setIsChecking(true);
try {
const updateState = await checkForAllExtensionUpdates(
useEffect(() => {
(async () => {
await checkForAllExtensionUpdates(
extensions,
extensionsUpdateState,
setExtensionsUpdateState,
dispatchExtensionStateUpdate,
);
let extensionsWithUpdatesCount = 0;
for (const extension of extensions) {
const prevState = extensionsUpdateState.get(extension.name);
const currentState = updateState.get(extension.name);
if (
prevState === currentState ||
currentState !== ExtensionUpdateState.UPDATE_AVAILABLE
) {
continue;
}
if (extension.installMetadata?.autoUpdate) {
updateExtension(
extension,
cwd,
(description) =>
requestConsentInteractive(
description,
addConfirmUpdateExtensionRequest,
),
currentState,
(newState) => {
setExtensionsUpdateState((prev) => {
const finalState = new Map(prev);
finalState.set(extension.name, newState);
return finalState;
});
},
)
.then((result) => {
if (!result) return;
addItem(
{
type: MessageType.INFO,
text: `Extension "${extension.name}" successfully updated: ${result.originalVersion}${result.updatedVersion}.`,
},
Date.now(),
);
})
.catch((error) => {
addItem(
{
type: MessageType.ERROR,
text: getErrorMessage(error),
},
Date.now(),
);
});
} else {
extensionsWithUpdatesCount++;
}
}
if (extensionsWithUpdatesCount > 0) {
const s = extensionsWithUpdatesCount > 1 ? 's' : '';
addItem(
{
type: MessageType.INFO,
text: `You have ${extensionsWithUpdatesCount} extension${s} with an update available, run "/extensions list" for more information.`,
},
Date.now(),
);
}
} finally {
setIsChecking(false);
})();
}, [extensions, extensions.length, dispatchExtensionStateUpdate]);
useEffect(() => {
if (extensionsUpdateState.batchChecksInProgress > 0) {
return;
}
})();
let extensionsWithUpdatesCount = 0;
for (const extension of extensions) {
const currentState = extensionsUpdateState.extensionStatuses.get(
extension.name,
);
if (
!currentState ||
currentState.processed ||
currentState.status !== ExtensionUpdateState.UPDATE_AVAILABLE
) {
continue;
}
// Mark as processed immediately to avoid re-triggering.
dispatchExtensionStateUpdate({
type: 'SET_PROCESSED',
payload: { name: extension.name, processed: true },
});
if (extension.installMetadata?.autoUpdate) {
updateExtension(
extension,
cwd,
(description) =>
requestConsentInteractive(
description,
addConfirmUpdateExtensionRequest,
),
currentState.status,
dispatchExtensionStateUpdate,
)
.then((result) => {
if (!result) return;
addItem(
{
type: MessageType.INFO,
text: `Extension "${extension.name}" successfully updated: ${result.originalVersion}${result.updatedVersion}.`,
},
Date.now(),
);
})
.catch((error) => {
addItem(
{
type: MessageType.ERROR,
text: getErrorMessage(error),
},
Date.now(),
);
});
} else {
extensionsWithUpdatesCount++;
}
}
if (extensionsWithUpdatesCount > 0) {
const s = extensionsWithUpdatesCount > 1 ? 's' : '';
addItem(
{
type: MessageType.INFO,
text: `You have ${extensionsWithUpdatesCount} extension${s} with an update available, run "/extensions list" for more information.`,
},
Date.now(),
);
}
}, [
extensions,
extensionsUpdateState,
addConfirmUpdateExtensionRequest,
addItem,
cwd,
]);
const extensionsUpdateStateComputed = useMemo(() => {
const result = new Map<string, ExtensionUpdateState>();
for (const [
key,
value,
] of extensionsUpdateState.extensionStatuses.entries()) {
result.set(key, value.status);
}
return result;
}, [extensionsUpdateState]);
return {
extensionsUpdateState,
setExtensionsUpdateState,
extensionsUpdateState: extensionsUpdateStateComputed,
extensionsUpdateStateInternal: extensionsUpdateState.extensionStatuses,
dispatchExtensionStateUpdate,
confirmUpdateExtensionRequests,
addConfirmUpdateExtensionRequest,
};

View File

@@ -5,6 +5,7 @@
*/
import type { CommandContext } from '../commands/types.js';
import type { ExtensionUpdateAction } from '../state/extensions.js';
/**
* Creates a UI context object with no-op functions.
@@ -24,7 +25,7 @@ export function createNonInteractiveUI(): CommandContext['ui'] {
setGeminiMdFileCount: (_count) => {},
reloadCommands: () => {},
extensionsUpdateState: new Map(),
setExtensionsUpdateState: (_updateState) => {},
dispatchExtensionStateUpdate: (_action: ExtensionUpdateAction) => {},
addConfirmUpdateExtensionRequest: (_request) => {},
};
}

View File

@@ -4,6 +4,8 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { checkExhaustive } from '../../utils/checks.js';
export enum ExtensionUpdateState {
CHECKING_FOR_UPDATES = 'checking for updates',
UPDATED_NEEDS_RESTART = 'updated, needs restart',
@@ -14,3 +16,74 @@ export enum ExtensionUpdateState {
NOT_UPDATABLE = 'not updatable',
UNKNOWN = 'unknown',
}
export interface ExtensionUpdateStatus {
status: ExtensionUpdateState;
processed: boolean;
}
export interface ExtensionUpdatesState {
extensionStatuses: Map<string, ExtensionUpdateStatus>;
batchChecksInProgress: number;
}
export const initialExtensionUpdatesState: ExtensionUpdatesState = {
extensionStatuses: new Map(),
batchChecksInProgress: 0,
};
export type ExtensionUpdateAction =
| {
type: 'SET_STATE';
payload: { name: string; state: ExtensionUpdateState };
}
| {
type: 'SET_PROCESSED';
payload: { name: string; processed: boolean };
}
| { type: 'BATCH_CHECK_START' }
| { type: 'BATCH_CHECK_END' };
export function extensionUpdatesReducer(
state: ExtensionUpdatesState,
action: ExtensionUpdateAction,
): ExtensionUpdatesState {
switch (action.type) {
case 'SET_STATE': {
const existing = state.extensionStatuses.get(action.payload.name);
if (existing?.status === action.payload.state) {
return state;
}
const newStatuses = new Map(state.extensionStatuses);
newStatuses.set(action.payload.name, {
status: action.payload.state,
processed: false,
});
return { ...state, extensionStatuses: newStatuses };
}
case 'SET_PROCESSED': {
const existing = state.extensionStatuses.get(action.payload.name);
if (!existing || existing.processed === action.payload.processed) {
return state;
}
const newStatuses = new Map(state.extensionStatuses);
newStatuses.set(action.payload.name, {
...existing,
processed: action.payload.processed,
});
return { ...state, extensionStatuses: newStatuses };
}
case 'BATCH_CHECK_START':
return {
...state,
batchChecksInProgress: state.batchChecksInProgress + 1,
};
case 'BATCH_CHECK_END':
return {
...state,
batchChecksInProgress: state.batchChecksInProgress - 1,
};
default:
checkExhaustive(action);
}
}