From 5fe328c56a0f862368e16a777850200b99b27aae Mon Sep 17 00:00:00 2001 From: Gaurav <39389231+gsquared94@users.noreply.github.com> Date: Mon, 26 Jan 2026 06:31:19 -0800 Subject: [PATCH] Improve error messages on failed onboarding (#17357) --- packages/cli/src/core/auth.test.ts | 50 ++-- packages/cli/src/core/auth.ts | 6 + packages/cli/src/gemini.tsx | 17 +- packages/cli/src/ui/AppContainer.tsx | 7 +- packages/cli/src/ui/auth/useAuth.ts | 10 +- .../ui/components/ValidationDialog.test.tsx | 33 ++- .../src/ui/components/ValidationDialog.tsx | 6 +- .../src/ui/hooks/useQuotaAndFallback.test.ts | 30 ++- .../cli/src/ui/hooks/useQuotaAndFallback.ts | 14 +- .../core/src/code_assist/codeAssist.test.ts | 15 +- packages/core/src/code_assist/codeAssist.ts | 2 +- packages/core/src/code_assist/setup.test.ts | 220 +++++++++++++++++- packages/core/src/code_assist/setup.ts | 88 ++++++- packages/core/src/code_assist/types.ts | 6 + packages/core/src/index.ts | 1 + packages/core/src/utils/errors.ts | 7 + packages/core/src/utils/googleQuotaErrors.ts | 2 +- 17 files changed, 458 insertions(+), 56 deletions(-) diff --git a/packages/cli/src/core/auth.test.ts b/packages/cli/src/core/auth.test.ts index 366e5c9137..c844ee6f93 100644 --- a/packages/cli/src/core/auth.test.ts +++ b/packages/cli/src/core/auth.test.ts @@ -6,18 +6,20 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'; import { performInitialAuth } from './auth.js'; -import { type Config } from '@google/gemini-cli-core'; +import { + type Config, + ValidationRequiredError, + AuthType, +} from '@google/gemini-cli-core'; -vi.mock('@google/gemini-cli-core', () => ({ - AuthType: { - OAUTH: 'oauth', - }, - getErrorMessage: (e: unknown) => (e as Error).message, -})); - -const AuthType = { - OAUTH: 'oauth', -} as const; +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const actual = + await importOriginal(); + return { + ...actual, + getErrorMessage: (e: unknown) => (e as Error).message, + }; +}); describe('auth', () => { let mockConfig: Config; @@ -37,10 +39,12 @@ describe('auth', () => { it('should return null on successful auth', async () => { const result = await performInitialAuth( mockConfig, - AuthType.OAUTH as unknown as Parameters[1], + AuthType.LOGIN_WITH_GOOGLE, ); expect(result).toBeNull(); - expect(mockConfig.refreshAuth).toHaveBeenCalledWith(AuthType.OAUTH); + expect(mockConfig.refreshAuth).toHaveBeenCalledWith( + AuthType.LOGIN_WITH_GOOGLE, + ); }); it('should return error message on failed auth', async () => { @@ -48,9 +52,25 @@ describe('auth', () => { vi.mocked(mockConfig.refreshAuth).mockRejectedValue(error); const result = await performInitialAuth( mockConfig, - AuthType.OAUTH as unknown as Parameters[1], + AuthType.LOGIN_WITH_GOOGLE, ); expect(result).toBe('Failed to login. Message: Auth failed'); - expect(mockConfig.refreshAuth).toHaveBeenCalledWith(AuthType.OAUTH); + expect(mockConfig.refreshAuth).toHaveBeenCalledWith( + AuthType.LOGIN_WITH_GOOGLE, + ); + }); + + it('should return null if refreshAuth throws ValidationRequiredError', async () => { + vi.mocked(mockConfig.refreshAuth).mockRejectedValue( + new ValidationRequiredError('Validation required'), + ); + const result = await performInitialAuth( + mockConfig, + AuthType.LOGIN_WITH_GOOGLE, + ); + expect(result).toBeNull(); + expect(mockConfig.refreshAuth).toHaveBeenCalledWith( + AuthType.LOGIN_WITH_GOOGLE, + ); }); }); diff --git a/packages/cli/src/core/auth.ts b/packages/cli/src/core/auth.ts index f4f4963bc7..7b1e8c8277 100644 --- a/packages/cli/src/core/auth.ts +++ b/packages/cli/src/core/auth.ts @@ -8,6 +8,7 @@ import { type AuthType, type Config, getErrorMessage, + ValidationRequiredError, } from '@google/gemini-cli-core'; /** @@ -29,6 +30,11 @@ export async function performInitialAuth( // The console.log is intentionally left out here. // We can add a dedicated startup message later if needed. } catch (e) { + if (e instanceof ValidationRequiredError) { + // Don't treat validation required as a fatal auth error during startup. + // This allows the React UI to load and show the ValidationDialog. + return null; + } return `Failed to login. Message: ${getErrorMessage(e)}`; } diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index ff73dcfdfa..20f022021a 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -61,6 +61,8 @@ import { SessionStartSource, SessionEndReason, getVersion, + ValidationCancelledError, + ValidationRequiredError, type FetchAdminControlsResponse, } from '@google/gemini-cli-core'; import { @@ -406,8 +408,19 @@ export async function main() { await partialConfig.refreshAuth(authType); } } catch (err) { - debugLogger.error('Error authenticating:', err); - initialAuthFailed = true; + if (err instanceof ValidationCancelledError) { + // User cancelled verification, exit immediately. + await runExitCleanup(); + process.exit(ExitCodes.SUCCESS); + } + + // If validation is required, we don't treat it as a fatal failure. + // We allow the app to start, and the React-based ValidationDialog + // will handle it. + if (!(err instanceof ValidationRequiredError)) { + debugLogger.error('Error authenticating:', err); + initialAuthFailed = true; + } } } diff --git a/packages/cli/src/ui/AppContainer.tsx b/packages/cli/src/ui/AppContainer.tsx index 43553efe14..0e337b7c1f 100644 --- a/packages/cli/src/ui/AppContainer.tsx +++ b/packages/cli/src/ui/AppContainer.tsx @@ -63,6 +63,7 @@ import { SessionStartSource, SessionEndReason, generateSummary, + ChangeAuthRequestedError, } from '@google/gemini-cli-core'; import { validateAuthMethod } from '../config/auth.js'; import process from 'node:process'; @@ -527,7 +528,7 @@ export const AppContainer = (props: AppContainerProps) => { onAuthError, apiKeyDefaultValue, reloadApiKey, - } = useAuthCommand(settings, config); + } = useAuthCommand(settings, config, initializationResult.authError); const [authContext, setAuthContext] = useState<{ requiresRestart?: boolean }>( {}, ); @@ -549,6 +550,7 @@ export const AppContainer = (props: AppContainerProps) => { historyManager, userTier, setModelSwitchedFromQuotaError, + onShowAuthSelection: () => setAuthState(AuthState.Updating), }); // Derive auth state variables for backward compatibility with UIStateContext @@ -598,6 +600,9 @@ export const AppContainer = (props: AppContainerProps) => { await config.refreshAuth(authType); setAuthState(AuthState.Authenticated); } catch (e) { + if (e instanceof ChangeAuthRequestedError) { + return; + } onAuthError( `Failed to authenticate: ${e instanceof Error ? e.message : String(e)}`, ); diff --git a/packages/cli/src/ui/auth/useAuth.ts b/packages/cli/src/ui/auth/useAuth.ts index 7b37e2d421..2b61265890 100644 --- a/packages/cli/src/ui/auth/useAuth.ts +++ b/packages/cli/src/ui/auth/useAuth.ts @@ -34,12 +34,16 @@ export function validateAuthMethodWithSettings( return validateAuthMethod(authType); } -export const useAuthCommand = (settings: LoadedSettings, config: Config) => { +export const useAuthCommand = ( + settings: LoadedSettings, + config: Config, + initialAuthError: string | null = null, +) => { const [authState, setAuthState] = useState( - AuthState.Unauthenticated, + initialAuthError ? AuthState.Updating : AuthState.Unauthenticated, ); - const [authError, setAuthError] = useState(null); + const [authError, setAuthError] = useState(initialAuthError); const [apiKeyDefaultValue, setApiKeyDefaultValue] = useState< string | undefined >(undefined); diff --git a/packages/cli/src/ui/components/ValidationDialog.test.tsx b/packages/cli/src/ui/components/ValidationDialog.test.tsx index ac938202ab..0e50781342 100644 --- a/packages/cli/src/ui/components/ValidationDialog.test.tsx +++ b/packages/cli/src/ui/components/ValidationDialog.test.tsx @@ -17,6 +17,7 @@ import { } from 'vitest'; import { ValidationDialog } from './ValidationDialog.js'; import { RadioButtonSelect } from './shared/RadioButtonSelect.js'; +import type { Key } from '../hooks/useKeypress.js'; // Mock the child components and utilities vi.mock('./shared/RadioButtonSelect.js', () => ({ @@ -41,8 +42,15 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => { }; }); +// Capture keypress handler to test it +let mockKeypressHandler: (key: Key) => void; +let mockKeypressOptions: { isActive: boolean }; + vi.mock('../hooks/useKeypress.js', () => ({ - useKeypress: vi.fn(), + useKeypress: vi.fn((handler, options) => { + mockKeypressHandler = handler; + mockKeypressOptions = options; + }), })); describe('ValidationDialog', () => { @@ -99,6 +107,29 @@ describe('ValidationDialog', () => { expect(lastFrame()).toContain('https://example.com/help'); unmount(); }); + + it('should call onChoice with cancel when ESCAPE is pressed', () => { + const { unmount } = render(); + + // Verify the keypress hook is active + expect(mockKeypressOptions.isActive).toBe(true); + + // Simulate ESCAPE key press + act(() => { + mockKeypressHandler({ + name: 'escape', + ctrl: false, + shift: false, + alt: false, + cmd: false, + insertable: false, + sequence: '\x1b', + }); + }); + + expect(mockOnChoice).toHaveBeenCalledWith('cancel'); + unmount(); + }); }); describe('onChoice handling', () => { diff --git a/packages/cli/src/ui/components/ValidationDialog.tsx b/packages/cli/src/ui/components/ValidationDialog.tsx index b7ddf2878a..9c71e93403 100644 --- a/packages/cli/src/ui/components/ValidationDialog.tsx +++ b/packages/cli/src/ui/components/ValidationDialog.tsx @@ -48,17 +48,17 @@ export function ValidationDialog({ }, ]; - // Handle keypresses during 'waiting' state (ESC to cancel, Enter to confirm completion) + // Handle keypresses globally for cancellation, and specific logic for waiting state useKeypress( (key) => { if (keyMatchers[Command.ESCAPE](key) || keyMatchers[Command.QUIT](key)) { onChoice('cancel'); - } else if (keyMatchers[Command.RETURN](key)) { + } else if (state === 'waiting' && keyMatchers[Command.RETURN](key)) { // User confirmed verification is complete - transition to 'complete' state setState('complete'); } }, - { isActive: state === 'waiting' }, + { isActive: state !== 'complete' }, ); // When state becomes 'complete', show success message briefly then proceed diff --git a/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts b/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts index 61e53638ec..2a9106329e 100644 --- a/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts +++ b/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts @@ -41,6 +41,7 @@ describe('useQuotaAndFallback', () => { let mockConfig: Config; let mockHistoryManager: UseHistoryManagerReturn; let mockSetModelSwitchedFromQuotaError: Mock; + let mockOnShowAuthSelection: Mock; let setFallbackHandlerSpy: SpyInstance; let mockGoogleApiError: GoogleApiError; @@ -66,6 +67,7 @@ describe('useQuotaAndFallback', () => { loadHistory: vi.fn(), }; mockSetModelSwitchedFromQuotaError = vi.fn(); + mockOnShowAuthSelection = vi.fn(); setFallbackHandlerSpy = vi.spyOn(mockConfig, 'setFallbackModelHandler'); vi.spyOn(mockConfig, 'setQuotaErrorOccurred'); @@ -85,6 +87,7 @@ describe('useQuotaAndFallback', () => { historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -101,6 +104,7 @@ describe('useQuotaAndFallback', () => { historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); return setFallbackHandlerSpy.mock.calls[0][0] as FallbackModelHandler; @@ -127,6 +131,7 @@ describe('useQuotaAndFallback', () => { historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -178,6 +183,7 @@ describe('useQuotaAndFallback', () => { historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -243,6 +249,7 @@ describe('useQuotaAndFallback', () => { userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -297,6 +304,7 @@ describe('useQuotaAndFallback', () => { historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -345,6 +353,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -362,6 +371,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -392,6 +402,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -435,6 +446,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -470,6 +482,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -513,6 +526,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -527,6 +541,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -568,6 +583,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -602,13 +618,14 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, expect(result.current.validationRequest).toBeNull(); }); - it('should add info message when change_auth is chosen', async () => { + it('should call onShowAuthSelection when change_auth is chosen', async () => { const { result } = renderHook(() => useQuotaAndFallback({ config: mockConfig, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -628,19 +645,17 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, const intent = await promise!; expect(intent).toBe('change_auth'); - expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(1); - const lastCall = (mockHistoryManager.addItem as Mock).mock.calls[0][0]; - expect(lastCall.type).toBe(MessageType.INFO); - expect(lastCall.text).toBe('Use /auth to change authentication method.'); + expect(mockOnShowAuthSelection).toHaveBeenCalledTimes(1); }); - it('should not add info message when cancel is chosen', async () => { + it('should call onShowAuthSelection when cancel is chosen', async () => { const { result } = renderHook(() => useQuotaAndFallback({ config: mockConfig, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -660,7 +675,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, const intent = await promise!; expect(intent).toBe('cancel'); - expect(mockHistoryManager.addItem).not.toHaveBeenCalled(); + expect(mockOnShowAuthSelection).toHaveBeenCalledTimes(1); }); it('should do nothing if handleValidationChoice is called without pending request', () => { @@ -670,6 +685,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); diff --git a/packages/cli/src/ui/hooks/useQuotaAndFallback.ts b/packages/cli/src/ui/hooks/useQuotaAndFallback.ts index 7f8b8d0f0d..bc12c60907 100644 --- a/packages/cli/src/ui/hooks/useQuotaAndFallback.ts +++ b/packages/cli/src/ui/hooks/useQuotaAndFallback.ts @@ -31,6 +31,7 @@ interface UseQuotaAndFallbackArgs { historyManager: UseHistoryManagerReturn; userTier: UserTierId | undefined; setModelSwitchedFromQuotaError: (value: boolean) => void; + onShowAuthSelection: () => void; } export function useQuotaAndFallback({ @@ -38,6 +39,7 @@ export function useQuotaAndFallback({ historyManager, userTier, setModelSwitchedFromQuotaError, + onShowAuthSelection, }: UseQuotaAndFallbackArgs) { const [proQuotaRequest, setProQuotaRequest] = useState(null); @@ -197,17 +199,11 @@ export function useQuotaAndFallback({ validationRequest.resolve(choice); setValidationRequest(null); - if (choice === 'change_auth') { - historyManager.addItem( - { - type: MessageType.INFO, - text: 'Use /auth to change authentication method.', - }, - Date.now(), - ); + if (choice === 'change_auth' || choice === 'cancel') { + onShowAuthSelection(); } }, - [validationRequest, historyManager], + [validationRequest, onShowAuthSelection], ); return { diff --git a/packages/core/src/code_assist/codeAssist.test.ts b/packages/core/src/code_assist/codeAssist.test.ts index 90ebfb1d9c..6efee88d69 100644 --- a/packages/core/src/code_assist/codeAssist.test.ts +++ b/packages/core/src/code_assist/codeAssist.test.ts @@ -35,7 +35,10 @@ describe('codeAssist', () => { describe('createCodeAssistContentGenerator', () => { const httpOptions = {}; - const mockConfig = {} as Config; + const mockValidationHandler = vi.fn(); + const mockConfig = { + getValidationHandler: () => mockValidationHandler, + } as unknown as Config; const mockAuthClient = { a: 'client' }; const mockUserData = { projectId: 'test-project', @@ -57,7 +60,10 @@ describe('codeAssist', () => { AuthType.LOGIN_WITH_GOOGLE, mockConfig, ); - expect(setupUser).toHaveBeenCalledWith(mockAuthClient); + expect(setupUser).toHaveBeenCalledWith( + mockAuthClient, + mockValidationHandler, + ); expect(MockedCodeAssistServer).toHaveBeenCalledWith( mockAuthClient, 'test-project', @@ -83,7 +89,10 @@ describe('codeAssist', () => { AuthType.COMPUTE_ADC, mockConfig, ); - expect(setupUser).toHaveBeenCalledWith(mockAuthClient); + expect(setupUser).toHaveBeenCalledWith( + mockAuthClient, + mockValidationHandler, + ); expect(MockedCodeAssistServer).toHaveBeenCalledWith( mockAuthClient, 'test-project', diff --git a/packages/core/src/code_assist/codeAssist.ts b/packages/core/src/code_assist/codeAssist.ts index fee43e9c45..3b87cb03e2 100644 --- a/packages/core/src/code_assist/codeAssist.ts +++ b/packages/core/src/code_assist/codeAssist.ts @@ -24,7 +24,7 @@ export async function createCodeAssistContentGenerator( authType === AuthType.COMPUTE_ADC ) { const authClient = await getOauthClient(authType, config); - const userData = await setupUser(authClient); + const userData = await setupUser(authClient, config.getValidationHandler()); return new CodeAssistServer( authClient, userData.projectId, diff --git a/packages/core/src/code_assist/setup.test.ts b/packages/core/src/code_assist/setup.test.ts index bd43ed2e88..9559c58254 100644 --- a/packages/core/src/code_assist/setup.test.ts +++ b/packages/core/src/code_assist/setup.test.ts @@ -5,7 +5,13 @@ */ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import { setupUser, ProjectIdRequiredError } from './setup.js'; +import { + ProjectIdRequiredError, + setupUser, + ValidationCancelledError, +} from './setup.js'; +import { ValidationRequiredError } from '../utils/googleQuotaErrors.js'; +import { ChangeAuthRequestedError } from '../utils/errors.js'; import { CodeAssistServer } from '../code_assist/server.js'; import type { OAuth2Client } from 'google-auth-library'; import type { GeminiUserTier } from './types.js'; @@ -307,3 +313,215 @@ describe('setupUser for new user', () => { }); }); }); + +describe('setupUser validation', () => { + let mockLoad: ReturnType; + + beforeEach(() => { + vi.resetAllMocks(); + mockLoad = vi.fn(); + vi.mocked(CodeAssistServer).mockImplementation( + () => + ({ + loadCodeAssist: mockLoad, + }) as unknown as CodeAssistServer, + ); + }); + + afterEach(() => { + vi.unstubAllEnvs(); + }); + + it('should throw error if LoadCodeAssist returns ineligible tiers and no current tier', async () => { + mockLoad.mockResolvedValue({ + currentTier: null, + ineligibleTiers: [ + { + reasonMessage: 'User is not eligible', + reasonCode: 'INELIGIBLE_ACCOUNT', + tierId: 'standard-tier', + tierName: 'standard', + }, + ], + }); + + await expect(setupUser({} as OAuth2Client)).rejects.toThrow( + 'User is not eligible', + ); + }); + + it('should retry if validation handler returns verify', async () => { + // First call fails + mockLoad.mockResolvedValueOnce({ + currentTier: null, + ineligibleTiers: [ + { + reasonMessage: 'User is not eligible', + reasonCode: 'VALIDATION_REQUIRED', + tierId: 'standard-tier', + tierName: 'standard', + validationUrl: 'https://example.com/verify', + validationLearnMoreUrl: 'https://example.com/learn', + }, + ], + }); + // Second call succeeds + mockLoad.mockResolvedValueOnce({ + currentTier: mockPaidTier, + cloudaicompanionProject: 'test-project', + }); + + const mockValidationHandler = vi.fn().mockResolvedValue('verify'); + + const result = await setupUser({} as OAuth2Client, mockValidationHandler); + + expect(mockValidationHandler).toHaveBeenCalledWith( + 'https://example.com/verify', + 'User is not eligible', + ); + expect(mockLoad).toHaveBeenCalledTimes(2); + expect(result).toEqual({ + projectId: 'test-project', + userTier: 'standard-tier', + userTierName: 'paid', + }); + }); + + it('should throw if validation handler returns cancel', async () => { + mockLoad.mockResolvedValue({ + currentTier: null, + ineligibleTiers: [ + { + reasonMessage: 'User is not eligible', + reasonCode: 'VALIDATION_REQUIRED', + tierId: 'standard-tier', + tierName: 'standard', + validationUrl: 'https://example.com/verify', + }, + ], + }); + + const mockValidationHandler = vi.fn().mockResolvedValue('cancel'); + + await expect( + setupUser({} as OAuth2Client, mockValidationHandler), + ).rejects.toThrow(ValidationCancelledError); + expect(mockValidationHandler).toHaveBeenCalled(); + expect(mockLoad).toHaveBeenCalledTimes(1); + }); + + it('should throw ChangeAuthRequestedError if validation handler returns change_auth', async () => { + mockLoad.mockResolvedValue({ + currentTier: null, + ineligibleTiers: [ + { + reasonMessage: 'User is not eligible', + reasonCode: 'VALIDATION_REQUIRED', + tierId: 'standard-tier', + tierName: 'standard', + validationUrl: 'https://example.com/verify', + }, + ], + }); + + const mockValidationHandler = vi.fn().mockResolvedValue('change_auth'); + + await expect( + setupUser({} as OAuth2Client, mockValidationHandler), + ).rejects.toThrow(ChangeAuthRequestedError); + expect(mockValidationHandler).toHaveBeenCalled(); + expect(mockLoad).toHaveBeenCalledTimes(1); + }); + + it('should throw ValidationRequiredError without handler', async () => { + mockLoad.mockResolvedValue({ + currentTier: null, + ineligibleTiers: [ + { + reasonMessage: 'Please verify your account', + reasonCode: 'VALIDATION_REQUIRED', + tierId: 'standard-tier', + tierName: 'standard', + validationUrl: 'https://example.com/verify', + }, + ], + }); + + await expect(setupUser({} as OAuth2Client)).rejects.toThrow( + ValidationRequiredError, + ); + expect(mockLoad).toHaveBeenCalledTimes(1); + }); + + it('should throw error if LoadCodeAssist returns empty response', async () => { + mockLoad.mockResolvedValue(null); + + await expect(setupUser({} as OAuth2Client)).rejects.toThrow( + 'LoadCodeAssist returned empty response', + ); + }); + + it('should retry multiple times when validation handler keeps returning verify', async () => { + // First two calls fail with validation required + mockLoad + .mockResolvedValueOnce({ + currentTier: null, + ineligibleTiers: [ + { + reasonMessage: 'Verify 1', + reasonCode: 'VALIDATION_REQUIRED', + tierId: 'standard-tier', + tierName: 'standard', + validationUrl: 'https://example.com/verify', + }, + ], + }) + .mockResolvedValueOnce({ + currentTier: null, + ineligibleTiers: [ + { + reasonMessage: 'Verify 2', + reasonCode: 'VALIDATION_REQUIRED', + tierId: 'standard-tier', + tierName: 'standard', + validationUrl: 'https://example.com/verify', + }, + ], + }) + .mockResolvedValueOnce({ + currentTier: mockPaidTier, + cloudaicompanionProject: 'test-project', + }); + + const mockValidationHandler = vi.fn().mockResolvedValue('verify'); + + const result = await setupUser({} as OAuth2Client, mockValidationHandler); + + expect(mockValidationHandler).toHaveBeenCalledTimes(2); + expect(mockLoad).toHaveBeenCalledTimes(3); + expect(result).toEqual({ + projectId: 'test-project', + userTier: 'standard-tier', + userTierName: 'paid', + }); + }); +}); + +describe('ValidationRequiredError', () => { + const error = new ValidationRequiredError( + 'Account validation required: Please verify', + undefined, + 'https://example.com/verify', + 'Please verify', + ); + + it('should be an instance of Error', () => { + expect(error).toBeInstanceOf(Error); + expect(error).toBeInstanceOf(ValidationRequiredError); + }); + + it('should have the correct properties', () => { + expect(error.validationLink).toBe('https://example.com/verify'); + expect(error.validationDescription).toBe('Please verify'); + }); +}); diff --git a/packages/core/src/code_assist/setup.ts b/packages/core/src/code_assist/setup.ts index 994bb99568..15da70fb42 100644 --- a/packages/core/src/code_assist/setup.ts +++ b/packages/core/src/code_assist/setup.ts @@ -10,9 +10,12 @@ import type { LoadCodeAssistResponse, OnboardUserRequest, } from './types.js'; -import { UserTierId } from './types.js'; +import { UserTierId, IneligibleTierReasonCode } from './types.js'; import { CodeAssistServer } from './server.js'; import type { AuthClient } from 'google-auth-library'; +import type { ValidationHandler } from '../fallback/types.js'; +import { ChangeAuthRequestedError } from '../utils/errors.js'; +import { ValidationRequiredError } from '../utils/googleQuotaErrors.js'; export class ProjectIdRequiredError extends Error { constructor() { @@ -22,6 +25,16 @@ export class ProjectIdRequiredError extends Error { } } +/** + * Error thrown when user cancels the validation process. + * This is a non-recoverable error that should result in auth failure. + */ +export class ValidationCancelledError extends Error { + constructor() { + super('User cancelled account validation'); + } +} + export interface UserData { projectId: string; userTier: UserTierId; @@ -33,7 +46,10 @@ export interface UserData { * @param projectId the user's project id, if any * @returns the user's actual project id */ -export async function setupUser(client: AuthClient): Promise { +export async function setupUser( + client: AuthClient, + validationHandler?: ValidationHandler, +): Promise { const projectId = process.env['GOOGLE_CLOUD_PROJECT'] || process.env['GOOGLE_CLOUD_PROJECT_ID'] || @@ -52,13 +68,36 @@ export async function setupUser(client: AuthClient): Promise { pluginType: 'GEMINI', }; - const loadRes = await caServer.loadCodeAssist({ - cloudaicompanionProject: projectId, - metadata: { - ...coreClientMetadata, - duetProject: projectId, - }, - }); + let loadRes: LoadCodeAssistResponse; + while (true) { + loadRes = await caServer.loadCodeAssist({ + cloudaicompanionProject: projectId, + metadata: { + ...coreClientMetadata, + duetProject: projectId, + }, + }); + + try { + validateLoadCodeAssistResponse(loadRes); + break; + } catch (e) { + if (e instanceof ValidationRequiredError && validationHandler) { + const intent = await validationHandler( + e.validationLink, + e.validationDescription, + ); + if (intent === 'verify') { + continue; + } + if (intent === 'change_auth') { + throw new ChangeAuthRequestedError(); + } + throw new ValidationCancelledError(); + } + throw e; + } + } if (loadRes.currentTier) { if (!loadRes.cloudaicompanionProject) { @@ -139,3 +178,34 @@ function getOnboardTier(res: LoadCodeAssistResponse): GeminiUserTier { userDefinedCloudaicompanionProject: true, }; } + +function validateLoadCodeAssistResponse(res: LoadCodeAssistResponse): void { + if (!res) { + throw new Error('LoadCodeAssist returned empty response'); + } + if ( + !res.currentTier && + res.ineligibleTiers && + res.ineligibleTiers.length > 0 + ) { + // Check for VALIDATION_REQUIRED first - this is a recoverable state + const validationTier = res.ineligibleTiers.find( + (t) => + t.validationUrl && + t.reasonCode === IneligibleTierReasonCode.VALIDATION_REQUIRED, + ); + const validationUrl = validationTier?.validationUrl; + if (validationTier && validationUrl) { + throw new ValidationRequiredError( + `Account validation required: ${validationTier.reasonMessage}`, + undefined, + validationUrl, + validationTier.reasonMessage, + ); + } + + // For other ineligibility reasons, throw a generic error + const reasons = res.ineligibleTiers.map((t) => t.reasonMessage).join(', '); + throw new Error(reasons); + } +} diff --git a/packages/core/src/code_assist/types.ts b/packages/core/src/code_assist/types.ts index fd74d69b38..5e706cc207 100644 --- a/packages/core/src/code_assist/types.ts +++ b/packages/core/src/code_assist/types.ts @@ -82,6 +82,11 @@ export interface IneligibleTier { reasonMessage: string; tierId: UserTierId; tierName: string; + validationErrorMessage?: string; + validationUrl?: string; + validationUrlLinkText?: string; + validationLearnMoreUrl?: string; + validationLearnMoreLinkText?: string; } /** @@ -98,6 +103,7 @@ export enum IneligibleTierReasonCode { UNKNOWN = 'UNKNOWN', UNKNOWN_LOCATION = 'UNKNOWN_LOCATION', UNSUPPORTED_LOCATION = 'UNSUPPORTED_LOCATION', + VALIDATION_REQUIRED = 'VALIDATION_REQUIRED', // go/keep-sorted end } /** diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index fdd54c5150..348df878d5 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -47,6 +47,7 @@ export * from './fallback/types.js'; export * from './code_assist/codeAssist.js'; export * from './code_assist/oauth2.js'; export * from './code_assist/server.js'; +export * from './code_assist/setup.js'; export * from './code_assist/types.js'; export * from './code_assist/telemetry.js'; export * from './core/apiKeyCredentialStorage.js'; diff --git a/packages/core/src/utils/errors.ts b/packages/core/src/utils/errors.ts index 8db1153d92..86f1cc9b86 100644 --- a/packages/core/src/utils/errors.ts +++ b/packages/core/src/utils/errors.ts @@ -81,6 +81,13 @@ export class ForbiddenError extends Error {} export class UnauthorizedError extends Error {} export class BadRequestError extends Error {} +export class ChangeAuthRequestedError extends Error { + constructor() { + super('User requested to change authentication method'); + this.name = 'ChangeAuthRequestedError'; + } +} + interface ResponseData { error?: { code?: number; diff --git a/packages/core/src/utils/googleQuotaErrors.ts b/packages/core/src/utils/googleQuotaErrors.ts index f3a909a20a..0ecc14d93f 100644 --- a/packages/core/src/utils/googleQuotaErrors.ts +++ b/packages/core/src/utils/googleQuotaErrors.ts @@ -63,7 +63,7 @@ export class ValidationRequiredError extends Error { constructor( message: string, - override readonly cause: GoogleApiError, + override readonly cause?: GoogleApiError, validationLink?: string, validationDescription?: string, learnMoreUrl?: string,