fix(core): preserve OAuth refresh tokens during rotation and retrieval (#26924)

This commit is contained in:
Coco Sheng
2026-05-13 13:19:05 -04:00
committed by GitHub
parent 749657cbf9
commit 297d3a3067
6 changed files with 146 additions and 5 deletions
@@ -242,6 +242,39 @@ describe('OAuthCredentialStorage', () => {
);
});
it('should merge existing refresh token when new payload lacks one', async () => {
const oldCredentials: OAuthCredentials = {
serverName: 'main-account',
token: {
accessToken: 'old-access-token',
refreshToken: 'persistent-refresh-token',
tokenType: 'Bearer',
expiresAt: Date.now() + 3600000,
scope: 'email',
},
updatedAt: Date.now(),
};
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
oldCredentials,
);
const newTokens: Credentials = {
access_token: 'new-access-token',
expiry_date: Date.now() + 3600000,
};
await OAuthCredentialStorage.saveCredentials(newTokens);
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
expect.objectContaining({
token: expect.objectContaining({
accessToken: 'new-access-token',
refreshToken: 'persistent-refresh-token', // correctly merged
}),
}),
);
});
it('should throw an error if access_token is missing', async () => {
const invalidCredentials: Credentials = {
...mockCredentials,
@@ -66,12 +66,16 @@ export class OAuthCredentialStorage {
throw new Error('Attempted to save credentials without an access token.');
}
const existing = await this.storage.getCredentials(MAIN_ACCOUNT_KEY);
const mergedRefreshToken =
credentials.refresh_token || existing?.token.refreshToken;
// Convert Google Credentials to OAuthCredentials format
const mcpCredentials: OAuthCredentials = {
serverName: MAIN_ACCOUNT_KEY,
token: {
accessToken: credentials.access_token,
refreshToken: credentials.refresh_token || undefined,
refreshToken: mergedRefreshToken || undefined,
tokenType: credentials.token_type || 'Bearer',
scope: credentials.scope || undefined,
expiresAt: credentials.expiry_date || undefined,
@@ -192,6 +192,38 @@ describe('MCPOAuthTokenStorage', () => {
expect(savedData[0].serverName).toBe('existing-server');
});
it('should merge existing refresh token when new payload lacks one', async () => {
const existingCredentials: OAuthCredentials = {
...mockCredentials,
serverName: 'existing-server',
token: {
...mockToken,
refreshToken: 'old-refresh-token',
},
};
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([existingCredentials]),
);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
const newToken: OAuthToken = {
accessToken: 'new_access_token',
expiresAt: Date.now() + ONE_HR_MS,
tokenType: 'Bearer',
}; // missing refreshToken
await tokenStorage.saveToken('existing-server', newToken);
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
const savedData = JSON.parse(
writeCall[1] as string,
) as OAuthCredentials[];
expect(savedData).toHaveLength(1);
expect(savedData[0].token.accessToken).toBe('new_access_token');
expect(savedData[0].token.refreshToken).toBe('old-refresh-token'); // successfully merged
});
it('should handle write errors gracefully', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
@@ -447,6 +479,55 @@ describe('MCPOAuthTokenStorage', () => {
expect(fs.mkdir).toHaveBeenCalled();
});
it('should merge existing refresh token when new payload lacks one in encrypted storage', async () => {
const serverName = 'server1';
const now = Date.now();
vi.spyOn(Date, 'now').mockReturnValue(now);
const existingCredentials: OAuthCredentials = {
serverName,
token: {
...mockToken,
refreshToken: 'old-refresh-token',
},
updatedAt: now,
};
mockHybridTokenStorage.getCredentials.mockResolvedValue(
existingCredentials,
);
const newToken: OAuthToken = {
accessToken: 'new_access_token',
expiresAt: Date.now() + ONE_HR_MS,
tokenType: 'Bearer',
};
await tokenStorage.saveToken(
serverName,
newToken,
'clientId',
'tokenUrl',
'mcpUrl',
);
const expectedCredential: OAuthCredentials = {
serverName,
token: {
...newToken,
refreshToken: 'old-refresh-token',
},
clientId: 'clientId',
tokenUrl: 'tokenUrl',
mcpServerUrl: 'mcpUrl',
updatedAt: now,
};
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
expectedCredential,
);
});
it('should use HybridTokenStorage to get credentials', async () => {
mockHybridTokenStorage.getCredentials.mockResolvedValue(mockCredentials);
const result = await tokenStorage.getCredentials('server1');
+10 -1
View File
@@ -143,9 +143,18 @@ export class MCPOAuthTokenStorage implements TokenStorage {
): Promise<void> {
await this.ensureConfigDir();
const existing = await this.getCredentials(serverName);
const mergedRefreshToken =
token.refreshToken || existing?.token.refreshToken;
const mergedToken = {
...token,
refreshToken: mergedRefreshToken,
};
const credential: OAuthCredentials = {
serverName,
token,
token: mergedToken,
clientId,
tokenUrl,
mcpServerUrl,
@@ -72,7 +72,7 @@ describe('KeychainTokenStorage', () => {
expect(retrieved?.serverName).toBe('test-server');
});
it('should return null if no credentials are found or they are expired', async () => {
it('should return null if no credentials are found or they are expired and unrefreshable', async () => {
expect(await storage.getCredentials('missing')).toBeNull();
const expiredCreds = {
@@ -81,6 +81,20 @@ describe('KeychainTokenStorage', () => {
};
await storage.setCredentials(expiredCreds);
expect(await storage.getCredentials('test-server')).toBeNull();
// Ensure that if it has a refresh token, it is NOT returned as null
const expiredWithRefresh = {
...validCredentials,
token: {
...validCredentials.token,
expiresAt: Date.now() - 1000,
refreshToken: 'some-refresh-token',
},
};
await storage.setCredentials(expiredWithRefresh);
const retrieved = await storage.getCredentials('test-server');
expect(retrieved).not.toBeNull();
expect(retrieved?.token.refreshToken).toBe('some-refresh-token');
});
it('should throw if stored data is corrupted JSON', async () => {
@@ -36,7 +36,7 @@ export class KeychainTokenStorage
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const credentials = JSON.parse(data) as OAuthCredentials;
if (this.isTokenExpired(credentials)) {
if (this.isTokenExpired(credentials) && !credentials.token.refreshToken) {
return null;
}
@@ -104,7 +104,7 @@ export class KeychainTokenStorage
try {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const data = JSON.parse(cred.password) as OAuthCredentials;
if (!this.isTokenExpired(data)) {
if (!this.isTokenExpired(data) || data.token.refreshToken) {
result.set(cred.account, data);
}
} catch (error) {