Merge branch 'main' into restart-resume

This commit is contained in:
Jack Wotherspoon
2026-03-10 16:41:56 +01:00
committed by GitHub
35 changed files with 1794 additions and 65 deletions
+1 -1
View File
@@ -25,7 +25,7 @@
"dist"
],
"dependencies": {
"@a2a-js/sdk": "^0.3.8",
"@a2a-js/sdk": "0.3.11",
"@google-cloud/storage": "^7.16.0",
"@google/gemini-cli-core": "file:../core",
"express": "^5.1.0",
@@ -11,6 +11,7 @@ import { createMockCommandContext } from '../../test-utils/mockCommandContext.js
import {
AuthType,
openBrowserSecurely,
shouldLaunchBrowser,
UPGRADE_URL_PAGE,
} from '@google/gemini-cli-core';
@@ -20,6 +21,7 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
return {
...actual,
openBrowserSecurely: vi.fn(),
shouldLaunchBrowser: vi.fn().mockReturnValue(true),
UPGRADE_URL_PAGE: 'https://goo.gle/set-up-gemini-code-assist',
};
});
@@ -96,4 +98,21 @@ describe('upgradeCommand', () => {
content: 'Failed to open upgrade page: Failed to open',
});
});
it('should return URL message when shouldLaunchBrowser returns false', async () => {
vi.mocked(shouldLaunchBrowser).mockReturnValue(false);
if (!upgradeCommand.action) {
throw new Error('The upgrade command must have an action.');
}
const result = await upgradeCommand.action(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'info',
content: `Please open this URL in a browser: ${UPGRADE_URL_PAGE}`,
});
expect(openBrowserSecurely).not.toHaveBeenCalled();
});
});
@@ -7,6 +7,7 @@
import {
AuthType,
openBrowserSecurely,
shouldLaunchBrowser,
UPGRADE_URL_PAGE,
} from '@google/gemini-cli-core';
import type { SlashCommand } from './types.js';
@@ -35,6 +36,14 @@ export const upgradeCommand: SlashCommand = {
};
}
if (!shouldLaunchBrowser()) {
return {
type: 'message',
messageType: 'info',
content: `Please open this URL in a browser: ${UPGRADE_URL_PAGE}`,
};
}
try {
await openBrowserSecurely(UPGRADE_URL_PAGE);
} catch (error) {
@@ -18,6 +18,12 @@ Spinner Connecting to MCP servers... (0/5) - Waiting for: s1, s2, s3, +2 more
"
`;
exports[`ConfigInitDisplay > truncates list of waiting servers if too many 2`] = `
"
Spinner Connecting to MCP servers... (0/5) - Waiting for: s1, s2, s3, +2 more
"
`;
exports[`ConfigInitDisplay > updates message on McpClientUpdate event 1`] = `
"
Spinner Connecting to MCP servers... (1/2) - Waiting for: server2
@@ -15,6 +15,7 @@ import {
shouldAutoUseCredits,
shouldShowOverageMenu,
shouldShowEmptyWalletMenu,
shouldLaunchBrowser,
logBillingEvent,
G1_CREDIT_TYPE,
UserTierId,
@@ -32,6 +33,7 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
shouldShowEmptyWalletMenu: vi.fn(),
logBillingEvent: vi.fn(),
openBrowserSecurely: vi.fn(),
shouldLaunchBrowser: vi.fn().mockReturnValue(true),
};
});
@@ -237,4 +239,49 @@ describe('handleCreditsFlow', () => {
expect(isDialogPending.current).toBe(false);
expect(mockSetEmptyWalletRequest).toHaveBeenCalledWith(null);
});
describe('headless mode (shouldLaunchBrowser=false)', () => {
beforeEach(() => {
vi.mocked(shouldLaunchBrowser).mockReturnValue(false);
});
it('should show manage URL in history when manage selected in headless mode', async () => {
vi.mocked(shouldShowOverageMenu).mockReturnValue(true);
const flowPromise = handleCreditsFlow(makeArgs());
const request = mockSetOverageMenuRequest.mock.calls[0][0];
request.resolve('manage');
const result = await flowPromise;
expect(result).toBe('stop');
expect(mockHistoryManager.addItem).toHaveBeenCalledWith(
expect.objectContaining({
type: MessageType.INFO,
text: expect.stringContaining('Please open this URL in a browser:'),
}),
expect.any(Number),
);
});
it('should show credits URL in history when get_credits selected in headless mode', async () => {
vi.mocked(shouldShowEmptyWalletMenu).mockReturnValue(true);
const flowPromise = handleCreditsFlow(makeArgs());
const request = mockSetEmptyWalletRequest.mock.calls[0][0];
// Trigger onGetCredits callback and wait for it
await request.onGetCredits();
expect(mockHistoryManager.addItem).toHaveBeenCalledWith(
expect.objectContaining({
type: MessageType.INFO,
text: expect.stringContaining('Please open this URL in a browser:'),
}),
expect.any(Number),
);
request.resolve('get_credits');
await flowPromise;
});
});
});
@@ -14,6 +14,7 @@ import {
shouldShowOverageMenu,
shouldShowEmptyWalletMenu,
openBrowserSecurely,
shouldLaunchBrowser,
logBillingEvent,
OverageMenuShownEvent,
OverageOptionSelectedEvent,
@@ -159,10 +160,23 @@ async function handleOverageMenu(
case 'use_fallback':
return 'retry_always';
case 'manage':
case 'manage': {
logCreditPurchaseClick(config, 'manage', usageLimitReachedModel);
await openG1Url('activity', G1_UTM_CAMPAIGNS.MANAGE_ACTIVITY);
const manageUrl = await openG1Url(
'activity',
G1_UTM_CAMPAIGNS.MANAGE_ACTIVITY,
);
if (manageUrl) {
args.historyManager.addItem(
{
type: MessageType.INFO,
text: `Please open this URL in a browser: ${manageUrl}`,
},
Date.now(),
);
}
return 'stop';
}
case 'stop':
default:
@@ -205,13 +219,25 @@ async function handleEmptyWalletMenu(
failedModel: usageLimitReachedModel,
fallbackModel,
resetTime,
onGetCredits: () => {
onGetCredits: async () => {
logCreditPurchaseClick(
config,
'empty_wallet_menu',
usageLimitReachedModel,
);
void openG1Url('credits', G1_UTM_CAMPAIGNS.EMPTY_WALLET_ADD_CREDITS);
const creditsUrl = await openG1Url(
'credits',
G1_UTM_CAMPAIGNS.EMPTY_WALLET_ADD_CREDITS,
);
if (creditsUrl) {
args.historyManager.addItem(
{
type: MessageType.INFO,
text: `Please open this URL in a browser: ${creditsUrl}`,
},
Date.now(),
);
}
},
resolve,
});
@@ -272,11 +298,16 @@ function logCreditPurchaseClick(
async function openG1Url(
path: 'activity' | 'credits',
campaign: string,
): Promise<void> {
): Promise<string | undefined> {
try {
const userEmail = new UserAccountManager().getCachedGoogleAccount() ?? '';
await openBrowserSecurely(buildG1Url(path, userEmail, campaign));
const url = buildG1Url(path, userEmail, campaign);
if (!shouldLaunchBrowser()) {
return url;
}
await openBrowserSecurely(url);
} catch {
// Ignore browser open errors
}
return undefined;
}
+1 -1
View File
@@ -21,7 +21,7 @@
"dist"
],
"dependencies": {
"@a2a-js/sdk": "^0.3.10",
"@a2a-js/sdk": "0.3.11",
"@bufbuild/protobuf": "^2.11.0",
"@google-cloud/logging": "^11.2.1",
"@google-cloud/opentelemetry-cloud-monitoring-exporter": "^0.21.0",
@@ -140,7 +140,7 @@ describe('A2AClientManager', () => {
expect(createAuthenticatingFetchWithRetry).not.toHaveBeenCalled();
});
it('should use provided custom authentication handler', async () => {
it('should use provided custom authentication handler for transports only', async () => {
const customAuthHandler = {
headers: vi.fn(),
shouldRetryWithHeaders: vi.fn(),
@@ -155,6 +155,66 @@ describe('A2AClientManager', () => {
expect.anything(),
customAuthHandler,
);
// Card resolver should NOT use the authenticated fetch by default.
const resolverInstance = vi.mocked(DefaultAgentCardResolver).mock
.instances[0];
expect(resolverInstance).toBeDefined();
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock
.calls[0][0];
expect(resolverOptions?.fetchImpl).not.toBe(authFetchMock);
});
it('should use unauthenticated fetch for card resolver and avoid authenticated fetch if success', async () => {
const customAuthHandler = {
headers: vi.fn(),
shouldRetryWithHeaders: vi.fn(),
};
await manager.loadAgent(
'AuthCardAgent',
'http://authcard.agent/card',
customAuthHandler as unknown as AuthenticationHandler,
);
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock
.calls[0][0];
const cardFetch = resolverOptions?.fetchImpl as typeof fetch;
expect(cardFetch).toBeDefined();
await cardFetch('http://test.url');
expect(fetch).toHaveBeenCalledWith('http://test.url', expect.anything());
expect(authFetchMock).not.toHaveBeenCalled();
});
it('should retry with authenticating fetch if agent card fetch returns 401', async () => {
const customAuthHandler = {
headers: vi.fn(),
shouldRetryWithHeaders: vi.fn(),
};
// Mock the initial unauthenticated fetch to fail with 401
vi.mocked(fetch).mockResolvedValueOnce({
ok: false,
status: 401,
json: async () => ({}),
} as Response);
await manager.loadAgent(
'AuthCardAgent401',
'http://authcard.agent/card',
customAuthHandler as unknown as AuthenticationHandler,
);
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock
.calls[0][0];
const cardFetch = resolverOptions?.fetchImpl as typeof fetch;
await cardFetch('http://test.url');
expect(fetch).toHaveBeenCalledWith('http://test.url', expect.anything());
expect(authFetchMock).toHaveBeenCalledWith('http://test.url', undefined);
});
it('should log a debug message upon loading an agent', async () => {
+23 -5
View File
@@ -95,19 +95,37 @@ export class A2AClientManager {
throw new Error(`Agent with name '${name}' is already loaded.`);
}
let fetchImpl: typeof fetch = a2aFetch;
// Authenticated fetch for API calls (transports).
let authFetch: typeof fetch = a2aFetch;
if (authHandler) {
fetchImpl = createAuthenticatingFetchWithRetry(a2aFetch, authHandler);
authFetch = createAuthenticatingFetchWithRetry(a2aFetch, authHandler);
}
const resolver = new DefaultAgentCardResolver({ fetchImpl });
// Use unauthenticated fetch for the agent card unless explicitly required.
// Some servers reject unexpected auth headers on the card endpoint (e.g. 400).
const cardFetch = async (
input: RequestInfo | URL,
init?: RequestInit,
): Promise<Response> => {
// Try without auth first
const response = await a2aFetch(input, init);
// Retry with auth if we hit a 401/403
if ((response.status === 401 || response.status === 403) && authFetch) {
return authFetch(input, init);
}
return response;
};
const resolver = new DefaultAgentCardResolver({ fetchImpl: cardFetch });
const options = ClientFactoryOptions.createFrom(
ClientFactoryOptions.default,
{
transports: [
new RestTransportFactory({ fetchImpl }),
new JsonRpcTransportFactory({ fetchImpl }),
new RestTransportFactory({ fetchImpl: authFetch }),
new JsonRpcTransportFactory({ fetchImpl: authFetch }),
],
cardResolver: resolver,
},
@@ -576,5 +576,139 @@ auth:
},
});
});
it('should parse remote agent with oauth2 auth', async () => {
const filePath = await writeAgentMarkdown(`---
kind: remote
name: oauth2-agent
agent_card_url: https://example.com/card
auth:
type: oauth2
client_id: $MY_OAUTH_CLIENT_ID
scopes:
- read
- write
---
`);
const result = await parseAgentMarkdown(filePath);
expect(result).toHaveLength(1);
expect(result[0]).toMatchObject({
kind: 'remote',
name: 'oauth2-agent',
auth: {
type: 'oauth2',
client_id: '$MY_OAUTH_CLIENT_ID',
scopes: ['read', 'write'],
},
});
});
it('should parse remote agent with oauth2 auth including all fields', async () => {
const filePath = await writeAgentMarkdown(`---
kind: remote
name: oauth2-full-agent
agent_card_url: https://example.com/card
auth:
type: oauth2
client_id: my-client-id
client_secret: my-client-secret
scopes:
- openid
- profile
authorization_url: https://auth.example.com/authorize
token_url: https://auth.example.com/token
---
`);
const result = await parseAgentMarkdown(filePath);
expect(result).toHaveLength(1);
expect(result[0]).toMatchObject({
kind: 'remote',
name: 'oauth2-full-agent',
auth: {
type: 'oauth2',
client_id: 'my-client-id',
client_secret: 'my-client-secret',
scopes: ['openid', 'profile'],
authorization_url: 'https://auth.example.com/authorize',
token_url: 'https://auth.example.com/token',
},
});
});
it('should parse remote agent with minimal oauth2 config (type only)', async () => {
const filePath = await writeAgentMarkdown(`---
kind: remote
name: oauth2-minimal-agent
agent_card_url: https://example.com/card
auth:
type: oauth2
---
`);
const result = await parseAgentMarkdown(filePath);
expect(result).toHaveLength(1);
expect(result[0]).toMatchObject({
kind: 'remote',
name: 'oauth2-minimal-agent',
auth: {
type: 'oauth2',
},
});
});
it('should reject oauth2 auth with invalid authorization_url', async () => {
const filePath = await writeAgentMarkdown(`---
kind: remote
name: invalid-oauth2-agent
agent_card_url: https://example.com/card
auth:
type: oauth2
client_id: my-client
authorization_url: not-a-valid-url
---
`);
await expect(parseAgentMarkdown(filePath)).rejects.toThrow(/Invalid url/);
});
it('should reject oauth2 auth with invalid token_url', async () => {
const filePath = await writeAgentMarkdown(`---
kind: remote
name: invalid-oauth2-agent
agent_card_url: https://example.com/card
auth:
type: oauth2
client_id: my-client
token_url: not-a-valid-url
---
`);
await expect(parseAgentMarkdown(filePath)).rejects.toThrow(/Invalid url/);
});
it('should convert oauth2 auth config in markdownToAgentDefinition', () => {
const markdown = {
kind: 'remote' as const,
name: 'oauth2-convert-agent',
agent_card_url: 'https://example.com/card',
auth: {
type: 'oauth2' as const,
client_id: '$MY_CLIENT_ID',
scopes: ['read'],
authorization_url: 'https://auth.example.com/authorize',
token_url: 'https://auth.example.com/token',
},
};
const result = markdownToAgentDefinition(markdown);
expect(result).toMatchObject({
kind: 'remote',
name: 'oauth2-convert-agent',
auth: {
type: 'oauth2',
client_id: '$MY_CLIENT_ID',
scopes: ['read'],
authorization_url: 'https://auth.example.com/authorize',
token_url: 'https://auth.example.com/token',
},
});
});
});
});
+37 -2
View File
@@ -44,7 +44,7 @@ interface FrontmatterLocalAgentDefinition
* Authentication configuration for remote agents in frontmatter format.
*/
interface FrontmatterAuthConfig {
type: 'apiKey' | 'http';
type: 'apiKey' | 'http' | 'oauth2';
agent_card_requires_auth?: boolean;
// API Key
key?: string;
@@ -55,6 +55,12 @@ interface FrontmatterAuthConfig {
username?: string;
password?: string;
value?: string;
// OAuth2
client_id?: string;
client_secret?: string;
scopes?: string[];
authorization_url?: string;
token_url?: string;
}
interface FrontmatterRemoteAgentDefinition
@@ -147,8 +153,26 @@ const httpAuthSchema = z.object({
value: z.string().min(1).optional(),
});
/**
* OAuth2 auth schema.
* authorization_url and token_url can be discovered from the agent card if omitted.
*/
const oauth2AuthSchema = z.object({
...baseAuthFields,
type: z.literal('oauth2'),
client_id: z.string().optional(),
client_secret: z.string().optional(),
scopes: z.array(z.string()).optional(),
authorization_url: z.string().url().optional(),
token_url: z.string().url().optional(),
});
const authConfigSchema = z
.discriminatedUnion('type', [apiKeyAuthSchema, httpAuthSchema])
.discriminatedUnion('type', [
apiKeyAuthSchema,
httpAuthSchema,
oauth2AuthSchema,
])
.superRefine((data, ctx) => {
if (data.type === 'http') {
if (data.value) {
@@ -395,6 +419,17 @@ function convertFrontmatterAuthToConfig(
}
}
case 'oauth2':
return {
...base,
type: 'oauth2',
client_id: frontmatter.client_id,
client_secret: frontmatter.client_secret,
scopes: frontmatter.scopes,
authorization_url: frontmatter.authorization_url,
token_url: frontmatter.token_url,
};
default: {
const exhaustive: never = frontmatter.type;
throw new Error(`Unknown auth type: ${exhaustive}`);
@@ -4,11 +4,22 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { describe, it, expect, vi } from 'vitest';
import { A2AAuthProviderFactory } from './factory.js';
import type { AgentCard, SecurityScheme } from '@a2a-js/sdk';
import type { A2AAuthConfig } from './types.js';
// Mock token storage so OAuth2AuthProvider.initialize() works without disk I/O.
vi.mock('../../mcp/oauth-token-storage.js', () => {
const MCPOAuthTokenStorage = vi.fn().mockImplementation(() => ({
getCredentials: vi.fn().mockResolvedValue(null),
saveToken: vi.fn().mockResolvedValue(undefined),
deleteCredentials: vi.fn().mockResolvedValue(undefined),
isTokenExpired: vi.fn().mockReturnValue(false),
}));
return { MCPOAuthTokenStorage };
});
describe('A2AAuthProviderFactory', () => {
describe('validateAuthConfig', () => {
describe('when no security schemes required', () => {
@@ -492,5 +503,62 @@ describe('A2AAuthProviderFactory', () => {
const headers = await provider!.headers();
expect(headers).toEqual({ 'X-API-Key': 'factory-test-key' });
});
it('should create an OAuth2AuthProvider for oauth2 config', async () => {
const provider = await A2AAuthProviderFactory.create({
agentName: 'my-oauth-agent',
authConfig: {
type: 'oauth2',
client_id: 'my-client',
authorization_url: 'https://auth.example.com/authorize',
token_url: 'https://auth.example.com/token',
scopes: ['read'],
},
});
expect(provider).toBeDefined();
expect(provider!.type).toBe('oauth2');
});
it('should create an OAuth2AuthProvider with agent card defaults', async () => {
const provider = await A2AAuthProviderFactory.create({
agentName: 'card-oauth-agent',
authConfig: {
type: 'oauth2',
client_id: 'my-client',
},
agentCard: {
securitySchemes: {
oauth: {
type: 'oauth2',
flows: {
authorizationCode: {
authorizationUrl: 'https://card.example.com/authorize',
tokenUrl: 'https://card.example.com/token',
scopes: { read: 'Read access' },
},
},
},
},
} as unknown as AgentCard,
});
expect(provider).toBeDefined();
expect(provider!.type).toBe('oauth2');
});
it('should use "unknown" as agent name when agentName is not provided for oauth2', async () => {
const provider = await A2AAuthProviderFactory.create({
authConfig: {
type: 'oauth2',
client_id: 'my-client',
authorization_url: 'https://auth.example.com/authorize',
token_url: 'https://auth.example.com/token',
},
});
expect(provider).toBeDefined();
expect(provider!.type).toBe('oauth2');
});
});
});
@@ -18,6 +18,8 @@ export interface CreateAuthProviderOptions {
agentName?: string;
authConfig?: A2AAuthConfig;
agentCard?: AgentCard;
/** URL to fetch the agent card from, used for OAuth2 URL discovery. */
agentCardUrl?: string;
}
/**
@@ -57,9 +59,20 @@ export class A2AAuthProviderFactory {
return provider;
}
case 'oauth2':
// TODO: Implement
throw new Error('oauth2 auth provider not yet implemented');
case 'oauth2': {
// Dynamic import to avoid pulling MCPOAuthTokenStorage into the
// factory's static module graph, which causes initialization
// conflicts with code_assist/oauth-credential-storage.ts.
const { OAuth2AuthProvider } = await import('./oauth2-provider.js');
const provider = new OAuth2AuthProvider(
authConfig,
options.agentName ?? 'unknown',
agentCard,
options.agentCardUrl,
);
await provider.initialize();
return provider;
}
case 'openIdConnect':
// TODO: Implement
@@ -0,0 +1,651 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { OAuth2AuthProvider } from './oauth2-provider.js';
import type { OAuth2AuthConfig } from './types.js';
import type { AgentCard } from '@a2a-js/sdk';
// Mock DefaultAgentCardResolver from @a2a-js/sdk/client.
const mockResolve = vi.fn();
vi.mock('@a2a-js/sdk/client', async (importOriginal) => {
const actual = await importOriginal<typeof import('@a2a-js/sdk/client')>();
return {
...actual,
DefaultAgentCardResolver: vi.fn().mockImplementation(() => ({
resolve: mockResolve,
})),
};
});
// Mock all external dependencies.
vi.mock('../../mcp/oauth-token-storage.js', () => {
const MCPOAuthTokenStorage = vi.fn().mockImplementation(() => ({
getCredentials: vi.fn().mockResolvedValue(null),
saveToken: vi.fn().mockResolvedValue(undefined),
deleteCredentials: vi.fn().mockResolvedValue(undefined),
isTokenExpired: vi.fn().mockReturnValue(false),
}));
return { MCPOAuthTokenStorage };
});
vi.mock('../../utils/oauth-flow.js', () => ({
generatePKCEParams: vi.fn().mockReturnValue({
codeVerifier: 'test-verifier',
codeChallenge: 'test-challenge',
state: 'test-state',
}),
startCallbackServer: vi.fn().mockReturnValue({
port: Promise.resolve(12345),
response: Promise.resolve({ code: 'test-code', state: 'test-state' }),
}),
getPortFromUrl: vi.fn().mockReturnValue(undefined),
buildAuthorizationUrl: vi
.fn()
.mockReturnValue('https://auth.example.com/authorize?foo=bar'),
exchangeCodeForToken: vi.fn().mockResolvedValue({
access_token: 'new-access-token',
token_type: 'Bearer',
expires_in: 3600,
refresh_token: 'new-refresh-token',
}),
refreshAccessToken: vi.fn().mockResolvedValue({
access_token: 'refreshed-access-token',
token_type: 'Bearer',
expires_in: 3600,
refresh_token: 'refreshed-refresh-token',
}),
}));
vi.mock('../../utils/secure-browser-launcher.js', () => ({
openBrowserSecurely: vi.fn().mockResolvedValue(undefined),
}));
vi.mock('../../utils/authConsent.js', () => ({
getConsentForOauth: vi.fn().mockResolvedValue(true),
}));
vi.mock('../../utils/events.js', () => ({
coreEvents: {
emitFeedback: vi.fn(),
},
}));
vi.mock('../../utils/debugLogger.js', () => ({
debugLogger: {
debug: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
log: vi.fn(),
},
}));
// Re-import mocked modules for assertions.
const { MCPOAuthTokenStorage } = await import(
'../../mcp/oauth-token-storage.js'
);
const {
refreshAccessToken,
exchangeCodeForToken,
generatePKCEParams,
startCallbackServer,
buildAuthorizationUrl,
} = await import('../../utils/oauth-flow.js');
const { getConsentForOauth } = await import('../../utils/authConsent.js');
function createConfig(
overrides: Partial<OAuth2AuthConfig> = {},
): OAuth2AuthConfig {
return {
type: 'oauth2',
client_id: 'test-client-id',
authorization_url: 'https://auth.example.com/authorize',
token_url: 'https://auth.example.com/token',
scopes: ['read', 'write'],
...overrides,
};
}
function getTokenStorage() {
// Access the mocked MCPOAuthTokenStorage instance created in the constructor.
const instance = vi.mocked(MCPOAuthTokenStorage).mock.results.at(-1)!.value;
return instance as {
getCredentials: ReturnType<typeof vi.fn>;
saveToken: ReturnType<typeof vi.fn>;
deleteCredentials: ReturnType<typeof vi.fn>;
isTokenExpired: ReturnType<typeof vi.fn>;
};
}
describe('OAuth2AuthProvider', () => {
beforeEach(() => {
vi.clearAllMocks();
});
describe('constructor', () => {
it('should set type to oauth2', () => {
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
expect(provider.type).toBe('oauth2');
});
it('should use config values for authorization_url and token_url', () => {
const config = createConfig({
authorization_url: 'https://custom.example.com/authorize',
token_url: 'https://custom.example.com/token',
});
const provider = new OAuth2AuthProvider(config, 'test-agent');
// Verify by calling headers which will trigger interactive flow with these URLs.
expect(provider.type).toBe('oauth2');
});
it('should merge agent card defaults when config values are missing', () => {
const config = createConfig({
authorization_url: undefined,
token_url: undefined,
scopes: undefined,
});
const agentCard = {
securitySchemes: {
oauth: {
type: 'oauth2' as const,
flows: {
authorizationCode: {
authorizationUrl: 'https://card.example.com/authorize',
tokenUrl: 'https://card.example.com/token',
scopes: { read: 'Read access', write: 'Write access' },
},
},
},
},
} as unknown as AgentCard;
const provider = new OAuth2AuthProvider(config, 'test-agent', agentCard);
expect(provider.type).toBe('oauth2');
});
it('should prefer config values over agent card values', async () => {
const config = createConfig({
authorization_url: 'https://config.example.com/authorize',
token_url: 'https://config.example.com/token',
scopes: ['custom-scope'],
});
const agentCard = {
securitySchemes: {
oauth: {
type: 'oauth2' as const,
flows: {
authorizationCode: {
authorizationUrl: 'https://card.example.com/authorize',
tokenUrl: 'https://card.example.com/token',
scopes: { read: 'Read access' },
},
},
},
},
} as unknown as AgentCard;
const provider = new OAuth2AuthProvider(config, 'test-agent', agentCard);
await provider.headers();
// The config URLs should be used, not the agent card ones.
expect(vi.mocked(buildAuthorizationUrl)).toHaveBeenCalledWith(
expect.objectContaining({
authorizationUrl: 'https://config.example.com/authorize',
tokenUrl: 'https://config.example.com/token',
scopes: ['custom-scope'],
}),
expect.anything(),
expect.anything(),
undefined,
);
});
});
describe('initialize', () => {
it('should load a valid token from storage', async () => {
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
const storage = getTokenStorage();
storage.getCredentials.mockResolvedValue({
serverName: 'test-agent',
token: {
accessToken: 'stored-token',
tokenType: 'Bearer',
},
updatedAt: Date.now(),
});
storage.isTokenExpired.mockReturnValue(false);
await provider.initialize();
const headers = await provider.headers();
expect(headers).toEqual({ Authorization: 'Bearer stored-token' });
});
it('should not cache an expired token from storage', async () => {
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
const storage = getTokenStorage();
storage.getCredentials.mockResolvedValue({
serverName: 'test-agent',
token: {
accessToken: 'expired-token',
tokenType: 'Bearer',
expiresAt: Date.now() - 1000,
},
updatedAt: Date.now(),
});
storage.isTokenExpired.mockReturnValue(true);
await provider.initialize();
// Should trigger interactive flow since cached token is null.
const headers = await provider.headers();
expect(headers).toEqual({ Authorization: 'Bearer new-access-token' });
});
it('should handle no stored credentials gracefully', async () => {
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
const storage = getTokenStorage();
storage.getCredentials.mockResolvedValue(null);
await provider.initialize();
// Should trigger interactive flow.
const headers = await provider.headers();
expect(headers).toEqual({ Authorization: 'Bearer new-access-token' });
});
});
describe('headers', () => {
it('should return cached token if valid', async () => {
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
const storage = getTokenStorage();
storage.getCredentials.mockResolvedValue({
serverName: 'test-agent',
token: { accessToken: 'cached-token', tokenType: 'Bearer' },
updatedAt: Date.now(),
});
storage.isTokenExpired.mockReturnValue(false);
await provider.initialize();
const headers = await provider.headers();
expect(headers).toEqual({ Authorization: 'Bearer cached-token' });
expect(vi.mocked(exchangeCodeForToken)).not.toHaveBeenCalled();
expect(vi.mocked(refreshAccessToken)).not.toHaveBeenCalled();
});
it('should refresh token when expired with refresh_token available', async () => {
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
const storage = getTokenStorage();
// First call: load from storage (expired but with refresh token).
storage.getCredentials.mockResolvedValue({
serverName: 'test-agent',
token: {
accessToken: 'expired-token',
tokenType: 'Bearer',
refreshToken: 'my-refresh-token',
expiresAt: Date.now() - 1000,
},
updatedAt: Date.now(),
});
// isTokenExpired: false for initialize (to cache it), true for headers check.
storage.isTokenExpired
.mockReturnValueOnce(false) // initialize: cache the token
.mockReturnValueOnce(true); // headers: token is expired
await provider.initialize();
const headers = await provider.headers();
expect(vi.mocked(refreshAccessToken)).toHaveBeenCalledWith(
expect.objectContaining({ clientId: 'test-client-id' }),
'my-refresh-token',
'https://auth.example.com/token',
);
expect(headers).toEqual({
Authorization: 'Bearer refreshed-access-token',
});
expect(storage.saveToken).toHaveBeenCalled();
});
it('should fall back to interactive flow when refresh fails', async () => {
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
const storage = getTokenStorage();
storage.getCredentials.mockResolvedValue({
serverName: 'test-agent',
token: {
accessToken: 'expired-token',
tokenType: 'Bearer',
refreshToken: 'bad-refresh-token',
expiresAt: Date.now() - 1000,
},
updatedAt: Date.now(),
});
storage.isTokenExpired
.mockReturnValueOnce(false) // initialize
.mockReturnValueOnce(true); // headers
vi.mocked(refreshAccessToken).mockRejectedValueOnce(
new Error('Refresh failed'),
);
await provider.initialize();
const headers = await provider.headers();
// Should have deleted stale credentials and done interactive flow.
expect(storage.deleteCredentials).toHaveBeenCalledWith('test-agent');
expect(headers).toEqual({ Authorization: 'Bearer new-access-token' });
});
it('should trigger interactive flow when no token exists', async () => {
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
const storage = getTokenStorage();
storage.getCredentials.mockResolvedValue(null);
await provider.initialize();
const headers = await provider.headers();
expect(vi.mocked(generatePKCEParams)).toHaveBeenCalled();
expect(vi.mocked(startCallbackServer)).toHaveBeenCalled();
expect(vi.mocked(exchangeCodeForToken)).toHaveBeenCalled();
expect(storage.saveToken).toHaveBeenCalledWith(
'test-agent',
expect.objectContaining({ accessToken: 'new-access-token' }),
'test-client-id',
'https://auth.example.com/token',
);
expect(headers).toEqual({ Authorization: 'Bearer new-access-token' });
});
it('should throw when user declines consent', async () => {
vi.mocked(getConsentForOauth).mockResolvedValueOnce(false);
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
await provider.initialize();
await expect(provider.headers()).rejects.toThrow(
'Authentication cancelled by user',
);
});
it('should throw when client_id is missing', async () => {
const config = createConfig({ client_id: undefined });
const provider = new OAuth2AuthProvider(config, 'test-agent');
await provider.initialize();
await expect(provider.headers()).rejects.toThrow(/requires a client_id/);
});
it('should throw when authorization_url and token_url are missing', async () => {
const config = createConfig({
authorization_url: undefined,
token_url: undefined,
});
const provider = new OAuth2AuthProvider(config, 'test-agent');
await provider.initialize();
await expect(provider.headers()).rejects.toThrow(
/requires authorization_url and token_url/,
);
});
});
describe('shouldRetryWithHeaders', () => {
it('should clear token and re-authenticate on 401', async () => {
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
const storage = getTokenStorage();
storage.getCredentials.mockResolvedValue({
serverName: 'test-agent',
token: { accessToken: 'old-token', tokenType: 'Bearer' },
updatedAt: Date.now(),
});
storage.isTokenExpired.mockReturnValue(false);
await provider.initialize();
const res = new Response(null, { status: 401 });
const retryHeaders = await provider.shouldRetryWithHeaders({}, res);
expect(storage.deleteCredentials).toHaveBeenCalledWith('test-agent');
expect(retryHeaders).toBeDefined();
expect(retryHeaders).toHaveProperty('Authorization');
});
it('should clear token and re-authenticate on 403', async () => {
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
const storage = getTokenStorage();
storage.getCredentials.mockResolvedValue({
serverName: 'test-agent',
token: { accessToken: 'old-token', tokenType: 'Bearer' },
updatedAt: Date.now(),
});
storage.isTokenExpired.mockReturnValue(false);
await provider.initialize();
const res = new Response(null, { status: 403 });
const retryHeaders = await provider.shouldRetryWithHeaders({}, res);
expect(retryHeaders).toBeDefined();
});
it('should return undefined for non-auth errors', async () => {
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
const res = new Response(null, { status: 500 });
const retryHeaders = await provider.shouldRetryWithHeaders({}, res);
expect(retryHeaders).toBeUndefined();
});
it('should respect MAX_AUTH_RETRIES', async () => {
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
const res401 = new Response(null, { status: 401 });
// First retry — should succeed.
const first = await provider.shouldRetryWithHeaders({}, res401);
expect(first).toBeDefined();
// Second retry — should succeed.
const second = await provider.shouldRetryWithHeaders({}, res401);
expect(second).toBeDefined();
// Third retry — should be blocked.
const third = await provider.shouldRetryWithHeaders({}, res401);
expect(third).toBeUndefined();
});
it('should reset retry count on non-auth response', async () => {
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
const res401 = new Response(null, { status: 401 });
const res200 = new Response(null, { status: 200 });
await provider.shouldRetryWithHeaders({}, res401);
await provider.shouldRetryWithHeaders({}, res200); // resets
// Should be able to retry again.
const result = await provider.shouldRetryWithHeaders({}, res401);
expect(result).toBeDefined();
});
});
describe('token persistence', () => {
it('should persist token after successful interactive auth', async () => {
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
const storage = getTokenStorage();
await provider.initialize();
await provider.headers();
expect(storage.saveToken).toHaveBeenCalledWith(
'test-agent',
expect.objectContaining({
accessToken: 'new-access-token',
tokenType: 'Bearer',
refreshToken: 'new-refresh-token',
}),
'test-client-id',
'https://auth.example.com/token',
);
});
it('should persist token after successful refresh', async () => {
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
const storage = getTokenStorage();
storage.getCredentials.mockResolvedValue({
serverName: 'test-agent',
token: {
accessToken: 'expired-token',
tokenType: 'Bearer',
refreshToken: 'my-refresh-token',
},
updatedAt: Date.now(),
});
storage.isTokenExpired
.mockReturnValueOnce(false)
.mockReturnValueOnce(true);
await provider.initialize();
await provider.headers();
expect(storage.saveToken).toHaveBeenCalledWith(
'test-agent',
expect.objectContaining({
accessToken: 'refreshed-access-token',
}),
'test-client-id',
'https://auth.example.com/token',
);
});
});
describe('agent card integration', () => {
it('should discover URLs from agent card when not in config', async () => {
const config = createConfig({
authorization_url: undefined,
token_url: undefined,
scopes: undefined,
});
const agentCard = {
securitySchemes: {
myOauth: {
type: 'oauth2' as const,
flows: {
authorizationCode: {
authorizationUrl: 'https://card.example.com/auth',
tokenUrl: 'https://card.example.com/token',
scopes: { profile: 'View profile', email: 'View email' },
},
},
},
},
} as unknown as AgentCard;
const provider = new OAuth2AuthProvider(config, 'card-agent', agentCard);
await provider.initialize();
await provider.headers();
expect(vi.mocked(buildAuthorizationUrl)).toHaveBeenCalledWith(
expect.objectContaining({
authorizationUrl: 'https://card.example.com/auth',
tokenUrl: 'https://card.example.com/token',
scopes: ['profile', 'email'],
}),
expect.anything(),
expect.anything(),
undefined,
);
});
it('should discover URLs from agentCardUrl via DefaultAgentCardResolver during initialize', async () => {
const config = createConfig({
authorization_url: undefined,
token_url: undefined,
scopes: undefined,
});
// Simulate a normalized agent card returned by DefaultAgentCardResolver.
mockResolve.mockResolvedValue({
securitySchemes: {
myOauth: {
type: 'oauth2' as const,
flows: {
authorizationCode: {
authorizationUrl: 'https://discovered.example.com/auth',
tokenUrl: 'https://discovered.example.com/token',
scopes: { openid: 'OpenID', profile: 'Profile' },
},
},
},
},
} as unknown as AgentCard);
// No agentCard passed to constructor — only agentCardUrl.
const provider = new OAuth2AuthProvider(
config,
'discover-agent',
undefined,
'https://example.com/.well-known/agent-card.json',
);
await provider.initialize();
await provider.headers();
expect(mockResolve).toHaveBeenCalledWith(
'https://example.com/.well-known/agent-card.json',
'',
);
expect(vi.mocked(buildAuthorizationUrl)).toHaveBeenCalledWith(
expect.objectContaining({
authorizationUrl: 'https://discovered.example.com/auth',
tokenUrl: 'https://discovered.example.com/token',
scopes: ['openid', 'profile'],
}),
expect.anything(),
expect.anything(),
undefined,
);
});
it('should ignore agent card with no authorizationCode flow', () => {
const config = createConfig({
authorization_url: undefined,
token_url: undefined,
});
const agentCard = {
securitySchemes: {
myOauth: {
type: 'oauth2' as const,
flows: {
clientCredentials: {
tokenUrl: 'https://card.example.com/token',
scopes: {},
},
},
},
},
} as unknown as AgentCard;
// Should not throw — just won't have URLs.
const provider = new OAuth2AuthProvider(config, 'card-agent', agentCard);
expect(provider.type).toBe('oauth2');
});
});
});
@@ -0,0 +1,340 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { type HttpHeaders, DefaultAgentCardResolver } from '@a2a-js/sdk/client';
import type { AgentCard } from '@a2a-js/sdk';
import { BaseA2AAuthProvider } from './base-provider.js';
import type { OAuth2AuthConfig } from './types.js';
import { MCPOAuthTokenStorage } from '../../mcp/oauth-token-storage.js';
import type { OAuthToken } from '../../mcp/token-storage/types.js';
import {
generatePKCEParams,
startCallbackServer,
getPortFromUrl,
buildAuthorizationUrl,
exchangeCodeForToken,
refreshAccessToken,
type OAuthFlowConfig,
} from '../../utils/oauth-flow.js';
import { openBrowserSecurely } from '../../utils/secure-browser-launcher.js';
import { getConsentForOauth } from '../../utils/authConsent.js';
import { FatalCancellationError, getErrorMessage } from '../../utils/errors.js';
import { coreEvents } from '../../utils/events.js';
import { debugLogger } from '../../utils/debugLogger.js';
import { Storage } from '../../config/storage.js';
/**
* Authentication provider for OAuth 2.0 Authorization Code flow with PKCE.
*
* Used by A2A remote agents whose security scheme is `oauth2`.
* Reuses the shared OAuth flow primitives from `utils/oauth-flow.ts`
* and persists tokens via `MCPOAuthTokenStorage`.
*/
export class OAuth2AuthProvider extends BaseA2AAuthProvider {
readonly type = 'oauth2' as const;
private readonly tokenStorage: MCPOAuthTokenStorage;
private cachedToken: OAuthToken | null = null;
/** Resolved OAuth URLs — may come from config or agent card. */
private authorizationUrl: string | undefined;
private tokenUrl: string | undefined;
private scopes: string[] | undefined;
constructor(
private readonly config: OAuth2AuthConfig,
private readonly agentName: string,
agentCard?: AgentCard,
private readonly agentCardUrl?: string,
) {
super();
this.tokenStorage = new MCPOAuthTokenStorage(
Storage.getA2AOAuthTokensPath(),
'gemini-cli-a2a',
);
// Seed from user config.
this.authorizationUrl = config.authorization_url;
this.tokenUrl = config.token_url;
this.scopes = config.scopes;
// Fall back to agent card's OAuth2 security scheme if user config is incomplete.
this.mergeAgentCardDefaults(agentCard);
}
/**
* Initialize the provider by loading any persisted token from storage.
* Also discovers OAuth URLs from the agent card if not yet resolved.
*/
override async initialize(): Promise<void> {
// If OAuth URLs are still missing, fetch the agent card to discover them.
if ((!this.authorizationUrl || !this.tokenUrl) && this.agentCardUrl) {
await this.fetchAgentCardDefaults();
}
const credentials = await this.tokenStorage.getCredentials(this.agentName);
if (credentials && !this.tokenStorage.isTokenExpired(credentials.token)) {
this.cachedToken = credentials.token;
debugLogger.debug(
`[OAuth2AuthProvider] Loaded valid cached token for "${this.agentName}"`,
);
}
}
/**
* Return an Authorization header with a valid Bearer token.
* Refreshes or triggers interactive auth as needed.
*/
override async headers(): Promise<HttpHeaders> {
// 1. Valid cached token → return immediately.
if (
this.cachedToken &&
!this.tokenStorage.isTokenExpired(this.cachedToken)
) {
return { Authorization: `Bearer ${this.cachedToken.accessToken}` };
}
// 2. Expired but has refresh token → attempt silent refresh.
if (
this.cachedToken?.refreshToken &&
this.tokenUrl &&
this.config.client_id
) {
try {
const refreshed = await refreshAccessToken(
{
clientId: this.config.client_id,
clientSecret: this.config.client_secret,
scopes: this.scopes,
},
this.cachedToken.refreshToken,
this.tokenUrl,
);
this.cachedToken = this.toOAuthToken(
refreshed,
this.cachedToken.refreshToken,
);
await this.persistToken();
return { Authorization: `Bearer ${this.cachedToken.accessToken}` };
} catch (error) {
debugLogger.debug(
`[OAuth2AuthProvider] Refresh failed, falling back to interactive flow: ${getErrorMessage(error)}`,
);
// Clear stale credentials and fall through to interactive flow.
await this.tokenStorage.deleteCredentials(this.agentName);
}
}
// 3. No valid token → interactive browser-based auth.
this.cachedToken = await this.authenticateInteractively();
return { Authorization: `Bearer ${this.cachedToken.accessToken}` };
}
/**
* On 401/403, clear the cached token and re-authenticate (up to MAX_AUTH_RETRIES).
*/
override async shouldRetryWithHeaders(
_req: RequestInit,
res: Response,
): Promise<HttpHeaders | undefined> {
if (res.status !== 401 && res.status !== 403) {
this.authRetryCount = 0;
return undefined;
}
if (this.authRetryCount >= BaseA2AAuthProvider.MAX_AUTH_RETRIES) {
return undefined;
}
this.authRetryCount++;
debugLogger.debug(
'[OAuth2AuthProvider] Auth failure, clearing token and re-authenticating',
);
this.cachedToken = null;
await this.tokenStorage.deleteCredentials(this.agentName);
return this.headers();
}
// ---------------------------------------------------------------------------
// Private helpers
// ---------------------------------------------------------------------------
/**
* Merge authorization_url, token_url, and scopes from the agent card's
* `securitySchemes` when not already provided via user config.
*/
private mergeAgentCardDefaults(
agentCard?: Pick<AgentCard, 'securitySchemes'> | null,
): void {
if (!agentCard?.securitySchemes) return;
for (const scheme of Object.values(agentCard.securitySchemes)) {
if (scheme.type === 'oauth2' && scheme.flows.authorizationCode) {
const flow = scheme.flows.authorizationCode;
this.authorizationUrl ??= flow.authorizationUrl;
this.tokenUrl ??= flow.tokenUrl;
this.scopes ??= Object.keys(flow.scopes);
break; // Use the first matching scheme.
}
}
}
/**
* Fetch the agent card from `agentCardUrl` using `DefaultAgentCardResolver`
* (which normalizes proto-format cards) and extract OAuth2 URLs.
*/
private async fetchAgentCardDefaults(): Promise<void> {
if (!this.agentCardUrl) return;
try {
debugLogger.debug(
`[OAuth2AuthProvider] Fetching agent card from ${this.agentCardUrl}`,
);
const resolver = new DefaultAgentCardResolver();
const card = await resolver.resolve(this.agentCardUrl, '');
this.mergeAgentCardDefaults(card);
} catch (error) {
debugLogger.warn(
`[OAuth2AuthProvider] Could not fetch agent card for OAuth URL discovery: ${getErrorMessage(error)}`,
);
}
}
/**
* Run a full OAuth 2.0 Authorization Code + PKCE flow through the browser.
*/
private async authenticateInteractively(): Promise<OAuthToken> {
if (!this.config.client_id) {
throw new Error(
`OAuth2 authentication for agent "${this.agentName}" requires a client_id. ` +
'Add client_id to the auth config in your agent definition.',
);
}
if (!this.authorizationUrl || !this.tokenUrl) {
throw new Error(
`OAuth2 authentication for agent "${this.agentName}" requires authorization_url and token_url. ` +
'Provide them in the auth config or ensure the agent card exposes an oauth2 security scheme.',
);
}
const flowConfig: OAuthFlowConfig = {
clientId: this.config.client_id,
clientSecret: this.config.client_secret,
authorizationUrl: this.authorizationUrl,
tokenUrl: this.tokenUrl,
scopes: this.scopes,
};
const pkceParams = generatePKCEParams();
const preferredPort = getPortFromUrl(flowConfig.redirectUri);
const callbackServer = startCallbackServer(pkceParams.state, preferredPort);
const redirectPort = await callbackServer.port;
const authUrl = buildAuthorizationUrl(
flowConfig,
pkceParams,
redirectPort,
/* resource= */ undefined, // No MCP resource parameter for A2A.
);
const consent = await getConsentForOauth(
`Authentication required for A2A agent: '${this.agentName}'.`,
);
if (!consent) {
throw new FatalCancellationError('Authentication cancelled by user.');
}
coreEvents.emitFeedback(
'info',
`→ Opening your browser for OAuth sign-in...
` +
`If the browser does not open, copy and paste this URL into your browser:
` +
`${authUrl}
` +
`💡 TIP: Triple-click to select the entire URL, then copy and paste it into your browser.
` +
`⚠️ Make sure to copy the COMPLETE URL - it may wrap across multiple lines.`,
);
try {
await openBrowserSecurely(authUrl);
} catch (error) {
debugLogger.warn(
'Failed to open browser automatically:',
getErrorMessage(error),
);
}
const { code } = await callbackServer.response;
debugLogger.debug(
'✓ Authorization code received, exchanging for tokens...',
);
const tokenResponse = await exchangeCodeForToken(
flowConfig,
code,
pkceParams.codeVerifier,
redirectPort,
/* resource= */ undefined,
);
if (!tokenResponse.access_token) {
throw new Error('No access token received from token endpoint');
}
const token = this.toOAuthToken(tokenResponse);
this.cachedToken = token;
await this.persistToken();
debugLogger.debug('✓ OAuth2 authentication successful! Token saved.');
return token;
}
/**
* Convert an `OAuthTokenResponse` into the internal `OAuthToken` format.
*/
private toOAuthToken(
response: {
access_token: string;
token_type?: string;
expires_in?: number;
refresh_token?: string;
scope?: string;
},
fallbackRefreshToken?: string,
): OAuthToken {
const token: OAuthToken = {
accessToken: response.access_token,
tokenType: response.token_type || 'Bearer',
refreshToken: response.refresh_token || fallbackRefreshToken,
scope: response.scope,
};
if (response.expires_in) {
token.expiresAt = Date.now() + response.expires_in * 1000;
}
return token;
}
/**
* Persist the current cached token to disk.
*/
private async persistToken(): Promise<void> {
if (!this.cachedToken) return;
await this.tokenStorage.saveToken(
this.agentName,
this.cachedToken,
this.config.client_id,
this.tokenUrl,
);
}
}
@@ -74,6 +74,10 @@ export interface OAuth2AuthConfig extends BaseAuthConfig {
client_id?: string;
client_secret?: string;
scopes?: string[];
/** Override or provide the authorization endpoint URL. Discovered from agent card if omitted. */
authorization_url?: string;
/** Override or provide the token endpoint URL. Discovered from agent card if omitted. */
token_url?: string;
}
/** Client config corresponding to OpenIdConnectSecurityScheme. */
@@ -591,6 +591,7 @@ describe('AgentRegistry', () => {
expect(A2AAuthProviderFactory.create).toHaveBeenCalledWith({
authConfig: mockAuth,
agentName: 'RemoteAgentWithAuth',
agentCardUrl: 'https://example.com/card',
});
expect(loadAgentSpy).toHaveBeenCalledWith(
'RemoteAgentWithAuth',
+1
View File
@@ -416,6 +416,7 @@ export class AgentRegistry {
const provider = await A2AAuthProviderFactory.create({
authConfig: definition.auth,
agentName: definition.name,
agentCardUrl: remoteDef.agentCardUrl,
});
if (!provider) {
throw new Error(
@@ -195,6 +195,7 @@ describe('RemoteAgentInvocation', () => {
expect(A2AAuthProviderFactory.create).toHaveBeenCalledWith({
authConfig: mockAuth,
agentName: 'test-agent',
agentCardUrl: 'http://test-agent/card',
});
expect(mockClientManager.loadAgent).toHaveBeenCalledWith(
'test-agent',
@@ -120,6 +120,7 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
const provider = await A2AAuthProviderFactory.create({
authConfig: this.definition.auth,
agentName: this.definition.name,
agentCardUrl: this.definition.agentCardUrl,
});
if (!provider) {
throw new Error(
+4
View File
@@ -62,6 +62,10 @@ export class Storage {
return path.join(Storage.getGlobalGeminiDir(), 'mcp-oauth-tokens.json');
}
static getA2AOAuthTokensPath(): string {
return path.join(Storage.getGlobalGeminiDir(), 'a2a-oauth-tokens.json');
}
static getGlobalSettingsPath(): string {
return path.join(Storage.getGlobalGeminiDir(), 'settings.json');
}
@@ -1154,6 +1154,7 @@ describe('GeminiChat', () => {
1,
);
expect(mockLogContentRetry).not.toHaveBeenCalled();
expect(mockLogContentRetryFailure).toHaveBeenCalledTimes(1);
});
it('should yield a RETRY event when an invalid stream is encountered', async () => {
+12 -19
View File
@@ -344,8 +344,6 @@ export class GeminiChat {
this: GeminiChat,
): AsyncGenerator<StreamEvent, void, void> {
try {
let lastError: unknown = new Error('Request failed after all retries.');
const maxAttempts = INVALID_CONTENT_RETRY_OPTIONS.maxAttempts;
for (let attempt = 0; attempt < maxAttempts; attempt++) {
@@ -374,15 +372,13 @@ export class GeminiChat {
yield { type: StreamEventType.CHUNK, value: chunk };
}
lastError = null;
break;
return;
} catch (error) {
if (error instanceof AgentExecutionStoppedError) {
yield {
type: StreamEventType.AGENT_EXECUTION_STOPPED,
reason: error.reason,
};
lastError = null; // Clear error as this is an expected stop
return; // Stop the generator
}
@@ -397,7 +393,6 @@ export class GeminiChat {
value: error.syntheticResponse,
};
}
lastError = null; // Clear error as this is an expected stop
return; // Stop the generator
}
@@ -415,8 +410,9 @@ export class GeminiChat {
}
// Fall through to retry logic for retryable connection errors
}
lastError = error;
const isContentError = error instanceof InvalidStreamError;
const errorType = isContentError ? error.type : 'NETWORK_ERROR';
if (
(isContentError && isGemini2Model(model)) ||
@@ -425,11 +421,10 @@ export class GeminiChat {
// Check if we have more attempts left.
if (attempt < maxAttempts - 1) {
const delayMs = INVALID_CONTENT_RETRY_OPTIONS.initialDelayMs;
const retryType = isContentError ? error.type : 'NETWORK_ERROR';
logContentRetry(
this.config,
new ContentRetryEvent(attempt, retryType, delayMs, model),
new ContentRetryEvent(attempt, errorType, delayMs, model),
);
coreEvents.emitRetryAttempt({
attempt: attempt + 1,
@@ -444,21 +439,19 @@ export class GeminiChat {
continue;
}
}
break;
}
}
if (lastError) {
if (
lastError instanceof InvalidStreamError &&
isGemini2Model(model)
) {
// If we've aborted, we throw without logging a failure.
if (signal.aborted) {
throw error;
}
logContentRetryFailure(
this.config,
new ContentRetryFailureEvent(maxAttempts, lastError.type, model),
new ContentRetryFailureEvent(attempt + 1, errorType, model),
);
throw error;
}
throw lastError;
}
} finally {
streamDoneResolver!();
@@ -401,6 +401,7 @@ describe('GeminiChat Network Retries', () => {
// Should only be called once (no retry)
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(1);
expect(mockLogContentRetryFailure).not.toHaveBeenCalled();
});
it('should retry on SSL error during stream iteration (mid-stream failure)', async () => {
@@ -44,6 +44,7 @@ vi.mock('../telemetry/index.js', () => ({
}));
vi.mock('../utils/secure-browser-launcher.js', () => ({
openBrowserSecurely: vi.fn(),
shouldLaunchBrowser: vi.fn().mockReturnValue(true),
}));
// Mock debugLogger to prevent console pollution and allow spying
+10 -1
View File
@@ -6,7 +6,10 @@
import type { Config } from '../config/config.js';
import { AuthType } from '../core/contentGenerator.js';
import { openBrowserSecurely } from '../utils/secure-browser-launcher.js';
import {
openBrowserSecurely,
shouldLaunchBrowser,
} from '../utils/secure-browser-launcher.js';
import { debugLogger } from '../utils/debugLogger.js';
import { getErrorMessage } from '../utils/errors.js';
import type { FallbackIntent, FallbackRecommendation } from './types.js';
@@ -112,6 +115,12 @@ export async function handleFallback(
}
async function handleUpgrade() {
if (!shouldLaunchBrowser()) {
debugLogger.log(
`Cannot open browser in this environment. Please visit: ${UPGRADE_URL_PAGE}`,
);
return;
}
try {
await openBrowserSecurely(UPGRADE_URL_PAGE);
} catch (error) {
+14 -5
View File
@@ -21,14 +21,23 @@ import {
} from './token-storage/index.js';
/**
* Class for managing MCP OAuth token storage and retrieval.
* Class for managing OAuth token storage and retrieval.
* Used by both MCP and A2A OAuth providers. Pass a custom `tokenFilePath`
* to store tokens in a protocol-specific file.
*/
export class MCPOAuthTokenStorage implements TokenStorage {
private readonly hybridTokenStorage = new HybridTokenStorage(
DEFAULT_SERVICE_NAME,
);
private readonly hybridTokenStorage: HybridTokenStorage;
private readonly useEncryptedFile =
process.env[FORCE_ENCRYPTED_FILE_ENV_VAR] === 'true';
private readonly customTokenFilePath?: string;
constructor(
tokenFilePath?: string,
serviceName: string = DEFAULT_SERVICE_NAME,
) {
this.customTokenFilePath = tokenFilePath;
this.hybridTokenStorage = new HybridTokenStorage(serviceName);
}
/**
* Get the path to the token storage file.
@@ -36,7 +45,7 @@ export class MCPOAuthTokenStorage implements TokenStorage {
* @returns The full path to the token storage file
*/
private getTokenFilePath(): string {
return Storage.getMcpOAuthTokensPath();
return this.customTokenFilePath ?? Storage.getMcpOAuthTokensPath();
}
/**
@@ -195,6 +195,9 @@ describe('ClearcutLogger', () => {
vi.stubEnv('MONOSPACE_ENV', '');
vi.stubEnv('REPLIT_USER', '');
vi.stubEnv('__COG_BASHRC_SOURCED', '');
vi.stubEnv('GH_PR_NUMBER', '');
vi.stubEnv('GH_ISSUE_NUMBER', '');
vi.stubEnv('GH_CUSTOM_TRACKING_ID', '');
});
function setup({
@@ -596,6 +599,110 @@ describe('ClearcutLogger', () => {
});
});
describe('GITHUB_EVENT_NAME metadata', () => {
it('includes event name when GITHUB_EVENT_NAME is set', () => {
const { logger } = setup({});
vi.stubEnv('GITHUB_EVENT_NAME', 'issues');
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
expect(event?.event_metadata[0]).toContainEqual({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GH_EVENT_NAME,
value: 'issues',
});
});
it('does not include event name when GITHUB_EVENT_NAME is not set', () => {
const { logger } = setup({});
vi.stubEnv('GITHUB_EVENT_NAME', undefined);
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
const hasEventName = event?.event_metadata[0].some(
(item) =>
item.gemini_cli_key === EventMetadataKey.GEMINI_CLI_GH_EVENT_NAME,
);
expect(hasEventName).toBe(false);
});
});
describe('GH_PR_NUMBER metadata', () => {
it('includes PR number when GH_PR_NUMBER is set', () => {
vi.stubEnv('GH_PR_NUMBER', '123');
const { logger } = setup({});
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
expect(event?.event_metadata[0]).toContainEqual({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GH_PR_NUMBER,
value: '123',
});
});
it('does not include PR number when GH_PR_NUMBER is not set', () => {
vi.stubEnv('GH_PR_NUMBER', undefined);
const { logger } = setup({});
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
const hasPRNumber = event?.event_metadata[0].some(
(item) =>
item.gemini_cli_key === EventMetadataKey.GEMINI_CLI_GH_PR_NUMBER,
);
expect(hasPRNumber).toBe(false);
});
});
describe('GH_ISSUE_NUMBER metadata', () => {
it('includes issue number when GH_ISSUE_NUMBER is set', () => {
vi.stubEnv('GH_ISSUE_NUMBER', '456');
const { logger } = setup({});
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
expect(event?.event_metadata[0]).toContainEqual({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GH_ISSUE_NUMBER,
value: '456',
});
});
it('does not include issue number when GH_ISSUE_NUMBER is not set', () => {
vi.stubEnv('GH_ISSUE_NUMBER', undefined);
const { logger } = setup({});
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
const hasIssueNumber = event?.event_metadata[0].some(
(item) =>
item.gemini_cli_key === EventMetadataKey.GEMINI_CLI_GH_ISSUE_NUMBER,
);
expect(hasIssueNumber).toBe(false);
});
});
describe('GH_CUSTOM_TRACKING_ID metadata', () => {
it('includes custom tracking ID when GH_CUSTOM_TRACKING_ID is set', () => {
vi.stubEnv('GH_CUSTOM_TRACKING_ID', 'abc-789');
const { logger } = setup({});
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
expect(event?.event_metadata[0]).toContainEqual({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GH_CUSTOM_TRACKING_ID,
value: 'abc-789',
});
});
it('does not include custom tracking ID when GH_CUSTOM_TRACKING_ID is not set', () => {
vi.stubEnv('GH_CUSTOM_TRACKING_ID', undefined);
const { logger } = setup({});
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
const hasTrackingId = event?.event_metadata[0].some(
(item) =>
item.gemini_cli_key ===
EventMetadataKey.GEMINI_CLI_GH_CUSTOM_TRACKING_ID,
);
expect(hasTrackingId).toBe(false);
});
});
describe('GITHUB_REPOSITORY metadata', () => {
it('includes hashed repository when GITHUB_REPOSITORY is set', () => {
vi.stubEnv('GITHUB_REPOSITORY', 'google/gemini-cli');
@@ -190,6 +190,34 @@ function determineGHRepositoryName(): string | undefined {
return process.env['GITHUB_REPOSITORY'];
}
/**
* Determines the GitHub event name if the CLI is running in a GitHub Actions environment.
*/
function determineGHEventName(): string | undefined {
return process.env['GITHUB_EVENT_NAME'];
}
/**
* Determines the GitHub Pull Request number if the CLI is running in a GitHub Actions environment.
*/
function determineGHPRNumber(): string | undefined {
return process.env['GH_PR_NUMBER'];
}
/**
* Determines the GitHub Issue number if the CLI is running in a GitHub Actions environment.
*/
function determineGHIssueNumber(): string | undefined {
return process.env['GH_ISSUE_NUMBER'];
}
/**
* Determines the GitHub custom tracking ID if the CLI is running in a GitHub Actions environment.
*/
function determineGHCustomTrackingId(): string | undefined {
return process.env['GH_CUSTOM_TRACKING_ID'];
}
/**
* Clearcut URL to send logging events to.
*/
@@ -372,6 +400,10 @@ export class ClearcutLogger {
const email = this.userAccountManager.getCachedGoogleAccount();
const surface = determineSurface();
const ghWorkflowName = determineGHWorkflowName();
const ghEventName = determineGHEventName();
const ghPRNumber = determineGHPRNumber();
const ghIssueNumber = determineGHIssueNumber();
const ghCustomTrackingId = determineGHCustomTrackingId();
const baseMetadata: EventValue[] = [
...data,
{
@@ -406,6 +438,34 @@ export class ClearcutLogger {
});
}
if (ghEventName) {
baseMetadata.push({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GH_EVENT_NAME,
value: ghEventName,
});
}
if (ghPRNumber) {
baseMetadata.push({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GH_PR_NUMBER,
value: ghPRNumber,
});
}
if (ghIssueNumber) {
baseMetadata.push({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GH_ISSUE_NUMBER,
value: ghIssueNumber,
});
}
if (ghCustomTrackingId) {
baseMetadata.push({
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GH_CUSTOM_TRACKING_ID,
value: ghCustomTrackingId,
});
}
const logEvent: LogEvent = {
console_type: 'GEMINI_CLI',
application: 102, // GEMINI_CLI
@@ -7,7 +7,7 @@
// Defines valid event metadata keys for Clearcut logging.
export enum EventMetadataKey {
// Deleted enums: 24
// Next ID: 176
// Next ID: 180
GEMINI_CLI_KEY_UNKNOWN = 0,
@@ -231,6 +231,18 @@ export enum EventMetadataKey {
// Logs the repository name of the GitHub Action that triggered the session.
GEMINI_CLI_GH_REPOSITORY_NAME_HASH = 132,
// Logs the event name of the GitHub Action that triggered the session.
GEMINI_CLI_GH_EVENT_NAME = 176,
// Logs the Pull Request number if the workflow is operating on a PR.
GEMINI_CLI_GH_PR_NUMBER = 177,
// Logs the Issue number if the workflow is operating on an Issue.
GEMINI_CLI_GH_ISSUE_NUMBER = 178,
// Logs a custom tracking string (e.g. a comma-separated list of issue IDs for scheduled batches).
GEMINI_CLI_GH_CUSTOM_TRACKING_ID = 179,
// ==========================================================================
// Loop Detected Event Keys
// ===========================================================================
@@ -134,21 +134,21 @@ describe('classifyGoogleError', () => {
expect((result as TerminalQuotaError).cause).toBe(apiError);
});
it('should return RetryableQuotaError for long retry delays', () => {
it('should return TerminalQuotaError for retry delays over 5 minutes', () => {
const apiError: GoogleApiError = {
code: 429,
message: 'Too many requests',
details: [
{
'@type': 'type.googleapis.com/google.rpc.RetryInfo',
retryDelay: '301s', // Any delay is now retryable
retryDelay: '301s', // Over 5 min threshold => terminal
},
],
};
vi.spyOn(errorParser, 'parseGoogleApiError').mockReturnValue(apiError);
const result = classifyGoogleError(new Error());
expect(result).toBeInstanceOf(RetryableQuotaError);
expect((result as RetryableQuotaError).retryDelayMs).toBe(301000);
expect(result).toBeInstanceOf(TerminalQuotaError);
expect((result as TerminalQuotaError).retryDelayMs).toBe(301000);
});
it('should return RetryableQuotaError for short retry delays', () => {
@@ -285,6 +285,34 @@ describe('classifyGoogleError', () => {
);
});
it('should return TerminalQuotaError for Cloud Code RATE_LIMIT_EXCEEDED with retry delay over 5 minutes', () => {
const apiError: GoogleApiError = {
code: 429,
message:
'You have exhausted your capacity on this model. Your quota will reset after 10m.',
details: [
{
'@type': 'type.googleapis.com/google.rpc.ErrorInfo',
reason: 'RATE_LIMIT_EXCEEDED',
domain: 'cloudcode-pa.googleapis.com',
metadata: {
uiMessage: 'true',
model: 'gemini-2.5-pro',
},
},
{
'@type': 'type.googleapis.com/google.rpc.RetryInfo',
retryDelay: '600s',
},
],
};
vi.spyOn(errorParser, 'parseGoogleApiError').mockReturnValue(apiError);
const result = classifyGoogleError(new Error());
expect(result).toBeInstanceOf(TerminalQuotaError);
expect((result as TerminalQuotaError).retryDelayMs).toBe(600000);
expect((result as TerminalQuotaError).reason).toBe('RATE_LIMIT_EXCEEDED');
});
it('should return TerminalQuotaError for Cloud Code QUOTA_EXHAUSTED', () => {
const apiError: GoogleApiError = {
code: 429,
@@ -427,6 +455,40 @@ describe('classifyGoogleError', () => {
}
});
it('should return TerminalQuotaError when fallback "Please retry in" delay exceeds 5 minutes', () => {
const errorWithEmptyDetails = {
error: {
code: 429,
message: 'Resource exhausted. Please retry in 400s',
details: [],
},
};
const result = classifyGoogleError(errorWithEmptyDetails);
expect(result).toBeInstanceOf(TerminalQuotaError);
if (result instanceof TerminalQuotaError) {
expect(result.retryDelayMs).toBe(400000);
}
});
it('should return RetryableQuotaError when retry delay is exactly 5 minutes', () => {
const apiError: GoogleApiError = {
code: 429,
message: 'Too many requests',
details: [
{
'@type': 'type.googleapis.com/google.rpc.RetryInfo',
retryDelay: '300s',
},
],
};
vi.spyOn(errorParser, 'parseGoogleApiError').mockReturnValue(apiError);
const result = classifyGoogleError(new Error());
expect(result).toBeInstanceOf(RetryableQuotaError);
expect((result as RetryableQuotaError).retryDelayMs).toBe(300000);
});
it('should return RetryableQuotaError without delay time for generic 429 without specific message', () => {
const generic429 = {
status: 429,
+33 -10
View File
@@ -100,6 +100,13 @@ function parseDurationInSeconds(duration: string): number | null {
return null;
}
/**
* Maximum retry delay (in seconds) before a retryable error is treated as terminal.
* If the server suggests waiting longer than this, the user is effectively locked out,
* so we trigger the fallback/credits flow instead of silently waiting.
*/
const MAX_RETRYABLE_DELAY_SECONDS = 300; // 5 minutes
/**
* Valid Cloud Code API domains for VALIDATION_REQUIRED errors.
*/
@@ -248,15 +255,15 @@ export function classifyGoogleError(error: unknown): unknown {
if (match?.[1]) {
const retryDelaySeconds = parseDurationInSeconds(match[1]);
if (retryDelaySeconds !== null) {
return new RetryableQuotaError(
errorMessage,
googleApiError ?? {
code: status ?? 429,
message: errorMessage,
details: [],
},
retryDelaySeconds,
);
const cause = googleApiError ?? {
code: status ?? 429,
message: errorMessage,
details: [],
};
if (retryDelaySeconds > MAX_RETRYABLE_DELAY_SECONDS) {
return new TerminalQuotaError(errorMessage, cause, retryDelaySeconds);
}
return new RetryableQuotaError(errorMessage, cause, retryDelaySeconds);
}
} else if (status === 429 || status === 499) {
// Fallback: If it is a 429 or 499 but doesn't have a specific "retry in" message,
@@ -325,10 +332,19 @@ export function classifyGoogleError(error: unknown): unknown {
if (errorInfo.domain) {
if (isCloudCodeDomain(errorInfo.domain)) {
if (errorInfo.reason === 'RATE_LIMIT_EXCEEDED') {
const effectiveDelay = delaySeconds ?? 10;
if (effectiveDelay > MAX_RETRYABLE_DELAY_SECONDS) {
return new TerminalQuotaError(
`${googleApiError.message}`,
googleApiError,
effectiveDelay,
errorInfo.reason,
);
}
return new RetryableQuotaError(
`${googleApiError.message}`,
googleApiError,
delaySeconds ?? 10,
effectiveDelay,
);
}
if (errorInfo.reason === 'QUOTA_EXHAUSTED') {
@@ -345,6 +361,13 @@ export function classifyGoogleError(error: unknown): unknown {
// 2. Check for delays in RetryInfo
if (retryInfo?.retryDelay && delaySeconds) {
if (delaySeconds > MAX_RETRYABLE_DELAY_SECONDS) {
return new TerminalQuotaError(
`${googleApiError.message}\nSuggested retry after ${retryInfo.retryDelay}.`,
googleApiError,
delaySeconds,
);
}
return new RetryableQuotaError(
`${googleApiError.message}\nSuggested retry after ${retryInfo.retryDelay}.`,
googleApiError,