From fd01cc03bf593c2fb94e97939d8865fa625d3600 Mon Sep 17 00:00:00 2001 From: Sahil Kirad <167863755+sahilkirad@users.noreply.github.com> Date: Thu, 14 May 2026 00:31:27 +0530 Subject: [PATCH] fix(core): refresh MCP OAuth token usage after re-auth (#26312) Co-authored-by: Tommaso Sciortino --- integration-tests/file-system.test.ts | 2 +- packages/core/src/mcp/oauth-provider.ts | 70 ++++++ .../core/src/mcp/stored-token-provider.ts | 101 ++++++++ packages/core/src/tools/mcp-client.test.ts | 238 ++++++++++++++++-- packages/core/src/tools/mcp-client.ts | 66 ++--- 5 files changed, 430 insertions(+), 47 deletions(-) create mode 100644 packages/core/src/mcp/stored-token-provider.ts diff --git a/integration-tests/file-system.test.ts b/integration-tests/file-system.test.ts index aa50000ef6..6a733b9875 100644 --- a/integration-tests/file-system.test.ts +++ b/integration-tests/file-system.test.ts @@ -172,7 +172,7 @@ describe('file-system', () => { ).toBeDefined(); const newFileContent = rig.readFile(fileName); - expect(newFileContent).toBe('1.0.1'); + expect(newFileContent.trimEnd()).toBe('1.0.1'); }); it.skip('should replace multiple instances of a string', async () => { diff --git a/packages/core/src/mcp/oauth-provider.ts b/packages/core/src/mcp/oauth-provider.ts index 6aaafa6054..8fd44183af 100644 --- a/packages/core/src/mcp/oauth-provider.ts +++ b/packages/core/src/mcp/oauth-provider.ts @@ -616,4 +616,74 @@ ${authUrl} return null; } + async getValidTokenWithMetadata( + serverName: string, + config: MCPOAuthConfig, + ): Promise<{ + accessToken: string; + tokenType: string; + expiresAt?: number; + scope?: string; + refreshToken?: string; + } | null> { + const credentials = await this.tokenStorage.getCredentials(serverName); + if (!credentials) return null; + + let current = credentials.token; + + if (this.tokenStorage.isTokenExpired(current)) { + const clientId = config.clientId ?? credentials.clientId; + if (current.refreshToken && clientId && credentials.tokenUrl) { + try { + const newTokenResponse = await this.refreshAccessToken( + config, + current.refreshToken, + credentials.tokenUrl, + credentials.mcpServerUrl, + ); + + const refreshed: OAuthToken = { + accessToken: newTokenResponse.access_token, + tokenType: newTokenResponse.token_type, + refreshToken: + newTokenResponse.refresh_token || current.refreshToken, + scope: newTokenResponse.scope || current.scope, + }; + + if (newTokenResponse.expires_in) { + refreshed.expiresAt = + Date.now() + newTokenResponse.expires_in * 1000; + } + + await this.tokenStorage.saveToken( + serverName, + refreshed, + clientId, + credentials.tokenUrl, + credentials.mcpServerUrl, + ); + + current = refreshed; + } catch (error) { + coreEvents.emitFeedback( + 'error', + 'Failed to refresh auth token.', + error, + ); + await this.tokenStorage.deleteCredentials(serverName); + return null; + } + } else { + return null; + } + } + + return { + accessToken: current.accessToken, + tokenType: current.tokenType || 'Bearer', + expiresAt: current.expiresAt, + scope: current.scope, + refreshToken: current.refreshToken, + }; + } } diff --git a/packages/core/src/mcp/stored-token-provider.ts b/packages/core/src/mcp/stored-token-provider.ts new file mode 100644 index 0000000000..5c2bfc939f --- /dev/null +++ b/packages/core/src/mcp/stored-token-provider.ts @@ -0,0 +1,101 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js'; +import type { + OAuthClientInformation, + OAuthClientMetadata, + OAuthTokens, +} from '@modelcontextprotocol/sdk/shared/auth.js'; +import type { MCPServerConfig } from '../config/config.js'; +import { MCPOAuthProvider } from './oauth-provider.js'; +import { FIVE_MIN_BUFFER_MS } from './oauth-utils.js'; + +export class DynamicStoredOAuthProvider implements OAuthClientProvider { + readonly redirectUrl = ''; + readonly clientMetadata: OAuthClientMetadata = { + client_name: 'Gemini CLI (Stored OAuth)', + redirect_uris: [], + grant_types: [], + response_types: [], + token_endpoint_auth_method: 'none', + }; + + private clientInfo?: OAuthClientInformation; + private readonly oauthProvider = new MCPOAuthProvider(); + private cachedToken?: OAuthTokens; + private tokenExpiryTime?: number; + + constructor( + private readonly serverName: string, + private readonly serverConfig: MCPServerConfig, + ) {} + + clientInformation(): OAuthClientInformation | undefined { + return this.clientInfo; + } + + saveClientInformation(clientInformation: OAuthClientInformation): void { + this.clientInfo = clientInformation; + } + + private isCachedTokenValid(): boolean { + return !!( + this.cachedToken?.access_token && + this.tokenExpiryTime && + Date.now() < this.tokenExpiryTime - FIVE_MIN_BUFFER_MS + ); + } + + async tokens(): Promise { + if (this.isCachedTokenValid()) { + return this.cachedToken; + } + + const oauthConfig = + this.serverConfig.oauth?.enabled && this.serverConfig.oauth + ? this.serverConfig.oauth + : {}; + + const tokenMeta = await this.oauthProvider.getValidTokenWithMetadata( + this.serverName, + oauthConfig, + ); + + if (!tokenMeta?.accessToken) { + this.cachedToken = undefined; + this.tokenExpiryTime = undefined; + return undefined; + } + + const freshTokens: OAuthTokens = { + access_token: tokenMeta.accessToken, + token_type: tokenMeta.tokenType || 'Bearer', + expires_in: tokenMeta.expiresAt + ? Math.max(0, Math.floor((tokenMeta.expiresAt - Date.now()) / 1000)) + : undefined, + scope: tokenMeta.scope, + refresh_token: tokenMeta.refreshToken, + }; + + if (freshTokens.expires_in !== undefined) { + this.cachedToken = freshTokens; + this.tokenExpiryTime = Date.now() + freshTokens.expires_in * 1000; + return this.cachedToken; + } + + this.cachedToken = undefined; + this.tokenExpiryTime = undefined; + return freshTokens; + } + + saveTokens(_tokens: OAuthTokens): void {} + redirectToAuthorization(_authorizationUrl: URL): void {} + saveCodeVerifier(_codeVerifier: string): void {} + codeVerifier(): string { + return ''; + } +} diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index d330a67fe0..be137f548e 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -1780,41 +1780,249 @@ describe('mcp-client', () => { describe('createTransport', () => { describe('should connect via httpUrl', () => { - it('without headers', async () => { + it('uses MCP SDK authProvider token() path for oauth-enabled servers', async () => { + const mockGetValidTokenWithMetadata = vi.fn().mockResolvedValue({ + accessToken: 'fresh-token', + tokenType: 'Bearer', + expiresAt: Date.now() + 10 * 60 * 1000, + }); + vi.mocked(MCPOAuthProvider).mockReturnValue({ + getValidTokenWithMetadata: mockGetValidTokenWithMetadata, + } as unknown as MCPOAuthProvider); + + vi.mocked(MCPOAuthTokenStorage).mockReturnValue({ + getCredentials: vi.fn().mockResolvedValue({ + clientId: 'cid', + token: { + accessToken: 'fresh-token', + tokenType: 'Bearer', + expiresAt: Date.now() + 10 * 60 * 1000, + }, + }), + } as unknown as MCPOAuthTokenStorage); + const transport = await createTransport( 'test-server', { - httpUrl: 'http://test-server', + url: 'http://test-server', + type: 'http', + oauth: { enabled: true }, }, false, MOCK_CONTEXT, ); - expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); - expect(transport).toMatchObject({ - _url: new URL('http://test-server'), - _requestInit: { headers: {} }, - }); + const testableTransport = transport as unknown as { + _authProvider?: { + tokens: () => Promise<{ access_token: string } | undefined>; + }; + }; + + expect(testableTransport._authProvider).toBeDefined(); + const tokens = await testableTransport._authProvider!.tokens(); + expect(tokens?.access_token).toBe('fresh-token'); }); + it('uses storage-backed expiry instead of long fallback cache for dynamic authProvider', async () => { + const now = Date.now(); + const soonExpiry = now + 10 * 60 * 1000; // 10 minutes + + const mockGetValidTokenWithMetadata = vi.fn().mockResolvedValue({ + accessToken: 'fresh-token', + tokenType: 'Bearer', + expiresAt: soonExpiry, + }); + const mockGetCredentials = vi.fn().mockImplementation(async () => ({ + clientId: 'cid', + token: { + accessToken: 'fresh-token', + tokenType: 'Bearer', + expiresAt: soonExpiry, + }, + })); + + vi.mocked(MCPOAuthProvider).mockReturnValue({ + getValidTokenWithMetadata: mockGetValidTokenWithMetadata, + } as unknown as MCPOAuthProvider); + + vi.mocked(MCPOAuthTokenStorage).mockReturnValue({ + getCredentials: mockGetCredentials, + } as unknown as MCPOAuthTokenStorage); - it('with headers', async () => { const transport = await createTransport( 'test-server', { - httpUrl: 'http://test-server', - headers: { Authorization: 'derp' }, + url: 'http://test-server', + type: 'http', }, false, MOCK_CONTEXT, ); - expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); - expect(transport).toMatchObject({ - _url: new URL('http://test-server'), - _requestInit: { - headers: { Authorization: 'derp' }, + const testableTransport = transport as unknown as { + _authProvider?: { + tokens: () => Promise< + { access_token: string; expires_in?: number } | undefined + >; + }; + }; + + expect(testableTransport._authProvider).toBeDefined(); + + const tokens = await testableTransport._authProvider!.tokens(); + expect(tokens?.access_token).toBe('fresh-token'); + expect(tokens?.expires_in).toBeDefined(); + expect((tokens?.expires_in ?? 0) <= 10 * 60).toBe(true); + + expect(mockGetValidTokenWithMetadata).toHaveBeenCalledTimes(1); + expect(mockGetCredentials).toHaveBeenCalledTimes(1); + }); + it('uses dynamic authProvider when stored OAuth token exists', async () => { + const mockGetValidTokenWithMetadata = vi.fn().mockResolvedValue({ + accessToken: 'stored-fresh-token', + tokenType: 'Bearer', + expiresAt: Date.now() + 10 * 60 * 1000, + }); + vi.mocked(MCPOAuthProvider).mockReturnValue({ + getValidTokenWithMetadata: mockGetValidTokenWithMetadata, + } as unknown as MCPOAuthProvider); + + vi.mocked(MCPOAuthTokenStorage).mockReturnValue({ + getCredentials: vi.fn().mockResolvedValue({ + clientId: 'cid', + token: { + accessToken: 'stored-fresh-token', + tokenType: 'Bearer', + expiresAt: Date.now() + 10 * 60 * 1000, + }, + }), + } as unknown as MCPOAuthTokenStorage); + + const transport = await createTransport( + 'test-server', + { + url: 'http://test-server', + type: 'http', + }, + false, + MOCK_CONTEXT, + ); + + const testableTransport = transport as unknown as { + _authProvider?: { + tokens: () => Promise<{ access_token: string } | undefined>; + }; + }; + + expect(testableTransport._authProvider).toBeDefined(); + const tokens = await testableTransport._authProvider!.tokens(); + expect(tokens?.access_token).toBe('stored-fresh-token'); + }); + it('caches OAuth tokens in dynamic authProvider and avoids repeated lookups', async () => { + const mockGetValidTokenWithMetadata = vi.fn().mockResolvedValue({ + accessToken: 'cached-token', + tokenType: 'Bearer', + expiresAt: Date.now() + 10 * 60 * 1000, + }); + const mockGetCredentials = vi.fn().mockResolvedValue({ + clientId: 'cid', + token: { + accessToken: 'cached-token', + tokenType: 'Bearer', + expiresAt: Date.now() + 10 * 60 * 1000, }, }); + + vi.mocked(MCPOAuthProvider).mockReturnValue({ + getValidTokenWithMetadata: mockGetValidTokenWithMetadata, + } as unknown as MCPOAuthProvider); + + vi.mocked(MCPOAuthTokenStorage).mockReturnValue({ + getCredentials: mockGetCredentials, + } as unknown as MCPOAuthTokenStorage); + + const transport = await createTransport( + 'test-server', + { + url: 'http://test-server', + type: 'http', + }, + false, + MOCK_CONTEXT, + ); + + const testableTransport = transport as unknown as { + _authProvider?: { + tokens: () => Promise<{ access_token: string } | undefined>; + }; + }; + + expect(testableTransport._authProvider).toBeDefined(); + + const t1 = await testableTransport._authProvider!.tokens(); + const t2 = await testableTransport._authProvider!.tokens(); + + expect(t1?.access_token).toBe('cached-token'); + expect(t2?.access_token).toBe('cached-token'); + + // one call from createTransport fallback detection + one call in first tokens(); + // second tokens() should come from in-memory cache + expect(mockGetCredentials).toHaveBeenCalledTimes(1); + expect(mockGetValidTokenWithMetadata).toHaveBeenCalledTimes(1); + }); + it('does not long-cache token when metadata has no expiresAt', async () => { + const mockGetValidTokenWithMetadata = vi.fn().mockResolvedValue({ + accessToken: 'no-exp-token', + tokenType: 'Bearer', + // expiresAt intentionally omitted + }); + + const mockGetCredentials = vi.fn().mockResolvedValue({ + clientId: 'cid', + token: { + accessToken: 'no-exp-token', + tokenType: 'Bearer', + // expiresAt intentionally omitted + }, + }); + + vi.mocked(MCPOAuthProvider).mockReturnValue({ + getValidTokenWithMetadata: mockGetValidTokenWithMetadata, + } as unknown as MCPOAuthProvider); + + vi.mocked(MCPOAuthTokenStorage).mockReturnValue({ + getCredentials: mockGetCredentials, + } as unknown as MCPOAuthTokenStorage); + + const transport = await createTransport( + 'test-server', + { + url: 'http://test-server', + type: 'http', + }, + false, + MOCK_CONTEXT, + ); + + const testableTransport = transport as unknown as { + _authProvider?: { + tokens: () => Promise< + { access_token: string; expires_in?: number } | undefined + >; + }; + }; + + expect(testableTransport._authProvider).toBeDefined(); + + const t1 = await testableTransport._authProvider!.tokens(); + const t2 = await testableTransport._authProvider!.tokens(); + + expect(t1?.access_token).toBe('no-exp-token'); + expect(t2?.access_token).toBe('no-exp-token'); + expect(t1?.expires_in).toBeUndefined(); + expect(t2?.expires_in).toBeUndefined(); + + // no-expiry tokens should not be long-cached in memory + expect(mockGetValidTokenWithMetadata).toHaveBeenCalledTimes(2); }); it('wraps fetch to convert GET 404 to 405 for POST-only servers (e.g. n8n)', async () => { diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 439e24fb71..61e2db17a8 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -5,6 +5,7 @@ */ import { Client } from '@modelcontextprotocol/sdk/client/index.js'; + import { AjvJsonSchemaValidator } from '@modelcontextprotocol/sdk/validation/ajv'; import type { jsonSchemaValidator, @@ -54,9 +55,11 @@ import { basename } from 'node:path'; import { pathToFileURL } from 'node:url'; import { randomUUID } from 'node:crypto'; import type { McpAuthProvider } from '../mcp/auth-provider.js'; -import { MCPOAuthProvider } from '../mcp/oauth-provider.js'; import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js'; +import { MCPOAuthProvider } from '../mcp/oauth-provider.js'; +import { DynamicStoredOAuthProvider } from '../mcp/stored-token-provider.js'; import { OAuthUtils } from '../mcp/oauth-utils.js'; + import type { PromptRegistry } from '../prompts/prompt-registry.js'; import { getErrorMessage, @@ -82,6 +85,7 @@ import { type EnvironmentSanitizationConfig, } from '../services/environmentSanitization.js'; import { expandEnvVars } from '../utils/envExpansion.js'; + import { GEMINI_CLI_IDENTIFICATION_ENV_VAR, GEMINI_CLI_IDENTIFICATION_ENV_VAR_VALUE, @@ -1025,6 +1029,16 @@ function createAuthProvider( } return undefined; } +/** + * Creates an OAuth token provider for transports so token lookup/refresh happens + * at request/auth time instead of freezing a single token at transport creation. + */ +function createDynamicOAuthTokenProvider( + mcpServerName: string, + mcpServerConfig: MCPServerConfig, +): McpAuthProvider { + return new DynamicStoredOAuthProvider(mcpServerName, mcpServerConfig); +} /** * Create a transport with OAuth token for the given server configuration. @@ -2205,41 +2219,32 @@ export async function createTransport( } } if (mcpServerConfig.httpUrl || mcpServerConfig.url) { - const authProvider = createAuthProvider(mcpServerConfig); + let authProvider = createAuthProvider(mcpServerConfig); const headers: Record = (await authProvider?.getRequestHeaders?.()) ?? {}; if (authProvider === undefined) { - // Check if we have OAuth configuration or stored tokens - let accessToken: string | null = null; - if (mcpServerConfig.oauth?.enabled && mcpServerConfig.oauth) { - const tokenStorage = new MCPOAuthTokenStorage(); - const mcpAuthProvider = new MCPOAuthProvider(tokenStorage); - accessToken = await mcpAuthProvider.getValidToken( - mcpServerName, - mcpServerConfig.oauth, - ); + const tokenStorage = new MCPOAuthTokenStorage(); + const credentials = await tokenStorage.getCredentials(mcpServerName); + const shouldUseDynamicOAuthProvider = !!credentials; - if (!accessToken) { - // Emit info message (not error) since this is expected behavior - cliConfig.emitMcpDiagnostic( - 'info', - `MCP server '${mcpServerName}' requires authentication using: /mcp auth ${mcpServerName}`, - undefined, - mcpServerName, - ); - } - } else { - // Check if we have stored OAuth tokens for this server (from previous authentication) - accessToken = await getStoredOAuthToken(mcpServerName); - if (accessToken) { - debugLogger.log( - `Found stored OAuth token for server '${mcpServerName}'`, - ); - } + if (!shouldUseDynamicOAuthProvider && mcpServerConfig.oauth?.enabled) { + cliConfig.emitMcpDiagnostic( + 'info', + `MCP server '${mcpServerName}' requires authentication using: /mcp auth ${mcpServerName}`, + undefined, + mcpServerName, + ); } - if (accessToken) { - headers['Authorization'] = `Bearer ${accessToken}`; + + if (shouldUseDynamicOAuthProvider) { + debugLogger.log( + `Found stored OAuth token for server '${mcpServerName}'`, + ); + authProvider = createDynamicOAuthTokenProvider( + mcpServerName, + mcpServerConfig, + ); } } @@ -2339,7 +2344,6 @@ export async function createTransport( `Invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio).`, ); } - interface NamedTool { name?: string; }