From 1e08b150f74e55c027c8da156dddf53b77a75fb5 Mon Sep 17 00:00:00 2001 From: Shreya Keshive Date: Tue, 30 Dec 2025 11:09:00 -0500 Subject: [PATCH] refactor(auth): Refactor non-interactive mode auth validation & refresh (#15679) --- packages/cli/src/gemini.test.tsx | 3 + packages/cli/src/gemini.tsx | 35 +++++--- packages/cli/src/gemini_cleanup.test.tsx | 1 + .../src/validateNonInterActiveAuth.test.ts | 79 +++++++------------ .../cli/src/validateNonInterActiveAuth.ts | 3 +- 5 files changed, 56 insertions(+), 65 deletions(-) diff --git a/packages/cli/src/gemini.test.tsx b/packages/cli/src/gemini.test.tsx index 632b15b9a8..a6905f736a 100644 --- a/packages/cli/src/gemini.test.tsx +++ b/packages/cli/src/gemini.test.tsx @@ -281,6 +281,7 @@ describe('gemini.tsx main function', () => { getOutputFormat: () => 'text', getExtensions: () => [], getUsageStatisticsEnabled: () => false, + refreshAuth: vi.fn(), setTerminalBackground: vi.fn(), } as unknown as Config; }); @@ -783,6 +784,7 @@ describe('gemini.tsx main function kitty protocol', () => { getFileFilteringRespectGitIgnore: () => true, getOutputFormat: () => 'text', getUsageStatisticsEnabled: () => false, + refreshAuth: vi.fn(), setTerminalBackground: vi.fn(), } as any); // eslint-disable-line @typescript-eslint/no-explicit-any @@ -1019,6 +1021,7 @@ describe('gemini.tsx main function kitty protocol', () => { getFileFilteringRespectGitIgnore: () => true, getOutputFormat: () => 'text', getUsageStatisticsEnabled: () => false, + refreshAuth: vi.fn(), setTerminalBackground: vi.fn(), } as any); // eslint-disable-line @typescript-eslint/no-explicit-any diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 3279155d1a..eacef49cb3 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -401,18 +401,28 @@ export async function main() { settings.merged.security?.auth?.selectedType && !settings.merged.security?.auth?.useExternal ) { - // Validate authentication here because the sandbox will interfere with the Oauth2 web redirect. try { - const err = validateAuthMethod( - settings.merged.security.auth.selectedType, - ); - if (err) { - throw new Error(err); - } + if (partialConfig.isInteractive()) { + // Validate authentication here because the sandbox will interfere with the Oauth2 web redirect. + const err = validateAuthMethod( + settings.merged.security.auth.selectedType, + ); + if (err) { + throw new Error(err); + } - await partialConfig.refreshAuth( - settings.merged.security.auth.selectedType, - ); + await partialConfig.refreshAuth( + settings.merged.security.auth.selectedType, + ); + } else { + const authType = await validateNonInteractiveAuth( + settings.merged.security?.auth?.selectedType, + settings.merged.security?.auth?.useExternal, + partialConfig, + settings, + ); + await partialConfig.refreshAuth(authType); + } } catch (err) { debugLogger.error('Error authenticating:', err); await runExitCleanup(); @@ -667,12 +677,13 @@ export async function main() { ), ); - const nonInteractiveConfig = await validateNonInteractiveAuth( + const authType = await validateNonInteractiveAuth( settings.merged.security?.auth?.selectedType, settings.merged.security?.auth?.useExternal, config, settings, ); + await config.refreshAuth(authType); if (config.getDebugMode()) { debugLogger.log('Session ID: %s', sessionId); @@ -684,7 +695,7 @@ export async function main() { initializeOutputListenersAndFlush(); await runNonInteractive({ - config: nonInteractiveConfig, + config, settings, input, prompt_id, diff --git a/packages/cli/src/gemini_cleanup.test.tsx b/packages/cli/src/gemini_cleanup.test.tsx index ca146a9181..95471ef031 100644 --- a/packages/cli/src/gemini_cleanup.test.tsx +++ b/packages/cli/src/gemini_cleanup.test.tsx @@ -213,6 +213,7 @@ describe('gemini.tsx main function cleanup', () => { getOutputFormat: vi.fn(() => 'text'), getUsageStatisticsEnabled: vi.fn(() => false), setTerminalBackground: vi.fn(), + refreshAuth: vi.fn(), } as any); // eslint-disable-line @typescript-eslint/no-explicit-any try { diff --git a/packages/cli/src/validateNonInterActiveAuth.test.ts b/packages/cli/src/validateNonInterActiveAuth.test.ts index e3cc33ea23..06f1067730 100644 --- a/packages/cli/src/validateNonInterActiveAuth.test.ts +++ b/packages/cli/src/validateNonInterActiveAuth.test.ts @@ -12,7 +12,6 @@ import { beforeEach, afterEach, type MockInstance, - type Mock, } from 'vitest'; import { validateNonInteractiveAuth } from './validateNonInterActiveAuth.js'; import { @@ -40,7 +39,6 @@ describe('validateNonInterActiveAuth', () => { let debugLoggerErrorSpy: ReturnType; let coreEventsEmitFeedbackSpy: MockInstance; let processExitSpy: MockInstance; - let refreshAuthMock: Mock; let mockSettings: LoadedSettings; beforeEach(() => { @@ -62,7 +60,6 @@ describe('validateNonInterActiveAuth', () => { throw new Error(`process.exit(${code}) called`); }); vi.spyOn(auth, 'validateAuthMethod').mockReturnValue(null); - refreshAuthMock = vi.fn().mockImplementation(async () => 'refreshed'); mockSettings = { system: { path: '', settings: {} }, systemDefaults: { path: '', settings: {} }, @@ -105,7 +102,6 @@ describe('validateNonInterActiveAuth', () => { it('exits if no auth type is configured or env vars set', async () => { const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, getOutputFormat: vi.fn().mockReturnValue(OutputFormat.TEXT), getContentGeneratorConfig: vi .fn() @@ -134,61 +130,57 @@ describe('validateNonInterActiveAuth', () => { it('uses LOGIN_WITH_GOOGLE if GOOGLE_GENAI_USE_GCA is set', async () => { process.env['GOOGLE_GENAI_USE_GCA'] = 'true'; - const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, - }); + const nonInteractiveConfig = createLocalMockConfig({}); await validateNonInteractiveAuth( undefined, undefined, nonInteractiveConfig, mockSettings, ); - expect(refreshAuthMock).toHaveBeenCalledWith(AuthType.LOGIN_WITH_GOOGLE); + expect(processExitSpy).not.toHaveBeenCalled(); + expect(debugLoggerErrorSpy).not.toHaveBeenCalled(); }); it('uses USE_GEMINI if GEMINI_API_KEY is set', async () => { process.env['GEMINI_API_KEY'] = 'fake-key'; - const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, - }); + const nonInteractiveConfig = createLocalMockConfig({}); await validateNonInteractiveAuth( undefined, undefined, nonInteractiveConfig, mockSettings, ); - expect(refreshAuthMock).toHaveBeenCalledWith(AuthType.USE_GEMINI); + expect(processExitSpy).not.toHaveBeenCalled(); + expect(debugLoggerErrorSpy).not.toHaveBeenCalled(); }); it('uses USE_VERTEX_AI if GOOGLE_GENAI_USE_VERTEXAI is true (with GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION)', async () => { process.env['GOOGLE_GENAI_USE_VERTEXAI'] = 'true'; process.env['GOOGLE_CLOUD_PROJECT'] = 'test-project'; process.env['GOOGLE_CLOUD_LOCATION'] = 'us-central1'; - const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, - }); + const nonInteractiveConfig = createLocalMockConfig({}); await validateNonInteractiveAuth( undefined, undefined, nonInteractiveConfig, mockSettings, ); - expect(refreshAuthMock).toHaveBeenCalledWith(AuthType.USE_VERTEX_AI); + expect(processExitSpy).not.toHaveBeenCalled(); + expect(debugLoggerErrorSpy).not.toHaveBeenCalled(); }); it('uses USE_VERTEX_AI if GOOGLE_GENAI_USE_VERTEXAI is true and GOOGLE_API_KEY is set', async () => { process.env['GOOGLE_GENAI_USE_VERTEXAI'] = 'true'; process.env['GOOGLE_API_KEY'] = 'vertex-api-key'; - const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, - }); + const nonInteractiveConfig = createLocalMockConfig({}); await validateNonInteractiveAuth( undefined, undefined, nonInteractiveConfig, mockSettings, ); - expect(refreshAuthMock).toHaveBeenCalledWith(AuthType.USE_VERTEX_AI); + expect(processExitSpy).not.toHaveBeenCalled(); + expect(debugLoggerErrorSpy).not.toHaveBeenCalled(); }); it('uses LOGIN_WITH_GOOGLE if GOOGLE_GENAI_USE_GCA is set, even with other env vars', async () => { @@ -197,16 +189,15 @@ describe('validateNonInterActiveAuth', () => { process.env['GOOGLE_GENAI_USE_VERTEXAI'] = 'true'; process.env['GOOGLE_CLOUD_PROJECT'] = 'test-project'; process.env['GOOGLE_CLOUD_LOCATION'] = 'us-central1'; - const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, - }); + const nonInteractiveConfig = createLocalMockConfig({}); await validateNonInteractiveAuth( undefined, undefined, nonInteractiveConfig, mockSettings, ); - expect(refreshAuthMock).toHaveBeenCalledWith(AuthType.LOGIN_WITH_GOOGLE); + expect(processExitSpy).not.toHaveBeenCalled(); + expect(debugLoggerErrorSpy).not.toHaveBeenCalled(); }); it('uses USE_VERTEX_AI if both GEMINI_API_KEY and GOOGLE_GENAI_USE_VERTEXAI are set', async () => { @@ -214,16 +205,15 @@ describe('validateNonInterActiveAuth', () => { process.env['GOOGLE_GENAI_USE_VERTEXAI'] = 'true'; process.env['GOOGLE_CLOUD_PROJECT'] = 'test-project'; process.env['GOOGLE_CLOUD_LOCATION'] = 'us-central1'; - const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, - }); + const nonInteractiveConfig = createLocalMockConfig({}); await validateNonInteractiveAuth( undefined, undefined, nonInteractiveConfig, mockSettings, ); - expect(refreshAuthMock).toHaveBeenCalledWith(AuthType.USE_VERTEX_AI); + expect(processExitSpy).not.toHaveBeenCalled(); + expect(debugLoggerErrorSpy).not.toHaveBeenCalled(); }); it('uses USE_GEMINI if GOOGLE_GENAI_USE_VERTEXAI is false, GEMINI_API_KEY is set, and project/location are available', async () => { @@ -231,37 +221,34 @@ describe('validateNonInterActiveAuth', () => { process.env['GEMINI_API_KEY'] = 'fake-key'; process.env['GOOGLE_CLOUD_PROJECT'] = 'test-project'; process.env['GOOGLE_CLOUD_LOCATION'] = 'us-central1'; - const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, - }); + const nonInteractiveConfig = createLocalMockConfig({}); await validateNonInteractiveAuth( undefined, undefined, nonInteractiveConfig, mockSettings, ); - expect(refreshAuthMock).toHaveBeenCalledWith(AuthType.USE_GEMINI); + expect(processExitSpy).not.toHaveBeenCalled(); + expect(debugLoggerErrorSpy).not.toHaveBeenCalled(); }); it('uses configuredAuthType over environment variables', async () => { process.env['GEMINI_API_KEY'] = 'fake-key'; - const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, - }); + const nonInteractiveConfig = createLocalMockConfig({}); await validateNonInteractiveAuth( AuthType.LOGIN_WITH_GOOGLE, undefined, nonInteractiveConfig, mockSettings, ); - expect(refreshAuthMock).toHaveBeenCalledWith(AuthType.LOGIN_WITH_GOOGLE); + expect(processExitSpy).not.toHaveBeenCalled(); + expect(debugLoggerErrorSpy).not.toHaveBeenCalled(); }); it('exits if validateAuthMethod returns error', async () => { // Mock validateAuthMethod to return error vi.spyOn(auth, 'validateAuthMethod').mockReturnValue('Auth error!'); const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, getOutputFormat: vi.fn().mockReturnValue(OutputFormat.TEXT), getContentGeneratorConfig: vi .fn() @@ -291,9 +278,7 @@ describe('validateNonInterActiveAuth', () => { const validateAuthMethodSpy = vi .spyOn(auth, 'validateAuthMethod') .mockReturnValue('Auth error!'); - const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, - }); + const nonInteractiveConfig = createLocalMockConfig({}); // Even with an invalid auth type, it should not exit // because validation is skipped. await validateNonInteractiveAuth( @@ -307,30 +292,26 @@ describe('validateNonInterActiveAuth', () => { expect(debugLoggerErrorSpy).not.toHaveBeenCalled(); expect(coreEventsEmitFeedbackSpy).not.toHaveBeenCalled(); expect(processExitSpy).not.toHaveBeenCalled(); - // We still expect refreshAuth to be called with the (invalid) type - expect(refreshAuthMock).toHaveBeenCalledWith('invalid-auth-type'); }); it('succeeds if effectiveAuthType matches enforcedAuthType', async () => { mockSettings.merged.security!.auth!.enforcedType = AuthType.USE_GEMINI; process.env['GEMINI_API_KEY'] = 'fake-key'; - const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, - }); + const nonInteractiveConfig = createLocalMockConfig({}); await validateNonInteractiveAuth( undefined, undefined, nonInteractiveConfig, mockSettings, ); - expect(refreshAuthMock).toHaveBeenCalledWith(AuthType.USE_GEMINI); + expect(processExitSpy).not.toHaveBeenCalled(); + expect(debugLoggerErrorSpy).not.toHaveBeenCalled(); }); it('exits if configuredAuthType does not match enforcedAuthType', async () => { mockSettings.merged.security!.auth!.enforcedType = AuthType.LOGIN_WITH_GOOGLE; const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, getOutputFormat: vi.fn().mockReturnValue(OutputFormat.TEXT), }); try { @@ -359,7 +340,6 @@ describe('validateNonInterActiveAuth', () => { AuthType.LOGIN_WITH_GOOGLE; process.env['GEMINI_API_KEY'] = 'fake-key'; const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, getOutputFormat: vi.fn().mockReturnValue(OutputFormat.TEXT), }); try { @@ -386,7 +366,6 @@ describe('validateNonInterActiveAuth', () => { describe('JSON output mode', () => { it(`prints JSON error when no auth is configured and exits with code ${ExitCodes.FATAL_AUTHENTICATION_ERROR}`, async () => { const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, getOutputFormat: vi.fn().mockReturnValue(OutputFormat.JSON), getContentGeneratorConfig: vi .fn() @@ -421,7 +400,6 @@ describe('validateNonInterActiveAuth', () => { it(`prints JSON error when enforced auth mismatches current auth and exits with code ${ExitCodes.FATAL_AUTHENTICATION_ERROR}`, async () => { mockSettings.merged.security!.auth!.enforcedType = AuthType.USE_GEMINI; const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, getOutputFormat: vi.fn().mockReturnValue(OutputFormat.JSON), getContentGeneratorConfig: vi .fn() @@ -460,7 +438,6 @@ describe('validateNonInterActiveAuth', () => { process.env['GEMINI_API_KEY'] = 'fake-key'; const nonInteractiveConfig = createLocalMockConfig({ - refreshAuth: refreshAuthMock, getOutputFormat: vi.fn().mockReturnValue(OutputFormat.JSON), getContentGeneratorConfig: vi .fn() diff --git a/packages/cli/src/validateNonInterActiveAuth.ts b/packages/cli/src/validateNonInterActiveAuth.ts index e3a63acbfe..82fb7cd725 100644 --- a/packages/cli/src/validateNonInterActiveAuth.ts +++ b/packages/cli/src/validateNonInterActiveAuth.ts @@ -61,8 +61,7 @@ export async function validateNonInteractiveAuth( } } - await nonInteractiveConfig.refreshAuth(authType); - return nonInteractiveConfig; + return authType; } catch (error) { if (nonInteractiveConfig.getOutputFormat() === OutputFormat.JSON) { handleError(