Improve error messages on failed onboarding (#17357)

This commit is contained in:
Gaurav
2026-01-26 06:31:19 -08:00
committed by GitHub
parent cb772a5b7f
commit 5fe328c56a
17 changed files with 458 additions and 56 deletions
+35 -15
View File
@@ -6,18 +6,20 @@
import { describe, it, expect, vi, beforeEach } from 'vitest'; import { describe, it, expect, vi, beforeEach } from 'vitest';
import { performInitialAuth } from './auth.js'; import { performInitialAuth } from './auth.js';
import { type Config } from '@google/gemini-cli-core'; import {
type Config,
ValidationRequiredError,
AuthType,
} from '@google/gemini-cli-core';
vi.mock('@google/gemini-cli-core', () => ({ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
AuthType: { const actual =
OAUTH: 'oauth', await importOriginal<typeof import('@google/gemini-cli-core')>();
}, return {
getErrorMessage: (e: unknown) => (e as Error).message, ...actual,
})); getErrorMessage: (e: unknown) => (e as Error).message,
};
const AuthType = { });
OAUTH: 'oauth',
} as const;
describe('auth', () => { describe('auth', () => {
let mockConfig: Config; let mockConfig: Config;
@@ -37,10 +39,12 @@ describe('auth', () => {
it('should return null on successful auth', async () => { it('should return null on successful auth', async () => {
const result = await performInitialAuth( const result = await performInitialAuth(
mockConfig, mockConfig,
AuthType.OAUTH as unknown as Parameters<typeof performInitialAuth>[1], AuthType.LOGIN_WITH_GOOGLE,
); );
expect(result).toBeNull(); expect(result).toBeNull();
expect(mockConfig.refreshAuth).toHaveBeenCalledWith(AuthType.OAUTH); expect(mockConfig.refreshAuth).toHaveBeenCalledWith(
AuthType.LOGIN_WITH_GOOGLE,
);
}); });
it('should return error message on failed auth', async () => { it('should return error message on failed auth', async () => {
@@ -48,9 +52,25 @@ describe('auth', () => {
vi.mocked(mockConfig.refreshAuth).mockRejectedValue(error); vi.mocked(mockConfig.refreshAuth).mockRejectedValue(error);
const result = await performInitialAuth( const result = await performInitialAuth(
mockConfig, mockConfig,
AuthType.OAUTH as unknown as Parameters<typeof performInitialAuth>[1], AuthType.LOGIN_WITH_GOOGLE,
); );
expect(result).toBe('Failed to login. Message: Auth failed'); expect(result).toBe('Failed to login. Message: Auth failed');
expect(mockConfig.refreshAuth).toHaveBeenCalledWith(AuthType.OAUTH); expect(mockConfig.refreshAuth).toHaveBeenCalledWith(
AuthType.LOGIN_WITH_GOOGLE,
);
});
it('should return null if refreshAuth throws ValidationRequiredError', async () => {
vi.mocked(mockConfig.refreshAuth).mockRejectedValue(
new ValidationRequiredError('Validation required'),
);
const result = await performInitialAuth(
mockConfig,
AuthType.LOGIN_WITH_GOOGLE,
);
expect(result).toBeNull();
expect(mockConfig.refreshAuth).toHaveBeenCalledWith(
AuthType.LOGIN_WITH_GOOGLE,
);
}); });
}); });
+6
View File
@@ -8,6 +8,7 @@ import {
type AuthType, type AuthType,
type Config, type Config,
getErrorMessage, getErrorMessage,
ValidationRequiredError,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
/** /**
@@ -29,6 +30,11 @@ export async function performInitialAuth(
// The console.log is intentionally left out here. // The console.log is intentionally left out here.
// We can add a dedicated startup message later if needed. // We can add a dedicated startup message later if needed.
} catch (e) { } catch (e) {
if (e instanceof ValidationRequiredError) {
// Don't treat validation required as a fatal auth error during startup.
// This allows the React UI to load and show the ValidationDialog.
return null;
}
return `Failed to login. Message: ${getErrorMessage(e)}`; return `Failed to login. Message: ${getErrorMessage(e)}`;
} }
+15 -2
View File
@@ -61,6 +61,8 @@ import {
SessionStartSource, SessionStartSource,
SessionEndReason, SessionEndReason,
getVersion, getVersion,
ValidationCancelledError,
ValidationRequiredError,
type FetchAdminControlsResponse, type FetchAdminControlsResponse,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import { import {
@@ -406,8 +408,19 @@ export async function main() {
await partialConfig.refreshAuth(authType); await partialConfig.refreshAuth(authType);
} }
} catch (err) { } catch (err) {
debugLogger.error('Error authenticating:', err); if (err instanceof ValidationCancelledError) {
initialAuthFailed = true; // User cancelled verification, exit immediately.
await runExitCleanup();
process.exit(ExitCodes.SUCCESS);
}
// If validation is required, we don't treat it as a fatal failure.
// We allow the app to start, and the React-based ValidationDialog
// will handle it.
if (!(err instanceof ValidationRequiredError)) {
debugLogger.error('Error authenticating:', err);
initialAuthFailed = true;
}
} }
} }
+6 -1
View File
@@ -63,6 +63,7 @@ import {
SessionStartSource, SessionStartSource,
SessionEndReason, SessionEndReason,
generateSummary, generateSummary,
ChangeAuthRequestedError,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import { validateAuthMethod } from '../config/auth.js'; import { validateAuthMethod } from '../config/auth.js';
import process from 'node:process'; import process from 'node:process';
@@ -527,7 +528,7 @@ export const AppContainer = (props: AppContainerProps) => {
onAuthError, onAuthError,
apiKeyDefaultValue, apiKeyDefaultValue,
reloadApiKey, reloadApiKey,
} = useAuthCommand(settings, config); } = useAuthCommand(settings, config, initializationResult.authError);
const [authContext, setAuthContext] = useState<{ requiresRestart?: boolean }>( const [authContext, setAuthContext] = useState<{ requiresRestart?: boolean }>(
{}, {},
); );
@@ -549,6 +550,7 @@ export const AppContainer = (props: AppContainerProps) => {
historyManager, historyManager,
userTier, userTier,
setModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError,
onShowAuthSelection: () => setAuthState(AuthState.Updating),
}); });
// Derive auth state variables for backward compatibility with UIStateContext // Derive auth state variables for backward compatibility with UIStateContext
@@ -598,6 +600,9 @@ export const AppContainer = (props: AppContainerProps) => {
await config.refreshAuth(authType); await config.refreshAuth(authType);
setAuthState(AuthState.Authenticated); setAuthState(AuthState.Authenticated);
} catch (e) { } catch (e) {
if (e instanceof ChangeAuthRequestedError) {
return;
}
onAuthError( onAuthError(
`Failed to authenticate: ${e instanceof Error ? e.message : String(e)}`, `Failed to authenticate: ${e instanceof Error ? e.message : String(e)}`,
); );
+7 -3
View File
@@ -34,12 +34,16 @@ export function validateAuthMethodWithSettings(
return validateAuthMethod(authType); return validateAuthMethod(authType);
} }
export const useAuthCommand = (settings: LoadedSettings, config: Config) => { export const useAuthCommand = (
settings: LoadedSettings,
config: Config,
initialAuthError: string | null = null,
) => {
const [authState, setAuthState] = useState<AuthState>( const [authState, setAuthState] = useState<AuthState>(
AuthState.Unauthenticated, initialAuthError ? AuthState.Updating : AuthState.Unauthenticated,
); );
const [authError, setAuthError] = useState<string | null>(null); const [authError, setAuthError] = useState<string | null>(initialAuthError);
const [apiKeyDefaultValue, setApiKeyDefaultValue] = useState< const [apiKeyDefaultValue, setApiKeyDefaultValue] = useState<
string | undefined string | undefined
>(undefined); >(undefined);
@@ -17,6 +17,7 @@ import {
} from 'vitest'; } from 'vitest';
import { ValidationDialog } from './ValidationDialog.js'; import { ValidationDialog } from './ValidationDialog.js';
import { RadioButtonSelect } from './shared/RadioButtonSelect.js'; import { RadioButtonSelect } from './shared/RadioButtonSelect.js';
import type { Key } from '../hooks/useKeypress.js';
// Mock the child components and utilities // Mock the child components and utilities
vi.mock('./shared/RadioButtonSelect.js', () => ({ vi.mock('./shared/RadioButtonSelect.js', () => ({
@@ -41,8 +42,15 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
}; };
}); });
// Capture keypress handler to test it
let mockKeypressHandler: (key: Key) => void;
let mockKeypressOptions: { isActive: boolean };
vi.mock('../hooks/useKeypress.js', () => ({ vi.mock('../hooks/useKeypress.js', () => ({
useKeypress: vi.fn(), useKeypress: vi.fn((handler, options) => {
mockKeypressHandler = handler;
mockKeypressOptions = options;
}),
})); }));
describe('ValidationDialog', () => { describe('ValidationDialog', () => {
@@ -99,6 +107,29 @@ describe('ValidationDialog', () => {
expect(lastFrame()).toContain('https://example.com/help'); expect(lastFrame()).toContain('https://example.com/help');
unmount(); unmount();
}); });
it('should call onChoice with cancel when ESCAPE is pressed', () => {
const { unmount } = render(<ValidationDialog onChoice={mockOnChoice} />);
// Verify the keypress hook is active
expect(mockKeypressOptions.isActive).toBe(true);
// Simulate ESCAPE key press
act(() => {
mockKeypressHandler({
name: 'escape',
ctrl: false,
shift: false,
alt: false,
cmd: false,
insertable: false,
sequence: '\x1b',
});
});
expect(mockOnChoice).toHaveBeenCalledWith('cancel');
unmount();
});
}); });
describe('onChoice handling', () => { describe('onChoice handling', () => {
@@ -48,17 +48,17 @@ export function ValidationDialog({
}, },
]; ];
// Handle keypresses during 'waiting' state (ESC to cancel, Enter to confirm completion) // Handle keypresses globally for cancellation, and specific logic for waiting state
useKeypress( useKeypress(
(key) => { (key) => {
if (keyMatchers[Command.ESCAPE](key) || keyMatchers[Command.QUIT](key)) { if (keyMatchers[Command.ESCAPE](key) || keyMatchers[Command.QUIT](key)) {
onChoice('cancel'); onChoice('cancel');
} else if (keyMatchers[Command.RETURN](key)) { } else if (state === 'waiting' && keyMatchers[Command.RETURN](key)) {
// User confirmed verification is complete - transition to 'complete' state // User confirmed verification is complete - transition to 'complete' state
setState('complete'); setState('complete');
} }
}, },
{ isActive: state === 'waiting' }, { isActive: state !== 'complete' },
); );
// When state becomes 'complete', show success message briefly then proceed // When state becomes 'complete', show success message briefly then proceed
@@ -41,6 +41,7 @@ describe('useQuotaAndFallback', () => {
let mockConfig: Config; let mockConfig: Config;
let mockHistoryManager: UseHistoryManagerReturn; let mockHistoryManager: UseHistoryManagerReturn;
let mockSetModelSwitchedFromQuotaError: Mock; let mockSetModelSwitchedFromQuotaError: Mock;
let mockOnShowAuthSelection: Mock;
let setFallbackHandlerSpy: SpyInstance; let setFallbackHandlerSpy: SpyInstance;
let mockGoogleApiError: GoogleApiError; let mockGoogleApiError: GoogleApiError;
@@ -66,6 +67,7 @@ describe('useQuotaAndFallback', () => {
loadHistory: vi.fn(), loadHistory: vi.fn(),
}; };
mockSetModelSwitchedFromQuotaError = vi.fn(); mockSetModelSwitchedFromQuotaError = vi.fn();
mockOnShowAuthSelection = vi.fn();
setFallbackHandlerSpy = vi.spyOn(mockConfig, 'setFallbackModelHandler'); setFallbackHandlerSpy = vi.spyOn(mockConfig, 'setFallbackModelHandler');
vi.spyOn(mockConfig, 'setQuotaErrorOccurred'); vi.spyOn(mockConfig, 'setQuotaErrorOccurred');
@@ -85,6 +87,7 @@ describe('useQuotaAndFallback', () => {
historyManager: mockHistoryManager, historyManager: mockHistoryManager,
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
@@ -101,6 +104,7 @@ describe('useQuotaAndFallback', () => {
historyManager: mockHistoryManager, historyManager: mockHistoryManager,
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
return setFallbackHandlerSpy.mock.calls[0][0] as FallbackModelHandler; return setFallbackHandlerSpy.mock.calls[0][0] as FallbackModelHandler;
@@ -127,6 +131,7 @@ describe('useQuotaAndFallback', () => {
historyManager: mockHistoryManager, historyManager: mockHistoryManager,
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
@@ -178,6 +183,7 @@ describe('useQuotaAndFallback', () => {
historyManager: mockHistoryManager, historyManager: mockHistoryManager,
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
@@ -243,6 +249,7 @@ describe('useQuotaAndFallback', () => {
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: setModelSwitchedFromQuotaError:
mockSetModelSwitchedFromQuotaError, mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
@@ -297,6 +304,7 @@ describe('useQuotaAndFallback', () => {
historyManager: mockHistoryManager, historyManager: mockHistoryManager,
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
@@ -345,6 +353,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
historyManager: mockHistoryManager, historyManager: mockHistoryManager,
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
@@ -362,6 +371,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
historyManager: mockHistoryManager, historyManager: mockHistoryManager,
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
@@ -392,6 +402,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
historyManager: mockHistoryManager, historyManager: mockHistoryManager,
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
@@ -435,6 +446,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
historyManager: mockHistoryManager, historyManager: mockHistoryManager,
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
@@ -470,6 +482,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
historyManager: mockHistoryManager, historyManager: mockHistoryManager,
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
@@ -513,6 +526,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
historyManager: mockHistoryManager, historyManager: mockHistoryManager,
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
@@ -527,6 +541,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
historyManager: mockHistoryManager, historyManager: mockHistoryManager,
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
@@ -568,6 +583,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
historyManager: mockHistoryManager, historyManager: mockHistoryManager,
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
@@ -602,13 +618,14 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
expect(result.current.validationRequest).toBeNull(); expect(result.current.validationRequest).toBeNull();
}); });
it('should add info message when change_auth is chosen', async () => { it('should call onShowAuthSelection when change_auth is chosen', async () => {
const { result } = renderHook(() => const { result } = renderHook(() =>
useQuotaAndFallback({ useQuotaAndFallback({
config: mockConfig, config: mockConfig,
historyManager: mockHistoryManager, historyManager: mockHistoryManager,
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
@@ -628,19 +645,17 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
const intent = await promise!; const intent = await promise!;
expect(intent).toBe('change_auth'); expect(intent).toBe('change_auth');
expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(1); expect(mockOnShowAuthSelection).toHaveBeenCalledTimes(1);
const lastCall = (mockHistoryManager.addItem as Mock).mock.calls[0][0];
expect(lastCall.type).toBe(MessageType.INFO);
expect(lastCall.text).toBe('Use /auth to change authentication method.');
}); });
it('should not add info message when cancel is chosen', async () => { it('should call onShowAuthSelection when cancel is chosen', async () => {
const { result } = renderHook(() => const { result } = renderHook(() =>
useQuotaAndFallback({ useQuotaAndFallback({
config: mockConfig, config: mockConfig,
historyManager: mockHistoryManager, historyManager: mockHistoryManager,
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
@@ -660,7 +675,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
const intent = await promise!; const intent = await promise!;
expect(intent).toBe('cancel'); expect(intent).toBe('cancel');
expect(mockHistoryManager.addItem).not.toHaveBeenCalled(); expect(mockOnShowAuthSelection).toHaveBeenCalledTimes(1);
}); });
it('should do nothing if handleValidationChoice is called without pending request', () => { it('should do nothing if handleValidationChoice is called without pending request', () => {
@@ -670,6 +685,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
historyManager: mockHistoryManager, historyManager: mockHistoryManager,
userTier: UserTierId.FREE, userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}), }),
); );
@@ -31,6 +31,7 @@ interface UseQuotaAndFallbackArgs {
historyManager: UseHistoryManagerReturn; historyManager: UseHistoryManagerReturn;
userTier: UserTierId | undefined; userTier: UserTierId | undefined;
setModelSwitchedFromQuotaError: (value: boolean) => void; setModelSwitchedFromQuotaError: (value: boolean) => void;
onShowAuthSelection: () => void;
} }
export function useQuotaAndFallback({ export function useQuotaAndFallback({
@@ -38,6 +39,7 @@ export function useQuotaAndFallback({
historyManager, historyManager,
userTier, userTier,
setModelSwitchedFromQuotaError, setModelSwitchedFromQuotaError,
onShowAuthSelection,
}: UseQuotaAndFallbackArgs) { }: UseQuotaAndFallbackArgs) {
const [proQuotaRequest, setProQuotaRequest] = const [proQuotaRequest, setProQuotaRequest] =
useState<ProQuotaDialogRequest | null>(null); useState<ProQuotaDialogRequest | null>(null);
@@ -197,17 +199,11 @@ export function useQuotaAndFallback({
validationRequest.resolve(choice); validationRequest.resolve(choice);
setValidationRequest(null); setValidationRequest(null);
if (choice === 'change_auth') { if (choice === 'change_auth' || choice === 'cancel') {
historyManager.addItem( onShowAuthSelection();
{
type: MessageType.INFO,
text: 'Use /auth to change authentication method.',
},
Date.now(),
);
} }
}, },
[validationRequest, historyManager], [validationRequest, onShowAuthSelection],
); );
return { return {
@@ -35,7 +35,10 @@ describe('codeAssist', () => {
describe('createCodeAssistContentGenerator', () => { describe('createCodeAssistContentGenerator', () => {
const httpOptions = {}; const httpOptions = {};
const mockConfig = {} as Config; const mockValidationHandler = vi.fn();
const mockConfig = {
getValidationHandler: () => mockValidationHandler,
} as unknown as Config;
const mockAuthClient = { a: 'client' }; const mockAuthClient = { a: 'client' };
const mockUserData = { const mockUserData = {
projectId: 'test-project', projectId: 'test-project',
@@ -57,7 +60,10 @@ describe('codeAssist', () => {
AuthType.LOGIN_WITH_GOOGLE, AuthType.LOGIN_WITH_GOOGLE,
mockConfig, mockConfig,
); );
expect(setupUser).toHaveBeenCalledWith(mockAuthClient); expect(setupUser).toHaveBeenCalledWith(
mockAuthClient,
mockValidationHandler,
);
expect(MockedCodeAssistServer).toHaveBeenCalledWith( expect(MockedCodeAssistServer).toHaveBeenCalledWith(
mockAuthClient, mockAuthClient,
'test-project', 'test-project',
@@ -83,7 +89,10 @@ describe('codeAssist', () => {
AuthType.COMPUTE_ADC, AuthType.COMPUTE_ADC,
mockConfig, mockConfig,
); );
expect(setupUser).toHaveBeenCalledWith(mockAuthClient); expect(setupUser).toHaveBeenCalledWith(
mockAuthClient,
mockValidationHandler,
);
expect(MockedCodeAssistServer).toHaveBeenCalledWith( expect(MockedCodeAssistServer).toHaveBeenCalledWith(
mockAuthClient, mockAuthClient,
'test-project', 'test-project',
+1 -1
View File
@@ -24,7 +24,7 @@ export async function createCodeAssistContentGenerator(
authType === AuthType.COMPUTE_ADC authType === AuthType.COMPUTE_ADC
) { ) {
const authClient = await getOauthClient(authType, config); const authClient = await getOauthClient(authType, config);
const userData = await setupUser(authClient); const userData = await setupUser(authClient, config.getValidationHandler());
return new CodeAssistServer( return new CodeAssistServer(
authClient, authClient,
userData.projectId, userData.projectId,
+219 -1
View File
@@ -5,7 +5,13 @@
*/ */
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { setupUser, ProjectIdRequiredError } from './setup.js'; import {
ProjectIdRequiredError,
setupUser,
ValidationCancelledError,
} from './setup.js';
import { ValidationRequiredError } from '../utils/googleQuotaErrors.js';
import { ChangeAuthRequestedError } from '../utils/errors.js';
import { CodeAssistServer } from '../code_assist/server.js'; import { CodeAssistServer } from '../code_assist/server.js';
import type { OAuth2Client } from 'google-auth-library'; import type { OAuth2Client } from 'google-auth-library';
import type { GeminiUserTier } from './types.js'; import type { GeminiUserTier } from './types.js';
@@ -307,3 +313,215 @@ describe('setupUser for new user', () => {
}); });
}); });
}); });
describe('setupUser validation', () => {
let mockLoad: ReturnType<typeof vi.fn>;
beforeEach(() => {
vi.resetAllMocks();
mockLoad = vi.fn();
vi.mocked(CodeAssistServer).mockImplementation(
() =>
({
loadCodeAssist: mockLoad,
}) as unknown as CodeAssistServer,
);
});
afterEach(() => {
vi.unstubAllEnvs();
});
it('should throw error if LoadCodeAssist returns ineligible tiers and no current tier', async () => {
mockLoad.mockResolvedValue({
currentTier: null,
ineligibleTiers: [
{
reasonMessage: 'User is not eligible',
reasonCode: 'INELIGIBLE_ACCOUNT',
tierId: 'standard-tier',
tierName: 'standard',
},
],
});
await expect(setupUser({} as OAuth2Client)).rejects.toThrow(
'User is not eligible',
);
});
it('should retry if validation handler returns verify', async () => {
// First call fails
mockLoad.mockResolvedValueOnce({
currentTier: null,
ineligibleTiers: [
{
reasonMessage: 'User is not eligible',
reasonCode: 'VALIDATION_REQUIRED',
tierId: 'standard-tier',
tierName: 'standard',
validationUrl: 'https://example.com/verify',
validationLearnMoreUrl: 'https://example.com/learn',
},
],
});
// Second call succeeds
mockLoad.mockResolvedValueOnce({
currentTier: mockPaidTier,
cloudaicompanionProject: 'test-project',
});
const mockValidationHandler = vi.fn().mockResolvedValue('verify');
const result = await setupUser({} as OAuth2Client, mockValidationHandler);
expect(mockValidationHandler).toHaveBeenCalledWith(
'https://example.com/verify',
'User is not eligible',
);
expect(mockLoad).toHaveBeenCalledTimes(2);
expect(result).toEqual({
projectId: 'test-project',
userTier: 'standard-tier',
userTierName: 'paid',
});
});
it('should throw if validation handler returns cancel', async () => {
mockLoad.mockResolvedValue({
currentTier: null,
ineligibleTiers: [
{
reasonMessage: 'User is not eligible',
reasonCode: 'VALIDATION_REQUIRED',
tierId: 'standard-tier',
tierName: 'standard',
validationUrl: 'https://example.com/verify',
},
],
});
const mockValidationHandler = vi.fn().mockResolvedValue('cancel');
await expect(
setupUser({} as OAuth2Client, mockValidationHandler),
).rejects.toThrow(ValidationCancelledError);
expect(mockValidationHandler).toHaveBeenCalled();
expect(mockLoad).toHaveBeenCalledTimes(1);
});
it('should throw ChangeAuthRequestedError if validation handler returns change_auth', async () => {
mockLoad.mockResolvedValue({
currentTier: null,
ineligibleTiers: [
{
reasonMessage: 'User is not eligible',
reasonCode: 'VALIDATION_REQUIRED',
tierId: 'standard-tier',
tierName: 'standard',
validationUrl: 'https://example.com/verify',
},
],
});
const mockValidationHandler = vi.fn().mockResolvedValue('change_auth');
await expect(
setupUser({} as OAuth2Client, mockValidationHandler),
).rejects.toThrow(ChangeAuthRequestedError);
expect(mockValidationHandler).toHaveBeenCalled();
expect(mockLoad).toHaveBeenCalledTimes(1);
});
it('should throw ValidationRequiredError without handler', async () => {
mockLoad.mockResolvedValue({
currentTier: null,
ineligibleTiers: [
{
reasonMessage: 'Please verify your account',
reasonCode: 'VALIDATION_REQUIRED',
tierId: 'standard-tier',
tierName: 'standard',
validationUrl: 'https://example.com/verify',
},
],
});
await expect(setupUser({} as OAuth2Client)).rejects.toThrow(
ValidationRequiredError,
);
expect(mockLoad).toHaveBeenCalledTimes(1);
});
it('should throw error if LoadCodeAssist returns empty response', async () => {
mockLoad.mockResolvedValue(null);
await expect(setupUser({} as OAuth2Client)).rejects.toThrow(
'LoadCodeAssist returned empty response',
);
});
it('should retry multiple times when validation handler keeps returning verify', async () => {
// First two calls fail with validation required
mockLoad
.mockResolvedValueOnce({
currentTier: null,
ineligibleTiers: [
{
reasonMessage: 'Verify 1',
reasonCode: 'VALIDATION_REQUIRED',
tierId: 'standard-tier',
tierName: 'standard',
validationUrl: 'https://example.com/verify',
},
],
})
.mockResolvedValueOnce({
currentTier: null,
ineligibleTiers: [
{
reasonMessage: 'Verify 2',
reasonCode: 'VALIDATION_REQUIRED',
tierId: 'standard-tier',
tierName: 'standard',
validationUrl: 'https://example.com/verify',
},
],
})
.mockResolvedValueOnce({
currentTier: mockPaidTier,
cloudaicompanionProject: 'test-project',
});
const mockValidationHandler = vi.fn().mockResolvedValue('verify');
const result = await setupUser({} as OAuth2Client, mockValidationHandler);
expect(mockValidationHandler).toHaveBeenCalledTimes(2);
expect(mockLoad).toHaveBeenCalledTimes(3);
expect(result).toEqual({
projectId: 'test-project',
userTier: 'standard-tier',
userTierName: 'paid',
});
});
});
describe('ValidationRequiredError', () => {
const error = new ValidationRequiredError(
'Account validation required: Please verify',
undefined,
'https://example.com/verify',
'Please verify',
);
it('should be an instance of Error', () => {
expect(error).toBeInstanceOf(Error);
expect(error).toBeInstanceOf(ValidationRequiredError);
});
it('should have the correct properties', () => {
expect(error.validationLink).toBe('https://example.com/verify');
expect(error.validationDescription).toBe('Please verify');
});
});
+79 -9
View File
@@ -10,9 +10,12 @@ import type {
LoadCodeAssistResponse, LoadCodeAssistResponse,
OnboardUserRequest, OnboardUserRequest,
} from './types.js'; } from './types.js';
import { UserTierId } from './types.js'; import { UserTierId, IneligibleTierReasonCode } from './types.js';
import { CodeAssistServer } from './server.js'; import { CodeAssistServer } from './server.js';
import type { AuthClient } from 'google-auth-library'; import type { AuthClient } from 'google-auth-library';
import type { ValidationHandler } from '../fallback/types.js';
import { ChangeAuthRequestedError } from '../utils/errors.js';
import { ValidationRequiredError } from '../utils/googleQuotaErrors.js';
export class ProjectIdRequiredError extends Error { export class ProjectIdRequiredError extends Error {
constructor() { constructor() {
@@ -22,6 +25,16 @@ export class ProjectIdRequiredError extends Error {
} }
} }
/**
* Error thrown when user cancels the validation process.
* This is a non-recoverable error that should result in auth failure.
*/
export class ValidationCancelledError extends Error {
constructor() {
super('User cancelled account validation');
}
}
export interface UserData { export interface UserData {
projectId: string; projectId: string;
userTier: UserTierId; userTier: UserTierId;
@@ -33,7 +46,10 @@ export interface UserData {
* @param projectId the user's project id, if any * @param projectId the user's project id, if any
* @returns the user's actual project id * @returns the user's actual project id
*/ */
export async function setupUser(client: AuthClient): Promise<UserData> { export async function setupUser(
client: AuthClient,
validationHandler?: ValidationHandler,
): Promise<UserData> {
const projectId = const projectId =
process.env['GOOGLE_CLOUD_PROJECT'] || process.env['GOOGLE_CLOUD_PROJECT'] ||
process.env['GOOGLE_CLOUD_PROJECT_ID'] || process.env['GOOGLE_CLOUD_PROJECT_ID'] ||
@@ -52,13 +68,36 @@ export async function setupUser(client: AuthClient): Promise<UserData> {
pluginType: 'GEMINI', pluginType: 'GEMINI',
}; };
const loadRes = await caServer.loadCodeAssist({ let loadRes: LoadCodeAssistResponse;
cloudaicompanionProject: projectId, while (true) {
metadata: { loadRes = await caServer.loadCodeAssist({
...coreClientMetadata, cloudaicompanionProject: projectId,
duetProject: projectId, metadata: {
}, ...coreClientMetadata,
}); duetProject: projectId,
},
});
try {
validateLoadCodeAssistResponse(loadRes);
break;
} catch (e) {
if (e instanceof ValidationRequiredError && validationHandler) {
const intent = await validationHandler(
e.validationLink,
e.validationDescription,
);
if (intent === 'verify') {
continue;
}
if (intent === 'change_auth') {
throw new ChangeAuthRequestedError();
}
throw new ValidationCancelledError();
}
throw e;
}
}
if (loadRes.currentTier) { if (loadRes.currentTier) {
if (!loadRes.cloudaicompanionProject) { if (!loadRes.cloudaicompanionProject) {
@@ -139,3 +178,34 @@ function getOnboardTier(res: LoadCodeAssistResponse): GeminiUserTier {
userDefinedCloudaicompanionProject: true, userDefinedCloudaicompanionProject: true,
}; };
} }
function validateLoadCodeAssistResponse(res: LoadCodeAssistResponse): void {
if (!res) {
throw new Error('LoadCodeAssist returned empty response');
}
if (
!res.currentTier &&
res.ineligibleTiers &&
res.ineligibleTiers.length > 0
) {
// Check for VALIDATION_REQUIRED first - this is a recoverable state
const validationTier = res.ineligibleTiers.find(
(t) =>
t.validationUrl &&
t.reasonCode === IneligibleTierReasonCode.VALIDATION_REQUIRED,
);
const validationUrl = validationTier?.validationUrl;
if (validationTier && validationUrl) {
throw new ValidationRequiredError(
`Account validation required: ${validationTier.reasonMessage}`,
undefined,
validationUrl,
validationTier.reasonMessage,
);
}
// For other ineligibility reasons, throw a generic error
const reasons = res.ineligibleTiers.map((t) => t.reasonMessage).join(', ');
throw new Error(reasons);
}
}
+6
View File
@@ -82,6 +82,11 @@ export interface IneligibleTier {
reasonMessage: string; reasonMessage: string;
tierId: UserTierId; tierId: UserTierId;
tierName: string; tierName: string;
validationErrorMessage?: string;
validationUrl?: string;
validationUrlLinkText?: string;
validationLearnMoreUrl?: string;
validationLearnMoreLinkText?: string;
} }
/** /**
@@ -98,6 +103,7 @@ export enum IneligibleTierReasonCode {
UNKNOWN = 'UNKNOWN', UNKNOWN = 'UNKNOWN',
UNKNOWN_LOCATION = 'UNKNOWN_LOCATION', UNKNOWN_LOCATION = 'UNKNOWN_LOCATION',
UNSUPPORTED_LOCATION = 'UNSUPPORTED_LOCATION', UNSUPPORTED_LOCATION = 'UNSUPPORTED_LOCATION',
VALIDATION_REQUIRED = 'VALIDATION_REQUIRED',
// go/keep-sorted end // go/keep-sorted end
} }
/** /**
+1
View File
@@ -47,6 +47,7 @@ export * from './fallback/types.js';
export * from './code_assist/codeAssist.js'; export * from './code_assist/codeAssist.js';
export * from './code_assist/oauth2.js'; export * from './code_assist/oauth2.js';
export * from './code_assist/server.js'; export * from './code_assist/server.js';
export * from './code_assist/setup.js';
export * from './code_assist/types.js'; export * from './code_assist/types.js';
export * from './code_assist/telemetry.js'; export * from './code_assist/telemetry.js';
export * from './core/apiKeyCredentialStorage.js'; export * from './core/apiKeyCredentialStorage.js';
+7
View File
@@ -81,6 +81,13 @@ export class ForbiddenError extends Error {}
export class UnauthorizedError extends Error {} export class UnauthorizedError extends Error {}
export class BadRequestError extends Error {} export class BadRequestError extends Error {}
export class ChangeAuthRequestedError extends Error {
constructor() {
super('User requested to change authentication method');
this.name = 'ChangeAuthRequestedError';
}
}
interface ResponseData { interface ResponseData {
error?: { error?: {
code?: number; code?: number;
+1 -1
View File
@@ -63,7 +63,7 @@ export class ValidationRequiredError extends Error {
constructor( constructor(
message: string, message: string,
override readonly cause: GoogleApiError, override readonly cause?: GoogleApiError,
validationLink?: string, validationLink?: string,
validationDescription?: string, validationDescription?: string,
learnMoreUrl?: string, learnMoreUrl?: string,