mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
feat(oauth) - Create hybrid storage class (#7610)
Co-authored-by: Shi Shu <shii@google.com>
This commit is contained in:
274
packages/core/src/mcp/token-storage/hybrid-token-storage.test.ts
Normal file
274
packages/core/src/mcp/token-storage/hybrid-token-storage.test.ts
Normal file
@@ -0,0 +1,274 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
|
||||
import { HybridTokenStorage } from './hybrid-token-storage.js';
|
||||
import { KeychainTokenStorage } from './keychain-token-storage.js';
|
||||
import { FileTokenStorage } from './file-token-storage.js';
|
||||
import { type OAuthCredentials, TokenStorageType } from './types.js';
|
||||
|
||||
vi.mock('./keychain-token-storage.js', () => ({
|
||||
KeychainTokenStorage: vi.fn().mockImplementation(() => ({
|
||||
isAvailable: vi.fn(),
|
||||
getCredentials: vi.fn(),
|
||||
setCredentials: vi.fn(),
|
||||
deleteCredentials: vi.fn(),
|
||||
listServers: vi.fn(),
|
||||
getAllCredentials: vi.fn(),
|
||||
clearAll: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('./file-token-storage.js', () => ({
|
||||
FileTokenStorage: vi.fn().mockImplementation(() => ({
|
||||
getCredentials: vi.fn(),
|
||||
setCredentials: vi.fn(),
|
||||
deleteCredentials: vi.fn(),
|
||||
listServers: vi.fn(),
|
||||
getAllCredentials: vi.fn(),
|
||||
clearAll: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
interface MockStorage {
|
||||
isAvailable?: ReturnType<typeof vi.fn>;
|
||||
getCredentials: ReturnType<typeof vi.fn>;
|
||||
setCredentials: ReturnType<typeof vi.fn>;
|
||||
deleteCredentials: ReturnType<typeof vi.fn>;
|
||||
listServers: ReturnType<typeof vi.fn>;
|
||||
getAllCredentials: ReturnType<typeof vi.fn>;
|
||||
clearAll: ReturnType<typeof vi.fn>;
|
||||
}
|
||||
|
||||
describe('HybridTokenStorage', () => {
|
||||
let storage: HybridTokenStorage;
|
||||
let mockKeychainStorage: MockStorage;
|
||||
let mockFileStorage: MockStorage;
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
process.env = { ...originalEnv };
|
||||
|
||||
// Create mock instances before creating HybridTokenStorage
|
||||
mockKeychainStorage = {
|
||||
isAvailable: vi.fn(),
|
||||
getCredentials: vi.fn(),
|
||||
setCredentials: vi.fn(),
|
||||
deleteCredentials: vi.fn(),
|
||||
listServers: vi.fn(),
|
||||
getAllCredentials: vi.fn(),
|
||||
clearAll: vi.fn(),
|
||||
};
|
||||
|
||||
mockFileStorage = {
|
||||
getCredentials: vi.fn(),
|
||||
setCredentials: vi.fn(),
|
||||
deleteCredentials: vi.fn(),
|
||||
listServers: vi.fn(),
|
||||
getAllCredentials: vi.fn(),
|
||||
clearAll: vi.fn(),
|
||||
};
|
||||
|
||||
(
|
||||
KeychainTokenStorage as unknown as ReturnType<typeof vi.fn>
|
||||
).mockImplementation(() => mockKeychainStorage);
|
||||
(
|
||||
FileTokenStorage as unknown as ReturnType<typeof vi.fn>
|
||||
).mockImplementation(() => mockFileStorage);
|
||||
|
||||
storage = new HybridTokenStorage('test-service');
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
describe('storage selection', () => {
|
||||
it('should use keychain when available', async () => {
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
|
||||
mockKeychainStorage.getCredentials.mockResolvedValue(null);
|
||||
|
||||
await storage.getCredentials('test-server');
|
||||
|
||||
expect(mockKeychainStorage.isAvailable).toHaveBeenCalled();
|
||||
expect(mockKeychainStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
expect(await storage.getStorageType()).toBe(TokenStorageType.KEYCHAIN);
|
||||
});
|
||||
|
||||
it('should use file storage when GEMINI_FORCE_FILE_STORAGE is set', async () => {
|
||||
process.env['GEMINI_FORCE_FILE_STORAGE'] = 'true';
|
||||
mockFileStorage.getCredentials.mockResolvedValue(null);
|
||||
|
||||
await storage.getCredentials('test-server');
|
||||
|
||||
expect(mockKeychainStorage.isAvailable).not.toHaveBeenCalled();
|
||||
expect(mockFileStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
expect(await storage.getStorageType()).toBe(
|
||||
TokenStorageType.ENCRYPTED_FILE,
|
||||
);
|
||||
});
|
||||
|
||||
it('should fall back to file storage when keychain is unavailable', async () => {
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(false);
|
||||
mockFileStorage.getCredentials.mockResolvedValue(null);
|
||||
|
||||
await storage.getCredentials('test-server');
|
||||
|
||||
expect(mockKeychainStorage.isAvailable).toHaveBeenCalled();
|
||||
expect(mockFileStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
expect(await storage.getStorageType()).toBe(
|
||||
TokenStorageType.ENCRYPTED_FILE,
|
||||
);
|
||||
});
|
||||
|
||||
it('should fall back to file storage when keychain throws error', async () => {
|
||||
mockKeychainStorage.isAvailable!.mockRejectedValue(
|
||||
new Error('Keychain error'),
|
||||
);
|
||||
mockFileStorage.getCredentials.mockResolvedValue(null);
|
||||
|
||||
await storage.getCredentials('test-server');
|
||||
|
||||
expect(mockKeychainStorage.isAvailable).toHaveBeenCalled();
|
||||
expect(mockFileStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
expect(await storage.getStorageType()).toBe(
|
||||
TokenStorageType.ENCRYPTED_FILE,
|
||||
);
|
||||
});
|
||||
|
||||
it('should cache storage selection', async () => {
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
|
||||
mockKeychainStorage.getCredentials.mockResolvedValue(null);
|
||||
|
||||
await storage.getCredentials('test-server');
|
||||
await storage.getCredentials('another-server');
|
||||
|
||||
expect(mockKeychainStorage.isAvailable).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getCredentials', () => {
|
||||
it('should delegate to selected storage', async () => {
|
||||
const credentials: OAuthCredentials = {
|
||||
serverName: 'test-server',
|
||||
token: {
|
||||
accessToken: 'access-token',
|
||||
tokenType: 'Bearer',
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
|
||||
mockKeychainStorage.getCredentials.mockResolvedValue(credentials);
|
||||
|
||||
const result = await storage.getCredentials('test-server');
|
||||
|
||||
expect(result).toEqual(credentials);
|
||||
expect(mockKeychainStorage.getCredentials).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('setCredentials', () => {
|
||||
it('should delegate to selected storage', async () => {
|
||||
const credentials: OAuthCredentials = {
|
||||
serverName: 'test-server',
|
||||
token: {
|
||||
accessToken: 'access-token',
|
||||
tokenType: 'Bearer',
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
|
||||
mockKeychainStorage.setCredentials.mockResolvedValue(undefined);
|
||||
|
||||
await storage.setCredentials(credentials);
|
||||
|
||||
expect(mockKeychainStorage.setCredentials).toHaveBeenCalledWith(
|
||||
credentials,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteCredentials', () => {
|
||||
it('should delegate to selected storage', async () => {
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
|
||||
mockKeychainStorage.deleteCredentials.mockResolvedValue(undefined);
|
||||
|
||||
await storage.deleteCredentials('test-server');
|
||||
|
||||
expect(mockKeychainStorage.deleteCredentials).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('listServers', () => {
|
||||
it('should delegate to selected storage', async () => {
|
||||
const servers = ['server1', 'server2'];
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
|
||||
mockKeychainStorage.listServers.mockResolvedValue(servers);
|
||||
|
||||
const result = await storage.listServers();
|
||||
|
||||
expect(result).toEqual(servers);
|
||||
expect(mockKeychainStorage.listServers).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAllCredentials', () => {
|
||||
it('should delegate to selected storage', async () => {
|
||||
const credentialsMap = new Map([
|
||||
[
|
||||
'server1',
|
||||
{
|
||||
serverName: 'server1',
|
||||
token: { accessToken: 'token1', tokenType: 'Bearer' },
|
||||
updatedAt: Date.now(),
|
||||
},
|
||||
],
|
||||
[
|
||||
'server2',
|
||||
{
|
||||
serverName: 'server2',
|
||||
token: { accessToken: 'token2', tokenType: 'Bearer' },
|
||||
updatedAt: Date.now(),
|
||||
},
|
||||
],
|
||||
]);
|
||||
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
|
||||
mockKeychainStorage.getAllCredentials.mockResolvedValue(credentialsMap);
|
||||
|
||||
const result = await storage.getAllCredentials();
|
||||
|
||||
expect(result).toEqual(credentialsMap);
|
||||
expect(mockKeychainStorage.getAllCredentials).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearAll', () => {
|
||||
it('should delegate to selected storage', async () => {
|
||||
mockKeychainStorage.isAvailable!.mockResolvedValue(true);
|
||||
mockKeychainStorage.clearAll.mockResolvedValue(undefined);
|
||||
|
||||
await storage.clearAll();
|
||||
|
||||
expect(mockKeychainStorage.clearAll).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
97
packages/core/src/mcp/token-storage/hybrid-token-storage.ts
Normal file
97
packages/core/src/mcp/token-storage/hybrid-token-storage.ts
Normal file
@@ -0,0 +1,97 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { BaseTokenStorage } from './base-token-storage.js';
|
||||
import { FileTokenStorage } from './file-token-storage.js';
|
||||
import type { TokenStorage, OAuthCredentials } from './types.js';
|
||||
import { TokenStorageType } from './types.js';
|
||||
|
||||
const FORCE_FILE_STORAGE_ENV_VAR = 'GEMINI_FORCE_FILE_STORAGE';
|
||||
|
||||
export class HybridTokenStorage extends BaseTokenStorage {
|
||||
private storage: TokenStorage | null = null;
|
||||
private storageType: TokenStorageType | null = null;
|
||||
private storageInitPromise: Promise<TokenStorage> | null = null;
|
||||
|
||||
constructor(serviceName: string) {
|
||||
super(serviceName);
|
||||
}
|
||||
|
||||
private async initializeStorage(): Promise<TokenStorage> {
|
||||
const forceFileStorage = process.env[FORCE_FILE_STORAGE_ENV_VAR] === 'true';
|
||||
|
||||
if (!forceFileStorage) {
|
||||
try {
|
||||
const { KeychainTokenStorage } = await import(
|
||||
'./keychain-token-storage.js'
|
||||
);
|
||||
const keychainStorage = new KeychainTokenStorage(this.serviceName);
|
||||
|
||||
const isAvailable = await keychainStorage.isAvailable();
|
||||
if (isAvailable) {
|
||||
this.storage = keychainStorage;
|
||||
this.storageType = TokenStorageType.KEYCHAIN;
|
||||
return this.storage;
|
||||
}
|
||||
} catch (_e) {
|
||||
// Fallback to file storage if keychain fails to initialize
|
||||
}
|
||||
}
|
||||
|
||||
this.storage = new FileTokenStorage(this.serviceName);
|
||||
this.storageType = TokenStorageType.ENCRYPTED_FILE;
|
||||
return this.storage;
|
||||
}
|
||||
|
||||
private async getStorage(): Promise<TokenStorage> {
|
||||
if (this.storage !== null) {
|
||||
return this.storage;
|
||||
}
|
||||
|
||||
// Use a single initialization promise to avoid race conditions
|
||||
if (!this.storageInitPromise) {
|
||||
this.storageInitPromise = this.initializeStorage();
|
||||
}
|
||||
|
||||
// Wait for initialization to complete
|
||||
return await this.storageInitPromise;
|
||||
}
|
||||
|
||||
async getCredentials(serverName: string): Promise<OAuthCredentials | null> {
|
||||
const storage = await this.getStorage();
|
||||
return storage.getCredentials(serverName);
|
||||
}
|
||||
|
||||
async setCredentials(credentials: OAuthCredentials): Promise<void> {
|
||||
const storage = await this.getStorage();
|
||||
await storage.setCredentials(credentials);
|
||||
}
|
||||
|
||||
async deleteCredentials(serverName: string): Promise<void> {
|
||||
const storage = await this.getStorage();
|
||||
await storage.deleteCredentials(serverName);
|
||||
}
|
||||
|
||||
async listServers(): Promise<string[]> {
|
||||
const storage = await this.getStorage();
|
||||
return storage.listServers();
|
||||
}
|
||||
|
||||
async getAllCredentials(): Promise<Map<string, OAuthCredentials>> {
|
||||
const storage = await this.getStorage();
|
||||
return storage.getAllCredentials();
|
||||
}
|
||||
|
||||
async clearAll(): Promise<void> {
|
||||
const storage = await this.getStorage();
|
||||
await storage.clearAll();
|
||||
}
|
||||
|
||||
async getStorageType(): Promise<TokenStorageType> {
|
||||
await this.getStorage();
|
||||
return this.storageType!;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,352 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
|
||||
import type { KeychainTokenStorage } from './keychain-token-storage.js';
|
||||
import type { OAuthCredentials } from './types.js';
|
||||
|
||||
// Hoist the mock to be available in the vi.mock factory
|
||||
const mockKeytar = vi.hoisted(() => ({
|
||||
getPassword: vi.fn(),
|
||||
setPassword: vi.fn(),
|
||||
deletePassword: vi.fn(),
|
||||
findCredentials: vi.fn(),
|
||||
}));
|
||||
|
||||
const mockServiceName = 'service-name';
|
||||
const mockCryptoRandomBytesString = 'random-string';
|
||||
|
||||
// Mock the dynamic import of 'keytar'
|
||||
vi.mock('keytar', () => ({
|
||||
default: mockKeytar,
|
||||
}));
|
||||
|
||||
vi.mock('node:crypto', () => ({
|
||||
randomBytes: vi.fn(() => ({
|
||||
toString: vi.fn(() => mockCryptoRandomBytesString),
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('KeychainTokenStorage', () => {
|
||||
let storage: KeychainTokenStorage;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.resetAllMocks();
|
||||
// Reset the internal state of the keychain-token-storage module
|
||||
vi.resetModules();
|
||||
const { KeychainTokenStorage } = await import(
|
||||
'./keychain-token-storage.js'
|
||||
);
|
||||
storage = new KeychainTokenStorage(mockServiceName);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
const validCredentials = {
|
||||
serverName: 'test-server',
|
||||
token: {
|
||||
accessToken: 'access-token',
|
||||
tokenType: 'Bearer',
|
||||
expiresAt: Date.now() + 3600000,
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
} as OAuthCredentials;
|
||||
|
||||
describe('checkKeychainAvailability', () => {
|
||||
it('should return true if keytar is available and functional', async () => {
|
||||
mockKeytar.setPassword.mockResolvedValue(undefined);
|
||||
mockKeytar.getPassword.mockResolvedValue('test');
|
||||
mockKeytar.deletePassword.mockResolvedValue(true);
|
||||
|
||||
const isAvailable = await storage.checkKeychainAvailability();
|
||||
expect(isAvailable).toBe(true);
|
||||
expect(mockKeytar.setPassword).toHaveBeenCalledWith(
|
||||
mockServiceName,
|
||||
`__keychain_test__${mockCryptoRandomBytesString}`,
|
||||
'test',
|
||||
);
|
||||
expect(mockKeytar.getPassword).toHaveBeenCalledWith(
|
||||
mockServiceName,
|
||||
`__keychain_test__${mockCryptoRandomBytesString}`,
|
||||
);
|
||||
expect(mockKeytar.deletePassword).toHaveBeenCalledWith(
|
||||
mockServiceName,
|
||||
`__keychain_test__${mockCryptoRandomBytesString}`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return false if keytar fails to set password', async () => {
|
||||
mockKeytar.setPassword.mockRejectedValue(new Error('write error'));
|
||||
const isAvailable = await storage.checkKeychainAvailability();
|
||||
expect(isAvailable).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false if retrieved password does not match', async () => {
|
||||
mockKeytar.setPassword.mockResolvedValue(undefined);
|
||||
mockKeytar.getPassword.mockResolvedValue('wrong-password');
|
||||
mockKeytar.deletePassword.mockResolvedValue(true);
|
||||
const isAvailable = await storage.checkKeychainAvailability();
|
||||
expect(isAvailable).toBe(false);
|
||||
});
|
||||
|
||||
it('should cache the availability result', async () => {
|
||||
mockKeytar.setPassword.mockResolvedValue(undefined);
|
||||
mockKeytar.getPassword.mockResolvedValue('test');
|
||||
mockKeytar.deletePassword.mockResolvedValue(true);
|
||||
|
||||
await storage.checkKeychainAvailability();
|
||||
await storage.checkKeychainAvailability();
|
||||
|
||||
expect(mockKeytar.setPassword).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('with keychain unavailable', () => {
|
||||
beforeEach(async () => {
|
||||
// Force keychain to be unavailable
|
||||
mockKeytar.setPassword.mockRejectedValue(new Error('keychain error'));
|
||||
await storage.checkKeychainAvailability();
|
||||
});
|
||||
|
||||
it('getCredentials should throw', async () => {
|
||||
await expect(storage.getCredentials('server')).rejects.toThrow(
|
||||
'Keychain is not available',
|
||||
);
|
||||
});
|
||||
|
||||
it('setCredentials should throw', async () => {
|
||||
await expect(storage.setCredentials(validCredentials)).rejects.toThrow(
|
||||
'Keychain is not available',
|
||||
);
|
||||
});
|
||||
|
||||
it('deleteCredentials should throw', async () => {
|
||||
await expect(storage.deleteCredentials('server')).rejects.toThrow(
|
||||
'Keychain is not available',
|
||||
);
|
||||
});
|
||||
|
||||
it('listServers should throw', async () => {
|
||||
await expect(storage.listServers()).rejects.toThrow(
|
||||
'Keychain is not available',
|
||||
);
|
||||
});
|
||||
|
||||
it('getAllCredentials should throw', async () => {
|
||||
await expect(storage.getAllCredentials()).rejects.toThrow(
|
||||
'Keychain is not available',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('with keychain available', () => {
|
||||
beforeEach(async () => {
|
||||
mockKeytar.setPassword.mockResolvedValue(undefined);
|
||||
mockKeytar.getPassword.mockResolvedValue('test');
|
||||
mockKeytar.deletePassword.mockResolvedValue(true);
|
||||
await storage.checkKeychainAvailability();
|
||||
// Reset mocks after availability check
|
||||
vi.resetAllMocks();
|
||||
});
|
||||
|
||||
describe('getCredentials', () => {
|
||||
it('should return null if no credentials are found', async () => {
|
||||
mockKeytar.getPassword.mockResolvedValue(null);
|
||||
const result = await storage.getCredentials('test-server');
|
||||
expect(result).toBeNull();
|
||||
expect(mockKeytar.getPassword).toHaveBeenCalledWith(
|
||||
mockServiceName,
|
||||
'test-server',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return credentials if found and not expired', async () => {
|
||||
mockKeytar.getPassword.mockResolvedValue(
|
||||
JSON.stringify(validCredentials),
|
||||
);
|
||||
const result = await storage.getCredentials('test-server');
|
||||
expect(result).toEqual(validCredentials);
|
||||
});
|
||||
|
||||
it('should return null if credentials have expired', async () => {
|
||||
const expiredCreds = {
|
||||
...validCredentials,
|
||||
token: { ...validCredentials.token, expiresAt: Date.now() - 1000 },
|
||||
};
|
||||
mockKeytar.getPassword.mockResolvedValue(JSON.stringify(expiredCreds));
|
||||
const result = await storage.getCredentials('test-server');
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should throw if stored data is corrupted JSON', async () => {
|
||||
mockKeytar.getPassword.mockResolvedValue('not-json');
|
||||
await expect(storage.getCredentials('test-server')).rejects.toThrow(
|
||||
'Failed to parse stored credentials for test-server',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('setCredentials', () => {
|
||||
it('should save credentials to keychain', async () => {
|
||||
vi.useFakeTimers();
|
||||
mockKeytar.setPassword.mockResolvedValue(undefined);
|
||||
await storage.setCredentials(validCredentials);
|
||||
expect(mockKeytar.setPassword).toHaveBeenCalledWith(
|
||||
mockServiceName,
|
||||
'test-server',
|
||||
JSON.stringify({ ...validCredentials, updatedAt: Date.now() }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw if saving to keychain fails', async () => {
|
||||
mockKeytar.setPassword.mockRejectedValue(
|
||||
new Error('keychain write error'),
|
||||
);
|
||||
await expect(storage.setCredentials(validCredentials)).rejects.toThrow(
|
||||
'keychain write error',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteCredentials', () => {
|
||||
it('should delete credentials from keychain', async () => {
|
||||
mockKeytar.deletePassword.mockResolvedValue(true);
|
||||
await storage.deleteCredentials('test-server');
|
||||
expect(mockKeytar.deletePassword).toHaveBeenCalledWith(
|
||||
mockServiceName,
|
||||
'test-server',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw if no credentials were found to delete', async () => {
|
||||
mockKeytar.deletePassword.mockResolvedValue(false);
|
||||
await expect(storage.deleteCredentials('test-server')).rejects.toThrow(
|
||||
'No credentials found for test-server',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw if deleting from keychain fails', async () => {
|
||||
mockKeytar.deletePassword.mockRejectedValue(
|
||||
new Error('keychain delete error'),
|
||||
);
|
||||
await expect(storage.deleteCredentials('test-server')).rejects.toThrow(
|
||||
'keychain delete error',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('listServers', () => {
|
||||
it('should return a list of server names', async () => {
|
||||
mockKeytar.findCredentials.mockResolvedValue([
|
||||
{ account: 'server1', password: '' },
|
||||
{ account: 'server2', password: '' },
|
||||
]);
|
||||
const result = await storage.listServers();
|
||||
expect(result).toEqual(['server1', 'server2']);
|
||||
});
|
||||
|
||||
it('should not include internal test keys in the server list', async () => {
|
||||
mockKeytar.findCredentials.mockResolvedValue([
|
||||
{ account: 'server1', password: '' },
|
||||
{
|
||||
account: `__keychain_test__${mockCryptoRandomBytesString}`,
|
||||
password: '',
|
||||
},
|
||||
{ account: 'server2', password: '' },
|
||||
]);
|
||||
const result = await storage.listServers();
|
||||
expect(result).toEqual(['server1', 'server2']);
|
||||
});
|
||||
|
||||
it('should return an empty array on error', async () => {
|
||||
mockKeytar.findCredentials.mockRejectedValue(new Error('find error'));
|
||||
const result = await storage.listServers();
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAllCredentials', () => {
|
||||
it('should return a map of all valid credentials', async () => {
|
||||
const creds2 = {
|
||||
...validCredentials,
|
||||
serverName: 'server2',
|
||||
};
|
||||
const expiredCreds = {
|
||||
...validCredentials,
|
||||
serverName: 'expired-server',
|
||||
token: { ...validCredentials.token, expiresAt: Date.now() - 1000 },
|
||||
};
|
||||
const structurallyInvalidCreds = {
|
||||
serverName: 'invalid-server',
|
||||
};
|
||||
|
||||
mockKeytar.findCredentials.mockResolvedValue([
|
||||
{
|
||||
account: 'test-server',
|
||||
password: JSON.stringify(validCredentials),
|
||||
},
|
||||
{ account: 'server2', password: JSON.stringify(creds2) },
|
||||
{
|
||||
account: 'expired-server',
|
||||
password: JSON.stringify(expiredCreds),
|
||||
},
|
||||
{ account: 'bad-server', password: 'not-json' },
|
||||
{
|
||||
account: 'invalid-server',
|
||||
password: JSON.stringify(structurallyInvalidCreds),
|
||||
},
|
||||
]);
|
||||
|
||||
const result = await storage.getAllCredentials();
|
||||
expect(result.size).toBe(2);
|
||||
expect(result.get('test-server')).toEqual(validCredentials);
|
||||
expect(result.get('server2')).toEqual(creds2);
|
||||
expect(result.has('expired-server')).toBe(false);
|
||||
expect(result.has('bad-server')).toBe(false);
|
||||
expect(result.has('invalid-server')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearAll', () => {
|
||||
it('should delete all credentials for the service', async () => {
|
||||
mockKeytar.findCredentials.mockResolvedValue([
|
||||
{ account: 'server1', password: '' },
|
||||
{ account: 'server2', password: '' },
|
||||
]);
|
||||
mockKeytar.deletePassword.mockResolvedValue(true);
|
||||
|
||||
await storage.clearAll();
|
||||
|
||||
expect(mockKeytar.deletePassword).toHaveBeenCalledTimes(2);
|
||||
expect(mockKeytar.deletePassword).toHaveBeenCalledWith(
|
||||
mockServiceName,
|
||||
'server1',
|
||||
);
|
||||
expect(mockKeytar.deletePassword).toHaveBeenCalledWith(
|
||||
mockServiceName,
|
||||
'server2',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an aggregated error if deletions fail', async () => {
|
||||
mockKeytar.findCredentials.mockResolvedValue([
|
||||
{ account: 'server1', password: '' },
|
||||
{ account: 'server2', password: '' },
|
||||
]);
|
||||
mockKeytar.deletePassword
|
||||
.mockResolvedValueOnce(true)
|
||||
.mockRejectedValueOnce(new Error('delete failed'));
|
||||
|
||||
await expect(storage.clearAll()).rejects.toThrow(
|
||||
'Failed to clear some credentials: delete failed',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
251
packages/core/src/mcp/token-storage/keychain-token-storage.ts
Normal file
251
packages/core/src/mcp/token-storage/keychain-token-storage.ts
Normal file
@@ -0,0 +1,251 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as crypto from 'node:crypto';
|
||||
import { BaseTokenStorage } from './base-token-storage.js';
|
||||
import type { OAuthCredentials } from './types.js';
|
||||
|
||||
interface Keytar {
|
||||
getPassword(service: string, account: string): Promise<string | null>;
|
||||
setPassword(
|
||||
service: string,
|
||||
account: string,
|
||||
password: string,
|
||||
): Promise<void>;
|
||||
deletePassword(service: string, account: string): Promise<boolean>;
|
||||
findCredentials(
|
||||
service: string,
|
||||
): Promise<Array<{ account: string; password: string }>>;
|
||||
}
|
||||
|
||||
const KEYCHAIN_TEST_PREFIX = '__keychain_test__';
|
||||
|
||||
export class KeychainTokenStorage extends BaseTokenStorage {
|
||||
private keychainAvailable: boolean | null = null;
|
||||
private keytarModule: Keytar | null = null;
|
||||
private keytarLoadAttempted = false;
|
||||
|
||||
async getKeytar(): Promise<Keytar | null> {
|
||||
// If we've already tried loading (successfully or not), return the result
|
||||
if (this.keytarLoadAttempted) {
|
||||
return this.keytarModule;
|
||||
}
|
||||
|
||||
this.keytarLoadAttempted = true;
|
||||
|
||||
try {
|
||||
// Try to import keytar without any timeout - let the OS handle it
|
||||
const moduleName = 'keytar';
|
||||
const module = await import(moduleName);
|
||||
this.keytarModule = module.default || module;
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
}
|
||||
return this.keytarModule;
|
||||
}
|
||||
|
||||
async getCredentials(serverName: string): Promise<OAuthCredentials | null> {
|
||||
if (!(await this.checkKeychainAvailability())) {
|
||||
throw new Error('Keychain is not available');
|
||||
}
|
||||
|
||||
const keytar = await this.getKeytar();
|
||||
if (!keytar) {
|
||||
throw new Error('Keytar module not available');
|
||||
}
|
||||
|
||||
try {
|
||||
const sanitizedName = this.sanitizeServerName(serverName);
|
||||
const data = await keytar.getPassword(this.serviceName, sanitizedName);
|
||||
|
||||
if (!data) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const credentials = JSON.parse(data) as OAuthCredentials;
|
||||
|
||||
if (this.isTokenExpired(credentials)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return credentials;
|
||||
} catch (error) {
|
||||
if (error instanceof SyntaxError) {
|
||||
throw new Error(`Failed to parse stored credentials for ${serverName}`);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async setCredentials(credentials: OAuthCredentials): Promise<void> {
|
||||
if (!(await this.checkKeychainAvailability())) {
|
||||
throw new Error('Keychain is not available');
|
||||
}
|
||||
|
||||
const keytar = await this.getKeytar();
|
||||
if (!keytar) {
|
||||
throw new Error('Keytar module not available');
|
||||
}
|
||||
|
||||
this.validateCredentials(credentials);
|
||||
|
||||
const sanitizedName = this.sanitizeServerName(credentials.serverName);
|
||||
const updatedCredentials: OAuthCredentials = {
|
||||
...credentials,
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
const data = JSON.stringify(updatedCredentials);
|
||||
await keytar.setPassword(this.serviceName, sanitizedName, data);
|
||||
}
|
||||
|
||||
async deleteCredentials(serverName: string): Promise<void> {
|
||||
if (!(await this.checkKeychainAvailability())) {
|
||||
throw new Error('Keychain is not available');
|
||||
}
|
||||
|
||||
const keytar = await this.getKeytar();
|
||||
if (!keytar) {
|
||||
throw new Error('Keytar module not available');
|
||||
}
|
||||
|
||||
const sanitizedName = this.sanitizeServerName(serverName);
|
||||
const deleted = await keytar.deletePassword(
|
||||
this.serviceName,
|
||||
sanitizedName,
|
||||
);
|
||||
|
||||
if (!deleted) {
|
||||
throw new Error(`No credentials found for ${serverName}`);
|
||||
}
|
||||
}
|
||||
|
||||
async listServers(): Promise<string[]> {
|
||||
if (!(await this.checkKeychainAvailability())) {
|
||||
throw new Error('Keychain is not available');
|
||||
}
|
||||
|
||||
const keytar = await this.getKeytar();
|
||||
if (!keytar) {
|
||||
throw new Error('Keytar module not available');
|
||||
}
|
||||
|
||||
try {
|
||||
const credentials = await keytar.findCredentials(this.serviceName);
|
||||
return credentials
|
||||
.filter((cred) => !cred.account.startsWith(KEYCHAIN_TEST_PREFIX))
|
||||
.map((cred: { account: string }) => cred.account);
|
||||
} catch (error) {
|
||||
console.error('Failed to list servers from keychain:', error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
async getAllCredentials(): Promise<Map<string, OAuthCredentials>> {
|
||||
if (!(await this.checkKeychainAvailability())) {
|
||||
throw new Error('Keychain is not available');
|
||||
}
|
||||
|
||||
const keytar = await this.getKeytar();
|
||||
if (!keytar) {
|
||||
throw new Error('Keytar module not available');
|
||||
}
|
||||
|
||||
const result = new Map<string, OAuthCredentials>();
|
||||
try {
|
||||
const credentials = (
|
||||
await keytar.findCredentials(this.serviceName)
|
||||
).filter((c) => !c.account.startsWith(KEYCHAIN_TEST_PREFIX));
|
||||
|
||||
for (const cred of credentials) {
|
||||
try {
|
||||
const data = JSON.parse(cred.password) as OAuthCredentials;
|
||||
if (!this.isTokenExpired(data)) {
|
||||
result.set(cred.account, data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Failed to parse credentials for ${cred.account}:`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to get all credentials from keychain:', error);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
async clearAll(): Promise<void> {
|
||||
if (!(await this.checkKeychainAvailability())) {
|
||||
throw new Error('Keychain is not available');
|
||||
}
|
||||
|
||||
const servers = this.keytarModule
|
||||
? await this.keytarModule
|
||||
.findCredentials(this.serviceName)
|
||||
.then((creds) => creds.map((c) => c.account))
|
||||
.catch((error: Error) => {
|
||||
throw new Error(
|
||||
`Failed to list servers for clearing: ${error.message}`,
|
||||
);
|
||||
})
|
||||
: [];
|
||||
const errors: Error[] = [];
|
||||
|
||||
for (const server of servers) {
|
||||
try {
|
||||
await this.deleteCredentials(server);
|
||||
} catch (error) {
|
||||
errors.push(error as Error);
|
||||
}
|
||||
}
|
||||
|
||||
if (errors.length > 0) {
|
||||
throw new Error(
|
||||
`Failed to clear some credentials: ${errors.map((e) => e.message).join(', ')}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Checks whether or not a set-get-delete cycle with the keychain works.
|
||||
// Returns false if any operation fails.
|
||||
async checkKeychainAvailability(): Promise<boolean> {
|
||||
if (this.keychainAvailable !== null) {
|
||||
return this.keychainAvailable;
|
||||
}
|
||||
|
||||
try {
|
||||
const keytar = await this.getKeytar();
|
||||
if (!keytar) {
|
||||
this.keychainAvailable = false;
|
||||
return false;
|
||||
}
|
||||
|
||||
const testAccount = `${KEYCHAIN_TEST_PREFIX}${crypto.randomBytes(8).toString('hex')}`;
|
||||
const testPassword = 'test';
|
||||
|
||||
await keytar.setPassword(this.serviceName, testAccount, testPassword);
|
||||
const retrieved = await keytar.getPassword(this.serviceName, testAccount);
|
||||
const deleted = await keytar.deletePassword(
|
||||
this.serviceName,
|
||||
testAccount,
|
||||
);
|
||||
|
||||
const success = deleted && retrieved === testPassword;
|
||||
this.keychainAvailable = success;
|
||||
return success;
|
||||
} catch (_error) {
|
||||
this.keychainAvailable = false;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
async isAvailable(): Promise<boolean> {
|
||||
return this.checkKeychainAvailability();
|
||||
}
|
||||
}
|
||||
@@ -35,3 +35,8 @@ export interface TokenStorage {
|
||||
getAllCredentials(): Promise<Map<string, OAuthCredentials>>;
|
||||
clearAll(): Promise<void>;
|
||||
}
|
||||
|
||||
export enum TokenStorageType {
|
||||
KEYCHAIN = 'keychain',
|
||||
ENCRYPTED_FILE = 'encrypted_file',
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user