mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-17 01:21:10 -07:00
feat(core): add google credentials provider for remote agents (#21024)
This commit is contained in:
@@ -44,7 +44,7 @@ interface FrontmatterLocalAgentDefinition
|
||||
* Authentication configuration for remote agents in frontmatter format.
|
||||
*/
|
||||
interface FrontmatterAuthConfig {
|
||||
type: 'apiKey' | 'http' | 'oauth2';
|
||||
type: 'apiKey' | 'http' | 'google-credentials' | 'oauth2';
|
||||
// API Key
|
||||
key?: string;
|
||||
name?: string;
|
||||
@@ -54,10 +54,11 @@ interface FrontmatterAuthConfig {
|
||||
username?: string;
|
||||
password?: string;
|
||||
value?: string;
|
||||
// Google Credentials
|
||||
scopes?: string[];
|
||||
// OAuth2
|
||||
client_id?: string;
|
||||
client_secret?: string;
|
||||
scopes?: string[];
|
||||
authorization_url?: string;
|
||||
token_url?: string;
|
||||
}
|
||||
@@ -152,6 +153,15 @@ const httpAuthSchema = z.object({
|
||||
value: z.string().min(1).optional(),
|
||||
});
|
||||
|
||||
/**
|
||||
* Google Credentials auth schema.
|
||||
*/
|
||||
const googleCredentialsAuthSchema = z.object({
|
||||
...baseAuthFields,
|
||||
type: z.literal('google-credentials'),
|
||||
scopes: z.array(z.string()).optional(),
|
||||
});
|
||||
|
||||
/**
|
||||
* OAuth2 auth schema.
|
||||
* authorization_url and token_url can be discovered from the agent card if omitted.
|
||||
@@ -170,6 +180,7 @@ const authConfigSchema = z
|
||||
.discriminatedUnion('type', [
|
||||
apiKeyAuthSchema,
|
||||
httpAuthSchema,
|
||||
googleCredentialsAuthSchema,
|
||||
oauth2AuthSchema,
|
||||
])
|
||||
.superRefine((data, ctx) => {
|
||||
@@ -369,6 +380,13 @@ function convertFrontmatterAuthToConfig(
|
||||
name: frontmatter.name,
|
||||
};
|
||||
|
||||
case 'google-credentials':
|
||||
return {
|
||||
...base,
|
||||
type: 'google-credentials',
|
||||
scopes: frontmatter.scopes,
|
||||
};
|
||||
|
||||
case 'http': {
|
||||
if (!frontmatter.scheme) {
|
||||
throw new Error(
|
||||
|
||||
@@ -12,12 +12,15 @@ import type {
|
||||
} from './types.js';
|
||||
import { ApiKeyAuthProvider } from './api-key-provider.js';
|
||||
import { HttpAuthProvider } from './http-provider.js';
|
||||
import { GoogleCredentialsAuthProvider } from './google-credentials-provider.js';
|
||||
|
||||
export interface CreateAuthProviderOptions {
|
||||
/** Required for OAuth/OIDC token storage. */
|
||||
agentName?: string;
|
||||
authConfig?: A2AAuthConfig;
|
||||
agentCard?: AgentCard;
|
||||
/** Required by some providers (like google-credentials) to determine token audience. */
|
||||
targetUrl?: string;
|
||||
/** URL to fetch the agent card from, used for OAuth2 URL discovery. */
|
||||
agentCardUrl?: string;
|
||||
}
|
||||
@@ -43,9 +46,14 @@ export class A2AAuthProviderFactory {
|
||||
}
|
||||
|
||||
switch (authConfig.type) {
|
||||
case 'google-credentials':
|
||||
// TODO: Implement
|
||||
throw new Error('google-credentials auth provider not yet implemented');
|
||||
case 'google-credentials': {
|
||||
const provider = new GoogleCredentialsAuthProvider(
|
||||
authConfig,
|
||||
options.targetUrl,
|
||||
);
|
||||
await provider.initialize();
|
||||
return provider;
|
||||
}
|
||||
|
||||
case 'apiKey': {
|
||||
const provider = new ApiKeyAuthProvider(authConfig);
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
|
||||
import { GoogleCredentialsAuthProvider } from './google-credentials-provider.js';
|
||||
import type { GoogleCredentialsAuthConfig } from './types.js';
|
||||
import { GoogleAuth } from 'google-auth-library';
|
||||
import { OAuthUtils } from '../../mcp/oauth-utils.js';
|
||||
|
||||
// Mock the external dependencies
|
||||
vi.mock('google-auth-library', () => ({
|
||||
GoogleAuth: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('GoogleCredentialsAuthProvider', () => {
|
||||
const mockConfig: GoogleCredentialsAuthConfig = {
|
||||
type: 'google-credentials',
|
||||
};
|
||||
|
||||
let mockGetClient: Mock;
|
||||
let mockGetAccessToken: Mock;
|
||||
let mockGetIdTokenClient: Mock;
|
||||
let mockFetchIdToken: Mock;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
mockGetAccessToken = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ token: 'mock-access-token' });
|
||||
mockGetClient = vi.fn().mockResolvedValue({
|
||||
getAccessToken: mockGetAccessToken,
|
||||
credentials: { expiry_date: Date.now() + 3600 * 1000 },
|
||||
});
|
||||
|
||||
mockFetchIdToken = vi.fn().mockResolvedValue('mock-id-token');
|
||||
mockGetIdTokenClient = vi.fn().mockResolvedValue({
|
||||
idTokenProvider: {
|
||||
fetchIdToken: mockFetchIdToken,
|
||||
},
|
||||
});
|
||||
|
||||
(GoogleAuth as unknown as Mock).mockImplementation(() => ({
|
||||
getClient: mockGetClient,
|
||||
getIdTokenClient: mockGetIdTokenClient,
|
||||
}));
|
||||
});
|
||||
|
||||
describe('Initialization', () => {
|
||||
it('throws if no targetUrl is provided', () => {
|
||||
expect(() => new GoogleCredentialsAuthProvider(mockConfig)).toThrow(
|
||||
/targetUrl must be provided/,
|
||||
);
|
||||
});
|
||||
|
||||
it('throws if targetHost is not allowed', () => {
|
||||
expect(
|
||||
() =>
|
||||
new GoogleCredentialsAuthProvider(mockConfig, 'https://example.com'),
|
||||
).toThrow(/is not an allowed host/);
|
||||
});
|
||||
|
||||
it('initializes seamlessly with .googleapis.com', () => {
|
||||
expect(
|
||||
() =>
|
||||
new GoogleCredentialsAuthProvider(
|
||||
mockConfig,
|
||||
'https://language.googleapis.com/v1/models',
|
||||
),
|
||||
).not.toThrow();
|
||||
});
|
||||
|
||||
it('initializes seamlessly with .run.app', () => {
|
||||
expect(
|
||||
() =>
|
||||
new GoogleCredentialsAuthProvider(
|
||||
mockConfig,
|
||||
'https://my-cloud-run-service.run.app',
|
||||
),
|
||||
).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Token Fetching', () => {
|
||||
it('fetches an access token for googleapis.com endpoint', async () => {
|
||||
const provider = new GoogleCredentialsAuthProvider(
|
||||
mockConfig,
|
||||
'https://language.googleapis.com',
|
||||
);
|
||||
const headers = await provider.headers();
|
||||
|
||||
expect(headers).toEqual({ Authorization: 'Bearer mock-access-token' });
|
||||
expect(mockGetClient).toHaveBeenCalled();
|
||||
expect(mockGetAccessToken).toHaveBeenCalled();
|
||||
expect(mockGetIdTokenClient).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('fetches an identity token for run.app endpoint', async () => {
|
||||
// Mock OAuthUtils.parseTokenExpiry to avoid Base64 decoding issues in tests
|
||||
vi.spyOn(OAuthUtils, 'parseTokenExpiry').mockReturnValue(
|
||||
Date.now() + 1000000,
|
||||
);
|
||||
|
||||
const provider = new GoogleCredentialsAuthProvider(
|
||||
mockConfig,
|
||||
'https://my-service.run.app/some-path',
|
||||
);
|
||||
const headers = await provider.headers();
|
||||
|
||||
expect(headers).toEqual({ Authorization: 'Bearer mock-id-token' });
|
||||
expect(mockGetIdTokenClient).toHaveBeenCalledWith('my-service.run.app');
|
||||
expect(mockFetchIdToken).toHaveBeenCalledWith('my-service.run.app');
|
||||
expect(mockGetClient).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('returns cached access token on subsequent calls', async () => {
|
||||
const provider = new GoogleCredentialsAuthProvider(
|
||||
mockConfig,
|
||||
'https://language.googleapis.com',
|
||||
);
|
||||
|
||||
await provider.headers();
|
||||
await provider.headers();
|
||||
|
||||
// Should only call getClient/getAccessToken once due to caching
|
||||
expect(mockGetClient).toHaveBeenCalledTimes(1);
|
||||
expect(mockGetAccessToken).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('returns cached id token on subsequent calls', async () => {
|
||||
vi.spyOn(OAuthUtils, 'parseTokenExpiry').mockReturnValue(
|
||||
Date.now() + 1000000,
|
||||
);
|
||||
|
||||
const provider = new GoogleCredentialsAuthProvider(
|
||||
mockConfig,
|
||||
'https://my-service.run.app',
|
||||
);
|
||||
|
||||
await provider.headers();
|
||||
await provider.headers();
|
||||
|
||||
expect(mockGetIdTokenClient).toHaveBeenCalledTimes(1);
|
||||
expect(mockFetchIdToken).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('re-fetches access token on 401 (shouldRetryWithHeaders)', async () => {
|
||||
const provider = new GoogleCredentialsAuthProvider(
|
||||
mockConfig,
|
||||
'https://language.googleapis.com',
|
||||
);
|
||||
|
||||
// Prime the cache
|
||||
await provider.headers();
|
||||
expect(mockGetAccessToken).toHaveBeenCalledTimes(1);
|
||||
|
||||
const req = {} as RequestInit;
|
||||
const res = { status: 401 } as Response;
|
||||
|
||||
const retryHeaders = await provider.shouldRetryWithHeaders(req, res);
|
||||
|
||||
expect(retryHeaders).toEqual({
|
||||
Authorization: 'Bearer mock-access-token',
|
||||
});
|
||||
// Cache was cleared, so getAccessToken was called again
|
||||
expect(mockGetAccessToken).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('re-fetches token on 403', async () => {
|
||||
const provider = new GoogleCredentialsAuthProvider(
|
||||
mockConfig,
|
||||
'https://language.googleapis.com',
|
||||
);
|
||||
|
||||
const req = {} as RequestInit;
|
||||
const res = { status: 403 } as Response;
|
||||
|
||||
const retryHeaders = await provider.shouldRetryWithHeaders(req, res);
|
||||
|
||||
expect(retryHeaders).toEqual({
|
||||
Authorization: 'Bearer mock-access-token',
|
||||
});
|
||||
});
|
||||
|
||||
it('stops retrying after MAX_AUTH_RETRIES', async () => {
|
||||
const provider = new GoogleCredentialsAuthProvider(
|
||||
mockConfig,
|
||||
'https://language.googleapis.com',
|
||||
);
|
||||
|
||||
const req = {} as RequestInit;
|
||||
const res = { status: 401 } as Response;
|
||||
|
||||
// First two retries should succeed (MAX_AUTH_RETRIES = 2)
|
||||
expect(await provider.shouldRetryWithHeaders(req, res)).toBeDefined();
|
||||
expect(await provider.shouldRetryWithHeaders(req, res)).toBeDefined();
|
||||
|
||||
// Third should return undefined (exhausted)
|
||||
expect(await provider.shouldRetryWithHeaders(req, res)).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,161 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { HttpHeaders } from '@a2a-js/sdk/client';
|
||||
import { BaseA2AAuthProvider } from './base-provider.js';
|
||||
import type { GoogleCredentialsAuthConfig } from './types.js';
|
||||
import { GoogleAuth } from 'google-auth-library';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
import { OAuthUtils, FIVE_MIN_BUFFER_MS } from '../../mcp/oauth-utils.js';
|
||||
|
||||
const CLOUD_RUN_HOST_REGEX = /^(.*\.)?run\.app$/;
|
||||
const ALLOWED_HOSTS = [/^.+\.googleapis\.com$/, CLOUD_RUN_HOST_REGEX];
|
||||
|
||||
/**
|
||||
* Authentication provider for Google ADC (Application Default Credentials).
|
||||
* Automatically decides whether to use identity tokens or access tokens
|
||||
* based on the target endpoint URL.
|
||||
*/
|
||||
export class GoogleCredentialsAuthProvider extends BaseA2AAuthProvider {
|
||||
readonly type = 'google-credentials' as const;
|
||||
|
||||
private readonly auth: GoogleAuth;
|
||||
private readonly useIdToken: boolean = false;
|
||||
private readonly audience?: string;
|
||||
private cachedToken?: string;
|
||||
private tokenExpiryTime?: number;
|
||||
|
||||
constructor(
|
||||
private readonly config: GoogleCredentialsAuthConfig,
|
||||
targetUrl?: string,
|
||||
) {
|
||||
super();
|
||||
|
||||
if (!targetUrl) {
|
||||
throw new Error(
|
||||
'targetUrl must be provided to GoogleCredentialsAuthProvider to determine token audience.',
|
||||
);
|
||||
}
|
||||
|
||||
const hostname = new URL(targetUrl).hostname;
|
||||
const isRunAppHost = CLOUD_RUN_HOST_REGEX.test(hostname);
|
||||
|
||||
if (isRunAppHost) {
|
||||
this.useIdToken = true;
|
||||
}
|
||||
this.audience = hostname;
|
||||
|
||||
if (
|
||||
!this.useIdToken &&
|
||||
!ALLOWED_HOSTS.some((pattern) => pattern.test(hostname))
|
||||
) {
|
||||
throw new Error(
|
||||
`Host "${hostname}" is not an allowed host for Google Credential provider.`,
|
||||
);
|
||||
}
|
||||
|
||||
// A2A spec requires scopes if configured, otherwise use default cloud-platform
|
||||
const scopes =
|
||||
this.config.scopes && this.config.scopes.length > 0
|
||||
? this.config.scopes
|
||||
: ['https://www.googleapis.com/auth/cloud-platform'];
|
||||
|
||||
this.auth = new GoogleAuth({
|
||||
scopes,
|
||||
});
|
||||
}
|
||||
|
||||
override async initialize(): Promise<void> {
|
||||
// We can pre-fetch or validate if necessary here,
|
||||
// but deferred fetching is usually better for auth tokens.
|
||||
}
|
||||
|
||||
async headers(): Promise<HttpHeaders> {
|
||||
// Check cache
|
||||
if (
|
||||
this.cachedToken &&
|
||||
this.tokenExpiryTime &&
|
||||
Date.now() < this.tokenExpiryTime - FIVE_MIN_BUFFER_MS
|
||||
) {
|
||||
return { Authorization: `Bearer ${this.cachedToken}` };
|
||||
}
|
||||
|
||||
// Clear expired cache
|
||||
this.cachedToken = undefined;
|
||||
this.tokenExpiryTime = undefined;
|
||||
|
||||
if (this.useIdToken) {
|
||||
try {
|
||||
const idClient = await this.auth.getIdTokenClient(this.audience!);
|
||||
const idToken = await idClient.idTokenProvider.fetchIdToken(
|
||||
this.audience!,
|
||||
);
|
||||
|
||||
const expiryTime = OAuthUtils.parseTokenExpiry(idToken);
|
||||
if (expiryTime) {
|
||||
this.tokenExpiryTime = expiryTime;
|
||||
this.cachedToken = idToken;
|
||||
}
|
||||
|
||||
return { Authorization: `Bearer ${idToken}` };
|
||||
} catch (e) {
|
||||
const errorMessage = `Failed to get ADC ID token: ${
|
||||
e instanceof Error ? e.message : String(e)
|
||||
}`;
|
||||
debugLogger.error(errorMessage, e);
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, access token
|
||||
try {
|
||||
const client = await this.auth.getClient();
|
||||
const token = await client.getAccessToken();
|
||||
|
||||
if (token.token) {
|
||||
this.cachedToken = token.token;
|
||||
// Use expiry_date from the underlying credentials if available.
|
||||
const creds = client.credentials;
|
||||
if (creds.expiry_date) {
|
||||
this.tokenExpiryTime = creds.expiry_date;
|
||||
}
|
||||
return { Authorization: `Bearer ${token.token}` };
|
||||
}
|
||||
throw new Error('Failed to retrieve ADC access token.');
|
||||
} catch (e) {
|
||||
const errorMessage = `Failed to get ADC access token: ${
|
||||
e instanceof Error ? e.message : String(e)
|
||||
}`;
|
||||
debugLogger.error(errorMessage, e);
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
}
|
||||
|
||||
override async shouldRetryWithHeaders(
|
||||
_req: RequestInit,
|
||||
res: Response,
|
||||
): Promise<HttpHeaders | undefined> {
|
||||
if (res.status !== 401 && res.status !== 403) {
|
||||
this.authRetryCount = 0;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (this.authRetryCount >= BaseA2AAuthProvider.MAX_AUTH_RETRIES) {
|
||||
return undefined;
|
||||
}
|
||||
this.authRetryCount++;
|
||||
|
||||
debugLogger.debug(
|
||||
'[GoogleCredentialsAuthProvider] Re-fetching token after auth failure',
|
||||
);
|
||||
|
||||
// Clear cache to force a re-fetch
|
||||
this.cachedToken = undefined;
|
||||
this.tokenExpiryTime = undefined;
|
||||
|
||||
return this.headers();
|
||||
}
|
||||
}
|
||||
@@ -593,6 +593,7 @@ describe('AgentRegistry', () => {
|
||||
expect(A2AAuthProviderFactory.create).toHaveBeenCalledWith({
|
||||
authConfig: mockAuth,
|
||||
agentName: 'RemoteAgentWithAuth',
|
||||
targetUrl: 'https://example.com/card',
|
||||
agentCardUrl: 'https://example.com/card',
|
||||
});
|
||||
expect(loadAgentSpy).toHaveBeenCalledWith(
|
||||
|
||||
@@ -420,6 +420,7 @@ export class AgentRegistry {
|
||||
const provider = await A2AAuthProviderFactory.create({
|
||||
authConfig: definition.auth,
|
||||
agentName: definition.name,
|
||||
targetUrl: definition.agentCardUrl,
|
||||
agentCardUrl: remoteDef.agentCardUrl,
|
||||
});
|
||||
if (!provider) {
|
||||
|
||||
@@ -195,6 +195,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
expect(A2AAuthProviderFactory.create).toHaveBeenCalledWith({
|
||||
authConfig: mockAuth,
|
||||
agentName: 'test-agent',
|
||||
targetUrl: 'http://test-agent/card',
|
||||
agentCardUrl: 'http://test-agent/card',
|
||||
});
|
||||
expect(mockClientManager.loadAgent).toHaveBeenCalledWith(
|
||||
|
||||
@@ -22,7 +22,6 @@ import {
|
||||
type SendMessageResult,
|
||||
} from './a2a-client-manager.js';
|
||||
import { extractIdsFromResponse, A2AResultReassembler } from './a2aUtils.js';
|
||||
import { GoogleAuth } from 'google-auth-library';
|
||||
import type { AuthenticationHandler } from '@a2a-js/sdk/client';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { safeJsonToMarkdown } from '../utils/markdownUtils.js';
|
||||
@@ -30,39 +29,6 @@ import type { AnsiOutput } from '../utils/terminalSerializer.js';
|
||||
import { A2AAuthProviderFactory } from './auth-provider/factory.js';
|
||||
import { A2AAgentError } from './a2a-errors.js';
|
||||
|
||||
/**
|
||||
* Authentication handler implementation using Google Application Default Credentials (ADC).
|
||||
*/
|
||||
export class ADCHandler implements AuthenticationHandler {
|
||||
private auth = new GoogleAuth({
|
||||
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
|
||||
});
|
||||
|
||||
async headers(): Promise<Record<string, string>> {
|
||||
try {
|
||||
const client = await this.auth.getClient();
|
||||
const token = await client.getAccessToken();
|
||||
if (token.token) {
|
||||
return { Authorization: `Bearer ${token.token}` };
|
||||
}
|
||||
throw new Error('Failed to retrieve ADC access token.');
|
||||
} catch (e) {
|
||||
const errorMessage = `Failed to get ADC token: ${
|
||||
e instanceof Error ? e.message : String(e)
|
||||
}`;
|
||||
debugLogger.log('ERROR', errorMessage);
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
}
|
||||
|
||||
async shouldRetryWithHeaders(
|
||||
_response: unknown,
|
||||
): Promise<Record<string, string> | undefined> {
|
||||
// For ADC, we usually just re-fetch the token if needed.
|
||||
return this.headers();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A tool invocation that proxies to a remote A2A agent.
|
||||
*
|
||||
@@ -121,6 +87,7 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
|
||||
const provider = await A2AAuthProviderFactory.create({
|
||||
authConfig: this.definition.auth,
|
||||
agentName: this.definition.name,
|
||||
targetUrl: this.definition.agentCardUrl,
|
||||
agentCardUrl: this.definition.agentCardUrl,
|
||||
});
|
||||
if (!provider) {
|
||||
|
||||
Reference in New Issue
Block a user