diff --git a/packages/core/src/code_assist/oauth2.test.ts b/packages/core/src/code_assist/oauth2.test.ts index f4b12f8cf2..8e37537763 100644 --- a/packages/core/src/code_assist/oauth2.test.ts +++ b/packages/core/src/code_assist/oauth2.test.ts @@ -29,6 +29,7 @@ import { FORCE_ENCRYPTED_FILE_ENV_VAR } from '../mcp/token-storage/index.js'; import { GEMINI_DIR } from '../utils/paths.js'; import { debugLogger } from '../utils/debugLogger.js'; import { writeToStdout } from '../utils/stdio.js'; +import { FatalCancellationError } from '../utils/errors.js'; vi.mock('os', async (importOriginal) => { const os = await importOriginal(); @@ -296,6 +297,7 @@ describe('oauth2', () => { generateAuthUrl: mockGenerateAuthUrl, getToken: mockGetToken, generateCodeVerifierAsync: mockGenerateCodeVerifierAsync, + getAccessToken: vi.fn().mockResolvedValue({ token: 'test-token' }), on: vi.fn(), credentials: {}, } as unknown as OAuth2Client; @@ -1100,6 +1102,139 @@ describe('oauth2', () => { }); }); + describe('cancellation', () => { + it('should cancel when SIGINT is received', async () => { + const mockAuthUrl = 'https://example.com/auth'; + const mockState = 'test-state'; + const mockOAuth2Client = { + generateAuthUrl: vi.fn().mockReturnValue(mockAuthUrl), + on: vi.fn(), + } as unknown as OAuth2Client; + vi.mocked(OAuth2Client).mockImplementation(() => mockOAuth2Client); + + vi.spyOn(crypto, 'randomBytes').mockReturnValue(mockState as never); + vi.mocked(open).mockImplementation( + async () => ({ on: vi.fn() }) as never, + ); + + // Mock createServer to return a server that doesn't do anything (keeps promise pending) + const mockHttpServer = { + listen: vi.fn(), + close: vi.fn(), + on: vi.fn(), + address: () => ({ port: 3000 }), + }; + (http.createServer as Mock).mockImplementation( + () => mockHttpServer as unknown as http.Server, + ); + + // Spy on process.on to capture the SIGINT handler + let sigIntHandler: (() => void) | undefined; + const originalOn = process.on; + const processOnSpy = vi + .spyOn(process, 'on') + .mockImplementation( + ( + event: string | symbol, + listener: (...args: unknown[]) => void, + ) => { + if (event === 'SIGINT') { + sigIntHandler = listener as () => void; + } + return originalOn.call(process, event, listener); + }, + ); + const processRemoveListenerSpy = vi.spyOn(process, 'removeListener'); + + const clientPromise = getOauthClient( + AuthType.LOGIN_WITH_GOOGLE, + mockConfig, + ); + + // Wait a tick to ensure the SIGINT handler is registered + await new Promise((resolve) => setTimeout(resolve, 0)); + + expect(sigIntHandler).toBeDefined(); + + // Trigger SIGINT + if (sigIntHandler) { + sigIntHandler(); + } + + await expect(clientPromise).rejects.toThrow(FatalCancellationError); + expect(processRemoveListenerSpy).toHaveBeenCalledWith( + 'SIGINT', + expect.any(Function), + ); + + processOnSpy.mockRestore(); + processRemoveListenerSpy.mockRestore(); + }); + + it('should cancel when Ctrl+C (0x03) is received on stdin', async () => { + const mockAuthUrl = 'https://example.com/auth'; + const mockState = 'test-state'; + const mockOAuth2Client = { + generateAuthUrl: vi.fn().mockReturnValue(mockAuthUrl), + on: vi.fn(), + } as unknown as OAuth2Client; + vi.mocked(OAuth2Client).mockImplementation(() => mockOAuth2Client); + + vi.spyOn(crypto, 'randomBytes').mockReturnValue(mockState as never); + vi.mocked(open).mockImplementation( + async () => ({ on: vi.fn() }) as never, + ); + + const mockHttpServer = { + listen: vi.fn(), + close: vi.fn(), + on: vi.fn(), + address: () => ({ port: 3000 }), + }; + (http.createServer as Mock).mockImplementation( + () => mockHttpServer as unknown as http.Server, + ); + + // Spy on process.stdin.on + let dataHandler: ((data: Buffer) => void) | undefined; + const stdinOnSpy = vi + .spyOn(process.stdin, 'on') + .mockImplementation((event: string | symbol, listener) => { + if (event === 'data') { + dataHandler = listener as (data: Buffer) => void; + } + return process.stdin; + }); + const stdinRemoveListenerSpy = vi.spyOn( + process.stdin, + 'removeListener', + ); + + const clientPromise = getOauthClient( + AuthType.LOGIN_WITH_GOOGLE, + mockConfig, + ); + + await new Promise((resolve) => setTimeout(resolve, 0)); + + expect(dataHandler).toBeDefined(); + + // Trigger Ctrl+C + if (dataHandler) { + dataHandler(Buffer.from([0x03])); + } + + await expect(clientPromise).rejects.toThrow(FatalCancellationError); + expect(stdinRemoveListenerSpy).toHaveBeenCalledWith( + 'data', + expect.any(Function), + ); + + stdinOnSpy.mockRestore(); + stdinRemoveListenerSpy.mockRestore(); + }); + }); + describe('clearCachedCredentialFile', () => { it('should clear cached credentials and Google account', async () => { const cachedCreds = { refresh_token: 'test-token' }; diff --git a/packages/core/src/code_assist/oauth2.ts b/packages/core/src/code_assist/oauth2.ts index 8c3cd8828f..406e054f1e 100644 --- a/packages/core/src/code_assist/oauth2.ts +++ b/packages/core/src/code_assist/oauth2.ts @@ -325,7 +325,41 @@ async function initOauthClient( }, authTimeout); }); - await Promise.race([webLogin.loginCompletePromise, timeoutPromise]); + // Listen for SIGINT to stop waiting for auth so the terminal doesn't hang + // if the user chooses not to auth. + let sigIntHandler: (() => void) | undefined; + let stdinHandler: ((data: Buffer) => void) | undefined; + const cancellationPromise = new Promise((_, reject) => { + sigIntHandler = () => + reject(new FatalCancellationError('Authentication cancelled by user.')); + process.on('SIGINT', sigIntHandler); + + // Note that SIGINT might not get raised on Ctrl+C in raw mode + // so we also need to look for Ctrl+C directly in stdin. + stdinHandler = (data) => { + if (data.includes(0x03)) { + reject( + new FatalCancellationError('Authentication cancelled by user.'), + ); + } + }; + process.stdin.on('data', stdinHandler); + }); + + try { + await Promise.race([ + webLogin.loginCompletePromise, + timeoutPromise, + cancellationPromise, + ]); + } finally { + if (sigIntHandler) { + process.removeListener('SIGINT', sigIntHandler); + } + if (stdinHandler) { + process.stdin.removeListener('data', stdinHandler); + } + } coreEvents.emit(CoreEvent.UserFeedback, { severity: 'info',