mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-15 06:12:50 -07:00
fix(core): refresh MCP OAuth token usage after re-auth (#26312)
Co-authored-by: Tommaso Sciortino <sciortino@gmail.com>
This commit is contained in:
@@ -616,4 +616,74 @@ ${authUrl}
|
||||
|
||||
return null;
|
||||
}
|
||||
async getValidTokenWithMetadata(
|
||||
serverName: string,
|
||||
config: MCPOAuthConfig,
|
||||
): Promise<{
|
||||
accessToken: string;
|
||||
tokenType: string;
|
||||
expiresAt?: number;
|
||||
scope?: string;
|
||||
refreshToken?: string;
|
||||
} | null> {
|
||||
const credentials = await this.tokenStorage.getCredentials(serverName);
|
||||
if (!credentials) return null;
|
||||
|
||||
let current = credentials.token;
|
||||
|
||||
if (this.tokenStorage.isTokenExpired(current)) {
|
||||
const clientId = config.clientId ?? credentials.clientId;
|
||||
if (current.refreshToken && clientId && credentials.tokenUrl) {
|
||||
try {
|
||||
const newTokenResponse = await this.refreshAccessToken(
|
||||
config,
|
||||
current.refreshToken,
|
||||
credentials.tokenUrl,
|
||||
credentials.mcpServerUrl,
|
||||
);
|
||||
|
||||
const refreshed: OAuthToken = {
|
||||
accessToken: newTokenResponse.access_token,
|
||||
tokenType: newTokenResponse.token_type,
|
||||
refreshToken:
|
||||
newTokenResponse.refresh_token || current.refreshToken,
|
||||
scope: newTokenResponse.scope || current.scope,
|
||||
};
|
||||
|
||||
if (newTokenResponse.expires_in) {
|
||||
refreshed.expiresAt =
|
||||
Date.now() + newTokenResponse.expires_in * 1000;
|
||||
}
|
||||
|
||||
await this.tokenStorage.saveToken(
|
||||
serverName,
|
||||
refreshed,
|
||||
clientId,
|
||||
credentials.tokenUrl,
|
||||
credentials.mcpServerUrl,
|
||||
);
|
||||
|
||||
current = refreshed;
|
||||
} catch (error) {
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
'Failed to refresh auth token.',
|
||||
error,
|
||||
);
|
||||
await this.tokenStorage.deleteCredentials(serverName);
|
||||
return null;
|
||||
}
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
accessToken: current.accessToken,
|
||||
tokenType: current.tokenType || 'Bearer',
|
||||
expiresAt: current.expiresAt,
|
||||
scope: current.scope,
|
||||
refreshToken: current.refreshToken,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js';
|
||||
import type {
|
||||
OAuthClientInformation,
|
||||
OAuthClientMetadata,
|
||||
OAuthTokens,
|
||||
} from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import type { MCPServerConfig } from '../config/config.js';
|
||||
import { MCPOAuthProvider } from './oauth-provider.js';
|
||||
import { FIVE_MIN_BUFFER_MS } from './oauth-utils.js';
|
||||
|
||||
export class DynamicStoredOAuthProvider implements OAuthClientProvider {
|
||||
readonly redirectUrl = '';
|
||||
readonly clientMetadata: OAuthClientMetadata = {
|
||||
client_name: 'Gemini CLI (Stored OAuth)',
|
||||
redirect_uris: [],
|
||||
grant_types: [],
|
||||
response_types: [],
|
||||
token_endpoint_auth_method: 'none',
|
||||
};
|
||||
|
||||
private clientInfo?: OAuthClientInformation;
|
||||
private readonly oauthProvider = new MCPOAuthProvider();
|
||||
private cachedToken?: OAuthTokens;
|
||||
private tokenExpiryTime?: number;
|
||||
|
||||
constructor(
|
||||
private readonly serverName: string,
|
||||
private readonly serverConfig: MCPServerConfig,
|
||||
) {}
|
||||
|
||||
clientInformation(): OAuthClientInformation | undefined {
|
||||
return this.clientInfo;
|
||||
}
|
||||
|
||||
saveClientInformation(clientInformation: OAuthClientInformation): void {
|
||||
this.clientInfo = clientInformation;
|
||||
}
|
||||
|
||||
private isCachedTokenValid(): boolean {
|
||||
return !!(
|
||||
this.cachedToken?.access_token &&
|
||||
this.tokenExpiryTime &&
|
||||
Date.now() < this.tokenExpiryTime - FIVE_MIN_BUFFER_MS
|
||||
);
|
||||
}
|
||||
|
||||
async tokens(): Promise<OAuthTokens | undefined> {
|
||||
if (this.isCachedTokenValid()) {
|
||||
return this.cachedToken;
|
||||
}
|
||||
|
||||
const oauthConfig =
|
||||
this.serverConfig.oauth?.enabled && this.serverConfig.oauth
|
||||
? this.serverConfig.oauth
|
||||
: {};
|
||||
|
||||
const tokenMeta = await this.oauthProvider.getValidTokenWithMetadata(
|
||||
this.serverName,
|
||||
oauthConfig,
|
||||
);
|
||||
|
||||
if (!tokenMeta?.accessToken) {
|
||||
this.cachedToken = undefined;
|
||||
this.tokenExpiryTime = undefined;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const freshTokens: OAuthTokens = {
|
||||
access_token: tokenMeta.accessToken,
|
||||
token_type: tokenMeta.tokenType || 'Bearer',
|
||||
expires_in: tokenMeta.expiresAt
|
||||
? Math.max(0, Math.floor((tokenMeta.expiresAt - Date.now()) / 1000))
|
||||
: undefined,
|
||||
scope: tokenMeta.scope,
|
||||
refresh_token: tokenMeta.refreshToken,
|
||||
};
|
||||
|
||||
if (freshTokens.expires_in !== undefined) {
|
||||
this.cachedToken = freshTokens;
|
||||
this.tokenExpiryTime = Date.now() + freshTokens.expires_in * 1000;
|
||||
return this.cachedToken;
|
||||
}
|
||||
|
||||
this.cachedToken = undefined;
|
||||
this.tokenExpiryTime = undefined;
|
||||
return freshTokens;
|
||||
}
|
||||
|
||||
saveTokens(_tokens: OAuthTokens): void {}
|
||||
redirectToAuthorization(_authorizationUrl: URL): void {}
|
||||
saveCodeVerifier(_codeVerifier: string): void {}
|
||||
codeVerifier(): string {
|
||||
return '';
|
||||
}
|
||||
}
|
||||
@@ -1780,41 +1780,249 @@ describe('mcp-client', () => {
|
||||
|
||||
describe('createTransport', () => {
|
||||
describe('should connect via httpUrl', () => {
|
||||
it('without headers', async () => {
|
||||
it('uses MCP SDK authProvider token() path for oauth-enabled servers', async () => {
|
||||
const mockGetValidTokenWithMetadata = vi.fn().mockResolvedValue({
|
||||
accessToken: 'fresh-token',
|
||||
tokenType: 'Bearer',
|
||||
expiresAt: Date.now() + 10 * 60 * 1000,
|
||||
});
|
||||
vi.mocked(MCPOAuthProvider).mockReturnValue({
|
||||
getValidTokenWithMetadata: mockGetValidTokenWithMetadata,
|
||||
} as unknown as MCPOAuthProvider);
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage).mockReturnValue({
|
||||
getCredentials: vi.fn().mockResolvedValue({
|
||||
clientId: 'cid',
|
||||
token: {
|
||||
accessToken: 'fresh-token',
|
||||
tokenType: 'Bearer',
|
||||
expiresAt: Date.now() + 10 * 60 * 1000,
|
||||
},
|
||||
}),
|
||||
} as unknown as MCPOAuthTokenStorage);
|
||||
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
httpUrl: 'http://test-server',
|
||||
url: 'http://test-server',
|
||||
type: 'http',
|
||||
oauth: { enabled: true },
|
||||
},
|
||||
false,
|
||||
MOCK_CONTEXT,
|
||||
);
|
||||
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
expect(transport).toMatchObject({
|
||||
_url: new URL('http://test-server'),
|
||||
_requestInit: { headers: {} },
|
||||
});
|
||||
const testableTransport = transport as unknown as {
|
||||
_authProvider?: {
|
||||
tokens: () => Promise<{ access_token: string } | undefined>;
|
||||
};
|
||||
};
|
||||
|
||||
expect(testableTransport._authProvider).toBeDefined();
|
||||
const tokens = await testableTransport._authProvider!.tokens();
|
||||
expect(tokens?.access_token).toBe('fresh-token');
|
||||
});
|
||||
it('uses storage-backed expiry instead of long fallback cache for dynamic authProvider', async () => {
|
||||
const now = Date.now();
|
||||
const soonExpiry = now + 10 * 60 * 1000; // 10 minutes
|
||||
|
||||
const mockGetValidTokenWithMetadata = vi.fn().mockResolvedValue({
|
||||
accessToken: 'fresh-token',
|
||||
tokenType: 'Bearer',
|
||||
expiresAt: soonExpiry,
|
||||
});
|
||||
const mockGetCredentials = vi.fn().mockImplementation(async () => ({
|
||||
clientId: 'cid',
|
||||
token: {
|
||||
accessToken: 'fresh-token',
|
||||
tokenType: 'Bearer',
|
||||
expiresAt: soonExpiry,
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mocked(MCPOAuthProvider).mockReturnValue({
|
||||
getValidTokenWithMetadata: mockGetValidTokenWithMetadata,
|
||||
} as unknown as MCPOAuthProvider);
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage).mockReturnValue({
|
||||
getCredentials: mockGetCredentials,
|
||||
} as unknown as MCPOAuthTokenStorage);
|
||||
|
||||
it('with headers', async () => {
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
httpUrl: 'http://test-server',
|
||||
headers: { Authorization: 'derp' },
|
||||
url: 'http://test-server',
|
||||
type: 'http',
|
||||
},
|
||||
false,
|
||||
MOCK_CONTEXT,
|
||||
);
|
||||
|
||||
expect(transport).toBeInstanceOf(StreamableHTTPClientTransport);
|
||||
expect(transport).toMatchObject({
|
||||
_url: new URL('http://test-server'),
|
||||
_requestInit: {
|
||||
headers: { Authorization: 'derp' },
|
||||
const testableTransport = transport as unknown as {
|
||||
_authProvider?: {
|
||||
tokens: () => Promise<
|
||||
{ access_token: string; expires_in?: number } | undefined
|
||||
>;
|
||||
};
|
||||
};
|
||||
|
||||
expect(testableTransport._authProvider).toBeDefined();
|
||||
|
||||
const tokens = await testableTransport._authProvider!.tokens();
|
||||
expect(tokens?.access_token).toBe('fresh-token');
|
||||
expect(tokens?.expires_in).toBeDefined();
|
||||
expect((tokens?.expires_in ?? 0) <= 10 * 60).toBe(true);
|
||||
|
||||
expect(mockGetValidTokenWithMetadata).toHaveBeenCalledTimes(1);
|
||||
expect(mockGetCredentials).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
it('uses dynamic authProvider when stored OAuth token exists', async () => {
|
||||
const mockGetValidTokenWithMetadata = vi.fn().mockResolvedValue({
|
||||
accessToken: 'stored-fresh-token',
|
||||
tokenType: 'Bearer',
|
||||
expiresAt: Date.now() + 10 * 60 * 1000,
|
||||
});
|
||||
vi.mocked(MCPOAuthProvider).mockReturnValue({
|
||||
getValidTokenWithMetadata: mockGetValidTokenWithMetadata,
|
||||
} as unknown as MCPOAuthProvider);
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage).mockReturnValue({
|
||||
getCredentials: vi.fn().mockResolvedValue({
|
||||
clientId: 'cid',
|
||||
token: {
|
||||
accessToken: 'stored-fresh-token',
|
||||
tokenType: 'Bearer',
|
||||
expiresAt: Date.now() + 10 * 60 * 1000,
|
||||
},
|
||||
}),
|
||||
} as unknown as MCPOAuthTokenStorage);
|
||||
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
url: 'http://test-server',
|
||||
type: 'http',
|
||||
},
|
||||
false,
|
||||
MOCK_CONTEXT,
|
||||
);
|
||||
|
||||
const testableTransport = transport as unknown as {
|
||||
_authProvider?: {
|
||||
tokens: () => Promise<{ access_token: string } | undefined>;
|
||||
};
|
||||
};
|
||||
|
||||
expect(testableTransport._authProvider).toBeDefined();
|
||||
const tokens = await testableTransport._authProvider!.tokens();
|
||||
expect(tokens?.access_token).toBe('stored-fresh-token');
|
||||
});
|
||||
it('caches OAuth tokens in dynamic authProvider and avoids repeated lookups', async () => {
|
||||
const mockGetValidTokenWithMetadata = vi.fn().mockResolvedValue({
|
||||
accessToken: 'cached-token',
|
||||
tokenType: 'Bearer',
|
||||
expiresAt: Date.now() + 10 * 60 * 1000,
|
||||
});
|
||||
const mockGetCredentials = vi.fn().mockResolvedValue({
|
||||
clientId: 'cid',
|
||||
token: {
|
||||
accessToken: 'cached-token',
|
||||
tokenType: 'Bearer',
|
||||
expiresAt: Date.now() + 10 * 60 * 1000,
|
||||
},
|
||||
});
|
||||
|
||||
vi.mocked(MCPOAuthProvider).mockReturnValue({
|
||||
getValidTokenWithMetadata: mockGetValidTokenWithMetadata,
|
||||
} as unknown as MCPOAuthProvider);
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage).mockReturnValue({
|
||||
getCredentials: mockGetCredentials,
|
||||
} as unknown as MCPOAuthTokenStorage);
|
||||
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
url: 'http://test-server',
|
||||
type: 'http',
|
||||
},
|
||||
false,
|
||||
MOCK_CONTEXT,
|
||||
);
|
||||
|
||||
const testableTransport = transport as unknown as {
|
||||
_authProvider?: {
|
||||
tokens: () => Promise<{ access_token: string } | undefined>;
|
||||
};
|
||||
};
|
||||
|
||||
expect(testableTransport._authProvider).toBeDefined();
|
||||
|
||||
const t1 = await testableTransport._authProvider!.tokens();
|
||||
const t2 = await testableTransport._authProvider!.tokens();
|
||||
|
||||
expect(t1?.access_token).toBe('cached-token');
|
||||
expect(t2?.access_token).toBe('cached-token');
|
||||
|
||||
// one call from createTransport fallback detection + one call in first tokens();
|
||||
// second tokens() should come from in-memory cache
|
||||
expect(mockGetCredentials).toHaveBeenCalledTimes(1);
|
||||
expect(mockGetValidTokenWithMetadata).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
it('does not long-cache token when metadata has no expiresAt', async () => {
|
||||
const mockGetValidTokenWithMetadata = vi.fn().mockResolvedValue({
|
||||
accessToken: 'no-exp-token',
|
||||
tokenType: 'Bearer',
|
||||
// expiresAt intentionally omitted
|
||||
});
|
||||
|
||||
const mockGetCredentials = vi.fn().mockResolvedValue({
|
||||
clientId: 'cid',
|
||||
token: {
|
||||
accessToken: 'no-exp-token',
|
||||
tokenType: 'Bearer',
|
||||
// expiresAt intentionally omitted
|
||||
},
|
||||
});
|
||||
|
||||
vi.mocked(MCPOAuthProvider).mockReturnValue({
|
||||
getValidTokenWithMetadata: mockGetValidTokenWithMetadata,
|
||||
} as unknown as MCPOAuthProvider);
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage).mockReturnValue({
|
||||
getCredentials: mockGetCredentials,
|
||||
} as unknown as MCPOAuthTokenStorage);
|
||||
|
||||
const transport = await createTransport(
|
||||
'test-server',
|
||||
{
|
||||
url: 'http://test-server',
|
||||
type: 'http',
|
||||
},
|
||||
false,
|
||||
MOCK_CONTEXT,
|
||||
);
|
||||
|
||||
const testableTransport = transport as unknown as {
|
||||
_authProvider?: {
|
||||
tokens: () => Promise<
|
||||
{ access_token: string; expires_in?: number } | undefined
|
||||
>;
|
||||
};
|
||||
};
|
||||
|
||||
expect(testableTransport._authProvider).toBeDefined();
|
||||
|
||||
const t1 = await testableTransport._authProvider!.tokens();
|
||||
const t2 = await testableTransport._authProvider!.tokens();
|
||||
|
||||
expect(t1?.access_token).toBe('no-exp-token');
|
||||
expect(t2?.access_token).toBe('no-exp-token');
|
||||
expect(t1?.expires_in).toBeUndefined();
|
||||
expect(t2?.expires_in).toBeUndefined();
|
||||
|
||||
// no-expiry tokens should not be long-cached in memory
|
||||
expect(mockGetValidTokenWithMetadata).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('wraps fetch to convert GET 404 to 405 for POST-only servers (e.g. n8n)', async () => {
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
|
||||
import { AjvJsonSchemaValidator } from '@modelcontextprotocol/sdk/validation/ajv';
|
||||
import type {
|
||||
jsonSchemaValidator,
|
||||
@@ -54,9 +55,11 @@ import { basename } from 'node:path';
|
||||
import { pathToFileURL } from 'node:url';
|
||||
import { randomUUID } from 'node:crypto';
|
||||
import type { McpAuthProvider } from '../mcp/auth-provider.js';
|
||||
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
||||
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
||||
import { DynamicStoredOAuthProvider } from '../mcp/stored-token-provider.js';
|
||||
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
||||
|
||||
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import {
|
||||
getErrorMessage,
|
||||
@@ -82,6 +85,7 @@ import {
|
||||
type EnvironmentSanitizationConfig,
|
||||
} from '../services/environmentSanitization.js';
|
||||
import { expandEnvVars } from '../utils/envExpansion.js';
|
||||
|
||||
import {
|
||||
GEMINI_CLI_IDENTIFICATION_ENV_VAR,
|
||||
GEMINI_CLI_IDENTIFICATION_ENV_VAR_VALUE,
|
||||
@@ -1025,6 +1029,16 @@ function createAuthProvider(
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
/**
|
||||
* Creates an OAuth token provider for transports so token lookup/refresh happens
|
||||
* at request/auth time instead of freezing a single token at transport creation.
|
||||
*/
|
||||
function createDynamicOAuthTokenProvider(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
): McpAuthProvider {
|
||||
return new DynamicStoredOAuthProvider(mcpServerName, mcpServerConfig);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a transport with OAuth token for the given server configuration.
|
||||
@@ -2205,41 +2219,32 @@ export async function createTransport(
|
||||
}
|
||||
}
|
||||
if (mcpServerConfig.httpUrl || mcpServerConfig.url) {
|
||||
const authProvider = createAuthProvider(mcpServerConfig);
|
||||
let authProvider = createAuthProvider(mcpServerConfig);
|
||||
const headers: Record<string, string> =
|
||||
(await authProvider?.getRequestHeaders?.()) ?? {};
|
||||
|
||||
if (authProvider === undefined) {
|
||||
// Check if we have OAuth configuration or stored tokens
|
||||
let accessToken: string | null = null;
|
||||
if (mcpServerConfig.oauth?.enabled && mcpServerConfig.oauth) {
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const mcpAuthProvider = new MCPOAuthProvider(tokenStorage);
|
||||
accessToken = await mcpAuthProvider.getValidToken(
|
||||
mcpServerName,
|
||||
mcpServerConfig.oauth,
|
||||
);
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
||||
const shouldUseDynamicOAuthProvider = !!credentials;
|
||||
|
||||
if (!accessToken) {
|
||||
// Emit info message (not error) since this is expected behavior
|
||||
cliConfig.emitMcpDiagnostic(
|
||||
'info',
|
||||
`MCP server '${mcpServerName}' requires authentication using: /mcp auth ${mcpServerName}`,
|
||||
undefined,
|
||||
mcpServerName,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Check if we have stored OAuth tokens for this server (from previous authentication)
|
||||
accessToken = await getStoredOAuthToken(mcpServerName);
|
||||
if (accessToken) {
|
||||
debugLogger.log(
|
||||
`Found stored OAuth token for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
if (!shouldUseDynamicOAuthProvider && mcpServerConfig.oauth?.enabled) {
|
||||
cliConfig.emitMcpDiagnostic(
|
||||
'info',
|
||||
`MCP server '${mcpServerName}' requires authentication using: /mcp auth ${mcpServerName}`,
|
||||
undefined,
|
||||
mcpServerName,
|
||||
);
|
||||
}
|
||||
if (accessToken) {
|
||||
headers['Authorization'] = `Bearer ${accessToken}`;
|
||||
|
||||
if (shouldUseDynamicOAuthProvider) {
|
||||
debugLogger.log(
|
||||
`Found stored OAuth token for server '${mcpServerName}'`,
|
||||
);
|
||||
authProvider = createDynamicOAuthTokenProvider(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2339,7 +2344,6 @@ export async function createTransport(
|
||||
`Invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio).`,
|
||||
);
|
||||
}
|
||||
|
||||
interface NamedTool {
|
||||
name?: string;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user