diff --git a/docs/reference/commands.md b/docs/reference/commands.md index aafb8c8566..e5e39cf875 100644 --- a/docs/reference/commands.md +++ b/docs/reference/commands.md @@ -439,6 +439,12 @@ Slash commands provide meta-level control over the CLI itself. - **`nodesc`** or **`nodescriptions`**: - **Description:** Hide tool descriptions, showing only the tool names. +### `/upgrade` + +- **Description:** Open the Gemini Code Assist upgrade page in your browser. + This lets you upgrade your tier for higher usage limits. +- **Note:** This command is only available when logged in with Google. + ### `/vim` - **Description:** Toggle vim mode on or off. When vim mode is enabled, the diff --git a/docs/reference/policy-engine.md b/docs/reference/policy-engine.md index e8de8c5aff..38a0b4d50c 100644 --- a/docs/reference/policy-engine.md +++ b/docs/reference/policy-engine.md @@ -91,10 +91,17 @@ the arguments don't match the pattern, the rule does not apply. There are three possible decisions a rule can enforce: - `allow`: The tool call is executed automatically without user interaction. -- `deny`: The tool call is blocked and is not executed. +- `deny`: The tool call is blocked and is not executed. For global rules (those + without an `argsPattern`), tools that are denied are **completely excluded + from the model's memory**. This means the model will not even see the tool as + an option, which is more secure and saves context window space. - `ask_user`: The user is prompted to approve or deny the tool call. (In non-interactive mode, this is treated as `deny`.) +> **Note:** The `deny` decision is the recommended way to exclude tools. The +> legacy `tools.exclude` setting in `settings.json` is deprecated in favor of +> policy rules with a `deny` decision. + ### Priority system and tiers The policy engine uses a sophisticated priority system to resolve conflicts when diff --git a/packages/cli/src/config/extensions/variables.test.ts b/packages/cli/src/config/extensions/variables.test.ts index 576546ef04..5f57fe19fe 100644 --- a/packages/cli/src/config/extensions/variables.test.ts +++ b/packages/cli/src/config/extensions/variables.test.ts @@ -124,4 +124,30 @@ describe('recursivelyHydrateStrings', () => { const result = recursivelyHydrateStrings(obj, context); expect(result).toEqual(obj); }); + + it('should not allow prototype pollution via __proto__', () => { + const payload = JSON.parse('{"__proto__": {"polluted": "yes"}}'); + const result = recursivelyHydrateStrings(payload, context); + + expect(result.polluted).toBeUndefined(); + expect(Object.prototype.hasOwnProperty.call(result, 'polluted')).toBe( + false, + ); + }); + + it('should not allow prototype pollution via constructor', () => { + const payload = JSON.parse( + '{"constructor": {"prototype": {"polluted": "yes"}}}', + ); + const result = recursivelyHydrateStrings(payload, context); + + expect(result.polluted).toBeUndefined(); + }); + + it('should not allow prototype pollution via prototype', () => { + const payload = JSON.parse('{"prototype": {"polluted": "yes"}}'); + const result = recursivelyHydrateStrings(payload, context); + + expect(result.polluted).toBeUndefined(); + }); }); diff --git a/packages/cli/src/config/extensions/variables.ts b/packages/cli/src/config/extensions/variables.ts index 3a79fc705f..b5b14c9643 100644 --- a/packages/cli/src/config/extensions/variables.ts +++ b/packages/cli/src/config/extensions/variables.ts @@ -8,6 +8,16 @@ import * as path from 'node:path'; import { type VariableSchema, VARIABLE_SCHEMA } from './variableSchema.js'; import { GEMINI_DIR } from '@google/gemini-cli-core'; +/** + * Represents a set of keys that will be considered invalid while unmarshalling + * JSON in recursivelyHydrateStrings. + */ +const UNMARSHALL_KEY_IGNORE_LIST: Set = new Set([ + '__proto__', + 'constructor', + 'prototype', +]); + export const EXTENSIONS_DIRECTORY_NAME = path.join(GEMINI_DIR, 'extensions'); export const EXTENSIONS_CONFIG_FILENAME = 'gemini-extension.json'; export const INSTALL_METADATA_FILENAME = '.gemini-extension-install.json'; @@ -65,7 +75,10 @@ export function recursivelyHydrateStrings( if (typeof obj === 'object' && obj !== null) { const newObj: Record = {}; for (const key in obj) { - if (Object.prototype.hasOwnProperty.call(obj, key)) { + if ( + !UNMARSHALL_KEY_IGNORE_LIST.has(key) && + Object.prototype.hasOwnProperty.call(obj, key) + ) { newObj[key] = recursivelyHydrateStrings( // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion (obj as Record)[key], diff --git a/packages/cli/src/services/BuiltinCommandLoader.test.ts b/packages/cli/src/services/BuiltinCommandLoader.test.ts index 6eb27862e3..62154e3fed 100644 --- a/packages/cli/src/services/BuiltinCommandLoader.test.ts +++ b/packages/cli/src/services/BuiltinCommandLoader.test.ts @@ -142,6 +142,14 @@ vi.mock('../ui/commands/mcpCommand.js', () => ({ }, })); +vi.mock('../ui/commands/upgradeCommand.js', () => ({ + upgradeCommand: { + name: 'upgrade', + description: 'Upgrade command', + kind: 'BUILT_IN', + }, +})); + describe('BuiltinCommandLoader', () => { let mockConfig: Config; @@ -163,6 +171,9 @@ describe('BuiltinCommandLoader', () => { getAllSkills: vi.fn().mockReturnValue([]), isAdminEnabled: vi.fn().mockReturnValue(true), }), + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: 'other', + }), } as unknown as Config; restoreCommandMock.mockReturnValue({ @@ -172,6 +183,27 @@ describe('BuiltinCommandLoader', () => { }); }); + it('should include upgrade command when authType is login_with_google', async () => { + const { AuthType } = await import('@google/gemini-cli-core'); + (mockConfig.getContentGeneratorConfig as Mock).mockReturnValue({ + authType: AuthType.LOGIN_WITH_GOOGLE, + }); + const loader = new BuiltinCommandLoader(mockConfig); + const commands = await loader.loadCommands(new AbortController().signal); + const upgradeCmd = commands.find((c) => c.name === 'upgrade'); + expect(upgradeCmd).toBeDefined(); + }); + + it('should exclude upgrade command when authType is NOT login_with_google', async () => { + (mockConfig.getContentGeneratorConfig as Mock).mockReturnValue({ + authType: 'other', + }); + const loader = new BuiltinCommandLoader(mockConfig); + const commands = await loader.loadCommands(new AbortController().signal); + const upgradeCmd = commands.find((c) => c.name === 'upgrade'); + expect(upgradeCmd).toBeUndefined(); + }); + it('should correctly pass the config object to restore command factory', async () => { const loader = new BuiltinCommandLoader(mockConfig); await loader.loadCommands(new AbortController().signal); @@ -364,6 +396,9 @@ describe('BuiltinCommandLoader profile', () => { getAllSkills: vi.fn().mockReturnValue([]), isAdminEnabled: vi.fn().mockReturnValue(true), }), + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: 'other', + }), } as unknown as Config; }); diff --git a/packages/cli/src/services/BuiltinCommandLoader.ts b/packages/cli/src/services/BuiltinCommandLoader.ts index 8ee5effc59..66806f5ef1 100644 --- a/packages/cli/src/services/BuiltinCommandLoader.ts +++ b/packages/cli/src/services/BuiltinCommandLoader.ts @@ -16,6 +16,7 @@ import { isNightly, startupProfiler, getAdminErrorMessage, + AuthType, } from '@google/gemini-cli-core'; import { aboutCommand } from '../ui/commands/aboutCommand.js'; import { agentsCommand } from '../ui/commands/agentsCommand.js'; @@ -59,6 +60,7 @@ import { shellsCommand } from '../ui/commands/shellsCommand.js'; import { vimCommand } from '../ui/commands/vimCommand.js'; import { setupGithubCommand } from '../ui/commands/setupGithubCommand.js'; import { terminalSetupCommand } from '../ui/commands/terminalSetupCommand.js'; +import { upgradeCommand } from '../ui/commands/upgradeCommand.js'; /** * Loads the core, hard-coded slash commands that are an integral part @@ -223,6 +225,10 @@ export class BuiltinCommandLoader implements ICommandLoader { vimCommand, setupGithubCommand, terminalSetupCommand, + ...(this.config?.getContentGeneratorConfig()?.authType === + AuthType.LOGIN_WITH_GOOGLE + ? [upgradeCommand] + : []), ]; handle?.end(); return allDefinitions.filter((cmd): cmd is SlashCommand => cmd !== null); diff --git a/packages/cli/src/ui/commands/upgradeCommand.test.ts b/packages/cli/src/ui/commands/upgradeCommand.test.ts new file mode 100644 index 0000000000..224123612e --- /dev/null +++ b/packages/cli/src/ui/commands/upgradeCommand.test.ts @@ -0,0 +1,99 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { upgradeCommand } from './upgradeCommand.js'; +import { type CommandContext } from './types.js'; +import { createMockCommandContext } from '../../test-utils/mockCommandContext.js'; +import { + AuthType, + openBrowserSecurely, + UPGRADE_URL_PAGE, +} from '@google/gemini-cli-core'; + +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const actual = + await importOriginal(); + return { + ...actual, + openBrowserSecurely: vi.fn(), + UPGRADE_URL_PAGE: 'https://goo.gle/set-up-gemini-code-assist', + }; +}); + +describe('upgradeCommand', () => { + let mockContext: CommandContext; + + beforeEach(() => { + vi.clearAllMocks(); + mockContext = createMockCommandContext({ + services: { + config: { + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: AuthType.LOGIN_WITH_GOOGLE, + }), + }, + }, + } as unknown as CommandContext); + }); + + it('should have the correct name and description', () => { + expect(upgradeCommand.name).toBe('upgrade'); + expect(upgradeCommand.description).toBe( + 'Upgrade your Gemini Code Assist tier for higher limits', + ); + }); + + it('should call openBrowserSecurely with UPGRADE_URL_PAGE when logged in with Google', async () => { + if (!upgradeCommand.action) { + throw new Error('The upgrade command must have an action.'); + } + + await upgradeCommand.action(mockContext, ''); + + expect(openBrowserSecurely).toHaveBeenCalledWith(UPGRADE_URL_PAGE); + }); + + it('should return an error message when NOT logged in with Google', async () => { + vi.mocked( + mockContext.services.config!.getContentGeneratorConfig, + ).mockReturnValue({ + authType: AuthType.USE_GEMINI, + }); + + if (!upgradeCommand.action) { + throw new Error('The upgrade command must have an action.'); + } + + const result = await upgradeCommand.action(mockContext, ''); + + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: + 'The /upgrade command is only available when logged in with Google.', + }); + expect(openBrowserSecurely).not.toHaveBeenCalled(); + }); + + it('should return an error message if openBrowserSecurely fails', async () => { + vi.mocked(openBrowserSecurely).mockRejectedValue( + new Error('Failed to open'), + ); + + if (!upgradeCommand.action) { + throw new Error('The upgrade command must have an action.'); + } + + const result = await upgradeCommand.action(mockContext, ''); + + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: 'Failed to open upgrade page: Failed to open', + }); + }); +}); diff --git a/packages/cli/src/ui/commands/upgradeCommand.ts b/packages/cli/src/ui/commands/upgradeCommand.ts new file mode 100644 index 0000000000..532ff3b481 --- /dev/null +++ b/packages/cli/src/ui/commands/upgradeCommand.ts @@ -0,0 +1,50 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + AuthType, + openBrowserSecurely, + UPGRADE_URL_PAGE, +} from '@google/gemini-cli-core'; +import type { SlashCommand } from './types.js'; +import { CommandKind } from './types.js'; + +/** + * Command to open the upgrade page for Gemini Code Assist. + * Only intended to be shown/available when the user is logged in with Google. + */ +export const upgradeCommand: SlashCommand = { + name: 'upgrade', + kind: CommandKind.BUILT_IN, + description: 'Upgrade your Gemini Code Assist tier for higher limits', + autoExecute: true, + action: async (context) => { + const authType = + context.services.config?.getContentGeneratorConfig()?.authType; + if (authType !== AuthType.LOGIN_WITH_GOOGLE) { + // This command should ideally be hidden if not logged in with Google, + // but we add a safety check here just in case. + return { + type: 'message', + messageType: 'error', + content: + 'The /upgrade command is only available when logged in with Google.', + }; + } + + try { + await openBrowserSecurely(UPGRADE_URL_PAGE); + } catch (error) { + return { + type: 'message', + messageType: 'error', + content: `Failed to open upgrade page: ${error instanceof Error ? error.message : String(error)}`, + }; + } + + return undefined; + }, +}; diff --git a/packages/cli/src/utils/settingsUtils.test.ts b/packages/cli/src/utils/settingsUtils.test.ts index a1f662af4d..9274c1b6f8 100644 --- a/packages/cli/src/utils/settingsUtils.test.ts +++ b/packages/cli/src/utils/settingsUtils.test.ts @@ -734,6 +734,55 @@ describe('SettingsUtils', () => { ); expect(result).toBe('false'); }); + + it('should display objects as JSON strings, not "[object Object]"', () => { + vi.mocked(getSettingsSchema).mockReturnValue({ + experimental: { + type: 'object', + label: 'Experimental', + category: 'Experimental', + requiresRestart: true, + default: {}, + description: 'Experimental settings', + showInDialog: false, + properties: { + gemmaModelRouter: { + type: 'object', + label: 'Gemma Model Router', + category: 'Experimental', + requiresRestart: true, + default: {}, + description: 'Gemma model router settings', + showInDialog: true, + }, + }, + }, + } as unknown as SettingsSchemaType); + + // Test with empty object (default) + const emptySettings = makeMockSettings({}); + const emptyResult = getDisplayValue( + 'experimental.gemmaModelRouter', + emptySettings, + emptySettings, + ); + expect(emptyResult).toBe('{}'); + expect(emptyResult).not.toBe('[object Object]'); + + // Test with object containing values + const settings = makeMockSettings({ + experimental: { + gemmaModelRouter: { enabled: true, host: 'localhost' }, + }, + }); + const result = getDisplayValue( + 'experimental.gemmaModelRouter', + settings, + settings, + ); + expect(result).toBe('{"enabled":true,"host":"localhost"}*'); + expect(result).not.toContain('[object Object]'); + }); }); describe('getDisplayValue with units', () => { diff --git a/packages/cli/src/utils/settingsUtils.ts b/packages/cli/src/utils/settingsUtils.ts index 11c3a9a13f..daa599826f 100644 --- a/packages/cli/src/utils/settingsUtils.ts +++ b/packages/cli/src/utils/settingsUtils.ts @@ -284,7 +284,14 @@ export function getDisplayValue( let valueString = String(value); - if (definition?.type === 'enum' && definition.options) { + // Handle object types by stringifying them + if ( + definition?.type === 'object' && + value !== null && + typeof value === 'object' + ) { + valueString = JSON.stringify(value); + } else if (definition?.type === 'enum' && definition.options) { const option = definition.options?.find((option) => option.value === value); valueString = option?.label ?? `${value}`; } diff --git a/packages/core/src/agents/browser/browserAgentInvocation.test.ts b/packages/core/src/agents/browser/browserAgentInvocation.test.ts index b58a9c409e..daf5309479 100644 --- a/packages/core/src/agents/browser/browserAgentInvocation.test.ts +++ b/packages/core/src/agents/browser/browserAgentInvocation.test.ts @@ -9,7 +9,11 @@ import { BrowserAgentInvocation } from './browserAgentInvocation.js'; import { makeFakeConfig } from '../../test-utils/config.js'; import type { Config } from '../../config/config.js'; import type { MessageBus } from '../../confirmation-bus/message-bus.js'; -import type { AgentInputs } from '../types.js'; +import { + type AgentInputs, + type SubagentProgress, + type SubagentActivityEvent, +} from '../types.js'; // Mock dependencies before imports vi.mock('../../utils/debugLogger.js', () => ({ @@ -19,6 +23,24 @@ vi.mock('../../utils/debugLogger.js', () => ({ }, })); +vi.mock('./browserAgentFactory.js', () => ({ + createBrowserAgentDefinition: vi.fn(), + cleanupBrowserAgent: vi.fn(), +})); + +vi.mock('../local-executor.js', () => ({ + LocalAgentExecutor: { + create: vi.fn(), + }, +})); + +import { + createBrowserAgentDefinition, + cleanupBrowserAgent, +} from './browserAgentFactory.js'; +import { LocalAgentExecutor } from '../local-executor.js'; +import type { ToolLiveOutput } from '../../tools/tools.js'; + describe('BrowserAgentInvocation', () => { let mockConfig: Config; let mockMessageBus: MessageBus; @@ -136,4 +158,473 @@ describe('BrowserAgentInvocation', () => { expect(locations).toEqual([]); }); }); + + describe('execute', () => { + let mockExecutor: { run: ReturnType }; + + beforeEach(() => { + vi.mocked(createBrowserAgentDefinition).mockResolvedValue({ + definition: { + name: 'browser_agent', + description: 'mock definition', + kind: 'local', + inputConfig: {} as never, + outputConfig: {} as never, + processOutput: () => '', + modelConfig: { model: 'test' }, + runConfig: {}, + promptConfig: { query: '', systemPrompt: '' }, + toolConfig: { tools: ['analyze_screenshot', 'click'] }, + }, + browserManager: {} as never, + }); + + mockExecutor = { + run: vi.fn().mockResolvedValue({ + result: JSON.stringify({ success: true }), + terminate_reason: 'GOAL', + }), + }; + + vi.mocked(LocalAgentExecutor.create).mockResolvedValue( + mockExecutor as never, + ); + vi.mocked(cleanupBrowserAgent).mockClear(); + }); + + it('should return result text and call cleanup on success', async () => { + const invocation = new BrowserAgentInvocation( + mockConfig, + mockParams, + mockMessageBus, + ); + + const controller = new AbortController(); + const updateOutput: (output: ToolLiveOutput) => void = vi.fn(); + + const result = await invocation.execute(controller.signal, updateOutput); + + expect(Array.isArray(result.llmContent)).toBe(true); + expect((result.llmContent as Array<{ text: string }>)[0].text).toContain( + 'Browser agent finished', + ); + expect(cleanupBrowserAgent).toHaveBeenCalled(); + }); + + it('should work without updateOutput (fire-and-forget)', async () => { + const invocation = new BrowserAgentInvocation( + mockConfig, + mockParams, + mockMessageBus, + ); + + const controller = new AbortController(); + // Should not throw even with no updateOutput + await expect( + invocation.execute(controller.signal), + ).resolves.toBeDefined(); + }); + + it('should return error result when executor throws', async () => { + mockExecutor.run.mockRejectedValue(new Error('Unexpected crash')); + + const invocation = new BrowserAgentInvocation( + mockConfig, + mockParams, + mockMessageBus, + ); + + const controller = new AbortController(); + const result = await invocation.execute(controller.signal); + + expect(result.error).toBeDefined(); + expect(cleanupBrowserAgent).toHaveBeenCalled(); + }); + + // ─── Structured SubagentProgress emission tests ─────────────────────── + + /** + * Helper: sets up LocalAgentExecutor.create to capture the onActivity + * callback so tests can fire synthetic activity events. + */ + function setupActivityCapture(): { + capturedOnActivity: () => SubagentActivityEvent | undefined; + fireActivity: (event: SubagentActivityEvent) => void; + } { + let onActivityFn: ((e: SubagentActivityEvent) => void) | undefined; + + vi.mocked(LocalAgentExecutor.create).mockImplementation( + async (_def, _config, onActivity) => { + onActivityFn = onActivity; + return mockExecutor as never; + }, + ); + + return { + capturedOnActivity: () => undefined, + fireActivity: (event: SubagentActivityEvent) => { + onActivityFn?.(event); + }, + }; + } + + it('should emit initial SubagentProgress with running state', async () => { + const updateOutput = vi.fn(); + + const invocation = new BrowserAgentInvocation( + mockConfig, + mockParams, + mockMessageBus, + ); + + await invocation.execute(new AbortController().signal, updateOutput); + + const firstCall = updateOutput.mock.calls[0]?.[0] as SubagentProgress; + expect(firstCall.isSubagentProgress).toBe(true); + expect(firstCall.state).toBe('running'); + expect(firstCall.recentActivity).toEqual([]); + }); + + it('should emit completed SubagentProgress on success', async () => { + const updateOutput = vi.fn(); + + const invocation = new BrowserAgentInvocation( + mockConfig, + mockParams, + mockMessageBus, + ); + + await invocation.execute(new AbortController().signal, updateOutput); + + const lastCall = updateOutput.mock.calls[ + updateOutput.mock.calls.length - 1 + ]?.[0] as SubagentProgress; + expect(lastCall.isSubagentProgress).toBe(true); + expect(lastCall.state).toBe('completed'); + }); + + it('should handle THOUGHT_CHUNK and emit structured progress', async () => { + const { fireActivity } = setupActivityCapture(); + const updateOutput = vi.fn(); + + const invocation = new BrowserAgentInvocation( + mockConfig, + mockParams, + mockMessageBus, + ); + + const executePromise = invocation.execute( + new AbortController().signal, + updateOutput, + ); + + // Allow createBrowserAgentDefinition to resolve and onActivity to be registered + await Promise.resolve(); + await Promise.resolve(); + + fireActivity({ + isSubagentActivityEvent: true, + agentName: 'browser_agent', + type: 'THOUGHT_CHUNK', + data: { text: 'Navigating to the page...' }, + }); + + await executePromise; + + const progressCalls = updateOutput.mock.calls + .map((c) => c[0] as SubagentProgress) + .filter((p) => p.isSubagentProgress); + + const thoughtProgress = progressCalls.find((p) => + p.recentActivity.some( + (a) => + a.type === 'thought' && + a.content.includes('Navigating to the page...'), + ), + ); + expect(thoughtProgress).toBeDefined(); + }); + + it('should handle TOOL_CALL_START and TOOL_CALL_END with callId tracking', async () => { + const { fireActivity } = setupActivityCapture(); + const updateOutput = vi.fn(); + + const invocation = new BrowserAgentInvocation( + mockConfig, + mockParams, + mockMessageBus, + ); + + const executePromise = invocation.execute( + new AbortController().signal, + updateOutput, + ); + + await Promise.resolve(); + await Promise.resolve(); + + fireActivity({ + isSubagentActivityEvent: true, + agentName: 'browser_agent', + type: 'TOOL_CALL_START', + data: { + name: 'navigate_browser', + callId: 'call-1', + args: { url: 'https://example.com' }, + }, + }); + + fireActivity({ + isSubagentActivityEvent: true, + agentName: 'browser_agent', + type: 'TOOL_CALL_END', + data: { name: 'navigate_browser', id: 'call-1' }, + }); + + await executePromise; + + const progressCalls = updateOutput.mock.calls + .map((c) => c[0] as SubagentProgress) + .filter((p) => p.isSubagentProgress); + + // After TOOL_CALL_END, the tool should be completed + const finalProgress = progressCalls[progressCalls.length - 1]; + const toolItem = finalProgress?.recentActivity.find( + (a) => a.type === 'tool_call' && a.content === 'navigate_browser', + ); + expect(toolItem).toBeDefined(); + expect(toolItem?.status).toBe('completed'); + }); + + it('should sanitize sensitive data in tool call args', async () => { + const { fireActivity } = setupActivityCapture(); + const updateOutput = vi.fn(); + + const invocation = new BrowserAgentInvocation( + mockConfig, + mockParams, + mockMessageBus, + ); + + const executePromise = invocation.execute( + new AbortController().signal, + updateOutput, + ); + + await Promise.resolve(); + await Promise.resolve(); + + fireActivity({ + isSubagentActivityEvent: true, + agentName: 'browser_agent', + type: 'TOOL_CALL_START', + data: { + name: 'fill_form', + callId: 'call-2', + args: { password: 'supersecret123', url: 'https://example.com' }, + }, + }); + + await executePromise; + + const progressCalls = updateOutput.mock.calls + .map((c) => c[0] as SubagentProgress) + .filter((p) => p.isSubagentProgress); + + const toolItem = progressCalls + .flatMap((p) => p.recentActivity) + .find((a) => a.type === 'tool_call' && a.content === 'fill_form'); + + expect(toolItem).toBeDefined(); + expect(toolItem?.args).not.toContain('supersecret123'); + expect(toolItem?.args).toContain('[REDACTED]'); + }); + + it('should handle ERROR event with callId and mark tool as errored', async () => { + const { fireActivity } = setupActivityCapture(); + const updateOutput = vi.fn(); + + const invocation = new BrowserAgentInvocation( + mockConfig, + mockParams, + mockMessageBus, + ); + + const executePromise = invocation.execute( + new AbortController().signal, + updateOutput, + ); + + await Promise.resolve(); + await Promise.resolve(); + + fireActivity({ + isSubagentActivityEvent: true, + agentName: 'browser_agent', + type: 'TOOL_CALL_START', + data: { name: 'click_element', callId: 'call-3', args: {} }, + }); + + fireActivity({ + isSubagentActivityEvent: true, + agentName: 'browser_agent', + type: 'ERROR', + data: { error: 'Element not found', callId: 'call-3' }, + }); + + await executePromise; + + const progressCalls = updateOutput.mock.calls + .map((c) => c[0] as SubagentProgress) + .filter((p) => p.isSubagentProgress); + + const allItems = progressCalls.flatMap((p) => p.recentActivity); + const toolItem = allItems.find( + (a) => a.type === 'tool_call' && a.content === 'click_element', + ); + expect(toolItem?.status).toBe('error'); + }); + + it('should sanitize sensitive data in ERROR event messages', async () => { + const { fireActivity } = setupActivityCapture(); + const updateOutput = vi.fn(); + + const invocation = new BrowserAgentInvocation( + mockConfig, + mockParams, + mockMessageBus, + ); + + const executePromise = invocation.execute( + new AbortController().signal, + updateOutput, + ); + + await Promise.resolve(); + await Promise.resolve(); + + fireActivity({ + isSubagentActivityEvent: true, + agentName: 'browser_agent', + type: 'ERROR', + data: { error: 'Auth failed: api_key=sk-secret-abc1234567890' }, + }); + + await executePromise; + + const progressCalls = updateOutput.mock.calls + .map((c) => c[0] as SubagentProgress) + .filter((p) => p.isSubagentProgress); + + const errorItem = progressCalls + .flatMap((p) => p.recentActivity) + .find((a) => a.type === 'thought' && a.status === 'error'); + + expect(errorItem).toBeDefined(); + expect(errorItem?.content).not.toContain('sk-secret-abc1234567890'); + expect(errorItem?.content).toContain('[REDACTED]'); + }); + + it('should sanitize inline PEM content in error messages', async () => { + const { fireActivity } = setupActivityCapture(); + const updateOutput = vi.fn(); + + const invocation = new BrowserAgentInvocation( + mockConfig, + mockParams, + mockMessageBus, + ); + + const executePromise = invocation.execute( + new AbortController().signal, + updateOutput, + ); + + await Promise.resolve(); + await Promise.resolve(); + + fireActivity({ + isSubagentActivityEvent: true, + agentName: 'browser_agent', + type: 'ERROR', + data: { + error: + 'Failed to authenticate:\n-----BEGIN RSA PRIVATE KEY-----\nMIIEowIBAAKCAQEA12345...\n-----END RSA PRIVATE KEY-----\nPlease check credentials.', + }, + }); + + await executePromise; + + const progressCalls = updateOutput.mock.calls + .map((c) => c[0] as SubagentProgress) + .filter((p) => p.isSubagentProgress); + + const errorItem = progressCalls + .flatMap((p) => p.recentActivity) + .find((a) => a.type === 'thought' && a.status === 'error'); + + expect(errorItem).toBeDefined(); + expect(errorItem?.content).toContain('[REDACTED_PEM]'); + expect(errorItem?.content).not.toContain('-----BEGIN'); + }); + + it('should mark all running tools as errored when ERROR has no callId', async () => { + const { fireActivity } = setupActivityCapture(); + const updateOutput = vi.fn(); + + const invocation = new BrowserAgentInvocation( + mockConfig, + mockParams, + mockMessageBus, + ); + + const executePromise = invocation.execute( + new AbortController().signal, + updateOutput, + ); + + await Promise.resolve(); + await Promise.resolve(); + + fireActivity({ + isSubagentActivityEvent: true, + agentName: 'browser_agent', + type: 'TOOL_CALL_START', + data: { name: 'tool_a', callId: 'c1', args: {} }, + }); + + fireActivity({ + isSubagentActivityEvent: true, + agentName: 'browser_agent', + type: 'TOOL_CALL_START', + data: { name: 'tool_b', callId: 'c2', args: {} }, + }); + + // ERROR with no callId should mark ALL running tools as error + fireActivity({ + isSubagentActivityEvent: true, + agentName: 'browser_agent', + type: 'ERROR', + data: { error: 'Agent crashed' }, + }); + + await executePromise; + + const progressCalls = updateOutput.mock.calls + .map((c) => c[0] as SubagentProgress) + .filter((p) => p.isSubagentProgress); + + const allItems = progressCalls.flatMap((p) => p.recentActivity); + const toolA = allItems.find( + (a) => a.type === 'tool_call' && a.content === 'tool_a', + ); + const toolB = allItems.find( + (a) => a.type === 'tool_call' && a.content === 'tool_b', + ); + + // Both should be error since no callId was specified + expect(toolA?.status).toBe('error'); + expect(toolB?.status).toBe('error'); + }); + }); }); diff --git a/packages/core/src/agents/browser/browserAgentInvocation.ts b/packages/core/src/agents/browser/browserAgentInvocation.ts index b503cc1214..3bdb4fa2d5 100644 --- a/packages/core/src/agents/browser/browserAgentInvocation.ts +++ b/packages/core/src/agents/browser/browserAgentInvocation.ts @@ -14,6 +14,7 @@ * The MCP tools are only available in the browser agent's isolated registry. */ +import { randomUUID } from 'node:crypto'; import type { Config } from '../../config/config.js'; import { LocalAgentExecutor } from '../local-executor.js'; import { @@ -22,7 +23,12 @@ import { type ToolLiveOutput, } from '../../tools/tools.js'; import { ToolErrorType } from '../../tools/tool-error.js'; -import type { AgentInputs, SubagentActivityEvent } from '../types.js'; +import { + type AgentInputs, + type SubagentActivityEvent, + type SubagentProgress, + type SubagentActivityItem, +} from '../types.js'; import type { MessageBus } from '../../confirmation-bus/message-bus.js'; import { createBrowserAgentDefinition, @@ -31,6 +37,134 @@ import { const INPUT_PREVIEW_MAX_LENGTH = 50; const DESCRIPTION_MAX_LENGTH = 200; +const MAX_RECENT_ACTIVITY = 20; + +/** + * Sensitive key patterns used for redaction. + */ +const SENSITIVE_KEY_PATTERNS = [ + 'password', + 'pwd', + 'apikey', + 'api_key', + 'api-key', + 'token', + 'secret', + 'credential', + 'auth', + 'authorization', + 'access_token', + 'access_key', + 'refresh_token', + 'session_id', + 'cookie', + 'passphrase', + 'privatekey', + 'private_key', + 'private-key', + 'secret_key', + 'client_secret', + 'client_id', +]; + +/** + * Sanitizes tool arguments by recursively redacting sensitive fields. + * Supports nested objects and arrays. + */ +function sanitizeToolArgs(args: unknown): unknown { + if (typeof args === 'string') { + return sanitizeErrorMessage(args); + } + if (typeof args !== 'object' || args === null) { + return args; + } + + if (Array.isArray(args)) { + return args.map(sanitizeToolArgs); + } + + const sanitized: Record = {}; + + for (const [key, value] of Object.entries(args)) { + // Decode key to handle URL-encoded sensitive keys (e.g., api%5fkey) + let decodedKey = key; + try { + decodedKey = decodeURIComponent(key); + } catch { + // Ignore decoding errors + } + const keyNormalized = decodedKey.toLowerCase().replace(/[-_]/g, ''); + const isSensitive = SENSITIVE_KEY_PATTERNS.some((pattern) => + keyNormalized.includes(pattern.replace(/[-_]/g, '')), + ); + if (isSensitive) { + sanitized[key] = '[REDACTED]'; + } else { + sanitized[key] = sanitizeToolArgs(value); + } + } + + return sanitized; +} + +/** + * Sanitizes error messages by redacting potential sensitive data patterns. + * Uses [^\s'"]+ to catch JWTs, tokens with dots/slashes, and other complex values. + */ +function sanitizeErrorMessage(message: string): string { + if (!message) return message; + + let sanitized = message; + + // 1. Redact inline PEM content + sanitized = sanitized.replace( + /-----BEGIN\s+[\w\s]+-----[\s\S]*?-----END\s+[\w\s]+-----/g, + '[REDACTED_PEM]', + ); + + const unquotedValue = `[^\\s]+(?:\\s+(?![a-zA-Z0-9_.-]+(?:=|:))[^\\s=:<>]+)*`; + const valuePattern = `(?:"[^"]*"|'[^']*'|${unquotedValue})`; + + // 2. Handle key-value pairs with delimiters (=, :, space, CLI-style --flag) + const urlSafeKeyPatternStr = SENSITIVE_KEY_PATTERNS.map((p) => + p.replace(/[-_]/g, '(?:[-_]|%2D|%5F|%2d|%5f)?'), + ).join('|'); + + const keyWithDelimiter = new RegExp( + `((?:--)?("|')?(${urlSafeKeyPatternStr})\\2\\s*(?:[:=]|%3A|%3D)\\s*)${valuePattern}`, + 'gi', + ); + sanitized = sanitized.replace(keyWithDelimiter, '$1[REDACTED]'); + + // 3. Handle space-separated sensitive keywords (e.g. "password mypass", "--api-key secret") + const tokenValuePattern = `[A-Za-z0-9._\\-/+=]{8,}`; + const spaceKeywords = [ + ...SENSITIVE_KEY_PATTERNS.map((p) => + p.replace(/[-_]/g, '(?:[-_]|%2D|%5F|%2d|%5f)?'), + ), + 'bearer', + ]; + const spaceSeparated = new RegExp( + `\\b((?:--)?(?:${spaceKeywords.join('|')})(?:\\s*:\\s*bearer)?\\s+)(${tokenValuePattern})`, + 'gi', + ); + sanitized = sanitized.replace(spaceSeparated, '$1[REDACTED]'); + + // 4. Handle file path redaction + sanitized = sanitized.replace( + /((?:[/\\][a-zA-Z0-9_-]+)*[/\\][a-zA-Z0-9_-]*\.(?:key|pem|p12|pfx))/gi, + '/path/to/[REDACTED].key', + ); + + return sanitized; +} + +/** + * Sanitizes LLM thought content by redacting sensitive data patterns. + */ +function sanitizeThoughtContent(text: string): string { + return sanitizeErrorMessage(text); +} /** * Browser agent invocation with async tool setup. @@ -88,15 +222,41 @@ export class BrowserAgentInvocation extends BaseToolInvocation< updateOutput?: (output: ToolLiveOutput) => void, ): Promise { let browserManager; + let recentActivity: SubagentActivityItem[] = []; try { if (updateOutput) { - updateOutput('🌐 Starting browser agent...\n'); + // Send initial state + const initialProgress: SubagentProgress = { + isSubagentProgress: true, + agentName: this['_toolName'] ?? 'browser_agent', + recentActivity: [], + state: 'running', + }; + updateOutput(initialProgress); } // Create definition with MCP tools + // Note: printOutput is used for low-level connection logs before agent starts const printOutput = updateOutput - ? (msg: string) => updateOutput(`🌐 ${msg}\n`) + ? (msg: string) => { + const sanitizedMsg = sanitizeThoughtContent(msg); + recentActivity.push({ + id: randomUUID(), + type: 'thought', + content: sanitizedMsg, + status: 'completed', + }); + if (recentActivity.length > MAX_RECENT_ACTIVITY) { + recentActivity = recentActivity.slice(-MAX_RECENT_ACTIVITY); + } + updateOutput({ + isSubagentProgress: true, + agentName: this['_toolName'] ?? 'browser_agent', + recentActivity: [...recentActivity], + state: 'running', + } as SubagentProgress); + } : undefined; const result = await createBrowserAgentDefinition( @@ -107,22 +267,141 @@ export class BrowserAgentInvocation extends BaseToolInvocation< const { definition } = result; browserManager = result.browserManager; - if (updateOutput) { - updateOutput( - `🌐 Browser connected. Tools: ${definition.toolConfig?.tools.length ?? 0}\n`, - ); - } - // Create activity callback for streaming output const onActivity = (activity: SubagentActivityEvent): void => { if (!updateOutput) return; - if ( - activity.type === 'THOUGHT_CHUNK' && - // eslint-disable-next-line no-restricted-syntax - typeof activity.data['text'] === 'string' - ) { - updateOutput(`🌐💭 ${activity.data['text']}`); + let updated = false; + + switch (activity.type) { + case 'THOUGHT_CHUNK': { + const text = String(activity.data['text']); + const lastItem = recentActivity[recentActivity.length - 1]; + if ( + lastItem && + lastItem.type === 'thought' && + lastItem.status === 'running' + ) { + lastItem.content = sanitizeThoughtContent( + lastItem.content + text, + ); + } else { + recentActivity.push({ + id: randomUUID(), + type: 'thought', + content: sanitizeThoughtContent(text), + status: 'running', + }); + } + updated = true; + break; + } + case 'TOOL_CALL_START': { + const name = String(activity.data['name']); + const displayName = activity.data['displayName'] + ? sanitizeErrorMessage(String(activity.data['displayName'])) + : undefined; + const description = activity.data['description'] + ? sanitizeErrorMessage(String(activity.data['description'])) + : undefined; + const args = JSON.stringify( + sanitizeToolArgs(activity.data['args']), + ); + const callId = activity.data['callId'] + ? String(activity.data['callId']) + : randomUUID(); + recentActivity.push({ + id: callId, + type: 'tool_call', + content: name, + displayName, + description, + args, + status: 'running', + }); + updated = true; + break; + } + case 'TOOL_CALL_END': { + const callId = activity.data['id'] + ? String(activity.data['id']) + : undefined; + // Find the tool call by ID + // Find the tool call by ID + for (let i = recentActivity.length - 1; i >= 0; i--) { + if ( + recentActivity[i].type === 'tool_call' && + callId != null && + recentActivity[i].id === callId && + recentActivity[i].status === 'running' + ) { + recentActivity[i].status = 'completed'; + updated = true; + break; + } + } + break; + } + case 'ERROR': { + const error = String(activity.data['error']); + const isCancellation = error === 'Request cancelled.'; + const callId = activity.data['callId'] + ? String(activity.data['callId']) + : undefined; + const newStatus = isCancellation ? 'cancelled' : 'error'; + + if (callId) { + // Mark the specific tool as error/cancelled + for (let i = recentActivity.length - 1; i >= 0; i--) { + if ( + recentActivity[i].type === 'tool_call' && + recentActivity[i].id === callId && + recentActivity[i].status === 'running' + ) { + recentActivity[i].status = newStatus; + updated = true; + break; + } + } + } else { + // No specific tool — mark ALL running tool_call items + for (const item of recentActivity) { + if (item.type === 'tool_call' && item.status === 'running') { + item.status = newStatus; + updated = true; + } + } + } + + // Sanitize the error message before emitting + const sanitizedError = sanitizeErrorMessage(error); + recentActivity.push({ + id: randomUUID(), + type: 'thought', + content: isCancellation + ? sanitizedError + : `Error: ${sanitizedError}`, + status: newStatus, + }); + updated = true; + break; + } + default: + break; + } + + if (updated) { + if (recentActivity.length > MAX_RECENT_ACTIVITY) { + recentActivity = recentActivity.slice(-MAX_RECENT_ACTIVITY); + } + + const progress: SubagentProgress = { + isSubagentProgress: true, + agentName: this['_toolName'] ?? 'browser_agent', + recentActivity: [...recentActivity], + state: 'running', + }; + updateOutput(progress); } }; @@ -149,17 +428,52 @@ Result: ${output.result} `; + if (updateOutput) { + updateOutput({ + isSubagentProgress: true, + agentName: this['_toolName'] ?? 'browser_agent', + recentActivity: [...recentActivity], + state: 'completed', + } as SubagentProgress); + } + return { llmContent: [{ text: resultContent }], returnDisplay: displayContent, }; } catch (error) { - const errorMessage = + const rawErrorMessage = error instanceof Error ? error.message : String(error); + const isAbort = + (error instanceof Error && error.name === 'AbortError') || + rawErrorMessage.includes('Aborted'); + const errorMessage = sanitizeErrorMessage(rawErrorMessage); + + // Mark any running items as error/cancelled + for (const item of recentActivity) { + if (item.status === 'running') { + item.status = isAbort ? 'cancelled' : 'error'; + } + } + + const progress: SubagentProgress = { + isSubagentProgress: true, + agentName: this['_toolName'] ?? 'browser_agent', + recentActivity: [...recentActivity], + state: isAbort ? 'cancelled' : 'error', + }; + + if (updateOutput) { + updateOutput(progress); + } + + const llmContent = isAbort + ? 'Browser agent execution was aborted.' + : `Browser agent failed. Error: ${errorMessage}`; return { - llmContent: `Browser agent failed. Error: ${errorMessage}`, - returnDisplay: `Browser Agent Failed\nError: ${errorMessage}`, + llmContent: [{ text: llmContent }], + returnDisplay: progress, error: { message: errorMessage, type: ToolErrorType.EXECUTION_FAILED, diff --git a/packages/core/src/agents/local-executor.test.ts b/packages/core/src/agents/local-executor.test.ts index f056c73a68..f9a518ae56 100644 --- a/packages/core/src/agents/local-executor.test.ts +++ b/packages/core/src/agents/local-executor.test.ts @@ -927,11 +927,11 @@ describe('LocalAgentExecutor', () => { expect(activities).toContainEqual( expect.objectContaining({ type: 'ERROR', - data: { + data: expect.objectContaining({ context: 'tool_call', name: TASK_COMPLETE_TOOL_NAME, error: expectedError, - }, + }), }), ); @@ -1213,11 +1213,11 @@ describe('LocalAgentExecutor', () => { expect(activities).toContainEqual( expect.objectContaining({ type: 'ERROR', - data: { + data: expect.objectContaining({ context: 'tool_call', name: TASK_COMPLETE_TOOL_NAME, error: expect.stringContaining('Output validation failed'), - }, + }), }), ); @@ -1338,11 +1338,11 @@ describe('LocalAgentExecutor', () => { expect(activities).toContainEqual( expect.objectContaining({ type: 'ERROR', - data: { + data: expect.objectContaining({ context: 'tool_call', name: LS_TOOL_NAME, error: toolErrorMessage, - }, + }), }), ); @@ -1699,15 +1699,17 @@ describe('LocalAgentExecutor', () => { expect(activities).toContainEqual( expect.objectContaining({ type: 'THOUGHT_CHUNK', - data: { + data: expect.objectContaining({ text: 'Execution limit reached (MAX_TURNS). Attempting one final recovery turn with a grace period.', - }, + }), }), ); expect(activities).toContainEqual( expect.objectContaining({ type: 'THOUGHT_CHUNK', - data: { text: 'Graceful recovery succeeded.' }, + data: expect.objectContaining({ + text: 'Graceful recovery succeeded.', + }), }), ); }); @@ -1784,9 +1786,9 @@ describe('LocalAgentExecutor', () => { expect(activities).toContainEqual( expect.objectContaining({ type: 'THOUGHT_CHUNK', - data: { + data: expect.objectContaining({ text: 'Execution limit reached (ERROR_NO_COMPLETE_TASK_CALL). Attempting one final recovery turn with a grace period.', - }, + }), }), ); }); @@ -1882,9 +1884,9 @@ describe('LocalAgentExecutor', () => { expect(activities).toContainEqual( expect.objectContaining({ type: 'THOUGHT_CHUNK', - data: { + data: expect.objectContaining({ text: 'Execution limit reached (TIMEOUT). Attempting one final recovery turn with a grace period.', - }, + }), }), ); }); diff --git a/packages/core/src/agents/local-executor.ts b/packages/core/src/agents/local-executor.ts index 7bbecdac7c..fd450c5efa 100644 --- a/packages/core/src/agents/local-executor.ts +++ b/packages/core/src/agents/local-executor.ts @@ -902,6 +902,7 @@ export class LocalAgentExecutor { displayName, description, args, + callId, }); if (toolName === TASK_COMPLETE_TOOL_NAME) { @@ -969,6 +970,7 @@ export class LocalAgentExecutor { }); this.emitActivity('TOOL_CALL_END', { name: toolName, + id: callId, output: 'Output submitted and task completed.', }); } else { @@ -985,6 +987,7 @@ export class LocalAgentExecutor { this.emitActivity('ERROR', { context: 'tool_call', name: toolName, + callId, error, }); } @@ -1009,6 +1012,7 @@ export class LocalAgentExecutor { }); this.emitActivity('TOOL_CALL_END', { name: toolName, + id: callId, output: 'Result submitted and task completed.', }); } else { @@ -1026,6 +1030,7 @@ export class LocalAgentExecutor { this.emitActivity('ERROR', { context: 'tool_call', name: toolName, + callId, error, }); } @@ -1086,18 +1091,21 @@ export class LocalAgentExecutor { if (call.status === 'success') { this.emitActivity('TOOL_CALL_END', { name: toolName, + id: call.request.callId, output: call.response.resultDisplay, }); } else if (call.status === 'error') { this.emitActivity('ERROR', { context: 'tool_call', name: toolName, + callId: call.request.callId, error: call.response.error?.message || 'Unknown error', }); } else if (call.status === 'cancelled') { this.emitActivity('ERROR', { context: 'tool_call', name: toolName, + callId: call.request.callId, error: 'Request cancelled.', }); aborted = true; diff --git a/packages/core/src/core/apiKeyCredentialStorage.test.ts b/packages/core/src/core/apiKeyCredentialStorage.test.ts index b0b0551f4b..b1eb9b21b7 100644 --- a/packages/core/src/core/apiKeyCredentialStorage.test.ts +++ b/packages/core/src/core/apiKeyCredentialStorage.test.ts @@ -9,6 +9,7 @@ import { loadApiKey, saveApiKey, clearApiKey, + resetApiKeyCacheForTesting, } from './apiKeyCredentialStorage.js'; const getCredentialsMock = vi.hoisted(() => vi.fn()); @@ -26,9 +27,10 @@ vi.mock('../mcp/token-storage/hybrid-token-storage.js', () => ({ describe('ApiKeyCredentialStorage', () => { beforeEach(() => { vi.clearAllMocks(); + resetApiKeyCacheForTesting(); }); - it('should load an API key', async () => { + it('should load an API key and cache it', async () => { getCredentialsMock.mockResolvedValue({ serverName: 'default-api-key', token: { @@ -38,19 +40,39 @@ describe('ApiKeyCredentialStorage', () => { updatedAt: Date.now(), }); - const apiKey = await loadApiKey(); - expect(apiKey).toBe('test-key'); - expect(getCredentialsMock).toHaveBeenCalledWith('default-api-key'); + const apiKey1 = await loadApiKey(); + expect(apiKey1).toBe('test-key'); + expect(getCredentialsMock).toHaveBeenCalledTimes(1); + + const apiKey2 = await loadApiKey(); + expect(apiKey2).toBe('test-key'); + expect(getCredentialsMock).toHaveBeenCalledTimes(1); // Should be cached }); - it('should return null if no API key is stored', async () => { + it('should return null if no API key is stored and cache it', async () => { getCredentialsMock.mockResolvedValue(null); - const apiKey = await loadApiKey(); - expect(apiKey).toBeNull(); - expect(getCredentialsMock).toHaveBeenCalledWith('default-api-key'); + const apiKey1 = await loadApiKey(); + expect(apiKey1).toBeNull(); + expect(getCredentialsMock).toHaveBeenCalledTimes(1); + + const apiKey2 = await loadApiKey(); + expect(apiKey2).toBeNull(); + expect(getCredentialsMock).toHaveBeenCalledTimes(1); // Should be cached }); - it('should save an API key', async () => { + it('should save an API key and clear cache', async () => { + getCredentialsMock.mockResolvedValue({ + serverName: 'default-api-key', + token: { + accessToken: 'old-key', + tokenType: 'ApiKey', + }, + updatedAt: Date.now(), + }); + + await loadApiKey(); + expect(getCredentialsMock).toHaveBeenCalledTimes(1); + await saveApiKey('new-key'); expect(setCredentialsMock).toHaveBeenCalledWith( expect.objectContaining({ @@ -61,28 +83,62 @@ describe('ApiKeyCredentialStorage', () => { }), }), ); + + getCredentialsMock.mockResolvedValue({ + serverName: 'default-api-key', + token: { + accessToken: 'new-key', + tokenType: 'ApiKey', + }, + updatedAt: Date.now(), + }); + + await loadApiKey(); + expect(getCredentialsMock).toHaveBeenCalledTimes(2); // Should have fetched again }); - it('should clear an API key when saving empty key', async () => { + it('should clear an API key and clear cache', async () => { + getCredentialsMock.mockResolvedValue({ + serverName: 'default-api-key', + token: { + accessToken: 'old-key', + tokenType: 'ApiKey', + }, + updatedAt: Date.now(), + }); + + await loadApiKey(); + expect(getCredentialsMock).toHaveBeenCalledTimes(1); + + await clearApiKey(); + expect(deleteCredentialsMock).toHaveBeenCalledWith('default-api-key'); + + getCredentialsMock.mockResolvedValue(null); + await loadApiKey(); + expect(getCredentialsMock).toHaveBeenCalledTimes(2); // Should have fetched again + }); + + it('should clear an API key and cache when saving empty key', async () => { await saveApiKey(''); expect(deleteCredentialsMock).toHaveBeenCalledWith('default-api-key'); expect(setCredentialsMock).not.toHaveBeenCalled(); }); - it('should clear an API key when saving null key', async () => { + it('should clear an API key and cache when saving null key', async () => { await saveApiKey(null); expect(deleteCredentialsMock).toHaveBeenCalledWith('default-api-key'); expect(setCredentialsMock).not.toHaveBeenCalled(); }); - it('should clear an API key', async () => { - await clearApiKey(); - expect(deleteCredentialsMock).toHaveBeenCalledWith('default-api-key'); - }); - - it('should not throw when clearing an API key fails', async () => { + it('should not throw when clearing an API key fails during saveApiKey', async () => { deleteCredentialsMock.mockRejectedValueOnce(new Error('Failed to delete')); await expect(saveApiKey('')).resolves.not.toThrow(); expect(deleteCredentialsMock).toHaveBeenCalledWith('default-api-key'); }); + + it('should not throw when clearing an API key fails during clearApiKey', async () => { + deleteCredentialsMock.mockRejectedValueOnce(new Error('Failed to delete')); + await expect(clearApiKey()).resolves.not.toThrow(); + expect(deleteCredentialsMock).toHaveBeenCalledWith('default-api-key'); + }); }); diff --git a/packages/core/src/core/apiKeyCredentialStorage.ts b/packages/core/src/core/apiKeyCredentialStorage.ts index 4836ba075b..41b3a0276a 100644 --- a/packages/core/src/core/apiKeyCredentialStorage.ts +++ b/packages/core/src/core/apiKeyCredentialStorage.ts @@ -7,29 +7,46 @@ import { HybridTokenStorage } from '../mcp/token-storage/hybrid-token-storage.js'; import type { OAuthCredentials } from '../mcp/token-storage/types.js'; import { debugLogger } from '../utils/debugLogger.js'; +import { createCache } from '../utils/cache.js'; const KEYCHAIN_SERVICE_NAME = 'gemini-cli-api-key'; const DEFAULT_API_KEY_ENTRY = 'default-api-key'; const storage = new HybridTokenStorage(KEYCHAIN_SERVICE_NAME); +// Cache to store the results of loadApiKey to avoid redundant keychain access. +const apiKeyCache = createCache>({ + storage: 'map', + defaultTtl: 30000, // 30 seconds +}); + +/** + * Resets the API key cache. Used exclusively for test isolation. + * @internal + */ +export function resetApiKeyCacheForTesting() { + apiKeyCache.clear(); +} + /** * Load cached API key */ export async function loadApiKey(): Promise { - try { - const credentials = await storage.getCredentials(DEFAULT_API_KEY_ENTRY); + return apiKeyCache.getOrCreate(DEFAULT_API_KEY_ENTRY, async () => { + try { + const credentials = await storage.getCredentials(DEFAULT_API_KEY_ENTRY); - if (credentials?.token?.accessToken) { - return credentials.token.accessToken; + if (credentials?.token?.accessToken) { + return credentials.token.accessToken; + } + + return null; + } catch (error: unknown) { + // Log other errors but don't crash, just return null so user can re-enter key + debugLogger.error('Failed to load API key from storage:', error); + return null; } - - return null; - } catch (error: unknown) { - // Log other errors but don't crash, just return null so user can re-enter key - debugLogger.error('Failed to load API key from storage:', error); - return null; - } + }); } /** @@ -38,6 +55,7 @@ export async function loadApiKey(): Promise { export async function saveApiKey( apiKey: string | null | undefined, ): Promise { + apiKeyCache.delete(DEFAULT_API_KEY_ENTRY); if (!apiKey || apiKey.trim() === '') { try { await storage.deleteCredentials(DEFAULT_API_KEY_ENTRY); @@ -65,6 +83,7 @@ export async function saveApiKey( * Clear cached API key */ export async function clearApiKey(): Promise { + apiKeyCache.delete(DEFAULT_API_KEY_ENTRY); try { await storage.deleteCredentials(DEFAULT_API_KEY_ENTRY); } catch (error: unknown) { diff --git a/packages/core/src/fallback/handler.ts b/packages/core/src/fallback/handler.ts index ed87454003..1946e3a635 100644 --- a/packages/core/src/fallback/handler.ts +++ b/packages/core/src/fallback/handler.ts @@ -18,7 +18,7 @@ import { applyAvailabilityTransition, } from '../availability/policyHelpers.js'; -const UPGRADE_URL_PAGE = 'https://goo.gle/set-up-gemini-code-assist'; +export const UPGRADE_URL_PAGE = 'https://goo.gle/set-up-gemini-code-assist'; export async function handleFallback( config: Config, diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 5dfd74ad61..47af5f76e1 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -48,6 +48,7 @@ export * from './scheduler/tool-executor.js'; export * from './core/recordingContentGenerator.js'; export * from './fallback/types.js'; +export * from './fallback/handler.js'; export * from './code_assist/codeAssist.js'; export * from './code_assist/oauth2.js';