diff --git a/packages/cli/src/commands/extensions/update.ts b/packages/cli/src/commands/extensions/update.ts index bb200e58f9..bf28bbc90f 100644 --- a/packages/cli/src/commands/extensions/update.ts +++ b/packages/cli/src/commands/extensions/update.ts @@ -53,16 +53,13 @@ export async function handleUpdate(args: UpdateArgs) { console.log(`Extension "${args.name}" not found.`); return; } - let updateState: ExtensionUpdateState | undefined; if (!extension.installMetadata) { console.log( `Unable to install extension "${args.name}" due to missing install metadata`, ); return; } - await checkForExtensionUpdate(extension, (newState) => { - updateState = newState; - }); + const updateState = await checkForExtensionUpdate(extension); if (updateState !== ExtensionUpdateState.UPDATE_AVAILABLE) { console.log(`Extension "${args.name}" is already up to date.`); return; @@ -92,14 +89,17 @@ export async function handleUpdate(args: UpdateArgs) { if (args.all) { try { const extensionState = new Map(); - await checkForAllExtensionUpdates(extensions, (action) => { - if (action.type === 'SET_STATE') { - extensionState.set(action.payload.name, { - status: action.payload.state, - processed: true, // No need to process as we will force the update. - }); - } - }); + await checkForAllExtensionUpdates( + extensions, + (action) => { + if (action.type === 'SET_STATE') { + extensionState.set(action.payload.name, { + status: action.payload.state, + }); + } + }, + workingDir, + ); let updateInfos = await updateAllUpdatableExtensions( workingDir, requestConsentNonInteractive, diff --git a/packages/cli/src/config/extensions/github.test.ts b/packages/cli/src/config/extensions/github.test.ts index a23e330612..232b0db0ff 100644 --- a/packages/cli/src/config/extensions/github.test.ts +++ b/packages/cli/src/config/extensions/github.test.ts @@ -132,11 +132,7 @@ describe('git extension helpers', () => { source: '', }, }; - let result: ExtensionUpdateState | undefined = undefined; - await checkForExtensionUpdate( - extension, - (newState) => (result = newState), - ); + const result = await checkForExtensionUpdate(extension); expect(result).toBe(ExtensionUpdateState.NOT_UPDATABLE); }); @@ -152,11 +148,7 @@ describe('git extension helpers', () => { }, }; mockGit.getRemotes.mockResolvedValue([]); - let result: ExtensionUpdateState | undefined = undefined; - await checkForExtensionUpdate( - extension, - (newState) => (result = newState), - ); + const result = await checkForExtensionUpdate(extension); expect(result).toBe(ExtensionUpdateState.ERROR); }); @@ -177,11 +169,7 @@ describe('git extension helpers', () => { mockGit.listRemote.mockResolvedValue('remote-hash\tHEAD'); mockGit.revparse.mockResolvedValue('local-hash'); - let result: ExtensionUpdateState | undefined = undefined; - await checkForExtensionUpdate( - extension, - (newState) => (result = newState), - ); + const result = await checkForExtensionUpdate(extension); expect(result).toBe(ExtensionUpdateState.UPDATE_AVAILABLE); }); @@ -202,11 +190,7 @@ describe('git extension helpers', () => { mockGit.listRemote.mockResolvedValue('same-hash\tHEAD'); mockGit.revparse.mockResolvedValue('same-hash'); - let result: ExtensionUpdateState | undefined = undefined; - await checkForExtensionUpdate( - extension, - (newState) => (result = newState), - ); + const result = await checkForExtensionUpdate(extension); expect(result).toBe(ExtensionUpdateState.UP_TO_DATE); }); @@ -223,11 +207,7 @@ describe('git extension helpers', () => { }; mockGit.getRemotes.mockRejectedValue(new Error('git error')); - let result: ExtensionUpdateState | undefined = undefined; - await checkForExtensionUpdate( - extension, - (newState) => (result = newState), - ); + const result = await checkForExtensionUpdate(extension); expect(result).toBe(ExtensionUpdateState.ERROR); }); }); diff --git a/packages/cli/src/config/extensions/github.ts b/packages/cli/src/config/extensions/github.ts index 9a54a29f9d..787fc570e6 100644 --- a/packages/cli/src/config/extensions/github.ts +++ b/packages/cli/src/config/extensions/github.ts @@ -119,10 +119,8 @@ async function fetchReleaseFromGithub( export async function checkForExtensionUpdate( extension: GeminiCLIExtension, - setExtensionUpdateState: (updateState: ExtensionUpdateState) => void, cwd: string = process.cwd(), -): Promise { - setExtensionUpdateState(ExtensionUpdateState.CHECKING_FOR_UPDATES); +): Promise { const installMetadata = extension.installMetadata; if (installMetadata?.type === 'local') { const newExtension = loadExtension({ @@ -133,23 +131,19 @@ export async function checkForExtensionUpdate( console.error( `Failed to check for update for local extension "${extension.name}". Could not load extension from source path: ${installMetadata.source}`, ); - setExtensionUpdateState(ExtensionUpdateState.ERROR); - return; + return ExtensionUpdateState.ERROR; } if (newExtension.config.version !== extension.version) { - setExtensionUpdateState(ExtensionUpdateState.UPDATE_AVAILABLE); - return; + return ExtensionUpdateState.UPDATE_AVAILABLE; } - setExtensionUpdateState(ExtensionUpdateState.UP_TO_DATE); - return; + return ExtensionUpdateState.UP_TO_DATE; } if ( !installMetadata || (installMetadata.type !== 'git' && installMetadata.type !== 'github-release') ) { - setExtensionUpdateState(ExtensionUpdateState.NOT_UPDATABLE); - return; + return ExtensionUpdateState.NOT_UPDATABLE; } try { if (installMetadata.type === 'git') { @@ -157,14 +151,12 @@ export async function checkForExtensionUpdate( const remotes = await git.getRemotes(true); if (remotes.length === 0) { console.error('No git remotes found.'); - setExtensionUpdateState(ExtensionUpdateState.ERROR); - return; + return ExtensionUpdateState.ERROR; } const remoteUrl = remotes[0].refs.fetch; if (!remoteUrl) { console.error(`No fetch URL found for git remote ${remotes[0].name}.`); - setExtensionUpdateState(ExtensionUpdateState.ERROR); - return; + return ExtensionUpdateState.ERROR; } // Determine the ref to check on the remote. @@ -174,8 +166,7 @@ export async function checkForExtensionUpdate( if (typeof lsRemoteOutput !== 'string' || lsRemoteOutput.trim() === '') { console.error(`Git ref ${refToCheck} not found.`); - setExtensionUpdateState(ExtensionUpdateState.ERROR); - return; + return ExtensionUpdateState.ERROR; } const remoteHash = lsRemoteOutput.split('\t')[0]; @@ -185,21 +176,17 @@ export async function checkForExtensionUpdate( console.error( `Unable to parse hash from git ls-remote output "${lsRemoteOutput}"`, ); - setExtensionUpdateState(ExtensionUpdateState.ERROR); - return; + return ExtensionUpdateState.ERROR; } if (remoteHash === localHash) { - setExtensionUpdateState(ExtensionUpdateState.UP_TO_DATE); - return; + return ExtensionUpdateState.UP_TO_DATE; } - setExtensionUpdateState(ExtensionUpdateState.UPDATE_AVAILABLE); - return; + return ExtensionUpdateState.UPDATE_AVAILABLE; } else { const { source, releaseTag } = installMetadata; if (!source) { console.error(`No "source" provided for extension.`); - setExtensionUpdateState(ExtensionUpdateState.ERROR); - return; + return ExtensionUpdateState.ERROR; } const { owner, repo } = parseGitHubRepoForReleases(source); @@ -209,18 +196,15 @@ export async function checkForExtensionUpdate( installMetadata.ref, ); if (releaseData.tag_name !== releaseTag) { - setExtensionUpdateState(ExtensionUpdateState.UPDATE_AVAILABLE); - return; + return ExtensionUpdateState.UPDATE_AVAILABLE; } - setExtensionUpdateState(ExtensionUpdateState.UP_TO_DATE); - return; + return ExtensionUpdateState.UP_TO_DATE; } } catch (error) { console.error( `Failed to check for updates for extension "${installMetadata.source}": ${getErrorMessage(error)}`, ); - setExtensionUpdateState(ExtensionUpdateState.ERROR); - return; + return ExtensionUpdateState.ERROR; } } export interface GitHubDownloadResult { diff --git a/packages/cli/src/config/extensions/update.test.ts b/packages/cli/src/config/extensions/update.test.ts index 54f6353340..82624868a3 100644 --- a/packages/cli/src/config/extensions/update.test.ts +++ b/packages/cli/src/config/extensions/update.test.ts @@ -302,7 +302,11 @@ describe('update tests', () => { mockGit.revparse.mockResolvedValue('localHash'); const dispatch = vi.fn(); - await checkForAllExtensionUpdates([extension], dispatch); + await checkForAllExtensionUpdates( + [extension], + dispatch, + tempWorkspaceDir, + ); expect(dispatch).toHaveBeenCalledWith({ type: 'SET_STATE', payload: { @@ -340,7 +344,11 @@ describe('update tests', () => { mockGit.revparse.mockResolvedValue('sameHash'); const dispatch = vi.fn(); - await checkForAllExtensionUpdates([extension], dispatch); + await checkForAllExtensionUpdates( + [extension], + dispatch, + tempWorkspaceDir, + ); expect(dispatch).toHaveBeenCalledWith({ type: 'SET_STATE', payload: { @@ -375,7 +383,11 @@ describe('update tests', () => { new ExtensionEnablementManager(ExtensionStorage.getUserExtensionsDir()), )[0]; const dispatch = vi.fn(); - await checkForAllExtensionUpdates([extension], dispatch); + await checkForAllExtensionUpdates( + [extension], + dispatch, + tempWorkspaceDir, + ); expect(dispatch).toHaveBeenCalledWith({ type: 'SET_STATE', payload: { @@ -410,7 +422,11 @@ describe('update tests', () => { new ExtensionEnablementManager(ExtensionStorage.getUserExtensionsDir()), )[0]; const dispatch = vi.fn(); - await checkForAllExtensionUpdates([extension], dispatch); + await checkForAllExtensionUpdates( + [extension], + dispatch, + tempWorkspaceDir, + ); expect(dispatch).toHaveBeenCalledWith({ type: 'SET_STATE', payload: { @@ -444,7 +460,11 @@ describe('update tests', () => { mockGit.getRemotes.mockRejectedValue(new Error('Git error')); const dispatch = vi.fn(); - await checkForAllExtensionUpdates([extension], dispatch); + await checkForAllExtensionUpdates( + [extension], + dispatch, + tempWorkspaceDir, + ); expect(dispatch).toHaveBeenCalledWith({ type: 'SET_STATE', payload: { diff --git a/packages/cli/src/config/extensions/update.ts b/packages/cli/src/config/extensions/update.ts index af52000ef5..cb1fa70e03 100644 --- a/packages/cli/src/config/extensions/update.ts +++ b/packages/cli/src/config/extensions/update.ts @@ -154,6 +154,7 @@ export interface ExtensionUpdateCheckResult { export async function checkForAllExtensionUpdates( extensions: GeminiCLIExtension[], dispatch: (action: ExtensionUpdateAction) => void, + cwd: string = process.cwd(), ): Promise { dispatch({ type: 'BATCH_CHECK_START' }); const promises: Array> = []; @@ -168,13 +169,20 @@ export async function checkForAllExtensionUpdates( }); continue; } + dispatch({ + type: 'SET_STATE', + payload: { + name: extension.name, + state: ExtensionUpdateState.CHECKING_FOR_UPDATES, + }, + }); promises.push( - checkForExtensionUpdate(extension, (updatedState) => { + checkForExtensionUpdate(extension, cwd).then((state) => dispatch({ type: 'SET_STATE', - payload: { name: extension.name, state: updatedState }, - }); - }), + payload: { name: extension.name, state }, + }), + ), ); } await Promise.all(promises); diff --git a/packages/cli/src/ui/commands/extensionsCommand.test.ts b/packages/cli/src/ui/commands/extensionsCommand.test.ts index 332f7a1099..ee41668a74 100644 --- a/packages/cli/src/ui/commands/extensionsCommand.test.ts +++ b/packages/cli/src/ui/commands/extensionsCommand.test.ts @@ -5,43 +5,23 @@ */ import type { GeminiCLIExtension } from '@google/gemini-cli-core'; -import { - updateAllUpdatableExtensions, - updateExtension, -} from '../../config/extensions/update.js'; import { createMockCommandContext } from '../../test-utils/mockCommandContext.js'; import { MessageType } from '../types.js'; import { extensionsCommand } from './extensionsCommand.js'; import { type CommandContext } from './types.js'; -import { - describe, - it, - expect, - vi, - beforeEach, - type MockedFunction, -} from 'vitest'; -import { ExtensionUpdateState } from '../state/extensions.js'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { type ExtensionUpdateAction } from '../state/extensions.js'; vi.mock('../../config/extensions/update.js', () => ({ updateExtension: vi.fn(), - updateAllUpdatableExtensions: vi.fn(), checkForAllExtensionUpdates: vi.fn(), })); -const mockUpdateExtension = updateExtension as MockedFunction< - typeof updateExtension ->; - -const mockUpdateAllUpdatableExtensions = - updateAllUpdatableExtensions as MockedFunction< - typeof updateAllUpdatableExtensions - >; - const mockGetExtensions = vi.fn(); describe('extensionsCommand', () => { let mockContext: CommandContext; + const mockDispatchExtensionState = vi.fn(); beforeEach(() => { vi.resetAllMocks(); @@ -53,7 +33,7 @@ describe('extensionsCommand', () => { }, }, ui: { - dispatchExtensionStateUpdate: vi.fn(), + dispatchExtensionStateUpdate: mockDispatchExtensionState, }, }); }); @@ -93,7 +73,14 @@ describe('extensionsCommand', () => { }); it('should inform user if there are no extensions to update with --all', async () => { - mockUpdateAllUpdatableExtensions.mockResolvedValue([]); + mockDispatchExtensionState.mockImplementationOnce( + (action: ExtensionUpdateAction) => { + if (action.type === 'SCHEDULE_UPDATE') { + action.payload.onComplete([]); + } + }, + ); + await updateAction(mockContext, '--all'); expect(mockContext.ui.addItem).toHaveBeenCalledWith( { @@ -105,18 +92,24 @@ describe('extensionsCommand', () => { }); it('should call setPendingItem and addItem in a finally block on success', async () => { - mockUpdateAllUpdatableExtensions.mockResolvedValue([ - { - name: 'ext-one', - originalVersion: '1.0.0', - updatedVersion: '1.0.1', + mockDispatchExtensionState.mockImplementationOnce( + (action: ExtensionUpdateAction) => { + if (action.type === 'SCHEDULE_UPDATE') { + action.payload.onComplete([ + { + name: 'ext-one', + originalVersion: '1.0.0', + updatedVersion: '1.0.1', + }, + { + name: 'ext-two', + originalVersion: '2.0.0', + updatedVersion: '2.0.1', + }, + ]); + } }, - { - name: 'ext-two', - originalVersion: '2.0.0', - updatedVersion: '2.0.1', - }, - ]); + ); await updateAction(mockContext, '--all'); expect(mockContext.ui.setPendingItem).toHaveBeenCalledWith({ type: MessageType.EXTENSIONS_LIST, @@ -131,9 +124,9 @@ describe('extensionsCommand', () => { }); it('should call setPendingItem and addItem in a finally block on failure', async () => { - mockUpdateAllUpdatableExtensions.mockRejectedValue( - new Error('Something went wrong'), - ); + mockDispatchExtensionState.mockImplementationOnce((_) => { + throw new Error('Something went wrong'); + }); await updateAction(mockContext, '--all'); expect(mockContext.ui.setPendingItem).toHaveBeenCalledWith({ type: MessageType.EXTENSIONS_LIST, @@ -155,95 +148,58 @@ describe('extensionsCommand', () => { }); it('should update a single extension by name', async () => { - const extension: GeminiCLIExtension = { - name: 'ext-one', - version: '1.0.0', - isActive: true, - path: '/test/dir/ext-one', - installMetadata: { - type: 'git', - autoUpdate: false, - source: 'https://github.com/some/extension.git', + mockDispatchExtensionState.mockImplementationOnce( + (action: ExtensionUpdateAction) => { + if (action.type === 'SCHEDULE_UPDATE') { + action.payload.onComplete([ + { + name: 'ext-one', + originalVersion: '1.0.0', + updatedVersion: '1.0.1', + }, + ]); + } }, - }; - mockUpdateExtension.mockResolvedValue({ - name: extension.name, - originalVersion: extension.version, - updatedVersion: '1.0.1', - }); - mockGetExtensions.mockReturnValue([extension]); - mockContext.ui.extensionsUpdateState.set(extension.name, { - status: ExtensionUpdateState.UPDATE_AVAILABLE, - processed: false, - }); - await updateAction(mockContext, 'ext-one'); - expect(mockUpdateExtension).toHaveBeenCalledWith( - extension, - '/test/dir', - expect.any(Function), - ExtensionUpdateState.UPDATE_AVAILABLE, - expect.any(Function), ); - }); - - it('should handle errors when updating a single extension', async () => { - mockUpdateExtension.mockRejectedValue(new Error('Extension not found')); - mockGetExtensions.mockReturnValue([]); await updateAction(mockContext, 'ext-one'); - expect(mockContext.ui.addItem).toHaveBeenCalledWith( - { - type: MessageType.ERROR, - text: 'Extension ext-one not found.', + expect(mockDispatchExtensionState).toHaveBeenCalledWith({ + type: 'SCHEDULE_UPDATE', + payload: { + all: false, + names: ['ext-one'], + onComplete: expect.any(Function), }, - expect.any(Number), - ); + }); }); it('should update multiple extensions by name', async () => { - const extensionOne: GeminiCLIExtension = { - name: 'ext-one', - version: '1.0.0', - isActive: true, - path: '/test/dir/ext-one', - installMetadata: { - type: 'git', - autoUpdate: false, - source: 'https://github.com/some/extension.git', + mockDispatchExtensionState.mockImplementationOnce( + (action: ExtensionUpdateAction) => { + if (action.type === 'SCHEDULE_UPDATE') { + action.payload.onComplete([ + { + name: 'ext-one', + originalVersion: '1.0.0', + updatedVersion: '1.0.1', + }, + { + name: 'ext-two', + originalVersion: '1.0.0', + updatedVersion: '1.0.1', + }, + ]); + } }, - }; - const extensionTwo: GeminiCLIExtension = { - name: 'ext-two', - version: '1.0.0', - isActive: true, - path: '/test/dir/ext-two', - installMetadata: { - type: 'git', - autoUpdate: false, - source: 'https://github.com/some/extension.git', - }, - }; - mockGetExtensions.mockReturnValue([extensionOne, extensionTwo]); - mockContext.ui.extensionsUpdateState.set( - extensionOne.name, - ExtensionUpdateState.UPDATE_AVAILABLE, ); - mockContext.ui.extensionsUpdateState.set( - extensionTwo.name, - ExtensionUpdateState.UPDATE_AVAILABLE, - ); - mockUpdateExtension - .mockResolvedValueOnce({ - name: 'ext-one', - originalVersion: '1.0.0', - updatedVersion: '1.0.1', - }) - .mockResolvedValueOnce({ - name: 'ext-two', - originalVersion: '2.0.0', - updatedVersion: '2.0.1', - }); await updateAction(mockContext, 'ext-one ext-two'); - expect(mockUpdateExtension).toHaveBeenCalledTimes(2); + expect(mockDispatchExtensionState).toHaveBeenCalledWith({ + type: 'SCHEDULE_UPDATE', + payload: { + all: false, + names: ['ext-one', 'ext-two'], + onComplete: expect.any(Function), + }, + }); expect(mockContext.ui.setPendingItem).toHaveBeenCalledWith({ type: MessageType.EXTENSIONS_LIST, }); diff --git a/packages/cli/src/ui/commands/extensionsCommand.ts b/packages/cli/src/ui/commands/extensionsCommand.ts index e4f2c8fbe3..87f126b161 100644 --- a/packages/cli/src/ui/commands/extensionsCommand.ts +++ b/packages/cli/src/ui/commands/extensionsCommand.ts @@ -4,15 +4,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { requestConsentInteractive } from '../../config/extension.js'; -import { - updateAllUpdatableExtensions, - type ExtensionUpdateInfo, - updateExtension, - checkForAllExtensionUpdates, -} from '../../config/extensions/update.js'; +import type { ExtensionUpdateInfo } from '../../config/extension.js'; import { getErrorMessage } from '../../utils/errors.js'; -import { ExtensionUpdateState } from '../state/extensions.js'; import { MessageType } from '../types.js'; import { type CommandContext, @@ -29,11 +22,10 @@ async function listAction(context: CommandContext) { ); } -async function updateAction(context: CommandContext, args: string) { +function updateAction(context: CommandContext, args: string): Promise { const updateArgs = args.split(' ').filter((value) => value.length > 0); const all = updateArgs.length === 1 && updateArgs[0] === '--all'; - const names = all ? undefined : updateArgs; - let updateInfos: ExtensionUpdateInfo[] = []; + const names = all ? null : updateArgs; if (!all && names?.length === 0) { context.ui.addItem( @@ -43,32 +35,48 @@ async function updateAction(context: CommandContext, args: string) { }, Date.now(), ); - return; + return Promise.resolve(); } - try { - await checkForAllExtensionUpdates( - context.services.config!.getExtensions(), - context.ui.dispatchExtensionStateUpdate, + let resolveUpdateComplete: (updateInfo: ExtensionUpdateInfo[]) => void; + const updateComplete = new Promise( + (resolve) => (resolveUpdateComplete = resolve), + ); + updateComplete.then((updateInfos) => { + if (updateInfos.length === 0) { + context.ui.addItem( + { + type: MessageType.INFO, + text: 'No extensions to update.', + }, + Date.now(), + ); + } + context.ui.addItem( + { + type: MessageType.EXTENSIONS_LIST, + }, + Date.now(), ); + context.ui.setPendingItem(null); + }); + + try { context.ui.setPendingItem({ type: MessageType.EXTENSIONS_LIST, }); - if (all) { - updateInfos = await updateAllUpdatableExtensions( - context.services.config!.getWorkingDir(), - // We don't have the ability to prompt for consent yet in this flow. - (description) => - requestConsentInteractive( - description, - context.ui.addConfirmUpdateExtensionRequest, - ), - context.services.config!.getExtensions(), - context.ui.extensionsUpdateState, - context.ui.dispatchExtensionStateUpdate, - ); - } else if (names?.length) { - const workingDir = context.services.config!.getWorkingDir(); + + context.ui.dispatchExtensionStateUpdate({ + type: 'SCHEDULE_UPDATE', + payload: { + all, + names, + onComplete: (updateInfos) => { + resolveUpdateComplete(updateInfos); + }, + }, + }); + if (names?.length) { const extensions = context.services.config!.getExtensions(); for (const name of names) { const extension = extensions.find( @@ -84,33 +92,10 @@ async function updateAction(context: CommandContext, args: string) { ); continue; } - const updateInfo = await updateExtension( - extension, - workingDir, - (description) => - requestConsentInteractive( - description, - context.ui.addConfirmUpdateExtensionRequest, - ), - context.ui.extensionsUpdateState.get(extension.name)?.status ?? - ExtensionUpdateState.UNKNOWN, - context.ui.dispatchExtensionStateUpdate, - ); - if (updateInfo) updateInfos.push(updateInfo); } } - - if (updateInfos.length === 0) { - context.ui.addItem( - { - type: MessageType.INFO, - text: 'No extensions to update.', - }, - Date.now(), - ); - return; - } } catch (error) { + resolveUpdateComplete!([]); context.ui.addItem( { type: MessageType.ERROR, @@ -118,15 +103,8 @@ async function updateAction(context: CommandContext, args: string) { }, Date.now(), ); - } finally { - context.ui.addItem( - { - type: MessageType.EXTENSIONS_LIST, - }, - Date.now(), - ); - context.ui.setPendingItem(null); } + return updateComplete.then((_) => {}); } const listExtensionsCommand: SlashCommand = { diff --git a/packages/cli/src/ui/hooks/useExtensionUpdates.test.ts b/packages/cli/src/ui/hooks/useExtensionUpdates.test.ts index 56506f58ee..c7a1fd0cfd 100644 --- a/packages/cli/src/ui/hooks/useExtensionUpdates.test.ts +++ b/packages/cli/src/ui/hooks/useExtensionUpdates.test.ts @@ -76,7 +76,7 @@ describe('useExtensionUpdates', () => { const cwd = '/test/cwd'; vi.mocked(checkForAllExtensionUpdates).mockImplementation( - async (extensions, dispatch) => { + async (_extensions, dispatch, _cwd) => { dispatch({ type: 'SET_STATE', payload: { @@ -122,7 +122,7 @@ describe('useExtensionUpdates', () => { const addItem = vi.fn(); vi.mocked(checkForAllExtensionUpdates).mockImplementation( - async (extensions, dispatch) => { + async (_extensions, dispatch, _cwd) => { dispatch({ type: 'SET_STATE', payload: { @@ -195,7 +195,7 @@ describe('useExtensionUpdates', () => { const addItem = vi.fn(); vi.mocked(checkForAllExtensionUpdates).mockImplementation( - async (extensions, dispatch) => { + async (_extensions, dispatch, _cwd) => { dispatch({ type: 'SET_STATE', payload: { @@ -280,7 +280,7 @@ describe('useExtensionUpdates', () => { const cwd = '/test/cwd'; vi.mocked(checkForAllExtensionUpdates).mockImplementation( - async (extensions, dispatch) => { + async (_extensions, dispatch, _cwd) => { dispatch({ type: 'BATCH_CHECK_START' }); dispatch({ type: 'SET_STATE', diff --git a/packages/cli/src/ui/hooks/useExtensionUpdates.ts b/packages/cli/src/ui/hooks/useExtensionUpdates.ts index 2967fdb324..5908e298b1 100644 --- a/packages/cli/src/ui/hooks/useExtensionUpdates.ts +++ b/packages/cli/src/ui/hooks/useExtensionUpdates.ts @@ -18,7 +18,10 @@ import { checkForAllExtensionUpdates, updateExtension, } from '../../config/extensions/update.js'; -import { requestConsentInteractive } from '../../config/extension.js'; +import { + requestConsentInteractive, + type ExtensionUpdateInfo, +} from '../../config/extension.js'; import { checkExhaustive } from '../../utils/checks.js'; type ConfirmationRequestWrapper = { @@ -41,7 +44,6 @@ function confirmationRequestsReducer( return state.filter((r) => r !== action.request); default: checkExhaustive(action); - return state; } } @@ -80,40 +82,77 @@ export const useExtensionUpdates = ( ); useEffect(() => { - (async () => { - await checkForAllExtensionUpdates( - extensions, - dispatchExtensionStateUpdate, + const extensionsToCheck = extensions.filter((extension) => { + const currentStatus = extensionsUpdateState.extensionStatuses.get( + extension.name, ); - })(); - }, [extensions, extensions.length, dispatchExtensionStateUpdate]); + if (!currentStatus) return true; + const currentState = currentStatus.status; + return !currentState || currentState === ExtensionUpdateState.UNKNOWN; + }); + if (extensionsToCheck.length === 0) return; + checkForAllExtensionUpdates( + extensionsToCheck, + dispatchExtensionStateUpdate, + cwd, + ); + }, [ + extensions, + extensionsUpdateState.extensionStatuses, + cwd, + dispatchExtensionStateUpdate, + ]); useEffect(() => { if (extensionsUpdateState.batchChecksInProgress > 0) { return; } + const scheduledUpdate = extensionsUpdateState.scheduledUpdate; + if (scheduledUpdate) { + dispatchExtensionStateUpdate({ + type: 'CLEAR_SCHEDULED_UPDATE', + }); + } + + function shouldDoUpdate(extension: GeminiCLIExtension): boolean { + if (scheduledUpdate) { + if (scheduledUpdate.all) { + return true; + } + return scheduledUpdate.names?.includes(extension.name) === true; + } else { + return extension.installMetadata?.autoUpdate === true; + } + } let extensionsWithUpdatesCount = 0; + // We only notify if we have unprocessed extensions in the UPDATE_AVAILABLE + // state. + let shouldNotifyOfUpdates = false; + const updatePromises: Array> = []; for (const extension of extensions) { const currentState = extensionsUpdateState.extensionStatuses.get( extension.name, ); if ( !currentState || - currentState.processed || currentState.status !== ExtensionUpdateState.UPDATE_AVAILABLE ) { continue; } - - // Mark as processed immediately to avoid re-triggering. - dispatchExtensionStateUpdate({ - type: 'SET_PROCESSED', - payload: { name: extension.name, processed: true }, - }); - - if (extension.installMetadata?.autoUpdate) { - updateExtension( + const shouldUpdate = shouldDoUpdate(extension); + if (!shouldUpdate) { + extensionsWithUpdatesCount++; + if (!currentState.notified) { + // Mark as processed immediately to avoid re-triggering. + dispatchExtensionStateUpdate({ + type: 'SET_NOTIFIED', + payload: { name: extension.name, notified: true }, + }); + shouldNotifyOfUpdates = true; + } + } else { + const updatePromise = updateExtension( extension, cwd, (description) => @@ -123,7 +162,9 @@ export const useExtensionUpdates = ( ), currentState.status, dispatchExtensionStateUpdate, - ) + ); + updatePromises.push(updatePromise); + updatePromise .then((result) => { if (!result) return; addItem( @@ -143,11 +184,9 @@ export const useExtensionUpdates = ( Date.now(), ); }); - } else { - extensionsWithUpdatesCount++; } } - if (extensionsWithUpdatesCount > 0) { + if (shouldNotifyOfUpdates) { const s = extensionsWithUpdatesCount > 1 ? 's' : ''; addItem( { @@ -157,6 +196,18 @@ export const useExtensionUpdates = ( Date.now(), ); } + if (scheduledUpdate) { + Promise.all(updatePromises).then((results) => { + const nonNullResults = results.filter((result) => result != null); + scheduledUpdate.onCompleteCallbacks.forEach((callback) => { + try { + callback(nonNullResults); + } catch (e) { + console.error(getErrorMessage(e)); + } + }); + }); + } }, [ extensions, extensionsUpdateState, diff --git a/packages/cli/src/ui/state/extensions.ts b/packages/cli/src/ui/state/extensions.ts index a83745b5cb..49295f5c15 100644 --- a/packages/cli/src/ui/state/extensions.ts +++ b/packages/cli/src/ui/state/extensions.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import type { ExtensionUpdateInfo } from '../../config/extension.js'; import { checkExhaustive } from '../../utils/checks.js'; export enum ExtensionUpdateState { @@ -19,17 +20,34 @@ export enum ExtensionUpdateState { export interface ExtensionUpdateStatus { status: ExtensionUpdateState; - processed: boolean; + notified: boolean; } export interface ExtensionUpdatesState { extensionStatuses: Map; batchChecksInProgress: number; + // Explicitly scheduled updates. + scheduledUpdate: ScheduledUpdate | null; } +export interface ScheduledUpdate { + names: string[] | null; + all: boolean; + onCompleteCallbacks: OnCompleteUpdate[]; +} + +export interface ScheduleUpdateArgs { + names: string[] | null; + all: boolean; + onComplete: OnCompleteUpdate; +} + +type OnCompleteUpdate = (updateInfos: ExtensionUpdateInfo[]) => void; + export const initialExtensionUpdatesState: ExtensionUpdatesState = { extensionStatuses: new Map(), batchChecksInProgress: 0, + scheduledUpdate: null, }; export type ExtensionUpdateAction = @@ -38,11 +56,13 @@ export type ExtensionUpdateAction = payload: { name: string; state: ExtensionUpdateState }; } | { - type: 'SET_PROCESSED'; - payload: { name: string; processed: boolean }; + type: 'SET_NOTIFIED'; + payload: { name: string; notified: boolean }; } | { type: 'BATCH_CHECK_START' } - | { type: 'BATCH_CHECK_END' }; + | { type: 'BATCH_CHECK_END' } + | { type: 'SCHEDULE_UPDATE'; payload: ScheduleUpdateArgs } + | { type: 'CLEAR_SCHEDULED_UPDATE' }; export function extensionUpdatesReducer( state: ExtensionUpdatesState, @@ -57,19 +77,19 @@ export function extensionUpdatesReducer( const newStatuses = new Map(state.extensionStatuses); newStatuses.set(action.payload.name, { status: action.payload.state, - processed: false, + notified: false, }); return { ...state, extensionStatuses: newStatuses }; } - case 'SET_PROCESSED': { + case 'SET_NOTIFIED': { const existing = state.extensionStatuses.get(action.payload.name); - if (!existing || existing.processed === action.payload.processed) { + if (!existing || existing.notified === action.payload.notified) { return state; } const newStatuses = new Map(state.extensionStatuses); newStatuses.set(action.payload.name, { ...existing, - processed: action.payload.processed, + notified: action.payload.notified, }); return { ...state, extensionStatuses: newStatuses }; } @@ -83,6 +103,27 @@ export function extensionUpdatesReducer( ...state, batchChecksInProgress: state.batchChecksInProgress - 1, }; + case 'SCHEDULE_UPDATE': + return { + ...state, + // If there is a pre-existing scheduled update, we merge them. + scheduledUpdate: { + all: state.scheduledUpdate?.all || action.payload.all, + names: [ + ...(state.scheduledUpdate?.names ?? []), + ...(action.payload.names ?? []), + ], + onCompleteCallbacks: [ + ...(state.scheduledUpdate?.onCompleteCallbacks ?? []), + action.payload.onComplete, + ], + }, + }; + case 'CLEAR_SCHEDULED_UPDATE': + return { + ...state, + scheduledUpdate: null, + }; default: checkExhaustive(action); }