diff --git a/packages/core/src/code_assist/oauth2.test.ts b/packages/core/src/code_assist/oauth2.test.ts index d089440e16..2210c695f9 100644 --- a/packages/core/src/code_assist/oauth2.test.ts +++ b/packages/core/src/code_assist/oauth2.test.ts @@ -14,7 +14,7 @@ import { clearOauthClientCache, } from './oauth2.js'; import { UserAccountManager } from '../utils/userAccountManager.js'; -import { OAuth2Client, Compute } from 'google-auth-library'; +import { OAuth2Client, Compute, GoogleAuth } from 'google-auth-library'; import * as fs from 'node:fs'; import * as path from 'node:path'; import http from 'node:http'; @@ -420,6 +420,53 @@ describe('oauth2', () => { // Assert the correct credentials were used expect(mockClient.setCredentials).toHaveBeenCalledWith(envCreds); }); + + it('should use GoogleAuth for BYOID credentials from GOOGLE_APPLICATION_CREDENTIALS', async () => { + // Setup BYOID credentials via environment variable + const byoidCredentials = { + type: 'external_account_authorized_user', + client_id: 'mock-client-id', + }; + const envCredsPath = path.join(tempHomeDir, 'byoid_creds.json'); + await fs.promises.writeFile( + envCredsPath, + JSON.stringify(byoidCredentials), + ); + vi.stubEnv('GOOGLE_APPLICATION_CREDENTIALS', envCredsPath); + + // Mock GoogleAuth and its chain of calls + const mockExternalAccountClient = { + getAccessToken: vi.fn().mockResolvedValue({ token: 'byoid-token' }), + }; + const mockFromJSON = vi + .fn() + .mockResolvedValue(mockExternalAccountClient); + const mockGoogleAuthInstance = { + fromJSON: mockFromJSON, + }; + (GoogleAuth as unknown as Mock).mockImplementation( + () => mockGoogleAuthInstance, + ); + + const mockOAuth2Client = { + on: vi.fn(), + }; + (OAuth2Client as unknown as Mock).mockImplementation( + () => mockOAuth2Client, + ); + + const client = await getOauthClient( + AuthType.LOGIN_WITH_GOOGLE, + mockConfig, + ); + + // Assert that GoogleAuth was used and the correct client was returned + expect(GoogleAuth).toHaveBeenCalledWith({ + scopes: expect.any(Array), + }); + expect(mockFromJSON).toHaveBeenCalledWith(byoidCredentials); + expect(client).toBe(mockExternalAccountClient); + }); }); describe('with GCP environment variables', () => { diff --git a/packages/core/src/code_assist/oauth2.ts b/packages/core/src/code_assist/oauth2.ts index fac45172e9..ef0be547f0 100644 --- a/packages/core/src/code_assist/oauth2.ts +++ b/packages/core/src/code_assist/oauth2.ts @@ -4,11 +4,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { Credentials } from 'google-auth-library'; +import type { Credentials, AuthClient, JWTInput } from 'google-auth-library'; import { OAuth2Client, Compute, CodeChallengeMethod, + GoogleAuth, } from 'google-auth-library'; import * as http from 'node:http'; import url from 'node:url'; @@ -64,7 +65,7 @@ export interface OauthWebLogin { loginCompletePromise: Promise; } -const oauthClientPromises = new Map>(); +const oauthClientPromises = new Map>(); function getUseEncryptedStorageFlag() { return process.env[FORCE_ENCRYPTED_FILE_ENV_VAR] === 'true'; @@ -73,7 +74,28 @@ function getUseEncryptedStorageFlag() { async function initOauthClient( authType: AuthType, config: Config, -): Promise { +): Promise { + const credentials = await fetchCachedCredentials(); + + if ( + credentials && + (credentials as { type?: string }).type === + 'external_account_authorized_user' + ) { + const auth = new GoogleAuth({ + scopes: OAUTH_SCOPE, + }); + const byoidClient = await auth.fromJSON({ + ...credentials, + refresh_token: credentials.refresh_token ?? undefined, + }); + const token = await byoidClient.getAccessToken(); + if (token) { + debugLogger.debug('Created BYOID auth client.'); + return byoidClient; + } + } + const client = new OAuth2Client({ clientId: OAUTH_CLIENT_ID, clientSecret: OAUTH_CLIENT_SECRET, @@ -102,20 +124,35 @@ async function initOauthClient( } }); - // If there are cached creds on disk, they always take precedence - if (await loadCachedCredentials(client)) { - // Found valid cached credentials. - // Check if we need to retrieve Google Account ID or Email - if (!userAccountManager.getCachedGoogleAccount()) { - try { - await fetchAndCacheUserInfo(client); - } catch (error) { - // Non-fatal, continue with existing auth. - debugLogger.warn('Failed to fetch user info:', getErrorMessage(error)); + if (credentials) { + client.setCredentials(credentials as Credentials); + try { + // This will verify locally that the credentials look good. + const { token } = await client.getAccessToken(); + if (token) { + // This will check with the server to see if it hasn't been revoked. + await client.getTokenInfo(token); + + if (!userAccountManager.getCachedGoogleAccount()) { + try { + await fetchAndCacheUserInfo(client); + } catch (error) { + // Non-fatal, continue with existing auth. + debugLogger.warn( + 'Failed to fetch user info:', + getErrorMessage(error), + ); + } + } + debugLogger.log('Loaded cached credentials.'); + return client; } + } catch (error) { + debugLogger.debug( + `Cached credentials are not valid:`, + getErrorMessage(error), + ); } - debugLogger.log('Loaded cached credentials.'); - return client; } // In Google Cloud Shell, we can use Application Default Credentials (ADC) @@ -218,7 +255,7 @@ async function initOauthClient( export async function getOauthClient( authType: AuthType, config: Config, -): Promise { +): Promise { if (!oauthClientPromises.has(authType)) { oauthClientPromises.set(authType, initOauthClient(authType, config)); } @@ -432,15 +469,12 @@ export function getAvailablePort(): Promise { }); } -async function loadCachedCredentials(client: OAuth2Client): Promise { +async function fetchCachedCredentials(): Promise< + Credentials | JWTInput | null +> { const useEncryptedStorage = getUseEncryptedStorageFlag(); if (useEncryptedStorage) { - const credentials = await OAuthCredentialStorage.loadCredentials(); - if (credentials) { - client.setCredentials(credentials); - return true; - } - return false; + return await OAuthCredentialStorage.loadCredentials(); } const pathsToTry = [ @@ -450,19 +484,8 @@ async function loadCachedCredentials(client: OAuth2Client): Promise { for (const keyFile of pathsToTry) { try { - const creds = await fs.readFile(keyFile, 'utf-8'); - client.setCredentials(JSON.parse(creds)); - - // This will verify locally that the credentials look good. - const { token } = await client.getAccessToken(); - if (!token) { - continue; - } - - // This will check with the server to see if it hasn't been revoked. - await client.getTokenInfo(token); - - return true; + const keyFileString = await fs.readFile(keyFile, 'utf-8'); + return JSON.parse(keyFileString); } catch (error) { // Log specific error for debugging, but continue trying other paths debugLogger.debug( @@ -472,7 +495,7 @@ async function loadCachedCredentials(client: OAuth2Client): Promise { } } - return false; + return null; } async function cacheCredentials(credentials: Credentials) { diff --git a/packages/core/src/code_assist/server.ts b/packages/core/src/code_assist/server.ts index 915d07c1df..8859d56083 100644 --- a/packages/core/src/code_assist/server.ts +++ b/packages/core/src/code_assist/server.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { OAuth2Client } from 'google-auth-library'; +import type { AuthClient } from 'google-auth-library'; import type { CodeAssistGlobalUserSettingResponse, GoogleRpcResponse, @@ -47,7 +47,7 @@ export const CODE_ASSIST_API_VERSION = 'v1internal'; export class CodeAssistServer implements ContentGenerator { constructor( - readonly client: OAuth2Client, + readonly client: AuthClient, readonly projectId?: string, readonly httpOptions: HttpOptions = {}, readonly sessionId?: string, diff --git a/packages/core/src/code_assist/setup.ts b/packages/core/src/code_assist/setup.ts index 055a0dbb57..d33c019d6c 100644 --- a/packages/core/src/code_assist/setup.ts +++ b/packages/core/src/code_assist/setup.ts @@ -12,7 +12,7 @@ import type { } from './types.js'; import { UserTierId } from './types.js'; import { CodeAssistServer } from './server.js'; -import type { OAuth2Client } from 'google-auth-library'; +import type { AuthClient } from 'google-auth-library'; export class ProjectIdRequiredError extends Error { constructor() { @@ -32,7 +32,7 @@ export interface UserData { * @param projectId the user's project id, if any * @returns the user's actual project id */ -export async function setupUser(client: OAuth2Client): Promise { +export async function setupUser(client: AuthClient): Promise { const projectId = process.env['GOOGLE_CLOUD_PROJECT'] || process.env['GOOGLE_CLOUD_PROJECT_ID'] ||