mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-12 12:54:07 -07:00
Feat(security) - Make the OAuthTokenStorage non static (#7716)
Co-authored-by: Shi Shu <shii@google.com>
This commit is contained in:
@@ -12,7 +12,21 @@ vi.mock('../utils/secure-browser-launcher.js', () => ({
|
||||
openBrowserSecurely: mockOpenBrowserSecurely,
|
||||
}));
|
||||
vi.mock('node:crypto');
|
||||
vi.mock('./oauth-token-storage.js');
|
||||
vi.mock('./oauth-token-storage.js', () => {
|
||||
const mockSaveToken = vi.fn();
|
||||
const mockGetToken = vi.fn();
|
||||
const mockIsTokenExpired = vi.fn();
|
||||
const mockRemoveToken = vi.fn();
|
||||
|
||||
return {
|
||||
MCPOAuthTokenStorage: vi.fn(() => ({
|
||||
saveToken: mockSaveToken,
|
||||
getToken: mockGetToken,
|
||||
isTokenExpired: mockIsTokenExpired,
|
||||
removeToken: mockRemoveToken,
|
||||
})),
|
||||
};
|
||||
});
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import * as http from 'node:http';
|
||||
@@ -147,8 +161,9 @@ describe('MCPOAuthProvider', () => {
|
||||
});
|
||||
|
||||
// Mock token storage
|
||||
vi.mocked(MCPOAuthTokenStorage.saveToken).mockResolvedValue(undefined);
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null);
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
vi.mocked(tokenStorage.saveToken).mockResolvedValue(undefined);
|
||||
vi.mocked(tokenStorage.getToken).mockResolvedValue(null);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -192,10 +207,8 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await MCPOAuthProvider.authenticate(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.authenticate('test-server', mockConfig);
|
||||
|
||||
expect(result).toEqual({
|
||||
accessToken: 'access_token_123',
|
||||
@@ -208,7 +221,8 @@ describe('MCPOAuthProvider', () => {
|
||||
expect(mockOpenBrowserSecurely).toHaveBeenCalledWith(
|
||||
expect.stringContaining('authorize'),
|
||||
);
|
||||
expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith(
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
expect(tokenStorage.saveToken).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
expect.objectContaining({ accessToken: 'access_token_123' }),
|
||||
'test-client-id',
|
||||
@@ -296,7 +310,8 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await MCPOAuthProvider.authenticate(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.authenticate(
|
||||
'test-server',
|
||||
configWithoutAuth,
|
||||
'https://api.example.com',
|
||||
@@ -385,7 +400,8 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await MCPOAuthProvider.authenticate(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.authenticate(
|
||||
'test-server',
|
||||
configWithoutClient,
|
||||
);
|
||||
@@ -424,8 +440,9 @@ describe('MCPOAuthProvider', () => {
|
||||
}, 10);
|
||||
});
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await expect(
|
||||
MCPOAuthProvider.authenticate('test-server', mockConfig),
|
||||
authProvider.authenticate('test-server', mockConfig),
|
||||
).rejects.toThrow('OAuth error: access_denied');
|
||||
});
|
||||
|
||||
@@ -453,8 +470,9 @@ describe('MCPOAuthProvider', () => {
|
||||
}, 10);
|
||||
});
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await expect(
|
||||
MCPOAuthProvider.authenticate('test-server', mockConfig),
|
||||
authProvider.authenticate('test-server', mockConfig),
|
||||
).rejects.toThrow('State mismatch - possible CSRF attack');
|
||||
});
|
||||
|
||||
@@ -491,8 +509,9 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await expect(
|
||||
MCPOAuthProvider.authenticate('test-server', mockConfig),
|
||||
authProvider.authenticate('test-server', mockConfig),
|
||||
).rejects.toThrow('Token exchange failed: invalid_grant - Invalid grant');
|
||||
});
|
||||
|
||||
@@ -516,8 +535,9 @@ describe('MCPOAuthProvider', () => {
|
||||
return originalSetTimeout(callback, 0);
|
||||
}) as unknown as typeof setTimeout;
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await expect(
|
||||
MCPOAuthProvider.authenticate('test-server', mockConfig),
|
||||
authProvider.authenticate('test-server', mockConfig),
|
||||
).rejects.toThrow('OAuth callback timeout');
|
||||
|
||||
global.setTimeout = originalSetTimeout;
|
||||
@@ -542,7 +562,8 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await MCPOAuthProvider.refreshAccessToken(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.refreshAccessToken(
|
||||
mockConfig,
|
||||
'old_refresh_token',
|
||||
'https://auth.example.com/token',
|
||||
@@ -572,7 +593,8 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
await MCPOAuthProvider.refreshAccessToken(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await authProvider.refreshAccessToken(
|
||||
mockConfig,
|
||||
'refresh_token',
|
||||
'https://auth.example.com/token',
|
||||
@@ -592,8 +614,9 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await expect(
|
||||
MCPOAuthProvider.refreshAccessToken(
|
||||
authProvider.refreshAccessToken(
|
||||
mockConfig,
|
||||
'invalid_refresh_token',
|
||||
'https://auth.example.com/token',
|
||||
@@ -614,12 +637,12 @@ describe('MCPOAuthProvider', () => {
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
|
||||
validCredentials,
|
||||
);
|
||||
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(false);
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
vi.mocked(tokenStorage.getToken).mockResolvedValue(validCredentials);
|
||||
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(false);
|
||||
|
||||
const result = await MCPOAuthProvider.getValidToken(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.getValidToken(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
@@ -636,10 +659,9 @@ describe('MCPOAuthProvider', () => {
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
|
||||
expiredCredentials,
|
||||
);
|
||||
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
vi.mocked(tokenStorage.getToken).mockResolvedValue(expiredCredentials);
|
||||
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
|
||||
|
||||
const refreshResponse = {
|
||||
access_token: 'new_access_token',
|
||||
@@ -657,13 +679,14 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await MCPOAuthProvider.getValidToken(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.getValidToken(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(result).toBe('new_access_token');
|
||||
expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith(
|
||||
expect(tokenStorage.saveToken).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
expect.objectContaining({ accessToken: 'new_access_token' }),
|
||||
'test-client-id',
|
||||
@@ -673,9 +696,11 @@ describe('MCPOAuthProvider', () => {
|
||||
});
|
||||
|
||||
it('should return null when no credentials exist', async () => {
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null);
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
vi.mocked(tokenStorage.getToken).mockResolvedValue(null);
|
||||
|
||||
const result = await MCPOAuthProvider.getValidToken(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.getValidToken(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
@@ -692,11 +717,10 @@ describe('MCPOAuthProvider', () => {
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
|
||||
expiredCredentials,
|
||||
);
|
||||
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
|
||||
vi.mocked(MCPOAuthTokenStorage.removeToken).mockResolvedValue(undefined);
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
vi.mocked(tokenStorage.getToken).mockResolvedValue(expiredCredentials);
|
||||
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
|
||||
vi.mocked(tokenStorage.removeToken).mockResolvedValue(undefined);
|
||||
|
||||
mockFetch.mockResolvedValueOnce(
|
||||
createMockResponse({
|
||||
@@ -707,15 +731,14 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await MCPOAuthProvider.getValidToken(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.getValidToken(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(MCPOAuthTokenStorage.removeToken).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
expect(tokenStorage.removeToken).toHaveBeenCalledWith('test-server');
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to refresh token'),
|
||||
);
|
||||
@@ -734,12 +757,12 @@ describe('MCPOAuthProvider', () => {
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
|
||||
tokenWithoutRefresh,
|
||||
);
|
||||
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
vi.mocked(tokenStorage.getToken).mockResolvedValue(tokenWithoutRefresh);
|
||||
vi.mocked(tokenStorage.isTokenExpired).mockReturnValue(true);
|
||||
|
||||
const result = await MCPOAuthProvider.getValidToken(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
const result = await authProvider.getValidToken(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
@@ -784,7 +807,8 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
await MCPOAuthProvider.authenticate('test-server', mockConfig);
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await authProvider.authenticate('test-server', mockConfig);
|
||||
|
||||
expect(crypto.randomBytes).toHaveBeenCalledWith(32); // code verifier
|
||||
expect(crypto.randomBytes).toHaveBeenCalledWith(16); // state
|
||||
@@ -833,7 +857,8 @@ describe('MCPOAuthProvider', () => {
|
||||
}),
|
||||
);
|
||||
|
||||
await MCPOAuthProvider.authenticate(
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await authProvider.authenticate(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
'https://auth.example.com',
|
||||
@@ -894,7 +919,8 @@ describe('MCPOAuthProvider', () => {
|
||||
authorizationUrl: 'https://auth.example.com/authorize?audience=1234',
|
||||
};
|
||||
|
||||
await MCPOAuthProvider.authenticate('test-server', configWithParamsInUrl);
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await authProvider.authenticate('test-server', configWithParamsInUrl);
|
||||
|
||||
const url = new URL(capturedUrl!);
|
||||
expect(url.searchParams.get('audience')).toBe('1234');
|
||||
@@ -947,7 +973,8 @@ describe('MCPOAuthProvider', () => {
|
||||
authorizationUrl: 'https://auth.example.com/authorize#login',
|
||||
};
|
||||
|
||||
await MCPOAuthProvider.authenticate('test-server', configWithFragment);
|
||||
const authProvider = new MCPOAuthProvider();
|
||||
await authProvider.authenticate('test-server', configWithFragment);
|
||||
|
||||
const url = new URL(capturedUrl!);
|
||||
expect(url.searchParams.get('client_id')).toBe('test-client-id');
|
||||
|
||||
Reference in New Issue
Block a user