mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-20 18:14:29 -07:00
feat(security) - Encrypted oauth flag (#8101)
Co-authored-by: Shi Shu <shii@google.com>
This commit is contained in:
@@ -7,7 +7,6 @@
|
||||
import { type Credentials } from 'google-auth-library';
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { OAuthCredentialStorage } from './oauth-credential-storage.js';
|
||||
import { HybridTokenStorage } from '../mcp/token-storage/hybrid-token-storage.js';
|
||||
import type { OAuthCredentials } from '../mcp/token-storage/types.js';
|
||||
|
||||
import * as path from 'node:path';
|
||||
@@ -15,7 +14,14 @@ import * as os from 'node:os';
|
||||
import { promises as fs } from 'node:fs';
|
||||
|
||||
// Mock external dependencies
|
||||
vi.mock('../mcp/token-storage/hybrid-token-storage.js');
|
||||
const mockHybridTokenStorage = vi.hoisted(() => ({
|
||||
getCredentials: vi.fn(),
|
||||
setCredentials: vi.fn(),
|
||||
deleteCredentials: vi.fn(),
|
||||
}));
|
||||
vi.mock('../mcp/token-storage/hybrid-token-storage.js', () => ({
|
||||
HybridTokenStorage: vi.fn(() => mockHybridTokenStorage),
|
||||
}));
|
||||
vi.mock('node:fs', () => ({
|
||||
promises: {
|
||||
readFile: vi.fn(),
|
||||
@@ -26,9 +32,6 @@ vi.mock('node:os');
|
||||
vi.mock('node:path');
|
||||
|
||||
describe('OAuthCredentialStorage', () => {
|
||||
let storage: HybridTokenStorage;
|
||||
let oauthStorage: OAuthCredentialStorage;
|
||||
|
||||
const mockCredentials: Credentials = {
|
||||
access_token: 'mock_access_token',
|
||||
refresh_token: 'mock_refresh_token',
|
||||
@@ -52,12 +55,13 @@ describe('OAuthCredentialStorage', () => {
|
||||
const oldFilePath = '/mock/home/.gemini/oauth.json';
|
||||
|
||||
beforeEach(() => {
|
||||
storage = new HybridTokenStorage('');
|
||||
oauthStorage = new OAuthCredentialStorage(storage);
|
||||
|
||||
vi.spyOn(storage, 'getCredentials').mockResolvedValue(null);
|
||||
vi.spyOn(storage, 'setCredentials').mockResolvedValue(undefined);
|
||||
vi.spyOn(storage, 'deleteCredentials').mockResolvedValue(undefined);
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(null);
|
||||
vi.spyOn(mockHybridTokenStorage, 'setCredentials').mockResolvedValue(
|
||||
undefined,
|
||||
);
|
||||
vi.spyOn(mockHybridTokenStorage, 'deleteCredentials').mockResolvedValue(
|
||||
undefined,
|
||||
);
|
||||
|
||||
vi.spyOn(fs, 'readFile').mockRejectedValue(new Error('File not found'));
|
||||
vi.spyOn(fs, 'rm').mockResolvedValue(undefined);
|
||||
@@ -72,25 +76,33 @@ describe('OAuthCredentialStorage', () => {
|
||||
|
||||
describe('loadCredentials', () => {
|
||||
it('should load credentials from HybridTokenStorage if available', async () => {
|
||||
vi.spyOn(storage, 'getCredentials').mockResolvedValue(mockMcpCredentials);
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
mockMcpCredentials,
|
||||
);
|
||||
|
||||
const result = await oauthStorage.loadCredentials();
|
||||
const result = await OAuthCredentialStorage.loadCredentials();
|
||||
|
||||
expect(storage.getCredentials).toHaveBeenCalledWith('main-account');
|
||||
expect(mockHybridTokenStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'main-account',
|
||||
);
|
||||
expect(result).toEqual(mockCredentials);
|
||||
});
|
||||
|
||||
it('should fallback to migrateFromFileStorage if no credentials in HybridTokenStorage', async () => {
|
||||
vi.spyOn(storage, 'getCredentials').mockResolvedValue(null);
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
null,
|
||||
);
|
||||
vi.spyOn(fs, 'readFile').mockResolvedValue(
|
||||
JSON.stringify(mockCredentials),
|
||||
);
|
||||
|
||||
const result = await oauthStorage.loadCredentials();
|
||||
const result = await OAuthCredentialStorage.loadCredentials();
|
||||
|
||||
expect(storage.getCredentials).toHaveBeenCalledWith('main-account');
|
||||
expect(mockHybridTokenStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'main-account',
|
||||
);
|
||||
expect(fs.readFile).toHaveBeenCalledWith(oldFilePath, 'utf-8');
|
||||
expect(storage.setCredentials).toHaveBeenCalled(); // Verify credentials were saved
|
||||
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalled(); // Verify credentials were saved
|
||||
expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true }); // Verify old file was removed
|
||||
expect(result).toEqual(mockCredentials);
|
||||
});
|
||||
@@ -101,41 +113,47 @@ describe('OAuthCredentialStorage', () => {
|
||||
code: 'ENOENT',
|
||||
});
|
||||
|
||||
const result = await oauthStorage.loadCredentials();
|
||||
const result = await OAuthCredentialStorage.loadCredentials();
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should throw an error if loading fails', async () => {
|
||||
vi.spyOn(storage, 'getCredentials').mockRejectedValue(
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockRejectedValue(
|
||||
new Error('Loading error'),
|
||||
);
|
||||
|
||||
await expect(oauthStorage.loadCredentials()).rejects.toThrow(
|
||||
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
|
||||
'Failed to load OAuth credentials',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if read file fails', async () => {
|
||||
vi.spyOn(storage, 'getCredentials').mockResolvedValue(null);
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
null,
|
||||
);
|
||||
vi.spyOn(fs, 'readFile').mockRejectedValue(
|
||||
new Error('Permission denied'),
|
||||
);
|
||||
|
||||
await expect(oauthStorage.loadCredentials()).rejects.toThrow(
|
||||
await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
|
||||
'Failed to load OAuth credentials',
|
||||
);
|
||||
});
|
||||
|
||||
it('should not throw error if migration file removal failed', async () => {
|
||||
vi.spyOn(storage, 'getCredentials').mockResolvedValue(null);
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
null,
|
||||
);
|
||||
vi.spyOn(fs, 'readFile').mockResolvedValue(
|
||||
JSON.stringify(mockCredentials),
|
||||
);
|
||||
vi.spyOn(oauthStorage, 'saveCredentials').mockResolvedValue(undefined);
|
||||
vi.spyOn(OAuthCredentialStorage, 'saveCredentials').mockResolvedValue(
|
||||
undefined,
|
||||
);
|
||||
vi.spyOn(fs, 'rm').mockRejectedValue(new Error('Deletion failed'));
|
||||
|
||||
const result = await oauthStorage.loadCredentials();
|
||||
const result = await OAuthCredentialStorage.loadCredentials();
|
||||
|
||||
expect(result).toEqual(mockCredentials);
|
||||
});
|
||||
@@ -143,9 +161,11 @@ describe('OAuthCredentialStorage', () => {
|
||||
|
||||
describe('saveCredentials', () => {
|
||||
it('should save credentials to HybridTokenStorage', async () => {
|
||||
await oauthStorage.saveCredentials(mockCredentials);
|
||||
await OAuthCredentialStorage.saveCredentials(mockCredentials);
|
||||
|
||||
expect(storage.setCredentials).toHaveBeenCalledWith(mockMcpCredentials);
|
||||
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
|
||||
mockMcpCredentials,
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if access_token is missing', async () => {
|
||||
@@ -154,7 +174,7 @@ describe('OAuthCredentialStorage', () => {
|
||||
access_token: undefined,
|
||||
};
|
||||
await expect(
|
||||
oauthStorage.saveCredentials(invalidCredentials),
|
||||
OAuthCredentialStorage.saveCredentials(invalidCredentials),
|
||||
).rejects.toThrow(
|
||||
'Attempted to save credentials without an access token.',
|
||||
);
|
||||
@@ -163,13 +183,15 @@ describe('OAuthCredentialStorage', () => {
|
||||
|
||||
describe('clearCredentials', () => {
|
||||
it('should delete credentials from HybridTokenStorage', async () => {
|
||||
await oauthStorage.clearCredentials();
|
||||
await OAuthCredentialStorage.clearCredentials();
|
||||
|
||||
expect(storage.deleteCredentials).toHaveBeenCalledWith('main-account');
|
||||
expect(mockHybridTokenStorage.deleteCredentials).toHaveBeenCalledWith(
|
||||
'main-account',
|
||||
);
|
||||
});
|
||||
|
||||
it('should attempt to remove the old file-based storage', async () => {
|
||||
await oauthStorage.clearCredentials();
|
||||
await OAuthCredentialStorage.clearCredentials();
|
||||
|
||||
expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true });
|
||||
});
|
||||
@@ -177,15 +199,17 @@ describe('OAuthCredentialStorage', () => {
|
||||
it('should not throw an error if deleting old file fails', async () => {
|
||||
vi.spyOn(fs, 'rm').mockRejectedValue(new Error('File deletion failed'));
|
||||
|
||||
await expect(oauthStorage.clearCredentials()).resolves.toBeUndefined();
|
||||
await expect(
|
||||
OAuthCredentialStorage.clearCredentials(),
|
||||
).resolves.toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw an error if clearing from HybridTokenStorage fails', async () => {
|
||||
vi.spyOn(storage, 'deleteCredentials').mockRejectedValue(
|
||||
vi.spyOn(mockHybridTokenStorage, 'deleteCredentials').mockRejectedValue(
|
||||
new Error('Deletion error'),
|
||||
);
|
||||
|
||||
await expect(oauthStorage.clearCredentials()).rejects.toThrow(
|
||||
await expect(OAuthCredentialStorage.clearCredentials()).rejects.toThrow(
|
||||
'Failed to clear OAuth credentials',
|
||||
);
|
||||
});
|
||||
|
||||
@@ -17,16 +17,14 @@ const KEYCHAIN_SERVICE_NAME = 'gemini-cli-oauth';
|
||||
const MAIN_ACCOUNT_KEY = 'main-account';
|
||||
|
||||
export class OAuthCredentialStorage {
|
||||
constructor(
|
||||
private readonly storage: HybridTokenStorage = new HybridTokenStorage(
|
||||
KEYCHAIN_SERVICE_NAME,
|
||||
),
|
||||
) {}
|
||||
private static storage: HybridTokenStorage = new HybridTokenStorage(
|
||||
KEYCHAIN_SERVICE_NAME,
|
||||
);
|
||||
|
||||
/**
|
||||
* Load cached OAuth credentials
|
||||
*/
|
||||
async loadCredentials(): Promise<Credentials | null> {
|
||||
static async loadCredentials(): Promise<Credentials | null> {
|
||||
try {
|
||||
const credentials = await this.storage.getCredentials(MAIN_ACCOUNT_KEY);
|
||||
|
||||
@@ -59,7 +57,7 @@ export class OAuthCredentialStorage {
|
||||
/**
|
||||
* Save OAuth credentials
|
||||
*/
|
||||
async saveCredentials(credentials: Credentials): Promise<void> {
|
||||
static async saveCredentials(credentials: Credentials): Promise<void> {
|
||||
if (!credentials.access_token) {
|
||||
throw new Error('Attempted to save credentials without an access token.');
|
||||
}
|
||||
@@ -83,7 +81,7 @@ export class OAuthCredentialStorage {
|
||||
/**
|
||||
* Clear cached OAuth credentials
|
||||
*/
|
||||
async clearCredentials(): Promise<void> {
|
||||
static async clearCredentials(): Promise<void> {
|
||||
try {
|
||||
await this.storage.deleteCredentials(MAIN_ACCOUNT_KEY);
|
||||
|
||||
@@ -99,7 +97,7 @@ export class OAuthCredentialStorage {
|
||||
/**
|
||||
* Migrate credentials from old file-based storage to keychain
|
||||
*/
|
||||
private async migrateFromFileStorage(): Promise<Credentials | null> {
|
||||
private static async migrateFromFileStorage(): Promise<Credentials | null> {
|
||||
const oldFilePath = path.join(os.homedir(), GEMINI_DIR, OAUTH_FILE);
|
||||
|
||||
let credsJson: string;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,6 +23,8 @@ import { UserAccountManager } from '../utils/userAccountManager.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import readline from 'node:readline';
|
||||
import { Storage } from '../config/storage.js';
|
||||
import { OAuthCredentialStorage } from './oauth-credential-storage.js';
|
||||
import { FORCE_ENCRYPTED_FILE_ENV_VAR } from '../mcp/token-storage/index.js';
|
||||
|
||||
const userAccountManager = new UserAccountManager();
|
||||
|
||||
@@ -63,6 +65,10 @@ export interface OauthWebLogin {
|
||||
|
||||
const oauthClientPromises = new Map<AuthType, Promise<OAuth2Client>>();
|
||||
|
||||
function getUseEncryptedStorageFlag() {
|
||||
return process.env[FORCE_ENCRYPTED_FILE_ENV_VAR] === 'true';
|
||||
}
|
||||
|
||||
async function initOauthClient(
|
||||
authType: AuthType,
|
||||
config: Config,
|
||||
@@ -74,6 +80,7 @@ async function initOauthClient(
|
||||
proxy: config.getProxy(),
|
||||
},
|
||||
});
|
||||
const useEncryptedStorage = getUseEncryptedStorageFlag();
|
||||
|
||||
if (
|
||||
process.env['GOOGLE_GENAI_USE_GCA'] &&
|
||||
@@ -87,7 +94,11 @@ async function initOauthClient(
|
||||
}
|
||||
|
||||
client.on('tokens', async (tokens: Credentials) => {
|
||||
await cacheCredentials(tokens);
|
||||
if (useEncryptedStorage) {
|
||||
await OAuthCredentialStorage.saveCredentials(tokens);
|
||||
} else {
|
||||
await cacheCredentials(tokens);
|
||||
}
|
||||
});
|
||||
|
||||
// If there are cached creds on disk, they always take precedence
|
||||
@@ -419,6 +430,16 @@ export function getAvailablePort(): Promise<number> {
|
||||
}
|
||||
|
||||
async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> {
|
||||
const useEncryptedStorage = getUseEncryptedStorageFlag();
|
||||
if (useEncryptedStorage) {
|
||||
const credentials = await OAuthCredentialStorage.loadCredentials();
|
||||
if (credentials) {
|
||||
client.setCredentials(credentials);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const pathsToTry = [
|
||||
Storage.getOAuthCredsPath(),
|
||||
process.env['GOOGLE_APPLICATION_CREDENTIALS'],
|
||||
@@ -470,7 +491,12 @@ export function clearOauthClientCache() {
|
||||
|
||||
export async function clearCachedCredentialFile() {
|
||||
try {
|
||||
await fs.rm(Storage.getOAuthCredsPath(), { force: true });
|
||||
const useEncryptedStorage = getUseEncryptedStorageFlag();
|
||||
if (useEncryptedStorage) {
|
||||
await OAuthCredentialStorage.clearCredentials();
|
||||
} else {
|
||||
await fs.rm(Storage.getOAuthCredsPath(), { force: true });
|
||||
}
|
||||
// Clear the Google Account ID cache when credentials are cleared
|
||||
await userAccountManager.clearCachedGoogleAccount();
|
||||
// Clear the in-memory OAuth client cache to force re-authentication
|
||||
|
||||
Reference in New Issue
Block a user