Feat(security) - Make the OAuthTokenStorage non static (#7716)

Co-authored-by: Shi Shu <shii@google.com>
This commit is contained in:
shishu314
2025-09-04 16:42:47 -04:00
committed by GitHub
parent e088c06a9a
commit 35a841f71a
7 changed files with 188 additions and 149 deletions
+75 -48
View File
@@ -12,7 +12,21 @@ vi.mock('../utils/secure-browser-launcher.js', () => ({
openBrowserSecurely: mockOpenBrowserSecurely,
}));
vi.mock('node:crypto');
vi.mock('./oauth-token-storage.js');
vi.mock('./oauth-token-storage.js', () => {
const mockSaveToken = vi.fn();
const mockGetToken = vi.fn();
const mockIsTokenExpired = vi.fn();
const mockRemoveToken = vi.fn();
return {
MCPOAuthTokenStorage: vi.fn(() => ({
saveToken: mockSaveToken,
getToken: mockGetToken,
isTokenExpired: mockIsTokenExpired,
removeToken: mockRemoveToken,
})),
};
});
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import * as http from 'node:http';
@@ -147,8 +161,9 @@ describe('MCPOAuthProvider', () => {
});
// Mock token storage
vi.mocked(MCPOAuthTokenStorage.saveToken).mockResolvedValue(undefined);
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null);
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.saveToken).mockResolvedValue(undefined);
vi.mocked(tokenStorage.getToken).mockResolvedValue(null);
});
afterEach(() => {
@@ -192,10 +207,8 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.authenticate(
'test-server',
mockConfig,
);
const authProvider = new MCPOAuthProvider();
const result = await authProvider.authenticate('test-server', mockConfig);
expect(result).toEqual({
accessToken: 'access_token_123',
@@ -208,7 +221,8 @@ describe('MCPOAuthProvider', () => {
expect(mockOpenBrowserSecurely).toHaveBeenCalledWith(
expect.stringContaining('authorize'),
);
expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith(
const tokenStorage = new MCPOAuthTokenStorage();
expect(tokenStorage.saveToken).toHaveBeenCalledWith(
'test-server',
expect.objectContaining({ accessToken: 'access_token_123' }),
'test-client-id',
@@ -296,7 +310,8 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.authenticate(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.authenticate(
'test-server',
configWithoutAuth,
'https://api.example.com',
@@ -385,7 +400,8 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.authenticate(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.authenticate(
'test-server',
configWithoutClient,
);
@@ -424,8 +440,9 @@ describe('MCPOAuthProvider', () => {
}, 10);
});
const authProvider = new MCPOAuthProvider();
await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig),
authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('OAuth error: access_denied');
});
@@ -453,8 +470,9 @@ describe('MCPOAuthProvider', () => {
}, 10);
});
const authProvider = new MCPOAuthProvider();
await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig),
authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('State mismatch - possible CSRF attack');
});
@@ -491,8 +509,9 @@ describe('MCPOAuthProvider', () => {
}),
);
const authProvider = new MCPOAuthProvider();
await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig),
authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('Token exchange failed: invalid_grant - Invalid grant');
});
@@ -516,8 +535,9 @@ describe('MCPOAuthProvider', () => {
return originalSetTimeout(callback, 0);
}) as unknown as typeof setTimeout;
const authProvider = new MCPOAuthProvider();
await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig),
authProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('OAuth callback timeout');
global.setTimeout = originalSetTimeout;
@@ -542,7 +562,8 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.refreshAccessToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.refreshAccessToken(
mockConfig,
'old_refresh_token',
'https://auth.example.com/token',
@@ -572,7 +593,8 @@ describe('MCPOAuthProvider', () => {
}),
);
await MCPOAuthProvider.refreshAccessToken(
const authProvider = new MCPOAuthProvider();
await authProvider.refreshAccessToken(
mockConfig,
'refresh_token',
'https://auth.example.com/token',
@@ -592,8 +614,9 @@ describe('MCPOAuthProvider', () => {
}),
);
const authProvider = new MCPOAuthProvider();
await expect(
MCPOAuthProvider.refreshAccessToken(
authProvider.refreshAccessToken(
mockConfig,
'invalid_refresh_token',
'https://auth.example.com/token',
@@ -614,12 +637,12 @@ describe('MCPOAuthProvider', () => {
updatedAt: Date.now(),
};
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
validCredentials,
);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(false);
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getToken).mockResolvedValue(validCredentials);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(false);
const result = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server',
mockConfig,
);
@@ -636,10 +659,9 @@ describe('MCPOAuthProvider', () => {
updatedAt: Date.now(),
};
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
expiredCredentials,
);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getToken).mockResolvedValue(expiredCredentials);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
const refreshResponse = {
access_token: 'new_access_token',
@@ -657,13 +679,14 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server',
mockConfig,
);
expect(result).toBe('new_access_token');
expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith(
expect(tokenStorage.saveToken).toHaveBeenCalledWith(
'test-server',
expect.objectContaining({ accessToken: 'new_access_token' }),
'test-client-id',
@@ -673,9 +696,11 @@ describe('MCPOAuthProvider', () => {
});
it('should return null when no credentials exist', async () => {
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null);
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getToken).mockResolvedValue(null);
const result = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server',
mockConfig,
);
@@ -692,11 +717,10 @@ describe('MCPOAuthProvider', () => {
updatedAt: Date.now(),
};
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
expiredCredentials,
);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
vi.mocked(MCPOAuthTokenStorage.removeToken).mockResolvedValue(undefined);
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getToken).mockResolvedValue(expiredCredentials);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
vi.mocked(tokenStorage.removeToken).mockResolvedValue(undefined);
mockFetch.mockResolvedValueOnce(
createMockResponse({
@@ -707,15 +731,14 @@ describe('MCPOAuthProvider', () => {
}),
);
const result = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server',
mockConfig,
);
expect(result).toBeNull();
expect(MCPOAuthTokenStorage.removeToken).toHaveBeenCalledWith(
'test-server',
);
expect(tokenStorage.removeToken).toHaveBeenCalledWith('test-server');
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to refresh token'),
);
@@ -734,12 +757,12 @@ describe('MCPOAuthProvider', () => {
updatedAt: Date.now(),
};
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
tokenWithoutRefresh,
);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
const tokenStorage = new MCPOAuthTokenStorage();
vi.mocked(tokenStorage.getToken).mockResolvedValue(tokenWithoutRefresh);
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
const result = await MCPOAuthProvider.getValidToken(
const authProvider = new MCPOAuthProvider();
const result = await authProvider.getValidToken(
'test-server',
mockConfig,
);
@@ -784,7 +807,8 @@ describe('MCPOAuthProvider', () => {
}),
);
await MCPOAuthProvider.authenticate('test-server', mockConfig);
const authProvider = new MCPOAuthProvider();
await authProvider.authenticate('test-server', mockConfig);
expect(crypto.randomBytes).toHaveBeenCalledWith(32); // code verifier
expect(crypto.randomBytes).toHaveBeenCalledWith(16); // state
@@ -833,7 +857,8 @@ describe('MCPOAuthProvider', () => {
}),
);
await MCPOAuthProvider.authenticate(
const authProvider = new MCPOAuthProvider();
await authProvider.authenticate(
'test-server',
mockConfig,
'https://auth.example.com',
@@ -894,7 +919,8 @@ describe('MCPOAuthProvider', () => {
authorizationUrl: 'https://auth.example.com/authorize?audience=1234',
};
await MCPOAuthProvider.authenticate('test-server', configWithParamsInUrl);
const authProvider = new MCPOAuthProvider();
await authProvider.authenticate('test-server', configWithParamsInUrl);
const url = new URL(capturedUrl!);
expect(url.searchParams.get('audience')).toBe('1234');
@@ -947,7 +973,8 @@ describe('MCPOAuthProvider', () => {
authorizationUrl: 'https://auth.example.com/authorize#login',
};
await MCPOAuthProvider.authenticate('test-server', configWithFragment);
const authProvider = new MCPOAuthProvider();
await authProvider.authenticate('test-server', configWithFragment);
const url = new URL(capturedUrl!);
expect(url.searchParams.get('client_id')).toBe('test-client-id');