fix(core): refresh MCP OAuth token usage after re-auth (#26312)

Co-authored-by: Tommaso Sciortino <sciortino@gmail.com>
This commit is contained in:
Sahil Kirad
2026-05-14 00:31:27 +05:30
committed by GitHub
parent fc4054446f
commit fd01cc03bf
5 changed files with 430 additions and 47 deletions
+1 -1
View File
@@ -172,7 +172,7 @@ describe('file-system', () => {
).toBeDefined(); ).toBeDefined();
const newFileContent = rig.readFile(fileName); const newFileContent = rig.readFile(fileName);
expect(newFileContent).toBe('1.0.1'); expect(newFileContent.trimEnd()).toBe('1.0.1');
}); });
it.skip('should replace multiple instances of a string', async () => { it.skip('should replace multiple instances of a string', async () => {
+70
View File
@@ -616,4 +616,74 @@ ${authUrl}
return null; return null;
} }
async getValidTokenWithMetadata(
serverName: string,
config: MCPOAuthConfig,
): Promise<{
accessToken: string;
tokenType: string;
expiresAt?: number;
scope?: string;
refreshToken?: string;
} | null> {
const credentials = await this.tokenStorage.getCredentials(serverName);
if (!credentials) return null;
let current = credentials.token;
if (this.tokenStorage.isTokenExpired(current)) {
const clientId = config.clientId ?? credentials.clientId;
if (current.refreshToken && clientId && credentials.tokenUrl) {
try {
const newTokenResponse = await this.refreshAccessToken(
config,
current.refreshToken,
credentials.tokenUrl,
credentials.mcpServerUrl,
);
const refreshed: OAuthToken = {
accessToken: newTokenResponse.access_token,
tokenType: newTokenResponse.token_type,
refreshToken:
newTokenResponse.refresh_token || current.refreshToken,
scope: newTokenResponse.scope || current.scope,
};
if (newTokenResponse.expires_in) {
refreshed.expiresAt =
Date.now() + newTokenResponse.expires_in * 1000;
}
await this.tokenStorage.saveToken(
serverName,
refreshed,
clientId,
credentials.tokenUrl,
credentials.mcpServerUrl,
);
current = refreshed;
} catch (error) {
coreEvents.emitFeedback(
'error',
'Failed to refresh auth token.',
error,
);
await this.tokenStorage.deleteCredentials(serverName);
return null;
}
} else {
return null;
}
}
return {
accessToken: current.accessToken,
tokenType: current.tokenType || 'Bearer',
expiresAt: current.expiresAt,
scope: current.scope,
refreshToken: current.refreshToken,
};
}
} }
@@ -0,0 +1,101 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
import type {
OAuthClientInformation,
OAuthClientMetadata,
OAuthTokens,
} from '@modelcontextprotocol/sdk/shared/auth.js';
import type { MCPServerConfig } from '../config/config.js';
import { MCPOAuthProvider } from './oauth-provider.js';
import { FIVE_MIN_BUFFER_MS } from './oauth-utils.js';
export class DynamicStoredOAuthProvider implements OAuthClientProvider {
readonly redirectUrl = '';
readonly clientMetadata: OAuthClientMetadata = {
client_name: 'Gemini CLI (Stored OAuth)',
redirect_uris: [],
grant_types: [],
response_types: [],
token_endpoint_auth_method: 'none',
};
private clientInfo?: OAuthClientInformation;
private readonly oauthProvider = new MCPOAuthProvider();
private cachedToken?: OAuthTokens;
private tokenExpiryTime?: number;
constructor(
private readonly serverName: string,
private readonly serverConfig: MCPServerConfig,
) {}
clientInformation(): OAuthClientInformation | undefined {
return this.clientInfo;
}
saveClientInformation(clientInformation: OAuthClientInformation): void {
this.clientInfo = clientInformation;
}
private isCachedTokenValid(): boolean {
return !!(
this.cachedToken?.access_token &&
this.tokenExpiryTime &&
Date.now() < this.tokenExpiryTime - FIVE_MIN_BUFFER_MS
);
}
async tokens(): Promise<OAuthTokens | undefined> {
if (this.isCachedTokenValid()) {
return this.cachedToken;
}
const oauthConfig =
this.serverConfig.oauth?.enabled && this.serverConfig.oauth
? this.serverConfig.oauth
: {};
const tokenMeta = await this.oauthProvider.getValidTokenWithMetadata(
this.serverName,
oauthConfig,
);
if (!tokenMeta?.accessToken) {
this.cachedToken = undefined;
this.tokenExpiryTime = undefined;
return undefined;
}
const freshTokens: OAuthTokens = {
access_token: tokenMeta.accessToken,
token_type: tokenMeta.tokenType || 'Bearer',
expires_in: tokenMeta.expiresAt
? Math.max(0, Math.floor((tokenMeta.expiresAt - Date.now()) / 1000))
: undefined,
scope: tokenMeta.scope,
refresh_token: tokenMeta.refreshToken,
};
if (freshTokens.expires_in !== undefined) {
this.cachedToken = freshTokens;
this.tokenExpiryTime = Date.now() + freshTokens.expires_in * 1000;
return this.cachedToken;
}
this.cachedToken = undefined;
this.tokenExpiryTime = undefined;
return freshTokens;
}
saveTokens(_tokens: OAuthTokens): void {}
redirectToAuthorization(_authorizationUrl: URL): void {}
saveCodeVerifier(_codeVerifier: string): void {}
codeVerifier(): string {
return '';
}
}
+223 -15
View File
@@ -1780,41 +1780,249 @@ describe('mcp-client', () => {
describe('createTransport', () => { describe('createTransport', () => {
describe('should connect via httpUrl', () => { describe('should connect via httpUrl', () => {
it('without headers', async () => { it('uses MCP SDK authProvider token() path for oauth-enabled servers', async () => {
const mockGetValidTokenWithMetadata = vi.fn().mockResolvedValue({
accessToken: 'fresh-token',
tokenType: 'Bearer',
expiresAt: Date.now() + 10 * 60 * 1000,
});
vi.mocked(MCPOAuthProvider).mockReturnValue({
getValidTokenWithMetadata: mockGetValidTokenWithMetadata,
} as unknown as MCPOAuthProvider);
vi.mocked(MCPOAuthTokenStorage).mockReturnValue({
getCredentials: vi.fn().mockResolvedValue({
clientId: 'cid',
token: {
accessToken: 'fresh-token',
tokenType: 'Bearer',
expiresAt: Date.now() + 10 * 60 * 1000,
},
}),
} as unknown as MCPOAuthTokenStorage);
const transport = await createTransport( const transport = await createTransport(
'test-server', 'test-server',
{ {
httpUrl: 'http://test-server', url: 'http://test-server',
type: 'http',
oauth: { enabled: true },
}, },
false, false,
MOCK_CONTEXT, MOCK_CONTEXT,
); );
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); const testableTransport = transport as unknown as {
expect(transport).toMatchObject({ _authProvider?: {
_url: new URL('http://test-server'), tokens: () => Promise<{ access_token: string } | undefined>;
_requestInit: { headers: {} }, };
}); };
expect(testableTransport._authProvider).toBeDefined();
const tokens = await testableTransport._authProvider!.tokens();
expect(tokens?.access_token).toBe('fresh-token');
}); });
it('uses storage-backed expiry instead of long fallback cache for dynamic authProvider', async () => {
const now = Date.now();
const soonExpiry = now + 10 * 60 * 1000; // 10 minutes
const mockGetValidTokenWithMetadata = vi.fn().mockResolvedValue({
accessToken: 'fresh-token',
tokenType: 'Bearer',
expiresAt: soonExpiry,
});
const mockGetCredentials = vi.fn().mockImplementation(async () => ({
clientId: 'cid',
token: {
accessToken: 'fresh-token',
tokenType: 'Bearer',
expiresAt: soonExpiry,
},
}));
vi.mocked(MCPOAuthProvider).mockReturnValue({
getValidTokenWithMetadata: mockGetValidTokenWithMetadata,
} as unknown as MCPOAuthProvider);
vi.mocked(MCPOAuthTokenStorage).mockReturnValue({
getCredentials: mockGetCredentials,
} as unknown as MCPOAuthTokenStorage);
it('with headers', async () => {
const transport = await createTransport( const transport = await createTransport(
'test-server', 'test-server',
{ {
httpUrl: 'http://test-server', url: 'http://test-server',
headers: { Authorization: 'derp' }, type: 'http',
}, },
false, false,
MOCK_CONTEXT, MOCK_CONTEXT,
); );
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); const testableTransport = transport as unknown as {
expect(transport).toMatchObject({ _authProvider?: {
_url: new URL('http://test-server'), tokens: () => Promise<
_requestInit: { { access_token: string; expires_in?: number } | undefined
headers: { Authorization: 'derp' }, >;
};
};
expect(testableTransport._authProvider).toBeDefined();
const tokens = await testableTransport._authProvider!.tokens();
expect(tokens?.access_token).toBe('fresh-token');
expect(tokens?.expires_in).toBeDefined();
expect((tokens?.expires_in ?? 0) <= 10 * 60).toBe(true);
expect(mockGetValidTokenWithMetadata).toHaveBeenCalledTimes(1);
expect(mockGetCredentials).toHaveBeenCalledTimes(1);
});
it('uses dynamic authProvider when stored OAuth token exists', async () => {
const mockGetValidTokenWithMetadata = vi.fn().mockResolvedValue({
accessToken: 'stored-fresh-token',
tokenType: 'Bearer',
expiresAt: Date.now() + 10 * 60 * 1000,
});
vi.mocked(MCPOAuthProvider).mockReturnValue({
getValidTokenWithMetadata: mockGetValidTokenWithMetadata,
} as unknown as MCPOAuthProvider);
vi.mocked(MCPOAuthTokenStorage).mockReturnValue({
getCredentials: vi.fn().mockResolvedValue({
clientId: 'cid',
token: {
accessToken: 'stored-fresh-token',
tokenType: 'Bearer',
expiresAt: Date.now() + 10 * 60 * 1000,
},
}),
} as unknown as MCPOAuthTokenStorage);
const transport = await createTransport(
'test-server',
{
url: 'http://test-server',
type: 'http',
},
false,
MOCK_CONTEXT,
);
const testableTransport = transport as unknown as {
_authProvider?: {
tokens: () => Promise<{ access_token: string } | undefined>;
};
};
expect(testableTransport._authProvider).toBeDefined();
const tokens = await testableTransport._authProvider!.tokens();
expect(tokens?.access_token).toBe('stored-fresh-token');
});
it('caches OAuth tokens in dynamic authProvider and avoids repeated lookups', async () => {
const mockGetValidTokenWithMetadata = vi.fn().mockResolvedValue({
accessToken: 'cached-token',
tokenType: 'Bearer',
expiresAt: Date.now() + 10 * 60 * 1000,
});
const mockGetCredentials = vi.fn().mockResolvedValue({
clientId: 'cid',
token: {
accessToken: 'cached-token',
tokenType: 'Bearer',
expiresAt: Date.now() + 10 * 60 * 1000,
}, },
}); });
vi.mocked(MCPOAuthProvider).mockReturnValue({
getValidTokenWithMetadata: mockGetValidTokenWithMetadata,
} as unknown as MCPOAuthProvider);
vi.mocked(MCPOAuthTokenStorage).mockReturnValue({
getCredentials: mockGetCredentials,
} as unknown as MCPOAuthTokenStorage);
const transport = await createTransport(
'test-server',
{
url: 'http://test-server',
type: 'http',
},
false,
MOCK_CONTEXT,
);
const testableTransport = transport as unknown as {
_authProvider?: {
tokens: () => Promise<{ access_token: string } | undefined>;
};
};
expect(testableTransport._authProvider).toBeDefined();
const t1 = await testableTransport._authProvider!.tokens();
const t2 = await testableTransport._authProvider!.tokens();
expect(t1?.access_token).toBe('cached-token');
expect(t2?.access_token).toBe('cached-token');
// one call from createTransport fallback detection + one call in first tokens();
// second tokens() should come from in-memory cache
expect(mockGetCredentials).toHaveBeenCalledTimes(1);
expect(mockGetValidTokenWithMetadata).toHaveBeenCalledTimes(1);
});
it('does not long-cache token when metadata has no expiresAt', async () => {
const mockGetValidTokenWithMetadata = vi.fn().mockResolvedValue({
accessToken: 'no-exp-token',
tokenType: 'Bearer',
// expiresAt intentionally omitted
});
const mockGetCredentials = vi.fn().mockResolvedValue({
clientId: 'cid',
token: {
accessToken: 'no-exp-token',
tokenType: 'Bearer',
// expiresAt intentionally omitted
},
});
vi.mocked(MCPOAuthProvider).mockReturnValue({
getValidTokenWithMetadata: mockGetValidTokenWithMetadata,
} as unknown as MCPOAuthProvider);
vi.mocked(MCPOAuthTokenStorage).mockReturnValue({
getCredentials: mockGetCredentials,
} as unknown as MCPOAuthTokenStorage);
const transport = await createTransport(
'test-server',
{
url: 'http://test-server',
type: 'http',
},
false,
MOCK_CONTEXT,
);
const testableTransport = transport as unknown as {
_authProvider?: {
tokens: () => Promise<
{ access_token: string; expires_in?: number } | undefined
>;
};
};
expect(testableTransport._authProvider).toBeDefined();
const t1 = await testableTransport._authProvider!.tokens();
const t2 = await testableTransport._authProvider!.tokens();
expect(t1?.access_token).toBe('no-exp-token');
expect(t2?.access_token).toBe('no-exp-token');
expect(t1?.expires_in).toBeUndefined();
expect(t2?.expires_in).toBeUndefined();
// no-expiry tokens should not be long-cached in memory
expect(mockGetValidTokenWithMetadata).toHaveBeenCalledTimes(2);
}); });
it('wraps fetch to convert GET 404 to 405 for POST-only servers (e.g. n8n)', async () => { it('wraps fetch to convert GET 404 to 405 for POST-only servers (e.g. n8n)', async () => {
+35 -31
View File
@@ -5,6 +5,7 @@
*/ */
import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { AjvJsonSchemaValidator } from '@modelcontextprotocol/sdk/validation/ajv'; import { AjvJsonSchemaValidator } from '@modelcontextprotocol/sdk/validation/ajv';
import type { import type {
jsonSchemaValidator, jsonSchemaValidator,
@@ -54,9 +55,11 @@ import { basename } from 'node:path';
import { pathToFileURL } from 'node:url'; import { pathToFileURL } from 'node:url';
import { randomUUID } from 'node:crypto'; import { randomUUID } from 'node:crypto';
import type { McpAuthProvider } from '../mcp/auth-provider.js'; import type { McpAuthProvider } from '../mcp/auth-provider.js';
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js'; import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
import { DynamicStoredOAuthProvider } from '../mcp/stored-token-provider.js';
import { OAuthUtils } from '../mcp/oauth-utils.js'; import { OAuthUtils } from '../mcp/oauth-utils.js';
import type { PromptRegistry } from '../prompts/prompt-registry.js'; import type { PromptRegistry } from '../prompts/prompt-registry.js';
import { import {
getErrorMessage, getErrorMessage,
@@ -82,6 +85,7 @@ import {
type EnvironmentSanitizationConfig, type EnvironmentSanitizationConfig,
} from '../services/environmentSanitization.js'; } from '../services/environmentSanitization.js';
import { expandEnvVars } from '../utils/envExpansion.js'; import { expandEnvVars } from '../utils/envExpansion.js';
import { import {
GEMINI_CLI_IDENTIFICATION_ENV_VAR, GEMINI_CLI_IDENTIFICATION_ENV_VAR,
GEMINI_CLI_IDENTIFICATION_ENV_VAR_VALUE, GEMINI_CLI_IDENTIFICATION_ENV_VAR_VALUE,
@@ -1025,6 +1029,16 @@ function createAuthProvider(
} }
return undefined; return undefined;
} }
/**
* Creates an OAuth token provider for transports so token lookup/refresh happens
* at request/auth time instead of freezing a single token at transport creation.
*/
function createDynamicOAuthTokenProvider(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
): McpAuthProvider {
return new DynamicStoredOAuthProvider(mcpServerName, mcpServerConfig);
}
/** /**
* Create a transport with OAuth token for the given server configuration. * Create a transport with OAuth token for the given server configuration.
@@ -2205,41 +2219,32 @@ export async function createTransport(
} }
} }
if (mcpServerConfig.httpUrl || mcpServerConfig.url) { if (mcpServerConfig.httpUrl || mcpServerConfig.url) {
const authProvider = createAuthProvider(mcpServerConfig); let authProvider = createAuthProvider(mcpServerConfig);
const headers: Record<string, string> = const headers: Record<string, string> =
(await authProvider?.getRequestHeaders?.()) ?? {}; (await authProvider?.getRequestHeaders?.()) ?? {};
if (authProvider === undefined) { if (authProvider === undefined) {
// Check if we have OAuth configuration or stored tokens const tokenStorage = new MCPOAuthTokenStorage();
let accessToken: string | null = null; const credentials = await tokenStorage.getCredentials(mcpServerName);
if (mcpServerConfig.oauth?.enabled && mcpServerConfig.oauth) { const shouldUseDynamicOAuthProvider = !!credentials;
const tokenStorage = new MCPOAuthTokenStorage();
const mcpAuthProvider = new MCPOAuthProvider(tokenStorage);
accessToken = await mcpAuthProvider.getValidToken(
mcpServerName,
mcpServerConfig.oauth,
);
if (!accessToken) { if (!shouldUseDynamicOAuthProvider && mcpServerConfig.oauth?.enabled) {
// Emit info message (not error) since this is expected behavior cliConfig.emitMcpDiagnostic(
cliConfig.emitMcpDiagnostic( 'info',
'info', `MCP server '${mcpServerName}' requires authentication using: /mcp auth ${mcpServerName}`,
`MCP server '${mcpServerName}' requires authentication using: /mcp auth ${mcpServerName}`, undefined,
undefined, mcpServerName,
mcpServerName, );
);
}
} else {
// Check if we have stored OAuth tokens for this server (from previous authentication)
accessToken = await getStoredOAuthToken(mcpServerName);
if (accessToken) {
debugLogger.log(
`Found stored OAuth token for server '${mcpServerName}'`,
);
}
} }
if (accessToken) {
headers['Authorization'] = `Bearer ${accessToken}`; if (shouldUseDynamicOAuthProvider) {
debugLogger.log(
`Found stored OAuth token for server '${mcpServerName}'`,
);
authProvider = createDynamicOAuthTokenProvider(
mcpServerName,
mcpServerConfig,
);
} }
} }
@@ -2339,7 +2344,6 @@ export async function createTransport(
`Invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio).`, `Invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio).`,
); );
} }
interface NamedTool { interface NamedTool {
name?: string; name?: string;
} }