mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-12 21:03:05 -07:00
Refactor createTransport to duplicate less code (#13010)
Co-authored-by: David McWherter <davidmcw@gmail.com>
This commit is contained in:
@@ -487,7 +487,9 @@ describe('mcp-client', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
expect(transport).toEqual(
|
expect(transport).toEqual(
|
||||||
new StreamableHTTPClientTransport(new URL('http://test-server'), {}),
|
new StreamableHTTPClientTransport(new URL('http://test-server'), {
|
||||||
|
requestInit: { headers: {} },
|
||||||
|
}),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -521,7 +523,9 @@ describe('mcp-client', () => {
|
|||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
expect(transport).toEqual(
|
expect(transport).toEqual(
|
||||||
new SSEClientTransport(new URL('http://test-server'), {}),
|
new SSEClientTransport(new URL('http://test-server'), {
|
||||||
|
requestInit: { headers: {} },
|
||||||
|
}),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -410,6 +410,53 @@ function createTransportRequestInit(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create an AuthProvider for the MCP Transport.
|
||||||
|
*
|
||||||
|
* @param mcpServerConfig The MCP server configuration
|
||||||
|
*/
|
||||||
|
function createAuthProvider(mcpServerConfig: MCPServerConfig) {
|
||||||
|
if (
|
||||||
|
mcpServerConfig.authProviderType ===
|
||||||
|
AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION
|
||||||
|
) {
|
||||||
|
return new ServiceAccountImpersonationProvider(mcpServerConfig);
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
mcpServerConfig.authProviderType === AuthProviderType.GOOGLE_CREDENTIALS
|
||||||
|
) {
|
||||||
|
return new GoogleCredentialProvider(mcpServerConfig);
|
||||||
|
}
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a transport for URL based servers (remote servers).
|
||||||
|
*
|
||||||
|
* @param mcpServerConfig The MCP server configuration
|
||||||
|
* @param transportOptions The transport options
|
||||||
|
*/
|
||||||
|
function createUrlTransport(
|
||||||
|
mcpServerConfig: MCPServerConfig,
|
||||||
|
transportOptions:
|
||||||
|
| StreamableHTTPClientTransportOptions
|
||||||
|
| SSEClientTransportOptions,
|
||||||
|
): StreamableHTTPClientTransport | SSEClientTransport {
|
||||||
|
if (mcpServerConfig.httpUrl) {
|
||||||
|
return new StreamableHTTPClientTransport(
|
||||||
|
new URL(mcpServerConfig.httpUrl),
|
||||||
|
transportOptions,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (mcpServerConfig.url) {
|
||||||
|
return new SSEClientTransport(
|
||||||
|
new URL(mcpServerConfig.url),
|
||||||
|
transportOptions,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
throw new Error('No URL configured for MCP Server');
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a transport with OAuth token for the given server configuration.
|
* Create a transport with OAuth token for the given server configuration.
|
||||||
*
|
*
|
||||||
@@ -424,28 +471,16 @@ async function createTransportWithOAuth(
|
|||||||
accessToken: string,
|
accessToken: string,
|
||||||
): Promise<StreamableHTTPClientTransport | SSEClientTransport | null> {
|
): Promise<StreamableHTTPClientTransport | SSEClientTransport | null> {
|
||||||
try {
|
try {
|
||||||
if (mcpServerConfig.httpUrl) {
|
const headers: Record<string, string> = {
|
||||||
// Create HTTP transport with OAuth token
|
Authorization: `Bearer ${accessToken}`,
|
||||||
const oauthTransportOptions: StreamableHTTPClientTransportOptions = {
|
};
|
||||||
requestInit: createTransportRequestInit(mcpServerConfig, {
|
const transportOptions:
|
||||||
Authorization: `Bearer ${accessToken}`,
|
| StreamableHTTPClientTransportOptions
|
||||||
}),
|
| SSEClientTransportOptions = {
|
||||||
};
|
requestInit: createTransportRequestInit(mcpServerConfig, headers),
|
||||||
|
};
|
||||||
|
|
||||||
return new StreamableHTTPClientTransport(
|
return createUrlTransport(mcpServerConfig, transportOptions);
|
||||||
new URL(mcpServerConfig.httpUrl),
|
|
||||||
oauthTransportOptions,
|
|
||||||
);
|
|
||||||
} else if (mcpServerConfig.url) {
|
|
||||||
// Create SSE transport with OAuth token in Authorization header
|
|
||||||
return new SSEClientTransport(new URL(mcpServerConfig.url), {
|
|
||||||
requestInit: createTransportRequestInit(mcpServerConfig, {
|
|
||||||
Authorization: `Bearer ${accessToken}`,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
coreEvents.emitFeedback(
|
coreEvents.emitFeedback(
|
||||||
'error',
|
'error',
|
||||||
@@ -1223,148 +1258,85 @@ export async function connectToMcpServer(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Visible for Testing */
|
/** Visible for Testing */
|
||||||
export async function createTransport(
|
export async function createTransport(
|
||||||
mcpServerName: string,
|
mcpServerName: string,
|
||||||
mcpServerConfig: MCPServerConfig,
|
mcpServerConfig: MCPServerConfig,
|
||||||
debugMode: boolean,
|
debugMode: boolean,
|
||||||
): Promise<Transport> {
|
): Promise<Transport> {
|
||||||
if (
|
const noUrl = !mcpServerConfig.url && !mcpServerConfig.httpUrl;
|
||||||
mcpServerConfig.authProviderType ===
|
if (noUrl) {
|
||||||
AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION
|
if (
|
||||||
) {
|
mcpServerConfig.authProviderType === AuthProviderType.GOOGLE_CREDENTIALS
|
||||||
const provider = new ServiceAccountImpersonationProvider(mcpServerConfig);
|
) {
|
||||||
const transportOptions:
|
|
||||||
| StreamableHTTPClientTransportOptions
|
|
||||||
| SSEClientTransportOptions = {
|
|
||||||
requestInit: createTransportRequestInit(mcpServerConfig, {}),
|
|
||||||
authProvider: provider,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (mcpServerConfig.httpUrl) {
|
|
||||||
return new StreamableHTTPClientTransport(
|
|
||||||
new URL(mcpServerConfig.httpUrl),
|
|
||||||
transportOptions,
|
|
||||||
);
|
|
||||||
} else if (mcpServerConfig.url) {
|
|
||||||
// Default to SSE if only url is provided
|
|
||||||
return new SSEClientTransport(
|
|
||||||
new URL(mcpServerConfig.url),
|
|
||||||
transportOptions,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
throw new Error(
|
|
||||||
'No URL configured for ServiceAccountImpersonation MCP Server',
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
mcpServerConfig.authProviderType === AuthProviderType.GOOGLE_CREDENTIALS
|
|
||||||
) {
|
|
||||||
const provider = new GoogleCredentialProvider(mcpServerConfig);
|
|
||||||
const transportOptions:
|
|
||||||
| StreamableHTTPClientTransportOptions
|
|
||||||
| SSEClientTransportOptions = {
|
|
||||||
requestInit: createTransportRequestInit(mcpServerConfig, {}),
|
|
||||||
authProvider: provider,
|
|
||||||
};
|
|
||||||
if (mcpServerConfig.httpUrl) {
|
|
||||||
return new StreamableHTTPClientTransport(
|
|
||||||
new URL(mcpServerConfig.httpUrl),
|
|
||||||
transportOptions,
|
|
||||||
);
|
|
||||||
} else if (mcpServerConfig.url) {
|
|
||||||
return new SSEClientTransport(
|
|
||||||
new URL(mcpServerConfig.url),
|
|
||||||
transportOptions,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
throw new Error('No URL configured for Google Credentials MCP server');
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we have OAuth configuration or stored tokens
|
|
||||||
let accessToken: string | null = null;
|
|
||||||
let hasOAuthConfig = mcpServerConfig.oauth?.enabled;
|
|
||||||
|
|
||||||
if (hasOAuthConfig && mcpServerConfig.oauth) {
|
|
||||||
const tokenStorage = new MCPOAuthTokenStorage();
|
|
||||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
|
||||||
accessToken = await authProvider.getValidToken(
|
|
||||||
mcpServerName,
|
|
||||||
mcpServerConfig.oauth,
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!accessToken) {
|
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`MCP server '${mcpServerName}' requires OAuth authentication. ` +
|
`URL must be provided in the config for Google Credentials provider`,
|
||||||
`Please authenticate using the /mcp auth command.`,
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} else {
|
if (
|
||||||
// Check if we have stored OAuth tokens for this server (from previous authentication)
|
mcpServerConfig.authProviderType ===
|
||||||
const tokenStorage = new MCPOAuthTokenStorage();
|
AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION
|
||||||
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
) {
|
||||||
if (credentials) {
|
throw new Error(
|
||||||
const authProvider = new MCPOAuthProvider(tokenStorage);
|
`No URL configured for ServiceAccountImpersonation MCP Server`,
|
||||||
accessToken = await authProvider.getValidToken(mcpServerName, {
|
);
|
||||||
// Pass client ID if available
|
}
|
||||||
clientId: credentials.clientId,
|
}
|
||||||
});
|
|
||||||
|
|
||||||
if (accessToken) {
|
if (mcpServerConfig.httpUrl || mcpServerConfig.url) {
|
||||||
hasOAuthConfig = true;
|
const authProvider = createAuthProvider(mcpServerConfig);
|
||||||
debugLogger.log(
|
|
||||||
`Found stored OAuth token for server '${mcpServerName}'`,
|
const headers: Record<string, string> = {};
|
||||||
|
if (authProvider === undefined) {
|
||||||
|
// Check if we have OAuth configuration or stored tokens
|
||||||
|
let accessToken: string | null = null;
|
||||||
|
let hasOAuthConfig = mcpServerConfig.oauth?.enabled;
|
||||||
|
if (hasOAuthConfig && mcpServerConfig.oauth) {
|
||||||
|
const tokenStorage = new MCPOAuthTokenStorage();
|
||||||
|
const mcpAuthProvider = new MCPOAuthProvider(tokenStorage);
|
||||||
|
accessToken = await mcpAuthProvider.getValidToken(
|
||||||
|
mcpServerName,
|
||||||
|
mcpServerConfig.oauth,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if (!accessToken) {
|
||||||
|
throw new Error(
|
||||||
|
`MCP server '${mcpServerName}' requires OAuth authentication. ` +
|
||||||
|
`Please authenticate using the /mcp auth command.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Check if we have stored OAuth tokens for this server (from previous authentication)
|
||||||
|
const tokenStorage = new MCPOAuthTokenStorage();
|
||||||
|
const credentials = await tokenStorage.getCredentials(mcpServerName);
|
||||||
|
if (credentials) {
|
||||||
|
const mcpAuthProvider = new MCPOAuthProvider(tokenStorage);
|
||||||
|
accessToken = await mcpAuthProvider.getValidToken(mcpServerName, {
|
||||||
|
// Pass client ID if available
|
||||||
|
clientId: credentials.clientId,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (accessToken) {
|
||||||
|
hasOAuthConfig = true;
|
||||||
|
debugLogger.log(
|
||||||
|
`Found stored OAuth token for server '${mcpServerName}'`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (hasOAuthConfig && accessToken) {
|
||||||
|
headers['Authorization'] = `Bearer ${accessToken}`;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (mcpServerConfig.httpUrl) {
|
const transportOptions:
|
||||||
const transportOptions: StreamableHTTPClientTransportOptions = {};
|
| StreamableHTTPClientTransportOptions
|
||||||
|
| SSEClientTransportOptions = {
|
||||||
|
requestInit: createTransportRequestInit(mcpServerConfig, headers),
|
||||||
|
authProvider,
|
||||||
|
};
|
||||||
|
|
||||||
// Set up headers with OAuth token if available
|
return createUrlTransport(mcpServerConfig, transportOptions);
|
||||||
if (hasOAuthConfig && accessToken) {
|
|
||||||
transportOptions.requestInit = {
|
|
||||||
headers: {
|
|
||||||
...mcpServerConfig.headers,
|
|
||||||
Authorization: `Bearer ${accessToken}`,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
} else if (mcpServerConfig.headers) {
|
|
||||||
transportOptions.requestInit = {
|
|
||||||
headers: mcpServerConfig.headers,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return new StreamableHTTPClientTransport(
|
|
||||||
new URL(mcpServerConfig.httpUrl),
|
|
||||||
transportOptions,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mcpServerConfig.url) {
|
|
||||||
const transportOptions: SSEClientTransportOptions = {};
|
|
||||||
|
|
||||||
// Set up headers with OAuth token if available
|
|
||||||
if (hasOAuthConfig && accessToken) {
|
|
||||||
transportOptions.requestInit = {
|
|
||||||
headers: {
|
|
||||||
...mcpServerConfig.headers,
|
|
||||||
Authorization: `Bearer ${accessToken}`,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
} else if (mcpServerConfig.headers) {
|
|
||||||
transportOptions.requestInit = {
|
|
||||||
headers: mcpServerConfig.headers,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return new SSEClientTransport(
|
|
||||||
new URL(mcpServerConfig.url),
|
|
||||||
transportOptions,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mcpServerConfig.command) {
|
if (mcpServerConfig.command) {
|
||||||
|
|||||||
Reference in New Issue
Block a user