mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-23 03:24:42 -07:00
Merge branch 'main' into restart-resume
This commit is contained in:
@@ -25,7 +25,7 @@
|
||||
"dist"
|
||||
],
|
||||
"dependencies": {
|
||||
"@a2a-js/sdk": "^0.3.8",
|
||||
"@a2a-js/sdk": "0.3.11",
|
||||
"@google-cloud/storage": "^7.16.0",
|
||||
"@google/gemini-cli-core": "file:../core",
|
||||
"express": "^5.1.0",
|
||||
|
||||
@@ -11,6 +11,7 @@ import { createMockCommandContext } from '../../test-utils/mockCommandContext.js
|
||||
import {
|
||||
AuthType,
|
||||
openBrowserSecurely,
|
||||
shouldLaunchBrowser,
|
||||
UPGRADE_URL_PAGE,
|
||||
} from '@google/gemini-cli-core';
|
||||
|
||||
@@ -20,6 +21,7 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
return {
|
||||
...actual,
|
||||
openBrowserSecurely: vi.fn(),
|
||||
shouldLaunchBrowser: vi.fn().mockReturnValue(true),
|
||||
UPGRADE_URL_PAGE: 'https://goo.gle/set-up-gemini-code-assist',
|
||||
};
|
||||
});
|
||||
@@ -96,4 +98,21 @@ describe('upgradeCommand', () => {
|
||||
content: 'Failed to open upgrade page: Failed to open',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return URL message when shouldLaunchBrowser returns false', async () => {
|
||||
vi.mocked(shouldLaunchBrowser).mockReturnValue(false);
|
||||
|
||||
if (!upgradeCommand.action) {
|
||||
throw new Error('The upgrade command must have an action.');
|
||||
}
|
||||
|
||||
const result = await upgradeCommand.action(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: `Please open this URL in a browser: ${UPGRADE_URL_PAGE}`,
|
||||
});
|
||||
expect(openBrowserSecurely).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
import {
|
||||
AuthType,
|
||||
openBrowserSecurely,
|
||||
shouldLaunchBrowser,
|
||||
UPGRADE_URL_PAGE,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type { SlashCommand } from './types.js';
|
||||
@@ -35,6 +36,14 @@ export const upgradeCommand: SlashCommand = {
|
||||
};
|
||||
}
|
||||
|
||||
if (!shouldLaunchBrowser()) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: `Please open this URL in a browser: ${UPGRADE_URL_PAGE}`,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
await openBrowserSecurely(UPGRADE_URL_PAGE);
|
||||
} catch (error) {
|
||||
|
||||
@@ -18,6 +18,12 @@ Spinner Connecting to MCP servers... (0/5) - Waiting for: s1, s2, s3, +2 more
|
||||
"
|
||||
`;
|
||||
|
||||
exports[`ConfigInitDisplay > truncates list of waiting servers if too many 2`] = `
|
||||
"
|
||||
Spinner Connecting to MCP servers... (0/5) - Waiting for: s1, s2, s3, +2 more
|
||||
"
|
||||
`;
|
||||
|
||||
exports[`ConfigInitDisplay > updates message on McpClientUpdate event 1`] = `
|
||||
"
|
||||
Spinner Connecting to MCP servers... (1/2) - Waiting for: server2
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
shouldAutoUseCredits,
|
||||
shouldShowOverageMenu,
|
||||
shouldShowEmptyWalletMenu,
|
||||
shouldLaunchBrowser,
|
||||
logBillingEvent,
|
||||
G1_CREDIT_TYPE,
|
||||
UserTierId,
|
||||
@@ -32,6 +33,7 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
shouldShowEmptyWalletMenu: vi.fn(),
|
||||
logBillingEvent: vi.fn(),
|
||||
openBrowserSecurely: vi.fn(),
|
||||
shouldLaunchBrowser: vi.fn().mockReturnValue(true),
|
||||
};
|
||||
});
|
||||
|
||||
@@ -237,4 +239,49 @@ describe('handleCreditsFlow', () => {
|
||||
expect(isDialogPending.current).toBe(false);
|
||||
expect(mockSetEmptyWalletRequest).toHaveBeenCalledWith(null);
|
||||
});
|
||||
|
||||
describe('headless mode (shouldLaunchBrowser=false)', () => {
|
||||
beforeEach(() => {
|
||||
vi.mocked(shouldLaunchBrowser).mockReturnValue(false);
|
||||
});
|
||||
|
||||
it('should show manage URL in history when manage selected in headless mode', async () => {
|
||||
vi.mocked(shouldShowOverageMenu).mockReturnValue(true);
|
||||
|
||||
const flowPromise = handleCreditsFlow(makeArgs());
|
||||
const request = mockSetOverageMenuRequest.mock.calls[0][0];
|
||||
request.resolve('manage');
|
||||
const result = await flowPromise;
|
||||
|
||||
expect(result).toBe('stop');
|
||||
expect(mockHistoryManager.addItem).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: MessageType.INFO,
|
||||
text: expect.stringContaining('Please open this URL in a browser:'),
|
||||
}),
|
||||
expect.any(Number),
|
||||
);
|
||||
});
|
||||
|
||||
it('should show credits URL in history when get_credits selected in headless mode', async () => {
|
||||
vi.mocked(shouldShowEmptyWalletMenu).mockReturnValue(true);
|
||||
|
||||
const flowPromise = handleCreditsFlow(makeArgs());
|
||||
const request = mockSetEmptyWalletRequest.mock.calls[0][0];
|
||||
|
||||
// Trigger onGetCredits callback and wait for it
|
||||
await request.onGetCredits();
|
||||
|
||||
expect(mockHistoryManager.addItem).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: MessageType.INFO,
|
||||
text: expect.stringContaining('Please open this URL in a browser:'),
|
||||
}),
|
||||
expect.any(Number),
|
||||
);
|
||||
|
||||
request.resolve('get_credits');
|
||||
await flowPromise;
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -14,6 +14,7 @@ import {
|
||||
shouldShowOverageMenu,
|
||||
shouldShowEmptyWalletMenu,
|
||||
openBrowserSecurely,
|
||||
shouldLaunchBrowser,
|
||||
logBillingEvent,
|
||||
OverageMenuShownEvent,
|
||||
OverageOptionSelectedEvent,
|
||||
@@ -159,10 +160,23 @@ async function handleOverageMenu(
|
||||
case 'use_fallback':
|
||||
return 'retry_always';
|
||||
|
||||
case 'manage':
|
||||
case 'manage': {
|
||||
logCreditPurchaseClick(config, 'manage', usageLimitReachedModel);
|
||||
await openG1Url('activity', G1_UTM_CAMPAIGNS.MANAGE_ACTIVITY);
|
||||
const manageUrl = await openG1Url(
|
||||
'activity',
|
||||
G1_UTM_CAMPAIGNS.MANAGE_ACTIVITY,
|
||||
);
|
||||
if (manageUrl) {
|
||||
args.historyManager.addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: `Please open this URL in a browser: ${manageUrl}`,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
}
|
||||
return 'stop';
|
||||
}
|
||||
|
||||
case 'stop':
|
||||
default:
|
||||
@@ -205,13 +219,25 @@ async function handleEmptyWalletMenu(
|
||||
failedModel: usageLimitReachedModel,
|
||||
fallbackModel,
|
||||
resetTime,
|
||||
onGetCredits: () => {
|
||||
onGetCredits: async () => {
|
||||
logCreditPurchaseClick(
|
||||
config,
|
||||
'empty_wallet_menu',
|
||||
usageLimitReachedModel,
|
||||
);
|
||||
void openG1Url('credits', G1_UTM_CAMPAIGNS.EMPTY_WALLET_ADD_CREDITS);
|
||||
const creditsUrl = await openG1Url(
|
||||
'credits',
|
||||
G1_UTM_CAMPAIGNS.EMPTY_WALLET_ADD_CREDITS,
|
||||
);
|
||||
if (creditsUrl) {
|
||||
args.historyManager.addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: `Please open this URL in a browser: ${creditsUrl}`,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
}
|
||||
},
|
||||
resolve,
|
||||
});
|
||||
@@ -272,11 +298,16 @@ function logCreditPurchaseClick(
|
||||
async function openG1Url(
|
||||
path: 'activity' | 'credits',
|
||||
campaign: string,
|
||||
): Promise<void> {
|
||||
): Promise<string | undefined> {
|
||||
try {
|
||||
const userEmail = new UserAccountManager().getCachedGoogleAccount() ?? '';
|
||||
await openBrowserSecurely(buildG1Url(path, userEmail, campaign));
|
||||
const url = buildG1Url(path, userEmail, campaign);
|
||||
if (!shouldLaunchBrowser()) {
|
||||
return url;
|
||||
}
|
||||
await openBrowserSecurely(url);
|
||||
} catch {
|
||||
// Ignore browser open errors
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
"dist"
|
||||
],
|
||||
"dependencies": {
|
||||
"@a2a-js/sdk": "^0.3.10",
|
||||
"@a2a-js/sdk": "0.3.11",
|
||||
"@bufbuild/protobuf": "^2.11.0",
|
||||
"@google-cloud/logging": "^11.2.1",
|
||||
"@google-cloud/opentelemetry-cloud-monitoring-exporter": "^0.21.0",
|
||||
|
||||
@@ -140,7 +140,7 @@ describe('A2AClientManager', () => {
|
||||
expect(createAuthenticatingFetchWithRetry).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should use provided custom authentication handler', async () => {
|
||||
it('should use provided custom authentication handler for transports only', async () => {
|
||||
const customAuthHandler = {
|
||||
headers: vi.fn(),
|
||||
shouldRetryWithHeaders: vi.fn(),
|
||||
@@ -155,6 +155,66 @@ describe('A2AClientManager', () => {
|
||||
expect.anything(),
|
||||
customAuthHandler,
|
||||
);
|
||||
|
||||
// Card resolver should NOT use the authenticated fetch by default.
|
||||
const resolverInstance = vi.mocked(DefaultAgentCardResolver).mock
|
||||
.instances[0];
|
||||
expect(resolverInstance).toBeDefined();
|
||||
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock
|
||||
.calls[0][0];
|
||||
expect(resolverOptions?.fetchImpl).not.toBe(authFetchMock);
|
||||
});
|
||||
|
||||
it('should use unauthenticated fetch for card resolver and avoid authenticated fetch if success', async () => {
|
||||
const customAuthHandler = {
|
||||
headers: vi.fn(),
|
||||
shouldRetryWithHeaders: vi.fn(),
|
||||
};
|
||||
await manager.loadAgent(
|
||||
'AuthCardAgent',
|
||||
'http://authcard.agent/card',
|
||||
customAuthHandler as unknown as AuthenticationHandler,
|
||||
);
|
||||
|
||||
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock
|
||||
.calls[0][0];
|
||||
const cardFetch = resolverOptions?.fetchImpl as typeof fetch;
|
||||
|
||||
expect(cardFetch).toBeDefined();
|
||||
|
||||
await cardFetch('http://test.url');
|
||||
|
||||
expect(fetch).toHaveBeenCalledWith('http://test.url', expect.anything());
|
||||
expect(authFetchMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should retry with authenticating fetch if agent card fetch returns 401', async () => {
|
||||
const customAuthHandler = {
|
||||
headers: vi.fn(),
|
||||
shouldRetryWithHeaders: vi.fn(),
|
||||
};
|
||||
|
||||
// Mock the initial unauthenticated fetch to fail with 401
|
||||
vi.mocked(fetch).mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 401,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
await manager.loadAgent(
|
||||
'AuthCardAgent401',
|
||||
'http://authcard.agent/card',
|
||||
customAuthHandler as unknown as AuthenticationHandler,
|
||||
);
|
||||
|
||||
const resolverOptions = vi.mocked(DefaultAgentCardResolver).mock
|
||||
.calls[0][0];
|
||||
const cardFetch = resolverOptions?.fetchImpl as typeof fetch;
|
||||
|
||||
await cardFetch('http://test.url');
|
||||
|
||||
expect(fetch).toHaveBeenCalledWith('http://test.url', expect.anything());
|
||||
expect(authFetchMock).toHaveBeenCalledWith('http://test.url', undefined);
|
||||
});
|
||||
|
||||
it('should log a debug message upon loading an agent', async () => {
|
||||
|
||||
@@ -95,19 +95,37 @@ export class A2AClientManager {
|
||||
throw new Error(`Agent with name '${name}' is already loaded.`);
|
||||
}
|
||||
|
||||
let fetchImpl: typeof fetch = a2aFetch;
|
||||
// Authenticated fetch for API calls (transports).
|
||||
let authFetch: typeof fetch = a2aFetch;
|
||||
if (authHandler) {
|
||||
fetchImpl = createAuthenticatingFetchWithRetry(a2aFetch, authHandler);
|
||||
authFetch = createAuthenticatingFetchWithRetry(a2aFetch, authHandler);
|
||||
}
|
||||
|
||||
const resolver = new DefaultAgentCardResolver({ fetchImpl });
|
||||
// Use unauthenticated fetch for the agent card unless explicitly required.
|
||||
// Some servers reject unexpected auth headers on the card endpoint (e.g. 400).
|
||||
const cardFetch = async (
|
||||
input: RequestInfo | URL,
|
||||
init?: RequestInit,
|
||||
): Promise<Response> => {
|
||||
// Try without auth first
|
||||
const response = await a2aFetch(input, init);
|
||||
|
||||
// Retry with auth if we hit a 401/403
|
||||
if ((response.status === 401 || response.status === 403) && authFetch) {
|
||||
return authFetch(input, init);
|
||||
}
|
||||
|
||||
return response;
|
||||
};
|
||||
|
||||
const resolver = new DefaultAgentCardResolver({ fetchImpl: cardFetch });
|
||||
|
||||
const options = ClientFactoryOptions.createFrom(
|
||||
ClientFactoryOptions.default,
|
||||
{
|
||||
transports: [
|
||||
new RestTransportFactory({ fetchImpl }),
|
||||
new JsonRpcTransportFactory({ fetchImpl }),
|
||||
new RestTransportFactory({ fetchImpl: authFetch }),
|
||||
new JsonRpcTransportFactory({ fetchImpl: authFetch }),
|
||||
],
|
||||
cardResolver: resolver,
|
||||
},
|
||||
|
||||
@@ -576,5 +576,139 @@ auth:
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should parse remote agent with oauth2 auth', async () => {
|
||||
const filePath = await writeAgentMarkdown(`---
|
||||
kind: remote
|
||||
name: oauth2-agent
|
||||
agent_card_url: https://example.com/card
|
||||
auth:
|
||||
type: oauth2
|
||||
client_id: $MY_OAUTH_CLIENT_ID
|
||||
scopes:
|
||||
- read
|
||||
- write
|
||||
---
|
||||
`);
|
||||
const result = await parseAgentMarkdown(filePath);
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0]).toMatchObject({
|
||||
kind: 'remote',
|
||||
name: 'oauth2-agent',
|
||||
auth: {
|
||||
type: 'oauth2',
|
||||
client_id: '$MY_OAUTH_CLIENT_ID',
|
||||
scopes: ['read', 'write'],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should parse remote agent with oauth2 auth including all fields', async () => {
|
||||
const filePath = await writeAgentMarkdown(`---
|
||||
kind: remote
|
||||
name: oauth2-full-agent
|
||||
agent_card_url: https://example.com/card
|
||||
auth:
|
||||
type: oauth2
|
||||
client_id: my-client-id
|
||||
client_secret: my-client-secret
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
authorization_url: https://auth.example.com/authorize
|
||||
token_url: https://auth.example.com/token
|
||||
---
|
||||
`);
|
||||
const result = await parseAgentMarkdown(filePath);
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0]).toMatchObject({
|
||||
kind: 'remote',
|
||||
name: 'oauth2-full-agent',
|
||||
auth: {
|
||||
type: 'oauth2',
|
||||
client_id: 'my-client-id',
|
||||
client_secret: 'my-client-secret',
|
||||
scopes: ['openid', 'profile'],
|
||||
authorization_url: 'https://auth.example.com/authorize',
|
||||
token_url: 'https://auth.example.com/token',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should parse remote agent with minimal oauth2 config (type only)', async () => {
|
||||
const filePath = await writeAgentMarkdown(`---
|
||||
kind: remote
|
||||
name: oauth2-minimal-agent
|
||||
agent_card_url: https://example.com/card
|
||||
auth:
|
||||
type: oauth2
|
||||
---
|
||||
`);
|
||||
const result = await parseAgentMarkdown(filePath);
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0]).toMatchObject({
|
||||
kind: 'remote',
|
||||
name: 'oauth2-minimal-agent',
|
||||
auth: {
|
||||
type: 'oauth2',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should reject oauth2 auth with invalid authorization_url', async () => {
|
||||
const filePath = await writeAgentMarkdown(`---
|
||||
kind: remote
|
||||
name: invalid-oauth2-agent
|
||||
agent_card_url: https://example.com/card
|
||||
auth:
|
||||
type: oauth2
|
||||
client_id: my-client
|
||||
authorization_url: not-a-valid-url
|
||||
---
|
||||
`);
|
||||
await expect(parseAgentMarkdown(filePath)).rejects.toThrow(/Invalid url/);
|
||||
});
|
||||
|
||||
it('should reject oauth2 auth with invalid token_url', async () => {
|
||||
const filePath = await writeAgentMarkdown(`---
|
||||
kind: remote
|
||||
name: invalid-oauth2-agent
|
||||
agent_card_url: https://example.com/card
|
||||
auth:
|
||||
type: oauth2
|
||||
client_id: my-client
|
||||
token_url: not-a-valid-url
|
||||
---
|
||||
`);
|
||||
await expect(parseAgentMarkdown(filePath)).rejects.toThrow(/Invalid url/);
|
||||
});
|
||||
|
||||
it('should convert oauth2 auth config in markdownToAgentDefinition', () => {
|
||||
const markdown = {
|
||||
kind: 'remote' as const,
|
||||
name: 'oauth2-convert-agent',
|
||||
agent_card_url: 'https://example.com/card',
|
||||
auth: {
|
||||
type: 'oauth2' as const,
|
||||
client_id: '$MY_CLIENT_ID',
|
||||
scopes: ['read'],
|
||||
authorization_url: 'https://auth.example.com/authorize',
|
||||
token_url: 'https://auth.example.com/token',
|
||||
},
|
||||
};
|
||||
|
||||
const result = markdownToAgentDefinition(markdown);
|
||||
expect(result).toMatchObject({
|
||||
kind: 'remote',
|
||||
name: 'oauth2-convert-agent',
|
||||
auth: {
|
||||
type: 'oauth2',
|
||||
client_id: '$MY_CLIENT_ID',
|
||||
scopes: ['read'],
|
||||
authorization_url: 'https://auth.example.com/authorize',
|
||||
token_url: 'https://auth.example.com/token',
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -44,7 +44,7 @@ interface FrontmatterLocalAgentDefinition
|
||||
* Authentication configuration for remote agents in frontmatter format.
|
||||
*/
|
||||
interface FrontmatterAuthConfig {
|
||||
type: 'apiKey' | 'http';
|
||||
type: 'apiKey' | 'http' | 'oauth2';
|
||||
agent_card_requires_auth?: boolean;
|
||||
// API Key
|
||||
key?: string;
|
||||
@@ -55,6 +55,12 @@ interface FrontmatterAuthConfig {
|
||||
username?: string;
|
||||
password?: string;
|
||||
value?: string;
|
||||
// OAuth2
|
||||
client_id?: string;
|
||||
client_secret?: string;
|
||||
scopes?: string[];
|
||||
authorization_url?: string;
|
||||
token_url?: string;
|
||||
}
|
||||
|
||||
interface FrontmatterRemoteAgentDefinition
|
||||
@@ -147,8 +153,26 @@ const httpAuthSchema = z.object({
|
||||
value: z.string().min(1).optional(),
|
||||
});
|
||||
|
||||
/**
|
||||
* OAuth2 auth schema.
|
||||
* authorization_url and token_url can be discovered from the agent card if omitted.
|
||||
*/
|
||||
const oauth2AuthSchema = z.object({
|
||||
...baseAuthFields,
|
||||
type: z.literal('oauth2'),
|
||||
client_id: z.string().optional(),
|
||||
client_secret: z.string().optional(),
|
||||
scopes: z.array(z.string()).optional(),
|
||||
authorization_url: z.string().url().optional(),
|
||||
token_url: z.string().url().optional(),
|
||||
});
|
||||
|
||||
const authConfigSchema = z
|
||||
.discriminatedUnion('type', [apiKeyAuthSchema, httpAuthSchema])
|
||||
.discriminatedUnion('type', [
|
||||
apiKeyAuthSchema,
|
||||
httpAuthSchema,
|
||||
oauth2AuthSchema,
|
||||
])
|
||||
.superRefine((data, ctx) => {
|
||||
if (data.type === 'http') {
|
||||
if (data.value) {
|
||||
@@ -395,6 +419,17 @@ function convertFrontmatterAuthToConfig(
|
||||
}
|
||||
}
|
||||
|
||||
case 'oauth2':
|
||||
return {
|
||||
...base,
|
||||
type: 'oauth2',
|
||||
client_id: frontmatter.client_id,
|
||||
client_secret: frontmatter.client_secret,
|
||||
scopes: frontmatter.scopes,
|
||||
authorization_url: frontmatter.authorization_url,
|
||||
token_url: frontmatter.token_url,
|
||||
};
|
||||
|
||||
default: {
|
||||
const exhaustive: never = frontmatter.type;
|
||||
throw new Error(`Unknown auth type: ${exhaustive}`);
|
||||
|
||||
@@ -4,11 +4,22 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { A2AAuthProviderFactory } from './factory.js';
|
||||
import type { AgentCard, SecurityScheme } from '@a2a-js/sdk';
|
||||
import type { A2AAuthConfig } from './types.js';
|
||||
|
||||
// Mock token storage so OAuth2AuthProvider.initialize() works without disk I/O.
|
||||
vi.mock('../../mcp/oauth-token-storage.js', () => {
|
||||
const MCPOAuthTokenStorage = vi.fn().mockImplementation(() => ({
|
||||
getCredentials: vi.fn().mockResolvedValue(null),
|
||||
saveToken: vi.fn().mockResolvedValue(undefined),
|
||||
deleteCredentials: vi.fn().mockResolvedValue(undefined),
|
||||
isTokenExpired: vi.fn().mockReturnValue(false),
|
||||
}));
|
||||
return { MCPOAuthTokenStorage };
|
||||
});
|
||||
|
||||
describe('A2AAuthProviderFactory', () => {
|
||||
describe('validateAuthConfig', () => {
|
||||
describe('when no security schemes required', () => {
|
||||
@@ -492,5 +503,62 @@ describe('A2AAuthProviderFactory', () => {
|
||||
const headers = await provider!.headers();
|
||||
expect(headers).toEqual({ 'X-API-Key': 'factory-test-key' });
|
||||
});
|
||||
|
||||
it('should create an OAuth2AuthProvider for oauth2 config', async () => {
|
||||
const provider = await A2AAuthProviderFactory.create({
|
||||
agentName: 'my-oauth-agent',
|
||||
authConfig: {
|
||||
type: 'oauth2',
|
||||
client_id: 'my-client',
|
||||
authorization_url: 'https://auth.example.com/authorize',
|
||||
token_url: 'https://auth.example.com/token',
|
||||
scopes: ['read'],
|
||||
},
|
||||
});
|
||||
|
||||
expect(provider).toBeDefined();
|
||||
expect(provider!.type).toBe('oauth2');
|
||||
});
|
||||
|
||||
it('should create an OAuth2AuthProvider with agent card defaults', async () => {
|
||||
const provider = await A2AAuthProviderFactory.create({
|
||||
agentName: 'card-oauth-agent',
|
||||
authConfig: {
|
||||
type: 'oauth2',
|
||||
client_id: 'my-client',
|
||||
},
|
||||
agentCard: {
|
||||
securitySchemes: {
|
||||
oauth: {
|
||||
type: 'oauth2',
|
||||
flows: {
|
||||
authorizationCode: {
|
||||
authorizationUrl: 'https://card.example.com/authorize',
|
||||
tokenUrl: 'https://card.example.com/token',
|
||||
scopes: { read: 'Read access' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as AgentCard,
|
||||
});
|
||||
|
||||
expect(provider).toBeDefined();
|
||||
expect(provider!.type).toBe('oauth2');
|
||||
});
|
||||
|
||||
it('should use "unknown" as agent name when agentName is not provided for oauth2', async () => {
|
||||
const provider = await A2AAuthProviderFactory.create({
|
||||
authConfig: {
|
||||
type: 'oauth2',
|
||||
client_id: 'my-client',
|
||||
authorization_url: 'https://auth.example.com/authorize',
|
||||
token_url: 'https://auth.example.com/token',
|
||||
},
|
||||
});
|
||||
|
||||
expect(provider).toBeDefined();
|
||||
expect(provider!.type).toBe('oauth2');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -18,6 +18,8 @@ export interface CreateAuthProviderOptions {
|
||||
agentName?: string;
|
||||
authConfig?: A2AAuthConfig;
|
||||
agentCard?: AgentCard;
|
||||
/** URL to fetch the agent card from, used for OAuth2 URL discovery. */
|
||||
agentCardUrl?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -57,9 +59,20 @@ export class A2AAuthProviderFactory {
|
||||
return provider;
|
||||
}
|
||||
|
||||
case 'oauth2':
|
||||
// TODO: Implement
|
||||
throw new Error('oauth2 auth provider not yet implemented');
|
||||
case 'oauth2': {
|
||||
// Dynamic import to avoid pulling MCPOAuthTokenStorage into the
|
||||
// factory's static module graph, which causes initialization
|
||||
// conflicts with code_assist/oauth-credential-storage.ts.
|
||||
const { OAuth2AuthProvider } = await import('./oauth2-provider.js');
|
||||
const provider = new OAuth2AuthProvider(
|
||||
authConfig,
|
||||
options.agentName ?? 'unknown',
|
||||
agentCard,
|
||||
options.agentCardUrl,
|
||||
);
|
||||
await provider.initialize();
|
||||
return provider;
|
||||
}
|
||||
|
||||
case 'openIdConnect':
|
||||
// TODO: Implement
|
||||
|
||||
@@ -0,0 +1,651 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { OAuth2AuthProvider } from './oauth2-provider.js';
|
||||
import type { OAuth2AuthConfig } from './types.js';
|
||||
import type { AgentCard } from '@a2a-js/sdk';
|
||||
|
||||
// Mock DefaultAgentCardResolver from @a2a-js/sdk/client.
|
||||
const mockResolve = vi.fn();
|
||||
vi.mock('@a2a-js/sdk/client', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@a2a-js/sdk/client')>();
|
||||
return {
|
||||
...actual,
|
||||
DefaultAgentCardResolver: vi.fn().mockImplementation(() => ({
|
||||
resolve: mockResolve,
|
||||
})),
|
||||
};
|
||||
});
|
||||
|
||||
// Mock all external dependencies.
|
||||
vi.mock('../../mcp/oauth-token-storage.js', () => {
|
||||
const MCPOAuthTokenStorage = vi.fn().mockImplementation(() => ({
|
||||
getCredentials: vi.fn().mockResolvedValue(null),
|
||||
saveToken: vi.fn().mockResolvedValue(undefined),
|
||||
deleteCredentials: vi.fn().mockResolvedValue(undefined),
|
||||
isTokenExpired: vi.fn().mockReturnValue(false),
|
||||
}));
|
||||
return { MCPOAuthTokenStorage };
|
||||
});
|
||||
|
||||
vi.mock('../../utils/oauth-flow.js', () => ({
|
||||
generatePKCEParams: vi.fn().mockReturnValue({
|
||||
codeVerifier: 'test-verifier',
|
||||
codeChallenge: 'test-challenge',
|
||||
state: 'test-state',
|
||||
}),
|
||||
startCallbackServer: vi.fn().mockReturnValue({
|
||||
port: Promise.resolve(12345),
|
||||
response: Promise.resolve({ code: 'test-code', state: 'test-state' }),
|
||||
}),
|
||||
getPortFromUrl: vi.fn().mockReturnValue(undefined),
|
||||
buildAuthorizationUrl: vi
|
||||
.fn()
|
||||
.mockReturnValue('https://auth.example.com/authorize?foo=bar'),
|
||||
exchangeCodeForToken: vi.fn().mockResolvedValue({
|
||||
access_token: 'new-access-token',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
refresh_token: 'new-refresh-token',
|
||||
}),
|
||||
refreshAccessToken: vi.fn().mockResolvedValue({
|
||||
access_token: 'refreshed-access-token',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
refresh_token: 'refreshed-refresh-token',
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock('../../utils/secure-browser-launcher.js', () => ({
|
||||
openBrowserSecurely: vi.fn().mockResolvedValue(undefined),
|
||||
}));
|
||||
|
||||
vi.mock('../../utils/authConsent.js', () => ({
|
||||
getConsentForOauth: vi.fn().mockResolvedValue(true),
|
||||
}));
|
||||
|
||||
vi.mock('../../utils/events.js', () => ({
|
||||
coreEvents: {
|
||||
emitFeedback: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock('../../utils/debugLogger.js', () => ({
|
||||
debugLogger: {
|
||||
debug: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
log: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
// Re-import mocked modules for assertions.
|
||||
const { MCPOAuthTokenStorage } = await import(
|
||||
'../../mcp/oauth-token-storage.js'
|
||||
);
|
||||
const {
|
||||
refreshAccessToken,
|
||||
exchangeCodeForToken,
|
||||
generatePKCEParams,
|
||||
startCallbackServer,
|
||||
buildAuthorizationUrl,
|
||||
} = await import('../../utils/oauth-flow.js');
|
||||
const { getConsentForOauth } = await import('../../utils/authConsent.js');
|
||||
|
||||
function createConfig(
|
||||
overrides: Partial<OAuth2AuthConfig> = {},
|
||||
): OAuth2AuthConfig {
|
||||
return {
|
||||
type: 'oauth2',
|
||||
client_id: 'test-client-id',
|
||||
authorization_url: 'https://auth.example.com/authorize',
|
||||
token_url: 'https://auth.example.com/token',
|
||||
scopes: ['read', 'write'],
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function getTokenStorage() {
|
||||
// Access the mocked MCPOAuthTokenStorage instance created in the constructor.
|
||||
const instance = vi.mocked(MCPOAuthTokenStorage).mock.results.at(-1)!.value;
|
||||
return instance as {
|
||||
getCredentials: ReturnType<typeof vi.fn>;
|
||||
saveToken: ReturnType<typeof vi.fn>;
|
||||
deleteCredentials: ReturnType<typeof vi.fn>;
|
||||
isTokenExpired: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
}
|
||||
|
||||
describe('OAuth2AuthProvider', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should set type to oauth2', () => {
|
||||
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
|
||||
expect(provider.type).toBe('oauth2');
|
||||
});
|
||||
|
||||
it('should use config values for authorization_url and token_url', () => {
|
||||
const config = createConfig({
|
||||
authorization_url: 'https://custom.example.com/authorize',
|
||||
token_url: 'https://custom.example.com/token',
|
||||
});
|
||||
const provider = new OAuth2AuthProvider(config, 'test-agent');
|
||||
// Verify by calling headers which will trigger interactive flow with these URLs.
|
||||
expect(provider.type).toBe('oauth2');
|
||||
});
|
||||
|
||||
it('should merge agent card defaults when config values are missing', () => {
|
||||
const config = createConfig({
|
||||
authorization_url: undefined,
|
||||
token_url: undefined,
|
||||
scopes: undefined,
|
||||
});
|
||||
|
||||
const agentCard = {
|
||||
securitySchemes: {
|
||||
oauth: {
|
||||
type: 'oauth2' as const,
|
||||
flows: {
|
||||
authorizationCode: {
|
||||
authorizationUrl: 'https://card.example.com/authorize',
|
||||
tokenUrl: 'https://card.example.com/token',
|
||||
scopes: { read: 'Read access', write: 'Write access' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as AgentCard;
|
||||
|
||||
const provider = new OAuth2AuthProvider(config, 'test-agent', agentCard);
|
||||
expect(provider.type).toBe('oauth2');
|
||||
});
|
||||
|
||||
it('should prefer config values over agent card values', async () => {
|
||||
const config = createConfig({
|
||||
authorization_url: 'https://config.example.com/authorize',
|
||||
token_url: 'https://config.example.com/token',
|
||||
scopes: ['custom-scope'],
|
||||
});
|
||||
|
||||
const agentCard = {
|
||||
securitySchemes: {
|
||||
oauth: {
|
||||
type: 'oauth2' as const,
|
||||
flows: {
|
||||
authorizationCode: {
|
||||
authorizationUrl: 'https://card.example.com/authorize',
|
||||
tokenUrl: 'https://card.example.com/token',
|
||||
scopes: { read: 'Read access' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as AgentCard;
|
||||
|
||||
const provider = new OAuth2AuthProvider(config, 'test-agent', agentCard);
|
||||
await provider.headers();
|
||||
|
||||
// The config URLs should be used, not the agent card ones.
|
||||
expect(vi.mocked(buildAuthorizationUrl)).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
authorizationUrl: 'https://config.example.com/authorize',
|
||||
tokenUrl: 'https://config.example.com/token',
|
||||
scopes: ['custom-scope'],
|
||||
}),
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('initialize', () => {
|
||||
it('should load a valid token from storage', async () => {
|
||||
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
|
||||
const storage = getTokenStorage();
|
||||
|
||||
storage.getCredentials.mockResolvedValue({
|
||||
serverName: 'test-agent',
|
||||
token: {
|
||||
accessToken: 'stored-token',
|
||||
tokenType: 'Bearer',
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
});
|
||||
storage.isTokenExpired.mockReturnValue(false);
|
||||
|
||||
await provider.initialize();
|
||||
|
||||
const headers = await provider.headers();
|
||||
expect(headers).toEqual({ Authorization: 'Bearer stored-token' });
|
||||
});
|
||||
|
||||
it('should not cache an expired token from storage', async () => {
|
||||
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
|
||||
const storage = getTokenStorage();
|
||||
|
||||
storage.getCredentials.mockResolvedValue({
|
||||
serverName: 'test-agent',
|
||||
token: {
|
||||
accessToken: 'expired-token',
|
||||
tokenType: 'Bearer',
|
||||
expiresAt: Date.now() - 1000,
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
});
|
||||
storage.isTokenExpired.mockReturnValue(true);
|
||||
|
||||
await provider.initialize();
|
||||
|
||||
// Should trigger interactive flow since cached token is null.
|
||||
const headers = await provider.headers();
|
||||
expect(headers).toEqual({ Authorization: 'Bearer new-access-token' });
|
||||
});
|
||||
|
||||
it('should handle no stored credentials gracefully', async () => {
|
||||
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
|
||||
const storage = getTokenStorage();
|
||||
|
||||
storage.getCredentials.mockResolvedValue(null);
|
||||
|
||||
await provider.initialize();
|
||||
|
||||
// Should trigger interactive flow.
|
||||
const headers = await provider.headers();
|
||||
expect(headers).toEqual({ Authorization: 'Bearer new-access-token' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('headers', () => {
|
||||
it('should return cached token if valid', async () => {
|
||||
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
|
||||
const storage = getTokenStorage();
|
||||
|
||||
storage.getCredentials.mockResolvedValue({
|
||||
serverName: 'test-agent',
|
||||
token: { accessToken: 'cached-token', tokenType: 'Bearer' },
|
||||
updatedAt: Date.now(),
|
||||
});
|
||||
storage.isTokenExpired.mockReturnValue(false);
|
||||
|
||||
await provider.initialize();
|
||||
|
||||
const headers = await provider.headers();
|
||||
expect(headers).toEqual({ Authorization: 'Bearer cached-token' });
|
||||
expect(vi.mocked(exchangeCodeForToken)).not.toHaveBeenCalled();
|
||||
expect(vi.mocked(refreshAccessToken)).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should refresh token when expired with refresh_token available', async () => {
|
||||
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
|
||||
const storage = getTokenStorage();
|
||||
|
||||
// First call: load from storage (expired but with refresh token).
|
||||
storage.getCredentials.mockResolvedValue({
|
||||
serverName: 'test-agent',
|
||||
token: {
|
||||
accessToken: 'expired-token',
|
||||
tokenType: 'Bearer',
|
||||
refreshToken: 'my-refresh-token',
|
||||
expiresAt: Date.now() - 1000,
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
});
|
||||
// isTokenExpired: false for initialize (to cache it), true for headers check.
|
||||
storage.isTokenExpired
|
||||
.mockReturnValueOnce(false) // initialize: cache the token
|
||||
.mockReturnValueOnce(true); // headers: token is expired
|
||||
|
||||
await provider.initialize();
|
||||
const headers = await provider.headers();
|
||||
|
||||
expect(vi.mocked(refreshAccessToken)).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ clientId: 'test-client-id' }),
|
||||
'my-refresh-token',
|
||||
'https://auth.example.com/token',
|
||||
);
|
||||
expect(headers).toEqual({
|
||||
Authorization: 'Bearer refreshed-access-token',
|
||||
});
|
||||
expect(storage.saveToken).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should fall back to interactive flow when refresh fails', async () => {
|
||||
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
|
||||
const storage = getTokenStorage();
|
||||
|
||||
storage.getCredentials.mockResolvedValue({
|
||||
serverName: 'test-agent',
|
||||
token: {
|
||||
accessToken: 'expired-token',
|
||||
tokenType: 'Bearer',
|
||||
refreshToken: 'bad-refresh-token',
|
||||
expiresAt: Date.now() - 1000,
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
});
|
||||
storage.isTokenExpired
|
||||
.mockReturnValueOnce(false) // initialize
|
||||
.mockReturnValueOnce(true); // headers
|
||||
|
||||
vi.mocked(refreshAccessToken).mockRejectedValueOnce(
|
||||
new Error('Refresh failed'),
|
||||
);
|
||||
|
||||
await provider.initialize();
|
||||
const headers = await provider.headers();
|
||||
|
||||
// Should have deleted stale credentials and done interactive flow.
|
||||
expect(storage.deleteCredentials).toHaveBeenCalledWith('test-agent');
|
||||
expect(headers).toEqual({ Authorization: 'Bearer new-access-token' });
|
||||
});
|
||||
|
||||
it('should trigger interactive flow when no token exists', async () => {
|
||||
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
|
||||
const storage = getTokenStorage();
|
||||
|
||||
storage.getCredentials.mockResolvedValue(null);
|
||||
|
||||
await provider.initialize();
|
||||
const headers = await provider.headers();
|
||||
|
||||
expect(vi.mocked(generatePKCEParams)).toHaveBeenCalled();
|
||||
expect(vi.mocked(startCallbackServer)).toHaveBeenCalled();
|
||||
expect(vi.mocked(exchangeCodeForToken)).toHaveBeenCalled();
|
||||
expect(storage.saveToken).toHaveBeenCalledWith(
|
||||
'test-agent',
|
||||
expect.objectContaining({ accessToken: 'new-access-token' }),
|
||||
'test-client-id',
|
||||
'https://auth.example.com/token',
|
||||
);
|
||||
expect(headers).toEqual({ Authorization: 'Bearer new-access-token' });
|
||||
});
|
||||
|
||||
it('should throw when user declines consent', async () => {
|
||||
vi.mocked(getConsentForOauth).mockResolvedValueOnce(false);
|
||||
|
||||
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
|
||||
await provider.initialize();
|
||||
|
||||
await expect(provider.headers()).rejects.toThrow(
|
||||
'Authentication cancelled by user',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw when client_id is missing', async () => {
|
||||
const config = createConfig({ client_id: undefined });
|
||||
const provider = new OAuth2AuthProvider(config, 'test-agent');
|
||||
await provider.initialize();
|
||||
|
||||
await expect(provider.headers()).rejects.toThrow(/requires a client_id/);
|
||||
});
|
||||
|
||||
it('should throw when authorization_url and token_url are missing', async () => {
|
||||
const config = createConfig({
|
||||
authorization_url: undefined,
|
||||
token_url: undefined,
|
||||
});
|
||||
const provider = new OAuth2AuthProvider(config, 'test-agent');
|
||||
await provider.initialize();
|
||||
|
||||
await expect(provider.headers()).rejects.toThrow(
|
||||
/requires authorization_url and token_url/,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('shouldRetryWithHeaders', () => {
|
||||
it('should clear token and re-authenticate on 401', async () => {
|
||||
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
|
||||
const storage = getTokenStorage();
|
||||
|
||||
storage.getCredentials.mockResolvedValue({
|
||||
serverName: 'test-agent',
|
||||
token: { accessToken: 'old-token', tokenType: 'Bearer' },
|
||||
updatedAt: Date.now(),
|
||||
});
|
||||
storage.isTokenExpired.mockReturnValue(false);
|
||||
|
||||
await provider.initialize();
|
||||
|
||||
const res = new Response(null, { status: 401 });
|
||||
const retryHeaders = await provider.shouldRetryWithHeaders({}, res);
|
||||
|
||||
expect(storage.deleteCredentials).toHaveBeenCalledWith('test-agent');
|
||||
expect(retryHeaders).toBeDefined();
|
||||
expect(retryHeaders).toHaveProperty('Authorization');
|
||||
});
|
||||
|
||||
it('should clear token and re-authenticate on 403', async () => {
|
||||
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
|
||||
const storage = getTokenStorage();
|
||||
|
||||
storage.getCredentials.mockResolvedValue({
|
||||
serverName: 'test-agent',
|
||||
token: { accessToken: 'old-token', tokenType: 'Bearer' },
|
||||
updatedAt: Date.now(),
|
||||
});
|
||||
storage.isTokenExpired.mockReturnValue(false);
|
||||
|
||||
await provider.initialize();
|
||||
|
||||
const res = new Response(null, { status: 403 });
|
||||
const retryHeaders = await provider.shouldRetryWithHeaders({}, res);
|
||||
|
||||
expect(retryHeaders).toBeDefined();
|
||||
});
|
||||
|
||||
it('should return undefined for non-auth errors', async () => {
|
||||
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
|
||||
|
||||
const res = new Response(null, { status: 500 });
|
||||
const retryHeaders = await provider.shouldRetryWithHeaders({}, res);
|
||||
|
||||
expect(retryHeaders).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should respect MAX_AUTH_RETRIES', async () => {
|
||||
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
|
||||
|
||||
const res401 = new Response(null, { status: 401 });
|
||||
|
||||
// First retry — should succeed.
|
||||
const first = await provider.shouldRetryWithHeaders({}, res401);
|
||||
expect(first).toBeDefined();
|
||||
|
||||
// Second retry — should succeed.
|
||||
const second = await provider.shouldRetryWithHeaders({}, res401);
|
||||
expect(second).toBeDefined();
|
||||
|
||||
// Third retry — should be blocked.
|
||||
const third = await provider.shouldRetryWithHeaders({}, res401);
|
||||
expect(third).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should reset retry count on non-auth response', async () => {
|
||||
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
|
||||
|
||||
const res401 = new Response(null, { status: 401 });
|
||||
const res200 = new Response(null, { status: 200 });
|
||||
|
||||
await provider.shouldRetryWithHeaders({}, res401);
|
||||
await provider.shouldRetryWithHeaders({}, res200); // resets
|
||||
|
||||
// Should be able to retry again.
|
||||
const result = await provider.shouldRetryWithHeaders({}, res401);
|
||||
expect(result).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('token persistence', () => {
|
||||
it('should persist token after successful interactive auth', async () => {
|
||||
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
|
||||
const storage = getTokenStorage();
|
||||
|
||||
await provider.initialize();
|
||||
await provider.headers();
|
||||
|
||||
expect(storage.saveToken).toHaveBeenCalledWith(
|
||||
'test-agent',
|
||||
expect.objectContaining({
|
||||
accessToken: 'new-access-token',
|
||||
tokenType: 'Bearer',
|
||||
refreshToken: 'new-refresh-token',
|
||||
}),
|
||||
'test-client-id',
|
||||
'https://auth.example.com/token',
|
||||
);
|
||||
});
|
||||
|
||||
it('should persist token after successful refresh', async () => {
|
||||
const provider = new OAuth2AuthProvider(createConfig(), 'test-agent');
|
||||
const storage = getTokenStorage();
|
||||
|
||||
storage.getCredentials.mockResolvedValue({
|
||||
serverName: 'test-agent',
|
||||
token: {
|
||||
accessToken: 'expired-token',
|
||||
tokenType: 'Bearer',
|
||||
refreshToken: 'my-refresh-token',
|
||||
},
|
||||
updatedAt: Date.now(),
|
||||
});
|
||||
storage.isTokenExpired
|
||||
.mockReturnValueOnce(false)
|
||||
.mockReturnValueOnce(true);
|
||||
|
||||
await provider.initialize();
|
||||
await provider.headers();
|
||||
|
||||
expect(storage.saveToken).toHaveBeenCalledWith(
|
||||
'test-agent',
|
||||
expect.objectContaining({
|
||||
accessToken: 'refreshed-access-token',
|
||||
}),
|
||||
'test-client-id',
|
||||
'https://auth.example.com/token',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('agent card integration', () => {
|
||||
it('should discover URLs from agent card when not in config', async () => {
|
||||
const config = createConfig({
|
||||
authorization_url: undefined,
|
||||
token_url: undefined,
|
||||
scopes: undefined,
|
||||
});
|
||||
|
||||
const agentCard = {
|
||||
securitySchemes: {
|
||||
myOauth: {
|
||||
type: 'oauth2' as const,
|
||||
flows: {
|
||||
authorizationCode: {
|
||||
authorizationUrl: 'https://card.example.com/auth',
|
||||
tokenUrl: 'https://card.example.com/token',
|
||||
scopes: { profile: 'View profile', email: 'View email' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as AgentCard;
|
||||
|
||||
const provider = new OAuth2AuthProvider(config, 'card-agent', agentCard);
|
||||
await provider.initialize();
|
||||
await provider.headers();
|
||||
|
||||
expect(vi.mocked(buildAuthorizationUrl)).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
authorizationUrl: 'https://card.example.com/auth',
|
||||
tokenUrl: 'https://card.example.com/token',
|
||||
scopes: ['profile', 'email'],
|
||||
}),
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should discover URLs from agentCardUrl via DefaultAgentCardResolver during initialize', async () => {
|
||||
const config = createConfig({
|
||||
authorization_url: undefined,
|
||||
token_url: undefined,
|
||||
scopes: undefined,
|
||||
});
|
||||
|
||||
// Simulate a normalized agent card returned by DefaultAgentCardResolver.
|
||||
mockResolve.mockResolvedValue({
|
||||
securitySchemes: {
|
||||
myOauth: {
|
||||
type: 'oauth2' as const,
|
||||
flows: {
|
||||
authorizationCode: {
|
||||
authorizationUrl: 'https://discovered.example.com/auth',
|
||||
tokenUrl: 'https://discovered.example.com/token',
|
||||
scopes: { openid: 'OpenID', profile: 'Profile' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as AgentCard);
|
||||
|
||||
// No agentCard passed to constructor — only agentCardUrl.
|
||||
const provider = new OAuth2AuthProvider(
|
||||
config,
|
||||
'discover-agent',
|
||||
undefined,
|
||||
'https://example.com/.well-known/agent-card.json',
|
||||
);
|
||||
await provider.initialize();
|
||||
await provider.headers();
|
||||
|
||||
expect(mockResolve).toHaveBeenCalledWith(
|
||||
'https://example.com/.well-known/agent-card.json',
|
||||
'',
|
||||
);
|
||||
expect(vi.mocked(buildAuthorizationUrl)).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
authorizationUrl: 'https://discovered.example.com/auth',
|
||||
tokenUrl: 'https://discovered.example.com/token',
|
||||
scopes: ['openid', 'profile'],
|
||||
}),
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should ignore agent card with no authorizationCode flow', () => {
|
||||
const config = createConfig({
|
||||
authorization_url: undefined,
|
||||
token_url: undefined,
|
||||
});
|
||||
|
||||
const agentCard = {
|
||||
securitySchemes: {
|
||||
myOauth: {
|
||||
type: 'oauth2' as const,
|
||||
flows: {
|
||||
clientCredentials: {
|
||||
tokenUrl: 'https://card.example.com/token',
|
||||
scopes: {},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as AgentCard;
|
||||
|
||||
// Should not throw — just won't have URLs.
|
||||
const provider = new OAuth2AuthProvider(config, 'card-agent', agentCard);
|
||||
expect(provider.type).toBe('oauth2');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,340 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type HttpHeaders, DefaultAgentCardResolver } from '@a2a-js/sdk/client';
|
||||
import type { AgentCard } from '@a2a-js/sdk';
|
||||
import { BaseA2AAuthProvider } from './base-provider.js';
|
||||
import type { OAuth2AuthConfig } from './types.js';
|
||||
import { MCPOAuthTokenStorage } from '../../mcp/oauth-token-storage.js';
|
||||
import type { OAuthToken } from '../../mcp/token-storage/types.js';
|
||||
import {
|
||||
generatePKCEParams,
|
||||
startCallbackServer,
|
||||
getPortFromUrl,
|
||||
buildAuthorizationUrl,
|
||||
exchangeCodeForToken,
|
||||
refreshAccessToken,
|
||||
type OAuthFlowConfig,
|
||||
} from '../../utils/oauth-flow.js';
|
||||
import { openBrowserSecurely } from '../../utils/secure-browser-launcher.js';
|
||||
import { getConsentForOauth } from '../../utils/authConsent.js';
|
||||
import { FatalCancellationError, getErrorMessage } from '../../utils/errors.js';
|
||||
import { coreEvents } from '../../utils/events.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
import { Storage } from '../../config/storage.js';
|
||||
|
||||
/**
|
||||
* Authentication provider for OAuth 2.0 Authorization Code flow with PKCE.
|
||||
*
|
||||
* Used by A2A remote agents whose security scheme is `oauth2`.
|
||||
* Reuses the shared OAuth flow primitives from `utils/oauth-flow.ts`
|
||||
* and persists tokens via `MCPOAuthTokenStorage`.
|
||||
*/
|
||||
export class OAuth2AuthProvider extends BaseA2AAuthProvider {
|
||||
readonly type = 'oauth2' as const;
|
||||
|
||||
private readonly tokenStorage: MCPOAuthTokenStorage;
|
||||
private cachedToken: OAuthToken | null = null;
|
||||
|
||||
/** Resolved OAuth URLs — may come from config or agent card. */
|
||||
private authorizationUrl: string | undefined;
|
||||
private tokenUrl: string | undefined;
|
||||
private scopes: string[] | undefined;
|
||||
|
||||
constructor(
|
||||
private readonly config: OAuth2AuthConfig,
|
||||
private readonly agentName: string,
|
||||
agentCard?: AgentCard,
|
||||
private readonly agentCardUrl?: string,
|
||||
) {
|
||||
super();
|
||||
this.tokenStorage = new MCPOAuthTokenStorage(
|
||||
Storage.getA2AOAuthTokensPath(),
|
||||
'gemini-cli-a2a',
|
||||
);
|
||||
|
||||
// Seed from user config.
|
||||
this.authorizationUrl = config.authorization_url;
|
||||
this.tokenUrl = config.token_url;
|
||||
this.scopes = config.scopes;
|
||||
|
||||
// Fall back to agent card's OAuth2 security scheme if user config is incomplete.
|
||||
this.mergeAgentCardDefaults(agentCard);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the provider by loading any persisted token from storage.
|
||||
* Also discovers OAuth URLs from the agent card if not yet resolved.
|
||||
*/
|
||||
override async initialize(): Promise<void> {
|
||||
// If OAuth URLs are still missing, fetch the agent card to discover them.
|
||||
if ((!this.authorizationUrl || !this.tokenUrl) && this.agentCardUrl) {
|
||||
await this.fetchAgentCardDefaults();
|
||||
}
|
||||
|
||||
const credentials = await this.tokenStorage.getCredentials(this.agentName);
|
||||
if (credentials && !this.tokenStorage.isTokenExpired(credentials.token)) {
|
||||
this.cachedToken = credentials.token;
|
||||
debugLogger.debug(
|
||||
`[OAuth2AuthProvider] Loaded valid cached token for "${this.agentName}"`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return an Authorization header with a valid Bearer token.
|
||||
* Refreshes or triggers interactive auth as needed.
|
||||
*/
|
||||
override async headers(): Promise<HttpHeaders> {
|
||||
// 1. Valid cached token → return immediately.
|
||||
if (
|
||||
this.cachedToken &&
|
||||
!this.tokenStorage.isTokenExpired(this.cachedToken)
|
||||
) {
|
||||
return { Authorization: `Bearer ${this.cachedToken.accessToken}` };
|
||||
}
|
||||
|
||||
// 2. Expired but has refresh token → attempt silent refresh.
|
||||
if (
|
||||
this.cachedToken?.refreshToken &&
|
||||
this.tokenUrl &&
|
||||
this.config.client_id
|
||||
) {
|
||||
try {
|
||||
const refreshed = await refreshAccessToken(
|
||||
{
|
||||
clientId: this.config.client_id,
|
||||
clientSecret: this.config.client_secret,
|
||||
scopes: this.scopes,
|
||||
},
|
||||
this.cachedToken.refreshToken,
|
||||
this.tokenUrl,
|
||||
);
|
||||
|
||||
this.cachedToken = this.toOAuthToken(
|
||||
refreshed,
|
||||
this.cachedToken.refreshToken,
|
||||
);
|
||||
await this.persistToken();
|
||||
return { Authorization: `Bearer ${this.cachedToken.accessToken}` };
|
||||
} catch (error) {
|
||||
debugLogger.debug(
|
||||
`[OAuth2AuthProvider] Refresh failed, falling back to interactive flow: ${getErrorMessage(error)}`,
|
||||
);
|
||||
// Clear stale credentials and fall through to interactive flow.
|
||||
await this.tokenStorage.deleteCredentials(this.agentName);
|
||||
}
|
||||
}
|
||||
|
||||
// 3. No valid token → interactive browser-based auth.
|
||||
this.cachedToken = await this.authenticateInteractively();
|
||||
return { Authorization: `Bearer ${this.cachedToken.accessToken}` };
|
||||
}
|
||||
|
||||
/**
|
||||
* On 401/403, clear the cached token and re-authenticate (up to MAX_AUTH_RETRIES).
|
||||
*/
|
||||
override async shouldRetryWithHeaders(
|
||||
_req: RequestInit,
|
||||
res: Response,
|
||||
): Promise<HttpHeaders | undefined> {
|
||||
if (res.status !== 401 && res.status !== 403) {
|
||||
this.authRetryCount = 0;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (this.authRetryCount >= BaseA2AAuthProvider.MAX_AUTH_RETRIES) {
|
||||
return undefined;
|
||||
}
|
||||
this.authRetryCount++;
|
||||
|
||||
debugLogger.debug(
|
||||
'[OAuth2AuthProvider] Auth failure, clearing token and re-authenticating',
|
||||
);
|
||||
this.cachedToken = null;
|
||||
await this.tokenStorage.deleteCredentials(this.agentName);
|
||||
|
||||
return this.headers();
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Private helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Merge authorization_url, token_url, and scopes from the agent card's
|
||||
* `securitySchemes` when not already provided via user config.
|
||||
*/
|
||||
private mergeAgentCardDefaults(
|
||||
agentCard?: Pick<AgentCard, 'securitySchemes'> | null,
|
||||
): void {
|
||||
if (!agentCard?.securitySchemes) return;
|
||||
|
||||
for (const scheme of Object.values(agentCard.securitySchemes)) {
|
||||
if (scheme.type === 'oauth2' && scheme.flows.authorizationCode) {
|
||||
const flow = scheme.flows.authorizationCode;
|
||||
this.authorizationUrl ??= flow.authorizationUrl;
|
||||
this.tokenUrl ??= flow.tokenUrl;
|
||||
this.scopes ??= Object.keys(flow.scopes);
|
||||
break; // Use the first matching scheme.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch the agent card from `agentCardUrl` using `DefaultAgentCardResolver`
|
||||
* (which normalizes proto-format cards) and extract OAuth2 URLs.
|
||||
*/
|
||||
private async fetchAgentCardDefaults(): Promise<void> {
|
||||
if (!this.agentCardUrl) return;
|
||||
|
||||
try {
|
||||
debugLogger.debug(
|
||||
`[OAuth2AuthProvider] Fetching agent card from ${this.agentCardUrl}`,
|
||||
);
|
||||
const resolver = new DefaultAgentCardResolver();
|
||||
const card = await resolver.resolve(this.agentCardUrl, '');
|
||||
this.mergeAgentCardDefaults(card);
|
||||
} catch (error) {
|
||||
debugLogger.warn(
|
||||
`[OAuth2AuthProvider] Could not fetch agent card for OAuth URL discovery: ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Run a full OAuth 2.0 Authorization Code + PKCE flow through the browser.
|
||||
*/
|
||||
private async authenticateInteractively(): Promise<OAuthToken> {
|
||||
if (!this.config.client_id) {
|
||||
throw new Error(
|
||||
`OAuth2 authentication for agent "${this.agentName}" requires a client_id. ` +
|
||||
'Add client_id to the auth config in your agent definition.',
|
||||
);
|
||||
}
|
||||
if (!this.authorizationUrl || !this.tokenUrl) {
|
||||
throw new Error(
|
||||
`OAuth2 authentication for agent "${this.agentName}" requires authorization_url and token_url. ` +
|
||||
'Provide them in the auth config or ensure the agent card exposes an oauth2 security scheme.',
|
||||
);
|
||||
}
|
||||
|
||||
const flowConfig: OAuthFlowConfig = {
|
||||
clientId: this.config.client_id,
|
||||
clientSecret: this.config.client_secret,
|
||||
authorizationUrl: this.authorizationUrl,
|
||||
tokenUrl: this.tokenUrl,
|
||||
scopes: this.scopes,
|
||||
};
|
||||
|
||||
const pkceParams = generatePKCEParams();
|
||||
const preferredPort = getPortFromUrl(flowConfig.redirectUri);
|
||||
const callbackServer = startCallbackServer(pkceParams.state, preferredPort);
|
||||
const redirectPort = await callbackServer.port;
|
||||
|
||||
const authUrl = buildAuthorizationUrl(
|
||||
flowConfig,
|
||||
pkceParams,
|
||||
redirectPort,
|
||||
/* resource= */ undefined, // No MCP resource parameter for A2A.
|
||||
);
|
||||
|
||||
const consent = await getConsentForOauth(
|
||||
`Authentication required for A2A agent: '${this.agentName}'.`,
|
||||
);
|
||||
if (!consent) {
|
||||
throw new FatalCancellationError('Authentication cancelled by user.');
|
||||
}
|
||||
|
||||
coreEvents.emitFeedback(
|
||||
'info',
|
||||
`→ Opening your browser for OAuth sign-in...
|
||||
|
||||
` +
|
||||
`If the browser does not open, copy and paste this URL into your browser:
|
||||
` +
|
||||
`${authUrl}
|
||||
|
||||
` +
|
||||
`💡 TIP: Triple-click to select the entire URL, then copy and paste it into your browser.
|
||||
` +
|
||||
`⚠️ Make sure to copy the COMPLETE URL - it may wrap across multiple lines.`,
|
||||
);
|
||||
|
||||
try {
|
||||
await openBrowserSecurely(authUrl);
|
||||
} catch (error) {
|
||||
debugLogger.warn(
|
||||
'Failed to open browser automatically:',
|
||||
getErrorMessage(error),
|
||||
);
|
||||
}
|
||||
|
||||
const { code } = await callbackServer.response;
|
||||
debugLogger.debug(
|
||||
'✓ Authorization code received, exchanging for tokens...',
|
||||
);
|
||||
|
||||
const tokenResponse = await exchangeCodeForToken(
|
||||
flowConfig,
|
||||
code,
|
||||
pkceParams.codeVerifier,
|
||||
redirectPort,
|
||||
/* resource= */ undefined,
|
||||
);
|
||||
|
||||
if (!tokenResponse.access_token) {
|
||||
throw new Error('No access token received from token endpoint');
|
||||
}
|
||||
|
||||
const token = this.toOAuthToken(tokenResponse);
|
||||
this.cachedToken = token;
|
||||
await this.persistToken();
|
||||
|
||||
debugLogger.debug('✓ OAuth2 authentication successful! Token saved.');
|
||||
return token;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert an `OAuthTokenResponse` into the internal `OAuthToken` format.
|
||||
*/
|
||||
private toOAuthToken(
|
||||
response: {
|
||||
access_token: string;
|
||||
token_type?: string;
|
||||
expires_in?: number;
|
||||
refresh_token?: string;
|
||||
scope?: string;
|
||||
},
|
||||
fallbackRefreshToken?: string,
|
||||
): OAuthToken {
|
||||
const token: OAuthToken = {
|
||||
accessToken: response.access_token,
|
||||
tokenType: response.token_type || 'Bearer',
|
||||
refreshToken: response.refresh_token || fallbackRefreshToken,
|
||||
scope: response.scope,
|
||||
};
|
||||
|
||||
if (response.expires_in) {
|
||||
token.expiresAt = Date.now() + response.expires_in * 1000;
|
||||
}
|
||||
|
||||
return token;
|
||||
}
|
||||
|
||||
/**
|
||||
* Persist the current cached token to disk.
|
||||
*/
|
||||
private async persistToken(): Promise<void> {
|
||||
if (!this.cachedToken) return;
|
||||
await this.tokenStorage.saveToken(
|
||||
this.agentName,
|
||||
this.cachedToken,
|
||||
this.config.client_id,
|
||||
this.tokenUrl,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -74,6 +74,10 @@ export interface OAuth2AuthConfig extends BaseAuthConfig {
|
||||
client_id?: string;
|
||||
client_secret?: string;
|
||||
scopes?: string[];
|
||||
/** Override or provide the authorization endpoint URL. Discovered from agent card if omitted. */
|
||||
authorization_url?: string;
|
||||
/** Override or provide the token endpoint URL. Discovered from agent card if omitted. */
|
||||
token_url?: string;
|
||||
}
|
||||
|
||||
/** Client config corresponding to OpenIdConnectSecurityScheme. */
|
||||
|
||||
@@ -591,6 +591,7 @@ describe('AgentRegistry', () => {
|
||||
expect(A2AAuthProviderFactory.create).toHaveBeenCalledWith({
|
||||
authConfig: mockAuth,
|
||||
agentName: 'RemoteAgentWithAuth',
|
||||
agentCardUrl: 'https://example.com/card',
|
||||
});
|
||||
expect(loadAgentSpy).toHaveBeenCalledWith(
|
||||
'RemoteAgentWithAuth',
|
||||
|
||||
@@ -416,6 +416,7 @@ export class AgentRegistry {
|
||||
const provider = await A2AAuthProviderFactory.create({
|
||||
authConfig: definition.auth,
|
||||
agentName: definition.name,
|
||||
agentCardUrl: remoteDef.agentCardUrl,
|
||||
});
|
||||
if (!provider) {
|
||||
throw new Error(
|
||||
|
||||
@@ -195,6 +195,7 @@ describe('RemoteAgentInvocation', () => {
|
||||
expect(A2AAuthProviderFactory.create).toHaveBeenCalledWith({
|
||||
authConfig: mockAuth,
|
||||
agentName: 'test-agent',
|
||||
agentCardUrl: 'http://test-agent/card',
|
||||
});
|
||||
expect(mockClientManager.loadAgent).toHaveBeenCalledWith(
|
||||
'test-agent',
|
||||
|
||||
@@ -120,6 +120,7 @@ export class RemoteAgentInvocation extends BaseToolInvocation<
|
||||
const provider = await A2AAuthProviderFactory.create({
|
||||
authConfig: this.definition.auth,
|
||||
agentName: this.definition.name,
|
||||
agentCardUrl: this.definition.agentCardUrl,
|
||||
});
|
||||
if (!provider) {
|
||||
throw new Error(
|
||||
|
||||
@@ -62,6 +62,10 @@ export class Storage {
|
||||
return path.join(Storage.getGlobalGeminiDir(), 'mcp-oauth-tokens.json');
|
||||
}
|
||||
|
||||
static getA2AOAuthTokensPath(): string {
|
||||
return path.join(Storage.getGlobalGeminiDir(), 'a2a-oauth-tokens.json');
|
||||
}
|
||||
|
||||
static getGlobalSettingsPath(): string {
|
||||
return path.join(Storage.getGlobalGeminiDir(), 'settings.json');
|
||||
}
|
||||
|
||||
@@ -1154,6 +1154,7 @@ describe('GeminiChat', () => {
|
||||
1,
|
||||
);
|
||||
expect(mockLogContentRetry).not.toHaveBeenCalled();
|
||||
expect(mockLogContentRetryFailure).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should yield a RETRY event when an invalid stream is encountered', async () => {
|
||||
|
||||
@@ -344,8 +344,6 @@ export class GeminiChat {
|
||||
this: GeminiChat,
|
||||
): AsyncGenerator<StreamEvent, void, void> {
|
||||
try {
|
||||
let lastError: unknown = new Error('Request failed after all retries.');
|
||||
|
||||
const maxAttempts = INVALID_CONTENT_RETRY_OPTIONS.maxAttempts;
|
||||
|
||||
for (let attempt = 0; attempt < maxAttempts; attempt++) {
|
||||
@@ -374,15 +372,13 @@ export class GeminiChat {
|
||||
yield { type: StreamEventType.CHUNK, value: chunk };
|
||||
}
|
||||
|
||||
lastError = null;
|
||||
break;
|
||||
return;
|
||||
} catch (error) {
|
||||
if (error instanceof AgentExecutionStoppedError) {
|
||||
yield {
|
||||
type: StreamEventType.AGENT_EXECUTION_STOPPED,
|
||||
reason: error.reason,
|
||||
};
|
||||
lastError = null; // Clear error as this is an expected stop
|
||||
return; // Stop the generator
|
||||
}
|
||||
|
||||
@@ -397,7 +393,6 @@ export class GeminiChat {
|
||||
value: error.syntheticResponse,
|
||||
};
|
||||
}
|
||||
lastError = null; // Clear error as this is an expected stop
|
||||
return; // Stop the generator
|
||||
}
|
||||
|
||||
@@ -415,8 +410,9 @@ export class GeminiChat {
|
||||
}
|
||||
// Fall through to retry logic for retryable connection errors
|
||||
}
|
||||
lastError = error;
|
||||
|
||||
const isContentError = error instanceof InvalidStreamError;
|
||||
const errorType = isContentError ? error.type : 'NETWORK_ERROR';
|
||||
|
||||
if (
|
||||
(isContentError && isGemini2Model(model)) ||
|
||||
@@ -425,11 +421,10 @@ export class GeminiChat {
|
||||
// Check if we have more attempts left.
|
||||
if (attempt < maxAttempts - 1) {
|
||||
const delayMs = INVALID_CONTENT_RETRY_OPTIONS.initialDelayMs;
|
||||
const retryType = isContentError ? error.type : 'NETWORK_ERROR';
|
||||
|
||||
logContentRetry(
|
||||
this.config,
|
||||
new ContentRetryEvent(attempt, retryType, delayMs, model),
|
||||
new ContentRetryEvent(attempt, errorType, delayMs, model),
|
||||
);
|
||||
coreEvents.emitRetryAttempt({
|
||||
attempt: attempt + 1,
|
||||
@@ -444,21 +439,19 @@ export class GeminiChat {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (lastError) {
|
||||
if (
|
||||
lastError instanceof InvalidStreamError &&
|
||||
isGemini2Model(model)
|
||||
) {
|
||||
// If we've aborted, we throw without logging a failure.
|
||||
if (signal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
logContentRetryFailure(
|
||||
this.config,
|
||||
new ContentRetryFailureEvent(maxAttempts, lastError.type, model),
|
||||
new ContentRetryFailureEvent(attempt + 1, errorType, model),
|
||||
);
|
||||
|
||||
throw error;
|
||||
}
|
||||
throw lastError;
|
||||
}
|
||||
} finally {
|
||||
streamDoneResolver!();
|
||||
|
||||
@@ -401,6 +401,7 @@ describe('GeminiChat Network Retries', () => {
|
||||
|
||||
// Should only be called once (no retry)
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(1);
|
||||
expect(mockLogContentRetryFailure).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should retry on SSL error during stream iteration (mid-stream failure)', async () => {
|
||||
|
||||
@@ -44,6 +44,7 @@ vi.mock('../telemetry/index.js', () => ({
|
||||
}));
|
||||
vi.mock('../utils/secure-browser-launcher.js', () => ({
|
||||
openBrowserSecurely: vi.fn(),
|
||||
shouldLaunchBrowser: vi.fn().mockReturnValue(true),
|
||||
}));
|
||||
|
||||
// Mock debugLogger to prevent console pollution and allow spying
|
||||
|
||||
@@ -6,7 +6,10 @@
|
||||
|
||||
import type { Config } from '../config/config.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import { openBrowserSecurely } from '../utils/secure-browser-launcher.js';
|
||||
import {
|
||||
openBrowserSecurely,
|
||||
shouldLaunchBrowser,
|
||||
} from '../utils/secure-browser-launcher.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import type { FallbackIntent, FallbackRecommendation } from './types.js';
|
||||
@@ -112,6 +115,12 @@ export async function handleFallback(
|
||||
}
|
||||
|
||||
async function handleUpgrade() {
|
||||
if (!shouldLaunchBrowser()) {
|
||||
debugLogger.log(
|
||||
`Cannot open browser in this environment. Please visit: ${UPGRADE_URL_PAGE}`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
await openBrowserSecurely(UPGRADE_URL_PAGE);
|
||||
} catch (error) {
|
||||
|
||||
@@ -21,14 +21,23 @@ import {
|
||||
} from './token-storage/index.js';
|
||||
|
||||
/**
|
||||
* Class for managing MCP OAuth token storage and retrieval.
|
||||
* Class for managing OAuth token storage and retrieval.
|
||||
* Used by both MCP and A2A OAuth providers. Pass a custom `tokenFilePath`
|
||||
* to store tokens in a protocol-specific file.
|
||||
*/
|
||||
export class MCPOAuthTokenStorage implements TokenStorage {
|
||||
private readonly hybridTokenStorage = new HybridTokenStorage(
|
||||
DEFAULT_SERVICE_NAME,
|
||||
);
|
||||
private readonly hybridTokenStorage: HybridTokenStorage;
|
||||
private readonly useEncryptedFile =
|
||||
process.env[FORCE_ENCRYPTED_FILE_ENV_VAR] === 'true';
|
||||
private readonly customTokenFilePath?: string;
|
||||
|
||||
constructor(
|
||||
tokenFilePath?: string,
|
||||
serviceName: string = DEFAULT_SERVICE_NAME,
|
||||
) {
|
||||
this.customTokenFilePath = tokenFilePath;
|
||||
this.hybridTokenStorage = new HybridTokenStorage(serviceName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the path to the token storage file.
|
||||
@@ -36,7 +45,7 @@ export class MCPOAuthTokenStorage implements TokenStorage {
|
||||
* @returns The full path to the token storage file
|
||||
*/
|
||||
private getTokenFilePath(): string {
|
||||
return Storage.getMcpOAuthTokensPath();
|
||||
return this.customTokenFilePath ?? Storage.getMcpOAuthTokensPath();
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -195,6 +195,9 @@ describe('ClearcutLogger', () => {
|
||||
vi.stubEnv('MONOSPACE_ENV', '');
|
||||
vi.stubEnv('REPLIT_USER', '');
|
||||
vi.stubEnv('__COG_BASHRC_SOURCED', '');
|
||||
vi.stubEnv('GH_PR_NUMBER', '');
|
||||
vi.stubEnv('GH_ISSUE_NUMBER', '');
|
||||
vi.stubEnv('GH_CUSTOM_TRACKING_ID', '');
|
||||
});
|
||||
|
||||
function setup({
|
||||
@@ -596,6 +599,110 @@ describe('ClearcutLogger', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('GITHUB_EVENT_NAME metadata', () => {
|
||||
it('includes event name when GITHUB_EVENT_NAME is set', () => {
|
||||
const { logger } = setup({});
|
||||
vi.stubEnv('GITHUB_EVENT_NAME', 'issues');
|
||||
|
||||
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
|
||||
expect(event?.event_metadata[0]).toContainEqual({
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GH_EVENT_NAME,
|
||||
value: 'issues',
|
||||
});
|
||||
});
|
||||
|
||||
it('does not include event name when GITHUB_EVENT_NAME is not set', () => {
|
||||
const { logger } = setup({});
|
||||
vi.stubEnv('GITHUB_EVENT_NAME', undefined);
|
||||
|
||||
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
|
||||
const hasEventName = event?.event_metadata[0].some(
|
||||
(item) =>
|
||||
item.gemini_cli_key === EventMetadataKey.GEMINI_CLI_GH_EVENT_NAME,
|
||||
);
|
||||
expect(hasEventName).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('GH_PR_NUMBER metadata', () => {
|
||||
it('includes PR number when GH_PR_NUMBER is set', () => {
|
||||
vi.stubEnv('GH_PR_NUMBER', '123');
|
||||
const { logger } = setup({});
|
||||
|
||||
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
|
||||
|
||||
expect(event?.event_metadata[0]).toContainEqual({
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GH_PR_NUMBER,
|
||||
value: '123',
|
||||
});
|
||||
});
|
||||
|
||||
it('does not include PR number when GH_PR_NUMBER is not set', () => {
|
||||
vi.stubEnv('GH_PR_NUMBER', undefined);
|
||||
const { logger } = setup({});
|
||||
|
||||
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
|
||||
const hasPRNumber = event?.event_metadata[0].some(
|
||||
(item) =>
|
||||
item.gemini_cli_key === EventMetadataKey.GEMINI_CLI_GH_PR_NUMBER,
|
||||
);
|
||||
expect(hasPRNumber).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('GH_ISSUE_NUMBER metadata', () => {
|
||||
it('includes issue number when GH_ISSUE_NUMBER is set', () => {
|
||||
vi.stubEnv('GH_ISSUE_NUMBER', '456');
|
||||
const { logger } = setup({});
|
||||
|
||||
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
|
||||
|
||||
expect(event?.event_metadata[0]).toContainEqual({
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GH_ISSUE_NUMBER,
|
||||
value: '456',
|
||||
});
|
||||
});
|
||||
|
||||
it('does not include issue number when GH_ISSUE_NUMBER is not set', () => {
|
||||
vi.stubEnv('GH_ISSUE_NUMBER', undefined);
|
||||
const { logger } = setup({});
|
||||
|
||||
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
|
||||
const hasIssueNumber = event?.event_metadata[0].some(
|
||||
(item) =>
|
||||
item.gemini_cli_key === EventMetadataKey.GEMINI_CLI_GH_ISSUE_NUMBER,
|
||||
);
|
||||
expect(hasIssueNumber).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('GH_CUSTOM_TRACKING_ID metadata', () => {
|
||||
it('includes custom tracking ID when GH_CUSTOM_TRACKING_ID is set', () => {
|
||||
vi.stubEnv('GH_CUSTOM_TRACKING_ID', 'abc-789');
|
||||
const { logger } = setup({});
|
||||
|
||||
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
|
||||
|
||||
expect(event?.event_metadata[0]).toContainEqual({
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GH_CUSTOM_TRACKING_ID,
|
||||
value: 'abc-789',
|
||||
});
|
||||
});
|
||||
|
||||
it('does not include custom tracking ID when GH_CUSTOM_TRACKING_ID is not set', () => {
|
||||
vi.stubEnv('GH_CUSTOM_TRACKING_ID', undefined);
|
||||
const { logger } = setup({});
|
||||
|
||||
const event = logger?.createLogEvent(EventNames.API_ERROR, []);
|
||||
const hasTrackingId = event?.event_metadata[0].some(
|
||||
(item) =>
|
||||
item.gemini_cli_key ===
|
||||
EventMetadataKey.GEMINI_CLI_GH_CUSTOM_TRACKING_ID,
|
||||
);
|
||||
expect(hasTrackingId).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('GITHUB_REPOSITORY metadata', () => {
|
||||
it('includes hashed repository when GITHUB_REPOSITORY is set', () => {
|
||||
vi.stubEnv('GITHUB_REPOSITORY', 'google/gemini-cli');
|
||||
|
||||
@@ -190,6 +190,34 @@ function determineGHRepositoryName(): string | undefined {
|
||||
return process.env['GITHUB_REPOSITORY'];
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines the GitHub event name if the CLI is running in a GitHub Actions environment.
|
||||
*/
|
||||
function determineGHEventName(): string | undefined {
|
||||
return process.env['GITHUB_EVENT_NAME'];
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines the GitHub Pull Request number if the CLI is running in a GitHub Actions environment.
|
||||
*/
|
||||
function determineGHPRNumber(): string | undefined {
|
||||
return process.env['GH_PR_NUMBER'];
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines the GitHub Issue number if the CLI is running in a GitHub Actions environment.
|
||||
*/
|
||||
function determineGHIssueNumber(): string | undefined {
|
||||
return process.env['GH_ISSUE_NUMBER'];
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines the GitHub custom tracking ID if the CLI is running in a GitHub Actions environment.
|
||||
*/
|
||||
function determineGHCustomTrackingId(): string | undefined {
|
||||
return process.env['GH_CUSTOM_TRACKING_ID'];
|
||||
}
|
||||
|
||||
/**
|
||||
* Clearcut URL to send logging events to.
|
||||
*/
|
||||
@@ -372,6 +400,10 @@ export class ClearcutLogger {
|
||||
const email = this.userAccountManager.getCachedGoogleAccount();
|
||||
const surface = determineSurface();
|
||||
const ghWorkflowName = determineGHWorkflowName();
|
||||
const ghEventName = determineGHEventName();
|
||||
const ghPRNumber = determineGHPRNumber();
|
||||
const ghIssueNumber = determineGHIssueNumber();
|
||||
const ghCustomTrackingId = determineGHCustomTrackingId();
|
||||
const baseMetadata: EventValue[] = [
|
||||
...data,
|
||||
{
|
||||
@@ -406,6 +438,34 @@ export class ClearcutLogger {
|
||||
});
|
||||
}
|
||||
|
||||
if (ghEventName) {
|
||||
baseMetadata.push({
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GH_EVENT_NAME,
|
||||
value: ghEventName,
|
||||
});
|
||||
}
|
||||
|
||||
if (ghPRNumber) {
|
||||
baseMetadata.push({
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GH_PR_NUMBER,
|
||||
value: ghPRNumber,
|
||||
});
|
||||
}
|
||||
|
||||
if (ghIssueNumber) {
|
||||
baseMetadata.push({
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GH_ISSUE_NUMBER,
|
||||
value: ghIssueNumber,
|
||||
});
|
||||
}
|
||||
|
||||
if (ghCustomTrackingId) {
|
||||
baseMetadata.push({
|
||||
gemini_cli_key: EventMetadataKey.GEMINI_CLI_GH_CUSTOM_TRACKING_ID,
|
||||
value: ghCustomTrackingId,
|
||||
});
|
||||
}
|
||||
|
||||
const logEvent: LogEvent = {
|
||||
console_type: 'GEMINI_CLI',
|
||||
application: 102, // GEMINI_CLI
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
// Defines valid event metadata keys for Clearcut logging.
|
||||
export enum EventMetadataKey {
|
||||
// Deleted enums: 24
|
||||
// Next ID: 176
|
||||
// Next ID: 180
|
||||
|
||||
GEMINI_CLI_KEY_UNKNOWN = 0,
|
||||
|
||||
@@ -231,6 +231,18 @@ export enum EventMetadataKey {
|
||||
// Logs the repository name of the GitHub Action that triggered the session.
|
||||
GEMINI_CLI_GH_REPOSITORY_NAME_HASH = 132,
|
||||
|
||||
// Logs the event name of the GitHub Action that triggered the session.
|
||||
GEMINI_CLI_GH_EVENT_NAME = 176,
|
||||
|
||||
// Logs the Pull Request number if the workflow is operating on a PR.
|
||||
GEMINI_CLI_GH_PR_NUMBER = 177,
|
||||
|
||||
// Logs the Issue number if the workflow is operating on an Issue.
|
||||
GEMINI_CLI_GH_ISSUE_NUMBER = 178,
|
||||
|
||||
// Logs a custom tracking string (e.g. a comma-separated list of issue IDs for scheduled batches).
|
||||
GEMINI_CLI_GH_CUSTOM_TRACKING_ID = 179,
|
||||
|
||||
// ==========================================================================
|
||||
// Loop Detected Event Keys
|
||||
// ===========================================================================
|
||||
|
||||
@@ -134,21 +134,21 @@ describe('classifyGoogleError', () => {
|
||||
expect((result as TerminalQuotaError).cause).toBe(apiError);
|
||||
});
|
||||
|
||||
it('should return RetryableQuotaError for long retry delays', () => {
|
||||
it('should return TerminalQuotaError for retry delays over 5 minutes', () => {
|
||||
const apiError: GoogleApiError = {
|
||||
code: 429,
|
||||
message: 'Too many requests',
|
||||
details: [
|
||||
{
|
||||
'@type': 'type.googleapis.com/google.rpc.RetryInfo',
|
||||
retryDelay: '301s', // Any delay is now retryable
|
||||
retryDelay: '301s', // Over 5 min threshold => terminal
|
||||
},
|
||||
],
|
||||
};
|
||||
vi.spyOn(errorParser, 'parseGoogleApiError').mockReturnValue(apiError);
|
||||
const result = classifyGoogleError(new Error());
|
||||
expect(result).toBeInstanceOf(RetryableQuotaError);
|
||||
expect((result as RetryableQuotaError).retryDelayMs).toBe(301000);
|
||||
expect(result).toBeInstanceOf(TerminalQuotaError);
|
||||
expect((result as TerminalQuotaError).retryDelayMs).toBe(301000);
|
||||
});
|
||||
|
||||
it('should return RetryableQuotaError for short retry delays', () => {
|
||||
@@ -285,6 +285,34 @@ describe('classifyGoogleError', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should return TerminalQuotaError for Cloud Code RATE_LIMIT_EXCEEDED with retry delay over 5 minutes', () => {
|
||||
const apiError: GoogleApiError = {
|
||||
code: 429,
|
||||
message:
|
||||
'You have exhausted your capacity on this model. Your quota will reset after 10m.',
|
||||
details: [
|
||||
{
|
||||
'@type': 'type.googleapis.com/google.rpc.ErrorInfo',
|
||||
reason: 'RATE_LIMIT_EXCEEDED',
|
||||
domain: 'cloudcode-pa.googleapis.com',
|
||||
metadata: {
|
||||
uiMessage: 'true',
|
||||
model: 'gemini-2.5-pro',
|
||||
},
|
||||
},
|
||||
{
|
||||
'@type': 'type.googleapis.com/google.rpc.RetryInfo',
|
||||
retryDelay: '600s',
|
||||
},
|
||||
],
|
||||
};
|
||||
vi.spyOn(errorParser, 'parseGoogleApiError').mockReturnValue(apiError);
|
||||
const result = classifyGoogleError(new Error());
|
||||
expect(result).toBeInstanceOf(TerminalQuotaError);
|
||||
expect((result as TerminalQuotaError).retryDelayMs).toBe(600000);
|
||||
expect((result as TerminalQuotaError).reason).toBe('RATE_LIMIT_EXCEEDED');
|
||||
});
|
||||
|
||||
it('should return TerminalQuotaError for Cloud Code QUOTA_EXHAUSTED', () => {
|
||||
const apiError: GoogleApiError = {
|
||||
code: 429,
|
||||
@@ -427,6 +455,40 @@ describe('classifyGoogleError', () => {
|
||||
}
|
||||
});
|
||||
|
||||
it('should return TerminalQuotaError when fallback "Please retry in" delay exceeds 5 minutes', () => {
|
||||
const errorWithEmptyDetails = {
|
||||
error: {
|
||||
code: 429,
|
||||
message: 'Resource exhausted. Please retry in 400s',
|
||||
details: [],
|
||||
},
|
||||
};
|
||||
|
||||
const result = classifyGoogleError(errorWithEmptyDetails);
|
||||
|
||||
expect(result).toBeInstanceOf(TerminalQuotaError);
|
||||
if (result instanceof TerminalQuotaError) {
|
||||
expect(result.retryDelayMs).toBe(400000);
|
||||
}
|
||||
});
|
||||
|
||||
it('should return RetryableQuotaError when retry delay is exactly 5 minutes', () => {
|
||||
const apiError: GoogleApiError = {
|
||||
code: 429,
|
||||
message: 'Too many requests',
|
||||
details: [
|
||||
{
|
||||
'@type': 'type.googleapis.com/google.rpc.RetryInfo',
|
||||
retryDelay: '300s',
|
||||
},
|
||||
],
|
||||
};
|
||||
vi.spyOn(errorParser, 'parseGoogleApiError').mockReturnValue(apiError);
|
||||
const result = classifyGoogleError(new Error());
|
||||
expect(result).toBeInstanceOf(RetryableQuotaError);
|
||||
expect((result as RetryableQuotaError).retryDelayMs).toBe(300000);
|
||||
});
|
||||
|
||||
it('should return RetryableQuotaError without delay time for generic 429 without specific message', () => {
|
||||
const generic429 = {
|
||||
status: 429,
|
||||
|
||||
@@ -100,6 +100,13 @@ function parseDurationInSeconds(duration: string): number | null {
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Maximum retry delay (in seconds) before a retryable error is treated as terminal.
|
||||
* If the server suggests waiting longer than this, the user is effectively locked out,
|
||||
* so we trigger the fallback/credits flow instead of silently waiting.
|
||||
*/
|
||||
const MAX_RETRYABLE_DELAY_SECONDS = 300; // 5 minutes
|
||||
|
||||
/**
|
||||
* Valid Cloud Code API domains for VALIDATION_REQUIRED errors.
|
||||
*/
|
||||
@@ -248,15 +255,15 @@ export function classifyGoogleError(error: unknown): unknown {
|
||||
if (match?.[1]) {
|
||||
const retryDelaySeconds = parseDurationInSeconds(match[1]);
|
||||
if (retryDelaySeconds !== null) {
|
||||
return new RetryableQuotaError(
|
||||
errorMessage,
|
||||
googleApiError ?? {
|
||||
code: status ?? 429,
|
||||
message: errorMessage,
|
||||
details: [],
|
||||
},
|
||||
retryDelaySeconds,
|
||||
);
|
||||
const cause = googleApiError ?? {
|
||||
code: status ?? 429,
|
||||
message: errorMessage,
|
||||
details: [],
|
||||
};
|
||||
if (retryDelaySeconds > MAX_RETRYABLE_DELAY_SECONDS) {
|
||||
return new TerminalQuotaError(errorMessage, cause, retryDelaySeconds);
|
||||
}
|
||||
return new RetryableQuotaError(errorMessage, cause, retryDelaySeconds);
|
||||
}
|
||||
} else if (status === 429 || status === 499) {
|
||||
// Fallback: If it is a 429 or 499 but doesn't have a specific "retry in" message,
|
||||
@@ -325,10 +332,19 @@ export function classifyGoogleError(error: unknown): unknown {
|
||||
if (errorInfo.domain) {
|
||||
if (isCloudCodeDomain(errorInfo.domain)) {
|
||||
if (errorInfo.reason === 'RATE_LIMIT_EXCEEDED') {
|
||||
const effectiveDelay = delaySeconds ?? 10;
|
||||
if (effectiveDelay > MAX_RETRYABLE_DELAY_SECONDS) {
|
||||
return new TerminalQuotaError(
|
||||
`${googleApiError.message}`,
|
||||
googleApiError,
|
||||
effectiveDelay,
|
||||
errorInfo.reason,
|
||||
);
|
||||
}
|
||||
return new RetryableQuotaError(
|
||||
`${googleApiError.message}`,
|
||||
googleApiError,
|
||||
delaySeconds ?? 10,
|
||||
effectiveDelay,
|
||||
);
|
||||
}
|
||||
if (errorInfo.reason === 'QUOTA_EXHAUSTED') {
|
||||
@@ -345,6 +361,13 @@ export function classifyGoogleError(error: unknown): unknown {
|
||||
|
||||
// 2. Check for delays in RetryInfo
|
||||
if (retryInfo?.retryDelay && delaySeconds) {
|
||||
if (delaySeconds > MAX_RETRYABLE_DELAY_SECONDS) {
|
||||
return new TerminalQuotaError(
|
||||
`${googleApiError.message}\nSuggested retry after ${retryInfo.retryDelay}.`,
|
||||
googleApiError,
|
||||
delaySeconds,
|
||||
);
|
||||
}
|
||||
return new RetryableQuotaError(
|
||||
`${googleApiError.message}\nSuggested retry after ${retryInfo.retryDelay}.`,
|
||||
googleApiError,
|
||||
|
||||
Reference in New Issue
Block a user