feat(security) - Encrypted oauth flag (#8101)

Co-authored-by: Shi Shu <shii@google.com>
This commit is contained in:
shishu314
2025-09-16 10:05:29 -04:00
committed by GitHub
parent 1634d5fcca
commit c999b7e354
4 changed files with 1100 additions and 841 deletions
@@ -7,7 +7,6 @@
import { type Credentials } from 'google-auth-library'; import { type Credentials } from 'google-auth-library';
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { OAuthCredentialStorage } from './oauth-credential-storage.js'; import { OAuthCredentialStorage } from './oauth-credential-storage.js';
import { HybridTokenStorage } from '../mcp/token-storage/hybrid-token-storage.js';
import type { OAuthCredentials } from '../mcp/token-storage/types.js'; import type { OAuthCredentials } from '../mcp/token-storage/types.js';
import * as path from 'node:path'; import * as path from 'node:path';
@@ -15,7 +14,14 @@ import * as os from 'node:os';
import { promises as fs } from 'node:fs'; import { promises as fs } from 'node:fs';
// Mock external dependencies // Mock external dependencies
vi.mock('../mcp/token-storage/hybrid-token-storage.js'); const mockHybridTokenStorage = vi.hoisted(() => ({
getCredentials: vi.fn(),
setCredentials: vi.fn(),
deleteCredentials: vi.fn(),
}));
vi.mock('../mcp/token-storage/hybrid-token-storage.js', () => ({
HybridTokenStorage: vi.fn(() => mockHybridTokenStorage),
}));
vi.mock('node:fs', () => ({ vi.mock('node:fs', () => ({
promises: { promises: {
readFile: vi.fn(), readFile: vi.fn(),
@@ -26,9 +32,6 @@ vi.mock('node:os');
vi.mock('node:path'); vi.mock('node:path');
describe('OAuthCredentialStorage', () => { describe('OAuthCredentialStorage', () => {
let storage: HybridTokenStorage;
let oauthStorage: OAuthCredentialStorage;
const mockCredentials: Credentials = { const mockCredentials: Credentials = {
access_token: 'mock_access_token', access_token: 'mock_access_token',
refresh_token: 'mock_refresh_token', refresh_token: 'mock_refresh_token',
@@ -52,12 +55,13 @@ describe('OAuthCredentialStorage', () => {
const oldFilePath = '/mock/home/.gemini/oauth.json'; const oldFilePath = '/mock/home/.gemini/oauth.json';
beforeEach(() => { beforeEach(() => {
storage = new HybridTokenStorage(''); vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(null);
oauthStorage = new OAuthCredentialStorage(storage); vi.spyOn(mockHybridTokenStorage, 'setCredentials').mockResolvedValue(
undefined,
vi.spyOn(storage, 'getCredentials').mockResolvedValue(null); );
vi.spyOn(storage, 'setCredentials').mockResolvedValue(undefined); vi.spyOn(mockHybridTokenStorage, 'deleteCredentials').mockResolvedValue(
vi.spyOn(storage, 'deleteCredentials').mockResolvedValue(undefined); undefined,
);
vi.spyOn(fs, 'readFile').mockRejectedValue(new Error('File not found')); vi.spyOn(fs, 'readFile').mockRejectedValue(new Error('File not found'));
vi.spyOn(fs, 'rm').mockResolvedValue(undefined); vi.spyOn(fs, 'rm').mockResolvedValue(undefined);
@@ -72,25 +76,33 @@ describe('OAuthCredentialStorage', () => {
describe('loadCredentials', () => { describe('loadCredentials', () => {
it('should load credentials from HybridTokenStorage if available', async () => { it('should load credentials from HybridTokenStorage if available', async () => {
vi.spyOn(storage, 'getCredentials').mockResolvedValue(mockMcpCredentials); vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
mockMcpCredentials,
);
const result = await oauthStorage.loadCredentials(); const result = await OAuthCredentialStorage.loadCredentials();
expect(storage.getCredentials).toHaveBeenCalledWith('main-account'); expect(mockHybridTokenStorage.getCredentials).toHaveBeenCalledWith(
'main-account',
);
expect(result).toEqual(mockCredentials); expect(result).toEqual(mockCredentials);
}); });
it('should fallback to migrateFromFileStorage if no credentials in HybridTokenStorage', async () => { it('should fallback to migrateFromFileStorage if no credentials in HybridTokenStorage', async () => {
vi.spyOn(storage, 'getCredentials').mockResolvedValue(null); vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
null,
);
vi.spyOn(fs, 'readFile').mockResolvedValue( vi.spyOn(fs, 'readFile').mockResolvedValue(
JSON.stringify(mockCredentials), JSON.stringify(mockCredentials),
); );
const result = await oauthStorage.loadCredentials(); const result = await OAuthCredentialStorage.loadCredentials();
expect(storage.getCredentials).toHaveBeenCalledWith('main-account'); expect(mockHybridTokenStorage.getCredentials).toHaveBeenCalledWith(
'main-account',
);
expect(fs.readFile).toHaveBeenCalledWith(oldFilePath, 'utf-8'); expect(fs.readFile).toHaveBeenCalledWith(oldFilePath, 'utf-8');
expect(storage.setCredentials).toHaveBeenCalled(); // Verify credentials were saved expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalled(); // Verify credentials were saved
expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true }); // Verify old file was removed expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true }); // Verify old file was removed
expect(result).toEqual(mockCredentials); expect(result).toEqual(mockCredentials);
}); });
@@ -101,41 +113,47 @@ describe('OAuthCredentialStorage', () => {
code: 'ENOENT', code: 'ENOENT',
}); });
const result = await oauthStorage.loadCredentials(); const result = await OAuthCredentialStorage.loadCredentials();
expect(result).toBeNull(); expect(result).toBeNull();
}); });
it('should throw an error if loading fails', async () => { it('should throw an error if loading fails', async () => {
vi.spyOn(storage, 'getCredentials').mockRejectedValue( vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockRejectedValue(
new Error('Loading error'), new Error('Loading error'),
); );
await expect(oauthStorage.loadCredentials()).rejects.toThrow( await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
'Failed to load OAuth credentials', 'Failed to load OAuth credentials',
); );
}); });
it('should throw an error if read file fails', async () => { it('should throw an error if read file fails', async () => {
vi.spyOn(storage, 'getCredentials').mockResolvedValue(null); vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
null,
);
vi.spyOn(fs, 'readFile').mockRejectedValue( vi.spyOn(fs, 'readFile').mockRejectedValue(
new Error('Permission denied'), new Error('Permission denied'),
); );
await expect(oauthStorage.loadCredentials()).rejects.toThrow( await expect(OAuthCredentialStorage.loadCredentials()).rejects.toThrow(
'Failed to load OAuth credentials', 'Failed to load OAuth credentials',
); );
}); });
it('should not throw error if migration file removal failed', async () => { it('should not throw error if migration file removal failed', async () => {
vi.spyOn(storage, 'getCredentials').mockResolvedValue(null); vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
null,
);
vi.spyOn(fs, 'readFile').mockResolvedValue( vi.spyOn(fs, 'readFile').mockResolvedValue(
JSON.stringify(mockCredentials), JSON.stringify(mockCredentials),
); );
vi.spyOn(oauthStorage, 'saveCredentials').mockResolvedValue(undefined); vi.spyOn(OAuthCredentialStorage, 'saveCredentials').mockResolvedValue(
undefined,
);
vi.spyOn(fs, 'rm').mockRejectedValue(new Error('Deletion failed')); vi.spyOn(fs, 'rm').mockRejectedValue(new Error('Deletion failed'));
const result = await oauthStorage.loadCredentials(); const result = await OAuthCredentialStorage.loadCredentials();
expect(result).toEqual(mockCredentials); expect(result).toEqual(mockCredentials);
}); });
@@ -143,9 +161,11 @@ describe('OAuthCredentialStorage', () => {
describe('saveCredentials', () => { describe('saveCredentials', () => {
it('should save credentials to HybridTokenStorage', async () => { it('should save credentials to HybridTokenStorage', async () => {
await oauthStorage.saveCredentials(mockCredentials); await OAuthCredentialStorage.saveCredentials(mockCredentials);
expect(storage.setCredentials).toHaveBeenCalledWith(mockMcpCredentials); expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
mockMcpCredentials,
);
}); });
it('should throw an error if access_token is missing', async () => { it('should throw an error if access_token is missing', async () => {
@@ -154,7 +174,7 @@ describe('OAuthCredentialStorage', () => {
access_token: undefined, access_token: undefined,
}; };
await expect( await expect(
oauthStorage.saveCredentials(invalidCredentials), OAuthCredentialStorage.saveCredentials(invalidCredentials),
).rejects.toThrow( ).rejects.toThrow(
'Attempted to save credentials without an access token.', 'Attempted to save credentials without an access token.',
); );
@@ -163,13 +183,15 @@ describe('OAuthCredentialStorage', () => {
describe('clearCredentials', () => { describe('clearCredentials', () => {
it('should delete credentials from HybridTokenStorage', async () => { it('should delete credentials from HybridTokenStorage', async () => {
await oauthStorage.clearCredentials(); await OAuthCredentialStorage.clearCredentials();
expect(storage.deleteCredentials).toHaveBeenCalledWith('main-account'); expect(mockHybridTokenStorage.deleteCredentials).toHaveBeenCalledWith(
'main-account',
);
}); });
it('should attempt to remove the old file-based storage', async () => { it('should attempt to remove the old file-based storage', async () => {
await oauthStorage.clearCredentials(); await OAuthCredentialStorage.clearCredentials();
expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true }); expect(fs.rm).toHaveBeenCalledWith(oldFilePath, { force: true });
}); });
@@ -177,15 +199,17 @@ describe('OAuthCredentialStorage', () => {
it('should not throw an error if deleting old file fails', async () => { it('should not throw an error if deleting old file fails', async () => {
vi.spyOn(fs, 'rm').mockRejectedValue(new Error('File deletion failed')); vi.spyOn(fs, 'rm').mockRejectedValue(new Error('File deletion failed'));
await expect(oauthStorage.clearCredentials()).resolves.toBeUndefined(); await expect(
OAuthCredentialStorage.clearCredentials(),
).resolves.toBeUndefined();
}); });
it('should throw an error if clearing from HybridTokenStorage fails', async () => { it('should throw an error if clearing from HybridTokenStorage fails', async () => {
vi.spyOn(storage, 'deleteCredentials').mockRejectedValue( vi.spyOn(mockHybridTokenStorage, 'deleteCredentials').mockRejectedValue(
new Error('Deletion error'), new Error('Deletion error'),
); );
await expect(oauthStorage.clearCredentials()).rejects.toThrow( await expect(OAuthCredentialStorage.clearCredentials()).rejects.toThrow(
'Failed to clear OAuth credentials', 'Failed to clear OAuth credentials',
); );
}); });
@@ -17,16 +17,14 @@ const KEYCHAIN_SERVICE_NAME = 'gemini-cli-oauth';
const MAIN_ACCOUNT_KEY = 'main-account'; const MAIN_ACCOUNT_KEY = 'main-account';
export class OAuthCredentialStorage { export class OAuthCredentialStorage {
constructor( private static storage: HybridTokenStorage = new HybridTokenStorage(
private readonly storage: HybridTokenStorage = new HybridTokenStorage(
KEYCHAIN_SERVICE_NAME, KEYCHAIN_SERVICE_NAME,
), );
) {}
/** /**
* Load cached OAuth credentials * Load cached OAuth credentials
*/ */
async loadCredentials(): Promise<Credentials | null> { static async loadCredentials(): Promise<Credentials | null> {
try { try {
const credentials = await this.storage.getCredentials(MAIN_ACCOUNT_KEY); const credentials = await this.storage.getCredentials(MAIN_ACCOUNT_KEY);
@@ -59,7 +57,7 @@ export class OAuthCredentialStorage {
/** /**
* Save OAuth credentials * Save OAuth credentials
*/ */
async saveCredentials(credentials: Credentials): Promise<void> { static async saveCredentials(credentials: Credentials): Promise<void> {
if (!credentials.access_token) { if (!credentials.access_token) {
throw new Error('Attempted to save credentials without an access token.'); throw new Error('Attempted to save credentials without an access token.');
} }
@@ -83,7 +81,7 @@ export class OAuthCredentialStorage {
/** /**
* Clear cached OAuth credentials * Clear cached OAuth credentials
*/ */
async clearCredentials(): Promise<void> { static async clearCredentials(): Promise<void> {
try { try {
await this.storage.deleteCredentials(MAIN_ACCOUNT_KEY); await this.storage.deleteCredentials(MAIN_ACCOUNT_KEY);
@@ -99,7 +97,7 @@ export class OAuthCredentialStorage {
/** /**
* Migrate credentials from old file-based storage to keychain * Migrate credentials from old file-based storage to keychain
*/ */
private async migrateFromFileStorage(): Promise<Credentials | null> { private static async migrateFromFileStorage(): Promise<Credentials | null> {
const oldFilePath = path.join(os.homedir(), GEMINI_DIR, OAUTH_FILE); const oldFilePath = path.join(os.homedir(), GEMINI_DIR, OAUTH_FILE);
let credsJson: string; let credsJson: string;
+227 -16
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import type { Credentials } from 'google-auth-library';
import type { Mock } from 'vitest'; import type { Mock } from 'vitest';
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { import {
@@ -23,6 +24,7 @@ import * as os from 'node:os';
import { AuthType } from '../core/contentGenerator.js'; import { AuthType } from '../core/contentGenerator.js';
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
import readline from 'node:readline'; import readline from 'node:readline';
import { FORCE_ENCRYPTED_FILE_ENV_VAR } from '../mcp/token-storage/index.js';
vi.mock('os', async (importOriginal) => { vi.mock('os', async (importOriginal) => {
const os = await importOriginal<typeof import('os')>(); const os = await importOriginal<typeof import('os')>();
@@ -41,6 +43,14 @@ vi.mock('../utils/browser.js', () => ({
shouldAttemptBrowserLaunch: () => true, shouldAttemptBrowserLaunch: () => true,
})); }));
vi.mock('./oauth-credential-storage.js', () => ({
OAuthCredentialStorage: {
saveCredentials: vi.fn(),
loadCredentials: vi.fn(),
clearCredentials: vi.fn(),
},
}));
const mockConfig = { const mockConfig = {
getNoBrowser: () => false, getNoBrowser: () => false,
getProxy: () => 'http://test.proxy.com:8080', getProxy: () => 'http://test.proxy.com:8080',
@@ -51,9 +61,11 @@ const mockConfig = {
global.fetch = vi.fn(); global.fetch = vi.fn();
describe('oauth2', () => { describe('oauth2', () => {
describe('with encrypted flag false', () => {
let tempHomeDir: string; let tempHomeDir: string;
beforeEach(() => { beforeEach(() => {
process.env[FORCE_ENCRYPTED_FILE_ENV_VAR] = 'false';
tempHomeDir = fs.mkdtempSync( tempHomeDir = fs.mkdtempSync(
path.join(os.tmpdir(), 'gemini-cli-test-home-'), path.join(os.tmpdir(), 'gemini-cli-test-home-'),
); );
@@ -229,7 +241,9 @@ describe('oauth2', () => {
}; };
(readline.createInterface as Mock).mockReturnValue(mockReadline); (readline.createInterface as Mock).mockReturnValue(mockReadline);
const consoleLogSpy = vi.spyOn(console, 'log').mockImplementation(() => {}); const consoleLogSpy = vi
.spyOn(console, 'log')
.mockImplementation(() => {});
const client = await getOauthClient( const client = await getOauthClient(
AuthType.LOGIN_WITH_GOOGLE, AuthType.LOGIN_WITH_GOOGLE,
@@ -269,7 +283,9 @@ describe('oauth2', () => {
getAccessToken: mockGetAccessToken, getAccessToken: mockGetAccessToken,
} as unknown as Compute; } as unknown as Compute;
(Compute as unknown as Mock).mockImplementation(() => mockComputeClient); (Compute as unknown as Mock).mockImplementation(
() => mockComputeClient,
);
}); });
it('should attempt to load cached credentials first', async () => { it('should attempt to load cached credentials first', async () => {
@@ -548,7 +564,9 @@ describe('oauth2', () => {
() => mockOAuth2Client, () => mockOAuth2Client,
); );
(open as Mock).mockImplementation(async () => ({ on: vi.fn() }) as never); (open as Mock).mockImplementation(
async () => ({ on: vi.fn() }) as never,
);
const mockHttpServer = { const mockHttpServer = {
listen: vi.fn(), listen: vi.fn(),
@@ -585,7 +603,9 @@ describe('oauth2', () => {
() => mockOAuth2Client, () => mockOAuth2Client,
); );
(open as Mock).mockImplementation(async () => ({ on: vi.fn() }) as never); (open as Mock).mockImplementation(
async () => ({ on: vi.fn() }) as never,
);
let requestCallback!: http.RequestListener; let requestCallback!: http.RequestListener;
let serverListeningCallback: (value: unknown) => void; let serverListeningCallback: (value: unknown) => void;
@@ -594,10 +614,12 @@ describe('oauth2', () => {
); );
const mockHttpServer = { const mockHttpServer = {
listen: vi.fn((_port: number, _host: string, callback?: () => void) => { listen: vi.fn(
(_port: number, _host: string, callback?: () => void) => {
if (callback) callback(); if (callback) callback();
serverListeningCallback(undefined); serverListeningCallback(undefined);
}), },
),
close: vi.fn(), close: vi.fn(),
on: vi.fn(), on: vi.fn(),
address: () => ({ port: 3000 }), address: () => ({ port: 3000 }),
@@ -640,7 +662,9 @@ describe('oauth2', () => {
() => mockOAuth2Client, () => mockOAuth2Client,
); );
(open as Mock).mockImplementation(async () => ({ on: vi.fn() }) as never); (open as Mock).mockImplementation(
async () => ({ on: vi.fn() }) as never,
);
let requestCallback!: http.RequestListener; let requestCallback!: http.RequestListener;
let serverListeningCallback: (value: unknown) => void; let serverListeningCallback: (value: unknown) => void;
@@ -649,10 +673,12 @@ describe('oauth2', () => {
); );
const mockHttpServer = { const mockHttpServer = {
listen: vi.fn((_port: number, _host: string, callback?: () => void) => { listen: vi.fn(
(_port: number, _host: string, callback?: () => void) => {
if (callback) callback(); if (callback) callback();
serverListeningCallback(undefined); serverListeningCallback(undefined);
}), },
),
close: vi.fn(), close: vi.fn(),
on: vi.fn(), on: vi.fn(),
address: () => ({ port: 3000 }), address: () => ({ port: 3000 }),
@@ -692,7 +718,9 @@ describe('oauth2', () => {
const mockOAuth2Client = { const mockOAuth2Client = {
generateAuthUrl: vi.fn().mockReturnValue(mockAuthUrl), generateAuthUrl: vi.fn().mockReturnValue(mockAuthUrl),
getToken: vi.fn().mockRejectedValue(new Error('Token exchange failed')), getToken: vi
.fn()
.mockRejectedValue(new Error('Token exchange failed')),
on: vi.fn(), on: vi.fn(),
} as unknown as OAuth2Client; } as unknown as OAuth2Client;
(OAuth2Client as unknown as Mock).mockImplementation( (OAuth2Client as unknown as Mock).mockImplementation(
@@ -700,7 +728,9 @@ describe('oauth2', () => {
); );
vi.spyOn(crypto, 'randomBytes').mockReturnValue(mockState as never); vi.spyOn(crypto, 'randomBytes').mockReturnValue(mockState as never);
(open as Mock).mockImplementation(async () => ({ on: vi.fn() }) as never); (open as Mock).mockImplementation(
async () => ({ on: vi.fn() }) as never,
);
let requestCallback!: http.RequestListener; let requestCallback!: http.RequestListener;
let serverListeningCallback: (value: unknown) => void; let serverListeningCallback: (value: unknown) => void;
@@ -709,10 +739,12 @@ describe('oauth2', () => {
); );
const mockHttpServer = { const mockHttpServer = {
listen: vi.fn((_port: number, _host: string, callback?: () => void) => { listen: vi.fn(
(_port: number, _host: string, callback?: () => void) => {
if (callback) callback(); if (callback) callback();
serverListeningCallback(undefined); serverListeningCallback(undefined);
}), },
),
close: vi.fn(), close: vi.fn(),
on: vi.fn(), on: vi.fn(),
address: () => ({ port: 3000 }), address: () => ({ port: 3000 }),
@@ -767,7 +799,9 @@ describe('oauth2', () => {
); );
vi.spyOn(crypto, 'randomBytes').mockReturnValue(mockState as never); vi.spyOn(crypto, 'randomBytes').mockReturnValue(mockState as never);
(open as Mock).mockImplementation(async () => ({ on: vi.fn() }) as never); (open as Mock).mockImplementation(
async () => ({ on: vi.fn() }) as never,
);
// Mock fetch to fail // Mock fetch to fail
(global.fetch as Mock).mockResolvedValue({ (global.fetch as Mock).mockResolvedValue({
@@ -787,10 +821,12 @@ describe('oauth2', () => {
); );
const mockHttpServer = { const mockHttpServer = {
listen: vi.fn((_port: number, _host: string, callback?: () => void) => { listen: vi.fn(
(_port: number, _host: string, callback?: () => void) => {
if (callback) callback(); if (callback) callback();
serverListeningCallback(undefined); serverListeningCallback(undefined);
}), },
),
close: vi.fn(), close: vi.fn(),
on: vi.fn(), on: vi.fn(),
address: () => ({ port: 3000 }), address: () => ({ port: 3000 }),
@@ -951,4 +987,179 @@ describe('oauth2', () => {
expect(OAuth2Client).toHaveBeenCalledTimes(2); expect(OAuth2Client).toHaveBeenCalledTimes(2);
}); });
}); });
});
describe('with encrypted flag true', () => {
let tempHomeDir: string;
beforeEach(() => {
process.env[FORCE_ENCRYPTED_FILE_ENV_VAR] = 'true';
tempHomeDir = fs.mkdtempSync(
path.join(os.tmpdir(), 'gemini-cli-test-home-'),
);
(os.homedir as Mock).mockReturnValue(tempHomeDir);
});
afterEach(() => {
fs.rmSync(tempHomeDir, { recursive: true, force: true });
vi.clearAllMocks();
resetOauthClientForTesting();
vi.unstubAllEnvs();
});
it('should save credentials using OAuthCredentialStorage during web login', async () => {
const { OAuthCredentialStorage } = await import(
'./oauth-credential-storage.js'
);
const mockAuthUrl = 'https://example.com/auth';
const mockCode = 'test-code';
const mockState = 'test-state';
const mockTokens = {
access_token: 'test-access-token',
refresh_token: 'test-refresh-token',
};
let onTokensCallback: (tokens: Credentials) => void = () => {};
const mockOn = vi.fn((event, callback) => {
if (event === 'tokens') {
onTokensCallback = callback;
}
});
const mockGetToken = vi.fn().mockImplementation(async () => {
onTokensCallback(mockTokens);
return { tokens: mockTokens };
});
const mockOAuth2Client = {
generateAuthUrl: vi.fn().mockReturnValue(mockAuthUrl),
getToken: mockGetToken,
setCredentials: vi.fn(),
getAccessToken: vi
.fn()
.mockResolvedValue({ token: 'mock-access-token' }),
on: mockOn,
credentials: mockTokens,
} as unknown as OAuth2Client;
(OAuth2Client as unknown as Mock).mockImplementation(
() => mockOAuth2Client,
);
vi.spyOn(crypto, 'randomBytes').mockReturnValue(mockState as never);
(open as Mock).mockImplementation(async () => ({ on: vi.fn() }) as never);
(global.fetch as Mock).mockResolvedValue({
ok: true,
json: vi
.fn()
.mockResolvedValue({ email: 'test-google-account@gmail.com' }),
} as unknown as Response);
let requestCallback!: http.RequestListener;
let serverListeningCallback: (value: unknown) => void;
const serverListeningPromise = new Promise(
(resolve) => (serverListeningCallback = resolve),
);
let capturedPort = 0;
const mockHttpServer = {
listen: vi.fn((port: number, _host: string, callback?: () => void) => {
capturedPort = port;
if (callback) {
callback();
}
serverListeningCallback(undefined);
}),
close: vi.fn((callback?: () => void) => {
if (callback) {
callback();
}
}),
on: vi.fn(),
address: () => ({ port: capturedPort }),
};
(http.createServer as Mock).mockImplementation((cb) => {
requestCallback = cb as http.RequestListener;
return mockHttpServer as unknown as http.Server;
});
const clientPromise = getOauthClient(
AuthType.LOGIN_WITH_GOOGLE,
mockConfig,
);
await serverListeningPromise;
const mockReq = {
url: `/oauth2callback?code=${mockCode}&state=${mockState}`,
} as http.IncomingMessage;
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
} as unknown as http.ServerResponse;
requestCallback(mockReq, mockRes);
await clientPromise;
expect(
OAuthCredentialStorage.saveCredentials as Mock,
).toHaveBeenCalledWith(mockTokens);
const credsPath = path.join(tempHomeDir, '.gemini', 'oauth_creds.json');
expect(fs.existsSync(credsPath)).toBe(false);
});
it('should load credentials using OAuthCredentialStorage and not from file', async () => {
const { OAuthCredentialStorage } = await import(
'./oauth-credential-storage.js'
);
const cachedCreds = { refresh_token: 'cached-encrypted-token' };
(OAuthCredentialStorage.loadCredentials as Mock).mockResolvedValue(
cachedCreds,
);
// Create a dummy unencrypted credential file.
// If the logic is correct, this file should be ignored.
const unencryptedCreds = { refresh_token: 'unencrypted-token' };
const credsPath = path.join(tempHomeDir, '.gemini', 'oauth_creds.json');
await fs.promises.mkdir(path.dirname(credsPath), { recursive: true });
await fs.promises.writeFile(credsPath, JSON.stringify(unencryptedCreds));
const mockClient = {
setCredentials: vi.fn(),
getAccessToken: vi.fn().mockResolvedValue({ token: 'test-token' }),
getTokenInfo: vi.fn().mockResolvedValue({}),
on: vi.fn(),
};
(OAuth2Client as unknown as Mock).mockImplementation(
() => mockClient as unknown as OAuth2Client,
);
await getOauthClient(AuthType.LOGIN_WITH_GOOGLE, mockConfig);
expect(OAuthCredentialStorage.loadCredentials as Mock).toHaveBeenCalled();
expect(mockClient.setCredentials).toHaveBeenCalledWith(cachedCreds);
expect(mockClient.setCredentials).not.toHaveBeenCalledWith(
unencryptedCreds,
);
});
it('should clear credentials using OAuthCredentialStorage', async () => {
const { OAuthCredentialStorage } = await import(
'./oauth-credential-storage.js'
);
// Create a dummy unencrypted credential file. It should not be deleted.
const credsPath = path.join(tempHomeDir, '.gemini', 'oauth_creds.json');
await fs.promises.mkdir(path.dirname(credsPath), { recursive: true });
await fs.promises.writeFile(credsPath, '{}');
await clearCachedCredentialFile();
expect(
OAuthCredentialStorage.clearCredentials as Mock,
).toHaveBeenCalled();
expect(fs.existsSync(credsPath)).toBe(true); // The unencrypted file should remain
});
});
}); });
+26
View File
@@ -23,6 +23,8 @@ import { UserAccountManager } from '../utils/userAccountManager.js';
import { AuthType } from '../core/contentGenerator.js'; import { AuthType } from '../core/contentGenerator.js';
import readline from 'node:readline'; import readline from 'node:readline';
import { Storage } from '../config/storage.js'; import { Storage } from '../config/storage.js';
import { OAuthCredentialStorage } from './oauth-credential-storage.js';
import { FORCE_ENCRYPTED_FILE_ENV_VAR } from '../mcp/token-storage/index.js';
const userAccountManager = new UserAccountManager(); const userAccountManager = new UserAccountManager();
@@ -63,6 +65,10 @@ export interface OauthWebLogin {
const oauthClientPromises = new Map<AuthType, Promise<OAuth2Client>>(); const oauthClientPromises = new Map<AuthType, Promise<OAuth2Client>>();
function getUseEncryptedStorageFlag() {
return process.env[FORCE_ENCRYPTED_FILE_ENV_VAR] === 'true';
}
async function initOauthClient( async function initOauthClient(
authType: AuthType, authType: AuthType,
config: Config, config: Config,
@@ -74,6 +80,7 @@ async function initOauthClient(
proxy: config.getProxy(), proxy: config.getProxy(),
}, },
}); });
const useEncryptedStorage = getUseEncryptedStorageFlag();
if ( if (
process.env['GOOGLE_GENAI_USE_GCA'] && process.env['GOOGLE_GENAI_USE_GCA'] &&
@@ -87,7 +94,11 @@ async function initOauthClient(
} }
client.on('tokens', async (tokens: Credentials) => { client.on('tokens', async (tokens: Credentials) => {
if (useEncryptedStorage) {
await OAuthCredentialStorage.saveCredentials(tokens);
} else {
await cacheCredentials(tokens); await cacheCredentials(tokens);
}
}); });
// If there are cached creds on disk, they always take precedence // If there are cached creds on disk, they always take precedence
@@ -419,6 +430,16 @@ export function getAvailablePort(): Promise<number> {
} }
async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> { async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> {
const useEncryptedStorage = getUseEncryptedStorageFlag();
if (useEncryptedStorage) {
const credentials = await OAuthCredentialStorage.loadCredentials();
if (credentials) {
client.setCredentials(credentials);
return true;
}
return false;
}
const pathsToTry = [ const pathsToTry = [
Storage.getOAuthCredsPath(), Storage.getOAuthCredsPath(),
process.env['GOOGLE_APPLICATION_CREDENTIALS'], process.env['GOOGLE_APPLICATION_CREDENTIALS'],
@@ -470,7 +491,12 @@ export function clearOauthClientCache() {
export async function clearCachedCredentialFile() { export async function clearCachedCredentialFile() {
try { try {
const useEncryptedStorage = getUseEncryptedStorageFlag();
if (useEncryptedStorage) {
await OAuthCredentialStorage.clearCredentials();
} else {
await fs.rm(Storage.getOAuthCredsPath(), { force: true }); await fs.rm(Storage.getOAuthCredsPath(), { force: true });
}
// Clear the Google Account ID cache when credentials are cleared // Clear the Google Account ID cache when credentials are cleared
await userAccountManager.clearCachedGoogleAccount(); await userAccountManager.clearCachedGoogleAccount();
// Clear the in-memory OAuth client cache to force re-authentication // Clear the in-memory OAuth client cache to force re-authentication