diff --git a/.github/workflows/gemini-scheduled-stale-issue-closer.yml b/.github/workflows/gemini-scheduled-stale-issue-closer.yml index fb86d8e70e..fe9032983b 100644 --- a/.github/workflows/gemini-scheduled-stale-issue-closer.yml +++ b/.github/workflows/gemini-scheduled-stale-issue-closer.yml @@ -79,8 +79,14 @@ jobs: continue; } - // Skip if it has a maintainer label - if (issue.labels.some(label => label.name.toLowerCase().includes('maintainer'))) { + // Skip if it has a maintainer, help wanted, or Public Roadmap label + const rawLabels = issue.labels.map((l) => l.name); + const lowercaseLabels = rawLabels.map((l) => l.toLowerCase()); + if ( + lowercaseLabels.some((l) => l.includes('maintainer')) || + lowercaseLabels.includes('help wanted') || + rawLabels.includes('🗓️ Public Roadmap') + ) { continue; } diff --git a/.github/workflows/gemini-scheduled-stale-pr-closer.yml b/.github/workflows/gemini-scheduled-stale-pr-closer.yml new file mode 100644 index 0000000000..04b6e37246 --- /dev/null +++ b/.github/workflows/gemini-scheduled-stale-pr-closer.yml @@ -0,0 +1,205 @@ +name: 'Gemini Scheduled Stale PR Closer' + +on: + schedule: + - cron: '0 2 * * *' # Every day at 2 AM UTC + pull_request: + types: ['opened', 'edited'] + workflow_dispatch: + inputs: + dry_run: + description: 'Run in dry-run mode' + required: false + default: false + type: 'boolean' + +jobs: + close-stale-prs: + if: "github.repository == 'google-gemini/gemini-cli'" + runs-on: 'ubuntu-latest' + permissions: + pull-requests: 'write' + issues: 'write' + steps: + - name: 'Generate GitHub App Token' + id: 'generate_token' + uses: 'actions/create-github-app-token@v1' + with: + app-id: '${{ secrets.APP_ID }}' + private-key: '${{ secrets.PRIVATE_KEY }}' + owner: '${{ github.repository_owner }}' + repositories: 'gemini-cli' + + - name: 'Process Stale PRs' + uses: 'actions/github-script@v7' + env: + DRY_RUN: '${{ inputs.dry_run }}' + with: + github-token: '${{ steps.generate_token.outputs.token }}' + script: | + const dryRun = process.env.DRY_RUN === 'true'; + const thirtyDaysAgo = new Date(); + thirtyDaysAgo.setDate(thirtyDaysAgo.getDate() - 30); + + // 1. Fetch maintainers for verification + let maintainerLogins = new Set(); + try { + const members = await github.paginate(github.rest.teams.listMembersInOrg, { + org: context.repo.owner, + team_slug: 'gemini-cli-maintainers' + }); + maintainerLogins = new Set(members.map(m => m.login)); + } catch (e) { + core.warning('Failed to fetch team members'); + } + + const isMaintainer = (login, assoc) => { + if (maintainerLogins.size > 0) return maintainerLogins.has(login); + return ['OWNER', 'MEMBER', 'COLLABORATOR'].includes(assoc); + }; + + // 2. Determine which PRs to check + let prs = []; + if (context.eventName === 'pull_request') { + const { data: pr } = await github.rest.pulls.get({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: context.payload.pull_request.number + }); + prs = [pr]; + } else { + prs = await github.paginate(github.rest.pulls.list, { + owner: context.repo.owner, + repo: context.repo.repo, + state: 'open', + per_page: 100 + }); + } + + for (const pr of prs) { + const maintainerPr = isMaintainer(pr.user.login, pr.author_association); + const isBot = pr.user.type === 'Bot' || pr.user.login.endsWith('[bot]'); + + // Detection Logic for Linked Issues + // Check 1: Official GitHub "Closing Issue" link (GraphQL) + const linkedIssueQuery = `query($owner:String!, $repo:String!, $number:Int!) { + repository(owner:$owner, name:$repo) { + pullRequest(number:$number) { + closingIssuesReferences(first: 1) { totalCount } + } + } + }`; + + let hasClosingLink = false; + try { + const res = await github.graphql(linkedIssueQuery, { + owner: context.repo.owner, repo: context.repo.repo, number: pr.number + }); + hasClosingLink = res.repository.pullRequest.closingIssuesReferences.totalCount > 0; + } catch (e) {} + + // Check 2: Regex for mentions (e.g., "Related to #123", "Part of #123", "#123") + // We check for # followed by numbers or direct URLs to issues. + const body = pr.body || ''; + const mentionRegex = /(?:#|https:\/\/github\.com\/[^\/]+\/[^\/]+\/issues\/)(\d+)/i; + const hasMentionLink = mentionRegex.test(body); + + const hasLinkedIssue = hasClosingLink || hasMentionLink; + + // Logic for Closed PRs (Auto-Reopen) + if (pr.state === 'closed' && context.eventName === 'pull_request' && context.payload.action === 'edited') { + if (hasLinkedIssue) { + core.info(`PR #${pr.number} now has a linked issue. Reopening.`); + if (!dryRun) { + await github.rest.pulls.update({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: pr.number, + state: 'open' + }); + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: pr.number, + body: "Thank you for linking an issue! This pull request has been automatically reopened." + }); + } + } + continue; + } + + // Logic for Open PRs (Immediate Closure) + if (pr.state === 'open' && !maintainerPr && !hasLinkedIssue && !isBot) { + core.info(`PR #${pr.number} is missing a linked issue. Closing.`); + if (!dryRun) { + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: pr.number, + body: "Hi there! Thank you for your contribution to Gemini CLI. \n\nTo improve our contribution process and better track changes, we now require all pull requests to be associated with an existing issue, as announced in our [recent discussion](https://github.com/google-gemini/gemini-cli/discussions/16706) and as detailed in our [CONTRIBUTING.md](https://github.com/google-gemini/gemini-cli/blob/main/CONTRIBUTING.md#1-link-to-an-existing-issue).\n\nThis pull request is being closed because it is not currently linked to an issue. **Once you have updated the description of this PR to link an issue (e.g., by adding `Fixes #123` or `Related to #123`), it will be automatically reopened.**\n\n**How to link an issue:**\nAdd a keyword followed by the issue number (e.g., `Fixes #123`) in the description of your pull request. For more details on supported keywords and how linking works, please refer to the [GitHub Documentation on linking pull requests to issues](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue).\n\nThank you for your understanding and for being a part of our community!" + }); + await github.rest.pulls.update({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: pr.number, + state: 'closed' + }); + } + continue; + } + + // Staleness check (Scheduled runs only) + if (pr.state === 'open' && context.eventName !== 'pull_request') { + const labels = pr.labels.map(l => l.name.toLowerCase()); + if (labels.includes('help wanted') || labels.includes('🔒 maintainer only')) continue; + + let lastActivity = new Date(0); + try { + const reviews = await github.paginate(github.rest.pulls.listReviews, { + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: pr.number + }); + for (const r of reviews) { + if (isMaintainer(r.user.login, r.author_association)) { + const d = new Date(r.submitted_at || r.updated_at); + if (d > lastActivity) lastActivity = d; + } + } + const comments = await github.paginate(github.rest.issues.listComments, { + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: pr.number + }); + for (const c of comments) { + if (isMaintainer(c.user.login, c.author_association)) { + const d = new Date(c.updated_at); + if (d > lastActivity) lastActivity = d; + } + } + } catch (e) {} + + if (maintainerPr) { + const d = new Date(pr.created_at); + if (d > lastActivity) lastActivity = d; + } + + if (lastActivity < thirtyDaysAgo) { + core.info(`PR #${pr.number} is stale.`); + if (!dryRun) { + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: pr.number, + body: "Hi there! Thank you for your contribution to Gemini CLI. We really appreciate the time and effort you've put into this pull request.\n\nTo keep our backlog manageable and ensure we're focusing on current priorities, we are closing pull requests that haven't seen maintainer activity for 30 days. Currently, the team is prioritizing work associated with **🔒 maintainer only** or **help wanted** issues.\n\nIf you believe this change is still critical, please feel free to comment with updated details. Otherwise, we encourage contributors to focus on open issues labeled as **help wanted**. Thank you for your understanding!" + }); + await github.rest.pulls.update({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: pr.number, + state: 'closed' + }); + } + } + } + } diff --git a/.github/workflows/label-enforcer.yml b/.github/workflows/label-enforcer.yml deleted file mode 100644 index 98b8a3f554..0000000000 --- a/.github/workflows/label-enforcer.yml +++ /dev/null @@ -1,119 +0,0 @@ -name: '🏷️ Enforce Restricted Label Permissions' - -on: - issues: - types: - - 'labeled' - - 'unlabeled' - -jobs: - enforce-label: - # Run this job only when restricted labels are changed - if: |- - ${{ (github.event.label.name == 'help wanted' || github.event.label.name == 'status/need-triage' || github.event.label.name == '🔒 maintainer only') && - (github.repository == 'google-gemini/gemini-cli' || github.repository == 'google-gemini/maintainers-gemini-cli') }} - runs-on: 'ubuntu-latest' - permissions: - issues: 'write' - steps: - - name: 'Generate GitHub App Token' - id: 'generate_token' - env: - APP_ID: '${{ secrets.APP_ID }}' - if: |- - ${{ env.APP_ID != '' }} - uses: 'actions/create-github-app-token@a8d616148505b5069dccd32f177bb87d7f39123b' # ratchet:actions/create-github-app-token@v2 - with: - app-id: '${{ secrets.APP_ID }}' - private-key: '${{ secrets.PRIVATE_KEY }}' - - - name: 'Check if user is in the maintainers team' - uses: 'actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea' - with: - github-token: '${{ steps.generate_token.outputs.token || secrets.GITHUB_TOKEN }}' - script: |- - const org = context.repo.owner; - const username = context.payload.sender.login; - const team_slug = 'gemini-cli-maintainers'; - const action = context.payload.action; // 'labeled' or 'unlabeled' - const labelName = context.payload.label.name; - - // Skip if the change was made by a bot to avoid infinite loops - if (username === 'github-actions[bot]' || username === 'gemini-cli[bot]' || username.endsWith('[bot]')) { - core.info('Change made by a bot. Skipping.'); - return; - } - - try { - // Check repository permission level directly. - // This is more robust than team membership as it doesn't require Org-level read permissions - // and correctly handles Repo Admins/Writers who might not be in the specific team. - const { data: { permission } } = await github.rest.repos.getCollaboratorPermissionLevel({ - owner: org, - repo: context.repo.repo, - username, - }); - - if (permission === 'admin' || permission === 'write') { - core.info(`${username} has '${permission}' permission. Allowed.`); - return; - } - - core.info(`${username} has '${permission}' permission (needs 'write' or 'admin'). Reverting '${action}' action for '${labelName}' label.`); - } catch (error) { - core.error(`Failed to check permissions for ${username}: ${error.message}`); - // Fall through to revert logic if we can't verify permissions (fail safe) - } - - // If we are here, the user is NOT authorized. - if (true) { // wrapping block to preserve variable scope if needed - if (action === 'labeled') { - // 1. Remove the label if added by a non-maintainer - await github.rest.issues.removeLabel ({ - owner: org, - repo: context.repo.repo, - issue_number: context.issue.number, - name: labelName - }); - - // 2. Post a polite comment - const comment = ` - Hi @${username}, thank you for your interest in helping triage issues! - - The \`${labelName}\` label is reserved for project maintainers to apply. This helps us ensure that an issue is ready and properly vetted for community contribution. - - A maintainer will review this issue soon. Please see our [CONTRIBUTING.md](https://github.com/google-gemini/gemini-cli/blob/main/CONTRIBUTING.md) for more details on our labeling process. - `.trim().replace(/^[ ]+/gm, ''); - - await github.rest.issues.createComment ({ - owner: org, - repo: context.repo.repo, - issue_number: context.issue.number, - body: comment - }); - } else if (action === 'unlabeled') { - // 1. Add the label back if removed by a non-maintainer - await github.rest.issues.addLabels ({ - owner: org, - repo: context.repo.repo, - issue_number: context.issue.number, - labels: [labelName] - }); - - // 2. Post a polite comment - const comment = ` - Hi @${username}, it looks like the \`${labelName}\` label was removed. - - This label is managed by project maintainers. We've added it back to ensure the issue remains visible to potential contributors until a maintainer decides otherwise. - - Thank you for your understanding! - `.trim().replace(/^[ ]+/gm, ''); - - await github.rest.issues.createComment ({ - owner: org, - repo: context.repo.repo, - issue_number: context.issue.number, - body: comment - }); - } - } diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index fd79d914dc..4a975869f5 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -40,5 +40,5 @@ jobs: If this is still relevant, you are welcome to reopen or leave a comment. Thanks for contributing! days-before-stale: 60 days-before-close: 14 - exempt-issue-labels: 'pinned,security' - exempt-pr-labels: 'pinned,security' + exempt-issue-labels: 'pinned,security,🔒 maintainer only,help wanted,🗓️ Public Roadmap' + exempt-pr-labels: 'pinned,security,🔒 maintainer only,help wanted,🗓️ Public Roadmap' diff --git a/README.md b/README.md index 77a7ba3647..22e258e289 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,12 @@ npm install -g @google/gemini-cli brew install gemini-cli ``` +#### Install globally with MacPorts (macOS) + +```bash +sudo port install gemini-cli +``` + #### Install with Anaconda (for restricted environments) ```bash diff --git a/docs/cli/uninstall.md b/docs/cli/uninstall.md index 9523e34d8d..e96ddc5acf 100644 --- a/docs/cli/uninstall.md +++ b/docs/cli/uninstall.md @@ -45,3 +45,21 @@ npm uninstall -g @google/gemini-cli ``` This command completely removes the package from your system. + +## Method 3: Homebrew + +If you installed the CLI globally using Homebrew (e.g., +`brew install gemini-cli`), use the `brew uninstall` command to remove it. + +```bash +brew uninstall gemini-cli +``` + +## Method 4: MacPorts + +If you installed the CLI globally using MacPorts (e.g., +`sudo port install gemini-cli`), use the `port uninstall` command to remove it. + +```bash +sudo port uninstall gemini-cli +``` diff --git a/docs/extensions/index.md b/docs/extensions/index.md index 8f71d1c184..a2b0598388 100644 --- a/docs/extensions/index.md +++ b/docs/extensions/index.md @@ -324,7 +324,7 @@ The `hooks.json` file contains a `hooks` object where keys are ```json { "hooks": { - "before_agent": [ + "BeforeAgent": [ { "hooks": [ { diff --git a/docs/hooks/best-practices.md b/docs/hooks/best-practices.md index 559f3f18bb..316aacbc29 100644 --- a/docs/hooks/best-practices.md +++ b/docs/hooks/best-practices.md @@ -91,6 +91,7 @@ spawning a process for irrelevant events. "hooks": [ { "name": "validate-writes", + "type": "command", "command": "./validate.sh" } ] @@ -584,6 +585,7 @@ defaults to 60 seconds, but you should set stricter limits for fast hooks. "hooks": [ { "name": "fast-validator", + "type": "command", "command": "./hooks/validate.sh", "timeout": 5000 // 5 seconds } diff --git a/docs/hooks/index.md b/docs/hooks/index.md index 24c843128a..dc1c036ade 100644 --- a/docs/hooks/index.md +++ b/docs/hooks/index.md @@ -104,9 +104,8 @@ You can filter which specific tools or triggers fire your hook using the ## Configuration -Hook definitions are configured in `settings.json`. Gemini CLI merges -configurations from multiple layers in the following order of precedence -(highest to lowest): +Hooks are configured in `settings.json`. Gemini CLI merges configurations from +multiple layers in the following order of precedence (highest to lowest): 1. **Project settings**: `.gemini/settings.json` in the current directory. 2. **User settings**: `~/.gemini/settings.json`. @@ -126,8 +125,7 @@ configurations from multiple layers in the following order of precedence "name": "security-check", "type": "command", "command": "$GEMINI_PROJECT_DIR/.gemini/hooks/security.sh", - "timeout": 5000, - "sequential": false + "timeout": 5000 } ] } @@ -136,6 +134,18 @@ configurations from multiple layers in the following order of precedence } ``` +#### Hook configuration fields + +| Field | Type | Required | Description | +| :------------ | :----- | :-------- | :------------------------------------------------------------------- | +| `type` | string | **Yes** | The execution engine. Currently only `"command"` is supported. | +| `command` | string | **Yes\*** | The shell command to execute. (Required when `type` is `"command"`). | +| `name` | string | No | A friendly name for identifying the hook in logs and CLI commands. | +| `timeout` | number | No | Execution timeout in milliseconds (default: 60000). | +| `description` | string | No | A brief explanation of the hook's purpose. | + +--- + ### Environment variables Hooks are executed with a sanitized environment. diff --git a/docs/hooks/reference.md b/docs/hooks/reference.md index 2feeedf940..a86474ea85 100644 --- a/docs/hooks/reference.md +++ b/docs/hooks/reference.md @@ -18,6 +18,31 @@ including JSON schemas and API details. --- +## Configuration schema + +Hooks are defined in `settings.json` within the `hooks` object. Each event +(e.g., `BeforeTool`) contains an array of **hook definitions**. + +### Hook definition + +| Field | Type | Required | Description | +| :----------- | :-------- | :------- | :-------------------------------------------------------------------------------------- | +| `matcher` | `string` | No | A regex (for tools) or exact string (for lifecycle) to filter when the hook runs. | +| `sequential` | `boolean` | No | If `true`, hooks in this group run one after another. If `false`, they run in parallel. | +| `hooks` | `array` | **Yes** | An array of **hook configurations**. | + +### Hook configuration + +| Field | Type | Required | Description | +| :------------ | :------- | :-------- | :------------------------------------------------------------------- | +| `type` | `string` | **Yes** | The execution engine. Currently only `"command"` is supported. | +| `command` | `string` | **Yes\*** | The shell command to execute. (Required when `type` is `"command"`). | +| `name` | `string` | No | A friendly name for identifying the hook in logs and CLI commands. | +| `timeout` | `number` | No | Execution timeout in milliseconds (default: 60000). | +| `description` | `string` | No | A brief explanation of the hook's purpose. | + +--- + ## Base input schema All hooks receive these common fields via `stdin`: diff --git a/docs/hooks/writing-hooks.md b/docs/hooks/writing-hooks.md index 7b66a90a65..33357fccb2 100644 --- a/docs/hooks/writing-hooks.md +++ b/docs/hooks/writing-hooks.md @@ -194,6 +194,7 @@ main().catch((err) => { "hooks": [ { "name": "intent-filter", + "type": "command", "command": "node .gemini/hooks/filter-tools.js" } ] @@ -234,7 +235,13 @@ security. "SessionStart": [ { "matcher": "startup", - "hooks": [{ "name": "init", "command": "node .gemini/hooks/init.js" }] + "hooks": [ + { + "name": "init", + "type": "command", + "command": "node .gemini/hooks/init.js" + } + ] } ], "BeforeAgent": [ @@ -243,6 +250,7 @@ security. "hooks": [ { "name": "memory", + "type": "command", "command": "node .gemini/hooks/inject-memories.js" } ] @@ -252,7 +260,11 @@ security. { "matcher": "*", "hooks": [ - { "name": "filter", "command": "node .gemini/hooks/rag-filter.js" } + { + "name": "filter", + "type": "command", + "command": "node .gemini/hooks/rag-filter.js" + } ] } ], @@ -260,7 +272,11 @@ security. { "matcher": "write_file", "hooks": [ - { "name": "security", "command": "node .gemini/hooks/security.js" } + { + "name": "security", + "type": "command", + "command": "node .gemini/hooks/security.js" + } ] } ], @@ -268,7 +284,11 @@ security. { "matcher": "*", "hooks": [ - { "name": "record", "command": "node .gemini/hooks/record.js" } + { + "name": "record", + "type": "command", + "command": "node .gemini/hooks/record.js" + } ] } ], @@ -276,7 +296,11 @@ security. { "matcher": "*", "hooks": [ - { "name": "validate", "command": "node .gemini/hooks/validate.js" } + { + "name": "validate", + "type": "command", + "command": "node .gemini/hooks/validate.js" + } ] } ], @@ -284,7 +308,11 @@ security. { "matcher": "exit", "hooks": [ - { "name": "save", "command": "node .gemini/hooks/consolidate.js" } + { + "name": "save", + "type": "command", + "command": "node .gemini/hooks/consolidate.js" + } ] } ] diff --git a/packages/a2a-server/src/config/config.ts b/packages/a2a-server/src/config/config.ts index b9e895dde0..732d6e2f84 100644 --- a/packages/a2a-server/src/config/config.ts +++ b/packages/a2a-server/src/config/config.ts @@ -97,6 +97,7 @@ export async function loadConfig( previewFeatures: settings.general?.previewFeatures, interactive: true, enableInteractiveShell: true, + ptyInfo: 'auto', }; const fileService = new FileDiscoveryService(workspaceDir); diff --git a/packages/cli/src/core/auth.test.ts b/packages/cli/src/core/auth.test.ts index 366e5c9137..c844ee6f93 100644 --- a/packages/cli/src/core/auth.test.ts +++ b/packages/cli/src/core/auth.test.ts @@ -6,18 +6,20 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'; import { performInitialAuth } from './auth.js'; -import { type Config } from '@google/gemini-cli-core'; +import { + type Config, + ValidationRequiredError, + AuthType, +} from '@google/gemini-cli-core'; -vi.mock('@google/gemini-cli-core', () => ({ - AuthType: { - OAUTH: 'oauth', - }, - getErrorMessage: (e: unknown) => (e as Error).message, -})); - -const AuthType = { - OAUTH: 'oauth', -} as const; +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const actual = + await importOriginal(); + return { + ...actual, + getErrorMessage: (e: unknown) => (e as Error).message, + }; +}); describe('auth', () => { let mockConfig: Config; @@ -37,10 +39,12 @@ describe('auth', () => { it('should return null on successful auth', async () => { const result = await performInitialAuth( mockConfig, - AuthType.OAUTH as unknown as Parameters[1], + AuthType.LOGIN_WITH_GOOGLE, ); expect(result).toBeNull(); - expect(mockConfig.refreshAuth).toHaveBeenCalledWith(AuthType.OAUTH); + expect(mockConfig.refreshAuth).toHaveBeenCalledWith( + AuthType.LOGIN_WITH_GOOGLE, + ); }); it('should return error message on failed auth', async () => { @@ -48,9 +52,25 @@ describe('auth', () => { vi.mocked(mockConfig.refreshAuth).mockRejectedValue(error); const result = await performInitialAuth( mockConfig, - AuthType.OAUTH as unknown as Parameters[1], + AuthType.LOGIN_WITH_GOOGLE, ); expect(result).toBe('Failed to login. Message: Auth failed'); - expect(mockConfig.refreshAuth).toHaveBeenCalledWith(AuthType.OAUTH); + expect(mockConfig.refreshAuth).toHaveBeenCalledWith( + AuthType.LOGIN_WITH_GOOGLE, + ); + }); + + it('should return null if refreshAuth throws ValidationRequiredError', async () => { + vi.mocked(mockConfig.refreshAuth).mockRejectedValue( + new ValidationRequiredError('Validation required'), + ); + const result = await performInitialAuth( + mockConfig, + AuthType.LOGIN_WITH_GOOGLE, + ); + expect(result).toBeNull(); + expect(mockConfig.refreshAuth).toHaveBeenCalledWith( + AuthType.LOGIN_WITH_GOOGLE, + ); }); }); diff --git a/packages/cli/src/core/auth.ts b/packages/cli/src/core/auth.ts index f4f4963bc7..7b1e8c8277 100644 --- a/packages/cli/src/core/auth.ts +++ b/packages/cli/src/core/auth.ts @@ -8,6 +8,7 @@ import { type AuthType, type Config, getErrorMessage, + ValidationRequiredError, } from '@google/gemini-cli-core'; /** @@ -29,6 +30,11 @@ export async function performInitialAuth( // The console.log is intentionally left out here. // We can add a dedicated startup message later if needed. } catch (e) { + if (e instanceof ValidationRequiredError) { + // Don't treat validation required as a fatal auth error during startup. + // This allows the React UI to load and show the ValidationDialog. + return null; + } return `Failed to login. Message: ${getErrorMessage(e)}`; } diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index ff73dcfdfa..20f022021a 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -61,6 +61,8 @@ import { SessionStartSource, SessionEndReason, getVersion, + ValidationCancelledError, + ValidationRequiredError, type FetchAdminControlsResponse, } from '@google/gemini-cli-core'; import { @@ -406,8 +408,19 @@ export async function main() { await partialConfig.refreshAuth(authType); } } catch (err) { - debugLogger.error('Error authenticating:', err); - initialAuthFailed = true; + if (err instanceof ValidationCancelledError) { + // User cancelled verification, exit immediately. + await runExitCleanup(); + process.exit(ExitCodes.SUCCESS); + } + + // If validation is required, we don't treat it as a fatal failure. + // We allow the app to start, and the React-based ValidationDialog + // will handle it. + if (!(err instanceof ValidationRequiredError)) { + debugLogger.error('Error authenticating:', err); + initialAuthFailed = true; + } } } diff --git a/packages/cli/src/ui/AppContainer.tsx b/packages/cli/src/ui/AppContainer.tsx index 43553efe14..4f10e10645 100644 --- a/packages/cli/src/ui/AppContainer.tsx +++ b/packages/cli/src/ui/AppContainer.tsx @@ -63,6 +63,7 @@ import { SessionStartSource, SessionEndReason, generateSummary, + ChangeAuthRequestedError, } from '@google/gemini-cli-core'; import { validateAuthMethod } from '../config/auth.js'; import process from 'node:process'; @@ -527,7 +528,7 @@ export const AppContainer = (props: AppContainerProps) => { onAuthError, apiKeyDefaultValue, reloadApiKey, - } = useAuthCommand(settings, config); + } = useAuthCommand(settings, config, initializationResult.authError); const [authContext, setAuthContext] = useState<{ requiresRestart?: boolean }>( {}, ); @@ -549,6 +550,7 @@ export const AppContainer = (props: AppContainerProps) => { historyManager, userTier, setModelSwitchedFromQuotaError, + onShowAuthSelection: () => setAuthState(AuthState.Updating), }); // Derive auth state variables for backward compatibility with UIStateContext @@ -558,7 +560,7 @@ export const AppContainer = (props: AppContainerProps) => { // Session browser and resume functionality const isGeminiClientInitialized = config.getGeminiClient()?.isInitialized(); - const { loadHistoryForResume } = useSessionResume({ + const { loadHistoryForResume, isResuming } = useSessionResume({ config, historyManager, refreshStatic, @@ -598,6 +600,9 @@ export const AppContainer = (props: AppContainerProps) => { await config.refreshAuth(authType); setAuthState(AuthState.Authenticated); } catch (e) { + if (e instanceof ChangeAuthRequestedError) { + return; + } onAuthError( `Failed to authenticate: ${e instanceof Error ? e.message : String(e)}`, ); @@ -1013,6 +1018,7 @@ Logging in with Google... Restarting Gemini CLI to continue. isConfigInitialized && !initError && !isProcessing && + !isResuming && !!slashCommands && (streamingState === StreamingState.Idle || streamingState === StreamingState.Responding) && @@ -1665,6 +1671,7 @@ Logging in with Google... Restarting Gemini CLI to continue. inputWidth, suggestionsWidth, isInputActive, + isResuming, shouldShowIdePrompt, isFolderTrustDialogOpen: isFolderTrustDialogOpen ?? false, isTrustedFolder, @@ -1761,6 +1768,7 @@ Logging in with Google... Restarting Gemini CLI to continue. inputWidth, suggestionsWidth, isInputActive, + isResuming, shouldShowIdePrompt, isFolderTrustDialogOpen, isTrustedFolder, diff --git a/packages/cli/src/ui/auth/useAuth.ts b/packages/cli/src/ui/auth/useAuth.ts index 7b37e2d421..2b61265890 100644 --- a/packages/cli/src/ui/auth/useAuth.ts +++ b/packages/cli/src/ui/auth/useAuth.ts @@ -34,12 +34,16 @@ export function validateAuthMethodWithSettings( return validateAuthMethod(authType); } -export const useAuthCommand = (settings: LoadedSettings, config: Config) => { +export const useAuthCommand = ( + settings: LoadedSettings, + config: Config, + initialAuthError: string | null = null, +) => { const [authState, setAuthState] = useState( - AuthState.Unauthenticated, + initialAuthError ? AuthState.Updating : AuthState.Unauthenticated, ); - const [authError, setAuthError] = useState(null); + const [authError, setAuthError] = useState(initialAuthError); const [apiKeyDefaultValue, setApiKeyDefaultValue] = useState< string | undefined >(undefined); diff --git a/packages/cli/src/ui/components/Composer.tsx b/packages/cli/src/ui/components/Composer.tsx index 9a550a323e..de3ecebd19 100644 --- a/packages/cli/src/ui/components/Composer.tsx +++ b/packages/cli/src/ui/components/Composer.tsx @@ -71,8 +71,12 @@ export const Composer = ({ isFocused = true }: { isFocused?: boolean }) => { /> )} - {(!uiState.slashCommands || !uiState.isConfigInitialized) && ( - + {(!uiState.slashCommands || + !uiState.isConfigInitialized || + uiState.isResuming) && ( + )} diff --git a/packages/cli/src/ui/components/ConfigInitDisplay.tsx b/packages/cli/src/ui/components/ConfigInitDisplay.tsx index b1dc71ff74..a47e16daff 100644 --- a/packages/cli/src/ui/components/ConfigInitDisplay.tsx +++ b/packages/cli/src/ui/components/ConfigInitDisplay.tsx @@ -15,13 +15,17 @@ import { import { GeminiSpinner } from './GeminiRespondingSpinner.js'; import { theme } from '../semantic-colors.js'; -export const ConfigInitDisplay = () => { - const [message, setMessage] = useState('Initializing...'); +export const ConfigInitDisplay = ({ + message: initialMessage = 'Initializing...', +}: { + message?: string; +}) => { + const [message, setMessage] = useState(initialMessage); useEffect(() => { const onChange = (clients?: Map) => { if (!clients || clients.size === 0) { - setMessage(`Initializing...`); + setMessage(initialMessage); return; } let connected = 0; @@ -39,12 +43,18 @@ export const ConfigInitDisplay = () => { const displayedServers = connecting.slice(0, maxDisplay).join(', '); const remaining = connecting.length - maxDisplay; const suffix = remaining > 0 ? `, +${remaining} more` : ''; + const mcpMessage = `Connecting to MCP servers... (${connected}/${clients.size}) - Waiting for: ${displayedServers}${suffix}`; setMessage( - `Connecting to MCP servers... (${connected}/${clients.size}) - Waiting for: ${displayedServers}${suffix}`, + initialMessage && initialMessage !== 'Initializing...' + ? `${initialMessage} (${mcpMessage})` + : mcpMessage, ); } else { + const mcpMessage = `Connecting to MCP servers... (${connected}/${clients.size})`; setMessage( - `Connecting to MCP servers... (${connected}/${clients.size})`, + initialMessage && initialMessage !== 'Initializing...' + ? `${initialMessage} (${mcpMessage})` + : mcpMessage, ); } }; @@ -53,7 +63,7 @@ export const ConfigInitDisplay = () => { return () => { coreEvents.off(CoreEvent.McpClientUpdate, onChange); }; - }, []); + }, [initialMessage]); return ( diff --git a/packages/cli/src/ui/components/ValidationDialog.test.tsx b/packages/cli/src/ui/components/ValidationDialog.test.tsx index ac938202ab..0e50781342 100644 --- a/packages/cli/src/ui/components/ValidationDialog.test.tsx +++ b/packages/cli/src/ui/components/ValidationDialog.test.tsx @@ -17,6 +17,7 @@ import { } from 'vitest'; import { ValidationDialog } from './ValidationDialog.js'; import { RadioButtonSelect } from './shared/RadioButtonSelect.js'; +import type { Key } from '../hooks/useKeypress.js'; // Mock the child components and utilities vi.mock('./shared/RadioButtonSelect.js', () => ({ @@ -41,8 +42,15 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => { }; }); +// Capture keypress handler to test it +let mockKeypressHandler: (key: Key) => void; +let mockKeypressOptions: { isActive: boolean }; + vi.mock('../hooks/useKeypress.js', () => ({ - useKeypress: vi.fn(), + useKeypress: vi.fn((handler, options) => { + mockKeypressHandler = handler; + mockKeypressOptions = options; + }), })); describe('ValidationDialog', () => { @@ -99,6 +107,29 @@ describe('ValidationDialog', () => { expect(lastFrame()).toContain('https://example.com/help'); unmount(); }); + + it('should call onChoice with cancel when ESCAPE is pressed', () => { + const { unmount } = render(); + + // Verify the keypress hook is active + expect(mockKeypressOptions.isActive).toBe(true); + + // Simulate ESCAPE key press + act(() => { + mockKeypressHandler({ + name: 'escape', + ctrl: false, + shift: false, + alt: false, + cmd: false, + insertable: false, + sequence: '\x1b', + }); + }); + + expect(mockOnChoice).toHaveBeenCalledWith('cancel'); + unmount(); + }); }); describe('onChoice handling', () => { diff --git a/packages/cli/src/ui/components/ValidationDialog.tsx b/packages/cli/src/ui/components/ValidationDialog.tsx index b7ddf2878a..9c71e93403 100644 --- a/packages/cli/src/ui/components/ValidationDialog.tsx +++ b/packages/cli/src/ui/components/ValidationDialog.tsx @@ -48,17 +48,17 @@ export function ValidationDialog({ }, ]; - // Handle keypresses during 'waiting' state (ESC to cancel, Enter to confirm completion) + // Handle keypresses globally for cancellation, and specific logic for waiting state useKeypress( (key) => { if (keyMatchers[Command.ESCAPE](key) || keyMatchers[Command.QUIT](key)) { onChoice('cancel'); - } else if (keyMatchers[Command.RETURN](key)) { + } else if (state === 'waiting' && keyMatchers[Command.RETURN](key)) { // User confirmed verification is complete - transition to 'complete' state setState('complete'); } }, - { isActive: state === 'waiting' }, + { isActive: state !== 'complete' }, ); // When state becomes 'complete', show success message briefly then proceed diff --git a/packages/cli/src/ui/contexts/UIStateContext.tsx b/packages/cli/src/ui/contexts/UIStateContext.tsx index 893ee80c07..fea13285b1 100644 --- a/packages/cli/src/ui/contexts/UIStateContext.tsx +++ b/packages/cli/src/ui/contexts/UIStateContext.tsx @@ -94,6 +94,7 @@ export interface UIState { inputWidth: number; suggestionsWidth: number; isInputActive: boolean; + isResuming: boolean; shouldShowIdePrompt: boolean; isFolderTrustDialogOpen: boolean; isTrustedFolder: boolean | undefined; diff --git a/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts b/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts index 61e53638ec..2a9106329e 100644 --- a/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts +++ b/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts @@ -41,6 +41,7 @@ describe('useQuotaAndFallback', () => { let mockConfig: Config; let mockHistoryManager: UseHistoryManagerReturn; let mockSetModelSwitchedFromQuotaError: Mock; + let mockOnShowAuthSelection: Mock; let setFallbackHandlerSpy: SpyInstance; let mockGoogleApiError: GoogleApiError; @@ -66,6 +67,7 @@ describe('useQuotaAndFallback', () => { loadHistory: vi.fn(), }; mockSetModelSwitchedFromQuotaError = vi.fn(); + mockOnShowAuthSelection = vi.fn(); setFallbackHandlerSpy = vi.spyOn(mockConfig, 'setFallbackModelHandler'); vi.spyOn(mockConfig, 'setQuotaErrorOccurred'); @@ -85,6 +87,7 @@ describe('useQuotaAndFallback', () => { historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -101,6 +104,7 @@ describe('useQuotaAndFallback', () => { historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); return setFallbackHandlerSpy.mock.calls[0][0] as FallbackModelHandler; @@ -127,6 +131,7 @@ describe('useQuotaAndFallback', () => { historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -178,6 +183,7 @@ describe('useQuotaAndFallback', () => { historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -243,6 +249,7 @@ describe('useQuotaAndFallback', () => { userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -297,6 +304,7 @@ describe('useQuotaAndFallback', () => { historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -345,6 +353,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -362,6 +371,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -392,6 +402,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -435,6 +446,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -470,6 +482,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -513,6 +526,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -527,6 +541,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -568,6 +583,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -602,13 +618,14 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, expect(result.current.validationRequest).toBeNull(); }); - it('should add info message when change_auth is chosen', async () => { + it('should call onShowAuthSelection when change_auth is chosen', async () => { const { result } = renderHook(() => useQuotaAndFallback({ config: mockConfig, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -628,19 +645,17 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, const intent = await promise!; expect(intent).toBe('change_auth'); - expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(1); - const lastCall = (mockHistoryManager.addItem as Mock).mock.calls[0][0]; - expect(lastCall.type).toBe(MessageType.INFO); - expect(lastCall.text).toBe('Use /auth to change authentication method.'); + expect(mockOnShowAuthSelection).toHaveBeenCalledTimes(1); }); - it('should not add info message when cancel is chosen', async () => { + it('should call onShowAuthSelection when cancel is chosen', async () => { const { result } = renderHook(() => useQuotaAndFallback({ config: mockConfig, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); @@ -660,7 +675,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, const intent = await promise!; expect(intent).toBe('cancel'); - expect(mockHistoryManager.addItem).not.toHaveBeenCalled(); + expect(mockOnShowAuthSelection).toHaveBeenCalledTimes(1); }); it('should do nothing if handleValidationChoice is called without pending request', () => { @@ -670,6 +685,7 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`, historyManager: mockHistoryManager, userTier: UserTierId.FREE, setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, }), ); diff --git a/packages/cli/src/ui/hooks/useQuotaAndFallback.ts b/packages/cli/src/ui/hooks/useQuotaAndFallback.ts index 7f8b8d0f0d..bc12c60907 100644 --- a/packages/cli/src/ui/hooks/useQuotaAndFallback.ts +++ b/packages/cli/src/ui/hooks/useQuotaAndFallback.ts @@ -31,6 +31,7 @@ interface UseQuotaAndFallbackArgs { historyManager: UseHistoryManagerReturn; userTier: UserTierId | undefined; setModelSwitchedFromQuotaError: (value: boolean) => void; + onShowAuthSelection: () => void; } export function useQuotaAndFallback({ @@ -38,6 +39,7 @@ export function useQuotaAndFallback({ historyManager, userTier, setModelSwitchedFromQuotaError, + onShowAuthSelection, }: UseQuotaAndFallbackArgs) { const [proQuotaRequest, setProQuotaRequest] = useState(null); @@ -197,17 +199,11 @@ export function useQuotaAndFallback({ validationRequest.resolve(choice); setValidationRequest(null); - if (choice === 'change_auth') { - historyManager.addItem( - { - type: MessageType.INFO, - text: 'Use /auth to change authentication method.', - }, - Date.now(), - ); + if (choice === 'change_auth' || choice === 'cancel') { + onShowAuthSelection(); } }, - [validationRequest, historyManager], + [validationRequest, onShowAuthSelection], ); return { diff --git a/packages/cli/src/ui/hooks/useSessionBrowser.ts b/packages/cli/src/ui/hooks/useSessionBrowser.ts index 1dbced887d..3d9619d738 100644 --- a/packages/cli/src/ui/hooks/useSessionBrowser.ts +++ b/packages/cli/src/ui/hooks/useSessionBrowser.ts @@ -24,7 +24,7 @@ export const useSessionBrowser = ( uiHistory: HistoryItemWithoutId[], clientHistory: Array<{ role: 'user' | 'model'; parts: Part[] }>, resumedSessionData: ResumedSessionData, - ) => void, + ) => Promise, ) => { const [isSessionBrowserOpen, setIsSessionBrowserOpen] = useState(false); @@ -73,7 +73,7 @@ export const useSessionBrowser = ( const historyData = convertSessionToHistoryFormats( conversation.messages, ); - onLoadHistory( + await onLoadHistory( historyData.uiHistory, historyData.clientHistory, resumedSessionData, diff --git a/packages/cli/src/ui/hooks/useSessionResume.test.ts b/packages/cli/src/ui/hooks/useSessionResume.test.ts index 029d23d725..071fe5878b 100644 --- a/packages/cli/src/ui/hooks/useSessionResume.test.ts +++ b/packages/cli/src/ui/hooks/useSessionResume.test.ts @@ -62,7 +62,7 @@ describe('useSessionResume', () => { expect(result.current.loadHistoryForResume).toBeInstanceOf(Function); }); - it('should clear history and add items when loading history', () => { + it('should clear history and add items when loading history', async () => { const { result } = renderHook(() => useSessionResume(getDefaultProps())); const uiHistory: HistoryItemWithoutId[] = [ @@ -86,8 +86,8 @@ describe('useSessionResume', () => { filePath: '/path/to/session.json', }; - act(() => { - result.current.loadHistoryForResume( + await act(async () => { + await result.current.loadHistoryForResume( uiHistory, clientHistory, resumedData, @@ -116,7 +116,7 @@ describe('useSessionResume', () => { ); }); - it('should not load history if Gemini client is not initialized', () => { + it('should not load history if Gemini client is not initialized', async () => { const { result } = renderHook(() => useSessionResume({ ...getDefaultProps(), @@ -141,8 +141,8 @@ describe('useSessionResume', () => { filePath: '/path/to/session.json', }; - act(() => { - result.current.loadHistoryForResume( + await act(async () => { + await result.current.loadHistoryForResume( uiHistory, clientHistory, resumedData, @@ -154,7 +154,7 @@ describe('useSessionResume', () => { expect(mockGeminiClient.resumeChat).not.toHaveBeenCalled(); }); - it('should handle empty history arrays', () => { + it('should handle empty history arrays', async () => { const { result } = renderHook(() => useSessionResume(getDefaultProps())); const resumedData: ResumedSessionData = { @@ -168,8 +168,8 @@ describe('useSessionResume', () => { filePath: '/path/to/session.json', }; - act(() => { - result.current.loadHistoryForResume([], [], resumedData); + await act(async () => { + await result.current.loadHistoryForResume([], [], resumedData); }); expect(mockHistoryManager.clearItems).toHaveBeenCalled(); @@ -311,15 +311,17 @@ describe('useSessionResume', () => { ] as MessageRecord[], }; - renderHook(() => - useSessionResume({ - ...getDefaultProps(), - resumedSessionData: { - conversation, - filePath: '/path/to/session.json', - }, - }), - ); + await act(async () => { + renderHook(() => + useSessionResume({ + ...getDefaultProps(), + resumedSessionData: { + conversation, + filePath: '/path/to/session.json', + }, + }), + ); + }); await waitFor(() => { expect(mockHistoryManager.clearItems).toHaveBeenCalled(); @@ -358,20 +360,24 @@ describe('useSessionResume', () => { ] as MessageRecord[], }; - const { rerender } = renderHook( - ({ refreshStatic }: { refreshStatic: () => void }) => - useSessionResume({ - ...getDefaultProps(), - refreshStatic, - resumedSessionData: { - conversation, - filePath: '/path/to/session.json', - }, - }), - { - initialProps: { refreshStatic: mockRefreshStatic }, - }, - ); + let rerenderFunc: (props: { refreshStatic: () => void }) => void; + await act(async () => { + const { rerender } = renderHook( + ({ refreshStatic }: { refreshStatic: () => void }) => + useSessionResume({ + ...getDefaultProps(), + refreshStatic, + resumedSessionData: { + conversation, + filePath: '/path/to/session.json', + }, + }), + { + initialProps: { refreshStatic: mockRefreshStatic as () => void }, + }, + ); + rerenderFunc = rerender; + }); await waitFor(() => { expect(mockHistoryManager.clearItems).toHaveBeenCalled(); @@ -383,7 +389,9 @@ describe('useSessionResume', () => { // Rerender with different refreshStatic const newRefreshStatic = vi.fn(); - rerender({ refreshStatic: newRefreshStatic }); + await act(async () => { + rerenderFunc({ refreshStatic: newRefreshStatic }); + }); // Should not resume again expect(mockHistoryManager.clearItems).toHaveBeenCalledTimes( @@ -413,15 +421,17 @@ describe('useSessionResume', () => { ] as MessageRecord[], }; - renderHook(() => - useSessionResume({ - ...getDefaultProps(), - resumedSessionData: { - conversation, - filePath: '/path/to/session.json', - }, - }), - ); + await act(async () => { + renderHook(() => + useSessionResume({ + ...getDefaultProps(), + resumedSessionData: { + conversation, + filePath: '/path/to/session.json', + }, + }), + ); + }); await waitFor(() => { expect(mockGeminiClient.resumeChat).toHaveBeenCalled(); diff --git a/packages/cli/src/ui/hooks/useSessionResume.ts b/packages/cli/src/ui/hooks/useSessionResume.ts index 228ca6ac2c..21b9d0884f 100644 --- a/packages/cli/src/ui/hooks/useSessionResume.ts +++ b/packages/cli/src/ui/hooks/useSessionResume.ts @@ -4,8 +4,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useCallback, useEffect, useRef } from 'react'; -import type { Config, ResumedSessionData } from '@google/gemini-cli-core'; +import { useCallback, useEffect, useRef, useState } from 'react'; +import { + coreEvents, + type Config, + type ResumedSessionData, +} from '@google/gemini-cli-core'; import type { Part } from '@google/genai'; import type { HistoryItemWithoutId } from '../types.js'; import type { UseHistoryManagerReturn } from './useHistoryManager.js'; @@ -35,6 +39,8 @@ export function useSessionResume({ resumedSessionData, isAuthenticating, }: UseSessionResumeParams) { + const [isResuming, setIsResuming] = useState(false); + // Use refs to avoid dependency chain that causes infinite loop const historyManagerRef = useRef(historyManager); const refreshStaticRef = useRef(refreshStatic); @@ -45,7 +51,7 @@ export function useSessionResume({ }); const loadHistoryForResume = useCallback( - ( + async ( uiHistory: HistoryItemWithoutId[], clientHistory: Array<{ role: 'user' | 'model'; parts: Part[] }>, resumedData: ResumedSessionData, @@ -55,17 +61,27 @@ export function useSessionResume({ return; } - // Now that we have the client, load the history into the UI and the client. - setQuittingMessages(null); - historyManagerRef.current.clearItems(); - uiHistory.forEach((item, index) => { - historyManagerRef.current.addItem(item, index, true); - }); - refreshStaticRef.current(); // Force Static component to re-render with the updated history. + setIsResuming(true); + try { + // Now that we have the client, load the history into the UI and the client. + setQuittingMessages(null); + historyManagerRef.current.clearItems(); + uiHistory.forEach((item, index) => { + historyManagerRef.current.addItem(item, index, true); + }); + refreshStaticRef.current(); // Force Static component to re-render with the updated history. - // Give the history to the Gemini client. - // eslint-disable-next-line @typescript-eslint/no-floating-promises - config.getGeminiClient()?.resumeChat(clientHistory, resumedData); + // Give the history to the Gemini client. + await config.getGeminiClient()?.resumeChat(clientHistory, resumedData); + } catch (error) { + coreEvents.emitFeedback( + 'error', + 'Failed to resume session. Please try again.', + error, + ); + } finally { + setIsResuming(false); + } }, [config, isGeminiClientInitialized, setQuittingMessages], ); @@ -84,7 +100,7 @@ export function useSessionResume({ const historyData = convertSessionToHistoryFormats( resumedSessionData.conversation.messages, ); - loadHistoryForResume( + void loadHistoryForResume( historyData.uiHistory, historyData.clientHistory, resumedSessionData, @@ -97,5 +113,5 @@ export function useSessionResume({ loadHistoryForResume, ]); - return { loadHistoryForResume }; + return { loadHistoryForResume, isResuming }; } diff --git a/packages/cli/src/utils/installationInfo.ts b/packages/cli/src/utils/installationInfo.ts index 2661014a49..ddc4afe8da 100644 --- a/packages/cli/src/utils/installationInfo.ts +++ b/packages/cli/src/utils/installationInfo.ts @@ -69,7 +69,10 @@ export function getInstallationInfo( updateMessage: 'Running via npx, update not applicable.', }; } - if (realPath.includes('/.pnpm/_pnpx')) { + if ( + realPath.includes('/.pnpm/_pnpx') || + realPath.includes('/.cache/pnpm/dlx') + ) { return { packageManager: PackageManager.PNPX, isGlobal: false, @@ -103,7 +106,10 @@ export function getInstallationInfo( } // Check for pnpm - if (realPath.includes('/.pnpm/global')) { + if ( + realPath.includes('/.pnpm/global') || + realPath.includes('/.local/share/pnpm') + ) { const updateCommand = 'pnpm add -g @google/gemini-cli@latest'; return { packageManager: PackageManager.PNPM, diff --git a/packages/core/src/code_assist/codeAssist.test.ts b/packages/core/src/code_assist/codeAssist.test.ts index 90ebfb1d9c..6efee88d69 100644 --- a/packages/core/src/code_assist/codeAssist.test.ts +++ b/packages/core/src/code_assist/codeAssist.test.ts @@ -35,7 +35,10 @@ describe('codeAssist', () => { describe('createCodeAssistContentGenerator', () => { const httpOptions = {}; - const mockConfig = {} as Config; + const mockValidationHandler = vi.fn(); + const mockConfig = { + getValidationHandler: () => mockValidationHandler, + } as unknown as Config; const mockAuthClient = { a: 'client' }; const mockUserData = { projectId: 'test-project', @@ -57,7 +60,10 @@ describe('codeAssist', () => { AuthType.LOGIN_WITH_GOOGLE, mockConfig, ); - expect(setupUser).toHaveBeenCalledWith(mockAuthClient); + expect(setupUser).toHaveBeenCalledWith( + mockAuthClient, + mockValidationHandler, + ); expect(MockedCodeAssistServer).toHaveBeenCalledWith( mockAuthClient, 'test-project', @@ -83,7 +89,10 @@ describe('codeAssist', () => { AuthType.COMPUTE_ADC, mockConfig, ); - expect(setupUser).toHaveBeenCalledWith(mockAuthClient); + expect(setupUser).toHaveBeenCalledWith( + mockAuthClient, + mockValidationHandler, + ); expect(MockedCodeAssistServer).toHaveBeenCalledWith( mockAuthClient, 'test-project', diff --git a/packages/core/src/code_assist/codeAssist.ts b/packages/core/src/code_assist/codeAssist.ts index fee43e9c45..3b87cb03e2 100644 --- a/packages/core/src/code_assist/codeAssist.ts +++ b/packages/core/src/code_assist/codeAssist.ts @@ -24,7 +24,7 @@ export async function createCodeAssistContentGenerator( authType === AuthType.COMPUTE_ADC ) { const authClient = await getOauthClient(authType, config); - const userData = await setupUser(authClient); + const userData = await setupUser(authClient, config.getValidationHandler()); return new CodeAssistServer( authClient, userData.projectId, diff --git a/packages/core/src/code_assist/setup.test.ts b/packages/core/src/code_assist/setup.test.ts index bd43ed2e88..9559c58254 100644 --- a/packages/core/src/code_assist/setup.test.ts +++ b/packages/core/src/code_assist/setup.test.ts @@ -5,7 +5,13 @@ */ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import { setupUser, ProjectIdRequiredError } from './setup.js'; +import { + ProjectIdRequiredError, + setupUser, + ValidationCancelledError, +} from './setup.js'; +import { ValidationRequiredError } from '../utils/googleQuotaErrors.js'; +import { ChangeAuthRequestedError } from '../utils/errors.js'; import { CodeAssistServer } from '../code_assist/server.js'; import type { OAuth2Client } from 'google-auth-library'; import type { GeminiUserTier } from './types.js'; @@ -307,3 +313,215 @@ describe('setupUser for new user', () => { }); }); }); + +describe('setupUser validation', () => { + let mockLoad: ReturnType; + + beforeEach(() => { + vi.resetAllMocks(); + mockLoad = vi.fn(); + vi.mocked(CodeAssistServer).mockImplementation( + () => + ({ + loadCodeAssist: mockLoad, + }) as unknown as CodeAssistServer, + ); + }); + + afterEach(() => { + vi.unstubAllEnvs(); + }); + + it('should throw error if LoadCodeAssist returns ineligible tiers and no current tier', async () => { + mockLoad.mockResolvedValue({ + currentTier: null, + ineligibleTiers: [ + { + reasonMessage: 'User is not eligible', + reasonCode: 'INELIGIBLE_ACCOUNT', + tierId: 'standard-tier', + tierName: 'standard', + }, + ], + }); + + await expect(setupUser({} as OAuth2Client)).rejects.toThrow( + 'User is not eligible', + ); + }); + + it('should retry if validation handler returns verify', async () => { + // First call fails + mockLoad.mockResolvedValueOnce({ + currentTier: null, + ineligibleTiers: [ + { + reasonMessage: 'User is not eligible', + reasonCode: 'VALIDATION_REQUIRED', + tierId: 'standard-tier', + tierName: 'standard', + validationUrl: 'https://example.com/verify', + validationLearnMoreUrl: 'https://example.com/learn', + }, + ], + }); + // Second call succeeds + mockLoad.mockResolvedValueOnce({ + currentTier: mockPaidTier, + cloudaicompanionProject: 'test-project', + }); + + const mockValidationHandler = vi.fn().mockResolvedValue('verify'); + + const result = await setupUser({} as OAuth2Client, mockValidationHandler); + + expect(mockValidationHandler).toHaveBeenCalledWith( + 'https://example.com/verify', + 'User is not eligible', + ); + expect(mockLoad).toHaveBeenCalledTimes(2); + expect(result).toEqual({ + projectId: 'test-project', + userTier: 'standard-tier', + userTierName: 'paid', + }); + }); + + it('should throw if validation handler returns cancel', async () => { + mockLoad.mockResolvedValue({ + currentTier: null, + ineligibleTiers: [ + { + reasonMessage: 'User is not eligible', + reasonCode: 'VALIDATION_REQUIRED', + tierId: 'standard-tier', + tierName: 'standard', + validationUrl: 'https://example.com/verify', + }, + ], + }); + + const mockValidationHandler = vi.fn().mockResolvedValue('cancel'); + + await expect( + setupUser({} as OAuth2Client, mockValidationHandler), + ).rejects.toThrow(ValidationCancelledError); + expect(mockValidationHandler).toHaveBeenCalled(); + expect(mockLoad).toHaveBeenCalledTimes(1); + }); + + it('should throw ChangeAuthRequestedError if validation handler returns change_auth', async () => { + mockLoad.mockResolvedValue({ + currentTier: null, + ineligibleTiers: [ + { + reasonMessage: 'User is not eligible', + reasonCode: 'VALIDATION_REQUIRED', + tierId: 'standard-tier', + tierName: 'standard', + validationUrl: 'https://example.com/verify', + }, + ], + }); + + const mockValidationHandler = vi.fn().mockResolvedValue('change_auth'); + + await expect( + setupUser({} as OAuth2Client, mockValidationHandler), + ).rejects.toThrow(ChangeAuthRequestedError); + expect(mockValidationHandler).toHaveBeenCalled(); + expect(mockLoad).toHaveBeenCalledTimes(1); + }); + + it('should throw ValidationRequiredError without handler', async () => { + mockLoad.mockResolvedValue({ + currentTier: null, + ineligibleTiers: [ + { + reasonMessage: 'Please verify your account', + reasonCode: 'VALIDATION_REQUIRED', + tierId: 'standard-tier', + tierName: 'standard', + validationUrl: 'https://example.com/verify', + }, + ], + }); + + await expect(setupUser({} as OAuth2Client)).rejects.toThrow( + ValidationRequiredError, + ); + expect(mockLoad).toHaveBeenCalledTimes(1); + }); + + it('should throw error if LoadCodeAssist returns empty response', async () => { + mockLoad.mockResolvedValue(null); + + await expect(setupUser({} as OAuth2Client)).rejects.toThrow( + 'LoadCodeAssist returned empty response', + ); + }); + + it('should retry multiple times when validation handler keeps returning verify', async () => { + // First two calls fail with validation required + mockLoad + .mockResolvedValueOnce({ + currentTier: null, + ineligibleTiers: [ + { + reasonMessage: 'Verify 1', + reasonCode: 'VALIDATION_REQUIRED', + tierId: 'standard-tier', + tierName: 'standard', + validationUrl: 'https://example.com/verify', + }, + ], + }) + .mockResolvedValueOnce({ + currentTier: null, + ineligibleTiers: [ + { + reasonMessage: 'Verify 2', + reasonCode: 'VALIDATION_REQUIRED', + tierId: 'standard-tier', + tierName: 'standard', + validationUrl: 'https://example.com/verify', + }, + ], + }) + .mockResolvedValueOnce({ + currentTier: mockPaidTier, + cloudaicompanionProject: 'test-project', + }); + + const mockValidationHandler = vi.fn().mockResolvedValue('verify'); + + const result = await setupUser({} as OAuth2Client, mockValidationHandler); + + expect(mockValidationHandler).toHaveBeenCalledTimes(2); + expect(mockLoad).toHaveBeenCalledTimes(3); + expect(result).toEqual({ + projectId: 'test-project', + userTier: 'standard-tier', + userTierName: 'paid', + }); + }); +}); + +describe('ValidationRequiredError', () => { + const error = new ValidationRequiredError( + 'Account validation required: Please verify', + undefined, + 'https://example.com/verify', + 'Please verify', + ); + + it('should be an instance of Error', () => { + expect(error).toBeInstanceOf(Error); + expect(error).toBeInstanceOf(ValidationRequiredError); + }); + + it('should have the correct properties', () => { + expect(error.validationLink).toBe('https://example.com/verify'); + expect(error.validationDescription).toBe('Please verify'); + }); +}); diff --git a/packages/core/src/code_assist/setup.ts b/packages/core/src/code_assist/setup.ts index 994bb99568..15da70fb42 100644 --- a/packages/core/src/code_assist/setup.ts +++ b/packages/core/src/code_assist/setup.ts @@ -10,9 +10,12 @@ import type { LoadCodeAssistResponse, OnboardUserRequest, } from './types.js'; -import { UserTierId } from './types.js'; +import { UserTierId, IneligibleTierReasonCode } from './types.js'; import { CodeAssistServer } from './server.js'; import type { AuthClient } from 'google-auth-library'; +import type { ValidationHandler } from '../fallback/types.js'; +import { ChangeAuthRequestedError } from '../utils/errors.js'; +import { ValidationRequiredError } from '../utils/googleQuotaErrors.js'; export class ProjectIdRequiredError extends Error { constructor() { @@ -22,6 +25,16 @@ export class ProjectIdRequiredError extends Error { } } +/** + * Error thrown when user cancels the validation process. + * This is a non-recoverable error that should result in auth failure. + */ +export class ValidationCancelledError extends Error { + constructor() { + super('User cancelled account validation'); + } +} + export interface UserData { projectId: string; userTier: UserTierId; @@ -33,7 +46,10 @@ export interface UserData { * @param projectId the user's project id, if any * @returns the user's actual project id */ -export async function setupUser(client: AuthClient): Promise { +export async function setupUser( + client: AuthClient, + validationHandler?: ValidationHandler, +): Promise { const projectId = process.env['GOOGLE_CLOUD_PROJECT'] || process.env['GOOGLE_CLOUD_PROJECT_ID'] || @@ -52,13 +68,36 @@ export async function setupUser(client: AuthClient): Promise { pluginType: 'GEMINI', }; - const loadRes = await caServer.loadCodeAssist({ - cloudaicompanionProject: projectId, - metadata: { - ...coreClientMetadata, - duetProject: projectId, - }, - }); + let loadRes: LoadCodeAssistResponse; + while (true) { + loadRes = await caServer.loadCodeAssist({ + cloudaicompanionProject: projectId, + metadata: { + ...coreClientMetadata, + duetProject: projectId, + }, + }); + + try { + validateLoadCodeAssistResponse(loadRes); + break; + } catch (e) { + if (e instanceof ValidationRequiredError && validationHandler) { + const intent = await validationHandler( + e.validationLink, + e.validationDescription, + ); + if (intent === 'verify') { + continue; + } + if (intent === 'change_auth') { + throw new ChangeAuthRequestedError(); + } + throw new ValidationCancelledError(); + } + throw e; + } + } if (loadRes.currentTier) { if (!loadRes.cloudaicompanionProject) { @@ -139,3 +178,34 @@ function getOnboardTier(res: LoadCodeAssistResponse): GeminiUserTier { userDefinedCloudaicompanionProject: true, }; } + +function validateLoadCodeAssistResponse(res: LoadCodeAssistResponse): void { + if (!res) { + throw new Error('LoadCodeAssist returned empty response'); + } + if ( + !res.currentTier && + res.ineligibleTiers && + res.ineligibleTiers.length > 0 + ) { + // Check for VALIDATION_REQUIRED first - this is a recoverable state + const validationTier = res.ineligibleTiers.find( + (t) => + t.validationUrl && + t.reasonCode === IneligibleTierReasonCode.VALIDATION_REQUIRED, + ); + const validationUrl = validationTier?.validationUrl; + if (validationTier && validationUrl) { + throw new ValidationRequiredError( + `Account validation required: ${validationTier.reasonMessage}`, + undefined, + validationUrl, + validationTier.reasonMessage, + ); + } + + // For other ineligibility reasons, throw a generic error + const reasons = res.ineligibleTiers.map((t) => t.reasonMessage).join(', '); + throw new Error(reasons); + } +} diff --git a/packages/core/src/code_assist/types.ts b/packages/core/src/code_assist/types.ts index fd74d69b38..5e706cc207 100644 --- a/packages/core/src/code_assist/types.ts +++ b/packages/core/src/code_assist/types.ts @@ -82,6 +82,11 @@ export interface IneligibleTier { reasonMessage: string; tierId: UserTierId; tierName: string; + validationErrorMessage?: string; + validationUrl?: string; + validationUrlLinkText?: string; + validationLearnMoreUrl?: string; + validationLearnMoreLinkText?: string; } /** @@ -98,6 +103,7 @@ export enum IneligibleTierReasonCode { UNKNOWN = 'UNKNOWN', UNKNOWN_LOCATION = 'UNKNOWN_LOCATION', UNSUPPORTED_LOCATION = 'UNSUPPORTED_LOCATION', + VALIDATION_REQUIRED = 'VALIDATION_REQUIRED', // go/keep-sorted end } /** diff --git a/packages/core/src/confirmation-bus/message-bus.ts b/packages/core/src/confirmation-bus/message-bus.ts index 11dab9ca23..722cb37344 100644 --- a/packages/core/src/confirmation-bus/message-bus.ts +++ b/packages/core/src/confirmation-bus/message-bus.ts @@ -7,12 +7,8 @@ import { randomUUID } from 'node:crypto'; import { EventEmitter } from 'node:events'; import type { PolicyEngine } from '../policy/policy-engine.js'; -import { PolicyDecision, getHookSource } from '../policy/types.js'; -import { - MessageBusType, - type Message, - type HookPolicyDecision, -} from './types.js'; +import { PolicyDecision } from '../policy/types.js'; +import { MessageBusType, type Message } from './types.js'; import { safeJsonStringify } from '../utils/safeJsonStringify.js'; import { debugLogger } from '../utils/debugLogger.js'; @@ -89,39 +85,6 @@ export class MessageBus extends EventEmitter { default: throw new Error(`Unknown policy decision: ${decision}`); } - } else if (message.type === MessageBusType.HOOK_EXECUTION_REQUEST) { - // Handle hook execution requests through policy evaluation - const hookRequest = message; - const decision = await this.policyEngine.checkHook(hookRequest); - - // Map decision to allow/deny for observability (ASK_USER treated as deny for hooks) - const effectiveDecision = - decision === PolicyDecision.ALLOW ? 'allow' : 'deny'; - - // Emit policy decision for observability - this.emitMessage({ - type: MessageBusType.HOOK_POLICY_DECISION, - eventName: hookRequest.eventName, - hookSource: getHookSource(hookRequest.input), - decision: effectiveDecision, - reason: - decision !== PolicyDecision.ALLOW - ? 'Hook execution denied by policy' - : undefined, - } as HookPolicyDecision); - - // If allowed, emit the request for hook system to handle - if (decision === PolicyDecision.ALLOW) { - this.emitMessage(message); - } else { - // If denied or ASK_USER, emit error response (hooks don't support interactive confirmation) - this.emitMessage({ - type: MessageBusType.HOOK_EXECUTION_RESPONSE, - correlationId: hookRequest.correlationId, - success: false, - error: new Error('Hook execution denied by policy'), - }); - } } else { // For all other message types, just emit them this.emitMessage(message); diff --git a/packages/core/src/confirmation-bus/types.ts b/packages/core/src/confirmation-bus/types.ts index 786894a972..aeecf73b3e 100644 --- a/packages/core/src/confirmation-bus/types.ts +++ b/packages/core/src/confirmation-bus/types.ts @@ -18,9 +18,6 @@ export enum MessageBusType { TOOL_EXECUTION_SUCCESS = 'tool-execution-success', TOOL_EXECUTION_FAILURE = 'tool-execution-failure', UPDATE_POLICY = 'update-policy', - HOOK_EXECUTION_REQUEST = 'hook-execution-request', - HOOK_EXECUTION_RESPONSE = 'hook-execution-response', - HOOK_POLICY_DECISION = 'hook-policy-decision', TOOL_CALLS_UPDATE = 'tool-calls-update', ASK_USER_REQUEST = 'ask-user-request', ASK_USER_RESPONSE = 'ask-user-response', @@ -120,29 +117,6 @@ export interface ToolExecutionFailure { error: E; } -export interface HookExecutionRequest { - type: MessageBusType.HOOK_EXECUTION_REQUEST; - eventName: string; - input: Record; - correlationId: string; -} - -export interface HookExecutionResponse { - type: MessageBusType.HOOK_EXECUTION_RESPONSE; - correlationId: string; - success: boolean; - output?: Record; - error?: Error; -} - -export interface HookPolicyDecision { - type: MessageBusType.HOOK_POLICY_DECISION; - eventName: string; - hookSource: 'project' | 'user' | 'system' | 'extension'; - decision: 'allow' | 'deny'; - reason?: string; -} - export interface QuestionOption { label: string; description: string; @@ -186,9 +160,6 @@ export type Message = | ToolExecutionSuccess | ToolExecutionFailure | UpdatePolicy - | HookExecutionRequest - | HookExecutionResponse - | HookPolicyDecision | AskUserRequest | AskUserResponse | ToolCallsUpdateMessage; diff --git a/packages/core/src/hooks/hookEventHandler.test.ts b/packages/core/src/hooks/hookEventHandler.test.ts index af7a6be37a..3dbc4ae881 100644 --- a/packages/core/src/hooks/hookEventHandler.test.ts +++ b/packages/core/src/hooks/hookEventHandler.test.ts @@ -8,7 +8,6 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'; import { HookEventHandler } from './hookEventHandler.js'; import type { Config } from '../config/config.js'; import type { HookConfig } from './types.js'; -import type { Logger } from '@opentelemetry/api-logs'; import type { HookPlanner } from './hookPlanner.js'; import type { HookRunner } from './hookRunner.js'; import type { HookAggregator } from './hookAggregator.js'; @@ -18,7 +17,6 @@ import { SessionStartSource, type HookExecutionResult, } from './types.js'; -import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; // Mock debugLogger const mockDebugLogger = vi.hoisted(() => ({ @@ -54,7 +52,6 @@ vi.mock('../telemetry/clearcut-logger/clearcut-logger.js', () => ({ describe('HookEventHandler', () => { let hookEventHandler: HookEventHandler; let mockConfig: Config; - let mockLogger: Logger; let mockHookPlanner: HookPlanner; let mockHookRunner: HookRunner; let mockHookAggregator: HookAggregator; @@ -74,8 +71,6 @@ describe('HookEventHandler', () => { }), } as unknown as Config; - mockLogger = {} as Logger; - mockHookPlanner = { createExecutionPlan: vi.fn(), } as unknown as HookPlanner; @@ -91,11 +86,9 @@ describe('HookEventHandler', () => { hookEventHandler = new HookEventHandler( mockConfig, - mockLogger, mockHookPlanner, mockHookRunner, mockHookAggregator, - createMockMessageBus(), ); }); diff --git a/packages/core/src/hooks/hookEventHandler.ts b/packages/core/src/hooks/hookEventHandler.ts index e208dd1ed4..cae3c61625 100644 --- a/packages/core/src/hooks/hookEventHandler.ts +++ b/packages/core/src/hooks/hookEventHandler.ts @@ -4,7 +4,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { Logger } from '@opentelemetry/api-logs'; import type { Config } from '../config/config.js'; import type { HookPlanner, HookEventContext } from './hookPlanner.js'; import type { HookRunner } from './hookRunner.js'; @@ -38,265 +37,9 @@ import type { } from '@google/genai'; import { logHookCall } from '../telemetry/loggers.js'; import { HookCallEvent } from '../telemetry/types.js'; -import type { MessageBus } from '../confirmation-bus/message-bus.js'; -import { - MessageBusType, - type HookExecutionRequest, -} from '../confirmation-bus/types.js'; import { debugLogger } from '../utils/debugLogger.js'; import { coreEvents } from '../utils/events.js'; -/** - * Validates that a value is a non-null object - */ -function isObject(value: unknown): value is Record { - return typeof value === 'object' && value !== null; -} - -/** - * Validates BeforeTool input fields - */ -function validateBeforeToolInput(input: Record): { - toolName: string; - toolInput: Record; - mcpContext?: McpToolContext; -} { - const toolName = input['tool_name']; - const toolInput = input['tool_input']; - const mcpContext = input['mcp_context']; - if (typeof toolName !== 'string') { - throw new Error( - 'Invalid input for BeforeTool hook event: tool_name must be a string', - ); - } - if (!isObject(toolInput)) { - throw new Error( - 'Invalid input for BeforeTool hook event: tool_input must be an object', - ); - } - if (mcpContext !== undefined && !isObject(mcpContext)) { - throw new Error( - 'Invalid input for BeforeTool hook event: mcp_context must be an object', - ); - } - return { - toolName, - toolInput, - mcpContext: mcpContext as McpToolContext | undefined, - }; -} - -/** - * Validates AfterTool input fields - */ -function validateAfterToolInput(input: Record): { - toolName: string; - toolInput: Record; - toolResponse: Record; - mcpContext?: McpToolContext; -} { - const toolName = input['tool_name']; - const toolInput = input['tool_input']; - const toolResponse = input['tool_response']; - const mcpContext = input['mcp_context']; - if (typeof toolName !== 'string') { - throw new Error( - 'Invalid input for AfterTool hook event: tool_name must be a string', - ); - } - if (!isObject(toolInput)) { - throw new Error( - 'Invalid input for AfterTool hook event: tool_input must be an object', - ); - } - if (!isObject(toolResponse)) { - throw new Error( - 'Invalid input for AfterTool hook event: tool_response must be an object', - ); - } - if (mcpContext !== undefined && !isObject(mcpContext)) { - throw new Error( - 'Invalid input for AfterTool hook event: mcp_context must be an object', - ); - } - return { - toolName, - toolInput, - toolResponse, - mcpContext: mcpContext as McpToolContext | undefined, - }; -} - -/** - * Validates BeforeAgent input fields - */ -function validateBeforeAgentInput(input: Record): { - prompt: string; -} { - const prompt = input['prompt']; - if (typeof prompt !== 'string') { - throw new Error( - 'Invalid input for BeforeAgent hook event: prompt must be a string', - ); - } - return { prompt }; -} - -/** - * Validates AfterAgent input fields - */ -function validateAfterAgentInput(input: Record): { - prompt: string; - promptResponse: string; - stopHookActive: boolean; -} { - const prompt = input['prompt']; - const promptResponse = input['prompt_response']; - const stopHookActive = input['stop_hook_active']; - if (typeof prompt !== 'string') { - throw new Error( - 'Invalid input for AfterAgent hook event: prompt must be a string', - ); - } - if (typeof promptResponse !== 'string') { - throw new Error( - 'Invalid input for AfterAgent hook event: prompt_response must be a string', - ); - } - // stopHookActive defaults to false if not a boolean - return { - prompt, - promptResponse, - stopHookActive: - typeof stopHookActive === 'boolean' ? stopHookActive : false, - }; -} - -/** - * Validates model-related input fields (llm_request) - */ -function validateModelInput( - input: Record, - eventName: string, -): { llmRequest: GenerateContentParameters } { - const llmRequest = input['llm_request']; - if (!isObject(llmRequest)) { - throw new Error( - `Invalid input for ${eventName} hook event: llm_request must be an object`, - ); - } - return { llmRequest: llmRequest as unknown as GenerateContentParameters }; -} - -/** - * Validates AfterModel input fields - */ -function validateAfterModelInput(input: Record): { - llmRequest: GenerateContentParameters; - llmResponse: GenerateContentResponse; -} { - const llmRequest = input['llm_request']; - const llmResponse = input['llm_response']; - if (!isObject(llmRequest)) { - throw new Error( - 'Invalid input for AfterModel hook event: llm_request must be an object', - ); - } - if (!isObject(llmResponse)) { - throw new Error( - 'Invalid input for AfterModel hook event: llm_response must be an object', - ); - } - return { - llmRequest: llmRequest as unknown as GenerateContentParameters, - llmResponse: llmResponse as unknown as GenerateContentResponse, - }; -} - -/** - * Validates Notification input fields - */ -function validateNotificationInput(input: Record): { - notificationType: NotificationType; - message: string; - details: Record; -} { - const notificationType = input['notification_type']; - const message = input['message']; - const details = input['details']; - if (typeof notificationType !== 'string') { - throw new Error( - 'Invalid input for Notification hook event: notification_type must be a string', - ); - } - if (typeof message !== 'string') { - throw new Error( - 'Invalid input for Notification hook event: message must be a string', - ); - } - if (!isObject(details)) { - throw new Error( - 'Invalid input for Notification hook event: details must be an object', - ); - } - return { - notificationType: notificationType as NotificationType, - message, - details, - }; -} - -/** - * Validates SessionStart input fields - */ -function validateSessionStartInput(input: Record): { - source: SessionStartSource; -} { - const source = input['source']; - if (typeof source !== 'string') { - throw new Error( - 'Invalid input for SessionStart hook event: source must be a string', - ); - } - return { - source: source as SessionStartSource, - }; -} - -/** - * Validates SessionEnd input fields - */ -function validateSessionEndInput(input: Record): { - reason: SessionEndReason; -} { - const reason = input['reason']; - if (typeof reason !== 'string') { - throw new Error( - 'Invalid input for SessionEnd hook event: reason must be a string', - ); - } - return { - reason: reason as SessionEndReason, - }; -} - -/** - * Validates PreCompress input fields - */ -function validatePreCompressInput(input: Record): { - trigger: PreCompressTrigger; -} { - const trigger = input['trigger']; - if (typeof trigger !== 'string') { - throw new Error( - 'Invalid input for PreCompress hook event: trigger must be a string', - ); - } - return { - trigger: trigger as PreCompressTrigger, - }; -} - /** * Hook event bus that coordinates hook execution across the system */ @@ -305,29 +48,17 @@ export class HookEventHandler { private readonly hookPlanner: HookPlanner; private readonly hookRunner: HookRunner; private readonly hookAggregator: HookAggregator; - private readonly messageBus: MessageBus; constructor( config: Config, - logger: Logger, hookPlanner: HookPlanner, hookRunner: HookRunner, hookAggregator: HookAggregator, - messageBus: MessageBus, ) { this.config = config; this.hookPlanner = hookPlanner; this.hookRunner = hookRunner; this.hookAggregator = hookAggregator; - this.messageBus = messageBus; - - // Subscribe to hook execution requests from MessageBus - if (this.messageBus) { - this.messageBus.subscribe( - MessageBusType.HOOK_EXECUTION_REQUEST, - (request) => this.handleHookExecutionRequest(request), - ); - } } /** @@ -729,152 +460,4 @@ export class HookEventHandler { private getHookTypeFromResult(result: HookExecutionResult): 'command' { return result.hookConfig.type; } - - /** - * Handle hook execution requests from MessageBus - * This method routes the request to the appropriate fire*Event method - * and publishes the response back through MessageBus - * - * The request input only contains event-specific fields. This method adds - * the common base fields (session_id, cwd, etc.) before routing. - */ - private async handleHookExecutionRequest( - request: HookExecutionRequest, - ): Promise { - try { - // Add base fields to the input - const enrichedInput = { - ...this.createBaseInput(request.eventName as HookEventName), - ...request.input, - } as Record; - - let result: AggregatedHookResult; - - // Route to appropriate event handler based on eventName - switch (request.eventName) { - case HookEventName.BeforeTool: { - const { toolName, toolInput, mcpContext } = - validateBeforeToolInput(enrichedInput); - result = await this.fireBeforeToolEvent( - toolName, - toolInput, - mcpContext, - ); - break; - } - case HookEventName.AfterTool: { - const { toolName, toolInput, toolResponse, mcpContext } = - validateAfterToolInput(enrichedInput); - result = await this.fireAfterToolEvent( - toolName, - toolInput, - toolResponse, - mcpContext, - ); - break; - } - case HookEventName.BeforeAgent: { - const { prompt } = validateBeforeAgentInput(enrichedInput); - result = await this.fireBeforeAgentEvent(prompt); - break; - } - case HookEventName.AfterAgent: { - const { prompt, promptResponse, stopHookActive } = - validateAfterAgentInput(enrichedInput); - result = await this.fireAfterAgentEvent( - prompt, - promptResponse, - stopHookActive, - ); - break; - } - case HookEventName.BeforeModel: { - const { llmRequest } = validateModelInput( - enrichedInput, - 'BeforeModel', - ); - const translatedRequest = - defaultHookTranslator.toHookLLMRequest(llmRequest); - // Update the enrichedInput with translated request - enrichedInput['llm_request'] = translatedRequest; - result = await this.fireBeforeModelEvent(llmRequest); - break; - } - case HookEventName.AfterModel: { - const { llmRequest, llmResponse } = - validateAfterModelInput(enrichedInput); - const translatedRequest = - defaultHookTranslator.toHookLLMRequest(llmRequest); - const translatedResponse = - defaultHookTranslator.toHookLLMResponse(llmResponse); - // Update the enrichedInput with translated versions - enrichedInput['llm_request'] = translatedRequest; - enrichedInput['llm_response'] = translatedResponse; - result = await this.fireAfterModelEvent(llmRequest, llmResponse); - break; - } - case HookEventName.BeforeToolSelection: { - const { llmRequest } = validateModelInput( - enrichedInput, - 'BeforeToolSelection', - ); - const translatedRequest = - defaultHookTranslator.toHookLLMRequest(llmRequest); - // Update the enrichedInput with translated request - enrichedInput['llm_request'] = translatedRequest; - result = await this.fireBeforeToolSelectionEvent(llmRequest); - break; - } - case HookEventName.Notification: { - const { notificationType, message, details } = - validateNotificationInput(enrichedInput); - result = await this.fireNotificationEvent( - notificationType, - message, - details, - ); - break; - } - case HookEventName.SessionStart: { - const { source } = validateSessionStartInput(enrichedInput); - result = await this.fireSessionStartEvent(source); - break; - } - case HookEventName.SessionEnd: { - const { reason } = validateSessionEndInput(enrichedInput); - result = await this.fireSessionEndEvent(reason); - break; - } - case HookEventName.PreCompress: { - const { trigger } = validatePreCompressInput(enrichedInput); - result = await this.firePreCompressEvent(trigger); - break; - } - default: - throw new Error(`Unsupported hook event: ${request.eventName}`); - } - - // Publish response through MessageBus - if (this.messageBus) { - // eslint-disable-next-line @typescript-eslint/no-floating-promises - this.messageBus.publish({ - type: MessageBusType.HOOK_EXECUTION_RESPONSE, - correlationId: request.correlationId, - success: result.success, - output: result.finalOutput as unknown as Record, - }); - } - } catch (error) { - // Publish error response - if (this.messageBus) { - // eslint-disable-next-line @typescript-eslint/no-floating-promises - this.messageBus.publish({ - type: MessageBusType.HOOK_EXECUTION_RESPONSE, - correlationId: request.correlationId, - success: false, - error: error instanceof Error ? error : new Error(String(error)), - }); - } - } - } } diff --git a/packages/core/src/hooks/hookSystem.ts b/packages/core/src/hooks/hookSystem.ts index bfb855c5d5..e3d14b4a62 100644 --- a/packages/core/src/hooks/hookSystem.ts +++ b/packages/core/src/hooks/hookSystem.ts @@ -11,8 +11,6 @@ import { HookAggregator } from './hookAggregator.js'; import { HookPlanner } from './hookPlanner.js'; import { HookEventHandler } from './hookEventHandler.js'; import type { HookRegistryEntry } from './hookRegistry.js'; -import { logs, type Logger } from '@opentelemetry/api-logs'; -import { SERVICE_NAME } from '../telemetry/constants.js'; import { debugLogger } from '../utils/debugLogger.js'; import type { SessionStartSource, @@ -155,9 +153,6 @@ export class HookSystem { private readonly hookEventHandler: HookEventHandler; constructor(config: Config) { - const logger: Logger = logs.getLogger(SERVICE_NAME); - const messageBus = config.getMessageBus(); - // Initialize components this.hookRegistry = new HookRegistry(config); this.hookRunner = new HookRunner(config); @@ -165,11 +160,9 @@ export class HookSystem { this.hookPlanner = new HookPlanner(this.hookRegistry); this.hookEventHandler = new HookEventHandler( config, - logger, this.hookPlanner, this.hookRunner, this.hookAggregator, - messageBus, // Pass MessageBus to enable mediated hook execution ); } diff --git a/packages/core/src/ide/ide-client.test.ts b/packages/core/src/ide/ide-client.test.ts index 64bfc022b1..24a87143cb 100644 --- a/packages/core/src/ide/ide-client.test.ts +++ b/packages/core/src/ide/ide-client.test.ts @@ -1131,4 +1131,92 @@ describe('getIdeServerHost', () => { '/run/.containerenv', ); // Short-circuiting }); + + describe('validateWorkspacePath', () => { + describe('with special characters and encoding', () => { + it('should return true for a URI-encoded path with spaces', () => { + const workspacePath = 'file:///test/my%20workspace'; + const cwd = '/test/my workspace/sub-dir'; + const result = IdeClient.validateWorkspacePath(workspacePath, cwd); + expect(result.isValid).toBe(true); + }); + + it('should return true for a URI-encoded path with Korean characters', () => { + const workspacePath = 'file:///test/%ED%85%8C%EC%8A%A4%ED%8A%B8'; // "테스트" + const cwd = '/test/테스트/sub-dir'; + const result = IdeClient.validateWorkspacePath(workspacePath, cwd); + expect(result.isValid).toBe(true); + }); + + it('should return true for a plain decoded path with Korean characters', () => { + const workspacePath = '/test/테스트'; + const cwd = '/test/테스트/sub-dir'; + const result = IdeClient.validateWorkspacePath(workspacePath, cwd); + expect(result.isValid).toBe(true); + }); + + it('should return true when one of multi-root paths is a valid URI-encoded path', () => { + const workspacePath = [ + '/another/workspace', + 'file:///test/%ED%85%8C%EC%8A%A4%ED%8A%B8', // "테스트" + ].join(path.delimiter); + const cwd = '/test/테스트/sub-dir'; + const result = IdeClient.validateWorkspacePath(workspacePath, cwd); + expect(result.isValid).toBe(true); + }); + + it('should return true for paths containing a literal % sign', () => { + const workspacePath = '/test/a%path'; + const cwd = '/test/a%path/sub-dir'; + const result = IdeClient.validateWorkspacePath(workspacePath, cwd); + expect(result.isValid).toBe(true); + }); + + it.skipIf(process.platform !== 'win32')( + 'should correctly convert a Windows file URI', + () => { + const workspacePath = 'file:///C:\\Users\\test'; + const cwd = 'C:\\Users\\test\\sub-dir'; + + const result = IdeClient.validateWorkspacePath(workspacePath, cwd); + + expect(result.isValid).toBe(true); + }, + ); + }); + }); + + describe('validateWorkspacePath (sanitization)', () => { + it.each([ + { + description: 'should return true for identical paths', + workspacePath: '/test/ws', + cwd: '/test/ws', + expectedValid: true, + }, + { + description: 'should return true when workspace has file:// protocol', + workspacePath: 'file:///test/ws', + cwd: '/test/ws', + expectedValid: true, + }, + { + description: 'should return true when workspace has encoded spaces', + workspacePath: '/test/my%20ws', + cwd: '/test/my ws', + expectedValid: true, + }, + { + description: + 'should return true when cwd needs normalization matching workspace', + workspacePath: '/test/my ws', + cwd: '/test/my%20ws', + expectedValid: true, + }, + ])('$description', ({ workspacePath, cwd, expectedValid }) => { + expect(IdeClient.validateWorkspacePath(workspacePath, cwd)).toMatchObject( + { isValid: expectedValid }, + ); + }); + }); }); diff --git a/packages/core/src/ide/ide-client.ts b/packages/core/src/ide/ide-client.ts index a4d9234bd0..928c411395 100644 --- a/packages/core/src/ide/ide-client.ts +++ b/packages/core/src/ide/ide-client.ts @@ -5,7 +5,7 @@ */ import * as fs from 'node:fs'; -import { isSubpath } from '../utils/paths.js'; +import { isSubpath, resolveToRealPath } from '../utils/paths.js'; import { detectIde, type IdeInfo } from '../ide/detect-ide.js'; import { ideContextStore } from './ideContext.js'; import { @@ -65,16 +65,6 @@ type ConnectionConfig = { stdio?: StdioConfig; }; -function getRealPath(path: string): string { - try { - return fs.realpathSync(path); - } catch (_e) { - // If realpathSync fails, it might be because the path doesn't exist. - // In that case, we can fall back to the original path. - return path; - } -} - /** * Manages the connection to and interaction with the IDE server. */ @@ -521,12 +511,14 @@ export class IdeClient { }; } - const ideWorkspacePaths = ideWorkspacePath.split(path.delimiter); - const realCwd = getRealPath(cwd); - const isWithinWorkspace = ideWorkspacePaths.some((workspacePath) => { - const idePath = getRealPath(workspacePath); - return isSubpath(idePath, realCwd); - }); + const ideWorkspacePaths = ideWorkspacePath + .split(path.delimiter) + .map((p) => resolveToRealPath(p)) + .filter((e) => !!e); + const realCwd = resolveToRealPath(cwd); + const isWithinWorkspace = ideWorkspacePaths.some((workspacePath) => + isSubpath(workspacePath, realCwd), + ); if (!isWithinWorkspace) { return { diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index fdd54c5150..348df878d5 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -47,6 +47,7 @@ export * from './fallback/types.js'; export * from './code_assist/codeAssist.js'; export * from './code_assist/oauth2.js'; export * from './code_assist/server.js'; +export * from './code_assist/setup.js'; export * from './code_assist/types.js'; export * from './code_assist/telemetry.js'; export * from './core/apiKeyCredentialStorage.js'; diff --git a/packages/core/src/policy/policy-engine.test.ts b/packages/core/src/policy/policy-engine.test.ts index a5df8e8167..782123cfb3 100644 --- a/packages/core/src/policy/policy-engine.test.ts +++ b/packages/core/src/policy/policy-engine.test.ts @@ -1821,291 +1821,4 @@ describe('PolicyEngine', () => { expect(result.decision).toBe(PolicyDecision.DENY); }); }); - - describe('checkHook', () => { - it('should allow hooks by default', async () => { - engine = new PolicyEngine({}, mockCheckerRunner); - const decision = await engine.checkHook({ - eventName: 'BeforeTool', - hookSource: 'user', - }); - expect(decision).toBe(PolicyDecision.ALLOW); - }); - - it('should deny all hooks when allowHooks is false', async () => { - engine = new PolicyEngine({ allowHooks: false }, mockCheckerRunner); - const decision = await engine.checkHook({ - eventName: 'BeforeTool', - hookSource: 'user', - }); - expect(decision).toBe(PolicyDecision.DENY); - }); - - it('should deny project hooks in untrusted folders', async () => { - engine = new PolicyEngine({}, mockCheckerRunner); - const decision = await engine.checkHook({ - eventName: 'BeforeTool', - hookSource: 'project', - trustedFolder: false, - }); - expect(decision).toBe(PolicyDecision.DENY); - }); - - it('should allow project hooks in trusted folders', async () => { - engine = new PolicyEngine({}, mockCheckerRunner); - const decision = await engine.checkHook({ - eventName: 'BeforeTool', - hookSource: 'project', - trustedFolder: true, - }); - expect(decision).toBe(PolicyDecision.ALLOW); - }); - - it('should allow user hooks in untrusted folders', async () => { - engine = new PolicyEngine({}, mockCheckerRunner); - const decision = await engine.checkHook({ - eventName: 'BeforeTool', - hookSource: 'user', - trustedFolder: false, - }); - expect(decision).toBe(PolicyDecision.ALLOW); - }); - - it('should run hook checkers and deny on DENY decision', async () => { - const hookCheckers = [ - { - eventName: 'BeforeTool', - checker: { type: 'external' as const, name: 'test-hook-checker' }, - }, - ]; - engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner); - - vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ - decision: SafetyCheckDecision.DENY, - reason: 'Hook checker denied', - }); - - const decision = await engine.checkHook({ - eventName: 'BeforeTool', - hookSource: 'user', - }); - - expect(decision).toBe(PolicyDecision.DENY); - expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith( - expect.objectContaining({ name: 'hook:BeforeTool' }), - expect.objectContaining({ name: 'test-hook-checker' }), - ); - }); - - it('should run hook checkers and allow on ALLOW decision', async () => { - const hookCheckers = [ - { - eventName: 'BeforeTool', - checker: { type: 'external' as const, name: 'test-hook-checker' }, - }, - ]; - engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner); - - vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ - decision: SafetyCheckDecision.ALLOW, - }); - - const decision = await engine.checkHook({ - eventName: 'BeforeTool', - hookSource: 'user', - }); - - expect(decision).toBe(PolicyDecision.ALLOW); - }); - - it('should return ASK_USER when checker requests it', async () => { - const hookCheckers = [ - { - checker: { type: 'external' as const, name: 'test-hook-checker' }, - }, - ]; - engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner); - - vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ - decision: SafetyCheckDecision.ASK_USER, - reason: 'Needs confirmation', - }); - - const decision = await engine.checkHook({ - eventName: 'BeforeTool', - hookSource: 'user', - }); - - expect(decision).toBe(PolicyDecision.ASK_USER); - }); - - it('should return DENY for ASK_USER in non-interactive mode', async () => { - const hookCheckers = [ - { - checker: { type: 'external' as const, name: 'test-hook-checker' }, - }, - ]; - engine = new PolicyEngine( - { hookCheckers, nonInteractive: true }, - mockCheckerRunner, - ); - - vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ - decision: SafetyCheckDecision.ASK_USER, - reason: 'Needs confirmation', - }); - - const decision = await engine.checkHook({ - eventName: 'BeforeTool', - hookSource: 'user', - }); - - expect(decision).toBe(PolicyDecision.DENY); - }); - - it('should match hook checkers by eventName', async () => { - const hookCheckers = [ - { - eventName: 'AfterTool', - checker: { type: 'external' as const, name: 'after-tool-checker' }, - }, - { - eventName: 'BeforeTool', - checker: { type: 'external' as const, name: 'before-tool-checker' }, - }, - ]; - engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner); - - vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ - decision: SafetyCheckDecision.ALLOW, - }); - - await engine.checkHook({ - eventName: 'BeforeTool', - hookSource: 'user', - }); - - expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ name: 'before-tool-checker' }), - ); - expect(mockCheckerRunner.runChecker).not.toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ name: 'after-tool-checker' }), - ); - }); - - it('should match hook checkers by hookSource', async () => { - const hookCheckers = [ - { - hookSource: 'project' as const, - checker: { type: 'external' as const, name: 'project-checker' }, - }, - { - hookSource: 'user' as const, - checker: { type: 'external' as const, name: 'user-checker' }, - }, - ]; - engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner); - - vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ - decision: SafetyCheckDecision.ALLOW, - }); - - await engine.checkHook({ - eventName: 'BeforeTool', - hookSource: 'user', - }); - - expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ name: 'user-checker' }), - ); - expect(mockCheckerRunner.runChecker).not.toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ name: 'project-checker' }), - ); - }); - - it('should deny when hook checker throws an error', async () => { - const hookCheckers = [ - { - checker: { type: 'external' as const, name: 'failing-checker' }, - }, - ]; - engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner); - - vi.mocked(mockCheckerRunner.runChecker).mockRejectedValue( - new Error('Checker failed'), - ); - - const decision = await engine.checkHook({ - eventName: 'BeforeTool', - hookSource: 'user', - }); - - expect(decision).toBe(PolicyDecision.DENY); - }); - - it('should run hook checkers in priority order', async () => { - const hookCheckers = [ - { - priority: 5, - checker: { type: 'external' as const, name: 'low-priority' }, - }, - { - priority: 20, - checker: { type: 'external' as const, name: 'high-priority' }, - }, - { - priority: 10, - checker: { type: 'external' as const, name: 'medium-priority' }, - }, - ]; - engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner); - - vi.mocked(mockCheckerRunner.runChecker).mockImplementation( - async (_call, config) => { - if (config.name === 'high-priority') { - return { decision: SafetyCheckDecision.DENY, reason: 'denied' }; - } - return { decision: SafetyCheckDecision.ALLOW }; - }, - ); - - await engine.checkHook({ - eventName: 'BeforeTool', - hookSource: 'user', - }); - - // Should only call the high-priority checker (first in sorted order) - expect(mockCheckerRunner.runChecker).toHaveBeenCalledTimes(1); - expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ name: 'high-priority' }), - ); - }); - }); - - describe('addHookChecker', () => { - it('should add a new hook checker and maintain priority order', () => { - engine = new PolicyEngine({}, mockCheckerRunner); - - engine.addHookChecker({ - priority: 5, - checker: { type: 'external', name: 'checker1' }, - }); - engine.addHookChecker({ - priority: 10, - checker: { type: 'external', name: 'checker2' }, - }); - - const checkers = engine.getHookCheckers(); - expect(checkers).toHaveLength(2); - expect(checkers[0].priority).toBe(10); - expect(checkers[0].checker.name).toBe('checker2'); - expect(checkers[1].priority).toBe(5); - expect(checkers[1].checker.name).toBe('checker1'); - }); - }); }); diff --git a/packages/core/src/policy/policy-engine.ts b/packages/core/src/policy/policy-engine.ts index 48feb537e6..be5536a9df 100644 --- a/packages/core/src/policy/policy-engine.ts +++ b/packages/core/src/policy/policy-engine.ts @@ -11,8 +11,6 @@ import { type PolicyRule, type SafetyCheckerRule, type HookCheckerRule, - type HookExecutionContext, - getHookSource, ApprovalMode, type CheckResult, } from './types.js'; @@ -20,7 +18,6 @@ import { stableStringify } from './stable-stringify.js'; import { debugLogger } from '../utils/debugLogger.js'; import type { CheckerRunner } from '../safety/checker-runner.js'; import { SafetyCheckDecision } from '../safety/protocol.js'; -import type { HookExecutionRequest } from '../confirmation-bus/types.js'; import { SHELL_TOOL_NAMES, initializeShellParsers, @@ -81,26 +78,6 @@ function ruleMatches( return true; } -/** - * Check if a hook checker rule matches a hook execution context. - */ -function hookCheckerMatches( - rule: HookCheckerRule, - context: HookExecutionContext, -): boolean { - // Check event name if specified - if (rule.eventName && rule.eventName !== context.eventName) { - return false; - } - - // Check hook source if specified - if (rule.hookSource && rule.hookSource !== context.hookSource) { - return false; - } - - return true; -} - export class PolicyEngine { private rules: PolicyRule[]; private checkers: SafetyCheckerRule[]; @@ -108,7 +85,6 @@ export class PolicyEngine { private readonly defaultDecision: PolicyDecision; private readonly nonInteractive: boolean; private readonly checkerRunner?: CheckerRunner; - private readonly allowHooks: boolean; private approvalMode: ApprovalMode; constructor(config: PolicyEngineConfig = {}, checkerRunner?: CheckerRunner) { @@ -124,7 +100,6 @@ export class PolicyEngine { this.defaultDecision = config.defaultDecision ?? PolicyDecision.ASK_USER; this.nonInteractive = config.nonInteractive ?? false; this.checkerRunner = checkerRunner; - this.allowHooks = config.allowHooks ?? true; this.approvalMode = config.approvalMode ?? ApprovalMode.DEFAULT; } @@ -495,84 +470,6 @@ export class PolicyEngine { return this.hookCheckers; } - /** - * Check if a hook execution is allowed based on the configured policies. - * Runs hook-specific safety checkers if configured. - */ - async checkHook( - request: HookExecutionRequest | HookExecutionContext, - ): Promise { - // If hooks are globally disabled, deny all hook executions - if (!this.allowHooks) { - return PolicyDecision.DENY; - } - - const context: HookExecutionContext = - 'input' in request - ? { - eventName: request.eventName, - hookSource: getHookSource(request.input), - trustedFolder: - typeof request.input['trusted_folder'] === 'boolean' - ? request.input['trusted_folder'] - : undefined, - } - : request; - - // In untrusted folders, deny project-level hooks - if (context.trustedFolder === false && context.hookSource === 'project') { - return PolicyDecision.DENY; - } - - // Run hook-specific safety checkers if configured - if (this.checkerRunner && this.hookCheckers.length > 0) { - for (const checkerRule of this.hookCheckers) { - if (hookCheckerMatches(checkerRule, context)) { - debugLogger.debug( - `[PolicyEngine.checkHook] Running hook checker: ${checkerRule.checker.name} for event: ${context.eventName}`, - ); - try { - // Create a synthetic function call for the checker runner - // This allows reusing the existing checker infrastructure - const syntheticCall = { - name: `hook:${context.eventName}`, - args: { - hookSource: context.hookSource, - trustedFolder: context.trustedFolder, - }, - }; - - const result = await this.checkerRunner.runChecker( - syntheticCall, - checkerRule.checker, - ); - - if (result.decision === SafetyCheckDecision.DENY) { - debugLogger.debug( - `[PolicyEngine.checkHook] Hook checker denied: ${result.reason}`, - ); - return PolicyDecision.DENY; - } else if (result.decision === SafetyCheckDecision.ASK_USER) { - debugLogger.debug( - `[PolicyEngine.checkHook] Hook checker requested ASK_USER: ${result.reason}`, - ); - // For hooks, ASK_USER is treated as DENY in non-interactive mode - return this.applyNonInteractiveMode(PolicyDecision.ASK_USER); - } - } catch (error) { - debugLogger.debug( - `[PolicyEngine.checkHook] Hook checker failed: ${error}`, - ); - return PolicyDecision.DENY; - } - } - } - } - - // Default: Allow hooks - return PolicyDecision.ALLOW; - } - private applyNonInteractiveMode(decision: PolicyDecision): PolicyDecision { // In non-interactive mode, ASK_USER becomes DENY if (this.nonInteractive && decision === PolicyDecision.ASK_USER) { diff --git a/packages/core/src/test-utils/mock-message-bus.ts b/packages/core/src/test-utils/mock-message-bus.ts index 1bd18c2f55..c28f077bf2 100644 --- a/packages/core/src/test-utils/mock-message-bus.ts +++ b/packages/core/src/test-utils/mock-message-bus.ts @@ -6,12 +6,7 @@ import { vi } from 'vitest'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; -import { - MessageBusType, - type Message, - type HookExecutionRequest, - type HookExecutionResponse, -} from '../confirmation-bus/types.js'; +import { MessageBusType, type Message } from '../confirmation-bus/types.js'; /** * Mock MessageBus for testing hook execution through MessageBus @@ -22,8 +17,6 @@ export class MockMessageBus { Set<(message: Message) => void> >(); publishedMessages: Message[] = []; - hookRequests: HookExecutionRequest[] = []; - hookResponses: HookExecutionResponse[] = []; defaultToolDecision: 'allow' | 'deny' | 'ask_user' = 'allow'; /** @@ -32,26 +25,6 @@ export class MockMessageBus { publish = vi.fn((message: Message) => { this.publishedMessages.push(message); - // Capture hook-specific messages - if (message.type === MessageBusType.HOOK_EXECUTION_REQUEST) { - this.hookRequests.push(message); - - // Auto-respond with success for testing - const response: HookExecutionResponse = { - type: MessageBusType.HOOK_EXECUTION_RESPONSE, - correlationId: message.correlationId, - success: true, - output: { - decision: 'allow', - reason: 'Mock hook execution successful', - }, - }; - this.hookResponses.push(response); - - // Emit response to subscribers - this.emit(MessageBusType.HOOK_EXECUTION_RESPONSE, response); - } - // Handle tool confirmation requests if (message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) { if (this.defaultToolDecision === 'allow') { @@ -115,78 +88,13 @@ export class MockMessageBus { } } - /** - * Manually trigger a hook response (for testing custom scenarios) - */ - triggerHookResponse( - correlationId: string, - success: boolean, - output?: Record, - error?: Error, - ) { - const response: HookExecutionResponse = { - type: MessageBusType.HOOK_EXECUTION_RESPONSE, - correlationId, - success, - output, - error, - }; - this.hookResponses.push(response); - this.emit(MessageBusType.HOOK_EXECUTION_RESPONSE, response); - } - - /** - * Get the last hook request published - */ - getLastHookRequest(): HookExecutionRequest | undefined { - return this.hookRequests[this.hookRequests.length - 1]; - } - - /** - * Get all hook requests for a specific event - */ - getHookRequestsForEvent(eventName: string): HookExecutionRequest[] { - return this.hookRequests.filter((req) => req.eventName === eventName); - } - /** * Clear all captured messages (for test isolation) */ clear() { this.publishedMessages = []; - this.hookRequests = []; - this.hookResponses = []; this.subscriptions.clear(); } - - /** - * Verify that a hook execution request was published - */ - expectHookRequest( - eventName: string, - input?: Partial>, - ) { - const request = this.hookRequests.find( - (req) => req.eventName === eventName, - ); - if (!request) { - throw new Error( - `Expected hook request for event "${eventName}" but none was found`, - ); - } - - if (input) { - Object.entries(input).forEach(([key, value]) => { - if (request.input[key] !== value) { - throw new Error( - `Expected hook input.${key} to be ${JSON.stringify(value)} but got ${JSON.stringify(request.input[key])}`, - ); - } - }); - } - - return request; - } } /** diff --git a/packages/core/src/tools/ripGrep.test.ts b/packages/core/src/tools/ripGrep.test.ts index e8eafc9b23..415db097e3 100644 --- a/packages/core/src/tools/ripGrep.test.ts +++ b/packages/core/src/tools/ripGrep.test.ts @@ -1009,10 +1009,10 @@ describe('RipGrepTool', () => { const result = await invocation.execute(controller.signal); expect(result.llmContent).toContain( - 'Error during grep search operation: ripgrep exited with code null', + 'Error during grep search operation: ripgrep was terminated by signal:', ); expect(result.returnDisplay).toContain( - 'Error: ripgrep exited with code null', + 'Error: ripgrep was terminated by signal:', ); }); }); diff --git a/packages/core/src/tools/ripGrep.ts b/packages/core/src/tools/ripGrep.ts index 0e52884b14..12f6d720e2 100644 --- a/packages/core/src/tools/ripGrep.ts +++ b/packages/core/src/tools/ripGrep.ts @@ -432,7 +432,7 @@ class GrepToolInvocation extends BaseToolInvocation< ); }); - child.on('close', (code) => { + child.on('close', (code, signal) => { options.signal.removeEventListener('abort', cleanup); const stdoutData = Buffer.concat(stdoutChunks).toString('utf8'); const stderrData = Buffer.concat(stderrChunks).toString('utf8'); @@ -442,9 +442,13 @@ class GrepToolInvocation extends BaseToolInvocation< } else if (code === 1) { resolve(''); // No matches found } else { - reject( - new Error(`ripgrep exited with code ${code}: ${stderrData}`), - ); + if (signal) { + reject(new Error(`ripgrep was terminated by signal: ${signal}`)); + } else { + reject( + new Error(`ripgrep exited with code ${code}: ${stderrData}`), + ); + } } }); }); diff --git a/packages/core/src/utils/errors.ts b/packages/core/src/utils/errors.ts index 8db1153d92..86f1cc9b86 100644 --- a/packages/core/src/utils/errors.ts +++ b/packages/core/src/utils/errors.ts @@ -81,6 +81,13 @@ export class ForbiddenError extends Error {} export class UnauthorizedError extends Error {} export class BadRequestError extends Error {} +export class ChangeAuthRequestedError extends Error { + constructor() { + super('User requested to change authentication method'); + this.name = 'ChangeAuthRequestedError'; + } +} + interface ResponseData { error?: { code?: number; diff --git a/packages/core/src/utils/googleQuotaErrors.ts b/packages/core/src/utils/googleQuotaErrors.ts index f3a909a20a..0ecc14d93f 100644 --- a/packages/core/src/utils/googleQuotaErrors.ts +++ b/packages/core/src/utils/googleQuotaErrors.ts @@ -63,7 +63,7 @@ export class ValidationRequiredError extends Error { constructor( message: string, - override readonly cause: GoogleApiError, + override readonly cause?: GoogleApiError, validationLink?: string, validationDescription?: string, learnMoreUrl?: string, diff --git a/packages/core/src/utils/paths.test.ts b/packages/core/src/utils/paths.test.ts index 210dc8b448..38b00628e5 100644 --- a/packages/core/src/utils/paths.test.ts +++ b/packages/core/src/utils/paths.test.ts @@ -4,8 +4,23 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, beforeAll, afterAll } from 'vitest'; -import { escapePath, unescapePath, isSubpath, shortenPath } from './paths.js'; +import { describe, it, expect, beforeAll, afterAll, vi } from 'vitest'; +import * as fs from 'node:fs'; +import { + escapePath, + unescapePath, + isSubpath, + shortenPath, + resolveToRealPath, +} from './paths.js'; + +vi.mock('node:fs', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...(actual as object), + realpathSync: (p: string) => p, + }; +}); describe('escapePath', () => { it.each([ @@ -472,3 +487,42 @@ describe('shortenPath', () => { }); }); }); + +describe('resolveToRealPath', () => { + it.each([ + { + description: + 'should return path as-is if no special characters or protocol', + input: '/simple/path', + expected: '/simple/path', + }, + { + description: 'should remove file:// protocol', + input: 'file:///path/to/file', + expected: '/path/to/file', + }, + { + description: 'should decode URI components', + input: '/path/to/some%20folder', + expected: '/path/to/some folder', + }, + { + description: 'should handle both file protocol and encoding', + input: 'file:///path/to/My%20Project', + expected: '/path/to/My Project', + }, + ])('$description', ({ input, expected }) => { + expect(resolveToRealPath(input)).toBe(expected); + }); + + it('should return decoded path even if fs.realpathSync fails', () => { + vi.spyOn(fs, 'realpathSync').mockImplementationOnce(() => { + throw new Error('File not found'); + }); + + const input = 'file:///path/to/New%20Project'; + const expected = '/path/to/New Project'; + + expect(resolveToRealPath(input)).toBe(expected); + }); +}); diff --git a/packages/core/src/utils/paths.ts b/packages/core/src/utils/paths.ts index 4d14a6d230..94ccd96cf3 100644 --- a/packages/core/src/utils/paths.ts +++ b/packages/core/src/utils/paths.ts @@ -8,6 +8,8 @@ import path from 'node:path'; import os from 'node:os'; import process from 'node:process'; import * as crypto from 'node:crypto'; +import * as fs from 'node:fs'; +import { fileURLToPath } from 'node:url'; export const GEMINI_DIR = '.gemini'; export const GOOGLE_ACCOUNTS_FILENAME = 'google_accounts.json'; @@ -343,3 +345,34 @@ export function isSubpath(parentPath: string, childPath: string): boolean { !pathModule.isAbsolute(relative) ); } + +/** + * Resolves a path to its real path, sanitizing it first. + * - Removes 'file://' protocol if present. + * - Decodes URI components (e.g. %20 -> space). + * - Resolves symbolic links using fs.realpathSync. + * + * @param pathStr The path string to resolve. + * @returns The resolved real path. + */ +export function resolveToRealPath(path: string): string { + let resolvedPath = path; + + try { + if (resolvedPath.startsWith('file://')) { + resolvedPath = fileURLToPath(resolvedPath); + } + + resolvedPath = decodeURIComponent(resolvedPath); + } catch (_e) { + // Ignore error (e.g. malformed URI), keep path from previous step + } + + try { + return fs.realpathSync(resolvedPath); + } catch (_e) { + // If realpathSync fails, it might be because the path doesn't exist. + // In that case, we can fall back to the path processed. + return resolvedPath; + } +}