mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-14 05:42:54 -07:00
fix(core): preserve OAuth refresh tokens during rotation and retrieval (#26924)
This commit is contained in:
@@ -242,6 +242,39 @@ describe('OAuthCredentialStorage', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should merge existing refresh token when new payload lacks one', async () => {
|
||||
const oldCredentials: OAuthCredentials = {
|
||||
serverName: 'main-account',
|
||||
token: {
|
||||
accessToken: 'old-access-token',
|
||||
refreshToken: 'persistent-refresh-token',
|
||||
tokenType: 'Bearer',
|
||||
expiresAt: Date.now() + 3600000,
|
||||
scope: 'email',
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
vi.spyOn(mockHybridTokenStorage, 'getCredentials').mockResolvedValue(
|
||||
oldCredentials,
|
||||
);
|
||||
|
||||
const newTokens: Credentials = {
|
||||
access_token: 'new-access-token',
|
||||
expiry_date: Date.now() + 3600000,
|
||||
};
|
||||
|
||||
await OAuthCredentialStorage.saveCredentials(newTokens);
|
||||
|
||||
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
token: expect.objectContaining({
|
||||
accessToken: 'new-access-token',
|
||||
refreshToken: 'persistent-refresh-token', // correctly merged
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if access_token is missing', async () => {
|
||||
const invalidCredentials: Credentials = {
|
||||
...mockCredentials,
|
||||
|
||||
@@ -66,12 +66,16 @@ export class OAuthCredentialStorage {
|
||||
throw new Error('Attempted to save credentials without an access token.');
|
||||
}
|
||||
|
||||
const existing = await this.storage.getCredentials(MAIN_ACCOUNT_KEY);
|
||||
const mergedRefreshToken =
|
||||
credentials.refresh_token || existing?.token.refreshToken;
|
||||
|
||||
// Convert Google Credentials to OAuthCredentials format
|
||||
const mcpCredentials: OAuthCredentials = {
|
||||
serverName: MAIN_ACCOUNT_KEY,
|
||||
token: {
|
||||
accessToken: credentials.access_token,
|
||||
refreshToken: credentials.refresh_token || undefined,
|
||||
refreshToken: mergedRefreshToken || undefined,
|
||||
tokenType: credentials.token_type || 'Bearer',
|
||||
scope: credentials.scope || undefined,
|
||||
expiresAt: credentials.expiry_date || undefined,
|
||||
|
||||
@@ -192,6 +192,38 @@ describe('MCPOAuthTokenStorage', () => {
|
||||
expect(savedData[0].serverName).toBe('existing-server');
|
||||
});
|
||||
|
||||
it('should merge existing refresh token when new payload lacks one', async () => {
|
||||
const existingCredentials: OAuthCredentials = {
|
||||
...mockCredentials,
|
||||
serverName: 'existing-server',
|
||||
token: {
|
||||
...mockToken,
|
||||
refreshToken: 'old-refresh-token',
|
||||
},
|
||||
};
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([existingCredentials]),
|
||||
);
|
||||
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
|
||||
|
||||
const newToken: OAuthToken = {
|
||||
accessToken: 'new_access_token',
|
||||
expiresAt: Date.now() + ONE_HR_MS,
|
||||
tokenType: 'Bearer',
|
||||
}; // missing refreshToken
|
||||
|
||||
await tokenStorage.saveToken('existing-server', newToken);
|
||||
|
||||
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
|
||||
const savedData = JSON.parse(
|
||||
writeCall[1] as string,
|
||||
) as OAuthCredentials[];
|
||||
|
||||
expect(savedData).toHaveLength(1);
|
||||
expect(savedData[0].token.accessToken).toBe('new_access_token');
|
||||
expect(savedData[0].token.refreshToken).toBe('old-refresh-token'); // successfully merged
|
||||
});
|
||||
|
||||
it('should handle write errors gracefully', async () => {
|
||||
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
||||
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
|
||||
@@ -447,6 +479,55 @@ describe('MCPOAuthTokenStorage', () => {
|
||||
expect(fs.mkdir).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should merge existing refresh token when new payload lacks one in encrypted storage', async () => {
|
||||
const serverName = 'server1';
|
||||
const now = Date.now();
|
||||
vi.spyOn(Date, 'now').mockReturnValue(now);
|
||||
|
||||
const existingCredentials: OAuthCredentials = {
|
||||
serverName,
|
||||
token: {
|
||||
...mockToken,
|
||||
refreshToken: 'old-refresh-token',
|
||||
},
|
||||
updatedAt: now,
|
||||
};
|
||||
|
||||
mockHybridTokenStorage.getCredentials.mockResolvedValue(
|
||||
existingCredentials,
|
||||
);
|
||||
|
||||
const newToken: OAuthToken = {
|
||||
accessToken: 'new_access_token',
|
||||
expiresAt: Date.now() + ONE_HR_MS,
|
||||
tokenType: 'Bearer',
|
||||
};
|
||||
|
||||
await tokenStorage.saveToken(
|
||||
serverName,
|
||||
newToken,
|
||||
'clientId',
|
||||
'tokenUrl',
|
||||
'mcpUrl',
|
||||
);
|
||||
|
||||
const expectedCredential: OAuthCredentials = {
|
||||
serverName,
|
||||
token: {
|
||||
...newToken,
|
||||
refreshToken: 'old-refresh-token',
|
||||
},
|
||||
clientId: 'clientId',
|
||||
tokenUrl: 'tokenUrl',
|
||||
mcpServerUrl: 'mcpUrl',
|
||||
updatedAt: now,
|
||||
};
|
||||
|
||||
expect(mockHybridTokenStorage.setCredentials).toHaveBeenCalledWith(
|
||||
expectedCredential,
|
||||
);
|
||||
});
|
||||
|
||||
it('should use HybridTokenStorage to get credentials', async () => {
|
||||
mockHybridTokenStorage.getCredentials.mockResolvedValue(mockCredentials);
|
||||
const result = await tokenStorage.getCredentials('server1');
|
||||
|
||||
@@ -143,9 +143,18 @@ export class MCPOAuthTokenStorage implements TokenStorage {
|
||||
): Promise<void> {
|
||||
await this.ensureConfigDir();
|
||||
|
||||
const existing = await this.getCredentials(serverName);
|
||||
const mergedRefreshToken =
|
||||
token.refreshToken || existing?.token.refreshToken;
|
||||
|
||||
const mergedToken = {
|
||||
...token,
|
||||
refreshToken: mergedRefreshToken,
|
||||
};
|
||||
|
||||
const credential: OAuthCredentials = {
|
||||
serverName,
|
||||
token,
|
||||
token: mergedToken,
|
||||
clientId,
|
||||
tokenUrl,
|
||||
mcpServerUrl,
|
||||
|
||||
@@ -72,7 +72,7 @@ describe('KeychainTokenStorage', () => {
|
||||
expect(retrieved?.serverName).toBe('test-server');
|
||||
});
|
||||
|
||||
it('should return null if no credentials are found or they are expired', async () => {
|
||||
it('should return null if no credentials are found or they are expired and unrefreshable', async () => {
|
||||
expect(await storage.getCredentials('missing')).toBeNull();
|
||||
|
||||
const expiredCreds = {
|
||||
@@ -81,6 +81,20 @@ describe('KeychainTokenStorage', () => {
|
||||
};
|
||||
await storage.setCredentials(expiredCreds);
|
||||
expect(await storage.getCredentials('test-server')).toBeNull();
|
||||
|
||||
// Ensure that if it has a refresh token, it is NOT returned as null
|
||||
const expiredWithRefresh = {
|
||||
...validCredentials,
|
||||
token: {
|
||||
...validCredentials.token,
|
||||
expiresAt: Date.now() - 1000,
|
||||
refreshToken: 'some-refresh-token',
|
||||
},
|
||||
};
|
||||
await storage.setCredentials(expiredWithRefresh);
|
||||
const retrieved = await storage.getCredentials('test-server');
|
||||
expect(retrieved).not.toBeNull();
|
||||
expect(retrieved?.token.refreshToken).toBe('some-refresh-token');
|
||||
});
|
||||
|
||||
it('should throw if stored data is corrupted JSON', async () => {
|
||||
|
||||
@@ -36,7 +36,7 @@ export class KeychainTokenStorage
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const credentials = JSON.parse(data) as OAuthCredentials;
|
||||
|
||||
if (this.isTokenExpired(credentials)) {
|
||||
if (this.isTokenExpired(credentials) && !credentials.token.refreshToken) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ export class KeychainTokenStorage
|
||||
try {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const data = JSON.parse(cred.password) as OAuthCredentials;
|
||||
if (!this.isTokenExpired(data)) {
|
||||
if (!this.isTokenExpired(data) || data.token.refreshToken) {
|
||||
result.set(cred.account, data);
|
||||
}
|
||||
} catch (error) {
|
||||
|
||||
Reference in New Issue
Block a user