diff --git a/evals/generalist_agent.eval.ts b/evals/generalist_agent.eval.ts index f93005d509..5a51a925cb 100644 --- a/evals/generalist_agent.eval.ts +++ b/evals/generalist_agent.eval.ts @@ -10,7 +10,7 @@ import path from 'node:path'; import fs from 'node:fs/promises'; describe('generalist_agent', () => { - evalTest('ALWAYS_PASSES', { + evalTest('USUALLY_PASSES', { name: 'should be able to use generalist agent by explicitly asking the main agent to invoke it', params: { settings: { diff --git a/packages/cli/src/ui/commands/aboutCommand.test.ts b/packages/cli/src/ui/commands/aboutCommand.test.ts index 9b93641958..f1c010678e 100644 --- a/packages/cli/src/ui/commands/aboutCommand.test.ts +++ b/packages/cli/src/ui/commands/aboutCommand.test.ts @@ -39,6 +39,7 @@ describe('aboutCommand', () => { config: { getModel: vi.fn(), getIdeMode: vi.fn().mockReturnValue(true), + getUserTierName: vi.fn().mockReturnValue(undefined), }, settings: { merged: { @@ -97,6 +98,7 @@ describe('aboutCommand', () => { gcpProject: 'test-gcp-project', ideClient: 'test-ide', userEmail: 'test-email@example.com', + tier: undefined, }); }); @@ -156,4 +158,21 @@ describe('aboutCommand', () => { }), ); }); + + it('should display the tier when getUserTierName returns a value', async () => { + vi.mocked(mockContext.services.config!.getUserTierName).mockReturnValue( + 'Enterprise Tier', + ); + if (!aboutCommand.action) { + throw new Error('The about command must have an action.'); + } + + await aboutCommand.action(mockContext, ''); + + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + expect.objectContaining({ + tier: 'Enterprise Tier', + }), + ); + }); }); diff --git a/packages/cli/src/ui/commands/aboutCommand.ts b/packages/cli/src/ui/commands/aboutCommand.ts index 3def750895..cf21d9b0d5 100644 --- a/packages/cli/src/ui/commands/aboutCommand.ts +++ b/packages/cli/src/ui/commands/aboutCommand.ts @@ -44,6 +44,8 @@ export const aboutCommand: SlashCommand = { }); const userEmail = cachedAccount ?? undefined; + const tier = context.services.config?.getUserTierName(); + const aboutItem: Omit = { type: MessageType.ABOUT, cliVersion, @@ -54,6 +56,7 @@ export const aboutCommand: SlashCommand = { gcpProject, ideClient, userEmail, + tier, }; context.ui.addItem(aboutItem); diff --git a/packages/cli/src/ui/components/AboutBox.test.tsx b/packages/cli/src/ui/components/AboutBox.test.tsx index b6e5968e53..1e4810fec5 100644 --- a/packages/cli/src/ui/components/AboutBox.test.tsx +++ b/packages/cli/src/ui/components/AboutBox.test.tsx @@ -33,13 +33,13 @@ describe('AboutBox', () => { expect(output).toContain('gemini-pro'); expect(output).toContain('default'); expect(output).toContain('macOS'); - expect(output).toContain('OAuth'); + expect(output).toContain('Logged in with Google'); }); it.each([ - ['userEmail', 'test@example.com', 'User Email'], ['gcpProject', 'my-project', 'GCP Project'], ['ideClient', 'vscode', 'IDE Client'], + ['tier', 'Enterprise', 'Tier'], ])('renders optional prop %s', (prop, value, label) => { const props = { ...defaultProps, [prop]: value }; const { lastFrame } = render(); @@ -48,6 +48,13 @@ describe('AboutBox', () => { expect(output).toContain(value); }); + it('renders Auth Method with email when userEmail is provided', () => { + const props = { ...defaultProps, userEmail: 'test@example.com' }; + const { lastFrame } = render(); + const output = lastFrame(); + expect(output).toContain('Logged in with Google (test@example.com)'); + }); + it('renders Auth Method correctly when not oauth', () => { const props = { ...defaultProps, selectedAuthType: 'api-key' }; const { lastFrame } = render(); diff --git a/packages/cli/src/ui/components/AboutBox.tsx b/packages/cli/src/ui/components/AboutBox.tsx index b14b814f03..4b45a55b37 100644 --- a/packages/cli/src/ui/components/AboutBox.tsx +++ b/packages/cli/src/ui/components/AboutBox.tsx @@ -18,6 +18,7 @@ interface AboutBoxProps { gcpProject: string; ideClient: string; userEmail?: string; + tier?: string; } export const AboutBox: React.FC = ({ @@ -29,6 +30,7 @@ export const AboutBox: React.FC = ({ gcpProject, ideClient, userEmail, + tier, }) => ( = ({ - {selectedAuthType.startsWith('oauth') ? 'OAuth' : selectedAuthType} + {selectedAuthType.startsWith('oauth') + ? userEmail + ? `Logged in with Google (${userEmail})` + : 'Logged in with Google' + : selectedAuthType} - {userEmail && ( + {tier && ( - User Email + Tier - {userEmail} + {tier} )} diff --git a/packages/cli/src/ui/components/HistoryItemDisplay.tsx b/packages/cli/src/ui/components/HistoryItemDisplay.tsx index 509645eda5..7a72dc6120 100644 --- a/packages/cli/src/ui/components/HistoryItemDisplay.tsx +++ b/packages/cli/src/ui/components/HistoryItemDisplay.tsx @@ -112,6 +112,7 @@ export const HistoryItemDisplay: React.FC = ({ gcpProject={itemForDisplay.gcpProject} ideClient={itemForDisplay.ideClient} userEmail={itemForDisplay.userEmail} + tier={itemForDisplay.tier} /> )} {itemForDisplay.type === 'help' && commands && ( diff --git a/packages/cli/src/ui/types.ts b/packages/cli/src/ui/types.ts index dcadfbcffd..ae865d2488 100644 --- a/packages/cli/src/ui/types.ts +++ b/packages/cli/src/ui/types.ts @@ -145,6 +145,7 @@ export type HistoryItemAbout = HistoryItemBase & { gcpProject: string; ideClient: string; userEmail?: string; + tier?: string; }; export type HistoryItemHelp = HistoryItemBase & { diff --git a/packages/core/src/code_assist/codeAssist.test.ts b/packages/core/src/code_assist/codeAssist.test.ts index 0974e2237e..90ebfb1d9c 100644 --- a/packages/core/src/code_assist/codeAssist.test.ts +++ b/packages/core/src/code_assist/codeAssist.test.ts @@ -64,6 +64,7 @@ describe('codeAssist', () => { httpOptions, 'session-123', 'free-tier', + undefined, ); expect(generator).toBeInstanceOf(MockedCodeAssistServer); }); @@ -89,6 +90,7 @@ describe('codeAssist', () => { httpOptions, undefined, // No session ID 'free-tier', + undefined, ); expect(generator).toBeInstanceOf(MockedCodeAssistServer); }); diff --git a/packages/core/src/code_assist/codeAssist.ts b/packages/core/src/code_assist/codeAssist.ts index f8c9ac47b8..fee43e9c45 100644 --- a/packages/core/src/code_assist/codeAssist.ts +++ b/packages/core/src/code_assist/codeAssist.ts @@ -31,6 +31,7 @@ export async function createCodeAssistContentGenerator( httpOptions, sessionId, userData.userTier, + userData.userTierName, ); } diff --git a/packages/core/src/code_assist/server.ts b/packages/core/src/code_assist/server.ts index fca17b6d95..bf57bc55b7 100644 --- a/packages/core/src/code_assist/server.ts +++ b/packages/core/src/code_assist/server.ts @@ -69,6 +69,7 @@ export class CodeAssistServer implements ContentGenerator { readonly httpOptions: HttpOptions = {}, readonly sessionId?: string, readonly userTier?: UserTierId, + readonly userTierName?: string, ) {} async generateContentStream( diff --git a/packages/core/src/code_assist/setup.test.ts b/packages/core/src/code_assist/setup.test.ts index 2a9640f703..bd43ed2e88 100644 --- a/packages/core/src/code_assist/setup.test.ts +++ b/packages/core/src/code_assist/setup.test.ts @@ -67,6 +67,7 @@ describe('setupUser for existing user', () => { {}, '', undefined, + undefined, ); }); @@ -83,10 +84,12 @@ describe('setupUser for existing user', () => { {}, '', undefined, + undefined, ); expect(projectId).toEqual({ projectId: 'server-project', userTier: 'standard-tier', + userTierName: 'paid', }); }); @@ -148,6 +151,7 @@ describe('setupUser for new user', () => { {}, '', undefined, + undefined, ); expect(mockLoad).toHaveBeenCalled(); expect(mockOnboardUser).toHaveBeenCalledWith({ @@ -163,6 +167,7 @@ describe('setupUser for new user', () => { expect(userData).toEqual({ projectId: 'server-project', userTier: 'standard-tier', + userTierName: 'paid', }); }); @@ -178,6 +183,7 @@ describe('setupUser for new user', () => { {}, '', undefined, + undefined, ); expect(mockLoad).toHaveBeenCalled(); expect(mockOnboardUser).toHaveBeenCalledWith({ @@ -192,6 +198,7 @@ describe('setupUser for new user', () => { expect(userData).toEqual({ projectId: 'server-project', userTier: 'free-tier', + userTierName: 'free', }); }); @@ -210,6 +217,7 @@ describe('setupUser for new user', () => { expect(userData).toEqual({ projectId: 'test-project', userTier: 'standard-tier', + userTierName: 'paid', }); }); @@ -268,6 +276,7 @@ describe('setupUser for new user', () => { expect(userData).toEqual({ projectId: 'server-project', userTier: 'standard-tier', + userTierName: 'paid', }); }); @@ -294,6 +303,7 @@ describe('setupUser for new user', () => { expect(userData).toEqual({ projectId: 'server-project', userTier: 'standard-tier', + userTierName: 'paid', }); }); }); diff --git a/packages/core/src/code_assist/setup.ts b/packages/core/src/code_assist/setup.ts index 2d137607a2..994bb99568 100644 --- a/packages/core/src/code_assist/setup.ts +++ b/packages/core/src/code_assist/setup.ts @@ -25,6 +25,7 @@ export class ProjectIdRequiredError extends Error { export interface UserData { projectId: string; userTier: UserTierId; + userTierName?: string; } /** @@ -37,7 +38,14 @@ export async function setupUser(client: AuthClient): Promise { process.env['GOOGLE_CLOUD_PROJECT'] || process.env['GOOGLE_CLOUD_PROJECT_ID'] || undefined; - const caServer = new CodeAssistServer(client, projectId, {}, '', undefined); + const caServer = new CodeAssistServer( + client, + projectId, + {}, + '', + undefined, + undefined, + ); const coreClientMetadata: ClientMetadata = { ideType: 'IDE_UNSPECIFIED', platform: 'PLATFORM_UNSPECIFIED', @@ -58,6 +66,7 @@ export async function setupUser(client: AuthClient): Promise { return { projectId, userTier: loadRes.currentTier.id, + userTierName: loadRes.currentTier.name, }; } throw new ProjectIdRequiredError(); @@ -65,6 +74,7 @@ export async function setupUser(client: AuthClient): Promise { return { projectId: loadRes.cloudaicompanionProject, userTier: loadRes.currentTier.id, + userTierName: loadRes.currentTier.name, }; } @@ -103,6 +113,7 @@ export async function setupUser(client: AuthClient): Promise { return { projectId, userTier: tier.id, + userTierName: tier.name, }; } throw new ProjectIdRequiredError(); @@ -111,6 +122,7 @@ export async function setupUser(client: AuthClient): Promise { return { projectId: lroRes.response.cloudaicompanionProject.id, userTier: tier.id, + userTierName: tier.name, }; } diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 02d431f2d7..7b9fbf1a80 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -963,6 +963,10 @@ export class Config { return this.contentGenerator?.userTier; } + getUserTierName(): string | undefined { + return this.contentGenerator?.userTierName; + } + /** * Provides access to the BaseLlmClient for stateless LLM operations. */ diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index 740bede47c..eb45c9f218 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -44,6 +44,8 @@ export interface ContentGenerator { embedContent(request: EmbedContentParameters): Promise; userTier?: UserTierId; + + userTierName?: string; } export enum AuthType { diff --git a/packages/core/src/core/fakeContentGenerator.ts b/packages/core/src/core/fakeContentGenerator.ts index a464c4f8fa..e6d7bbf8ff 100644 --- a/packages/core/src/core/fakeContentGenerator.ts +++ b/packages/core/src/core/fakeContentGenerator.ts @@ -42,6 +42,7 @@ export type FakeResponse = export class FakeContentGenerator implements ContentGenerator { private callCounter = 0; userTier?: UserTierId; + userTierName?: string; constructor(private readonly responses: FakeResponse[]) {} diff --git a/packages/core/src/core/loggingContentGenerator.test.ts b/packages/core/src/core/loggingContentGenerator.test.ts index 92286d207c..4b99f8a06c 100644 --- a/packages/core/src/core/loggingContentGenerator.test.ts +++ b/packages/core/src/core/loggingContentGenerator.test.ts @@ -31,6 +31,7 @@ import type { ContentGenerator } from './contentGenerator.js'; import { LoggingContentGenerator } from './loggingContentGenerator.js'; import type { Config } from '../config/config.js'; import { ApiRequestEvent } from '../telemetry/types.js'; +import { UserTierId } from '../code_assist/types.js'; describe('LoggingContentGenerator', () => { let wrapped: ContentGenerator; @@ -302,4 +303,16 @@ describe('LoggingContentGenerator', () => { expect(result).toBe(response); }); }); + + describe('delegation', () => { + it('should delegate userTier to wrapped', () => { + wrapped.userTier = UserTierId.STANDARD; + expect(loggingContentGenerator.userTier).toBe(UserTierId.STANDARD); + }); + + it('should delegate userTierName to wrapped', () => { + wrapped.userTierName = 'Standard Tier'; + expect(loggingContentGenerator.userTierName).toBe('Standard Tier'); + }); + }); }); diff --git a/packages/core/src/core/loggingContentGenerator.ts b/packages/core/src/core/loggingContentGenerator.ts index cc5ab05890..fd89f86f54 100644 --- a/packages/core/src/core/loggingContentGenerator.ts +++ b/packages/core/src/core/loggingContentGenerator.ts @@ -23,6 +23,7 @@ import { ApiErrorEvent, } from '../telemetry/types.js'; import type { Config } from '../config/config.js'; +import type { UserTierId } from '../code_assist/types.js'; import { logApiError, logApiRequest, @@ -51,6 +52,14 @@ export class LoggingContentGenerator implements ContentGenerator { return this.wrapped; } + get userTier(): UserTierId | undefined { + return this.wrapped.userTier; + } + + get userTierName(): string | undefined { + return this.wrapped.userTierName; + } + private logApiRequest( contents: Content[], model: string, diff --git a/packages/core/src/core/recordingContentGenerator.ts b/packages/core/src/core/recordingContentGenerator.ts index 27abcb418f..510a20b8c1 100644 --- a/packages/core/src/core/recordingContentGenerator.ts +++ b/packages/core/src/core/recordingContentGenerator.ts @@ -25,13 +25,19 @@ import { safeJsonStringify } from '../utils/safeJsonStringify.js'; // // Note that only the "interesting" bits of the responses are actually kept. export class RecordingContentGenerator implements ContentGenerator { - userTier?: UserTierId; - constructor( private readonly realGenerator: ContentGenerator, private readonly filePath: string, ) {} + get userTier(): UserTierId | undefined { + return this.realGenerator.userTier; + } + + get userTierName(): string | undefined { + return this.realGenerator.userTierName; + } + async generateContent( request: GenerateContentParameters, userPromptId: string,