diff --git a/packages/core/src/code_assist/oauth-credential-storage.test.ts b/packages/core/src/code_assist/oauth-credential-storage.test.ts new file mode 100644 index 0000000000..2927d31e75 --- /dev/null +++ b/packages/core/src/code_assist/oauth-credential-storage.test.ts @@ -0,0 +1,193 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { type Credentials } from 'google-auth-library'; +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { OAuthCredentialStorage } from './oauth-credential-storage.js'; +import { HybridTokenStorage } from '../mcp/token-storage/hybrid-token-storage.js'; +import type { OAuthCredentials } from '../mcp/token-storage/types.js'; + +import * as path from 'node:path'; +import * as os from 'node:os'; +import { promises as fs } from 'node:fs'; + +// Mock external dependencies +vi.mock('../mcp/token-storage/hybrid-token-storage.js'); +vi.mock('node:fs', () => ({ + promises: { + readFile: vi.fn(), + rm: vi.fn(), + }, +})); +vi.mock('node:os'); +vi.mock('node:path'); + +describe('OAuthCredentialStorage', () => { + let storage: HybridTokenStorage; + let oauthStorage: OAuthCredentialStorage; + + const mockCredentials: Credentials = { + access_token: 'mock_access_token', + refresh_token: 'mock_refresh_token', + expiry_date: Date.now() + 3600 * 1000, + token_type: 'Bearer', + scope: 'email profile', + }; + + const mockMcpCredentials: OAuthCredentials = { + serverName: 'main-account', + token: { + accessToken: 'mock_access_token', + refreshToken: 'mock_refresh_token', + tokenType: 'Bearer', + scope: 'email profile', + expiresAt: mockCredentials.expiry_date!, + }, + updatedAt: expect.any(Number), + }; + + const oldFilePath = '/mock/home/.gemini/oauth.json'; + + beforeEach(() => { + storage = new HybridTokenStorage(''); + oauthStorage = new OAuthCredentialStorage(storage); + + vi.spyOn(storage, 'getCredentials').mockResolvedValue(null); + vi.spyOn(storage, 'setCredentials').mockResolvedValue(undefined); + vi.spyOn(storage, 'deleteCredentials').mockResolvedValue(undefined); + + vi.spyOn(fs, 'readFile').mockRejectedValue(new Error('File not found')); + vi.spyOn(fs, 'rm').mockResolvedValue(undefined); + + vi.spyOn(os, 'homedir').mockReturnValue('/mock/home'); + vi.spyOn(path, 'join').mockReturnValue(oldFilePath); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('loadCredentials', () => { + it('should load credentials from HybridTokenStorage if available', async () => { + vi.spyOn(storage, 'getCredentials').mockResolvedValue(mockMcpCredentials); + + const result = await oauthStorage.loadCredentials(); + + expect(storage.getCredentials).toHaveBeenCalledWith('main-account'); + expect(result).toEqual(mockCredentials); + }); + + it('should fallback to migrateFromFileStorage if no credentials in HybridTokenStorage', async () => { + vi.spyOn(storage, 'getCredentials').mockResolvedValue(null); + vi.spyOn(fs, 'readFile').mockResolvedValue( + JSON.stringify(mockCredentials), + ); + + const result = await oauthStorage.loadCredentials(); + + expect(storage.getCredentials).toHaveBeenCalledWith('main-account'); + expect(fs.readFile).toHaveBeenCalledWith(oldFilePath, 'utf-8'); + expect(storage.setCredentials).toHaveBeenCalled(); // Verify credentials were saved + expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true }); // Verify old file was removed + expect(result).toEqual(mockCredentials); + }); + + it('should return null if no credentials found and no old file to migrate', async () => { + vi.spyOn(fs, 'readFile').mockRejectedValue({ + message: 'File not found', + code: 'ENOENT', + }); + + const result = await oauthStorage.loadCredentials(); + + expect(result).toBeNull(); + }); + + it('should throw an error if loading fails', async () => { + vi.spyOn(storage, 'getCredentials').mockRejectedValue( + new Error('Loading error'), + ); + + await expect(oauthStorage.loadCredentials()).rejects.toThrow( + 'Failed to load OAuth credentials', + ); + }); + + it('should throw an error if read file fails', async () => { + vi.spyOn(storage, 'getCredentials').mockResolvedValue(null); + vi.spyOn(fs, 'readFile').mockRejectedValue( + new Error('Permission denied'), + ); + + await expect(oauthStorage.loadCredentials()).rejects.toThrow( + 'Failed to load OAuth credentials', + ); + }); + + it('should not throw error if migration file removal failed', async () => { + vi.spyOn(storage, 'getCredentials').mockResolvedValue(null); + vi.spyOn(fs, 'readFile').mockResolvedValue( + JSON.stringify(mockCredentials), + ); + vi.spyOn(oauthStorage, 'saveCredentials').mockResolvedValue(undefined); + vi.spyOn(fs, 'rm').mockRejectedValue(new Error('Deletion failed')); + + const result = await oauthStorage.loadCredentials(); + + expect(result).toEqual(mockCredentials); + }); + }); + + describe('saveCredentials', () => { + it('should save credentials to HybridTokenStorage', async () => { + await oauthStorage.saveCredentials(mockCredentials); + + expect(storage.setCredentials).toHaveBeenCalledWith(mockMcpCredentials); + }); + + it('should throw an error if access_token is missing', async () => { + const invalidCredentials: Credentials = { + ...mockCredentials, + access_token: undefined, + }; + await expect( + oauthStorage.saveCredentials(invalidCredentials), + ).rejects.toThrow( + 'Attempted to save credentials without an access token.', + ); + }); + }); + + describe('clearCredentials', () => { + it('should delete credentials from HybridTokenStorage', async () => { + await oauthStorage.clearCredentials(); + + expect(storage.deleteCredentials).toHaveBeenCalledWith('main-account'); + }); + + it('should attempt to remove the old file-based storage', async () => { + await oauthStorage.clearCredentials(); + + expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true }); + }); + + it('should not throw an error if deleting old file fails', async () => { + vi.spyOn(fs, 'rm').mockRejectedValue(new Error('File deletion failed')); + + await expect(oauthStorage.clearCredentials()).resolves.toBeUndefined(); + }); + + it('should throw an error if clearing from HybridTokenStorage fails', async () => { + vi.spyOn(storage, 'deleteCredentials').mockRejectedValue( + new Error('Deletion error'), + ); + + await expect(oauthStorage.clearCredentials()).rejects.toThrow( + 'Failed to clear OAuth credentials', + ); + }); + }); +}); diff --git a/packages/core/src/code_assist/oauth-credential-storage.ts b/packages/core/src/code_assist/oauth-credential-storage.ts new file mode 100644 index 0000000000..9c6f085f3f --- /dev/null +++ b/packages/core/src/code_assist/oauth-credential-storage.ts @@ -0,0 +1,132 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { type Credentials } from 'google-auth-library'; +import { HybridTokenStorage } from '../mcp/token-storage/hybrid-token-storage.js'; +import { OAUTH_FILE } from '../config/storage.js'; +import type { OAuthCredentials } from '../mcp/token-storage/types.js'; +import * as path from 'node:path'; +import * as os from 'node:os'; +import { promises as fs } from 'node:fs'; + +const GEMINI_DIR = '.gemini'; +const KEYCHAIN_SERVICE_NAME = 'gemini-cli-oauth'; +const MAIN_ACCOUNT_KEY = 'main-account'; + +export class OAuthCredentialStorage { + constructor( + private readonly storage: HybridTokenStorage = new HybridTokenStorage( + KEYCHAIN_SERVICE_NAME, + ), + ) {} + + /** + * Load cached OAuth credentials + */ + async loadCredentials(): Promise { + try { + const credentials = await this.storage.getCredentials(MAIN_ACCOUNT_KEY); + + if (credentials?.token) { + const { accessToken, refreshToken, expiresAt, tokenType, scope } = + credentials.token; + // Convert from OAuthCredentials format to Google Credentials format + const googleCreds: Credentials = { + access_token: accessToken, + refresh_token: refreshToken || undefined, + token_type: tokenType || undefined, + scope: scope || undefined, + }; + + if (expiresAt) { + googleCreds.expiry_date = expiresAt; + } + + return googleCreds; + } + + // Fallback: Try to migrate from old file-based storage + return await this.migrateFromFileStorage(); + } catch (error: unknown) { + console.error(error); + throw new Error('Failed to load OAuth credentials'); + } + } + + /** + * Save OAuth credentials + */ + async saveCredentials(credentials: Credentials): Promise { + if (!credentials.access_token) { + throw new Error('Attempted to save credentials without an access token.'); + } + + // Convert Google Credentials to OAuthCredentials format + const mcpCredentials: OAuthCredentials = { + serverName: MAIN_ACCOUNT_KEY, + token: { + accessToken: credentials.access_token, + refreshToken: credentials.refresh_token || undefined, + tokenType: credentials.token_type || 'Bearer', + scope: credentials.scope || undefined, + expiresAt: credentials.expiry_date || undefined, + }, + updatedAt: Date.now(), + }; + + await this.storage.setCredentials(mcpCredentials); + } + + /** + * Clear cached OAuth credentials + */ + async clearCredentials(): Promise { + try { + await this.storage.deleteCredentials(MAIN_ACCOUNT_KEY); + + // Also try to remove the old file if it exists + const oldFilePath = path.join(os.homedir(), GEMINI_DIR, OAUTH_FILE); + await fs.rm(oldFilePath, { force: true }).catch(() => {}); + } catch (error: unknown) { + console.error(error); + throw new Error('Failed to clear OAuth credentials'); + } + } + + /** + * Migrate credentials from old file-based storage to keychain + */ + private async migrateFromFileStorage(): Promise { + const oldFilePath = path.join(os.homedir(), GEMINI_DIR, OAUTH_FILE); + + let credsJson: string; + try { + credsJson = await fs.readFile(oldFilePath, 'utf-8'); + } catch (error: unknown) { + if ( + typeof error === 'object' && + error !== null && + 'code' in error && + error.code === 'ENOENT' + ) { + // File doesn't exist, so no migration. + return null; + } + // Other read errors should propagate. + throw error; + } + + const credentials = JSON.parse(credsJson) as Credentials; + + // Save to new storage + await this.saveCredentials(credentials); + + // Remove old file after successful migration + await fs.rm(oldFilePath, { force: true }).catch(() => {}); + + return credentials; + } +} diff --git a/packages/core/src/config/storage.ts b/packages/core/src/config/storage.ts index d08a15ceeb..6442b87c87 100644 --- a/packages/core/src/config/storage.ts +++ b/packages/core/src/config/storage.ts @@ -11,6 +11,7 @@ import * as fs from 'node:fs'; export const GEMINI_DIR = '.gemini'; export const GOOGLE_ACCOUNTS_FILENAME = 'google_accounts.json'; +export const OAUTH_FILE = 'oauth_creds.json'; const TMP_DIR_NAME = 'tmp'; export class Storage { @@ -71,7 +72,7 @@ export class Storage { } static getOAuthCredsPath(): string { - return path.join(Storage.getGlobalGeminiDir(), 'oauth_creds.json'); + return path.join(Storage.getGlobalGeminiDir(), OAUTH_FILE); } getProjectRoot(): string {