mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 14:10:37 -07:00
Create BYOID auth client when detecting BYOID credentials (#11592)
This commit is contained in:
@@ -14,7 +14,7 @@ import {
|
||||
clearOauthClientCache,
|
||||
} from './oauth2.js';
|
||||
import { UserAccountManager } from '../utils/userAccountManager.js';
|
||||
import { OAuth2Client, Compute } from 'google-auth-library';
|
||||
import { OAuth2Client, Compute, GoogleAuth } from 'google-auth-library';
|
||||
import * as fs from 'node:fs';
|
||||
import * as path from 'node:path';
|
||||
import http from 'node:http';
|
||||
@@ -420,6 +420,53 @@ describe('oauth2', () => {
|
||||
// Assert the correct credentials were used
|
||||
expect(mockClient.setCredentials).toHaveBeenCalledWith(envCreds);
|
||||
});
|
||||
|
||||
it('should use GoogleAuth for BYOID credentials from GOOGLE_APPLICATION_CREDENTIALS', async () => {
|
||||
// Setup BYOID credentials via environment variable
|
||||
const byoidCredentials = {
|
||||
type: 'external_account_authorized_user',
|
||||
client_id: 'mock-client-id',
|
||||
};
|
||||
const envCredsPath = path.join(tempHomeDir, 'byoid_creds.json');
|
||||
await fs.promises.writeFile(
|
||||
envCredsPath,
|
||||
JSON.stringify(byoidCredentials),
|
||||
);
|
||||
vi.stubEnv('GOOGLE_APPLICATION_CREDENTIALS', envCredsPath);
|
||||
|
||||
// Mock GoogleAuth and its chain of calls
|
||||
const mockExternalAccountClient = {
|
||||
getAccessToken: vi.fn().mockResolvedValue({ token: 'byoid-token' }),
|
||||
};
|
||||
const mockFromJSON = vi
|
||||
.fn()
|
||||
.mockResolvedValue(mockExternalAccountClient);
|
||||
const mockGoogleAuthInstance = {
|
||||
fromJSON: mockFromJSON,
|
||||
};
|
||||
(GoogleAuth as unknown as Mock).mockImplementation(
|
||||
() => mockGoogleAuthInstance,
|
||||
);
|
||||
|
||||
const mockOAuth2Client = {
|
||||
on: vi.fn(),
|
||||
};
|
||||
(OAuth2Client as unknown as Mock).mockImplementation(
|
||||
() => mockOAuth2Client,
|
||||
);
|
||||
|
||||
const client = await getOauthClient(
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
// Assert that GoogleAuth was used and the correct client was returned
|
||||
expect(GoogleAuth).toHaveBeenCalledWith({
|
||||
scopes: expect.any(Array),
|
||||
});
|
||||
expect(mockFromJSON).toHaveBeenCalledWith(byoidCredentials);
|
||||
expect(client).toBe(mockExternalAccountClient);
|
||||
});
|
||||
});
|
||||
|
||||
describe('with GCP environment variables', () => {
|
||||
|
||||
@@ -4,11 +4,12 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Credentials } from 'google-auth-library';
|
||||
import type { Credentials, AuthClient, JWTInput } from 'google-auth-library';
|
||||
import {
|
||||
OAuth2Client,
|
||||
Compute,
|
||||
CodeChallengeMethod,
|
||||
GoogleAuth,
|
||||
} from 'google-auth-library';
|
||||
import * as http from 'node:http';
|
||||
import url from 'node:url';
|
||||
@@ -64,7 +65,7 @@ export interface OauthWebLogin {
|
||||
loginCompletePromise: Promise<void>;
|
||||
}
|
||||
|
||||
const oauthClientPromises = new Map<AuthType, Promise<OAuth2Client>>();
|
||||
const oauthClientPromises = new Map<AuthType, Promise<AuthClient>>();
|
||||
|
||||
function getUseEncryptedStorageFlag() {
|
||||
return process.env[FORCE_ENCRYPTED_FILE_ENV_VAR] === 'true';
|
||||
@@ -73,7 +74,28 @@ function getUseEncryptedStorageFlag() {
|
||||
async function initOauthClient(
|
||||
authType: AuthType,
|
||||
config: Config,
|
||||
): Promise<OAuth2Client> {
|
||||
): Promise<AuthClient> {
|
||||
const credentials = await fetchCachedCredentials();
|
||||
|
||||
if (
|
||||
credentials &&
|
||||
(credentials as { type?: string }).type ===
|
||||
'external_account_authorized_user'
|
||||
) {
|
||||
const auth = new GoogleAuth({
|
||||
scopes: OAUTH_SCOPE,
|
||||
});
|
||||
const byoidClient = await auth.fromJSON({
|
||||
...credentials,
|
||||
refresh_token: credentials.refresh_token ?? undefined,
|
||||
});
|
||||
const token = await byoidClient.getAccessToken();
|
||||
if (token) {
|
||||
debugLogger.debug('Created BYOID auth client.');
|
||||
return byoidClient;
|
||||
}
|
||||
}
|
||||
|
||||
const client = new OAuth2Client({
|
||||
clientId: OAUTH_CLIENT_ID,
|
||||
clientSecret: OAUTH_CLIENT_SECRET,
|
||||
@@ -102,20 +124,35 @@ async function initOauthClient(
|
||||
}
|
||||
});
|
||||
|
||||
// If there are cached creds on disk, they always take precedence
|
||||
if (await loadCachedCredentials(client)) {
|
||||
// Found valid cached credentials.
|
||||
// Check if we need to retrieve Google Account ID or Email
|
||||
if (!userAccountManager.getCachedGoogleAccount()) {
|
||||
try {
|
||||
await fetchAndCacheUserInfo(client);
|
||||
} catch (error) {
|
||||
// Non-fatal, continue with existing auth.
|
||||
debugLogger.warn('Failed to fetch user info:', getErrorMessage(error));
|
||||
if (credentials) {
|
||||
client.setCredentials(credentials as Credentials);
|
||||
try {
|
||||
// This will verify locally that the credentials look good.
|
||||
const { token } = await client.getAccessToken();
|
||||
if (token) {
|
||||
// This will check with the server to see if it hasn't been revoked.
|
||||
await client.getTokenInfo(token);
|
||||
|
||||
if (!userAccountManager.getCachedGoogleAccount()) {
|
||||
try {
|
||||
await fetchAndCacheUserInfo(client);
|
||||
} catch (error) {
|
||||
// Non-fatal, continue with existing auth.
|
||||
debugLogger.warn(
|
||||
'Failed to fetch user info:',
|
||||
getErrorMessage(error),
|
||||
);
|
||||
}
|
||||
}
|
||||
debugLogger.log('Loaded cached credentials.');
|
||||
return client;
|
||||
}
|
||||
} catch (error) {
|
||||
debugLogger.debug(
|
||||
`Cached credentials are not valid:`,
|
||||
getErrorMessage(error),
|
||||
);
|
||||
}
|
||||
debugLogger.log('Loaded cached credentials.');
|
||||
return client;
|
||||
}
|
||||
|
||||
// In Google Cloud Shell, we can use Application Default Credentials (ADC)
|
||||
@@ -218,7 +255,7 @@ async function initOauthClient(
|
||||
export async function getOauthClient(
|
||||
authType: AuthType,
|
||||
config: Config,
|
||||
): Promise<OAuth2Client> {
|
||||
): Promise<AuthClient> {
|
||||
if (!oauthClientPromises.has(authType)) {
|
||||
oauthClientPromises.set(authType, initOauthClient(authType, config));
|
||||
}
|
||||
@@ -432,15 +469,12 @@ export function getAvailablePort(): Promise<number> {
|
||||
});
|
||||
}
|
||||
|
||||
async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> {
|
||||
async function fetchCachedCredentials(): Promise<
|
||||
Credentials | JWTInput | null
|
||||
> {
|
||||
const useEncryptedStorage = getUseEncryptedStorageFlag();
|
||||
if (useEncryptedStorage) {
|
||||
const credentials = await OAuthCredentialStorage.loadCredentials();
|
||||
if (credentials) {
|
||||
client.setCredentials(credentials);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
return await OAuthCredentialStorage.loadCredentials();
|
||||
}
|
||||
|
||||
const pathsToTry = [
|
||||
@@ -450,19 +484,8 @@ async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> {
|
||||
|
||||
for (const keyFile of pathsToTry) {
|
||||
try {
|
||||
const creds = await fs.readFile(keyFile, 'utf-8');
|
||||
client.setCredentials(JSON.parse(creds));
|
||||
|
||||
// This will verify locally that the credentials look good.
|
||||
const { token } = await client.getAccessToken();
|
||||
if (!token) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// This will check with the server to see if it hasn't been revoked.
|
||||
await client.getTokenInfo(token);
|
||||
|
||||
return true;
|
||||
const keyFileString = await fs.readFile(keyFile, 'utf-8');
|
||||
return JSON.parse(keyFileString);
|
||||
} catch (error) {
|
||||
// Log specific error for debugging, but continue trying other paths
|
||||
debugLogger.debug(
|
||||
@@ -472,7 +495,7 @@ async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> {
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
return null;
|
||||
}
|
||||
|
||||
async function cacheCredentials(credentials: Credentials) {
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { OAuth2Client } from 'google-auth-library';
|
||||
import type { AuthClient } from 'google-auth-library';
|
||||
import type {
|
||||
CodeAssistGlobalUserSettingResponse,
|
||||
GoogleRpcResponse,
|
||||
@@ -47,7 +47,7 @@ export const CODE_ASSIST_API_VERSION = 'v1internal';
|
||||
|
||||
export class CodeAssistServer implements ContentGenerator {
|
||||
constructor(
|
||||
readonly client: OAuth2Client,
|
||||
readonly client: AuthClient,
|
||||
readonly projectId?: string,
|
||||
readonly httpOptions: HttpOptions = {},
|
||||
readonly sessionId?: string,
|
||||
|
||||
@@ -12,7 +12,7 @@ import type {
|
||||
} from './types.js';
|
||||
import { UserTierId } from './types.js';
|
||||
import { CodeAssistServer } from './server.js';
|
||||
import type { OAuth2Client } from 'google-auth-library';
|
||||
import type { AuthClient } from 'google-auth-library';
|
||||
|
||||
export class ProjectIdRequiredError extends Error {
|
||||
constructor() {
|
||||
@@ -32,7 +32,7 @@ export interface UserData {
|
||||
* @param projectId the user's project id, if any
|
||||
* @returns the user's actual project id
|
||||
*/
|
||||
export async function setupUser(client: OAuth2Client): Promise<UserData> {
|
||||
export async function setupUser(client: AuthClient): Promise<UserData> {
|
||||
const projectId =
|
||||
process.env['GOOGLE_CLOUD_PROJECT'] ||
|
||||
process.env['GOOGLE_CLOUD_PROJECT_ID'] ||
|
||||
|
||||
Reference in New Issue
Block a user