diff --git a/packages/a2a-server/src/agent/executor.ts b/packages/a2a-server/src/agent/executor.ts index 302224cbf0..4fdea984ae 100644 --- a/packages/a2a-server/src/agent/executor.ts +++ b/packages/a2a-server/src/agent/executor.ts @@ -125,9 +125,7 @@ export class CoderAgentExecutor implements AgentExecutor { eventBus, ); runtimeTask.taskState = persistedState._taskState; - await runtimeTask.geminiClient.initialize( - runtimeTask.config.getContentGeneratorConfig(), - ); + await runtimeTask.geminiClient.initialize(); const wrapper = new TaskWrapper(runtimeTask, agentSettings); this.tasks.set(sdkTask.id, wrapper); @@ -144,9 +142,7 @@ export class CoderAgentExecutor implements AgentExecutor { const agentSettings = agentSettingsInput || ({} as AgentSettings); const config = await this.getConfig(agentSettings, taskId); const runtimeTask = await Task.create(taskId, contextId, config, eventBus); - await runtimeTask.geminiClient.initialize( - runtimeTask.config.getContentGeneratorConfig(), - ); + await runtimeTask.geminiClient.initialize(); const wrapper = new TaskWrapper(runtimeTask, agentSettings); this.tasks.set(taskId, wrapper); diff --git a/packages/a2a-server/src/agent/task.ts b/packages/a2a-server/src/agent/task.ts index f8c756992a..4514aac011 100644 --- a/packages/a2a-server/src/agent/task.ts +++ b/packages/a2a-server/src/agent/task.ts @@ -6,7 +6,7 @@ import { CoreToolScheduler, - GeminiClient, + type GeminiClient, GeminiEventType, ToolConfirmationOutcome, ApprovalMode, @@ -82,7 +82,7 @@ export class Task { this.contextId = contextId; this.config = config; this.scheduler = this.createScheduler(); - this.geminiClient = new GeminiClient(this.config); + this.geminiClient = this.config.getGeminiClient(); this.pendingToolConfirmationDetails = new Map(); this.taskState = 'submitted'; this.eventBus = eventBus; @@ -227,7 +227,7 @@ export class Task { } = { coderAgent: coderAgentMessage, model: this.config.getModel(), - userTier: this.geminiClient.getUserTier(), + userTier: this.config.getUserTier(), }; if (metadataError) { diff --git a/packages/a2a-server/src/utils/testing_utils.ts b/packages/a2a-server/src/utils/testing_utils.ts index 1c4dea7133..92ef706d3c 100644 --- a/packages/a2a-server/src/utils/testing_utils.ts +++ b/packages/a2a-server/src/utils/testing_utils.ts @@ -13,6 +13,7 @@ import { ApprovalMode, DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, + GeminiClient, } from '@google/gemini-cli-core'; import type { Config, Storage } from '@google/gemini-cli-core'; import { expect, vi } from 'vitest'; @@ -38,7 +39,6 @@ export function createMockConfig( getTruncateToolOutputThreshold: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, - getGeminiClient: vi.fn(), getDebugMode: vi.fn().mockReturnValue(false), getContentGeneratorConfig: vi.fn().mockReturnValue({ model: 'gemini-pro' }), getModel: vi.fn().mockReturnValue('gemini-pro'), @@ -49,8 +49,13 @@ export function createMockConfig( getHistory: vi.fn().mockReturnValue([]), getEmbeddingModel: vi.fn().mockReturnValue('text-embedding-004'), getSessionId: vi.fn().mockReturnValue('test-session-id'), + getUserTier: vi.fn(), ...overrides, - }; + } as unknown as Config; + + mockConfig.getGeminiClient = vi + .fn() + .mockReturnValue(new GeminiClient(mockConfig)); return mockConfig; } diff --git a/packages/cli/src/ui/AppContainer.tsx b/packages/cli/src/ui/AppContainer.tsx index a006053e39..0444a3d771 100644 --- a/packages/cli/src/ui/AppContainer.tsx +++ b/packages/cli/src/ui/AppContainer.tsx @@ -316,7 +316,7 @@ Logging in with Google... Please restart Gemini CLI to continue. useEffect(() => { // Only sync when not currently authenticating if (authState === AuthState.Authenticated) { - setUserTier(config.getGeminiClient()?.getUserTier()); + setUserTier(config.getUserTier()); } }, [config, authState]); diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts index 1d164445e9..4cf5bce5ed 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts @@ -4,6 +4,27 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { act, renderHook, waitFor } from '@testing-library/react'; +import { vi, describe, it, expect, beforeEach } from 'vitest'; +import { useSlashCommandProcessor } from './slashCommandProcessor.js'; +import type { + CommandContext, + ConfirmShellCommandsActionReturn, + SlashCommand, +} from '../commands/types.js'; +import { CommandKind } from '../commands/types.js'; +import type { LoadedSettings } from '../../config/settings.js'; +import { MessageType } from '../types.js'; +import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js'; +import { FileCommandLoader } from '../../services/FileCommandLoader.js'; +import { McpPromptLoader } from '../../services/McpPromptLoader.js'; +import { + type GeminiClient, + SlashCommandStatus, + ToolConfirmationOutcome, + makeFakeConfig, +} from '@google/gemini-cli-core'; + const { logSlashCommand } = vi.hoisted(() => ({ logSlashCommand: vi.fn(), })); @@ -68,26 +89,6 @@ vi.mock('../../utils/cleanup.js', () => ({ runExitCleanup: mockRunExitCleanup, })); -import { act, renderHook, waitFor } from '@testing-library/react'; -import { vi, describe, it, expect, beforeEach, type Mock } from 'vitest'; -import { useSlashCommandProcessor } from './slashCommandProcessor.js'; -import type { - CommandContext, - ConfirmShellCommandsActionReturn, - SlashCommand, -} from '../commands/types.js'; -import { CommandKind } from '../commands/types.js'; -import type { LoadedSettings } from '../../config/settings.js'; -import { MessageType } from '../types.js'; -import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js'; -import { FileCommandLoader } from '../../services/FileCommandLoader.js'; -import { McpPromptLoader } from '../../services/McpPromptLoader.js'; -import { - SlashCommandStatus, - ToolConfirmationOutcome, - makeFakeConfig, -} from '@google/gemini-cli-core'; - function createTestCommand( overrides: Partial, kind: CommandKind = CommandKind.BUILT_IN, @@ -113,7 +114,7 @@ describe('useSlashCommandProcessor', () => { beforeEach(() => { vi.clearAllMocks(); - (vi.mocked(BuiltinCommandLoader) as Mock).mockClear(); + vi.mocked(BuiltinCommandLoader).mockClear(); mockBuiltinLoadCommands.mockResolvedValue([]); mockFileLoadCommands.mockResolvedValue([]); mockMcpLoadCommands.mockResolvedValue([]); @@ -391,6 +392,12 @@ describe('useSlashCommandProcessor', () => { }); it('should handle "load_history" action', async () => { + const mockClient = { + setHistory: vi.fn(), + stripThoughtsFromHistory: vi.fn(), + } as unknown as GeminiClient; + vi.spyOn(mockConfig, 'getGeminiClient').mockReturnValue(mockClient); + const command = createTestCommand({ name: 'load', action: vi.fn().mockResolvedValue({ @@ -414,14 +421,11 @@ describe('useSlashCommandProcessor', () => { }); it('should strip thoughts when handling "load_history" action', async () => { - const mockSetHistory = vi.fn(); - const mockGeminiClient = { - setHistory: mockSetHistory, - }; - vi.spyOn(mockConfig, 'getGeminiClient').mockReturnValue( - // eslint-disable-next-line @typescript-eslint/no-explicit-any - mockGeminiClient as any, - ); + const mockClient = { + setHistory: vi.fn(), + stripThoughtsFromHistory: vi.fn(), + } as unknown as GeminiClient; + vi.spyOn(mockConfig, 'getGeminiClient').mockReturnValue(mockClient); const historyWithThoughts = [ { @@ -445,10 +449,8 @@ describe('useSlashCommandProcessor', () => { await result.current.handleSlashCommand('/loadwiththoughts'); }); - expect(mockSetHistory).toHaveBeenCalledTimes(1); - expect(mockSetHistory).toHaveBeenCalledWith(historyWithThoughts, { - stripThoughts: true, - }); + expect(mockClient.setHistory).toHaveBeenCalledTimes(1); + expect(mockClient.stripThoughtsFromHistory).toHaveBeenCalledWith(); }); it('should handle a "quit" action', async () => { diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index ed92d97cd1..c4dbc8d129 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -401,9 +401,8 @@ export const useSlashCommandProcessor = ( } } case 'load_history': { - config - ?.getGeminiClient() - ?.setHistory(result.clientHistory, { stripThoughts: true }); + config?.getGeminiClient()?.setHistory(result.clientHistory); + config?.getGeminiClient()?.stripThoughtsFromHistory(); fullCommandContext.ui.clear(); result.history.forEach((item, index) => { fullCommandContext.ui.addItem(item, index); diff --git a/packages/core/src/code_assist/codeAssist.ts b/packages/core/src/code_assist/codeAssist.ts index 32f314a51a..c8ade92edd 100644 --- a/packages/core/src/code_assist/codeAssist.ts +++ b/packages/core/src/code_assist/codeAssist.ts @@ -40,7 +40,7 @@ export async function createCodeAssistContentGenerator( export function getCodeAssistServer( config: Config, ): CodeAssistServer | undefined { - let server = config.getGeminiClient().getContentGenerator(); + let server = config.getContentGenerator(); // Unwrap LoggingContentGenerator if present if (server instanceof LoggingContentGenerator) { diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 8f91d5fc48..459d42a40a 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -71,18 +71,12 @@ vi.mock('../tools/memoryTool', () => ({ GEMINI_CONFIG_DIR: '.gemini', })); -vi.mock('../core/contentGenerator.js', async (importOriginal) => { - const actual = - await importOriginal(); - return { - ...actual, - createContentGeneratorConfig: vi.fn(), - }; -}); +vi.mock('../core/contentGenerator.js'); vi.mock('../core/client.js', () => ({ GeminiClient: vi.fn().mockImplementation(() => ({ initialize: vi.fn().mockResolvedValue(undefined), + stripThoughtsFromHistory: vi.fn(), })), })); @@ -196,7 +190,9 @@ describe('Server Config (config.ts)', () => { apiKey: 'test-key', }; - (createContentGeneratorConfig as Mock).mockReturnValue(mockContentConfig); + vi.mocked(createContentGeneratorConfig).mockReturnValue( + mockContentConfig, + ); // Set fallback mode to true to ensure it gets reset config.setFallbackMode(true); @@ -217,172 +213,38 @@ describe('Server Config (config.ts)', () => { expect(config.isInFallbackMode()).toBe(false); }); - it('should preserve conversation history when refreshing auth', async () => { - const config = new Config(baseParams); - const authType = AuthType.USE_GEMINI; - const mockContentConfig = { - model: 'gemini-pro', - apiKey: 'test-key', - }; - - (createContentGeneratorConfig as Mock).mockReturnValue(mockContentConfig); - - // Mock the existing client with some history - const mockExistingHistory = [ - { role: 'user', parts: [{ text: 'Hello' }] }, - { role: 'model', parts: [{ text: 'Hi there!' }] }, - { role: 'user', parts: [{ text: 'How are you?' }] }, - ]; - - const mockExistingClient = { - isInitialized: vi.fn().mockReturnValue(true), - getHistory: vi.fn().mockReturnValue(mockExistingHistory), - }; - - const mockNewClient = { - isInitialized: vi.fn().mockReturnValue(true), - getHistory: vi.fn().mockReturnValue([]), - setHistory: vi.fn(), - initialize: vi.fn().mockResolvedValue(undefined), - }; - - // Set the existing client - ( - config as unknown as { geminiClient: typeof mockExistingClient } - ).geminiClient = mockExistingClient; - (GeminiClient as Mock).mockImplementation(() => mockNewClient); - - await config.refreshAuth(authType); - - // Verify that existing history was retrieved - expect(mockExistingClient.getHistory).toHaveBeenCalled(); - - // Verify that new client was created and initialized - expect(GeminiClient).toHaveBeenCalledWith(config); - expect(mockNewClient.initialize).toHaveBeenCalledWith(mockContentConfig); - - // Verify that history was restored to the new client - expect(mockNewClient.setHistory).toHaveBeenCalledWith( - mockExistingHistory, - { stripThoughts: false }, - ); - }); - - it('should handle case when no existing client is initialized', async () => { - const config = new Config(baseParams); - const authType = AuthType.USE_GEMINI; - const mockContentConfig = { - model: 'gemini-pro', - apiKey: 'test-key', - }; - - (createContentGeneratorConfig as Mock).mockReturnValue(mockContentConfig); - - const mockNewClient = { - isInitialized: vi.fn().mockReturnValue(true), - getHistory: vi.fn().mockReturnValue([]), - setHistory: vi.fn(), - initialize: vi.fn().mockResolvedValue(undefined), - }; - - // No existing client - (config as unknown as { geminiClient: null }).geminiClient = null; - (GeminiClient as Mock).mockImplementation(() => mockNewClient); - - await config.refreshAuth(authType); - - // Verify that new client was created and initialized - expect(GeminiClient).toHaveBeenCalledWith(config); - expect(mockNewClient.initialize).toHaveBeenCalledWith(mockContentConfig); - - // Verify that setHistory was not called since there was no existing history - expect(mockNewClient.setHistory).not.toHaveBeenCalled(); - }); - it('should strip thoughts when switching from GenAI to Vertex', async () => { const config = new Config(baseParams); - const mockContentConfig = { - model: 'gemini-pro', - apiKey: 'test-key', - authType: AuthType.USE_GEMINI, - }; - ( - config as unknown as { contentGeneratorConfig: ContentGeneratorConfig } - ).contentGeneratorConfig = mockContentConfig; - (createContentGeneratorConfig as Mock).mockReturnValue({ - ...mockContentConfig, - authType: AuthType.LOGIN_WITH_GOOGLE, - }); + vi.mocked(createContentGeneratorConfig).mockImplementation( + (_: Config, authType: AuthType | undefined) => + ({ authType }) as unknown as ContentGeneratorConfig, + ); - const mockExistingHistory = [ - { role: 'user', parts: [{ text: 'Hello' }] }, - ]; - const mockExistingClient = { - isInitialized: vi.fn().mockReturnValue(true), - getHistory: vi.fn().mockReturnValue(mockExistingHistory), - }; - const mockNewClient = { - isInitialized: vi.fn().mockReturnValue(true), - getHistory: vi.fn().mockReturnValue([]), - setHistory: vi.fn(), - initialize: vi.fn().mockResolvedValue(undefined), - }; - - ( - config as unknown as { geminiClient: typeof mockExistingClient } - ).geminiClient = mockExistingClient; - (GeminiClient as Mock).mockImplementation(() => mockNewClient); + await config.refreshAuth(AuthType.USE_GEMINI); await config.refreshAuth(AuthType.LOGIN_WITH_GOOGLE); - expect(mockNewClient.setHistory).toHaveBeenCalledWith( - mockExistingHistory, - { stripThoughts: true }, - ); + expect( + config.getGeminiClient().stripThoughtsFromHistory, + ).toHaveBeenCalledWith(); }); it('should not strip thoughts when switching from Vertex to GenAI', async () => { const config = new Config(baseParams); - const mockContentConfig = { - model: 'gemini-pro', - apiKey: 'test-key', - authType: AuthType.LOGIN_WITH_GOOGLE, - }; - ( - config as unknown as { contentGeneratorConfig: ContentGeneratorConfig } - ).contentGeneratorConfig = mockContentConfig; - (createContentGeneratorConfig as Mock).mockReturnValue({ - ...mockContentConfig, - authType: AuthType.USE_GEMINI, - }); + vi.mocked(createContentGeneratorConfig).mockImplementation( + (_: Config, authType: AuthType | undefined) => + ({ authType }) as unknown as ContentGeneratorConfig, + ); - const mockExistingHistory = [ - { role: 'user', parts: [{ text: 'Hello' }] }, - ]; - const mockExistingClient = { - isInitialized: vi.fn().mockReturnValue(true), - getHistory: vi.fn().mockReturnValue(mockExistingHistory), - }; - const mockNewClient = { - isInitialized: vi.fn().mockReturnValue(true), - getHistory: vi.fn().mockReturnValue([]), - setHistory: vi.fn(), - initialize: vi.fn().mockResolvedValue(undefined), - }; - - ( - config as unknown as { geminiClient: typeof mockExistingClient } - ).geminiClient = mockExistingClient; - (GeminiClient as Mock).mockImplementation(() => mockNewClient); + await config.refreshAuth(AuthType.USE_VERTEX_AI); await config.refreshAuth(AuthType.USE_GEMINI); - expect(mockNewClient.setHistory).toHaveBeenCalledWith( - mockExistingHistory, - { stripThoughts: false }, - ); + expect( + config.getGeminiClient().stripThoughtsFromHistory, + ).not.toHaveBeenCalledWith(); }); }); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index f7f3e6a177..05dff79c06 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -6,9 +6,13 @@ import * as path from 'node:path'; import process from 'node:process'; -import type { ContentGeneratorConfig } from '../core/contentGenerator.js'; +import type { + ContentGenerator, + ContentGeneratorConfig, +} from '../core/contentGenerator.js'; import { AuthType, + createContentGenerator, createContentGeneratorConfig, } from '../core/contentGenerator.js'; import { PromptRegistry } from '../prompts/prompt-registry.js'; @@ -44,7 +48,6 @@ import { shouldAttemptBrowserLaunch } from '../utils/browser.js'; import type { MCPOAuthConfig } from '../mcp/oauth-provider.js'; import { IdeClient } from '../ide/ide-client.js'; import { ideContext } from '../ide/ideContext.js'; -import type { Content } from '@google/genai'; import type { FileSystemService } from '../services/fileSystemService.js'; import { StandardFileSystemService } from '../services/fileSystemService.js'; import { logCliConfiguration, logIdeConnection } from '../telemetry/loggers.js'; @@ -57,6 +60,8 @@ import { WorkspaceContext } from '../utils/workspaceContext.js'; import { Storage } from './storage.js'; import { FileExclusions } from '../utils/ignorePatterns.js'; import type { EventEmitter } from 'node:events'; +import type { UserTierId } from '../code_assist/types.js'; +import { ProxyAgent, setGlobalDispatcher } from 'undici'; export enum ApprovalMode { DEFAULT = 'default', @@ -226,6 +231,7 @@ export class Config { private readonly sessionId: string; private fileSystemService: FileSystemService; private contentGeneratorConfig!: ContentGeneratorConfig; + private contentGenerator!: ContentGenerator; private readonly embeddingModel: string; private readonly sandbox: SandboxConfig | undefined; private readonly targetDir: string; @@ -386,6 +392,11 @@ export class Config { if (this.telemetrySettings.enabled) { initializeTelemetry(this); } + + if (this.getProxy()) { + setGlobalDispatcher(new ProxyAgent(this.getProxy() as string)); + } + this.geminiClient = new GeminiClient(this); } /** @@ -410,46 +421,45 @@ export class Config { this.promptRegistry = new PromptRegistry(); this.toolRegistry = await this.createToolRegistry(); logCliConfiguration(this, new StartSessionEvent(this, this.toolRegistry)); + + await this.geminiClient.initialize(); + } + + getContentGenerator(): ContentGenerator { + return this.contentGenerator; } async refreshAuth(authMethod: AuthType) { - // Save the current conversation history before creating a new client - let existingHistory: Content[] = []; - if (this.geminiClient && this.geminiClient.isInitialized()) { - existingHistory = this.geminiClient.getHistory(); + // Vertex and Genai have incompatible encryption and sending history with + // throughtSignature from Genai to Vertex will fail, we need to strip them + if ( + this.contentGeneratorConfig?.authType === AuthType.USE_GEMINI && + authMethod === AuthType.LOGIN_WITH_GOOGLE + ) { + // Restore the conversation history to the new client + this.geminiClient.stripThoughtsFromHistory(); } - // Create new content generator config const newContentGeneratorConfig = createContentGeneratorConfig( this, authMethod, ); - - // Create and initialize new client in local variable first - const newGeminiClient = new GeminiClient(this); - await newGeminiClient.initialize(newContentGeneratorConfig); - - // Vertex and Genai have incompatible encryption and sending history with - // throughtSignature from Genai to Vertex will fail, we need to strip them - const fromGenaiToVertex = - this.contentGeneratorConfig?.authType === AuthType.USE_GEMINI && - authMethod === AuthType.LOGIN_WITH_GOOGLE; - + this.contentGenerator = await createContentGenerator( + newContentGeneratorConfig, + this, + this.getSessionId(), + ); // Only assign to instance properties after successful initialization this.contentGeneratorConfig = newContentGeneratorConfig; - this.geminiClient = newGeminiClient; - - // Restore the conversation history to the new client - if (existingHistory.length > 0) { - this.geminiClient.setHistory(existingHistory, { - stripThoughts: fromGenaiToVertex, - }); - } // Reset the session flag since we're explicitly changing auth and using default model this.inFallbackMode = false; } + getUserTier(): UserTierId | undefined { + return this.contentGenerator?.userTier; + } + getSessionId(): string { return this.sessionId; } diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 00d3b3e6d1..fe110cca4a 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -4,24 +4,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { - describe, - it, - expect, - vi, - beforeEach, - afterEach, - type Mocked, -} from 'vitest'; +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import type { - Chat, - Content, - EmbedContentResponse, - GenerateContentResponse, - Part, -} from '@google/genai'; -import { GoogleGenAI } from '@google/genai'; +import type { Content, GenerateContentResponse, Part } from '@google/genai'; import { findIndexAfterFraction, isThinkingDefault, @@ -34,7 +19,7 @@ import { type ContentGeneratorConfig, } from './contentGenerator.js'; import { type GeminiChat } from './geminiChat.js'; -import { Config } from '../config/config.js'; +import type { Config } from '../config/config.js'; import { CompressionStatus, GeminiEventType, @@ -76,12 +61,8 @@ vi.mock('node:fs', () => { }); // --- Mocks --- -const mockChatCreateFn = vi.fn(); -const mockGenerateContentFn = vi.fn(); -const mockEmbedContentFn = vi.fn(); const mockTurnRunFn = vi.fn(); -vi.mock('@google/genai'); vi.mock('./turn', async (importOriginal) => { const actual = await importOriginal(); // Define a mock class that has the same shape as the real Turn @@ -228,6 +209,8 @@ describe('isThinkingDefault', () => { }); describe('Gemini Client (client.ts)', () => { + let mockContentGenerator: ContentGenerator; + let mockConfig: Config; let client: GeminiClient; beforeEach(async () => { vi.resetAllMocks(); @@ -235,29 +218,15 @@ describe('Gemini Client (client.ts)', () => { // Disable 429 simulation for tests setSimulate429(false); - // Set up the mock for GoogleGenAI constructor and its methods - const MockedGoogleGenAI = vi.mocked(GoogleGenAI); - MockedGoogleGenAI.mockImplementation(() => { - const mock = { - chats: { create: mockChatCreateFn }, - models: { - generateContent: mockGenerateContentFn, - embedContent: mockEmbedContentFn, - }, - }; - return mock as unknown as GoogleGenAI; - }); - - mockChatCreateFn.mockResolvedValue({} as Chat); - mockGenerateContentFn.mockResolvedValue({ - candidates: [ - { - content: { - parts: [{ text: '{"key": "value"}' }], - }, - }, - ], - } as unknown as GenerateContentResponse); + mockContentGenerator = { + generateContent: vi.fn().mockResolvedValue({ + candidates: [{ content: { parts: [{ text: '{"key": "value"}' }] } }], + }), + generateContentStream: vi.fn(), + countTokens: vi.fn(), + embedContent: vi.fn(), + batchEmbedContents: vi.fn(), + } as unknown as ContentGenerator; // Because the GeminiClient constructor kicks off an async process (startChat) // that depends on a fully-formed Config object, we need to mock the @@ -273,7 +242,7 @@ describe('Gemini Client (client.ts)', () => { vertexai: false, authType: AuthType.USE_GEMINI, }; - const mockConfigObject = { + mockConfig = { getContentGeneratorConfig: vi .fn() .mockReturnValue(contentGeneratorConfig), @@ -309,46 +278,18 @@ describe('Gemini Client (client.ts)', () => { storage: { getProjectTempDir: vi.fn().mockReturnValue('/test/temp'), }, - }; - const MockedConfig = vi.mocked(Config, true); - MockedConfig.mockImplementation( - () => mockConfigObject as unknown as Config, - ); + getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator), + } as unknown as Config; - // We can instantiate the client here since Config is mocked - // and the constructor will use the mocked GoogleGenAI - client = new GeminiClient( - new Config({ sessionId: 'test-session-id' } as never), - ); - mockConfigObject.getGeminiClient.mockReturnValue(client); - - await client.initialize(contentGeneratorConfig); + client = new GeminiClient(mockConfig); + await client.initialize(); + vi.mocked(mockConfig.getGeminiClient).mockReturnValue(client); }); afterEach(() => { vi.restoreAllMocks(); }); - // NOTE: The following tests for startChat were removed due to persistent issues with - // the @google/genai mock. Specifically, the mockChatCreateFn (representing instance.chats.create) - // was not being detected as called by the GeminiClient instance. - // This likely points to a subtle issue in how the GoogleGenerativeAI class constructor - // and its instance methods are mocked and then used by the class under test. - // For future debugging, ensure that the `this.client` in `GeminiClient` (which is an - // instance of the mocked GoogleGenerativeAI) correctly has its `chats.create` method - // pointing to `mockChatCreateFn`. - // it('startChat should call getCoreSystemPrompt with userMemory and pass to chats.create', async () => { ... }); - // it('startChat should call getCoreSystemPrompt with empty string if userMemory is empty', async () => { ... }); - - // NOTE: The following tests for generateJson were removed due to persistent issues with - // the @google/genai mock, similar to the startChat tests. The mockGenerateContentFn - // (representing instance.models.generateContent) was not being detected as called, or the mock - // was not preventing an actual API call (leading to API key errors). - // For future debugging, ensure `this.client.models.generateContent` in `GeminiClient` correctly - // uses the `mockGenerateContentFn`. - // it('generateJson should call getCoreSystemPrompt with userMemory and pass to generateContent', async () => { ... }); - // it('generateJson should call getCoreSystemPrompt with empty string if userMemory is empty', async () => { ... }); - describe('generateEmbedding', () => { const texts = ['hello world', 'goodbye world']; const testEmbeddingModel = 'test-embedding-model'; @@ -358,18 +299,17 @@ describe('Gemini Client (client.ts)', () => { [0.1, 0.2, 0.3], [0.4, 0.5, 0.6], ]; - const mockResponse: EmbedContentResponse = { + vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({ embeddings: [ { values: mockEmbeddings[0] }, { values: mockEmbeddings[1] }, ], - }; - mockEmbedContentFn.mockResolvedValue(mockResponse); + }); const result = await client.generateEmbedding(texts); - expect(mockEmbedContentFn).toHaveBeenCalledTimes(1); - expect(mockEmbedContentFn).toHaveBeenCalledWith({ + expect(mockContentGenerator.embedContent).toHaveBeenCalledTimes(1); + expect(mockContentGenerator.embedContent).toHaveBeenCalledWith({ model: testEmbeddingModel, contents: texts, }); @@ -379,11 +319,11 @@ describe('Gemini Client (client.ts)', () => { it('should return an empty array if an empty array is passed', async () => { const result = await client.generateEmbedding([]); expect(result).toEqual([]); - expect(mockEmbedContentFn).not.toHaveBeenCalled(); + expect(mockContentGenerator.embedContent).not.toHaveBeenCalled(); }); it('should throw an error if API response has no embeddings array', async () => { - mockEmbedContentFn.mockResolvedValue({} as EmbedContentResponse); // No `embeddings` key + vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({}); await expect(client.generateEmbedding(texts)).rejects.toThrow( 'No embeddings found in API response.', @@ -391,20 +331,19 @@ describe('Gemini Client (client.ts)', () => { }); it('should throw an error if API response has an empty embeddings array', async () => { - const mockResponse: EmbedContentResponse = { + vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({ embeddings: [], - }; - mockEmbedContentFn.mockResolvedValue(mockResponse); + }); + await expect(client.generateEmbedding(texts)).rejects.toThrow( 'No embeddings found in API response.', ); }); it('should throw an error if API returns a mismatched number of embeddings', async () => { - const mockResponse: EmbedContentResponse = { + vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({ embeddings: [{ values: [1, 2, 3] }], // Only one for two texts - }; - mockEmbedContentFn.mockResolvedValue(mockResponse); + }); await expect(client.generateEmbedding(texts)).rejects.toThrow( 'API returned a mismatched number of embeddings. Expected 2, got 1.', @@ -412,10 +351,9 @@ describe('Gemini Client (client.ts)', () => { }); it('should throw an error if any embedding has nullish values', async () => { - const mockResponse: EmbedContentResponse = { + vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({ embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad - }; - mockEmbedContentFn.mockResolvedValue(mockResponse); + }); await expect(client.generateEmbedding(texts)).rejects.toThrow( 'API returned an empty embedding for input text at index 1: "goodbye world"', @@ -423,10 +361,9 @@ describe('Gemini Client (client.ts)', () => { }); it('should throw an error if any embedding has an empty values array', async () => { - const mockResponse: EmbedContentResponse = { + vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({ embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad - }; - mockEmbedContentFn.mockResolvedValue(mockResponse); + }); await expect(client.generateEmbedding(texts)).rejects.toThrow( 'API returned an empty embedding for input text at index 0: "hello world"', @@ -434,8 +371,9 @@ describe('Gemini Client (client.ts)', () => { }); it('should propagate errors from the API call', async () => { - const apiError = new Error('API Failure'); - mockEmbedContentFn.mockRejectedValue(apiError); + vi.mocked(mockContentGenerator.embedContent).mockRejectedValue( + new Error('API Failure'), + ); await expect(client.generateEmbedding(texts)).rejects.toThrow( 'API Failure', @@ -449,12 +387,9 @@ describe('Gemini Client (client.ts)', () => { const schema = { type: 'string' }; const abortSignal = new AbortController().signal; - // Mock countTokens - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; + vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ + totalTokens: 1, + }); await client.generateJson( contents, @@ -463,7 +398,7 @@ describe('Gemini Client (client.ts)', () => { DEFAULT_GEMINI_FLASH_MODEL, ); - expect(mockGenerateContentFn).toHaveBeenCalledWith( + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( { model: DEFAULT_GEMINI_FLASH_MODEL, config: { @@ -489,11 +424,9 @@ describe('Gemini Client (client.ts)', () => { const customModel = 'custom-json-model'; const customConfig = { temperature: 0.9, topK: 20 }; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; + vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ + totalTokens: 1, + }); await client.generateJson( contents, @@ -503,7 +436,7 @@ describe('Gemini Client (client.ts)', () => { customConfig, ); - expect(mockGenerateContentFn).toHaveBeenCalledWith( + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( { model: customModel, config: { @@ -524,10 +457,10 @@ describe('Gemini Client (client.ts)', () => { describe('addHistory', () => { it('should call chat.addHistory with the provided content', async () => { - const mockChat: Partial = { + const mockChat = { addHistory: vi.fn(), - }; - client['chat'] = mockChat as GeminiChat; + } as unknown as GeminiChat; + client['chat'] = mockChat; const newContent = { role: 'user', @@ -568,7 +501,6 @@ describe('Gemini Client (client.ts)', () => { }); describe('tryCompressChat', () => { - const mockCountTokens = vi.fn(); const mockSendMessage = vi.fn(); const mockGetHistory = vi.fn(); @@ -577,10 +509,6 @@ describe('Gemini Client (client.ts)', () => { tokenLimit: vi.fn(), })); - client['contentGenerator'] = { - countTokens: mockCountTokens, - } as unknown as ContentGenerator; - client['chat'] = { getHistory: mockGetHistory, addHistory: vi.fn(), @@ -600,27 +528,21 @@ describe('Gemini Client (client.ts)', () => { setHistory: vi.fn(), sendMessage: vi.fn().mockResolvedValue({ text: 'Summary' }), }; - const mockCountTokens = vi - .fn() + vi.mocked(mockContentGenerator.countTokens) .mockResolvedValueOnce({ totalTokens: 1000 }) .mockResolvedValueOnce({ totalTokens: 5000 }); - const mockGenerator: Partial> = { - countTokens: mockCountTokens, - }; - client['chat'] = mockChat as GeminiChat; - client['contentGenerator'] = mockGenerator as ContentGenerator; client['startChat'] = vi.fn().mockResolvedValue({ ...mockChat }); - return { client, mockChat, mockGenerator }; + return { client, mockChat }; } describe('when compression inflates the token count', () => { - it('uses the truncated history for compression'); it('allows compression to be forced/manual after a failure', async () => { - const { client, mockGenerator } = setup(); - mockGenerator.countTokens?.mockResolvedValue({ + const { client } = setup(); + + vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ totalTokens: 1000, }); await client.tryCompressChat('prompt-id-4'); // Fails @@ -635,6 +557,9 @@ describe('Gemini Client (client.ts)', () => { it('yields the result even if the compression inflated the tokens', async () => { const { client } = setup(); + vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ + totalTokens: 1000, + }); const result = await client.tryCompressChat('prompt-id-4', true); expect(result).toEqual({ @@ -654,7 +579,7 @@ describe('Gemini Client (client.ts)', () => { it('restores the history back to the original', async () => { vi.mocked(tokenLimit).mockReturnValue(1000); - mockCountTokens.mockResolvedValue({ + vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ totalTokens: 999, }); @@ -679,13 +604,13 @@ describe('Gemini Client (client.ts)', () => { }); it('will not attempt to compress context after a failure', async () => { - const { client, mockGenerator } = setup(); + const { client } = setup(); await client.tryCompressChat('prompt-id-4'); const result = await client.tryCompressChat('prompt-id-5'); // it counts tokens for {original, compressed} and then never again - expect(mockGenerator.countTokens).toHaveBeenCalledTimes(2); + expect(mockContentGenerator.countTokens).toHaveBeenCalledTimes(2); expect(result).toEqual({ compressionStatus: CompressionStatus.NOOP, newTokenCount: 0, @@ -696,7 +621,7 @@ describe('Gemini Client (client.ts)', () => { it('attempts to compress with a maxOutputTokens set to the original token count', async () => { vi.mocked(tokenLimit).mockReturnValue(1000); - mockCountTokens.mockResolvedValue({ + vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ totalTokens: 999, }); @@ -728,8 +653,7 @@ describe('Gemini Client (client.ts)', () => { mockGetHistory.mockReturnValue([ { role: 'user', parts: [{ text: '...history...' }] }, ]); - - mockCountTokens.mockResolvedValue({ + vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ totalTokens: MOCKED_TOKEN_LIMIT * 0.699, // TOKEN_THRESHOLD_FOR_SUMMARIZATION = 0.7 }); @@ -763,7 +687,7 @@ describe('Gemini Client (client.ts)', () => { MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD; const newTokenCount = 100; - mockCountTokens + vi.mocked(mockContentGenerator.countTokens) .mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check .mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history @@ -800,7 +724,7 @@ describe('Gemini Client (client.ts)', () => { MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD; const newTokenCount = 100; - mockCountTokens + vi.mocked(mockContentGenerator.countTokens) .mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check .mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history @@ -853,7 +777,7 @@ describe('Gemini Client (client.ts)', () => { const originalTokenCount = 1000 * 0.7; const newTokenCount = 100; - mockCountTokens + vi.mocked(mockContentGenerator.countTokens) .mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check .mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history @@ -895,7 +819,7 @@ describe('Gemini Client (client.ts)', () => { const originalTokenCount = 10; // Well below threshold const newTokenCount = 5; - mockCountTokens + vi.mocked(mockContentGenerator.countTokens) .mockResolvedValueOnce({ totalTokens: originalTokenCount }) .mockResolvedValueOnce({ totalTokens: newTokenCount }); @@ -922,10 +846,16 @@ describe('Gemini Client (client.ts)', () => { }); it('should use current model from config for token counting after sendMessage', async () => { - const initialModel = client['config'].getModel(); + const initialModel = mockConfig.getModel(); - const mockCountTokens = vi - .fn() + // mock the model has been changed between calls of `countTokens` + const firstCurrentModel = initialModel + '-changed-1'; + const secondCurrentModel = initialModel + '-changed-2'; + vi.mocked(mockConfig.getModel) + .mockReturnValueOnce(firstCurrentModel) + .mockReturnValueOnce(secondCurrentModel); + + vi.mocked(mockContentGenerator.countTokens) .mockResolvedValueOnce({ totalTokens: 100000 }) .mockResolvedValueOnce({ totalTokens: 5000 }); @@ -936,35 +866,23 @@ describe('Gemini Client (client.ts)', () => { { role: 'model', parts: [{ text: 'Long response' }] }, ]; - const mockChat: Partial = { + const mockChat = { getHistory: vi.fn().mockReturnValue(mockChatHistory), setHistory: vi.fn(), sendMessage: mockSendMessage, - }; + } as unknown as GeminiChat; - const mockGenerator: Partial = { - countTokens: mockCountTokens, - }; - - // mock the model has been changed between calls of `countTokens` - const firstCurrentModel = initialModel + '-changed-1'; - const secondCurrentModel = initialModel + '-changed-2'; - vi.spyOn(client['config'], 'getModel') - .mockReturnValueOnce(firstCurrentModel) - .mockReturnValueOnce(secondCurrentModel); - - client['chat'] = mockChat as GeminiChat; - client['contentGenerator'] = mockGenerator as ContentGenerator; + client['chat'] = mockChat; client['startChat'] = vi.fn().mockResolvedValue(mockChat); const result = await client.tryCompressChat('prompt-id-4', true); - expect(mockCountTokens).toHaveBeenCalledTimes(2); - expect(mockCountTokens).toHaveBeenNthCalledWith(1, { + expect(mockContentGenerator.countTokens).toHaveBeenCalledTimes(2); + expect(mockContentGenerator.countTokens).toHaveBeenNthCalledWith(1, { model: firstCurrentModel, contents: mockChatHistory, }); - expect(mockCountTokens).toHaveBeenNthCalledWith(2, { + expect(mockContentGenerator.countTokens).toHaveBeenNthCalledWith(2, { model: secondCurrentModel, contents: expect.any(Array), }); @@ -980,22 +898,11 @@ describe('Gemini Client (client.ts)', () => { describe('sendMessageStream', () => { it('emits a compression event when the context was automatically compressed', async () => { // Arrange - const mockStream = (async function* () { - yield { type: 'content', value: 'Hello' }; - })(); - mockTurnRunFn.mockReturnValue(mockStream); - - const mockChat: Partial = { - addHistory: vi.fn(), - getHistory: vi.fn().mockReturnValue([]), - }; - client['chat'] = mockChat as GeminiChat; - - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; + mockTurnRunFn.mockReturnValue( + (async function* () { + yield { type: 'content', value: 'Hello' }; + })(), + ); const compressionInfo: ChatCompressionInfo = { compressionStatus: CompressionStatus.COMPRESSED, @@ -1042,18 +949,6 @@ describe('Gemini Client (client.ts)', () => { })(); mockTurnRunFn.mockReturnValue(mockStream); - const mockChat: Partial = { - addHistory: vi.fn(), - getHistory: vi.fn().mockReturnValue([]), - }; - client['chat'] = mockChat as GeminiChat; - - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - const compressionInfo: ChatCompressionInfo = { compressionStatus, originalTokenCount: 1000, @@ -1105,24 +1000,19 @@ describe('Gemini Client (client.ts)', () => { }, }); - vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true); + vi.mocked(mockConfig.getIdeMode).mockReturnValue(true); - const mockStream = (async function* () { - yield { type: 'content', value: 'Hello' }; - })(); - mockTurnRunFn.mockReturnValue(mockStream); + mockTurnRunFn.mockReturnValue( + (async function* () { + yield { type: 'content', value: 'Hello' }; + })(), + ); - const mockChat: Partial = { + const mockChat = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), - }; - client['chat'] = mockChat as GeminiChat; - - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; + } as unknown as GeminiChat; + client['chat'] = mockChat; const initialRequest: Part[] = [{ text: 'Hi' }]; @@ -1186,12 +1076,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - const initialRequest = [{ text: 'Hi' }]; // Act @@ -1241,12 +1125,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - const initialRequest = [{ text: 'Hi' }]; // Act @@ -1317,12 +1195,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - const initialRequest = [{ text: 'Hi' }]; // Act @@ -1369,12 +1241,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - // Act const stream = client.sendMessageStream( [{ text: 'Hi' }], @@ -1419,12 +1285,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - // Use a signal that never gets aborted const abortController = new AbortController(); const signal = abortController.signal; @@ -1512,12 +1372,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - // Act & Assert // Run up to the limit for (let i = 0; i < MAX_SESSION_TURNS; i++) { @@ -1574,12 +1428,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - // Use a signal that never gets aborted const abortController = new AbortController(); const signal = abortController.signal; @@ -1659,12 +1507,6 @@ ${JSON.stringify( ]), }; client['chat'] = mockChat as GeminiChat; - - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; }); const testCases = [ @@ -1921,11 +1763,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true); vi.mocked(ideContext.getIdeContext).mockReturnValue({ workspaceState: { @@ -2266,12 +2103,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - // Act const stream = client.sendMessageStream( [{ text: 'Hi' }], @@ -2308,12 +2139,6 @@ ${JSON.stringify( }; client['chat'] = mockChat as GeminiChat; - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - // Act const stream = client.sendMessageStream( [{ text: 'Hi' }], @@ -2335,13 +2160,6 @@ ${JSON.stringify( const generationConfig = { temperature: 0.5 }; const abortSignal = new AbortController().signal; - // Mock countTokens - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - await client.generateContent( contents, generationConfig, @@ -2349,7 +2167,7 @@ ${JSON.stringify( DEFAULT_GEMINI_FLASH_MODEL, ); - expect(mockGenerateContentFn).toHaveBeenCalledWith( + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( { model: DEFAULT_GEMINI_FLASH_MODEL, config: { @@ -2371,12 +2189,6 @@ ${JSON.stringify( vi.spyOn(client['config'], 'getModel').mockReturnValueOnce(currentModel); - const mockGenerator: Partial = { - countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }), - generateContent: mockGenerateContentFn, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; - await client.generateContent( contents, {}, @@ -2384,12 +2196,12 @@ ${JSON.stringify( DEFAULT_GEMINI_FLASH_MODEL, ); - expect(mockGenerateContentFn).not.toHaveBeenCalledWith({ + expect(mockContentGenerator.generateContent).not.toHaveBeenCalledWith({ model: initialModel, config: expect.any(Object), contents, }); - expect(mockGenerateContentFn).toHaveBeenCalledWith( + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( { model: DEFAULT_GEMINI_FLASH_MODEL, config: expect.any(Object), @@ -2427,73 +2239,4 @@ ${JSON.stringify( ); }); }); - - describe('setHistory', () => { - it('should strip thought signatures when stripThoughts is true', () => { - const mockChat = { - setHistory: vi.fn(), - }; - client['chat'] = mockChat as unknown as GeminiChat; - - const historyWithThoughts: Content[] = [ - { - role: 'user', - parts: [{ text: 'hello' }], - }, - { - role: 'model', - parts: [ - { text: 'thinking...', thoughtSignature: 'thought-123' }, - { - functionCall: { name: 'test', args: {} }, - thoughtSignature: 'thought-456', - }, - ], - }, - ]; - - client.setHistory(historyWithThoughts, { stripThoughts: true }); - - const expectedHistory: Content[] = [ - { - role: 'user', - parts: [{ text: 'hello' }], - }, - { - role: 'model', - parts: [ - { text: 'thinking...' }, - { functionCall: { name: 'test', args: {} } }, - ], - }, - ]; - - expect(mockChat.setHistory).toHaveBeenCalledWith(expectedHistory); - }); - - it('should not strip thought signatures when stripThoughts is false', () => { - const mockChat = { - setHistory: vi.fn(), - }; - client['chat'] = mockChat as unknown as GeminiChat; - - const historyWithThoughts: Content[] = [ - { - role: 'user', - parts: [{ text: 'hello' }], - }, - { - role: 'model', - parts: [ - { text: 'thinking...', thoughtSignature: 'thought-123' }, - { text: 'ok', thoughtSignature: 'thought-456' }, - ], - }, - ]; - - client.setHistory(historyWithThoughts, { stripThoughts: false }); - - expect(mockChat.setHistory).toHaveBeenCalledWith(historyWithThoughts); - }); - }); }); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index d00504dc5b..01c52c4c92 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -20,7 +20,6 @@ import type { ServerGeminiStreamEvent, ChatCompressionInfo } from './turn.js'; import { CompressionStatus } from './turn.js'; import { Turn, GeminiEventType } from './turn.js'; import type { Config } from '../config/config.js'; -import type { UserTierId } from '../code_assist/types.js'; import { getCoreSystemPrompt, getCompressionPrompt } from './prompts.js'; import { getResponseText } from '../utils/partUtils.js'; import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js'; @@ -31,12 +30,8 @@ import { getErrorMessage } from '../utils/errors.js'; import { isFunctionResponse } from '../utils/messageInspectors.js'; import { tokenLimit } from './tokenLimits.js'; import type { ChatRecordingService } from '../services/chatRecordingService.js'; -import type { - ContentGenerator, - ContentGeneratorConfig, -} from './contentGenerator.js'; -import { AuthType, createContentGenerator } from './contentGenerator.js'; -import { ProxyAgent, setGlobalDispatcher } from 'undici'; +import type { ContentGenerator } from './contentGenerator.js'; +import { AuthType } from './contentGenerator.js'; import { DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_THINKING_MODE, @@ -115,8 +110,6 @@ const COMPRESSION_PRESERVE_THRESHOLD = 0.3; export class GeminiClient { private chat?: GeminiChat; - private contentGenerator?: ContentGenerator; - private readonly embeddingModel: string; private readonly generateContentConfig: GenerateContentConfig = { temperature: 0, topP: 1, @@ -135,33 +128,19 @@ export class GeminiClient { private hasFailedCompressionAttempt = false; constructor(private readonly config: Config) { - if (config.getProxy()) { - setGlobalDispatcher(new ProxyAgent(config.getProxy() as string)); - } - - this.embeddingModel = config.getEmbeddingModel(); this.loopDetector = new LoopDetectionService(config); this.lastPromptId = this.config.getSessionId(); } - async initialize(contentGeneratorConfig: ContentGeneratorConfig) { - this.contentGenerator = await createContentGenerator( - contentGeneratorConfig, - this.config, - this.config.getSessionId(), - ); + async initialize() { this.chat = await this.startChat(); } - getContentGenerator(): ContentGenerator { - if (!this.contentGenerator) { + private getContentGeneratorOrFail(): ContentGenerator { + if (!this.config.getContentGenerator()) { throw new Error('Content generator not initialized'); } - return this.contentGenerator; - } - - getUserTier(): UserTierId | undefined { - return this.contentGenerator?.userTier; + return this.config.getContentGenerator(); } async addHistory(content: Content) { @@ -176,39 +155,19 @@ export class GeminiClient { } isInitialized(): boolean { - return this.chat !== undefined && this.contentGenerator !== undefined; + return this.chat !== undefined; } getHistory(): Content[] { return this.getChat().getHistory(); } - setHistory( - history: Content[], - { stripThoughts = false }: { stripThoughts?: boolean } = {}, - ) { - const historyToSet = stripThoughts - ? history.map((content) => { - const newContent = { ...content }; - if (newContent.parts) { - newContent.parts = newContent.parts.map((part) => { - if ( - part && - typeof part === 'object' && - 'thoughtSignature' in part - ) { - const newPart = { ...part }; - delete (newPart as { thoughtSignature?: string }) - .thoughtSignature; - return newPart; - } - return part; - }); - } - return newContent; - }) - : history; - this.getChat().setHistory(historyToSet); + stripThoughtsFromHistory() { + this.getChat().stripThoughtsFromHistory(); + } + + setHistory(history: Content[]) { + this.getChat().setHistory(history); this.forceFullIdeContext = true; } @@ -242,9 +201,11 @@ export class GeminiClient { this.forceFullIdeContext = true; this.hasFailedCompressionAttempt = false; const envParts = await getEnvironmentContext(this.config); + const toolRegistry = this.config.getToolRegistry(); const toolDeclarations = toolRegistry.getFunctionDeclarations(); const tools: Tool[] = [{ functionDeclarations: toolDeclarations }]; + const history: Content[] = [ { role: 'user', @@ -274,7 +235,6 @@ export class GeminiClient { : this.generateContentConfig; return new GeminiChat( this.config, - this.getContentGenerator(), { systemInstruction, ...generateContentConfigWithThinking, @@ -600,7 +560,7 @@ export class GeminiClient { }; const apiCall = () => - this.getContentGenerator().generateContent( + this.getContentGeneratorOrFail().generateContent( { model, config: { @@ -711,7 +671,7 @@ export class GeminiClient { }; const apiCall = () => - this.getContentGenerator().generateContent( + this.getContentGeneratorOrFail().generateContent( { model, config: requestConfig, @@ -751,12 +711,12 @@ export class GeminiClient { return []; } const embedModelParams: EmbedContentParameters = { - model: this.embeddingModel, + model: this.config.getEmbeddingModel(), contents: texts, }; const embedContentResponse = - await this.getContentGenerator().embedContent(embedModelParams); + await this.getContentGeneratorOrFail().embedContent(embedModelParams); if ( !embedContentResponse.embeddings || embedContentResponse.embeddings.length === 0 @@ -802,7 +762,7 @@ export class GeminiClient { const model = this.config.getModel(); const { totalTokens: originalTokenCount } = - await this.getContentGenerator().countTokens({ + await this.getContentGeneratorOrFail().countTokens({ model, contents: curatedHistory, }); @@ -877,7 +837,7 @@ export class GeminiClient { this.forceFullIdeContext = true; const { totalTokens: newTokenCount } = - await this.getContentGenerator().countTokens({ + await this.getContentGeneratorOrFail().countTokens({ // model might change after calling `sendMessage`, so we get the newest value from config model: this.config.getModel(), contents: chat.getHistory(), diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index c7660441fd..8305b1f359 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -7,11 +7,11 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import type { Content, - Models, GenerateContentConfig, Part, GenerateContentResponse, } from '@google/genai'; +import type { ContentGenerator } from '../core/contentGenerator.js'; import { GeminiChat, EmptyStreamError, @@ -47,15 +47,6 @@ vi.mock('node:fs', () => { }; }); -// Mocks -const mockModelsModule = { - generateContent: vi.fn(), - generateContentStream: vi.fn(), - countTokens: vi.fn(), - embedContent: vi.fn(), - batchEmbedContents: vi.fn(), -} as unknown as Models; - const { mockLogInvalidChunk, mockLogContentRetry, mockLogContentRetryFailure } = vi.hoisted(() => ({ mockLogInvalidChunk: vi.fn(), @@ -70,12 +61,21 @@ vi.mock('../telemetry/loggers.js', () => ({ })); describe('GeminiChat', () => { + let mockContentGenerator: ContentGenerator; let chat: GeminiChat; let mockConfig: Config; const config: GenerateContentConfig = {}; beforeEach(() => { vi.clearAllMocks(); + mockContentGenerator = { + generateContent: vi.fn(), + generateContentStream: vi.fn(), + countTokens: vi.fn(), + embedContent: vi.fn(), + batchEmbedContents: vi.fn(), + } as unknown as ContentGenerator; + mockConfig = { getSessionId: () => 'test-session-id', getTelemetryLogPromptsEnabled: () => true, @@ -97,12 +97,13 @@ describe('GeminiChat', () => { getToolRegistry: vi.fn().mockReturnValue({ getTool: vi.fn(), }), + getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator), } as unknown as Config; // Disable 429 simulation for tests setSimulate429(false); // Reset history for each test by creating a new instance - chat = new GeminiChat(mockConfig, mockModelsModule, config, []); + chat = new GeminiChat(mockConfig, config, []); }); afterEach(() => { @@ -160,7 +161,7 @@ describe('GeminiChat', () => { ], } as unknown as GenerateContentResponse; - vi.mocked(mockModelsModule.generateContent).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContent).mockResolvedValue( mockAfcResponse, ); @@ -218,7 +219,7 @@ describe('GeminiChat', () => { { content: { role: 'model', parts: [{ text: 'some response' }] } }, ], } as unknown as GenerateContentResponse; - vi.mocked(mockModelsModule.generateContent).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContent).mockResolvedValue( mockResponse, ); @@ -247,7 +248,7 @@ describe('GeminiChat', () => { ], } as unknown as GenerateContentResponse; - vi.mocked(mockModelsModule.generateContent).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContent).mockResolvedValue( mixedContentResponse, ); @@ -302,7 +303,7 @@ describe('GeminiChat', () => { { content: { role: 'model', parts: [{ thought: true }] } }, ], } as unknown as GenerateContentResponse; - vi.mocked(mockModelsModule.generateContent).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContent).mockResolvedValue( emptyModelResponse, ); @@ -348,11 +349,13 @@ describe('GeminiChat', () => { ], text: () => 'response', } as unknown as GenerateContentResponse; - vi.mocked(mockModelsModule.generateContent).mockResolvedValue(response); + vi.mocked(mockContentGenerator.generateContent).mockResolvedValue( + response, + ); await chat.sendMessage({ message: 'hello' }, 'prompt-id-1'); - expect(mockModelsModule.generateContent).toHaveBeenCalledWith( + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( { model: 'gemini-pro', contents: [{ role: 'user', parts: [{ text: 'hello' }] }], @@ -390,7 +393,7 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( streamWithToolCall, ); @@ -442,7 +445,7 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( streamWithNoFinish, ); @@ -487,7 +490,7 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( streamWithInvalidEnd, ); @@ -543,7 +546,7 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( multiChunkStream, ); @@ -593,7 +596,7 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( multiChunkStream, ); @@ -650,7 +653,7 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( multiChunkStream, ); @@ -697,7 +700,7 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( mixedContentStream, ); @@ -759,7 +762,7 @@ describe('GeminiChat', () => { ], } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( emptyStreamResponse, ); @@ -811,7 +814,7 @@ describe('GeminiChat', () => { text: () => 'response', } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( response, ); @@ -823,7 +826,7 @@ describe('GeminiChat', () => { // consume stream to trigger internal logic } - expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith( + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledWith( { model: 'gemini-pro', contents: [{ role: 'user', parts: [{ text: 'hello' }] }], @@ -1012,7 +1015,7 @@ describe('GeminiChat', () => { describe('sendMessageStream with retries', () => { it('should yield a RETRY event when an invalid stream is encountered', async () => { // ARRANGE: Mock the stream to fail once, then succeed. - vi.mocked(mockModelsModule.generateContentStream) + vi.mocked(mockContentGenerator.generateContentStream) .mockImplementationOnce(async () => // First attempt: An invalid stream with an empty text part. (async function* () { @@ -1053,7 +1056,7 @@ describe('GeminiChat', () => { }); it('should retry on invalid content, succeed, and report metrics', async () => { // Use mockImplementationOnce to provide a fresh, promise-wrapped generator for each attempt. - vi.mocked(mockModelsModule.generateContentStream) + vi.mocked(mockContentGenerator.generateContentStream) .mockImplementationOnce(async () => // First call returns an invalid stream (async function* () { @@ -1089,7 +1092,9 @@ describe('GeminiChat', () => { expect(mockLogInvalidChunk).toHaveBeenCalledTimes(1); expect(mockLogContentRetry).toHaveBeenCalledTimes(1); expect(mockLogContentRetryFailure).not.toHaveBeenCalled(); - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes( + 2, + ); // Check for a retry event expect(chunks.some((c) => c.type === StreamEventType.RETRY)).toBe(true); @@ -1118,7 +1123,7 @@ describe('GeminiChat', () => { }); it('should fail after all retries on persistent invalid content and report metrics', async () => { - vi.mocked(mockModelsModule.generateContentStream).mockImplementation( + vi.mocked(mockContentGenerator.generateContentStream).mockImplementation( async () => (async function* () { yield { @@ -1150,7 +1155,9 @@ describe('GeminiChat', () => { ); // Should be called 3 times (initial + 2 retries) - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(3); + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes( + 3, + ); expect(mockLogInvalidChunk).toHaveBeenCalledTimes(3); expect(mockLogContentRetry).toHaveBeenCalledTimes(2); expect(mockLogContentRetryFailure).toHaveBeenCalledTimes(1); @@ -1169,7 +1176,7 @@ describe('GeminiChat', () => { chat.setHistory(initialHistory); // 2. Mock the API to fail once with an empty stream, then succeed. - vi.mocked(mockModelsModule.generateContentStream) + vi.mocked(mockContentGenerator.generateContentStream) .mockImplementationOnce(async () => (async function* () { yield { @@ -1264,7 +1271,7 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; // 2. Mock the API to return our controllable promises in order - vi.mocked(mockModelsModule.generateContent) + vi.mocked(mockContentGenerator.generateContent) .mockReturnValueOnce(firstCallPromise) .mockReturnValueOnce(secondCallPromise); @@ -1285,8 +1292,8 @@ describe('GeminiChat', () => { // 5. CRUCIAL CHECK: At this point, only the first API call should have been made. // The second call should be waiting on `sendPromise`. - expect(mockModelsModule.generateContent).toHaveBeenCalledTimes(1); - expect(mockModelsModule.generateContent).toHaveBeenCalledWith( + expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(1); + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( expect.objectContaining({ contents: expect.arrayContaining([ expect.objectContaining({ parts: [{ text: 'first' }] }), @@ -1303,8 +1310,8 @@ describe('GeminiChat', () => { await new Promise(process.nextTick); // 7. CRUCIAL CHECK: Now, the second API call should have been made. - expect(mockModelsModule.generateContent).toHaveBeenCalledTimes(2); - expect(mockModelsModule.generateContent).toHaveBeenCalledWith( + expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(2); + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( expect.objectContaining({ contents: expect.arrayContaining([ expect.objectContaining({ parts: [{ text: 'second' }] }), @@ -1320,7 +1327,7 @@ describe('GeminiChat', () => { }); it('should retry if the model returns a completely empty stream (no chunks)', async () => { // 1. Mock the API to return an empty stream first, then a valid one. - vi.mocked(mockModelsModule.generateContentStream) + vi.mocked(mockContentGenerator.generateContentStream) .mockImplementationOnce( // First call resolves to an async generator that yields nothing. async () => (async function* () {})(), @@ -1353,7 +1360,7 @@ describe('GeminiChat', () => { } // 3. Assert the results. - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(2); expect( chunks.some( (c) => @@ -1417,7 +1424,7 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; })(); - vi.mocked(mockModelsModule.generateContentStream) + vi.mocked(mockContentGenerator.generateContentStream) .mockResolvedValueOnce(firstStreamGenerator) .mockResolvedValueOnce(secondStreamGenerator); @@ -1436,7 +1443,7 @@ describe('GeminiChat', () => { ); // 5. Assert that only one API call has been made so far. - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(1); + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(1); // 6. Unblock and fully consume the first stream to completion. continueFirstStream!(); @@ -1451,7 +1458,7 @@ describe('GeminiChat', () => { await secondStreamIterator.next(); // 9. The second API call should now have been made. - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(2); // 10. FIX: Fully consume the second stream to ensure recordHistory is called. await secondStreamIterator.next(); // This finishes the iterator. @@ -1471,7 +1478,7 @@ describe('GeminiChat', () => { it('should discard valid partial content from a failed attempt upon retry', async () => { // ARRANGE: Mock the stream to fail on the first attempt after yielding some valid content. - vi.mocked(mockModelsModule.generateContentStream) + vi.mocked(mockContentGenerator.generateContentStream) .mockImplementationOnce(async () => // First attempt: yields one valid chunk, then one invalid chunk (async function* () { @@ -1517,7 +1524,7 @@ describe('GeminiChat', () => { // ASSERT // Check that a retry happened - expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(2); expect(events.some((e) => e.type === StreamEventType.RETRY)).toBe(true); // Check the final recorded history @@ -1532,4 +1539,41 @@ describe('GeminiChat', () => { 'This valid part should be discarded', ); }); + + describe('stripThoughtsFromHistory', () => { + it('should strip thought signatures', () => { + chat.setHistory([ + { + role: 'user', + parts: [{ text: 'hello' }], + }, + { + role: 'model', + parts: [ + { text: 'thinking...', thoughtSignature: 'thought-123' }, + { + functionCall: { name: 'test', args: {} }, + thoughtSignature: 'thought-456', + }, + ], + }, + ]); + + chat.stripThoughtsFromHistory(); + + expect(chat.getHistory()).toEqual([ + { + role: 'user', + parts: [{ text: 'hello' }], + }, + { + role: 'model', + parts: [ + { text: 'thinking...' }, + { functionCall: { name: 'test', args: {} } }, + ], + }, + ]); + }); + }); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 159e3560c3..62942b25f8 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -18,7 +18,6 @@ import type { import { toParts } from '../code_assist/converter.js'; import { createUserContent } from '@google/genai'; import { retryWithBackoff } from '../utils/retry.js'; -import type { ContentGenerator } from './contentGenerator.js'; import { AuthType } from './contentGenerator.js'; import type { Config } from '../config/config.js'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; @@ -172,7 +171,6 @@ export class GeminiChat { constructor( private readonly config: Config, - private readonly contentGenerator: ContentGenerator, private readonly generationConfig: GenerateContentConfig = {}, private history: Content[] = [], ) { @@ -287,7 +285,7 @@ export class GeminiChat { ); } - return this.contentGenerator.generateContent( + return this.config.getContentGenerator().generateContent( { model: modelToUse, contents: requestContents, @@ -498,7 +496,7 @@ export class GeminiChat { ); } - return this.contentGenerator.generateContentStream( + return this.config.getContentGenerator().generateContentStream( { model: modelToUse, contents: requestContents, @@ -570,10 +568,28 @@ export class GeminiChat { addHistory(content: Content): void { this.history.push(content); } + setHistory(history: Content[]): void { this.history = history; } + stripThoughtsFromHistory(): void { + this.history = this.history.map((content) => { + const newContent = { ...content }; + if (newContent.parts) { + newContent.parts = newContent.parts.map((part) => { + if (part && typeof part === 'object' && 'thoughtSignature' in part) { + const newPart = { ...part }; + delete (newPart as { thoughtSignature?: string }).thoughtSignature; + return newPart; + } + return part; + }); + } + return newContent; + }); + } + setTools(tools: Tool[]): void { this.generationConfig.tools = tools; } diff --git a/packages/core/src/core/subagent.test.ts b/packages/core/src/core/subagent.test.ts index cc54037bad..e3b51cacc7 100644 --- a/packages/core/src/core/subagent.test.ts +++ b/packages/core/src/core/subagent.test.ts @@ -55,8 +55,6 @@ async function createMockConfig( }; const config = new Config(configParams); await config.initialize(); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - await config.refreshAuth('test-auth' as any); // Mock ToolRegistry const mockToolRegistry = { @@ -164,15 +162,13 @@ describe('subagent.ts', () => { // Helper to safely access generationConfig from mock calls const getGenerationConfigFromMock = ( callIndex = 0, - ): GenerateContentConfig & { systemInstruction?: string | Content } => { + ): GenerateContentConfig => { const callArgs = vi.mocked(GeminiChat).mock.calls[callIndex]; - const generationConfig = callArgs?.[2]; + const generationConfig = callArgs?.[1]; // Ensure it's defined before proceeding expect(generationConfig).toBeDefined(); if (!generationConfig) throw new Error('generationConfig is undefined'); - return generationConfig as GenerateContentConfig & { - systemInstruction?: string | Content; - }; + return generationConfig as GenerateContentConfig; }; describe('create (Tool Validation)', () => { @@ -347,7 +343,7 @@ describe('subagent.ts', () => { ); // Check History (should include environment context) - const history = callArgs[3]; + const history = callArgs[2]; expect(history).toEqual([ { role: 'user', parts: [{ text: 'Env Context' }] }, { @@ -420,7 +416,7 @@ describe('subagent.ts', () => { const callArgs = vi.mocked(GeminiChat).mock.calls[0]; const generationConfig = getGenerationConfigFromMock(); - const history = callArgs[3]; + const history = callArgs[2]; expect(generationConfig.systemInstruction).toBeUndefined(); expect(history).toEqual([ diff --git a/packages/core/src/core/subagent.ts b/packages/core/src/core/subagent.ts index 41de5978a1..fe5ac1ff87 100644 --- a/packages/core/src/core/subagent.ts +++ b/packages/core/src/core/subagent.ts @@ -10,7 +10,6 @@ import type { AnyDeclarativeTool } from '../tools/tools.js'; import type { Config } from '../config/config.js'; import type { ToolCallRequestInfo } from './turn.js'; import { executeToolCall } from './nonInteractiveToolExecutor.js'; -import { createContentGenerator } from './contentGenerator.js'; import { getEnvironmentContext } from '../utils/environmentContext.js'; import type { Content, @@ -635,9 +634,7 @@ export class SubAgentScope { : undefined; try { - const generationConfig: GenerateContentConfig & { - systemInstruction?: string | Content; - } = { + const generationConfig: GenerateContentConfig = { temperature: this.modelConfig.temp, topP: this.modelConfig.top_p, }; @@ -646,17 +643,10 @@ export class SubAgentScope { generationConfig.systemInstruction = systemInstruction; } - const contentGenerator = await createContentGenerator( - this.runtimeContext.getContentGeneratorConfig(), - this.runtimeContext, - this.runtimeContext.getSessionId(), - ); - this.runtimeContext.setModel(this.modelConfig.model); return new GeminiChat( this.runtimeContext, - contentGenerator, generationConfig, start_history, ); diff --git a/packages/core/src/utils/nextSpeakerChecker.test.ts b/packages/core/src/utils/nextSpeakerChecker.test.ts index b9e861998e..dab9099d69 100644 --- a/packages/core/src/utils/nextSpeakerChecker.test.ts +++ b/packages/core/src/utils/nextSpeakerChecker.test.ts @@ -6,10 +6,10 @@ import type { Mock } from 'vitest'; import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import type { Content, GoogleGenAI, Models } from '@google/genai'; +import type { Content } from '@google/genai'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { GeminiClient } from '../core/client.js'; -import { Config } from '../config/config.js'; +import type { Config } from '../config/config.js'; import type { NextSpeakerResponse } from './nextSpeakerChecker.js'; import { checkNextSpeaker } from './nextSpeakerChecker.js'; import { GeminiChat } from '../core/geminiChat.js'; @@ -44,73 +44,28 @@ vi.mock('node:fs', () => { vi.mock('../core/client.js'); vi.mock('../config/config.js'); -// Define mocks for GoogleGenAI and Models instances that will be used across tests -const mockModelsInstance = { - generateContent: vi.fn(), - generateContentStream: vi.fn(), - countTokens: vi.fn(), - embedContent: vi.fn(), - batchEmbedContents: vi.fn(), -} as unknown as Models; - -const mockGoogleGenAIInstance = { - getGenerativeModel: vi.fn().mockReturnValue(mockModelsInstance), - // Add other methods of GoogleGenAI if they are directly used by GeminiChat constructor or its methods -} as unknown as GoogleGenAI; - -vi.mock('@google/genai', async () => { - const actualGenAI = - await vi.importActual('@google/genai'); - return { - ...actualGenAI, - GoogleGenAI: vi.fn(() => mockGoogleGenAIInstance), // Mock constructor to return the predefined instance - // If Models is instantiated directly in GeminiChat, mock its constructor too - // For now, assuming Models instance is obtained via getGenerativeModel - }; -}); - describe('checkNextSpeaker', () => { let chatInstance: GeminiChat; + let mockConfig: Config; let mockGeminiClient: GeminiClient; - let MockConfig: Mock; const abortSignal = new AbortController().signal; beforeEach(() => { - MockConfig = vi.mocked(Config); - const mockConfigInstance = new MockConfig( - 'test-api-key', - 'gemini-pro', - false, - '.', - false, - undefined, - false, - undefined, - undefined, - undefined, - ); + vi.resetAllMocks(); + mockConfig = { + getProjectRoot: vi.fn().mockReturnValue('/test/project/root'), + getSessionId: vi.fn().mockReturnValue('test-session-id'), + getModel: () => 'test-model', + storage: { + getProjectTempDir: vi.fn().mockReturnValue('/test/temp'), + }, + } as unknown as Config; - // Mock the methods that ChatRecordingService needs - mockConfigInstance.getSessionId = vi - .fn() - .mockReturnValue('test-session-id'); - mockConfigInstance.getProjectRoot = vi - .fn() - .mockReturnValue('/test/project/root'); - mockConfigInstance.storage = { - getProjectTempDir: vi.fn().mockReturnValue('/test/temp'), - }; - - mockGeminiClient = new GeminiClient(mockConfigInstance); - - // Reset mocks before each test to ensure test isolation - vi.mocked(mockModelsInstance.generateContent).mockReset(); - vi.mocked(mockModelsInstance.generateContentStream).mockReset(); + mockGeminiClient = new GeminiClient(mockConfig); // GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor chatInstance = new GeminiChat( - mockConfigInstance, - mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel + mockConfig, {}, [], // initial history ); @@ -120,7 +75,7 @@ describe('checkNextSpeaker', () => { }); afterEach(() => { - vi.clearAllMocks(); + vi.restoreAllMocks(); }); it('should return null if history is empty', async () => { @@ -135,9 +90,9 @@ describe('checkNextSpeaker', () => { }); it('should return null if the last speaker was the user', async () => { - (chatInstance.getHistory as Mock).mockReturnValue([ + vi.mocked(chatInstance.getHistory).mockReturnValue([ { role: 'user', parts: [{ text: 'Hello' }] }, - ] as Content[]); + ]); const result = await checkNextSpeaker( chatInstance, mockGeminiClient,