mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-22 02:54:31 -07:00
feat(security) - Make oauth token storage implement the shared interface (#7802)
Co-authored-by: Shi Shu <shii@google.com>
This commit is contained in:
@@ -143,7 +143,7 @@ const getMcpStatus = async (
|
|||||||
'@google/gemini-cli-core'
|
'@google/gemini-cli-core'
|
||||||
);
|
);
|
||||||
const tokenStorage = new MCPOAuthTokenStorage();
|
const tokenStorage = new MCPOAuthTokenStorage();
|
||||||
const hasToken = await tokenStorage.getToken(serverName);
|
const hasToken = await tokenStorage.getCredentials(serverName);
|
||||||
if (hasToken) {
|
if (hasToken) {
|
||||||
const isExpired = tokenStorage.isTokenExpired(hasToken.token);
|
const isExpired = tokenStorage.isTokenExpired(hasToken.token);
|
||||||
if (isExpired) {
|
if (isExpired) {
|
||||||
|
|||||||
@@ -14,16 +14,16 @@ vi.mock('../utils/secure-browser-launcher.js', () => ({
|
|||||||
vi.mock('node:crypto');
|
vi.mock('node:crypto');
|
||||||
vi.mock('./oauth-token-storage.js', () => {
|
vi.mock('./oauth-token-storage.js', () => {
|
||||||
const mockSaveToken = vi.fn();
|
const mockSaveToken = vi.fn();
|
||||||
const mockGetToken = vi.fn();
|
const mockGetCredentials = vi.fn();
|
||||||
const mockIsTokenExpired = vi.fn();
|
const mockIsTokenExpired = vi.fn();
|
||||||
const mockRemoveToken = vi.fn();
|
const mockdeleteCredentials = vi.fn();
|
||||||
|
|
||||||
return {
|
return {
|
||||||
MCPOAuthTokenStorage: vi.fn(() => ({
|
MCPOAuthTokenStorage: vi.fn(() => ({
|
||||||
saveToken: mockSaveToken,
|
saveToken: mockSaveToken,
|
||||||
getToken: mockGetToken,
|
getCredentials: mockGetCredentials,
|
||||||
isTokenExpired: mockIsTokenExpired,
|
isTokenExpired: mockIsTokenExpired,
|
||||||
removeToken: mockRemoveToken,
|
deleteCredentials: mockdeleteCredentials,
|
||||||
})),
|
})),
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
@@ -163,7 +163,7 @@ describe('MCPOAuthProvider', () => {
|
|||||||
// Mock token storage
|
// Mock token storage
|
||||||
const tokenStorage = new MCPOAuthTokenStorage();
|
const tokenStorage = new MCPOAuthTokenStorage();
|
||||||
vi.mocked(tokenStorage.saveToken).mockResolvedValue(undefined);
|
vi.mocked(tokenStorage.saveToken).mockResolvedValue(undefined);
|
||||||
vi.mocked(tokenStorage.getToken).mockResolvedValue(null);
|
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(null);
|
||||||
});
|
});
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
@@ -638,7 +638,9 @@ describe('MCPOAuthProvider', () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const tokenStorage = new MCPOAuthTokenStorage();
|
const tokenStorage = new MCPOAuthTokenStorage();
|
||||||
vi.mocked(tokenStorage.getToken).mockResolvedValue(validCredentials);
|
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
|
||||||
|
validCredentials,
|
||||||
|
);
|
||||||
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(false);
|
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(false);
|
||||||
|
|
||||||
const authProvider = new MCPOAuthProvider();
|
const authProvider = new MCPOAuthProvider();
|
||||||
@@ -660,7 +662,9 @@ describe('MCPOAuthProvider', () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const tokenStorage = new MCPOAuthTokenStorage();
|
const tokenStorage = new MCPOAuthTokenStorage();
|
||||||
vi.mocked(tokenStorage.getToken).mockResolvedValue(expiredCredentials);
|
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
|
||||||
|
expiredCredentials,
|
||||||
|
);
|
||||||
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
|
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
|
||||||
|
|
||||||
const refreshResponse = {
|
const refreshResponse = {
|
||||||
@@ -697,7 +701,7 @@ describe('MCPOAuthProvider', () => {
|
|||||||
|
|
||||||
it('should return null when no credentials exist', async () => {
|
it('should return null when no credentials exist', async () => {
|
||||||
const tokenStorage = new MCPOAuthTokenStorage();
|
const tokenStorage = new MCPOAuthTokenStorage();
|
||||||
vi.mocked(tokenStorage.getToken).mockResolvedValue(null);
|
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(null);
|
||||||
|
|
||||||
const authProvider = new MCPOAuthProvider();
|
const authProvider = new MCPOAuthProvider();
|
||||||
const result = await authProvider.getValidToken(
|
const result = await authProvider.getValidToken(
|
||||||
@@ -718,9 +722,11 @@ describe('MCPOAuthProvider', () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const tokenStorage = new MCPOAuthTokenStorage();
|
const tokenStorage = new MCPOAuthTokenStorage();
|
||||||
vi.mocked(tokenStorage.getToken).mockResolvedValue(expiredCredentials);
|
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
|
||||||
|
expiredCredentials,
|
||||||
|
);
|
||||||
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
|
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
|
||||||
vi.mocked(tokenStorage.removeToken).mockResolvedValue(undefined);
|
vi.mocked(tokenStorage.deleteCredentials).mockResolvedValue(undefined);
|
||||||
|
|
||||||
mockFetch.mockResolvedValueOnce(
|
mockFetch.mockResolvedValueOnce(
|
||||||
createMockResponse({
|
createMockResponse({
|
||||||
@@ -738,7 +744,9 @@ describe('MCPOAuthProvider', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
expect(result).toBeNull();
|
expect(result).toBeNull();
|
||||||
expect(tokenStorage.removeToken).toHaveBeenCalledWith('test-server');
|
expect(tokenStorage.deleteCredentials).toHaveBeenCalledWith(
|
||||||
|
'test-server',
|
||||||
|
);
|
||||||
expect(console.error).toHaveBeenCalledWith(
|
expect(console.error).toHaveBeenCalledWith(
|
||||||
expect.stringContaining('Failed to refresh token'),
|
expect.stringContaining('Failed to refresh token'),
|
||||||
);
|
);
|
||||||
@@ -758,7 +766,9 @@ describe('MCPOAuthProvider', () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const tokenStorage = new MCPOAuthTokenStorage();
|
const tokenStorage = new MCPOAuthTokenStorage();
|
||||||
vi.mocked(tokenStorage.getToken).mockResolvedValue(tokenWithoutRefresh);
|
vi.mocked(tokenStorage.getCredentials).mockResolvedValue(
|
||||||
|
tokenWithoutRefresh,
|
||||||
|
);
|
||||||
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
|
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
|
||||||
|
|
||||||
const authProvider = new MCPOAuthProvider();
|
const authProvider = new MCPOAuthProvider();
|
||||||
|
|||||||
@@ -798,7 +798,7 @@ export class MCPOAuthProvider {
|
|||||||
console.log('Authentication successful! Token saved.');
|
console.log('Authentication successful! Token saved.');
|
||||||
|
|
||||||
// Verify token was saved
|
// Verify token was saved
|
||||||
const savedToken = await this.tokenStorage.getToken(serverName);
|
const savedToken = await this.tokenStorage.getCredentials(serverName);
|
||||||
if (savedToken && savedToken.token && savedToken.token.accessToken) {
|
if (savedToken && savedToken.token && savedToken.token.accessToken) {
|
||||||
const tokenPreview =
|
const tokenPreview =
|
||||||
savedToken.token.accessToken.length > 20
|
savedToken.token.accessToken.length > 20
|
||||||
@@ -830,7 +830,7 @@ export class MCPOAuthProvider {
|
|||||||
config: MCPOAuthConfig,
|
config: MCPOAuthConfig,
|
||||||
): Promise<string | null> {
|
): Promise<string | null> {
|
||||||
console.debug(`Getting valid token for server: ${serverName}`);
|
console.debug(`Getting valid token for server: ${serverName}`);
|
||||||
const credentials = await this.tokenStorage.getToken(serverName);
|
const credentials = await this.tokenStorage.getCredentials(serverName);
|
||||||
|
|
||||||
if (!credentials) {
|
if (!credentials) {
|
||||||
console.debug(`No credentials found for server: ${serverName}`);
|
console.debug(`No credentials found for server: ${serverName}`);
|
||||||
@@ -884,7 +884,7 @@ export class MCPOAuthProvider {
|
|||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`Failed to refresh token: ${getErrorMessage(error)}`);
|
console.error(`Failed to refresh token: ${getErrorMessage(error)}`);
|
||||||
// Remove invalid token
|
// Remove invalid token
|
||||||
await this.tokenStorage.removeToken(serverName);
|
await this.tokenStorage.deleteCredentials(serverName);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -54,11 +54,11 @@ describe('MCPOAuthTokenStorage', () => {
|
|||||||
vi.restoreAllMocks();
|
vi.restoreAllMocks();
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('loadTokens', () => {
|
describe('getAllCredentials', () => {
|
||||||
it('should return empty map when token file does not exist', async () => {
|
it('should return empty map when token file does not exist', async () => {
|
||||||
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
||||||
|
|
||||||
const tokens = await tokenStorage.loadTokens();
|
const tokens = await tokenStorage.getAllCredentials();
|
||||||
|
|
||||||
expect(tokens.size).toBe(0);
|
expect(tokens.size).toBe(0);
|
||||||
expect(console.error).not.toHaveBeenCalled();
|
expect(console.error).not.toHaveBeenCalled();
|
||||||
@@ -68,7 +68,7 @@ describe('MCPOAuthTokenStorage', () => {
|
|||||||
const tokensArray = [mockCredentials];
|
const tokensArray = [mockCredentials];
|
||||||
vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(tokensArray));
|
vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(tokensArray));
|
||||||
|
|
||||||
const tokens = await tokenStorage.loadTokens();
|
const tokens = await tokenStorage.getAllCredentials();
|
||||||
|
|
||||||
expect(tokens.size).toBe(1);
|
expect(tokens.size).toBe(1);
|
||||||
expect(tokens.get('test-server')).toEqual(mockCredentials);
|
expect(tokens.get('test-server')).toEqual(mockCredentials);
|
||||||
@@ -81,7 +81,7 @@ describe('MCPOAuthTokenStorage', () => {
|
|||||||
it('should handle corrupted token file gracefully', async () => {
|
it('should handle corrupted token file gracefully', async () => {
|
||||||
vi.mocked(fs.readFile).mockResolvedValue('invalid json');
|
vi.mocked(fs.readFile).mockResolvedValue('invalid json');
|
||||||
|
|
||||||
const tokens = await tokenStorage.loadTokens();
|
const tokens = await tokenStorage.getAllCredentials();
|
||||||
|
|
||||||
expect(tokens.size).toBe(0);
|
expect(tokens.size).toBe(0);
|
||||||
expect(console.error).toHaveBeenCalledWith(
|
expect(console.error).toHaveBeenCalledWith(
|
||||||
@@ -93,7 +93,7 @@ describe('MCPOAuthTokenStorage', () => {
|
|||||||
const error = new Error('Permission denied');
|
const error = new Error('Permission denied');
|
||||||
vi.mocked(fs.readFile).mockRejectedValue(error);
|
vi.mocked(fs.readFile).mockRejectedValue(error);
|
||||||
|
|
||||||
const tokens = await tokenStorage.loadTokens();
|
const tokens = await tokenStorage.getAllCredentials();
|
||||||
|
|
||||||
expect(tokens.size).toBe(0);
|
expect(tokens.size).toBe(0);
|
||||||
expect(console.error).toHaveBeenCalledWith(
|
expect(console.error).toHaveBeenCalledWith(
|
||||||
@@ -163,13 +163,13 @@ describe('MCPOAuthTokenStorage', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('getToken', () => {
|
describe('getCredentials', () => {
|
||||||
it('should return token for existing server', async () => {
|
it('should return token for existing server', async () => {
|
||||||
vi.mocked(fs.readFile).mockResolvedValue(
|
vi.mocked(fs.readFile).mockResolvedValue(
|
||||||
JSON.stringify([mockCredentials]),
|
JSON.stringify([mockCredentials]),
|
||||||
);
|
);
|
||||||
|
|
||||||
const result = await tokenStorage.getToken('test-server');
|
const result = await tokenStorage.getCredentials('test-server');
|
||||||
|
|
||||||
expect(result).toEqual(mockCredentials);
|
expect(result).toEqual(mockCredentials);
|
||||||
});
|
});
|
||||||
@@ -179,7 +179,7 @@ describe('MCPOAuthTokenStorage', () => {
|
|||||||
JSON.stringify([mockCredentials]),
|
JSON.stringify([mockCredentials]),
|
||||||
);
|
);
|
||||||
|
|
||||||
const result = await tokenStorage.getToken('non-existent');
|
const result = await tokenStorage.getCredentials('non-existent');
|
||||||
|
|
||||||
expect(result).toBeNull();
|
expect(result).toBeNull();
|
||||||
});
|
});
|
||||||
@@ -187,13 +187,13 @@ describe('MCPOAuthTokenStorage', () => {
|
|||||||
it('should return null when no tokens file exists', async () => {
|
it('should return null when no tokens file exists', async () => {
|
||||||
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
||||||
|
|
||||||
const result = await tokenStorage.getToken('test-server');
|
const result = await tokenStorage.getCredentials('test-server');
|
||||||
|
|
||||||
expect(result).toBeNull();
|
expect(result).toBeNull();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('removeToken', () => {
|
describe('deleteCredentials', () => {
|
||||||
it('should remove token for specific server', async () => {
|
it('should remove token for specific server', async () => {
|
||||||
const credentials1 = { ...mockCredentials, serverName: 'server1' };
|
const credentials1 = { ...mockCredentials, serverName: 'server1' };
|
||||||
const credentials2 = { ...mockCredentials, serverName: 'server2' };
|
const credentials2 = { ...mockCredentials, serverName: 'server2' };
|
||||||
@@ -202,7 +202,7 @@ describe('MCPOAuthTokenStorage', () => {
|
|||||||
);
|
);
|
||||||
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
|
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
|
||||||
|
|
||||||
await tokenStorage.removeToken('server1');
|
await tokenStorage.deleteCredentials('server1');
|
||||||
|
|
||||||
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
|
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
|
||||||
const savedData = JSON.parse(writeCall[1] as string);
|
const savedData = JSON.parse(writeCall[1] as string);
|
||||||
@@ -217,7 +217,7 @@ describe('MCPOAuthTokenStorage', () => {
|
|||||||
);
|
);
|
||||||
vi.mocked(fs.unlink).mockResolvedValue(undefined);
|
vi.mocked(fs.unlink).mockResolvedValue(undefined);
|
||||||
|
|
||||||
await tokenStorage.removeToken('test-server');
|
await tokenStorage.deleteCredentials('test-server');
|
||||||
|
|
||||||
expect(fs.unlink).toHaveBeenCalledWith(
|
expect(fs.unlink).toHaveBeenCalledWith(
|
||||||
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
|
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
|
||||||
@@ -230,7 +230,7 @@ describe('MCPOAuthTokenStorage', () => {
|
|||||||
JSON.stringify([mockCredentials]),
|
JSON.stringify([mockCredentials]),
|
||||||
);
|
);
|
||||||
|
|
||||||
await tokenStorage.removeToken('non-existent');
|
await tokenStorage.deleteCredentials('non-existent');
|
||||||
|
|
||||||
expect(fs.writeFile).not.toHaveBeenCalled();
|
expect(fs.writeFile).not.toHaveBeenCalled();
|
||||||
expect(fs.unlink).not.toHaveBeenCalled();
|
expect(fs.unlink).not.toHaveBeenCalled();
|
||||||
@@ -242,7 +242,7 @@ describe('MCPOAuthTokenStorage', () => {
|
|||||||
);
|
);
|
||||||
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
|
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
|
||||||
|
|
||||||
await tokenStorage.removeToken('test-server');
|
await tokenStorage.deleteCredentials('test-server');
|
||||||
|
|
||||||
expect(console.error).toHaveBeenCalledWith(
|
expect(console.error).toHaveBeenCalledWith(
|
||||||
expect.stringContaining('Failed to remove MCP OAuth token'),
|
expect.stringContaining('Failed to remove MCP OAuth token'),
|
||||||
@@ -294,11 +294,11 @@ describe('MCPOAuthTokenStorage', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('clearAllTokens', () => {
|
describe('clearAll', () => {
|
||||||
it('should remove token file successfully', async () => {
|
it('should remove token file successfully', async () => {
|
||||||
vi.mocked(fs.unlink).mockResolvedValue(undefined);
|
vi.mocked(fs.unlink).mockResolvedValue(undefined);
|
||||||
|
|
||||||
await tokenStorage.clearAllTokens();
|
await tokenStorage.clearAll();
|
||||||
|
|
||||||
expect(fs.unlink).toHaveBeenCalledWith(
|
expect(fs.unlink).toHaveBeenCalledWith(
|
||||||
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
|
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
|
||||||
@@ -308,7 +308,7 @@ describe('MCPOAuthTokenStorage', () => {
|
|||||||
it('should handle non-existent file gracefully', async () => {
|
it('should handle non-existent file gracefully', async () => {
|
||||||
vi.mocked(fs.unlink).mockRejectedValue({ code: 'ENOENT' });
|
vi.mocked(fs.unlink).mockRejectedValue({ code: 'ENOENT' });
|
||||||
|
|
||||||
await tokenStorage.clearAllTokens();
|
await tokenStorage.clearAll();
|
||||||
|
|
||||||
expect(console.error).not.toHaveBeenCalled();
|
expect(console.error).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
@@ -316,7 +316,7 @@ describe('MCPOAuthTokenStorage', () => {
|
|||||||
it('should handle other file errors gracefully', async () => {
|
it('should handle other file errors gracefully', async () => {
|
||||||
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
|
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
|
||||||
|
|
||||||
await tokenStorage.clearAllTokens();
|
await tokenStorage.clearAll();
|
||||||
|
|
||||||
expect(console.error).toHaveBeenCalledWith(
|
expect(console.error).toHaveBeenCalledWith(
|
||||||
expect.stringContaining('Failed to clear MCP OAuth tokens'),
|
expect.stringContaining('Failed to clear MCP OAuth tokens'),
|
||||||
|
|||||||
@@ -8,12 +8,16 @@ import { promises as fs } from 'node:fs';
|
|||||||
import * as path from 'node:path';
|
import * as path from 'node:path';
|
||||||
import { Storage } from '../config/storage.js';
|
import { Storage } from '../config/storage.js';
|
||||||
import { getErrorMessage } from '../utils/errors.js';
|
import { getErrorMessage } from '../utils/errors.js';
|
||||||
import type { OAuthToken, OAuthCredentials } from './token-storage/types.js';
|
import type {
|
||||||
|
OAuthToken,
|
||||||
|
OAuthCredentials,
|
||||||
|
TokenStorage,
|
||||||
|
} from './token-storage/types.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Class for managing MCP OAuth token storage and retrieval.
|
* Class for managing MCP OAuth token storage and retrieval.
|
||||||
*/
|
*/
|
||||||
export class MCPOAuthTokenStorage {
|
export class MCPOAuthTokenStorage implements TokenStorage {
|
||||||
/**
|
/**
|
||||||
* Get the path to the token storage file.
|
* Get the path to the token storage file.
|
||||||
*
|
*
|
||||||
@@ -36,7 +40,7 @@ export class MCPOAuthTokenStorage {
|
|||||||
*
|
*
|
||||||
* @returns A map of server names to credentials
|
* @returns A map of server names to credentials
|
||||||
*/
|
*/
|
||||||
async loadTokens(): Promise<Map<string, OAuthCredentials>> {
|
async getAllCredentials(): Promise<Map<string, OAuthCredentials>> {
|
||||||
const tokenMap = new Map<string, OAuthCredentials>();
|
const tokenMap = new Map<string, OAuthCredentials>();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@@ -59,36 +63,14 @@ export class MCPOAuthTokenStorage {
|
|||||||
return tokenMap;
|
return tokenMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
async listServers(): Promise<string[]> {
|
||||||
* Save a token for a specific MCP server.
|
const tokens = await this.getAllCredentials();
|
||||||
*
|
return Array.from(tokens.keys());
|
||||||
* @param serverName The name of the MCP server
|
}
|
||||||
* @param token The OAuth token to save
|
|
||||||
* @param clientId Optional client ID used for this token
|
|
||||||
* @param tokenUrl Optional token URL used for this token
|
|
||||||
* @param mcpServerUrl Optional MCP server URL
|
|
||||||
*/
|
|
||||||
async saveToken(
|
|
||||||
serverName: string,
|
|
||||||
token: OAuthToken,
|
|
||||||
clientId?: string,
|
|
||||||
tokenUrl?: string,
|
|
||||||
mcpServerUrl?: string,
|
|
||||||
): Promise<void> {
|
|
||||||
await this.ensureConfigDir();
|
|
||||||
|
|
||||||
const tokens = await this.loadTokens();
|
async setCredentials(credentials: OAuthCredentials): Promise<void> {
|
||||||
|
const tokens = await this.getAllCredentials();
|
||||||
const credential: OAuthCredentials = {
|
tokens.set(credentials.serverName, credentials);
|
||||||
serverName,
|
|
||||||
token,
|
|
||||||
clientId,
|
|
||||||
tokenUrl,
|
|
||||||
mcpServerUrl,
|
|
||||||
updatedAt: Date.now(),
|
|
||||||
};
|
|
||||||
|
|
||||||
tokens.set(serverName, credential);
|
|
||||||
|
|
||||||
const tokenArray = Array.from(tokens.values());
|
const tokenArray = Array.from(tokens.values());
|
||||||
const tokenFile = this.getTokenFilePath();
|
const tokenFile = this.getTokenFilePath();
|
||||||
@@ -107,14 +89,44 @@ export class MCPOAuthTokenStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save a token for a specific MCP server.
|
||||||
|
*
|
||||||
|
* @param serverName The name of the MCP server
|
||||||
|
* @param token The OAuth token to save
|
||||||
|
* @param clientId Optional client ID used for this token
|
||||||
|
* @param tokenUrl Optional token URL used for this token
|
||||||
|
* @param mcpServerUrl Optional MCP server URL
|
||||||
|
*/
|
||||||
|
async saveToken(
|
||||||
|
serverName: string,
|
||||||
|
token: OAuthToken,
|
||||||
|
clientId?: string,
|
||||||
|
tokenUrl?: string,
|
||||||
|
mcpServerUrl?: string,
|
||||||
|
): Promise<void> {
|
||||||
|
await this.ensureConfigDir();
|
||||||
|
|
||||||
|
const credential: OAuthCredentials = {
|
||||||
|
serverName,
|
||||||
|
token,
|
||||||
|
clientId,
|
||||||
|
tokenUrl,
|
||||||
|
mcpServerUrl,
|
||||||
|
updatedAt: Date.now(),
|
||||||
|
};
|
||||||
|
|
||||||
|
await this.setCredentials(credential);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get a token for a specific MCP server.
|
* Get a token for a specific MCP server.
|
||||||
*
|
*
|
||||||
* @param serverName The name of the MCP server
|
* @param serverName The name of the MCP server
|
||||||
* @returns The stored credentials or null if not found
|
* @returns The stored credentials or null if not found
|
||||||
*/
|
*/
|
||||||
async getToken(serverName: string): Promise<OAuthCredentials | null> {
|
async getCredentials(serverName: string): Promise<OAuthCredentials | null> {
|
||||||
const tokens = await this.loadTokens();
|
const tokens = await this.getAllCredentials();
|
||||||
return tokens.get(serverName) || null;
|
return tokens.get(serverName) || null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,8 +135,8 @@ export class MCPOAuthTokenStorage {
|
|||||||
*
|
*
|
||||||
* @param serverName The name of the MCP server
|
* @param serverName The name of the MCP server
|
||||||
*/
|
*/
|
||||||
async removeToken(serverName: string): Promise<void> {
|
async deleteCredentials(serverName: string): Promise<void> {
|
||||||
const tokens = await this.loadTokens();
|
const tokens = await this.getAllCredentials();
|
||||||
|
|
||||||
if (tokens.delete(serverName)) {
|
if (tokens.delete(serverName)) {
|
||||||
const tokenArray = Array.from(tokens.values());
|
const tokenArray = Array.from(tokens.values());
|
||||||
@@ -166,7 +178,7 @@ export class MCPOAuthTokenStorage {
|
|||||||
/**
|
/**
|
||||||
* Clear all stored MCP OAuth tokens.
|
* Clear all stored MCP OAuth tokens.
|
||||||
*/
|
*/
|
||||||
async clearAllTokens(): Promise<void> {
|
async clearAll(): Promise<void> {
|
||||||
try {
|
try {
|
||||||
const tokenFile = this.getTokenFilePath();
|
const tokenFile = this.getTokenFilePath();
|
||||||
await fs.unlink(tokenFile);
|
await fs.unlink(tokenFile);
|
||||||
|
|||||||
@@ -897,7 +897,7 @@ export async function connectToMcpServer(
|
|||||||
if (!shouldTriggerOAuth) {
|
if (!shouldTriggerOAuth) {
|
||||||
// For SSE servers without explicit OAuth config, if a token was found but rejected, report it accurately.
|
// For SSE servers without explicit OAuth config, if a token was found but rejected, report it accurately.
|
||||||
const tokenStorage = new MCPOAuthTokenStorage();
|
const tokenStorage = new MCPOAuthTokenStorage();
|
||||||
const credentials = await tokenStorage.getToken(mcpServerName);
|
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
||||||
if (credentials) {
|
if (credentials) {
|
||||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||||
const hasStoredTokens = await authProvider.getValidToken(
|
const hasStoredTokens = await authProvider.getValidToken(
|
||||||
@@ -982,7 +982,7 @@ export async function connectToMcpServer(
|
|||||||
// Get the valid token - we need to create a proper OAuth config
|
// Get the valid token - we need to create a proper OAuth config
|
||||||
// The token should already be available from the authentication process
|
// The token should already be available from the authentication process
|
||||||
const tokenStorage = new MCPOAuthTokenStorage();
|
const tokenStorage = new MCPOAuthTokenStorage();
|
||||||
const credentials = await tokenStorage.getToken(mcpServerName);
|
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
||||||
if (credentials) {
|
if (credentials) {
|
||||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||||
const accessToken = await authProvider.getValidToken(
|
const accessToken = await authProvider.getValidToken(
|
||||||
@@ -1057,7 +1057,7 @@ export async function connectToMcpServer(
|
|||||||
|
|
||||||
if (!shouldTryDiscovery) {
|
if (!shouldTryDiscovery) {
|
||||||
const tokenStorage = new MCPOAuthTokenStorage();
|
const tokenStorage = new MCPOAuthTokenStorage();
|
||||||
const credentials = await tokenStorage.getToken(mcpServerName);
|
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
||||||
if (credentials) {
|
if (credentials) {
|
||||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||||
const hasStoredTokens = await authProvider.getValidToken(
|
const hasStoredTokens = await authProvider.getValidToken(
|
||||||
@@ -1128,7 +1128,8 @@ export async function connectToMcpServer(
|
|||||||
|
|
||||||
// Retry connection with OAuth token
|
// Retry connection with OAuth token
|
||||||
const tokenStorage = new MCPOAuthTokenStorage();
|
const tokenStorage = new MCPOAuthTokenStorage();
|
||||||
const credentials = await tokenStorage.getToken(mcpServerName);
|
const credentials =
|
||||||
|
await tokenStorage.getCredentials(mcpServerName);
|
||||||
if (credentials) {
|
if (credentials) {
|
||||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||||
const accessToken = await authProvider.getValidToken(
|
const accessToken = await authProvider.getValidToken(
|
||||||
@@ -1286,7 +1287,7 @@ export async function createTransport(
|
|||||||
} else {
|
} else {
|
||||||
// Check if we have stored OAuth tokens for this server (from previous authentication)
|
// Check if we have stored OAuth tokens for this server (from previous authentication)
|
||||||
const tokenStorage = new MCPOAuthTokenStorage();
|
const tokenStorage = new MCPOAuthTokenStorage();
|
||||||
const credentials = await tokenStorage.getToken(mcpServerName);
|
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
||||||
if (credentials) {
|
if (credentials) {
|
||||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||||
accessToken = await authProvider.getValidToken(mcpServerName, {
|
accessToken = await authProvider.getValidToken(mcpServerName, {
|
||||||
|
|||||||
Reference in New Issue
Block a user