diff --git a/packages/core/src/code_assist/oauth-credential-storage.test.ts b/packages/core/src/code_assist/oauth-credential-storage.test.ts index b1cb460368..3ef2de997c 100644 --- a/packages/core/src/code_assist/oauth-credential-storage.test.ts +++ b/packages/core/src/code_assist/oauth-credential-storage.test.ts @@ -242,6 +242,39 @@ describe('OAuthCredentialStorage', () => { ); }); + it('should merge existing refresh token when new payload lacks one', async () => { + const oldCredentials: OAuthCredentials = { + serverName: 'main-account', + token: { + accessToken: 'old-access-token', + refreshToken: 'persistent-refresh-token', + tokenType: 'Bearer', + expiresAt: Date.now() + 3600000, + scope: 'email', + }, + updatedAt: Date.now(), + }; + vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue( + oldCredentials, + ); + + const newTokens: Credentials = { + access_token: 'new-access-token', + expiry_date: Date.now() + 3600000, + }; + + await OAuthCredentialStorage.saveCredentials(newTokens); + + expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith( + expect.objectContaining({ + token: expect.objectContaining({ + accessToken: 'new-access-token', + refreshToken: 'persistent-refresh-token', // correctly merged + }), + }), + ); + }); + it('should throw an error if access_token is missing', async () => { const invalidCredentials: Credentials = { ...mockCredentials, diff --git a/packages/core/src/code_assist/oauth-credential-storage.ts b/packages/core/src/code_assist/oauth-credential-storage.ts index c7c0209cfa..c924031d0d 100644 --- a/packages/core/src/code_assist/oauth-credential-storage.ts +++ b/packages/core/src/code_assist/oauth-credential-storage.ts @@ -66,12 +66,16 @@ export class OAuthCredentialStorage { throw new Error('Attempted to save credentials without an access token.'); } + const existing = await this.storage.getCredentials(MAIN_ACCOUNT_KEY); + const mergedRefreshToken = + credentials.refresh_token || existing?.token.refreshToken; + // Convert Google Credentials to OAuthCredentials format const mcpCredentials: OAuthCredentials = { serverName: MAIN_ACCOUNT_KEY, token: { accessToken: credentials.access_token, - refreshToken: credentials.refresh_token || undefined, + refreshToken: mergedRefreshToken || undefined, tokenType: credentials.token_type || 'Bearer', scope: credentials.scope || undefined, expiresAt: credentials.expiry_date || undefined, diff --git a/packages/core/src/mcp/oauth-token-storage.test.ts b/packages/core/src/mcp/oauth-token-storage.test.ts index 2ccce0e7e2..943e6a15f9 100644 --- a/packages/core/src/mcp/oauth-token-storage.test.ts +++ b/packages/core/src/mcp/oauth-token-storage.test.ts @@ -192,6 +192,38 @@ describe('MCPOAuthTokenStorage', () => { expect(savedData[0].serverName).toBe('existing-server'); }); + it('should merge existing refresh token when new payload lacks one', async () => { + const existingCredentials: OAuthCredentials = { + ...mockCredentials, + serverName: 'existing-server', + token: { + ...mockToken, + refreshToken: 'old-refresh-token', + }, + }; + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify([existingCredentials]), + ); + vi.mocked(fs.writeFile).mockResolvedValue(undefined); + + const newToken: OAuthToken = { + accessToken: 'new_access_token', + expiresAt: Date.now() + ONE_HR_MS, + tokenType: 'Bearer', + }; // missing refreshToken + + await tokenStorage.saveToken('existing-server', newToken); + + const writeCall = vi.mocked(fs.writeFile).mock.calls[0]; + const savedData = JSON.parse( + writeCall[1] as string, + ) as OAuthCredentials[]; + + expect(savedData).toHaveLength(1); + expect(savedData[0].token.accessToken).toBe('new_access_token'); + expect(savedData[0].token.refreshToken).toBe('old-refresh-token'); // successfully merged + }); + it('should handle write errors gracefully', async () => { vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' }); vi.mocked(fs.mkdir).mockResolvedValue(undefined); @@ -447,6 +479,55 @@ describe('MCPOAuthTokenStorage', () => { expect(fs.mkdir).toHaveBeenCalled(); }); + it('should merge existing refresh token when new payload lacks one in encrypted storage', async () => { + const serverName = 'server1'; + const now = Date.now(); + vi.spyOn(Date, 'now').mockReturnValue(now); + + const existingCredentials: OAuthCredentials = { + serverName, + token: { + ...mockToken, + refreshToken: 'old-refresh-token', + }, + updatedAt: now, + }; + + mockHybridTokenStorage.getCredentials.mockResolvedValue( + existingCredentials, + ); + + const newToken: OAuthToken = { + accessToken: 'new_access_token', + expiresAt: Date.now() + ONE_HR_MS, + tokenType: 'Bearer', + }; + + await tokenStorage.saveToken( + serverName, + newToken, + 'clientId', + 'tokenUrl', + 'mcpUrl', + ); + + const expectedCredential: OAuthCredentials = { + serverName, + token: { + ...newToken, + refreshToken: 'old-refresh-token', + }, + clientId: 'clientId', + tokenUrl: 'tokenUrl', + mcpServerUrl: 'mcpUrl', + updatedAt: now, + }; + + expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith( + expectedCredential, + ); + }); + it('should use HybridTokenStorage to get credentials', async () => { mockHybridTokenStorage.getCredentials.mockResolvedValue(mockCredentials); const result = await tokenStorage.getCredentials('server1'); diff --git a/packages/core/src/mcp/oauth-token-storage.ts b/packages/core/src/mcp/oauth-token-storage.ts index 3b27d756e9..cd6af992e4 100644 --- a/packages/core/src/mcp/oauth-token-storage.ts +++ b/packages/core/src/mcp/oauth-token-storage.ts @@ -143,9 +143,18 @@ export class MCPOAuthTokenStorage implements TokenStorage { ): Promise { await this.ensureConfigDir(); + const existing = await this.getCredentials(serverName); + const mergedRefreshToken = + token.refreshToken || existing?.token.refreshToken; + + const mergedToken = { + ...token, + refreshToken: mergedRefreshToken, + }; + const credential: OAuthCredentials = { serverName, - token, + token: mergedToken, clientId, tokenUrl, mcpServerUrl, diff --git a/packages/core/src/mcp/token-storage/keychain-token-storage.test.ts b/packages/core/src/mcp/token-storage/keychain-token-storage.test.ts index 2192abbc45..1a326a29cb 100644 --- a/packages/core/src/mcp/token-storage/keychain-token-storage.test.ts +++ b/packages/core/src/mcp/token-storage/keychain-token-storage.test.ts @@ -72,7 +72,7 @@ describe('KeychainTokenStorage', () => { expect(retrieved?.serverName).toBe('test-server'); }); - it('should return null if no credentials are found or they are expired', async () => { + it('should return null if no credentials are found or they are expired and unrefreshable', async () => { expect(await storage.getCredentials('missing')).toBeNull(); const expiredCreds = { @@ -81,6 +81,20 @@ describe('KeychainTokenStorage', () => { }; await storage.setCredentials(expiredCreds); expect(await storage.getCredentials('test-server')).toBeNull(); + + // Ensure that if it has a refresh token, it is NOT returned as null + const expiredWithRefresh = { + ...validCredentials, + token: { + ...validCredentials.token, + expiresAt: Date.now() - 1000, + refreshToken: 'some-refresh-token', + }, + }; + await storage.setCredentials(expiredWithRefresh); + const retrieved = await storage.getCredentials('test-server'); + expect(retrieved).not.toBeNull(); + expect(retrieved?.token.refreshToken).toBe('some-refresh-token'); }); it('should throw if stored data is corrupted JSON', async () => { diff --git a/packages/core/src/mcp/token-storage/keychain-token-storage.ts b/packages/core/src/mcp/token-storage/keychain-token-storage.ts index f649b0f1c0..36adb170ec 100644 --- a/packages/core/src/mcp/token-storage/keychain-token-storage.ts +++ b/packages/core/src/mcp/token-storage/keychain-token-storage.ts @@ -36,7 +36,7 @@ export class KeychainTokenStorage // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion const credentials = JSON.parse(data) as OAuthCredentials; - if (this.isTokenExpired(credentials)) { + if (this.isTokenExpired(credentials) && !credentials.token.refreshToken) { return null; } @@ -104,7 +104,7 @@ export class KeychainTokenStorage try { // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion const data = JSON.parse(cred.password) as OAuthCredentials; - if (!this.isTokenExpired(data)) { + if (!this.isTokenExpired(data) || data.token.refreshToken) { result.set(cred.account, data); } } catch (error) {