feat(security) - Make oauth token storage implement the shared interface (#7802)

Co-authored-by: Shi Shu <shii@google.com>
This commit is contained in:
shishu314
2025-09-05 12:08:50 -04:00
committed by GitHub
parent af52b04e6f
commit 918ab3c2ec
6 changed files with 99 additions and 76 deletions

View File

@@ -143,7 +143,7 @@ const getMcpStatus = async (
'@google/gemini-cli-core'
);
const tokenStorage = new MCPOAuthTokenStorage();
const hasToken = await tokenStorage.getToken(serverName);
const hasToken = await tokenStorage.getCredentials(serverName);
if (hasToken) {
const isExpired = tokenStorage.isTokenExpired(hasToken.token);
if (isExpired) {

View File

@@ -14,16 +14,16 @@ vi.mock('../utils/secure-browser-launcher.js', () => ({
vi.mock('node:crypto');
vi.mock('./oauth-token-storage.js', () => {
const mockSaveToken = vi.fn();
const mockGetToken = vi.fn();
const mockGetCredentials = vi.fn();
const mockIsTokenExpired = vi.fn();
const mockRemoveToken = vi.fn();
const mockdeleteCredentials = vi.fn();
return {
MCPOAuthTokenStorage: vi.fn(() => ({
saveToken: mockSaveToken,
getToken: mockGetToken,
getCredentials: mockGetCredentials,
isTokenExpired: mockIsTokenExpired,
removeToken: mockRemoveToken,
deleteCredentials: mockdeleteCredentials,
})),
};
});
@@ -163,7 +163,7 @@ describe('MCPOAuthProvider', () => {
// Mock token storage
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.saveToken).mockResolvedValue(undefined);
vi.mocked(tokenStorage.getToken).mockResolvedValue(null);
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(null);
});
afterEach(() => {
@@ -638,7 +638,9 @@ describe('MCPOAuthProvider', () => {
};
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getToken).mockResolvedValue(validCredentials);
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
validCredentials,
);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(false);
const authProvider = new MCPOAuthProvider();
@@ -660,7 +662,9 @@ describe('MCPOAuthProvider', () => {
};
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getToken).mockResolvedValue(expiredCredentials);
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
expiredCredentials,
);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
const refreshResponse = {
@@ -697,7 +701,7 @@ describe('MCPOAuthProvider', () => {
it('should return null when no credentials exist', async () => {
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getToken).mockResolvedValue(null);
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(null);
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
@@ -718,9 +722,11 @@ describe('MCPOAuthProvider', () => {
};
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getToken).mockResolvedValue(expiredCredentials);
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
expiredCredentials,
);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
vi.mocked(tokenStorage.removeToken).mockResolvedValue(undefined);
vi.mocked(tokenStorage.deleteCredentials).mockResolvedValue(undefined);
mockFetch.mockResolvedValueOnce(
createMockResponse({
@@ -738,7 +744,9 @@ describe('MCPOAuthProvider', () => {
);
expect(result).toBeNull();
expect(tokenStorage.removeToken).toHaveBeenCalledWith('test-server');
expect(tokenStorage.deleteCredentials).toHaveBeenCalledWith(
'test-server',
);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to refresh token'),
);
@@ -758,7 +766,9 @@ describe('MCPOAuthProvider', () => {
};
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getToken).mockResolvedValue(tokenWithoutRefresh);
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
tokenWithoutRefresh,
);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
const authProvider = new MCPOAuthProvider();

View File

@@ -798,7 +798,7 @@ export class MCPOAuthProvider {
console.log('Authentication successful! Token saved.');
// Verify token was saved
const savedToken = await this.tokenStorage.getToken(serverName);
const savedToken = await this.tokenStorage.getCredentials(serverName);
if (savedToken && savedToken.token && savedToken.token.accessToken) {
const tokenPreview =
savedToken.token.accessToken.length > 20
@@ -830,7 +830,7 @@ export class MCPOAuthProvider {
config: MCPOAuthConfig,
): Promise<string | null> {
console.debug(`Getting valid token for server: ${serverName}`);
const credentials = await this.tokenStorage.getToken(serverName);
const credentials = await this.tokenStorage.getCredentials(serverName);
if (!credentials) {
console.debug(`No credentials found for server: ${serverName}`);
@@ -884,7 +884,7 @@ export class MCPOAuthProvider {
} catch (error) {
console.error(`Failed to refresh token: ${getErrorMessage(error)}`);
// Remove invalid token
await this.tokenStorage.removeToken(serverName);
await this.tokenStorage.deleteCredentials(serverName);
}
}

View File

@@ -54,11 +54,11 @@ describe('MCPOAuthTokenStorage', () => {
vi.restoreAllMocks();
});
describe('loadTokens', () => {
describe('getAllCredentials', () => {
it('should return empty map when token file does not exist', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
const tokens = await tokenStorage.loadTokens();
const tokens = await tokenStorage.getAllCredentials();
expect(tokens.size).toBe(0);
expect(console.error).not.toHaveBeenCalled();
@@ -68,7 +68,7 @@ describe('MCPOAuthTokenStorage', () => {
const tokensArray = [mockCredentials];
vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(tokensArray));
const tokens = await tokenStorage.loadTokens();
const tokens = await tokenStorage.getAllCredentials();
expect(tokens.size).toBe(1);
expect(tokens.get('test-server')).toEqual(mockCredentials);
@@ -81,7 +81,7 @@ describe('MCPOAuthTokenStorage', () => {
it('should handle corrupted token file gracefully', async () => {
vi.mocked(fs.readFile).mockResolvedValue('invalid json');
const tokens = await tokenStorage.loadTokens();
const tokens = await tokenStorage.getAllCredentials();
expect(tokens.size).toBe(0);
expect(console.error).toHaveBeenCalledWith(
@@ -93,7 +93,7 @@ describe('MCPOAuthTokenStorage', () => {
const error = new Error('Permission denied');
vi.mocked(fs.readFile).mockRejectedValue(error);
const tokens = await tokenStorage.loadTokens();
const tokens = await tokenStorage.getAllCredentials();
expect(tokens.size).toBe(0);
expect(console.error).toHaveBeenCalledWith(
@@ -163,13 +163,13 @@ describe('MCPOAuthTokenStorage', () => {
});
});
describe('getToken', () => {
describe('getCredentials', () => {
it('should return token for existing server', async () => {
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([mockCredentials]),
);
const result = await tokenStorage.getToken('test-server');
const result = await tokenStorage.getCredentials('test-server');
expect(result).toEqual(mockCredentials);
});
@@ -179,7 +179,7 @@ describe('MCPOAuthTokenStorage', () => {
JSON.stringify([mockCredentials]),
);
const result = await tokenStorage.getToken('non-existent');
const result = await tokenStorage.getCredentials('non-existent');
expect(result).toBeNull();
});
@@ -187,13 +187,13 @@ describe('MCPOAuthTokenStorage', () => {
it('should return null when no tokens file exists', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
const result = await tokenStorage.getToken('test-server');
const result = await tokenStorage.getCredentials('test-server');
expect(result).toBeNull();
});
});
describe('removeToken', () => {
describe('deleteCredentials', () => {
it('should remove token for specific server', async () => {
const credentials1 = { ...mockCredentials, serverName: 'server1' };
const credentials2 = { ...mockCredentials, serverName: 'server2' };
@@ -202,7 +202,7 @@ describe('MCPOAuthTokenStorage', () => {
);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
await tokenStorage.removeToken('server1');
await tokenStorage.deleteCredentials('server1');
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
const savedData = JSON.parse(writeCall[1] as string);
@@ -217,7 +217,7 @@ describe('MCPOAuthTokenStorage', () => {
);
vi.mocked(fs.unlink).mockResolvedValue(undefined);
await tokenStorage.removeToken('test-server');
await tokenStorage.deleteCredentials('test-server');
expect(fs.unlink).toHaveBeenCalledWith(
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
@@ -230,7 +230,7 @@ describe('MCPOAuthTokenStorage', () => {
JSON.stringify([mockCredentials]),
);
await tokenStorage.removeToken('non-existent');
await tokenStorage.deleteCredentials('non-existent');
expect(fs.writeFile).not.toHaveBeenCalled();
expect(fs.unlink).not.toHaveBeenCalled();
@@ -242,7 +242,7 @@ describe('MCPOAuthTokenStorage', () => {
);
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
await tokenStorage.removeToken('test-server');
await tokenStorage.deleteCredentials('test-server');
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to remove MCP OAuth token'),
@@ -294,11 +294,11 @@ describe('MCPOAuthTokenStorage', () => {
});
});
describe('clearAllTokens', () => {
describe('clearAll', () => {
it('should remove token file successfully', async () => {
vi.mocked(fs.unlink).mockResolvedValue(undefined);
await tokenStorage.clearAllTokens();
await tokenStorage.clearAll();
expect(fs.unlink).toHaveBeenCalledWith(
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
@@ -308,7 +308,7 @@ describe('MCPOAuthTokenStorage', () => {
it('should handle non-existent file gracefully', async () => {
vi.mocked(fs.unlink).mockRejectedValue({ code: 'ENOENT' });
await tokenStorage.clearAllTokens();
await tokenStorage.clearAll();
expect(console.error).not.toHaveBeenCalled();
});
@@ -316,7 +316,7 @@ describe('MCPOAuthTokenStorage', () => {
it('should handle other file errors gracefully', async () => {
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
await tokenStorage.clearAllTokens();
await tokenStorage.clearAll();
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to clear MCP OAuth tokens'),

View File

@@ -8,12 +8,16 @@ import { promises as fs } from 'node:fs';
import * as path from 'node:path';
import { Storage } from '../config/storage.js';
import { getErrorMessage } from '../utils/errors.js';
import type { OAuthToken, OAuthCredentials } from './token-storage/types.js';
import type {
OAuthToken,
OAuthCredentials,
TokenStorage,
} from './token-storage/types.js';
/**
* Class for managing MCP OAuth token storage and retrieval.
*/
export class MCPOAuthTokenStorage {
export class MCPOAuthTokenStorage implements TokenStorage {
/**
* Get the path to the token storage file.
*
@@ -36,7 +40,7 @@ export class MCPOAuthTokenStorage {
*
* @returns A map of server names to credentials
*/
async loadTokens(): Promise<Map<string, OAuthCredentials>> {
async getAllCredentials(): Promise<Map<string, OAuthCredentials>> {
const tokenMap = new Map<string, OAuthCredentials>();
try {
@@ -59,36 +63,14 @@ export class MCPOAuthTokenStorage {
return tokenMap;
}
/**
* Save a token for a specific MCP server.
*
* @param serverName The name of the MCP server
* @param token The OAuth token to save
* @param clientId Optional client ID used for this token
* @param tokenUrl Optional token URL used for this token
* @param mcpServerUrl Optional MCP server URL
*/
async saveToken(
serverName: string,
token: OAuthToken,
clientId?: string,
tokenUrl?: string,
mcpServerUrl?: string,
): Promise<void> {
await this.ensureConfigDir();
async listServers(): Promise<string[]> {
const tokens = await this.getAllCredentials();
return Array.from(tokens.keys());
}
const tokens = await this.loadTokens();
const credential: OAuthCredentials = {
serverName,
token,
clientId,
tokenUrl,
mcpServerUrl,
updatedAt: Date.now(),
};
tokens.set(serverName, credential);
async setCredentials(credentials: OAuthCredentials): Promise<void> {
const tokens = await this.getAllCredentials();
tokens.set(credentials.serverName, credentials);
const tokenArray = Array.from(tokens.values());
const tokenFile = this.getTokenFilePath();
@@ -107,14 +89,44 @@ export class MCPOAuthTokenStorage {
}
}
/**
* Save a token for a specific MCP server.
*
* @param serverName The name of the MCP server
* @param token The OAuth token to save
* @param clientId Optional client ID used for this token
* @param tokenUrl Optional token URL used for this token
* @param mcpServerUrl Optional MCP server URL
*/
async saveToken(
serverName: string,
token: OAuthToken,
clientId?: string,
tokenUrl?: string,
mcpServerUrl?: string,
): Promise<void> {
await this.ensureConfigDir();
const credential: OAuthCredentials = {
serverName,
token,
clientId,
tokenUrl,
mcpServerUrl,
updatedAt: Date.now(),
};
await this.setCredentials(credential);
}
/**
* Get a token for a specific MCP server.
*
* @param serverName The name of the MCP server
* @returns The stored credentials or null if not found
*/
async getToken(serverName: string): Promise<OAuthCredentials | null> {
const tokens = await this.loadTokens();
async getCredentials(serverName: string): Promise<OAuthCredentials | null> {
const tokens = await this.getAllCredentials();
return tokens.get(serverName) || null;
}
@@ -123,8 +135,8 @@ export class MCPOAuthTokenStorage {
*
* @param serverName The name of the MCP server
*/
async removeToken(serverName: string): Promise<void> {
const tokens = await this.loadTokens();
async deleteCredentials(serverName: string): Promise<void> {
const tokens = await this.getAllCredentials();
if (tokens.delete(serverName)) {
const tokenArray = Array.from(tokens.values());
@@ -166,7 +178,7 @@ export class MCPOAuthTokenStorage {
/**
* Clear all stored MCP OAuth tokens.
*/
async clearAllTokens(): Promise<void> {
async clearAll(): Promise<void> {
try {
const tokenFile = this.getTokenFilePath();
await fs.unlink(tokenFile);

View File

@@ -897,7 +897,7 @@ export async function connectToMcpServer(
if (!shouldTriggerOAuth) {
// For SSE servers without explicit OAuth config, if a token was found but rejected, report it accurately.
const tokenStorage = new MCPOAuthTokenStorage();
const credentials = await tokenStorage.getToken(mcpServerName);
const credentials = await tokenStorage.getCredentials(mcpServerName);
if (credentials) {
const authProvider = new MCPOAuthProvider(tokenStorage);
const hasStoredTokens = await authProvider.getValidToken(
@@ -982,7 +982,7 @@ export async function connectToMcpServer(
// Get the valid token - we need to create a proper OAuth config
// The token should already be available from the authentication process
const tokenStorage = new MCPOAuthTokenStorage();
const credentials = await tokenStorage.getToken(mcpServerName);
const credentials = await tokenStorage.getCredentials(mcpServerName);
if (credentials) {
const authProvider = new MCPOAuthProvider(tokenStorage);
const accessToken = await authProvider.getValidToken(
@@ -1057,7 +1057,7 @@ export async function connectToMcpServer(
if (!shouldTryDiscovery) {
const tokenStorage = new MCPOAuthTokenStorage();
const credentials = await tokenStorage.getToken(mcpServerName);
const credentials = await tokenStorage.getCredentials(mcpServerName);
if (credentials) {
const authProvider = new MCPOAuthProvider(tokenStorage);
const hasStoredTokens = await authProvider.getValidToken(
@@ -1128,7 +1128,8 @@ export async function connectToMcpServer(
// Retry connection with OAuth token
const tokenStorage = new MCPOAuthTokenStorage();
const credentials = await tokenStorage.getToken(mcpServerName);
const credentials =
await tokenStorage.getCredentials(mcpServerName);
if (credentials) {
const authProvider = new MCPOAuthProvider(tokenStorage);
const accessToken = await authProvider.getValidToken(
@@ -1286,7 +1287,7 @@ export async function createTransport(
} else {
// Check if we have stored OAuth tokens for this server (from previous authentication)
const tokenStorage = new MCPOAuthTokenStorage();
const credentials = await tokenStorage.getToken(mcpServerName);
const credentials = await tokenStorage.getCredentials(mcpServerName);
if (credentials) {
const authProvider = new MCPOAuthProvider(tokenStorage);
accessToken = await authProvider.getValidToken(mcpServerName, {