mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-13 15:40:57 -07:00
Refactor createTransport to duplicate less code (#13010)
Co-authored-by: David McWherter <davidmcw@gmail.com>
This commit is contained in:
@@ -487,7 +487,9 @@ describe('mcp-client', () => {
|
||||
);
|
||||
|
||||
expect(transport).toEqual(
|
||||
new StreamableHTTPClientTransport(new URL('http://test-server'), {}),
|
||||
new StreamableHTTPClientTransport(new URL('http://test-server'), {
|
||||
requestInit: { headers: {} },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -521,7 +523,9 @@ describe('mcp-client', () => {
|
||||
false,
|
||||
);
|
||||
expect(transport).toEqual(
|
||||
new SSEClientTransport(new URL('http://test-server'), {}),
|
||||
new SSEClientTransport(new URL('http://test-server'), {
|
||||
requestInit: { headers: {} },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -410,6 +410,53 @@ function createTransportRequestInit(
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an AuthProvider for the MCP Transport.
|
||||
*
|
||||
* @param mcpServerConfig The MCP server configuration
|
||||
*/
|
||||
function createAuthProvider(mcpServerConfig: MCPServerConfig) {
|
||||
if (
|
||||
mcpServerConfig.authProviderType ===
|
||||
AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION
|
||||
) {
|
||||
return new ServiceAccountImpersonationProvider(mcpServerConfig);
|
||||
}
|
||||
if (
|
||||
mcpServerConfig.authProviderType === AuthProviderType.GOOGLE_CREDENTIALS
|
||||
) {
|
||||
return new GoogleCredentialProvider(mcpServerConfig);
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a transport for URL based servers (remote servers).
|
||||
*
|
||||
* @param mcpServerConfig The MCP server configuration
|
||||
* @param transportOptions The transport options
|
||||
*/
|
||||
function createUrlTransport(
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
transportOptions:
|
||||
| StreamableHTTPClientTransportOptions
|
||||
| SSEClientTransportOptions,
|
||||
): StreamableHTTPClientTransport | SSEClientTransport {
|
||||
if (mcpServerConfig.httpUrl) {
|
||||
return new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerConfig.httpUrl),
|
||||
transportOptions,
|
||||
);
|
||||
}
|
||||
if (mcpServerConfig.url) {
|
||||
return new SSEClientTransport(
|
||||
new URL(mcpServerConfig.url),
|
||||
transportOptions,
|
||||
);
|
||||
}
|
||||
throw new Error('No URL configured for MCP Server');
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a transport with OAuth token for the given server configuration.
|
||||
*
|
||||
@@ -424,28 +471,16 @@ async function createTransportWithOAuth(
|
||||
accessToken: string,
|
||||
): Promise<StreamableHTTPClientTransport | SSEClientTransport | null> {
|
||||
try {
|
||||
if (mcpServerConfig.httpUrl) {
|
||||
// Create HTTP transport with OAuth token
|
||||
const oauthTransportOptions: StreamableHTTPClientTransportOptions = {
|
||||
requestInit: createTransportRequestInit(mcpServerConfig, {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
}),
|
||||
};
|
||||
const headers: Record<string, string> = {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
};
|
||||
const transportOptions:
|
||||
| StreamableHTTPClientTransportOptions
|
||||
| SSEClientTransportOptions = {
|
||||
requestInit: createTransportRequestInit(mcpServerConfig, headers),
|
||||
};
|
||||
|
||||
return new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerConfig.httpUrl),
|
||||
oauthTransportOptions,
|
||||
);
|
||||
} else if (mcpServerConfig.url) {
|
||||
// Create SSE transport with OAuth token in Authorization header
|
||||
return new SSEClientTransport(new URL(mcpServerConfig.url), {
|
||||
requestInit: createTransportRequestInit(mcpServerConfig, {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
return null;
|
||||
return createUrlTransport(mcpServerConfig, transportOptions);
|
||||
} catch (error) {
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
@@ -1223,148 +1258,85 @@ export async function connectToMcpServer(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Visible for Testing */
|
||||
export async function createTransport(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
debugMode: boolean,
|
||||
): Promise<Transport> {
|
||||
if (
|
||||
mcpServerConfig.authProviderType ===
|
||||
AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION
|
||||
) {
|
||||
const provider = new ServiceAccountImpersonationProvider(mcpServerConfig);
|
||||
const transportOptions:
|
||||
| StreamableHTTPClientTransportOptions
|
||||
| SSEClientTransportOptions = {
|
||||
requestInit: createTransportRequestInit(mcpServerConfig, {}),
|
||||
authProvider: provider,
|
||||
};
|
||||
|
||||
if (mcpServerConfig.httpUrl) {
|
||||
return new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerConfig.httpUrl),
|
||||
transportOptions,
|
||||
);
|
||||
} else if (mcpServerConfig.url) {
|
||||
// Default to SSE if only url is provided
|
||||
return new SSEClientTransport(
|
||||
new URL(mcpServerConfig.url),
|
||||
transportOptions,
|
||||
);
|
||||
}
|
||||
throw new Error(
|
||||
'No URL configured for ServiceAccountImpersonation MCP Server',
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
mcpServerConfig.authProviderType === AuthProviderType.GOOGLE_CREDENTIALS
|
||||
) {
|
||||
const provider = new GoogleCredentialProvider(mcpServerConfig);
|
||||
const transportOptions:
|
||||
| StreamableHTTPClientTransportOptions
|
||||
| SSEClientTransportOptions = {
|
||||
requestInit: createTransportRequestInit(mcpServerConfig, {}),
|
||||
authProvider: provider,
|
||||
};
|
||||
if (mcpServerConfig.httpUrl) {
|
||||
return new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerConfig.httpUrl),
|
||||
transportOptions,
|
||||
);
|
||||
} else if (mcpServerConfig.url) {
|
||||
return new SSEClientTransport(
|
||||
new URL(mcpServerConfig.url),
|
||||
transportOptions,
|
||||
);
|
||||
}
|
||||
throw new Error('No URL configured for Google Credentials MCP server');
|
||||
}
|
||||
|
||||
// Check if we have OAuth configuration or stored tokens
|
||||
let accessToken: string | null = null;
|
||||
let hasOAuthConfig = mcpServerConfig.oauth?.enabled;
|
||||
|
||||
if (hasOAuthConfig && mcpServerConfig.oauth) {
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||
accessToken = await authProvider.getValidToken(
|
||||
mcpServerName,
|
||||
mcpServerConfig.oauth,
|
||||
);
|
||||
|
||||
if (!accessToken) {
|
||||
const noUrl = !mcpServerConfig.url && !mcpServerConfig.httpUrl;
|
||||
if (noUrl) {
|
||||
if (
|
||||
mcpServerConfig.authProviderType === AuthProviderType.GOOGLE_CREDENTIALS
|
||||
) {
|
||||
throw new Error(
|
||||
`MCP server '${mcpServerName}' requires OAuth authentication. ` +
|
||||
`Please authenticate using the /mcp auth command.`,
|
||||
`URL must be provided in the config for Google Credentials provider`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Check if we have stored OAuth tokens for this server (from previous authentication)
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
||||
if (credentials) {
|
||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
||||
accessToken = await authProvider.getValidToken(mcpServerName, {
|
||||
// Pass client ID if available
|
||||
clientId: credentials.clientId,
|
||||
});
|
||||
if (
|
||||
mcpServerConfig.authProviderType ===
|
||||
AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION
|
||||
) {
|
||||
throw new Error(
|
||||
`No URL configured for ServiceAccountImpersonation MCP Server`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (accessToken) {
|
||||
hasOAuthConfig = true;
|
||||
debugLogger.log(
|
||||
`Found stored OAuth token for server '${mcpServerName}'`,
|
||||
if (mcpServerConfig.httpUrl || mcpServerConfig.url) {
|
||||
const authProvider = createAuthProvider(mcpServerConfig);
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
if (authProvider === undefined) {
|
||||
// Check if we have OAuth configuration or stored tokens
|
||||
let accessToken: string | null = null;
|
||||
let hasOAuthConfig = mcpServerConfig.oauth?.enabled;
|
||||
if (hasOAuthConfig && mcpServerConfig.oauth) {
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const mcpAuthProvider = new MCPOAuthProvider(tokenStorage);
|
||||
accessToken = await mcpAuthProvider.getValidToken(
|
||||
mcpServerName,
|
||||
mcpServerConfig.oauth,
|
||||
);
|
||||
|
||||
if (!accessToken) {
|
||||
throw new Error(
|
||||
`MCP server '${mcpServerName}' requires OAuth authentication. ` +
|
||||
`Please authenticate using the /mcp auth command.`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Check if we have stored OAuth tokens for this server (from previous authentication)
|
||||
const tokenStorage = new MCPOAuthTokenStorage();
|
||||
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
||||
if (credentials) {
|
||||
const mcpAuthProvider = new MCPOAuthProvider(tokenStorage);
|
||||
accessToken = await mcpAuthProvider.getValidToken(mcpServerName, {
|
||||
// Pass client ID if available
|
||||
clientId: credentials.clientId,
|
||||
});
|
||||
|
||||
if (accessToken) {
|
||||
hasOAuthConfig = true;
|
||||
debugLogger.log(
|
||||
`Found stored OAuth token for server '${mcpServerName}'`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (hasOAuthConfig && accessToken) {
|
||||
headers['Authorization'] = `Bearer ${accessToken}`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mcpServerConfig.httpUrl) {
|
||||
const transportOptions: StreamableHTTPClientTransportOptions = {};
|
||||
const transportOptions:
|
||||
| StreamableHTTPClientTransportOptions
|
||||
| SSEClientTransportOptions = {
|
||||
requestInit: createTransportRequestInit(mcpServerConfig, headers),
|
||||
authProvider,
|
||||
};
|
||||
|
||||
// Set up headers with OAuth token if available
|
||||
if (hasOAuthConfig && accessToken) {
|
||||
transportOptions.requestInit = {
|
||||
headers: {
|
||||
...mcpServerConfig.headers,
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
};
|
||||
} else if (mcpServerConfig.headers) {
|
||||
transportOptions.requestInit = {
|
||||
headers: mcpServerConfig.headers,
|
||||
};
|
||||
}
|
||||
|
||||
return new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerConfig.httpUrl),
|
||||
transportOptions,
|
||||
);
|
||||
}
|
||||
|
||||
if (mcpServerConfig.url) {
|
||||
const transportOptions: SSEClientTransportOptions = {};
|
||||
|
||||
// Set up headers with OAuth token if available
|
||||
if (hasOAuthConfig && accessToken) {
|
||||
transportOptions.requestInit = {
|
||||
headers: {
|
||||
...mcpServerConfig.headers,
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
};
|
||||
} else if (mcpServerConfig.headers) {
|
||||
transportOptions.requestInit = {
|
||||
headers: mcpServerConfig.headers,
|
||||
};
|
||||
}
|
||||
|
||||
return new SSEClientTransport(
|
||||
new URL(mcpServerConfig.url),
|
||||
transportOptions,
|
||||
);
|
||||
return createUrlTransport(mcpServerConfig, transportOptions);
|
||||
}
|
||||
|
||||
if (mcpServerConfig.command) {
|
||||
|
||||
Reference in New Issue
Block a user