Cleanup extension update logic (#10514)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Jacob MacDonald
2025-10-03 21:06:26 -07:00
committed by GitHub
parent 1a06282061
commit 7f8537a130
10 changed files with 310 additions and 292 deletions
@@ -53,16 +53,13 @@ export async function handleUpdate(args: UpdateArgs) {
console.log(`Extension "${args.name}" not found.`); console.log(`Extension "${args.name}" not found.`);
return; return;
} }
let updateState: ExtensionUpdateState | undefined;
if (!extension.installMetadata) { if (!extension.installMetadata) {
console.log( console.log(
`Unable to install extension "${args.name}" due to missing install metadata`, `Unable to install extension "${args.name}" due to missing install metadata`,
); );
return; return;
} }
await checkForExtensionUpdate(extension, (newState) => { const updateState = await checkForExtensionUpdate(extension);
updateState = newState;
});
if (updateState !== ExtensionUpdateState.UPDATE_AVAILABLE) { if (updateState !== ExtensionUpdateState.UPDATE_AVAILABLE) {
console.log(`Extension "${args.name}" is already up to date.`); console.log(`Extension "${args.name}" is already up to date.`);
return; return;
@@ -92,14 +89,17 @@ export async function handleUpdate(args: UpdateArgs) {
if (args.all) { if (args.all) {
try { try {
const extensionState = new Map(); const extensionState = new Map();
await checkForAllExtensionUpdates(extensions, (action) => { await checkForAllExtensionUpdates(
extensions,
(action) => {
if (action.type === 'SET_STATE') { if (action.type === 'SET_STATE') {
extensionState.set(action.payload.name, { extensionState.set(action.payload.name, {
status: action.payload.state, status: action.payload.state,
processed: true, // No need to process as we will force the update.
}); });
} }
}); },
workingDir,
);
let updateInfos = await updateAllUpdatableExtensions( let updateInfos = await updateAllUpdatableExtensions(
workingDir, workingDir,
requestConsentNonInteractive, requestConsentNonInteractive,
@@ -132,11 +132,7 @@ describe('git extension helpers', () => {
source: '', source: '',
}, },
}; };
let result: ExtensionUpdateState | undefined = undefined; const result = await checkForExtensionUpdate(extension);
await checkForExtensionUpdate(
extension,
(newState) => (result = newState),
);
expect(result).toBe(ExtensionUpdateState.NOT_UPDATABLE); expect(result).toBe(ExtensionUpdateState.NOT_UPDATABLE);
}); });
@@ -152,11 +148,7 @@ describe('git extension helpers', () => {
}, },
}; };
mockGit.getRemotes.mockResolvedValue([]); mockGit.getRemotes.mockResolvedValue([]);
let result: ExtensionUpdateState | undefined = undefined; const result = await checkForExtensionUpdate(extension);
await checkForExtensionUpdate(
extension,
(newState) => (result = newState),
);
expect(result).toBe(ExtensionUpdateState.ERROR); expect(result).toBe(ExtensionUpdateState.ERROR);
}); });
@@ -177,11 +169,7 @@ describe('git extension helpers', () => {
mockGit.listRemote.mockResolvedValue('remote-hash\tHEAD'); mockGit.listRemote.mockResolvedValue('remote-hash\tHEAD');
mockGit.revparse.mockResolvedValue('local-hash'); mockGit.revparse.mockResolvedValue('local-hash');
let result: ExtensionUpdateState | undefined = undefined; const result = await checkForExtensionUpdate(extension);
await checkForExtensionUpdate(
extension,
(newState) => (result = newState),
);
expect(result).toBe(ExtensionUpdateState.UPDATE_AVAILABLE); expect(result).toBe(ExtensionUpdateState.UPDATE_AVAILABLE);
}); });
@@ -202,11 +190,7 @@ describe('git extension helpers', () => {
mockGit.listRemote.mockResolvedValue('same-hash\tHEAD'); mockGit.listRemote.mockResolvedValue('same-hash\tHEAD');
mockGit.revparse.mockResolvedValue('same-hash'); mockGit.revparse.mockResolvedValue('same-hash');
let result: ExtensionUpdateState | undefined = undefined; const result = await checkForExtensionUpdate(extension);
await checkForExtensionUpdate(
extension,
(newState) => (result = newState),
);
expect(result).toBe(ExtensionUpdateState.UP_TO_DATE); expect(result).toBe(ExtensionUpdateState.UP_TO_DATE);
}); });
@@ -223,11 +207,7 @@ describe('git extension helpers', () => {
}; };
mockGit.getRemotes.mockRejectedValue(new Error('git error')); mockGit.getRemotes.mockRejectedValue(new Error('git error'));
let result: ExtensionUpdateState | undefined = undefined; const result = await checkForExtensionUpdate(extension);
await checkForExtensionUpdate(
extension,
(newState) => (result = newState),
);
expect(result).toBe(ExtensionUpdateState.ERROR); expect(result).toBe(ExtensionUpdateState.ERROR);
}); });
}); });
+15 -31
View File
@@ -119,10 +119,8 @@ async function fetchReleaseFromGithub(
export async function checkForExtensionUpdate( export async function checkForExtensionUpdate(
extension: GeminiCLIExtension, extension: GeminiCLIExtension,
setExtensionUpdateState: (updateState: ExtensionUpdateState) => void,
cwd: string = process.cwd(), cwd: string = process.cwd(),
): Promise<void> { ): Promise<ExtensionUpdateState> {
setExtensionUpdateState(ExtensionUpdateState.CHECKING_FOR_UPDATES);
const installMetadata = extension.installMetadata; const installMetadata = extension.installMetadata;
if (installMetadata?.type === 'local') { if (installMetadata?.type === 'local') {
const newExtension = loadExtension({ const newExtension = loadExtension({
@@ -133,23 +131,19 @@ export async function checkForExtensionUpdate(
console.error( console.error(
`Failed to check for update for local extension "${extension.name}". Could not load extension from source path: ${installMetadata.source}`, `Failed to check for update for local extension "${extension.name}". Could not load extension from source path: ${installMetadata.source}`,
); );
setExtensionUpdateState(ExtensionUpdateState.ERROR); return ExtensionUpdateState.ERROR;
return;
} }
if (newExtension.config.version !== extension.version) { if (newExtension.config.version !== extension.version) {
setExtensionUpdateState(ExtensionUpdateState.UPDATE_AVAILABLE); return ExtensionUpdateState.UPDATE_AVAILABLE;
return;
} }
setExtensionUpdateState(ExtensionUpdateState.UP_TO_DATE); return ExtensionUpdateState.UP_TO_DATE;
return;
} }
if ( if (
!installMetadata || !installMetadata ||
(installMetadata.type !== 'git' && (installMetadata.type !== 'git' &&
installMetadata.type !== 'github-release') installMetadata.type !== 'github-release')
) { ) {
setExtensionUpdateState(ExtensionUpdateState.NOT_UPDATABLE); return ExtensionUpdateState.NOT_UPDATABLE;
return;
} }
try { try {
if (installMetadata.type === 'git') { if (installMetadata.type === 'git') {
@@ -157,14 +151,12 @@ export async function checkForExtensionUpdate(
const remotes = await git.getRemotes(true); const remotes = await git.getRemotes(true);
if (remotes.length === 0) { if (remotes.length === 0) {
console.error('No git remotes found.'); console.error('No git remotes found.');
setExtensionUpdateState(ExtensionUpdateState.ERROR); return ExtensionUpdateState.ERROR;
return;
} }
const remoteUrl = remotes[0].refs.fetch; const remoteUrl = remotes[0].refs.fetch;
if (!remoteUrl) { if (!remoteUrl) {
console.error(`No fetch URL found for git remote ${remotes[0].name}.`); console.error(`No fetch URL found for git remote ${remotes[0].name}.`);
setExtensionUpdateState(ExtensionUpdateState.ERROR); return ExtensionUpdateState.ERROR;
return;
} }
// Determine the ref to check on the remote. // Determine the ref to check on the remote.
@@ -174,8 +166,7 @@ export async function checkForExtensionUpdate(
if (typeof lsRemoteOutput !== 'string' || lsRemoteOutput.trim() === '') { if (typeof lsRemoteOutput !== 'string' || lsRemoteOutput.trim() === '') {
console.error(`Git ref ${refToCheck} not found.`); console.error(`Git ref ${refToCheck} not found.`);
setExtensionUpdateState(ExtensionUpdateState.ERROR); return ExtensionUpdateState.ERROR;
return;
} }
const remoteHash = lsRemoteOutput.split('\t')[0]; const remoteHash = lsRemoteOutput.split('\t')[0];
@@ -185,21 +176,17 @@ export async function checkForExtensionUpdate(
console.error( console.error(
`Unable to parse hash from git ls-remote output "${lsRemoteOutput}"`, `Unable to parse hash from git ls-remote output "${lsRemoteOutput}"`,
); );
setExtensionUpdateState(ExtensionUpdateState.ERROR); return ExtensionUpdateState.ERROR;
return;
} }
if (remoteHash === localHash) { if (remoteHash === localHash) {
setExtensionUpdateState(ExtensionUpdateState.UP_TO_DATE); return ExtensionUpdateState.UP_TO_DATE;
return;
} }
setExtensionUpdateState(ExtensionUpdateState.UPDATE_AVAILABLE); return ExtensionUpdateState.UPDATE_AVAILABLE;
return;
} else { } else {
const { source, releaseTag } = installMetadata; const { source, releaseTag } = installMetadata;
if (!source) { if (!source) {
console.error(`No "source" provided for extension.`); console.error(`No "source" provided for extension.`);
setExtensionUpdateState(ExtensionUpdateState.ERROR); return ExtensionUpdateState.ERROR;
return;
} }
const { owner, repo } = parseGitHubRepoForReleases(source); const { owner, repo } = parseGitHubRepoForReleases(source);
@@ -209,18 +196,15 @@ export async function checkForExtensionUpdate(
installMetadata.ref, installMetadata.ref,
); );
if (releaseData.tag_name !== releaseTag) { if (releaseData.tag_name !== releaseTag) {
setExtensionUpdateState(ExtensionUpdateState.UPDATE_AVAILABLE); return ExtensionUpdateState.UPDATE_AVAILABLE;
return;
} }
setExtensionUpdateState(ExtensionUpdateState.UP_TO_DATE); return ExtensionUpdateState.UP_TO_DATE;
return;
} }
} catch (error) { } catch (error) {
console.error( console.error(
`Failed to check for updates for extension "${installMetadata.source}": ${getErrorMessage(error)}`, `Failed to check for updates for extension "${installMetadata.source}": ${getErrorMessage(error)}`,
); );
setExtensionUpdateState(ExtensionUpdateState.ERROR); return ExtensionUpdateState.ERROR;
return;
} }
} }
export interface GitHubDownloadResult { export interface GitHubDownloadResult {
@@ -302,7 +302,11 @@ describe('update tests', () => {
mockGit.revparse.mockResolvedValue('localHash'); mockGit.revparse.mockResolvedValue('localHash');
const dispatch = vi.fn(); const dispatch = vi.fn();
await checkForAllExtensionUpdates([extension], dispatch); await checkForAllExtensionUpdates(
[extension],
dispatch,
tempWorkspaceDir,
);
expect(dispatch).toHaveBeenCalledWith({ expect(dispatch).toHaveBeenCalledWith({
type: 'SET_STATE', type: 'SET_STATE',
payload: { payload: {
@@ -340,7 +344,11 @@ describe('update tests', () => {
mockGit.revparse.mockResolvedValue('sameHash'); mockGit.revparse.mockResolvedValue('sameHash');
const dispatch = vi.fn(); const dispatch = vi.fn();
await checkForAllExtensionUpdates([extension], dispatch); await checkForAllExtensionUpdates(
[extension],
dispatch,
tempWorkspaceDir,
);
expect(dispatch).toHaveBeenCalledWith({ expect(dispatch).toHaveBeenCalledWith({
type: 'SET_STATE', type: 'SET_STATE',
payload: { payload: {
@@ -375,7 +383,11 @@ describe('update tests', () => {
new ExtensionEnablementManager(ExtensionStorage.getUserExtensionsDir()), new ExtensionEnablementManager(ExtensionStorage.getUserExtensionsDir()),
)[0]; )[0];
const dispatch = vi.fn(); const dispatch = vi.fn();
await checkForAllExtensionUpdates([extension], dispatch); await checkForAllExtensionUpdates(
[extension],
dispatch,
tempWorkspaceDir,
);
expect(dispatch).toHaveBeenCalledWith({ expect(dispatch).toHaveBeenCalledWith({
type: 'SET_STATE', type: 'SET_STATE',
payload: { payload: {
@@ -410,7 +422,11 @@ describe('update tests', () => {
new ExtensionEnablementManager(ExtensionStorage.getUserExtensionsDir()), new ExtensionEnablementManager(ExtensionStorage.getUserExtensionsDir()),
)[0]; )[0];
const dispatch = vi.fn(); const dispatch = vi.fn();
await checkForAllExtensionUpdates([extension], dispatch); await checkForAllExtensionUpdates(
[extension],
dispatch,
tempWorkspaceDir,
);
expect(dispatch).toHaveBeenCalledWith({ expect(dispatch).toHaveBeenCalledWith({
type: 'SET_STATE', type: 'SET_STATE',
payload: { payload: {
@@ -444,7 +460,11 @@ describe('update tests', () => {
mockGit.getRemotes.mockRejectedValue(new Error('Git error')); mockGit.getRemotes.mockRejectedValue(new Error('Git error'));
const dispatch = vi.fn(); const dispatch = vi.fn();
await checkForAllExtensionUpdates([extension], dispatch); await checkForAllExtensionUpdates(
[extension],
dispatch,
tempWorkspaceDir,
);
expect(dispatch).toHaveBeenCalledWith({ expect(dispatch).toHaveBeenCalledWith({
type: 'SET_STATE', type: 'SET_STATE',
payload: { payload: {
+11 -3
View File
@@ -154,6 +154,7 @@ export interface ExtensionUpdateCheckResult {
export async function checkForAllExtensionUpdates( export async function checkForAllExtensionUpdates(
extensions: GeminiCLIExtension[], extensions: GeminiCLIExtension[],
dispatch: (action: ExtensionUpdateAction) => void, dispatch: (action: ExtensionUpdateAction) => void,
cwd: string = process.cwd(),
): Promise<void> { ): Promise<void> {
dispatch({ type: 'BATCH_CHECK_START' }); dispatch({ type: 'BATCH_CHECK_START' });
const promises: Array<Promise<void>> = []; const promises: Array<Promise<void>> = [];
@@ -168,13 +169,20 @@ export async function checkForAllExtensionUpdates(
}); });
continue; continue;
} }
promises.push(
checkForExtensionUpdate(extension, (updatedState) => {
dispatch({ dispatch({
type: 'SET_STATE', type: 'SET_STATE',
payload: { name: extension.name, state: updatedState }, payload: {
name: extension.name,
state: ExtensionUpdateState.CHECKING_FOR_UPDATES,
},
}); });
promises.push(
checkForExtensionUpdate(extension, cwd).then((state) =>
dispatch({
type: 'SET_STATE',
payload: { name: extension.name, state },
}), }),
),
); );
} }
await Promise.all(promises); await Promise.all(promises);
@@ -5,43 +5,23 @@
*/ */
import type { GeminiCLIExtension } from '@google/gemini-cli-core'; import type { GeminiCLIExtension } from '@google/gemini-cli-core';
import {
updateAllUpdatableExtensions,
updateExtension,
} from '../../config/extensions/update.js';
import { createMockCommandContext } from '../../test-utils/mockCommandContext.js'; import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
import { MessageType } from '../types.js'; import { MessageType } from '../types.js';
import { extensionsCommand } from './extensionsCommand.js'; import { extensionsCommand } from './extensionsCommand.js';
import { type CommandContext } from './types.js'; import { type CommandContext } from './types.js';
import { import { describe, it, expect, vi, beforeEach } from 'vitest';
describe, import { type ExtensionUpdateAction } from '../state/extensions.js';
it,
expect,
vi,
beforeEach,
type MockedFunction,
} from 'vitest';
import { ExtensionUpdateState } from '../state/extensions.js';
vi.mock('../../config/extensions/update.js', () => ({ vi.mock('../../config/extensions/update.js', () => ({
updateExtension: vi.fn(), updateExtension: vi.fn(),
updateAllUpdatableExtensions: vi.fn(),
checkForAllExtensionUpdates: vi.fn(), checkForAllExtensionUpdates: vi.fn(),
})); }));
const mockUpdateExtension = updateExtension as MockedFunction<
typeof updateExtension
>;
const mockUpdateAllUpdatableExtensions =
updateAllUpdatableExtensions as MockedFunction<
typeof updateAllUpdatableExtensions
>;
const mockGetExtensions = vi.fn(); const mockGetExtensions = vi.fn();
describe('extensionsCommand', () => { describe('extensionsCommand', () => {
let mockContext: CommandContext; let mockContext: CommandContext;
const mockDispatchExtensionState = vi.fn();
beforeEach(() => { beforeEach(() => {
vi.resetAllMocks(); vi.resetAllMocks();
@@ -53,7 +33,7 @@ describe('extensionsCommand', () => {
}, },
}, },
ui: { ui: {
dispatchExtensionStateUpdate: vi.fn(), dispatchExtensionStateUpdate: mockDispatchExtensionState,
}, },
}); });
}); });
@@ -93,7 +73,14 @@ describe('extensionsCommand', () => {
}); });
it('should inform user if there are no extensions to update with --all', async () => { it('should inform user if there are no extensions to update with --all', async () => {
mockUpdateAllUpdatableExtensions.mockResolvedValue([]); mockDispatchExtensionState.mockImplementationOnce(
(action: ExtensionUpdateAction) => {
if (action.type === 'SCHEDULE_UPDATE') {
action.payload.onComplete([]);
}
},
);
await updateAction(mockContext, '--all'); await updateAction(mockContext, '--all');
expect(mockContext.ui.addItem).toHaveBeenCalledWith( expect(mockContext.ui.addItem).toHaveBeenCalledWith(
{ {
@@ -105,7 +92,10 @@ describe('extensionsCommand', () => {
}); });
it('should call setPendingItem and addItem in a finally block on success', async () => { it('should call setPendingItem and addItem in a finally block on success', async () => {
mockUpdateAllUpdatableExtensions.mockResolvedValue([ mockDispatchExtensionState.mockImplementationOnce(
(action: ExtensionUpdateAction) => {
if (action.type === 'SCHEDULE_UPDATE') {
action.payload.onComplete([
{ {
name: 'ext-one', name: 'ext-one',
originalVersion: '1.0.0', originalVersion: '1.0.0',
@@ -117,6 +107,9 @@ describe('extensionsCommand', () => {
updatedVersion: '2.0.1', updatedVersion: '2.0.1',
}, },
]); ]);
}
},
);
await updateAction(mockContext, '--all'); await updateAction(mockContext, '--all');
expect(mockContext.ui.setPendingItem).toHaveBeenCalledWith({ expect(mockContext.ui.setPendingItem).toHaveBeenCalledWith({
type: MessageType.EXTENSIONS_LIST, type: MessageType.EXTENSIONS_LIST,
@@ -131,9 +124,9 @@ describe('extensionsCommand', () => {
}); });
it('should call setPendingItem and addItem in a finally block on failure', async () => { it('should call setPendingItem and addItem in a finally block on failure', async () => {
mockUpdateAllUpdatableExtensions.mockRejectedValue( mockDispatchExtensionState.mockImplementationOnce((_) => {
new Error('Something went wrong'), throw new Error('Something went wrong');
); });
await updateAction(mockContext, '--all'); await updateAction(mockContext, '--all');
expect(mockContext.ui.setPendingItem).toHaveBeenCalledWith({ expect(mockContext.ui.setPendingItem).toHaveBeenCalledWith({
type: MessageType.EXTENSIONS_LIST, type: MessageType.EXTENSIONS_LIST,
@@ -155,95 +148,58 @@ describe('extensionsCommand', () => {
}); });
it('should update a single extension by name', async () => { it('should update a single extension by name', async () => {
const extension: GeminiCLIExtension = { mockDispatchExtensionState.mockImplementationOnce(
name: 'ext-one', (action: ExtensionUpdateAction) => {
version: '1.0.0', if (action.type === 'SCHEDULE_UPDATE') {
isActive: true, action.payload.onComplete([
path: '/test/dir/ext-one',
installMetadata: {
type: 'git',
autoUpdate: false,
source: 'https://github.com/some/extension.git',
},
};
mockUpdateExtension.mockResolvedValue({
name: extension.name,
originalVersion: extension.version,
updatedVersion: '1.0.1',
});
mockGetExtensions.mockReturnValue([extension]);
mockContext.ui.extensionsUpdateState.set(extension.name, {
status: ExtensionUpdateState.UPDATE_AVAILABLE,
processed: false,
});
await updateAction(mockContext, 'ext-one');
expect(mockUpdateExtension).toHaveBeenCalledWith(
extension,
'/test/dir',
expect.any(Function),
ExtensionUpdateState.UPDATE_AVAILABLE,
expect.any(Function),
);
});
it('should handle errors when updating a single extension', async () => {
mockUpdateExtension.mockRejectedValue(new Error('Extension not found'));
mockGetExtensions.mockReturnValue([]);
await updateAction(mockContext, 'ext-one');
expect(mockContext.ui.addItem).toHaveBeenCalledWith(
{ {
type: MessageType.ERROR,
text: 'Extension ext-one not found.',
},
expect.any(Number),
);
});
it('should update multiple extensions by name', async () => {
const extensionOne: GeminiCLIExtension = {
name: 'ext-one',
version: '1.0.0',
isActive: true,
path: '/test/dir/ext-one',
installMetadata: {
type: 'git',
autoUpdate: false,
source: 'https://github.com/some/extension.git',
},
};
const extensionTwo: GeminiCLIExtension = {
name: 'ext-two',
version: '1.0.0',
isActive: true,
path: '/test/dir/ext-two',
installMetadata: {
type: 'git',
autoUpdate: false,
source: 'https://github.com/some/extension.git',
},
};
mockGetExtensions.mockReturnValue([extensionOne, extensionTwo]);
mockContext.ui.extensionsUpdateState.set(
extensionOne.name,
ExtensionUpdateState.UPDATE_AVAILABLE,
);
mockContext.ui.extensionsUpdateState.set(
extensionTwo.name,
ExtensionUpdateState.UPDATE_AVAILABLE,
);
mockUpdateExtension
.mockResolvedValueOnce({
name: 'ext-one', name: 'ext-one',
originalVersion: '1.0.0', originalVersion: '1.0.0',
updatedVersion: '1.0.1', updatedVersion: '1.0.1',
}) },
.mockResolvedValueOnce({ ]);
name: 'ext-two', }
originalVersion: '2.0.0', },
updatedVersion: '2.0.1', );
await updateAction(mockContext, 'ext-one');
expect(mockDispatchExtensionState).toHaveBeenCalledWith({
type: 'SCHEDULE_UPDATE',
payload: {
all: false,
names: ['ext-one'],
onComplete: expect.any(Function),
},
}); });
});
it('should update multiple extensions by name', async () => {
mockDispatchExtensionState.mockImplementationOnce(
(action: ExtensionUpdateAction) => {
if (action.type === 'SCHEDULE_UPDATE') {
action.payload.onComplete([
{
name: 'ext-one',
originalVersion: '1.0.0',
updatedVersion: '1.0.1',
},
{
name: 'ext-two',
originalVersion: '1.0.0',
updatedVersion: '1.0.1',
},
]);
}
},
);
await updateAction(mockContext, 'ext-one ext-two'); await updateAction(mockContext, 'ext-one ext-two');
expect(mockUpdateExtension).toHaveBeenCalledTimes(2); expect(mockDispatchExtensionState).toHaveBeenCalledWith({
type: 'SCHEDULE_UPDATE',
payload: {
all: false,
names: ['ext-one', 'ext-two'],
onComplete: expect.any(Function),
},
});
expect(mockContext.ui.setPendingItem).toHaveBeenCalledWith({ expect(mockContext.ui.setPendingItem).toHaveBeenCalledWith({
type: MessageType.EXTENSIONS_LIST, type: MessageType.EXTENSIONS_LIST,
}); });
@@ -4,15 +4,8 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import { requestConsentInteractive } from '../../config/extension.js'; import type { ExtensionUpdateInfo } from '../../config/extension.js';
import {
updateAllUpdatableExtensions,
type ExtensionUpdateInfo,
updateExtension,
checkForAllExtensionUpdates,
} from '../../config/extensions/update.js';
import { getErrorMessage } from '../../utils/errors.js'; import { getErrorMessage } from '../../utils/errors.js';
import { ExtensionUpdateState } from '../state/extensions.js';
import { MessageType } from '../types.js'; import { MessageType } from '../types.js';
import { import {
type CommandContext, type CommandContext,
@@ -29,11 +22,10 @@ async function listAction(context: CommandContext) {
); );
} }
async function updateAction(context: CommandContext, args: string) { function updateAction(context: CommandContext, args: string): Promise<void> {
const updateArgs = args.split(' ').filter((value) => value.length > 0); const updateArgs = args.split(' ').filter((value) => value.length > 0);
const all = updateArgs.length === 1 && updateArgs[0] === '--all'; const all = updateArgs.length === 1 && updateArgs[0] === '--all';
const names = all ? undefined : updateArgs; const names = all ? null : updateArgs;
let updateInfos: ExtensionUpdateInfo[] = [];
if (!all && names?.length === 0) { if (!all && names?.length === 0) {
context.ui.addItem( context.ui.addItem(
@@ -43,32 +35,48 @@ async function updateAction(context: CommandContext, args: string) {
}, },
Date.now(), Date.now(),
); );
return; return Promise.resolve();
} }
try { let resolveUpdateComplete: (updateInfo: ExtensionUpdateInfo[]) => void;
await checkForAllExtensionUpdates( const updateComplete = new Promise<ExtensionUpdateInfo[]>(
context.services.config!.getExtensions(), (resolve) => (resolveUpdateComplete = resolve),
context.ui.dispatchExtensionStateUpdate,
); );
updateComplete.then((updateInfos) => {
if (updateInfos.length === 0) {
context.ui.addItem(
{
type: MessageType.INFO,
text: 'No extensions to update.',
},
Date.now(),
);
}
context.ui.addItem(
{
type: MessageType.EXTENSIONS_LIST,
},
Date.now(),
);
context.ui.setPendingItem(null);
});
try {
context.ui.setPendingItem({ context.ui.setPendingItem({
type: MessageType.EXTENSIONS_LIST, type: MessageType.EXTENSIONS_LIST,
}); });
if (all) {
updateInfos = await updateAllUpdatableExtensions( context.ui.dispatchExtensionStateUpdate({
context.services.config!.getWorkingDir(), type: 'SCHEDULE_UPDATE',
// We don't have the ability to prompt for consent yet in this flow. payload: {
(description) => all,
requestConsentInteractive( names,
description, onComplete: (updateInfos) => {
context.ui.addConfirmUpdateExtensionRequest, resolveUpdateComplete(updateInfos);
), },
context.services.config!.getExtensions(), },
context.ui.extensionsUpdateState, });
context.ui.dispatchExtensionStateUpdate, if (names?.length) {
);
} else if (names?.length) {
const workingDir = context.services.config!.getWorkingDir();
const extensions = context.services.config!.getExtensions(); const extensions = context.services.config!.getExtensions();
for (const name of names) { for (const name of names) {
const extension = extensions.find( const extension = extensions.find(
@@ -84,33 +92,10 @@ async function updateAction(context: CommandContext, args: string) {
); );
continue; continue;
} }
const updateInfo = await updateExtension(
extension,
workingDir,
(description) =>
requestConsentInteractive(
description,
context.ui.addConfirmUpdateExtensionRequest,
),
context.ui.extensionsUpdateState.get(extension.name)?.status ??
ExtensionUpdateState.UNKNOWN,
context.ui.dispatchExtensionStateUpdate,
);
if (updateInfo) updateInfos.push(updateInfo);
} }
} }
if (updateInfos.length === 0) {
context.ui.addItem(
{
type: MessageType.INFO,
text: 'No extensions to update.',
},
Date.now(),
);
return;
}
} catch (error) { } catch (error) {
resolveUpdateComplete!([]);
context.ui.addItem( context.ui.addItem(
{ {
type: MessageType.ERROR, type: MessageType.ERROR,
@@ -118,15 +103,8 @@ async function updateAction(context: CommandContext, args: string) {
}, },
Date.now(), Date.now(),
); );
} finally {
context.ui.addItem(
{
type: MessageType.EXTENSIONS_LIST,
},
Date.now(),
);
context.ui.setPendingItem(null);
} }
return updateComplete.then((_) => {});
} }
const listExtensionsCommand: SlashCommand = { const listExtensionsCommand: SlashCommand = {
@@ -76,7 +76,7 @@ describe('useExtensionUpdates', () => {
const cwd = '/test/cwd'; const cwd = '/test/cwd';
vi.mocked(checkForAllExtensionUpdates).mockImplementation( vi.mocked(checkForAllExtensionUpdates).mockImplementation(
async (extensions, dispatch) => { async (_extensions, dispatch, _cwd) => {
dispatch({ dispatch({
type: 'SET_STATE', type: 'SET_STATE',
payload: { payload: {
@@ -122,7 +122,7 @@ describe('useExtensionUpdates', () => {
const addItem = vi.fn(); const addItem = vi.fn();
vi.mocked(checkForAllExtensionUpdates).mockImplementation( vi.mocked(checkForAllExtensionUpdates).mockImplementation(
async (extensions, dispatch) => { async (_extensions, dispatch, _cwd) => {
dispatch({ dispatch({
type: 'SET_STATE', type: 'SET_STATE',
payload: { payload: {
@@ -195,7 +195,7 @@ describe('useExtensionUpdates', () => {
const addItem = vi.fn(); const addItem = vi.fn();
vi.mocked(checkForAllExtensionUpdates).mockImplementation( vi.mocked(checkForAllExtensionUpdates).mockImplementation(
async (extensions, dispatch) => { async (_extensions, dispatch, _cwd) => {
dispatch({ dispatch({
type: 'SET_STATE', type: 'SET_STATE',
payload: { payload: {
@@ -280,7 +280,7 @@ describe('useExtensionUpdates', () => {
const cwd = '/test/cwd'; const cwd = '/test/cwd';
vi.mocked(checkForAllExtensionUpdates).mockImplementation( vi.mocked(checkForAllExtensionUpdates).mockImplementation(
async (extensions, dispatch) => { async (_extensions, dispatch, _cwd) => {
dispatch({ type: 'BATCH_CHECK_START' }); dispatch({ type: 'BATCH_CHECK_START' });
dispatch({ dispatch({
type: 'SET_STATE', type: 'SET_STATE',
@@ -18,7 +18,10 @@ import {
checkForAllExtensionUpdates, checkForAllExtensionUpdates,
updateExtension, updateExtension,
} from '../../config/extensions/update.js'; } from '../../config/extensions/update.js';
import { requestConsentInteractive } from '../../config/extension.js'; import {
requestConsentInteractive,
type ExtensionUpdateInfo,
} from '../../config/extension.js';
import { checkExhaustive } from '../../utils/checks.js'; import { checkExhaustive } from '../../utils/checks.js';
type ConfirmationRequestWrapper = { type ConfirmationRequestWrapper = {
@@ -41,7 +44,6 @@ function confirmationRequestsReducer(
return state.filter((r) => r !== action.request); return state.filter((r) => r !== action.request);
default: default:
checkExhaustive(action); checkExhaustive(action);
return state;
} }
} }
@@ -80,40 +82,77 @@ export const useExtensionUpdates = (
); );
useEffect(() => { useEffect(() => {
(async () => { const extensionsToCheck = extensions.filter((extension) => {
await checkForAllExtensionUpdates( const currentStatus = extensionsUpdateState.extensionStatuses.get(
extensions, extension.name,
dispatchExtensionStateUpdate,
); );
})(); if (!currentStatus) return true;
}, [extensions, extensions.length, dispatchExtensionStateUpdate]); const currentState = currentStatus.status;
return !currentState || currentState === ExtensionUpdateState.UNKNOWN;
});
if (extensionsToCheck.length === 0) return;
checkForAllExtensionUpdates(
extensionsToCheck,
dispatchExtensionStateUpdate,
cwd,
);
}, [
extensions,
extensionsUpdateState.extensionStatuses,
cwd,
dispatchExtensionStateUpdate,
]);
useEffect(() => { useEffect(() => {
if (extensionsUpdateState.batchChecksInProgress > 0) { if (extensionsUpdateState.batchChecksInProgress > 0) {
return; return;
} }
const scheduledUpdate = extensionsUpdateState.scheduledUpdate;
if (scheduledUpdate) {
dispatchExtensionStateUpdate({
type: 'CLEAR_SCHEDULED_UPDATE',
});
}
function shouldDoUpdate(extension: GeminiCLIExtension): boolean {
if (scheduledUpdate) {
if (scheduledUpdate.all) {
return true;
}
return scheduledUpdate.names?.includes(extension.name) === true;
} else {
return extension.installMetadata?.autoUpdate === true;
}
}
let extensionsWithUpdatesCount = 0; let extensionsWithUpdatesCount = 0;
// We only notify if we have unprocessed extensions in the UPDATE_AVAILABLE
// state.
let shouldNotifyOfUpdates = false;
const updatePromises: Array<Promise<ExtensionUpdateInfo | undefined>> = [];
for (const extension of extensions) { for (const extension of extensions) {
const currentState = extensionsUpdateState.extensionStatuses.get( const currentState = extensionsUpdateState.extensionStatuses.get(
extension.name, extension.name,
); );
if ( if (
!currentState || !currentState ||
currentState.processed ||
currentState.status !== ExtensionUpdateState.UPDATE_AVAILABLE currentState.status !== ExtensionUpdateState.UPDATE_AVAILABLE
) { ) {
continue; continue;
} }
const shouldUpdate = shouldDoUpdate(extension);
if (!shouldUpdate) {
extensionsWithUpdatesCount++;
if (!currentState.notified) {
// Mark as processed immediately to avoid re-triggering. // Mark as processed immediately to avoid re-triggering.
dispatchExtensionStateUpdate({ dispatchExtensionStateUpdate({
type: 'SET_PROCESSED', type: 'SET_NOTIFIED',
payload: { name: extension.name, processed: true }, payload: { name: extension.name, notified: true },
}); });
shouldNotifyOfUpdates = true;
if (extension.installMetadata?.autoUpdate) { }
updateExtension( } else {
const updatePromise = updateExtension(
extension, extension,
cwd, cwd,
(description) => (description) =>
@@ -123,7 +162,9 @@ export const useExtensionUpdates = (
), ),
currentState.status, currentState.status,
dispatchExtensionStateUpdate, dispatchExtensionStateUpdate,
) );
updatePromises.push(updatePromise);
updatePromise
.then((result) => { .then((result) => {
if (!result) return; if (!result) return;
addItem( addItem(
@@ -143,11 +184,9 @@ export const useExtensionUpdates = (
Date.now(), Date.now(),
); );
}); });
} else {
extensionsWithUpdatesCount++;
} }
} }
if (extensionsWithUpdatesCount > 0) { if (shouldNotifyOfUpdates) {
const s = extensionsWithUpdatesCount > 1 ? 's' : ''; const s = extensionsWithUpdatesCount > 1 ? 's' : '';
addItem( addItem(
{ {
@@ -157,6 +196,18 @@ export const useExtensionUpdates = (
Date.now(), Date.now(),
); );
} }
if (scheduledUpdate) {
Promise.all(updatePromises).then((results) => {
const nonNullResults = results.filter((result) => result != null);
scheduledUpdate.onCompleteCallbacks.forEach((callback) => {
try {
callback(nonNullResults);
} catch (e) {
console.error(getErrorMessage(e));
}
});
});
}
}, [ }, [
extensions, extensions,
extensionsUpdateState, extensionsUpdateState,
+49 -8
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import type { ExtensionUpdateInfo } from '../../config/extension.js';
import { checkExhaustive } from '../../utils/checks.js'; import { checkExhaustive } from '../../utils/checks.js';
export enum ExtensionUpdateState { export enum ExtensionUpdateState {
@@ -19,17 +20,34 @@ export enum ExtensionUpdateState {
export interface ExtensionUpdateStatus { export interface ExtensionUpdateStatus {
status: ExtensionUpdateState; status: ExtensionUpdateState;
processed: boolean; notified: boolean;
} }
export interface ExtensionUpdatesState { export interface ExtensionUpdatesState {
extensionStatuses: Map<string, ExtensionUpdateStatus>; extensionStatuses: Map<string, ExtensionUpdateStatus>;
batchChecksInProgress: number; batchChecksInProgress: number;
// Explicitly scheduled updates.
scheduledUpdate: ScheduledUpdate | null;
} }
export interface ScheduledUpdate {
names: string[] | null;
all: boolean;
onCompleteCallbacks: OnCompleteUpdate[];
}
export interface ScheduleUpdateArgs {
names: string[] | null;
all: boolean;
onComplete: OnCompleteUpdate;
}
type OnCompleteUpdate = (updateInfos: ExtensionUpdateInfo[]) => void;
export const initialExtensionUpdatesState: ExtensionUpdatesState = { export const initialExtensionUpdatesState: ExtensionUpdatesState = {
extensionStatuses: new Map(), extensionStatuses: new Map(),
batchChecksInProgress: 0, batchChecksInProgress: 0,
scheduledUpdate: null,
}; };
export type ExtensionUpdateAction = export type ExtensionUpdateAction =
@@ -38,11 +56,13 @@ export type ExtensionUpdateAction =
payload: { name: string; state: ExtensionUpdateState }; payload: { name: string; state: ExtensionUpdateState };
} }
| { | {
type: 'SET_PROCESSED'; type: 'SET_NOTIFIED';
payload: { name: string; processed: boolean }; payload: { name: string; notified: boolean };
} }
| { type: 'BATCH_CHECK_START' } | { type: 'BATCH_CHECK_START' }
| { type: 'BATCH_CHECK_END' }; | { type: 'BATCH_CHECK_END' }
| { type: 'SCHEDULE_UPDATE'; payload: ScheduleUpdateArgs }
| { type: 'CLEAR_SCHEDULED_UPDATE' };
export function extensionUpdatesReducer( export function extensionUpdatesReducer(
state: ExtensionUpdatesState, state: ExtensionUpdatesState,
@@ -57,19 +77,19 @@ export function extensionUpdatesReducer(
const newStatuses = new Map(state.extensionStatuses); const newStatuses = new Map(state.extensionStatuses);
newStatuses.set(action.payload.name, { newStatuses.set(action.payload.name, {
status: action.payload.state, status: action.payload.state,
processed: false, notified: false,
}); });
return { ...state, extensionStatuses: newStatuses }; return { ...state, extensionStatuses: newStatuses };
} }
case 'SET_PROCESSED': { case 'SET_NOTIFIED': {
const existing = state.extensionStatuses.get(action.payload.name); const existing = state.extensionStatuses.get(action.payload.name);
if (!existing || existing.processed === action.payload.processed) { if (!existing || existing.notified === action.payload.notified) {
return state; return state;
} }
const newStatuses = new Map(state.extensionStatuses); const newStatuses = new Map(state.extensionStatuses);
newStatuses.set(action.payload.name, { newStatuses.set(action.payload.name, {
...existing, ...existing,
processed: action.payload.processed, notified: action.payload.notified,
}); });
return { ...state, extensionStatuses: newStatuses }; return { ...state, extensionStatuses: newStatuses };
} }
@@ -83,6 +103,27 @@ export function extensionUpdatesReducer(
...state, ...state,
batchChecksInProgress: state.batchChecksInProgress - 1, batchChecksInProgress: state.batchChecksInProgress - 1,
}; };
case 'SCHEDULE_UPDATE':
return {
...state,
// If there is a pre-existing scheduled update, we merge them.
scheduledUpdate: {
all: state.scheduledUpdate?.all || action.payload.all,
names: [
...(state.scheduledUpdate?.names ?? []),
...(action.payload.names ?? []),
],
onCompleteCallbacks: [
...(state.scheduledUpdate?.onCompleteCallbacks ?? []),
action.payload.onComplete,
],
},
};
case 'CLEAR_SCHEDULED_UPDATE':
return {
...state,
scheduledUpdate: null,
};
default: default:
checkExhaustive(action); checkExhaustive(action);
} }