Pull contentGenerator out of GeminiClient and into Config. (#7825)

This commit is contained in:
Tommaso Sciortino
2025-09-07 13:00:03 -07:00
committed by GitHub
parent d33defde68
commit 6e26c88c2c
16 changed files with 359 additions and 781 deletions
+1 -1
View File
@@ -40,7 +40,7 @@ export async function createCodeAssistContentGenerator(
export function getCodeAssistServer(
config: Config,
): CodeAssistServer | undefined {
let server = config.getGeminiClient().getContentGenerator();
let server = config.getContentGenerator();
// Unwrap LoggingContentGenerator if present
if (server instanceof LoggingContentGenerator) {
+21 -159
View File
@@ -71,18 +71,12 @@ vi.mock('../tools/memoryTool', () => ({
GEMINI_CONFIG_DIR: '.gemini',
}));
vi.mock('../core/contentGenerator.js', async (importOriginal) => {
const actual =
await importOriginal<typeof import('../core/contentGenerator.js')>();
return {
...actual,
createContentGeneratorConfig: vi.fn(),
};
});
vi.mock('../core/contentGenerator.js');
vi.mock('../core/client.js', () => ({
GeminiClient: vi.fn().mockImplementation(() => ({
initialize: vi.fn().mockResolvedValue(undefined),
stripThoughtsFromHistory: vi.fn(),
})),
}));
@@ -196,7 +190,9 @@ describe('Server Config (config.ts)', () => {
apiKey: 'test-key',
};
(createContentGeneratorConfig as Mock).mockReturnValue(mockContentConfig);
vi.mocked(createContentGeneratorConfig).mockReturnValue(
mockContentConfig,
);
// Set fallback mode to true to ensure it gets reset
config.setFallbackMode(true);
@@ -217,172 +213,38 @@ describe('Server Config (config.ts)', () => {
expect(config.isInFallbackMode()).toBe(false);
});
it('should preserve conversation history when refreshing auth', async () => {
const config = new Config(baseParams);
const authType = AuthType.USE_GEMINI;
const mockContentConfig = {
model: 'gemini-pro',
apiKey: 'test-key',
};
(createContentGeneratorConfig as Mock).mockReturnValue(mockContentConfig);
// Mock the existing client with some history
const mockExistingHistory = [
{ role: 'user', parts: [{ text: 'Hello' }] },
{ role: 'model', parts: [{ text: 'Hi there!' }] },
{ role: 'user', parts: [{ text: 'How are you?' }] },
];
const mockExistingClient = {
isInitialized: vi.fn().mockReturnValue(true),
getHistory: vi.fn().mockReturnValue(mockExistingHistory),
};
const mockNewClient = {
isInitialized: vi.fn().mockReturnValue(true),
getHistory: vi.fn().mockReturnValue([]),
setHistory: vi.fn(),
initialize: vi.fn().mockResolvedValue(undefined),
};
// Set the existing client
(
config as unknown as { geminiClient: typeof mockExistingClient }
).geminiClient = mockExistingClient;
(GeminiClient as Mock).mockImplementation(() => mockNewClient);
await config.refreshAuth(authType);
// Verify that existing history was retrieved
expect(mockExistingClient.getHistory).toHaveBeenCalled();
// Verify that new client was created and initialized
expect(GeminiClient).toHaveBeenCalledWith(config);
expect(mockNewClient.initialize).toHaveBeenCalledWith(mockContentConfig);
// Verify that history was restored to the new client
expect(mockNewClient.setHistory).toHaveBeenCalledWith(
mockExistingHistory,
{ stripThoughts: false },
);
});
it('should handle case when no existing client is initialized', async () => {
const config = new Config(baseParams);
const authType = AuthType.USE_GEMINI;
const mockContentConfig = {
model: 'gemini-pro',
apiKey: 'test-key',
};
(createContentGeneratorConfig as Mock).mockReturnValue(mockContentConfig);
const mockNewClient = {
isInitialized: vi.fn().mockReturnValue(true),
getHistory: vi.fn().mockReturnValue([]),
setHistory: vi.fn(),
initialize: vi.fn().mockResolvedValue(undefined),
};
// No existing client
(config as unknown as { geminiClient: null }).geminiClient = null;
(GeminiClient as Mock).mockImplementation(() => mockNewClient);
await config.refreshAuth(authType);
// Verify that new client was created and initialized
expect(GeminiClient).toHaveBeenCalledWith(config);
expect(mockNewClient.initialize).toHaveBeenCalledWith(mockContentConfig);
// Verify that setHistory was not called since there was no existing history
expect(mockNewClient.setHistory).not.toHaveBeenCalled();
});
it('should strip thoughts when switching from GenAI to Vertex', async () => {
const config = new Config(baseParams);
const mockContentConfig = {
model: 'gemini-pro',
apiKey: 'test-key',
authType: AuthType.USE_GEMINI,
};
(
config as unknown as { contentGeneratorConfig: ContentGeneratorConfig }
).contentGeneratorConfig = mockContentConfig;
(createContentGeneratorConfig as Mock).mockReturnValue({
...mockContentConfig,
authType: AuthType.LOGIN_WITH_GOOGLE,
});
vi.mocked(createContentGeneratorConfig).mockImplementation(
(_: Config, authType: AuthType | undefined) =>
({ authType }) as unknown as ContentGeneratorConfig,
);
const mockExistingHistory = [
{ role: 'user', parts: [{ text: 'Hello' }] },
];
const mockExistingClient = {
isInitialized: vi.fn().mockReturnValue(true),
getHistory: vi.fn().mockReturnValue(mockExistingHistory),
};
const mockNewClient = {
isInitialized: vi.fn().mockReturnValue(true),
getHistory: vi.fn().mockReturnValue([]),
setHistory: vi.fn(),
initialize: vi.fn().mockResolvedValue(undefined),
};
(
config as unknown as { geminiClient: typeof mockExistingClient }
).geminiClient = mockExistingClient;
(GeminiClient as Mock).mockImplementation(() => mockNewClient);
await config.refreshAuth(AuthType.USE_GEMINI);
await config.refreshAuth(AuthType.LOGIN_WITH_GOOGLE);
expect(mockNewClient.setHistory).toHaveBeenCalledWith(
mockExistingHistory,
{ stripThoughts: true },
);
expect(
config.getGeminiClient().stripThoughtsFromHistory,
).toHaveBeenCalledWith();
});
it('should not strip thoughts when switching from Vertex to GenAI', async () => {
const config = new Config(baseParams);
const mockContentConfig = {
model: 'gemini-pro',
apiKey: 'test-key',
authType: AuthType.LOGIN_WITH_GOOGLE,
};
(
config as unknown as { contentGeneratorConfig: ContentGeneratorConfig }
).contentGeneratorConfig = mockContentConfig;
(createContentGeneratorConfig as Mock).mockReturnValue({
...mockContentConfig,
authType: AuthType.USE_GEMINI,
});
vi.mocked(createContentGeneratorConfig).mockImplementation(
(_: Config, authType: AuthType | undefined) =>
({ authType }) as unknown as ContentGeneratorConfig,
);
const mockExistingHistory = [
{ role: 'user', parts: [{ text: 'Hello' }] },
];
const mockExistingClient = {
isInitialized: vi.fn().mockReturnValue(true),
getHistory: vi.fn().mockReturnValue(mockExistingHistory),
};
const mockNewClient = {
isInitialized: vi.fn().mockReturnValue(true),
getHistory: vi.fn().mockReturnValue([]),
setHistory: vi.fn(),
initialize: vi.fn().mockResolvedValue(undefined),
};
(
config as unknown as { geminiClient: typeof mockExistingClient }
).geminiClient = mockExistingClient;
(GeminiClient as Mock).mockImplementation(() => mockNewClient);
await config.refreshAuth(AuthType.USE_VERTEX_AI);
await config.refreshAuth(AuthType.USE_GEMINI);
expect(mockNewClient.setHistory).toHaveBeenCalledWith(
mockExistingHistory,
{ stripThoughts: false },
);
expect(
config.getGeminiClient().stripThoughtsFromHistory,
).not.toHaveBeenCalledWith();
});
});
+36 -26
View File
@@ -6,9 +6,13 @@
import * as path from 'node:path';
import process from 'node:process';
import type { ContentGeneratorConfig } from '../core/contentGenerator.js';
import type {
ContentGenerator,
ContentGeneratorConfig,
} from '../core/contentGenerator.js';
import {
AuthType,
createContentGenerator,
createContentGeneratorConfig,
} from '../core/contentGenerator.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
@@ -44,7 +48,6 @@ import { shouldAttemptBrowserLaunch } from '../utils/browser.js';
import type { MCPOAuthConfig } from '../mcp/oauth-provider.js';
import { IdeClient } from '../ide/ide-client.js';
import { ideContext } from '../ide/ideContext.js';
import type { Content } from '@google/genai';
import type { FileSystemService } from '../services/fileSystemService.js';
import { StandardFileSystemService } from '../services/fileSystemService.js';
import { logCliConfiguration, logIdeConnection } from '../telemetry/loggers.js';
@@ -57,6 +60,8 @@ import { WorkspaceContext } from '../utils/workspaceContext.js';
import { Storage } from './storage.js';
import { FileExclusions } from '../utils/ignorePatterns.js';
import type { EventEmitter } from 'node:events';
import type { UserTierId } from '../code_assist/types.js';
import { ProxyAgent, setGlobalDispatcher } from 'undici';
export enum ApprovalMode {
DEFAULT = 'default',
@@ -226,6 +231,7 @@ export class Config {
private readonly sessionId: string;
private fileSystemService: FileSystemService;
private contentGeneratorConfig!: ContentGeneratorConfig;
private contentGenerator!: ContentGenerator;
private readonly embeddingModel: string;
private readonly sandbox: SandboxConfig | undefined;
private readonly targetDir: string;
@@ -386,6 +392,11 @@ export class Config {
if (this.telemetrySettings.enabled) {
initializeTelemetry(this);
}
if (this.getProxy()) {
setGlobalDispatcher(new ProxyAgent(this.getProxy() as string));
}
this.geminiClient = new GeminiClient(this);
}
/**
@@ -410,46 +421,45 @@ export class Config {
this.promptRegistry = new PromptRegistry();
this.toolRegistry = await this.createToolRegistry();
logCliConfiguration(this, new StartSessionEvent(this, this.toolRegistry));
await this.geminiClient.initialize();
}
getContentGenerator(): ContentGenerator {
return this.contentGenerator;
}
async refreshAuth(authMethod: AuthType) {
// Save the current conversation history before creating a new client
let existingHistory: Content[] = [];
if (this.geminiClient && this.geminiClient.isInitialized()) {
existingHistory = this.geminiClient.getHistory();
// Vertex and Genai have incompatible encryption and sending history with
// throughtSignature from Genai to Vertex will fail, we need to strip them
if (
this.contentGeneratorConfig?.authType === AuthType.USE_GEMINI &&
authMethod === AuthType.LOGIN_WITH_GOOGLE
) {
// Restore the conversation history to the new client
this.geminiClient.stripThoughtsFromHistory();
}
// Create new content generator config
const newContentGeneratorConfig = createContentGeneratorConfig(
this,
authMethod,
);
// Create and initialize new client in local variable first
const newGeminiClient = new GeminiClient(this);
await newGeminiClient.initialize(newContentGeneratorConfig);
// Vertex and Genai have incompatible encryption and sending history with
// throughtSignature from Genai to Vertex will fail, we need to strip them
const fromGenaiToVertex =
this.contentGeneratorConfig?.authType === AuthType.USE_GEMINI &&
authMethod === AuthType.LOGIN_WITH_GOOGLE;
this.contentGenerator = await createContentGenerator(
newContentGeneratorConfig,
this,
this.getSessionId(),
);
// Only assign to instance properties after successful initialization
this.contentGeneratorConfig = newContentGeneratorConfig;
this.geminiClient = newGeminiClient;
// Restore the conversation history to the new client
if (existingHistory.length > 0) {
this.geminiClient.setHistory(existingHistory, {
stripThoughts: fromGenaiToVertex,
});
}
// Reset the session flag since we're explicitly changing auth and using default model
this.inFallbackMode = false;
}
getUserTier(): UserTierId | undefined {
return this.contentGenerator?.userTier;
}
getSessionId(): string {
return this.sessionId;
}
+98 -355
View File
@@ -4,24 +4,9 @@
* SPDX-License-Identifier: Apache-2.0
*/
import {
describe,
it,
expect,
vi,
beforeEach,
afterEach,
type Mocked,
} from 'vitest';
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import type {
Chat,
Content,
EmbedContentResponse,
GenerateContentResponse,
Part,
} from '@google/genai';
import { GoogleGenAI } from '@google/genai';
import type { Content, GenerateContentResponse, Part } from '@google/genai';
import {
findIndexAfterFraction,
isThinkingDefault,
@@ -34,7 +19,7 @@ import {
type ContentGeneratorConfig,
} from './contentGenerator.js';
import { type GeminiChat } from './geminiChat.js';
import { Config } from '../config/config.js';
import type { Config } from '../config/config.js';
import {
CompressionStatus,
GeminiEventType,
@@ -76,12 +61,8 @@ vi.mock('node:fs', () => {
});
// --- Mocks ---
const mockChatCreateFn = vi.fn();
const mockGenerateContentFn = vi.fn();
const mockEmbedContentFn = vi.fn();
const mockTurnRunFn = vi.fn();
vi.mock('@google/genai');
vi.mock('./turn', async (importOriginal) => {
const actual = await importOriginal<typeof import('./turn.js')>();
// Define a mock class that has the same shape as the real Turn
@@ -228,6 +209,8 @@ describe('isThinkingDefault', () => {
});
describe('Gemini Client (client.ts)', () => {
let mockContentGenerator: ContentGenerator;
let mockConfig: Config;
let client: GeminiClient;
beforeEach(async () => {
vi.resetAllMocks();
@@ -235,29 +218,15 @@ describe('Gemini Client (client.ts)', () => {
// Disable 429 simulation for tests
setSimulate429(false);
// Set up the mock for GoogleGenAI constructor and its methods
const MockedGoogleGenAI = vi.mocked(GoogleGenAI);
MockedGoogleGenAI.mockImplementation(() => {
const mock = {
chats: { create: mockChatCreateFn },
models: {
generateContent: mockGenerateContentFn,
embedContent: mockEmbedContentFn,
},
};
return mock as unknown as GoogleGenAI;
});
mockChatCreateFn.mockResolvedValue({} as Chat);
mockGenerateContentFn.mockResolvedValue({
candidates: [
{
content: {
parts: [{ text: '{"key": "value"}' }],
},
},
],
} as unknown as GenerateContentResponse);
mockContentGenerator = {
generateContent: vi.fn().mockResolvedValue({
candidates: [{ content: { parts: [{ text: '{"key": "value"}' }] } }],
}),
generateContentStream: vi.fn(),
countTokens: vi.fn(),
embedContent: vi.fn(),
batchEmbedContents: vi.fn(),
} as unknown as ContentGenerator;
// Because the GeminiClient constructor kicks off an async process (startChat)
// that depends on a fully-formed Config object, we need to mock the
@@ -273,7 +242,7 @@ describe('Gemini Client (client.ts)', () => {
vertexai: false,
authType: AuthType.USE_GEMINI,
};
const mockConfigObject = {
mockConfig = {
getContentGeneratorConfig: vi
.fn()
.mockReturnValue(contentGeneratorConfig),
@@ -309,46 +278,18 @@ describe('Gemini Client (client.ts)', () => {
storage: {
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
},
};
const MockedConfig = vi.mocked(Config, true);
MockedConfig.mockImplementation(
() => mockConfigObject as unknown as Config,
);
getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator),
} as unknown as Config;
// We can instantiate the client here since Config is mocked
// and the constructor will use the mocked GoogleGenAI
client = new GeminiClient(
new Config({ sessionId: 'test-session-id' } as never),
);
mockConfigObject.getGeminiClient.mockReturnValue(client);
await client.initialize(contentGeneratorConfig);
client = new GeminiClient(mockConfig);
await client.initialize();
vi.mocked(mockConfig.getGeminiClient).mockReturnValue(client);
});
afterEach(() => {
vi.restoreAllMocks();
});
// NOTE: The following tests for startChat were removed due to persistent issues with
// the @google/genai mock. Specifically, the mockChatCreateFn (representing instance.chats.create)
// was not being detected as called by the GeminiClient instance.
// This likely points to a subtle issue in how the GoogleGenerativeAI class constructor
// and its instance methods are mocked and then used by the class under test.
// For future debugging, ensure that the `this.client` in `GeminiClient` (which is an
// instance of the mocked GoogleGenerativeAI) correctly has its `chats.create` method
// pointing to `mockChatCreateFn`.
// it('startChat should call getCoreSystemPrompt with userMemory and pass to chats.create', async () => { ... });
// it('startChat should call getCoreSystemPrompt with empty string if userMemory is empty', async () => { ... });
// NOTE: The following tests for generateJson were removed due to persistent issues with
// the @google/genai mock, similar to the startChat tests. The mockGenerateContentFn
// (representing instance.models.generateContent) was not being detected as called, or the mock
// was not preventing an actual API call (leading to API key errors).
// For future debugging, ensure `this.client.models.generateContent` in `GeminiClient` correctly
// uses the `mockGenerateContentFn`.
// it('generateJson should call getCoreSystemPrompt with userMemory and pass to generateContent', async () => { ... });
// it('generateJson should call getCoreSystemPrompt with empty string if userMemory is empty', async () => { ... });
describe('generateEmbedding', () => {
const texts = ['hello world', 'goodbye world'];
const testEmbeddingModel = 'test-embedding-model';
@@ -358,18 +299,17 @@ describe('Gemini Client (client.ts)', () => {
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
];
const mockResponse: EmbedContentResponse = {
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
embeddings: [
{ values: mockEmbeddings[0] },
{ values: mockEmbeddings[1] },
],
};
mockEmbedContentFn.mockResolvedValue(mockResponse);
});
const result = await client.generateEmbedding(texts);
expect(mockEmbedContentFn).toHaveBeenCalledTimes(1);
expect(mockEmbedContentFn).toHaveBeenCalledWith({
expect(mockContentGenerator.embedContent).toHaveBeenCalledTimes(1);
expect(mockContentGenerator.embedContent).toHaveBeenCalledWith({
model: testEmbeddingModel,
contents: texts,
});
@@ -379,11 +319,11 @@ describe('Gemini Client (client.ts)', () => {
it('should return an empty array if an empty array is passed', async () => {
const result = await client.generateEmbedding([]);
expect(result).toEqual([]);
expect(mockEmbedContentFn).not.toHaveBeenCalled();
expect(mockContentGenerator.embedContent).not.toHaveBeenCalled();
});
it('should throw an error if API response has no embeddings array', async () => {
mockEmbedContentFn.mockResolvedValue({} as EmbedContentResponse); // No `embeddings` key
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({});
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'No embeddings found in API response.',
@@ -391,20 +331,19 @@ describe('Gemini Client (client.ts)', () => {
});
it('should throw an error if API response has an empty embeddings array', async () => {
const mockResponse: EmbedContentResponse = {
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
embeddings: [],
};
mockEmbedContentFn.mockResolvedValue(mockResponse);
});
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'No embeddings found in API response.',
);
});
it('should throw an error if API returns a mismatched number of embeddings', async () => {
const mockResponse: EmbedContentResponse = {
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
embeddings: [{ values: [1, 2, 3] }], // Only one for two texts
};
mockEmbedContentFn.mockResolvedValue(mockResponse);
});
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'API returned a mismatched number of embeddings. Expected 2, got 1.',
@@ -412,10 +351,9 @@ describe('Gemini Client (client.ts)', () => {
});
it('should throw an error if any embedding has nullish values', async () => {
const mockResponse: EmbedContentResponse = {
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad
};
mockEmbedContentFn.mockResolvedValue(mockResponse);
});
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'API returned an empty embedding for input text at index 1: "goodbye world"',
@@ -423,10 +361,9 @@ describe('Gemini Client (client.ts)', () => {
});
it('should throw an error if any embedding has an empty values array', async () => {
const mockResponse: EmbedContentResponse = {
vi.mocked(mockContentGenerator.embedContent).mockResolvedValue({
embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad
};
mockEmbedContentFn.mockResolvedValue(mockResponse);
});
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'API returned an empty embedding for input text at index 0: "hello world"',
@@ -434,8 +371,9 @@ describe('Gemini Client (client.ts)', () => {
});
it('should propagate errors from the API call', async () => {
const apiError = new Error('API Failure');
mockEmbedContentFn.mockRejectedValue(apiError);
vi.mocked(mockContentGenerator.embedContent).mockRejectedValue(
new Error('API Failure'),
);
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'API Failure',
@@ -449,12 +387,9 @@ describe('Gemini Client (client.ts)', () => {
const schema = { type: 'string' };
const abortSignal = new AbortController().signal;
// Mock countTokens
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
totalTokens: 1,
});
await client.generateJson(
contents,
@@ -463,7 +398,7 @@ describe('Gemini Client (client.ts)', () => {
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(mockGenerateContentFn).toHaveBeenCalledWith(
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
{
model: DEFAULT_GEMINI_FLASH_MODEL,
config: {
@@ -489,11 +424,9 @@ describe('Gemini Client (client.ts)', () => {
const customModel = 'custom-json-model';
const customConfig = { temperature: 0.9, topK: 20 };
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
totalTokens: 1,
});
await client.generateJson(
contents,
@@ -503,7 +436,7 @@ describe('Gemini Client (client.ts)', () => {
customConfig,
);
expect(mockGenerateContentFn).toHaveBeenCalledWith(
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
{
model: customModel,
config: {
@@ -524,10 +457,10 @@ describe('Gemini Client (client.ts)', () => {
describe('addHistory', () => {
it('should call chat.addHistory with the provided content', async () => {
const mockChat: Partial<GeminiChat> = {
const mockChat = {
addHistory: vi.fn(),
};
client['chat'] = mockChat as GeminiChat;
} as unknown as GeminiChat;
client['chat'] = mockChat;
const newContent = {
role: 'user',
@@ -568,7 +501,6 @@ describe('Gemini Client (client.ts)', () => {
});
describe('tryCompressChat', () => {
const mockCountTokens = vi.fn();
const mockSendMessage = vi.fn();
const mockGetHistory = vi.fn();
@@ -577,10 +509,6 @@ describe('Gemini Client (client.ts)', () => {
tokenLimit: vi.fn(),
}));
client['contentGenerator'] = {
countTokens: mockCountTokens,
} as unknown as ContentGenerator;
client['chat'] = {
getHistory: mockGetHistory,
addHistory: vi.fn(),
@@ -600,27 +528,21 @@ describe('Gemini Client (client.ts)', () => {
setHistory: vi.fn(),
sendMessage: vi.fn().mockResolvedValue({ text: 'Summary' }),
};
const mockCountTokens = vi
.fn()
vi.mocked(mockContentGenerator.countTokens)
.mockResolvedValueOnce({ totalTokens: 1000 })
.mockResolvedValueOnce({ totalTokens: 5000 });
const mockGenerator: Partial<Mocked<ContentGenerator>> = {
countTokens: mockCountTokens,
};
client['chat'] = mockChat as GeminiChat;
client['contentGenerator'] = mockGenerator as ContentGenerator;
client['startChat'] = vi.fn().mockResolvedValue({ ...mockChat });
return { client, mockChat, mockGenerator };
return { client, mockChat };
}
describe('when compression inflates the token count', () => {
it('uses the truncated history for compression');
it('allows compression to be forced/manual after a failure', async () => {
const { client, mockGenerator } = setup();
mockGenerator.countTokens?.mockResolvedValue({
const { client } = setup();
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
totalTokens: 1000,
});
await client.tryCompressChat('prompt-id-4'); // Fails
@@ -635,6 +557,9 @@ describe('Gemini Client (client.ts)', () => {
it('yields the result even if the compression inflated the tokens', async () => {
const { client } = setup();
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
totalTokens: 1000,
});
const result = await client.tryCompressChat('prompt-id-4', true);
expect(result).toEqual({
@@ -654,7 +579,7 @@ describe('Gemini Client (client.ts)', () => {
it('restores the history back to the original', async () => {
vi.mocked(tokenLimit).mockReturnValue(1000);
mockCountTokens.mockResolvedValue({
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
totalTokens: 999,
});
@@ -679,13 +604,13 @@ describe('Gemini Client (client.ts)', () => {
});
it('will not attempt to compress context after a failure', async () => {
const { client, mockGenerator } = setup();
const { client } = setup();
await client.tryCompressChat('prompt-id-4');
const result = await client.tryCompressChat('prompt-id-5');
// it counts tokens for {original, compressed} and then never again
expect(mockGenerator.countTokens).toHaveBeenCalledTimes(2);
expect(mockContentGenerator.countTokens).toHaveBeenCalledTimes(2);
expect(result).toEqual({
compressionStatus: CompressionStatus.NOOP,
newTokenCount: 0,
@@ -696,7 +621,7 @@ describe('Gemini Client (client.ts)', () => {
it('attempts to compress with a maxOutputTokens set to the original token count', async () => {
vi.mocked(tokenLimit).mockReturnValue(1000);
mockCountTokens.mockResolvedValue({
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
totalTokens: 999,
});
@@ -728,8 +653,7 @@ describe('Gemini Client (client.ts)', () => {
mockGetHistory.mockReturnValue([
{ role: 'user', parts: [{ text: '...history...' }] },
]);
mockCountTokens.mockResolvedValue({
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
totalTokens: MOCKED_TOKEN_LIMIT * 0.699, // TOKEN_THRESHOLD_FOR_SUMMARIZATION = 0.7
});
@@ -763,7 +687,7 @@ describe('Gemini Client (client.ts)', () => {
MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD;
const newTokenCount = 100;
mockCountTokens
vi.mocked(mockContentGenerator.countTokens)
.mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
@@ -800,7 +724,7 @@ describe('Gemini Client (client.ts)', () => {
MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD;
const newTokenCount = 100;
mockCountTokens
vi.mocked(mockContentGenerator.countTokens)
.mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
@@ -853,7 +777,7 @@ describe('Gemini Client (client.ts)', () => {
const originalTokenCount = 1000 * 0.7;
const newTokenCount = 100;
mockCountTokens
vi.mocked(mockContentGenerator.countTokens)
.mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
@@ -895,7 +819,7 @@ describe('Gemini Client (client.ts)', () => {
const originalTokenCount = 10; // Well below threshold
const newTokenCount = 5;
mockCountTokens
vi.mocked(mockContentGenerator.countTokens)
.mockResolvedValueOnce({ totalTokens: originalTokenCount })
.mockResolvedValueOnce({ totalTokens: newTokenCount });
@@ -922,10 +846,16 @@ describe('Gemini Client (client.ts)', () => {
});
it('should use current model from config for token counting after sendMessage', async () => {
const initialModel = client['config'].getModel();
const initialModel = mockConfig.getModel();
const mockCountTokens = vi
.fn()
// mock the model has been changed between calls of `countTokens`
const firstCurrentModel = initialModel + '-changed-1';
const secondCurrentModel = initialModel + '-changed-2';
vi.mocked(mockConfig.getModel)
.mockReturnValueOnce(firstCurrentModel)
.mockReturnValueOnce(secondCurrentModel);
vi.mocked(mockContentGenerator.countTokens)
.mockResolvedValueOnce({ totalTokens: 100000 })
.mockResolvedValueOnce({ totalTokens: 5000 });
@@ -936,35 +866,23 @@ describe('Gemini Client (client.ts)', () => {
{ role: 'model', parts: [{ text: 'Long response' }] },
];
const mockChat: Partial<GeminiChat> = {
const mockChat = {
getHistory: vi.fn().mockReturnValue(mockChatHistory),
setHistory: vi.fn(),
sendMessage: mockSendMessage,
};
} as unknown as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: mockCountTokens,
};
// mock the model has been changed between calls of `countTokens`
const firstCurrentModel = initialModel + '-changed-1';
const secondCurrentModel = initialModel + '-changed-2';
vi.spyOn(client['config'], 'getModel')
.mockReturnValueOnce(firstCurrentModel)
.mockReturnValueOnce(secondCurrentModel);
client['chat'] = mockChat as GeminiChat;
client['contentGenerator'] = mockGenerator as ContentGenerator;
client['chat'] = mockChat;
client['startChat'] = vi.fn().mockResolvedValue(mockChat);
const result = await client.tryCompressChat('prompt-id-4', true);
expect(mockCountTokens).toHaveBeenCalledTimes(2);
expect(mockCountTokens).toHaveBeenNthCalledWith(1, {
expect(mockContentGenerator.countTokens).toHaveBeenCalledTimes(2);
expect(mockContentGenerator.countTokens).toHaveBeenNthCalledWith(1, {
model: firstCurrentModel,
contents: mockChatHistory,
});
expect(mockCountTokens).toHaveBeenNthCalledWith(2, {
expect(mockContentGenerator.countTokens).toHaveBeenNthCalledWith(2, {
model: secondCurrentModel,
contents: expect.any(Array),
});
@@ -980,22 +898,11 @@ describe('Gemini Client (client.ts)', () => {
describe('sendMessageStream', () => {
it('emits a compression event when the context was automatically compressed', async () => {
// Arrange
const mockStream = (async function* () {
yield { type: 'content', value: 'Hello' };
})();
mockTurnRunFn.mockReturnValue(mockStream);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
mockTurnRunFn.mockReturnValue(
(async function* () {
yield { type: 'content', value: 'Hello' };
})(),
);
const compressionInfo: ChatCompressionInfo = {
compressionStatus: CompressionStatus.COMPRESSED,
@@ -1042,18 +949,6 @@ describe('Gemini Client (client.ts)', () => {
})();
mockTurnRunFn.mockReturnValue(mockStream);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
const compressionInfo: ChatCompressionInfo = {
compressionStatus,
originalTokenCount: 1000,
@@ -1105,24 +1000,19 @@ describe('Gemini Client (client.ts)', () => {
},
});
vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true);
vi.mocked(mockConfig.getIdeMode).mockReturnValue(true);
const mockStream = (async function* () {
yield { type: 'content', value: 'Hello' };
})();
mockTurnRunFn.mockReturnValue(mockStream);
mockTurnRunFn.mockReturnValue(
(async function* () {
yield { type: 'content', value: 'Hello' };
})(),
);
const mockChat: Partial<GeminiChat> = {
const mockChat = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
} as unknown as GeminiChat;
client['chat'] = mockChat;
const initialRequest: Part[] = [{ text: 'Hi' }];
@@ -1186,12 +1076,6 @@ ${JSON.stringify(
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
const initialRequest = [{ text: 'Hi' }];
// Act
@@ -1241,12 +1125,6 @@ ${JSON.stringify(
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
const initialRequest = [{ text: 'Hi' }];
// Act
@@ -1317,12 +1195,6 @@ ${JSON.stringify(
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
const initialRequest = [{ text: 'Hi' }];
// Act
@@ -1369,12 +1241,6 @@ ${JSON.stringify(
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
// Act
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
@@ -1419,12 +1285,6 @@ ${JSON.stringify(
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
// Use a signal that never gets aborted
const abortController = new AbortController();
const signal = abortController.signal;
@@ -1512,12 +1372,6 @@ ${JSON.stringify(
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
// Act & Assert
// Run up to the limit
for (let i = 0; i < MAX_SESSION_TURNS; i++) {
@@ -1574,12 +1428,6 @@ ${JSON.stringify(
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
// Use a signal that never gets aborted
const abortController = new AbortController();
const signal = abortController.signal;
@@ -1659,12 +1507,6 @@ ${JSON.stringify(
]),
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
});
const testCases = [
@@ -1921,11 +1763,6 @@ ${JSON.stringify(
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true);
vi.mocked(ideContext.getIdeContext).mockReturnValue({
workspaceState: {
@@ -2266,12 +2103,6 @@ ${JSON.stringify(
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
// Act
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
@@ -2308,12 +2139,6 @@ ${JSON.stringify(
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
// Act
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
@@ -2335,13 +2160,6 @@ ${JSON.stringify(
const generationConfig = { temperature: 0.5 };
const abortSignal = new AbortController().signal;
// Mock countTokens
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
await client.generateContent(
contents,
generationConfig,
@@ -2349,7 +2167,7 @@ ${JSON.stringify(
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(mockGenerateContentFn).toHaveBeenCalledWith(
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
{
model: DEFAULT_GEMINI_FLASH_MODEL,
config: {
@@ -2371,12 +2189,6 @@ ${JSON.stringify(
vi.spyOn(client['config'], 'getModel').mockReturnValueOnce(currentModel);
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
await client.generateContent(
contents,
{},
@@ -2384,12 +2196,12 @@ ${JSON.stringify(
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(mockGenerateContentFn).not.toHaveBeenCalledWith({
expect(mockContentGenerator.generateContent).not.toHaveBeenCalledWith({
model: initialModel,
config: expect.any(Object),
contents,
});
expect(mockGenerateContentFn).toHaveBeenCalledWith(
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
{
model: DEFAULT_GEMINI_FLASH_MODEL,
config: expect.any(Object),
@@ -2427,73 +2239,4 @@ ${JSON.stringify(
);
});
});
describe('setHistory', () => {
it('should strip thought signatures when stripThoughts is true', () => {
const mockChat = {
setHistory: vi.fn(),
};
client['chat'] = mockChat as unknown as GeminiChat;
const historyWithThoughts: Content[] = [
{
role: 'user',
parts: [{ text: 'hello' }],
},
{
role: 'model',
parts: [
{ text: 'thinking...', thoughtSignature: 'thought-123' },
{
functionCall: { name: 'test', args: {} },
thoughtSignature: 'thought-456',
},
],
},
];
client.setHistory(historyWithThoughts, { stripThoughts: true });
const expectedHistory: Content[] = [
{
role: 'user',
parts: [{ text: 'hello' }],
},
{
role: 'model',
parts: [
{ text: 'thinking...' },
{ functionCall: { name: 'test', args: {} } },
],
},
];
expect(mockChat.setHistory).toHaveBeenCalledWith(expectedHistory);
});
it('should not strip thought signatures when stripThoughts is false', () => {
const mockChat = {
setHistory: vi.fn(),
};
client['chat'] = mockChat as unknown as GeminiChat;
const historyWithThoughts: Content[] = [
{
role: 'user',
parts: [{ text: 'hello' }],
},
{
role: 'model',
parts: [
{ text: 'thinking...', thoughtSignature: 'thought-123' },
{ text: 'ok', thoughtSignature: 'thought-456' },
],
},
];
client.setHistory(historyWithThoughts, { stripThoughts: false });
expect(mockChat.setHistory).toHaveBeenCalledWith(historyWithThoughts);
});
});
});
+21 -61
View File
@@ -20,7 +20,6 @@ import type { ServerGeminiStreamEvent, ChatCompressionInfo } from './turn.js';
import { CompressionStatus } from './turn.js';
import { Turn, GeminiEventType } from './turn.js';
import type { Config } from '../config/config.js';
import type { UserTierId } from '../code_assist/types.js';
import { getCoreSystemPrompt, getCompressionPrompt } from './prompts.js';
import { getResponseText } from '../utils/partUtils.js';
import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js';
@@ -31,12 +30,8 @@ import { getErrorMessage } from '../utils/errors.js';
import { isFunctionResponse } from '../utils/messageInspectors.js';
import { tokenLimit } from './tokenLimits.js';
import type { ChatRecordingService } from '../services/chatRecordingService.js';
import type {
ContentGenerator,
ContentGeneratorConfig,
} from './contentGenerator.js';
import { AuthType, createContentGenerator } from './contentGenerator.js';
import { ProxyAgent, setGlobalDispatcher } from 'undici';
import type { ContentGenerator } from './contentGenerator.js';
import { AuthType } from './contentGenerator.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_THINKING_MODE,
@@ -115,8 +110,6 @@ const COMPRESSION_PRESERVE_THRESHOLD = 0.3;
export class GeminiClient {
private chat?: GeminiChat;
private contentGenerator?: ContentGenerator;
private readonly embeddingModel: string;
private readonly generateContentConfig: GenerateContentConfig = {
temperature: 0,
topP: 1,
@@ -135,33 +128,19 @@ export class GeminiClient {
private hasFailedCompressionAttempt = false;
constructor(private readonly config: Config) {
if (config.getProxy()) {
setGlobalDispatcher(new ProxyAgent(config.getProxy() as string));
}
this.embeddingModel = config.getEmbeddingModel();
this.loopDetector = new LoopDetectionService(config);
this.lastPromptId = this.config.getSessionId();
}
async initialize(contentGeneratorConfig: ContentGeneratorConfig) {
this.contentGenerator = await createContentGenerator(
contentGeneratorConfig,
this.config,
this.config.getSessionId(),
);
async initialize() {
this.chat = await this.startChat();
}
getContentGenerator(): ContentGenerator {
if (!this.contentGenerator) {
private getContentGeneratorOrFail(): ContentGenerator {
if (!this.config.getContentGenerator()) {
throw new Error('Content generator not initialized');
}
return this.contentGenerator;
}
getUserTier(): UserTierId | undefined {
return this.contentGenerator?.userTier;
return this.config.getContentGenerator();
}
async addHistory(content: Content) {
@@ -176,39 +155,19 @@ export class GeminiClient {
}
isInitialized(): boolean {
return this.chat !== undefined && this.contentGenerator !== undefined;
return this.chat !== undefined;
}
getHistory(): Content[] {
return this.getChat().getHistory();
}
setHistory(
history: Content[],
{ stripThoughts = false }: { stripThoughts?: boolean } = {},
) {
const historyToSet = stripThoughts
? history.map((content) => {
const newContent = { ...content };
if (newContent.parts) {
newContent.parts = newContent.parts.map((part) => {
if (
part &&
typeof part === 'object' &&
'thoughtSignature' in part
) {
const newPart = { ...part };
delete (newPart as { thoughtSignature?: string })
.thoughtSignature;
return newPart;
}
return part;
});
}
return newContent;
})
: history;
this.getChat().setHistory(historyToSet);
stripThoughtsFromHistory() {
this.getChat().stripThoughtsFromHistory();
}
setHistory(history: Content[]) {
this.getChat().setHistory(history);
this.forceFullIdeContext = true;
}
@@ -242,9 +201,11 @@ export class GeminiClient {
this.forceFullIdeContext = true;
this.hasFailedCompressionAttempt = false;
const envParts = await getEnvironmentContext(this.config);
const toolRegistry = this.config.getToolRegistry();
const toolDeclarations = toolRegistry.getFunctionDeclarations();
const tools: Tool[] = [{ functionDeclarations: toolDeclarations }];
const history: Content[] = [
{
role: 'user',
@@ -274,7 +235,6 @@ export class GeminiClient {
: this.generateContentConfig;
return new GeminiChat(
this.config,
this.getContentGenerator(),
{
systemInstruction,
...generateContentConfigWithThinking,
@@ -600,7 +560,7 @@ export class GeminiClient {
};
const apiCall = () =>
this.getContentGenerator().generateContent(
this.getContentGeneratorOrFail().generateContent(
{
model,
config: {
@@ -711,7 +671,7 @@ export class GeminiClient {
};
const apiCall = () =>
this.getContentGenerator().generateContent(
this.getContentGeneratorOrFail().generateContent(
{
model,
config: requestConfig,
@@ -751,12 +711,12 @@ export class GeminiClient {
return [];
}
const embedModelParams: EmbedContentParameters = {
model: this.embeddingModel,
model: this.config.getEmbeddingModel(),
contents: texts,
};
const embedContentResponse =
await this.getContentGenerator().embedContent(embedModelParams);
await this.getContentGeneratorOrFail().embedContent(embedModelParams);
if (
!embedContentResponse.embeddings ||
embedContentResponse.embeddings.length === 0
@@ -802,7 +762,7 @@ export class GeminiClient {
const model = this.config.getModel();
const { totalTokens: originalTokenCount } =
await this.getContentGenerator().countTokens({
await this.getContentGeneratorOrFail().countTokens({
model,
contents: curatedHistory,
});
@@ -877,7 +837,7 @@ export class GeminiClient {
this.forceFullIdeContext = true;
const { totalTokens: newTokenCount } =
await this.getContentGenerator().countTokens({
await this.getContentGeneratorOrFail().countTokens({
// model might change after calling `sendMessage`, so we get the newest value from config
model: this.config.getModel(),
contents: chat.getHistory(),
+89 -45
View File
@@ -7,11 +7,11 @@
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import type {
Content,
Models,
GenerateContentConfig,
Part,
GenerateContentResponse,
} from '@google/genai';
import type { ContentGenerator } from '../core/contentGenerator.js';
import {
GeminiChat,
EmptyStreamError,
@@ -47,15 +47,6 @@ vi.mock('node:fs', () => {
};
});
// Mocks
const mockModelsModule = {
generateContent: vi.fn(),
generateContentStream: vi.fn(),
countTokens: vi.fn(),
embedContent: vi.fn(),
batchEmbedContents: vi.fn(),
} as unknown as Models;
const { mockLogInvalidChunk, mockLogContentRetry, mockLogContentRetryFailure } =
vi.hoisted(() => ({
mockLogInvalidChunk: vi.fn(),
@@ -70,12 +61,21 @@ vi.mock('../telemetry/loggers.js', () => ({
}));
describe('GeminiChat', () => {
let mockContentGenerator: ContentGenerator;
let chat: GeminiChat;
let mockConfig: Config;
const config: GenerateContentConfig = {};
beforeEach(() => {
vi.clearAllMocks();
mockContentGenerator = {
generateContent: vi.fn(),
generateContentStream: vi.fn(),
countTokens: vi.fn(),
embedContent: vi.fn(),
batchEmbedContents: vi.fn(),
} as unknown as ContentGenerator;
mockConfig = {
getSessionId: () => 'test-session-id',
getTelemetryLogPromptsEnabled: () => true,
@@ -97,12 +97,13 @@ describe('GeminiChat', () => {
getToolRegistry: vi.fn().mockReturnValue({
getTool: vi.fn(),
}),
getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator),
} as unknown as Config;
// Disable 429 simulation for tests
setSimulate429(false);
// Reset history for each test by creating a new instance
chat = new GeminiChat(mockConfig, mockModelsModule, config, []);
chat = new GeminiChat(mockConfig, config, []);
});
afterEach(() => {
@@ -160,7 +161,7 @@ describe('GeminiChat', () => {
],
} as unknown as GenerateContentResponse;
vi.mocked(mockModelsModule.generateContent).mockResolvedValue(
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
mockAfcResponse,
);
@@ -218,7 +219,7 @@ describe('GeminiChat', () => {
{ content: { role: 'model', parts: [{ text: 'some response' }] } },
],
} as unknown as GenerateContentResponse;
vi.mocked(mockModelsModule.generateContent).mockResolvedValue(
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
mockResponse,
);
@@ -247,7 +248,7 @@ describe('GeminiChat', () => {
],
} as unknown as GenerateContentResponse;
vi.mocked(mockModelsModule.generateContent).mockResolvedValue(
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
mixedContentResponse,
);
@@ -302,7 +303,7 @@ describe('GeminiChat', () => {
{ content: { role: 'model', parts: [{ thought: true }] } },
],
} as unknown as GenerateContentResponse;
vi.mocked(mockModelsModule.generateContent).mockResolvedValue(
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
emptyModelResponse,
);
@@ -348,11 +349,13 @@ describe('GeminiChat', () => {
],
text: () => 'response',
} as unknown as GenerateContentResponse;
vi.mocked(mockModelsModule.generateContent).mockResolvedValue(response);
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
response,
);
await chat.sendMessage({ message: 'hello' }, 'prompt-id-1');
expect(mockModelsModule.generateContent).toHaveBeenCalledWith(
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
{
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
@@ -390,7 +393,7 @@ describe('GeminiChat', () => {
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
streamWithToolCall,
);
@@ -442,7 +445,7 @@ describe('GeminiChat', () => {
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
streamWithNoFinish,
);
@@ -487,7 +490,7 @@ describe('GeminiChat', () => {
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
streamWithInvalidEnd,
);
@@ -543,7 +546,7 @@ describe('GeminiChat', () => {
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
multiChunkStream,
);
@@ -593,7 +596,7 @@ describe('GeminiChat', () => {
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
multiChunkStream,
);
@@ -650,7 +653,7 @@ describe('GeminiChat', () => {
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
multiChunkStream,
);
@@ -697,7 +700,7 @@ describe('GeminiChat', () => {
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
mixedContentStream,
);
@@ -759,7 +762,7 @@ describe('GeminiChat', () => {
],
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
emptyStreamResponse,
);
@@ -811,7 +814,7 @@ describe('GeminiChat', () => {
text: () => 'response',
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
response,
);
@@ -823,7 +826,7 @@ describe('GeminiChat', () => {
// consume stream to trigger internal logic
}
expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith(
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledWith(
{
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
@@ -1012,7 +1015,7 @@ describe('GeminiChat', () => {
describe('sendMessageStream with retries', () => {
it('should yield a RETRY event when an invalid stream is encountered', async () => {
// ARRANGE: Mock the stream to fail once, then succeed.
vi.mocked(mockModelsModule.generateContentStream)
vi.mocked(mockContentGenerator.generateContentStream)
.mockImplementationOnce(async () =>
// First attempt: An invalid stream with an empty text part.
(async function* () {
@@ -1053,7 +1056,7 @@ describe('GeminiChat', () => {
});
it('should retry on invalid content, succeed, and report metrics', async () => {
// Use mockImplementationOnce to provide a fresh, promise-wrapped generator for each attempt.
vi.mocked(mockModelsModule.generateContentStream)
vi.mocked(mockContentGenerator.generateContentStream)
.mockImplementationOnce(async () =>
// First call returns an invalid stream
(async function* () {
@@ -1089,7 +1092,9 @@ describe('GeminiChat', () => {
expect(mockLogInvalidChunk).toHaveBeenCalledTimes(1);
expect(mockLogContentRetry).toHaveBeenCalledTimes(1);
expect(mockLogContentRetryFailure).not.toHaveBeenCalled();
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2);
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(
2,
);
// Check for a retry event
expect(chunks.some((c) => c.type === StreamEventType.RETRY)).toBe(true);
@@ -1118,7 +1123,7 @@ describe('GeminiChat', () => {
});
it('should fail after all retries on persistent invalid content and report metrics', async () => {
vi.mocked(mockModelsModule.generateContentStream).mockImplementation(
vi.mocked(mockContentGenerator.generateContentStream).mockImplementation(
async () =>
(async function* () {
yield {
@@ -1150,7 +1155,9 @@ describe('GeminiChat', () => {
);
// Should be called 3 times (initial + 2 retries)
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(3);
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(
3,
);
expect(mockLogInvalidChunk).toHaveBeenCalledTimes(3);
expect(mockLogContentRetry).toHaveBeenCalledTimes(2);
expect(mockLogContentRetryFailure).toHaveBeenCalledTimes(1);
@@ -1169,7 +1176,7 @@ describe('GeminiChat', () => {
chat.setHistory(initialHistory);
// 2. Mock the API to fail once with an empty stream, then succeed.
vi.mocked(mockModelsModule.generateContentStream)
vi.mocked(mockContentGenerator.generateContentStream)
.mockImplementationOnce(async () =>
(async function* () {
yield {
@@ -1264,7 +1271,7 @@ describe('GeminiChat', () => {
} as unknown as GenerateContentResponse;
// 2. Mock the API to return our controllable promises in order
vi.mocked(mockModelsModule.generateContent)
vi.mocked(mockContentGenerator.generateContent)
.mockReturnValueOnce(firstCallPromise)
.mockReturnValueOnce(secondCallPromise);
@@ -1285,8 +1292,8 @@ describe('GeminiChat', () => {
// 5. CRUCIAL CHECK: At this point, only the first API call should have been made.
// The second call should be waiting on `sendPromise`.
expect(mockModelsModule.generateContent).toHaveBeenCalledTimes(1);
expect(mockModelsModule.generateContent).toHaveBeenCalledWith(
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(1);
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
contents: expect.arrayContaining([
expect.objectContaining({ parts: [{ text: 'first' }] }),
@@ -1303,8 +1310,8 @@ describe('GeminiChat', () => {
await new Promise(process.nextTick);
// 7. CRUCIAL CHECK: Now, the second API call should have been made.
expect(mockModelsModule.generateContent).toHaveBeenCalledTimes(2);
expect(mockModelsModule.generateContent).toHaveBeenCalledWith(
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(2);
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
contents: expect.arrayContaining([
expect.objectContaining({ parts: [{ text: 'second' }] }),
@@ -1320,7 +1327,7 @@ describe('GeminiChat', () => {
});
it('should retry if the model returns a completely empty stream (no chunks)', async () => {
// 1. Mock the API to return an empty stream first, then a valid one.
vi.mocked(mockModelsModule.generateContentStream)
vi.mocked(mockContentGenerator.generateContentStream)
.mockImplementationOnce(
// First call resolves to an async generator that yields nothing.
async () => (async function* () {})(),
@@ -1353,7 +1360,7 @@ describe('GeminiChat', () => {
}
// 3. Assert the results.
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2);
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(2);
expect(
chunks.some(
(c) =>
@@ -1417,7 +1424,7 @@ describe('GeminiChat', () => {
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockModelsModule.generateContentStream)
vi.mocked(mockContentGenerator.generateContentStream)
.mockResolvedValueOnce(firstStreamGenerator)
.mockResolvedValueOnce(secondStreamGenerator);
@@ -1436,7 +1443,7 @@ describe('GeminiChat', () => {
);
// 5. Assert that only one API call has been made so far.
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(1);
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(1);
// 6. Unblock and fully consume the first stream to completion.
continueFirstStream!();
@@ -1451,7 +1458,7 @@ describe('GeminiChat', () => {
await secondStreamIterator.next();
// 9. The second API call should now have been made.
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2);
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(2);
// 10. FIX: Fully consume the second stream to ensure recordHistory is called.
await secondStreamIterator.next(); // This finishes the iterator.
@@ -1471,7 +1478,7 @@ describe('GeminiChat', () => {
it('should discard valid partial content from a failed attempt upon retry', async () => {
// ARRANGE: Mock the stream to fail on the first attempt after yielding some valid content.
vi.mocked(mockModelsModule.generateContentStream)
vi.mocked(mockContentGenerator.generateContentStream)
.mockImplementationOnce(async () =>
// First attempt: yields one valid chunk, then one invalid chunk
(async function* () {
@@ -1517,7 +1524,7 @@ describe('GeminiChat', () => {
// ASSERT
// Check that a retry happened
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2);
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(2);
expect(events.some((e) => e.type === StreamEventType.RETRY)).toBe(true);
// Check the final recorded history
@@ -1532,4 +1539,41 @@ describe('GeminiChat', () => {
'This valid part should be discarded',
);
});
describe('stripThoughtsFromHistory', () => {
it('should strip thought signatures', () => {
chat.setHistory([
{
role: 'user',
parts: [{ text: 'hello' }],
},
{
role: 'model',
parts: [
{ text: 'thinking...', thoughtSignature: 'thought-123' },
{
functionCall: { name: 'test', args: {} },
thoughtSignature: 'thought-456',
},
],
},
]);
chat.stripThoughtsFromHistory();
expect(chat.getHistory()).toEqual([
{
role: 'user',
parts: [{ text: 'hello' }],
},
{
role: 'model',
parts: [
{ text: 'thinking...' },
{ functionCall: { name: 'test', args: {} } },
],
},
]);
});
});
});
+20 -4
View File
@@ -18,7 +18,6 @@ import type {
import { toParts } from '../code_assist/converter.js';
import { createUserContent } from '@google/genai';
import { retryWithBackoff } from '../utils/retry.js';
import type { ContentGenerator } from './contentGenerator.js';
import { AuthType } from './contentGenerator.js';
import type { Config } from '../config/config.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
@@ -172,7 +171,6 @@ export class GeminiChat {
constructor(
private readonly config: Config,
private readonly contentGenerator: ContentGenerator,
private readonly generationConfig: GenerateContentConfig = {},
private history: Content[] = [],
) {
@@ -287,7 +285,7 @@ export class GeminiChat {
);
}
return this.contentGenerator.generateContent(
return this.config.getContentGenerator().generateContent(
{
model: modelToUse,
contents: requestContents,
@@ -498,7 +496,7 @@ export class GeminiChat {
);
}
return this.contentGenerator.generateContentStream(
return this.config.getContentGenerator().generateContentStream(
{
model: modelToUse,
contents: requestContents,
@@ -570,10 +568,28 @@ export class GeminiChat {
addHistory(content: Content): void {
this.history.push(content);
}
setHistory(history: Content[]): void {
this.history = history;
}
stripThoughtsFromHistory(): void {
this.history = this.history.map((content) => {
const newContent = { ...content };
if (newContent.parts) {
newContent.parts = newContent.parts.map((part) => {
if (part && typeof part === 'object' && 'thoughtSignature' in part) {
const newPart = { ...part };
delete (newPart as { thoughtSignature?: string }).thoughtSignature;
return newPart;
}
return part;
});
}
return newContent;
});
}
setTools(tools: Tool[]): void {
this.generationConfig.tools = tools;
}
+5 -9
View File
@@ -55,8 +55,6 @@ async function createMockConfig(
};
const config = new Config(configParams);
await config.initialize();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
await config.refreshAuth('test-auth' as any);
// Mock ToolRegistry
const mockToolRegistry = {
@@ -164,15 +162,13 @@ describe('subagent.ts', () => {
// Helper to safely access generationConfig from mock calls
const getGenerationConfigFromMock = (
callIndex = 0,
): GenerateContentConfig & { systemInstruction?: string | Content } => {
): GenerateContentConfig => {
const callArgs = vi.mocked(GeminiChat).mock.calls[callIndex];
const generationConfig = callArgs?.[2];
const generationConfig = callArgs?.[1];
// Ensure it's defined before proceeding
expect(generationConfig).toBeDefined();
if (!generationConfig) throw new Error('generationConfig is undefined');
return generationConfig as GenerateContentConfig & {
systemInstruction?: string | Content;
};
return generationConfig as GenerateContentConfig;
};
describe('create (Tool Validation)', () => {
@@ -347,7 +343,7 @@ describe('subagent.ts', () => {
);
// Check History (should include environment context)
const history = callArgs[3];
const history = callArgs[2];
expect(history).toEqual([
{ role: 'user', parts: [{ text: 'Env Context' }] },
{
@@ -420,7 +416,7 @@ describe('subagent.ts', () => {
const callArgs = vi.mocked(GeminiChat).mock.calls[0];
const generationConfig = getGenerationConfigFromMock();
const history = callArgs[3];
const history = callArgs[2];
expect(generationConfig.systemInstruction).toBeUndefined();
expect(history).toEqual([
+1 -11
View File
@@ -10,7 +10,6 @@ import type { AnyDeclarativeTool } from '../tools/tools.js';
import type { Config } from '../config/config.js';
import type { ToolCallRequestInfo } from './turn.js';
import { executeToolCall } from './nonInteractiveToolExecutor.js';
import { createContentGenerator } from './contentGenerator.js';
import { getEnvironmentContext } from '../utils/environmentContext.js';
import type {
Content,
@@ -635,9 +634,7 @@ export class SubAgentScope {
: undefined;
try {
const generationConfig: GenerateContentConfig & {
systemInstruction?: string | Content;
} = {
const generationConfig: GenerateContentConfig = {
temperature: this.modelConfig.temp,
topP: this.modelConfig.top_p,
};
@@ -646,17 +643,10 @@ export class SubAgentScope {
generationConfig.systemInstruction = systemInstruction;
}
const contentGenerator = await createContentGenerator(
this.runtimeContext.getContentGeneratorConfig(),
this.runtimeContext,
this.runtimeContext.getSessionId(),
);
this.runtimeContext.setModel(this.modelConfig.model);
return new GeminiChat(
this.runtimeContext,
contentGenerator,
generationConfig,
start_history,
);
@@ -6,10 +6,10 @@
import type { Mock } from 'vitest';
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import type { Content, GoogleGenAI, Models } from '@google/genai';
import type { Content } from '@google/genai';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { GeminiClient } from '../core/client.js';
import { Config } from '../config/config.js';
import type { Config } from '../config/config.js';
import type { NextSpeakerResponse } from './nextSpeakerChecker.js';
import { checkNextSpeaker } from './nextSpeakerChecker.js';
import { GeminiChat } from '../core/geminiChat.js';
@@ -44,73 +44,28 @@ vi.mock('node:fs', () => {
vi.mock('../core/client.js');
vi.mock('../config/config.js');
// Define mocks for GoogleGenAI and Models instances that will be used across tests
const mockModelsInstance = {
generateContent: vi.fn(),
generateContentStream: vi.fn(),
countTokens: vi.fn(),
embedContent: vi.fn(),
batchEmbedContents: vi.fn(),
} as unknown as Models;
const mockGoogleGenAIInstance = {
getGenerativeModel: vi.fn().mockReturnValue(mockModelsInstance),
// Add other methods of GoogleGenAI if they are directly used by GeminiChat constructor or its methods
} as unknown as GoogleGenAI;
vi.mock('@google/genai', async () => {
const actualGenAI =
await vi.importActual<typeof import('@google/genai')>('@google/genai');
return {
...actualGenAI,
GoogleGenAI: vi.fn(() => mockGoogleGenAIInstance), // Mock constructor to return the predefined instance
// If Models is instantiated directly in GeminiChat, mock its constructor too
// For now, assuming Models instance is obtained via getGenerativeModel
};
});
describe('checkNextSpeaker', () => {
let chatInstance: GeminiChat;
let mockConfig: Config;
let mockGeminiClient: GeminiClient;
let MockConfig: Mock;
const abortSignal = new AbortController().signal;
beforeEach(() => {
MockConfig = vi.mocked(Config);
const mockConfigInstance = new MockConfig(
'test-api-key',
'gemini-pro',
false,
'.',
false,
undefined,
false,
undefined,
undefined,
undefined,
);
vi.resetAllMocks();
mockConfig = {
getProjectRoot: vi.fn().mockReturnValue('/test/project/root'),
getSessionId: vi.fn().mockReturnValue('test-session-id'),
getModel: () => 'test-model',
storage: {
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
},
} as unknown as Config;
// Mock the methods that ChatRecordingService needs
mockConfigInstance.getSessionId = vi
.fn()
.mockReturnValue('test-session-id');
mockConfigInstance.getProjectRoot = vi
.fn()
.mockReturnValue('/test/project/root');
mockConfigInstance.storage = {
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
};
mockGeminiClient = new GeminiClient(mockConfigInstance);
// Reset mocks before each test to ensure test isolation
vi.mocked(mockModelsInstance.generateContent).mockReset();
vi.mocked(mockModelsInstance.generateContentStream).mockReset();
mockGeminiClient = new GeminiClient(mockConfig);
// GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor
chatInstance = new GeminiChat(
mockConfigInstance,
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
mockConfig,
{},
[], // initial history
);
@@ -120,7 +75,7 @@ describe('checkNextSpeaker', () => {
});
afterEach(() => {
vi.clearAllMocks();
vi.restoreAllMocks();
});
it('should return null if history is empty', async () => {
@@ -135,9 +90,9 @@ describe('checkNextSpeaker', () => {
});
it('should return null if the last speaker was the user', async () => {
(chatInstance.getHistory as Mock).mockReturnValue([
vi.mocked(chatInstance.getHistory).mockReturnValue([
{ role: 'user', parts: [{ text: 'Hello' }] },
] as Content[]);
]);
const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,